In [1]:
import warnings
warnings.filterwarnings('ignore')

import os
import torch
from torch import nn

from hparams.hparams_W5G import create_hparams
from model import Model
from datasets import LJDataset, TextMelCollate
from utils import sizeof_fmt, Logger

In [2]:
!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="0"

Fri Mar 19 12:56:01 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 208...  Off  | 00000000:2E:00.0 Off |                  N/A |
| 54%   62C    P2   103W / 250W |   1678MiB / 11016MiB |     30%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  GeForce RTX 208...  Off  | 00000000:2F:00.0 Off |                  N/A |
| 56%   66C    P2   184W / 250W |   3645MiB / 11019MiB |     88%      Defaul

In [3]:
save_dir = '/data/save/model_W5G'
logger = Logger(save_dir=save_dir, new=False)
print('done')

done


In [4]:
!ls $save_dir

data.json    save_140000  save_20000   save_250000  save_312251  save_45000
save_100000  save_145000  save_200000  save_255000  save_315000  save_50000
save_100003  save_15000   save_205000  save_260000  save_320000  save_55000
save_105000  save_150000  save_207498  save_265000  save_320618  save_60000
save_108400  save_155000  save_20951   save_270000  save_325000  save_65000
save_110000  save_160000  save_210000  save_275000  save_330000  save_67743
save_112252  save_165000  save_213814  save_275660  save_330312  save_70000
save_112360  save_165513  save_215000  save_280000  save_335000  save_75000
save_113679  save_169764  save_220000  save_285000  save_340000  save_77732
save_115000  save_170000  save_224267  save_290000  save_345000  save_80000
save_120000  save_175000  save_225000  save_293000  save_348351  save_85000
save_125000  save_180000  save_230000  save_295000  save_35000	 save_85982
save_130000  save_185000  save_235000  save_30000   save_38005	 save_90000
s

In [5]:
stt_hparams, tts_hparams = create_hparams()
model = Model(stt_hparams, tts_hparams, mode='train')
model = model.cuda()
step = 400000

size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters()))
print(f"Model size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.parameters()))
print(f"TTS size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.mel_decoder.parameters()))
print(f"MelDecoder size {size}")

if True:
    model, _, _ = logger.load(step, model, None)
print(step)

print('done')

Model size 132.9MiB
TTS size 45.2MiB
MelDecoder size 28.7MiB
loaded : 400000
400000
done


In [9]:
trainset = LJDataset(tts_hparams, split='train')
testset = LJDataset(tts_hparams, split='valid')
collate_fn = TextMelCollate(tts_hparams)

train_loader = torch.utils.data.DataLoader(trainset, num_workers=1, #tts_hparams.num_workers, 
                          shuffle=True, sampler=None, batch_size=tts_hparams.batch_size, pin_memory=False,
                          drop_last=True, collate_fn=collate_fn)
print(train_loader)

test_loader = torch.utils.data.DataLoader(testset, num_workers=1, 
                          shuffle=False, sampler=None, batch_size=1, pin_memory=False,
                          drop_last=True, collate_fn=collate_fn)
print(test_loader)

<torch.utils.data.dataloader.DataLoader object at 0x7f6730f1cee0>
<torch.utils.data.dataloader.DataLoader object at 0x7f672d4d6d30>


In [10]:
def to_cuda(batch):
    batch['text'] = batch['text'].cuda()
    batch['text_lengths'] = batch['text_lengths'].cuda()
    batch['mels'] = batch['mels'].cuda()
    batch['mel_lengths'] = batch['mel_lengths'].cuda()
    
    return batch

In [11]:
import matplotlib.pyplot as plt
import librosa.display
import numpy as np

ce_losses = []
recon_losses = []
kl_losses = []
sample_stds = []
model.eval()
with torch.no_grad():
    for i, batch in enumerate(train_loader):
        batch = to_cuda(batch)
        stt_outputs, tts_outputs = model(batch)
        ce_loss = stt_outputs['loss'].item()
        recon_loss = tts_outputs['recon_loss'].item()
        kl_loss = tts_outputs['kl_loss'].item()
        print(i, recon_loss)
        
        ce_losses.append(ce_loss)
        recon_losses.append(recon_loss)
        kl_losses.append(kl_loss)
        
        model.eval()
        
        samples_list = []
        for _ in range(10):
            samples, _ = model.inference(batch['text'], batch['mels'].size(2), stt_outputs["alignments"], temperature=1.0, clip=2)
            samples_list.append(samples)
        samples_list = torch.cat(samples_list, dim=0)
        sample_std = torch.std(samples_list, dim=0).mean().item()
        sample_stds.append(sample_std)
            
        
print('done')

print(np.mean(ce_losses))
print(np.mean(recon_losses))
print(np.mean(kl_losses))
print(np.mean(sample_stds))

0 0.02426605112850666
1 0.022838551551103592
2 0.02337535470724106
3 0.025763504207134247
4 0.027001088485121727
5 0.024851160123944283
6 0.02446964755654335
7 0.024855243042111397
8 0.02707437239587307
9 0.024453584104776382
10 0.02621963806450367
11 0.02769777551293373
12 0.026133226230740547
13 0.027506912127137184
14 0.02485697716474533
15 0.02594061754643917
16 0.02725783735513687
17 0.027164652943611145
18 0.02730446495115757
19 0.026360634714365005
20 0.026033712550997734
21 0.026163220405578613
22 0.024749964475631714
23 0.026501616463065147
24 0.027214596047997475
25 0.026498988270759583
26 0.0263333972543478
27 0.02510468102991581
28 0.026536645367741585
29 0.026459338143467903
30 0.025635959580540657
31 0.02333526685833931
32 0.022524336352944374
33 0.02690778858959675
34 0.027661962434649467
35 0.025010909885168076
36 0.02521771937608719
37 0.02509400062263012
38 0.024197649210691452
39 0.025675790384411812
40 0.02577301487326622
41 0.026644811034202576
42 0.024513034150004

338 0.026502545922994614
339 0.027018845081329346
340 0.028190163895487785
341 0.025733202695846558
342 0.02662874013185501
343 0.02658284455537796
344 0.02667674608528614
345 0.02469494193792343
346 0.027871057391166687
347 0.025258779525756836
348 0.025096306577324867
349 0.02541038766503334
350 0.027150887995958328
351 0.024090472608804703
352 0.02454935386776924
353 0.025550242513418198
354 0.02747110091149807
355 0.0257249902933836
356 0.026379013434052467
357 0.027974244207143784
358 0.02880275249481201
359 0.023139843717217445
360 0.025396475568413734
361 0.02504163235425949
362 0.027628596872091293
363 0.028561430051922798
364 0.025943424552679062
365 0.023778000846505165
366 0.024771038442850113
367 0.026474859565496445
368 0.028862327337265015
369 0.024619750678539276
370 0.026295192539691925
371 0.026579245924949646
372 0.025796834379434586
373 0.027061140164732933
374 0.025946825742721558
375 0.024021003395318985
376 0.02628871612250805
377 0.025415925309062004
378 0.027078