In [1]:
import warnings
warnings.filterwarnings('ignore')

import os
import torch
from torch import nn

from hparams.hparams_W4G import create_hparams
from model import Model
from datasets import LJDataset, TextMelCollate
from utils import sizeof_fmt, Logger

In [2]:
!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="0"

Fri Mar 19 12:55:43 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 208...  Off  | 00000000:2E:00.0 Off |                  N/A |
| 37%   61C    P0    73W / 250W |      0MiB / 11016MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  GeForce RTX 208...  Off  | 00000000:2F:00.0 Off |                  N/A |
| 53%   62C    P2   187W / 250W |   3645MiB / 11019MiB |     82%      Defaul

In [5]:
save_dir = '/data/save/model_W4G'
logger = Logger(save_dir=save_dir, new=False)
print('done')

done


In [6]:
!ls $save_dir

data.json    save_101436  save_102141  save_400000
save_100000  save_101938  save_102658


In [7]:
stt_hparams, tts_hparams = create_hparams()
model = Model(stt_hparams, tts_hparams, mode='train')
model = model.cuda()
step = 400000

size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters()))
print(f"Model size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.parameters()))
print(f"TTS size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.mel_decoder.parameters()))
print(f"MelDecoder size {size}")

if True:
    model, _, _ = logger.load(step, model, None)
print(step)

print('done')

Model size 124.1MiB
TTS size 36.5MiB
MelDecoder size 22.9MiB
loaded : 400000
400000
done


In [12]:
trainset = LJDataset(tts_hparams, split='train')
testset = LJDataset(tts_hparams, split='valid')
collate_fn = TextMelCollate(tts_hparams)

train_loader = torch.utils.data.DataLoader(trainset, num_workers=1, #tts_hparams.num_workers, 
                          shuffle=True, sampler=None, batch_size=tts_hparams.batch_size, pin_memory=False,
                          drop_last=True, collate_fn=collate_fn)
print(train_loader)

test_loader = torch.utils.data.DataLoader(testset, num_workers=1, 
                          shuffle=False, sampler=None, batch_size=1, pin_memory=False,
                          drop_last=True, collate_fn=collate_fn)
print(test_loader)

<torch.utils.data.dataloader.DataLoader object at 0x7ff0e4387fd0>
<torch.utils.data.dataloader.DataLoader object at 0x7ff0609bbe80>


In [13]:
def to_cuda(batch):
    batch['text'] = batch['text'].cuda()
    batch['text_lengths'] = batch['text_lengths'].cuda()
    batch['mels'] = batch['mels'].cuda()
    batch['mel_lengths'] = batch['mel_lengths'].cuda()
    
    return batch

In [14]:
import matplotlib.pyplot as plt
import librosa.display
import numpy as np

ce_losses = []
recon_losses = []
kl_losses = []
sample_stds = []
model.eval()
with torch.no_grad():
    for i, batch in enumerate(train_loader):
        batch = to_cuda(batch)
        stt_outputs, tts_outputs = model(batch)
        ce_loss = stt_outputs['loss'].item()
        recon_loss = tts_outputs['recon_loss'].item()
        kl_loss = tts_outputs['kl_loss'].item()
        print(i, recon_loss)
        
        ce_losses.append(ce_loss)
        recon_losses.append(recon_loss)
        kl_losses.append(kl_loss)
        
        model.eval()
        
        samples_list = []
        for _ in range(10):
            samples, _ = model.inference(batch['text'], batch['mels'].size(2), stt_outputs["alignments"], temperature=1.0, clip=2)
            samples_list.append(samples)
        samples_list = torch.cat(samples_list, dim=0)
        sample_std = torch.std(samples_list, dim=0).mean().item()
        sample_stds.append(sample_std)
            
        
print('done')

print(np.mean(ce_losses))
print(np.mean(recon_losses))
print(np.mean(kl_losses))
print(np.mean(sample_stds))

0 0.02667085826396942
1 0.023400668054819107
2 0.023837624117732048
3 0.02691640332341194
4 0.02382315695285797
5 0.025670312345027924
6 0.02382320538163185
7 0.02772974595427513
8 0.02745291404426098
9 0.025503579527139664
10 0.02708583138883114
11 0.025199536234140396
12 0.026328686624765396
13 0.02591957524418831
14 0.027820836752653122
15 0.0263991542160511
16 0.023914894089102745
17 0.024729019030928612
18 0.02718309499323368
19 0.026684589684009552
20 0.02692275308072567
21 0.027878224849700928
22 0.025882594287395477
23 0.027538947761058807
24 0.025294728577136993
25 0.027478905394673347
26 0.027889825403690338
27 0.027582192793488503
28 0.023624958470463753
29 0.029210040345788002
30 0.028289226815104485
31 0.025709498673677444
32 0.0254405215382576
33 0.025958798825740814
34 0.027683133259415627
35 0.02773411199450493
36 0.026453088968992233
37 0.025053178891539574
38 0.024989325553178787
39 0.02535027079284191
40 0.027588099241256714
41 0.026543449610471725
42 0.0250306278467

337 0.02441144548356533
338 0.025161778554320335
339 0.0252109132707119
340 0.02484969049692154
341 0.025519205257296562
342 0.026870617642998695
343 0.027341043576598167
344 0.024248510599136353
345 0.025592323392629623
346 0.027269523590803146
347 0.02558806538581848
348 0.027248511090874672
349 0.029391568154096603
350 0.024167723953723907
351 0.023876497521996498
352 0.02791108749806881
353 0.024208970367908478
354 0.02634073793888092
355 0.027660205960273743
356 0.02697041817009449
357 0.02435673587024212
358 0.02516515925526619
359 0.024938052520155907
360 0.025893010199069977
361 0.023968210443854332
362 0.02672012336552143
363 0.02571903169155121
364 0.0235916581004858
365 0.024370871484279633
366 0.026865074411034584
367 0.02603326365351677
368 0.027656931430101395
369 0.026097552850842476
370 0.02479543536901474
371 0.024610823020339012
372 0.027288910001516342
373 0.024811336770653725
374 0.026877425611019135
375 0.02671832963824272
376 0.02656613290309906
377 0.026910964399

In [18]:
print(np.mean(recon_losses))

0.045737755880958735
