In [1]:
import warnings
warnings.filterwarnings('ignore')

import os
import torch
from torch import nn

from hparams.hparams_S5G import create_hparams
from model import Model
from datasets import LJDataset, TextMelCollate
from utils import sizeof_fmt, Logger

In [2]:
!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="1"

Fri Mar 19 12:55:20 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 208...  Off  | 00000000:2E:00.0 Off |                  N/A |
| 43%   61C    P0    74W / 250W |      0MiB / 11016MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  GeForce RTX 208...  Off  | 00000000:2F:00.0 Off |                  N/A |
| 48%   56C    P2   133W / 250W |   1802MiB / 11019MiB |     41%      Defaul

In [3]:
save_dir = '/data/save/model_S5G'
logger = Logger(save_dir=save_dir, new=False)
print('done')

done


In [4]:
!ls $save_dir

data.json    save_135000  save_185000  save_239735  save_290000  save_360000
save_100000  save_137981  save_185146  save_239809  save_295000  save_362255
save_100189  save_138253  save_190000  save_240000  save_300000  save_363423
save_100481  save_140000  save_195000  save_242209  save_305000  save_365000
save_100728  save_145000  save_196835  save_245000  save_306501  save_370000
save_104221  save_146358  save_199144  save_247546  save_310000  save_375000
save_105000  save_147866  save_200000  save_248082  save_315000  save_380000
save_110000  save_150000  save_201253  save_250000  save_320000  save_385000
save_115000  save_151422  save_205000  save_255000  save_324281  save_387320
save_120000  save_155000  save_210000  save_258351  save_325000  save_400000
save_123593  save_160000  save_215000  save_258531  save_330000  save_75000
save_123671  save_165000  save_219392  save_260000  save_335000  save_80000
save_123692  save_170000  save_220000  save_265000  save_340000  s

In [5]:
stt_hparams, tts_hparams = create_hparams()
model = Model(stt_hparams, tts_hparams, mode='train')
model = model.cuda()
step = 400000

size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters()))
print(f"Model size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.parameters()))
print(f"TTS size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.mel_decoder.parameters()))
print(f"MelDecoder size {size}")

if True:
    model, _, _ = logger.load(step, model, None)
print(step)

print('done')

Model size 283.0MiB
TTS size 195.3MiB
MelDecoder size 119.0MiB
loaded : 400000
400000
done


In [9]:
trainset = LJDataset(tts_hparams, split='train')
testset = LJDataset(tts_hparams, split='valid')
collate_fn = TextMelCollate(tts_hparams)

train_loader = torch.utils.data.DataLoader(trainset, num_workers=1, #tts_hparams.num_workers, 
                          shuffle=True, sampler=None, batch_size=tts_hparams.batch_size, pin_memory=False,
                          drop_last=True, collate_fn=collate_fn)
print(train_loader)

test_loader = torch.utils.data.DataLoader(testset, num_workers=1, 
                          shuffle=False, sampler=None, batch_size=1, pin_memory=False,
                          drop_last=True, collate_fn=collate_fn)
print(test_loader)

<torch.utils.data.dataloader.DataLoader object at 0x7f19548e8fa0>
<torch.utils.data.dataloader.DataLoader object at 0x7f194951cd60>


In [10]:
def to_cuda(batch):
    batch['text'] = batch['text'].cuda()
    batch['text_lengths'] = batch['text_lengths'].cuda()
    batch['mels'] = batch['mels'].cuda()
    batch['mel_lengths'] = batch['mel_lengths'].cuda()
    
    return batch

In [11]:
import matplotlib.pyplot as plt
import librosa.display
import numpy as np

ce_losses = []
recon_losses = []
kl_losses = []
sample_stds = []
model.eval()
with torch.no_grad():
    for i, batch in enumerate(train_loader):
        batch = to_cuda(batch)
        stt_outputs, tts_outputs = model(batch)
        ce_loss = stt_outputs['loss'].item()
        recon_loss = tts_outputs['recon_loss'].item()
        kl_loss = tts_outputs['kl_loss'].item()
        print(i, recon_loss)
        
        ce_losses.append(ce_loss)
        recon_losses.append(recon_loss)
        kl_losses.append(kl_loss)
        
        model.eval()
        
        samples_list = []
        for _ in range(10):
            samples, _ = model.inference(batch['text'], batch['mels'].size(2), stt_outputs["alignments"], temperature=1.0, clip=2)
            samples_list.append(samples)
        samples_list = torch.cat(samples_list, dim=0)
        sample_std = torch.std(samples_list, dim=0).mean().item()
        sample_stds.append(sample_std)
            
        
print('done')

print(np.mean(ce_losses))
print(np.mean(recon_losses))
print(np.mean(kl_losses))
print(np.mean(sample_stds))

0 0.022026753053069115
1 0.023143725469708443
2 0.022638076916337013
3 0.019688107073307037
4 0.022250201553106308
5 0.021832916885614395
6 0.021876441314816475
7 0.023222211748361588
8 0.022262725979089737
9 0.023279467597603798
10 0.025753909721970558
11 0.02520756423473358
12 0.021154791116714478
13 0.021061643958091736
14 0.022246038541197777
15 0.021627606824040413
16 0.023495744913816452
17 0.02398715540766716
18 0.02214910462498665
19 0.0219180416315794
20 0.021468905732035637
21 0.021053005009889603
22 0.019942184910178185
23 0.02233804762363434
24 0.02131296694278717
25 0.0219963900744915
26 0.023628881201148033
27 0.021848473697900772
28 0.023121705278754234
29 0.021480225026607513
30 0.0216981153935194
31 0.021800700575113297
32 0.023176835849881172
33 0.020909324288368225
34 0.022546594962477684
35 0.022049767896533012
36 0.02314898744225502
37 0.023999875411391258
38 0.024075450375676155
39 0.022620564326643944
40 0.02205466665327549
41 0.023304320871829987
42 0.0233996026

338 0.024554356932640076
339 0.021125182509422302
340 0.02081279829144478
341 0.021791130304336548
342 0.02172834612429142
343 0.021340904757380486
344 0.024392636492848396
345 0.022142978385090828
346 0.02368086948990822
347 0.02344701811671257
348 0.02167026326060295
349 0.022661268711090088
350 0.02358611300587654
351 0.021893152967095375
352 0.021914055570960045
353 0.022352205589413643
354 0.021395012736320496
355 0.02091001532971859
356 0.019794469699263573
357 0.020154280588030815
358 0.022168368101119995
359 0.02251810021698475
360 0.02427622117102146
361 0.021778812631964684
362 0.023538541048765182
363 0.019926361739635468
364 0.020464381203055382
365 0.0211489275097847
366 0.021440884098410606
367 0.022309815511107445
368 0.022372350096702576
369 0.021567273885011673
370 0.022497011348605156
371 0.023387735709547997
372 0.023825019598007202
373 0.024781253188848495
374 0.023304995149374008
375 0.021476926282048225
376 0.02065306529402733
377 0.022474532946944237
378 0.022265