In [1]:
import warnings
warnings.filterwarnings('ignore')

import os
import torch
from torch import nn

from hparams.hparams_S4G import create_hparams
from model import Model
from datasets import LJDataset, TextMelCollate
from utils import sizeof_fmt, Logger

In [2]:
!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="1"

Sat Mar 13 21:58:27 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 208...  Off  | 00000000:2E:00.0 Off |                  N/A |
| 57%   67C    P2   129W / 250W |   2589MiB / 11016MiB |     27%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  GeForce RTX 208...  Off  | 00000000:2F:00.0 Off |                  N/A |
| 52%   61C    P2   124W / 250W |   1536MiB / 11019MiB |      4%      Defaul

In [3]:
save_dir = '/data/save/model_S4G'
logger = Logger(save_dir=save_dir, new=False)
print('done')

done


In [4]:
!ls $save_dir

data.json    save_100096  save_65000  save_75000  save_85000  save_95000
save_100000  save_61887   save_70000  save_80000  save_90000  save_96091


In [5]:
stt_hparams, tts_hparams = create_hparams()
model = Model(stt_hparams, tts_hparams, mode='train')
model = model.cuda()
step = 100000

size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters()))
print(f"Model size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.parameters()))
print(f"TTS size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.mel_decoder.parameters()))
print(f"MelDecoder size {size}")

if True:
    model, _, _ = logger.load(step, model, None)
print(step)

print('done')

Model size 243.8MiB
TTS size 156.2MiB
MelDecoder size 94.9MiB
loaded : 100000
100000
done


In [6]:
testset = LJDataset(tts_hparams, split='test')
collate_fn = TextMelCollate(tts_hparams)

test_loader = torch.utils.data.DataLoader(testset, num_workers=1, 
                          shuffle=False, sampler=None, batch_size=1, pin_memory=False,
                          drop_last=True, collate_fn=collate_fn)
print(test_loader)

<torch.utils.data.dataloader.DataLoader object at 0x7fb3d02a25b0>


In [7]:
def to_cuda(batch):
    batch['text'] = batch['text'].cuda()
    batch['text_lengths'] = batch['text_lengths'].cuda()
    batch['mels'] = batch['mels'].cuda()
    batch['mel_lengths'] = batch['mel_lengths'].cuda()
    
    return batch

In [8]:
import matplotlib.pyplot as plt
import librosa.display
import numpy as np

ce_losses = []
recon_losses = []
kl_losses = []
sample_stds = []
model.eval()
with torch.no_grad():
    for i, batch in enumerate(test_loader):
        batch = to_cuda(batch)
        stt_outputs, tts_outputs = model(batch)
        ce_loss = stt_outputs['loss'].item()
        recon_loss = tts_outputs['recon_loss'].item()
        kl_loss = tts_outputs['kl_loss'].item()
        print(i, recon_loss)
        
        ce_losses.append(ce_loss)
        recon_losses.append(recon_loss)
        kl_losses.append(kl_loss)
        
        model.eval()
        
        samples_list = []
        for _ in range(10):
            samples, _ = model.inference(batch['text'], batch['mels'].size(2), stt_outputs["alignments"], temperature=1.0, clip=2)
            samples_list.append(samples)
        samples_list = torch.cat(samples_list, dim=0)
        sample_std = torch.std(samples_list, dim=0).mean().item()
        sample_stds.append(sample_std)
            
        
print('done')

print(np.mean(ce_losses))
print(np.mean(recon_losses))
print(np.mean(kl_losses))
print(np.mean(sample_stds))

0 0.05386738479137421
1 0.05480808764696121
2 0.051323115825653076
3 0.05175279080867767
4 0.05337399244308472
5 0.04903625324368477
6 0.05092424154281616
7 0.05146056041121483
8 0.052417006343603134
9 0.05441701412200928
10 0.05192522704601288
11 0.04811131954193115
12 0.04949634149670601
13 0.05234526842832565
14 0.04983225464820862
15 0.05503953993320465
16 0.05789179354906082
17 0.05157534033060074
18 0.05468757823109627
19 0.051282040774822235
20 0.05174049362540245
21 0.05738746002316475
22 0.04748603329062462
23 0.051538992673158646
24 0.052718427032232285
25 0.05030687153339386
26 0.050646401941776276
27 0.046920403838157654
28 0.05291299149394035
29 0.051630500704050064
30 0.05245724692940712
31 0.05297877639532089
32 0.050569768995046616
33 0.04811513051390648
34 0.06570583581924438
35 0.05495858192443848
36 0.052461713552474976
37 0.05372517928481102
38 0.05085716024041176
39 0.0490577258169651
40 0.05176404118537903
41 0.05186259001493454
42 0.0578429289162159
43 0.05447462

343 0.04319709166884422
344 0.04631969332695007
345 0.044675376266241074
346 0.045138660818338394
347 0.04924171417951584
done
0.12111433699945745
0.04852541181761986
0.015687642067713642
0.28284479189535666
