In [1]:
import warnings
warnings.filterwarnings('ignore')

import os
import torch
from torch import nn

from hparams.hparams_S5G import create_hparams
from model import Model
from datasets import LJDataset, TextMelCollate
from utils import sizeof_fmt, Logger

In [3]:
!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="1"

Mon Mar 15 17:09:50 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 208...  Off  | 00000000:2E:00.0 Off |                  N/A |
| 67%   80C    P2   116W / 250W |   1686MiB / 11016MiB |     27%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  GeForce RTX 208...  Off  | 00000000:2F:00.0 Off |                  N/A |
| 49%   68C    P0    53W / 250W |      0MiB / 11019MiB |      0%      Defaul

In [4]:
save_dir = '/data/save/model_S5G'
logger = Logger(save_dir=save_dir, new=False)
print('done')

done


In [5]:
!ls $save_dir

data.json    save_123671  save_145000  save_177763  save_205000  save_75000
save_100000  save_123692  save_146358  save_180000  save_210000  save_80000
save_100189  save_125000  save_147866  save_184582  save_215000  save_85000
save_100481  save_128431  save_150000  save_185000  save_219392  save_90000
save_100728  save_130000  save_151422  save_185146  save_220000  save_95000
save_104221  save_132767  save_155000  save_190000  save_225000  save_98268
save_105000  save_133754  save_160000  save_195000  save_228147  save_98508
save_110000  save_135000  save_165000  save_196835  save_230000
save_115000  save_137981  save_170000  save_199144  save_232353
save_120000  save_138253  save_170017  save_200000  save_235000
save_123593  save_140000  save_175000  save_201253  save_239735


In [6]:
stt_hparams, tts_hparams = create_hparams()
model = Model(stt_hparams, tts_hparams, mode='train')
model = model.cuda()
step = 200000

size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters()))
print(f"Model size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.parameters()))
print(f"TTS size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.mel_decoder.parameters()))
print(f"MelDecoder size {size}")

if True:
    model, _, _ = logger.load(step, model, None)
print(step)

print('done')

Model size 283.0MiB
TTS size 195.3MiB
MelDecoder size 119.0MiB
loaded : 200000
200000
done


In [10]:
testset = LJDataset(tts_hparams, split='valid')
collate_fn = TextMelCollate(tts_hparams)

test_loader = torch.utils.data.DataLoader(testset, num_workers=1, 
                          shuffle=False, sampler=None, batch_size=1, pin_memory=False,
                          drop_last=True, collate_fn=collate_fn)
print(test_loader)

<torch.utils.data.dataloader.DataLoader object at 0x7f4fc2dc3ee0>


In [11]:
def to_cuda(batch):
    batch['text'] = batch['text'].cuda()
    batch['text_lengths'] = batch['text_lengths'].cuda()
    batch['mels'] = batch['mels'].cuda()
    batch['mel_lengths'] = batch['mel_lengths'].cuda()
    
    return batch

In [12]:
import matplotlib.pyplot as plt
import librosa.display
import numpy as np

ce_losses = []
recon_losses = []
kl_losses = []
sample_stds = []
model.eval()
with torch.no_grad():
    for i, batch in enumerate(test_loader):
        batch = to_cuda(batch)
        stt_outputs, tts_outputs = model(batch)
        ce_loss = stt_outputs['loss'].item()
        recon_loss = tts_outputs['recon_loss'].item()
        kl_loss = tts_outputs['kl_loss'].item()
        print(i, recon_loss)
        
        ce_losses.append(ce_loss)
        recon_losses.append(recon_loss)
        kl_losses.append(kl_loss)
        
        model.eval()
        
        samples_list = []
        for _ in range(10):
            samples, _ = model.inference(batch['text'], batch['mels'].size(2), stt_outputs["alignments"], temperature=1.0, clip=2)
            samples_list.append(samples)
        samples_list = torch.cat(samples_list, dim=0)
        sample_std = torch.std(samples_list, dim=0).mean().item()
        sample_stds.append(sample_std)
            
        
print('done')

print(np.mean(ce_losses))
print(np.mean(recon_losses))
print(np.mean(kl_losses))
print(np.mean(sample_stds))

0 0.055360015481710434
1 0.053552135825157166
2 0.05508439242839813
3 0.056137412786483765
4 0.05431392788887024
5 0.05758911743760109
6 0.06160475313663483
7 0.05369510129094124
8 0.056176479905843735
9 0.05095161870121956
10 0.05450725927948952
11 0.05487013980746269
12 0.052675072103738785
13 0.054262734949588776
14 0.049854032695293427
15 0.05551094189286232
16 0.06074507161974907
17 0.06093020737171173
18 0.05794273689389229
19 0.05755509436130524
20 0.0579526349902153
21 0.05392006039619446
22 0.054319337010383606
23 0.054613299667835236
24 0.061138980090618134
25 0.06648078560829163
26 0.054453495889902115
27 0.05611124634742737
28 0.057615846395492554
29 0.05390026047825813
30 0.05634808912873268
31 0.05243983492255211
32 0.05960921570658684
33 0.052072349935770035
34 0.0536082424223423
35 0.05515662953257561
36 0.049184199422597885
37 0.059639569371938705
38 0.05464186146855354
39 0.0530659481883049
40 0.06321576237678528
41 0.05493394285440445
42 0.05035127326846123
43 0.0491

343 0.05254712328314781
344 0.056176427751779556
345 0.05591726303100586
346 0.05744358152151108
347 0.06076287478208542
348 0.05237362161278725
349 0.05329403281211853
350 0.050966519862413406
351 0.05878521874547005
352 0.053727272897958755
353 0.054442744702100754
354 0.05451289936900139
355 0.052459511905908585
356 0.05328791216015816
357 0.0536615364253521
358 0.05119894817471504
359 0.054957397282123566
360 0.05504031851887703
361 0.051629021763801575
362 0.054395899176597595
363 0.05013427138328552
364 0.056110456585884094
365 0.05895585939288139
366 0.054290443658828735
367 0.054786454886198044
368 0.05810907483100891
369 0.05533989146351814
370 0.05611250922083855
371 0.059923816472291946
372 0.056331876665353775
373 0.05623917281627655
374 0.052947234362363815
375 0.05636025592684746
376 0.06006895750761032
377 0.05161459371447563
378 0.0520128533244133
379 0.0525168851017952
380 0.05468742921948433
381 0.05898292735219002
382 0.05364711582660675
383 0.05615691840648651
384 0