In [1]:
import warnings
warnings.filterwarnings('ignore')

import os
import torch
from torch import nn

from hparams.hparams_W5G import create_hparams
from model import Model
from datasets import LJDataset, TextMelCollate
from utils import sizeof_fmt, Logger

In [2]:
!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="0"

Mon Mar 15 17:19:14 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 208...  Off  | 00000000:2E:00.0 Off |                  N/A |
| 54%   61C    P0    74W / 250W |      0MiB / 11016MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  GeForce RTX 208...  Off  | 00000000:2F:00.0 Off |                  N/A |
| 52%   54C    P8    29W / 250W |   1846MiB / 11019MiB |      0%      Default |
|       

In [3]:
save_dir = '/data/save/model_W5G'
logger = Logger(save_dir=save_dir, new=False)
print('done')

done


In [4]:
!ls $save_dir

data.json    save_120000  save_155000  save_195000  save_38005	save_75000
save_100000  save_125000  save_160000  save_196255  save_40000	save_77732
save_100003  save_130000  save_165000  save_20000   save_42922	save_80000
save_105000  save_131258  save_165513  save_200000  save_45000	save_85000
save_108400  save_135000  save_169764  save_205000  save_50000	save_85982
save_110000  save_138640  save_170000  save_207498  save_55000	save_90000
save_112252  save_140000  save_175000  save_20951   save_60000	save_95000
save_112360  save_145000  save_180000  save_25000   save_65000	save_96773
save_113679  save_15000   save_185000  save_30000   save_67743
save_115000  save_150000  save_190000  save_35000   save_70000


In [5]:
stt_hparams, tts_hparams = create_hparams()
model = Model(stt_hparams, tts_hparams, mode='train')
model = model.cuda()
step = 200000

size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters()))
print(f"Model size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.parameters()))
print(f"TTS size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.mel_decoder.parameters()))
print(f"MelDecoder size {size}")

if True:
    model, _, _ = logger.load(step, model, None)
print(step)

print('done')

Model size 132.9MiB
TTS size 45.2MiB
MelDecoder size 28.7MiB
loaded : 200000
200000
done


In [8]:
testset = LJDataset(tts_hparams, split='valid')
collate_fn = TextMelCollate(tts_hparams)

test_loader = torch.utils.data.DataLoader(testset, num_workers=1, 
                          shuffle=False, sampler=None, batch_size=1, pin_memory=False,
                          drop_last=True, collate_fn=collate_fn)
print(test_loader)

<torch.utils.data.dataloader.DataLoader object at 0x7ff1d0913190>


In [9]:
def to_cuda(batch):
    batch['text'] = batch['text'].cuda()
    batch['text_lengths'] = batch['text_lengths'].cuda()
    batch['mels'] = batch['mels'].cuda()
    batch['mel_lengths'] = batch['mel_lengths'].cuda()
    
    return batch

In [10]:
import matplotlib.pyplot as plt
import librosa.display
import numpy as np

ce_losses = []
recon_losses = []
kl_losses = []
sample_stds = []
model.eval()
with torch.no_grad():
    for i, batch in enumerate(test_loader):
        batch = to_cuda(batch)
        stt_outputs, tts_outputs = model(batch)
        ce_loss = stt_outputs['loss'].item()
        recon_loss = tts_outputs['recon_loss'].item()
        kl_loss = tts_outputs['kl_loss'].item()
        print(i, recon_loss)
        
        ce_losses.append(ce_loss)
        recon_losses.append(recon_loss)
        kl_losses.append(kl_loss)
        
        model.eval()
        
        samples_list = []
        for _ in range(10):
            samples, _ = model.inference(batch['text'], batch['mels'].size(2), stt_outputs["alignments"], temperature=1.0, clip=2)
            samples_list.append(samples)
        samples_list = torch.cat(samples_list, dim=0)
        sample_std = torch.std(samples_list, dim=0).mean().item()
        sample_stds.append(sample_std)
            
        
print('done')

print(np.mean(ce_losses))
print(np.mean(recon_losses))
print(np.mean(kl_losses))
print(np.mean(sample_stds))

0 0.04715722054243088
1 0.04360160604119301
2 0.04405490309000015
3 0.04687179997563362
4 0.04639536887407303
5 0.04458916187286377
6 0.0459538996219635
7 0.04425394535064697
8 0.04579484835267067
9 0.04171115159988403
10 0.044945962727069855
11 0.04520116746425629
12 0.043805088847875595
13 0.043837498873472214
14 0.044868044555187225
15 0.04634243622422218
16 0.048601824790239334
17 0.04515663534402847
18 0.04852774366736412
19 0.04520774260163307
20 0.04550336301326752
21 0.044040337204933167
22 0.04459396004676819
23 0.04510466009378433
24 0.0460255965590477
25 0.04939733445644379
26 0.04588458687067032
27 0.04361061006784439
28 0.04475771635770798
29 0.04610371217131615
30 0.04492596164345741
31 0.04231884330511093
32 0.04899398237466812
33 0.04238775745034218
34 0.04774579033255577
35 0.043092645704746246
36 0.043703820556402206
37 0.04540408402681351
38 0.04505961015820503
39 0.046187520027160645
40 0.04666530340909958
41 0.04288794845342636
42 0.04285387322306633
43 0.041312310

344 0.04733552411198616
345 0.04731189087033272
346 0.04863543435931206
347 0.04986785724759102
348 0.046485304832458496
349 0.045285362750291824
350 0.04387057200074196
351 0.04827858507633209
352 0.045822951942682266
353 0.044787198305130005
354 0.04568511247634888
355 0.043563976883888245
356 0.04339488968253136
357 0.04518371820449829
358 0.04471335932612419
359 0.04785897210240364
360 0.04757748544216156
361 0.046178046613931656
362 0.046913884580135345
363 0.04420328885316849
364 0.04619733244180679
365 0.04993663355708122
366 0.04496149346232414
367 0.047340746968984604
368 0.05094803497195244
369 0.04774472489953041
370 0.04521750286221504
371 0.0464811846613884
372 0.04678346589207649
373 0.04572725668549538
374 0.04608190059661865
375 0.04847012832760811
376 0.047918159514665604
377 0.045605968683958054
378 0.04420149326324463
379 0.04516441002488136
380 0.045572903007268906
381 0.04782181605696678
382 0.046306610107421875
383 0.04603544622659683
384 0.04853209853172302
385 0