In [1]:
import warnings
warnings.filterwarnings('ignore')

import os
import torch
from torch import nn

from hparams.hparams_S5G import create_hparams
from model import Model
from datasets import LJDataset, TextMelCollate
from utils import sizeof_fmt, Logger

In [2]:
!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="0"

Sat Mar 13 21:58:24 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 208...  Off  | 00000000:2E:00.0 Off |                  N/A |
| 56%   65C    P2    96W / 250W |   1138MiB / 11016MiB |      7%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  GeForce RTX 208...  Off  | 00000000:2F:00.0 Off |                  N/A |
| 52%   60C    P2    73W / 250W |    634MiB / 11019MiB |      8%      Defaul

In [3]:
save_dir = '/data/save/model_S5G'
logger = Logger(save_dir=save_dir, new=False)
print('done')

done


In [4]:
!ls $save_dir

data.json    save_104221  save_123593  save_130000  save_75000	save_98268
save_100000  save_105000  save_123671  save_132767  save_80000	save_98508
save_100189  save_110000  save_123692  save_133754  save_85000
save_100481  save_115000  save_125000  save_135000  save_90000
save_100728  save_120000  save_128431  save_137981  save_95000


In [5]:
stt_hparams, tts_hparams = create_hparams()
model = Model(stt_hparams, tts_hparams, mode='train')
model = model.cuda()
step = 100000

size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters()))
print(f"Model size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.parameters()))
print(f"TTS size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.mel_decoder.parameters()))
print(f"MelDecoder size {size}")

if True:
    model, _, _ = logger.load(step, model, None)
print(step)

print('done')

Model size 283.0MiB
TTS size 195.3MiB
MelDecoder size 119.0MiB
loaded : 100000
100000
done


In [6]:
testset = LJDataset(tts_hparams, split='test')
collate_fn = TextMelCollate(tts_hparams)

test_loader = torch.utils.data.DataLoader(testset, num_workers=1, 
                          shuffle=False, sampler=None, batch_size=1, pin_memory=False,
                          drop_last=True, collate_fn=collate_fn)
print(test_loader)

<torch.utils.data.dataloader.DataLoader object at 0x7feca8343280>


In [7]:
def to_cuda(batch):
    batch['text'] = batch['text'].cuda()
    batch['text_lengths'] = batch['text_lengths'].cuda()
    batch['mels'] = batch['mels'].cuda()
    batch['mel_lengths'] = batch['mel_lengths'].cuda()
    
    return batch

In [8]:
import matplotlib.pyplot as plt
import librosa.display
import numpy as np

ce_losses = []
recon_losses = []
kl_losses = []
sample_stds = []
model.eval()
with torch.no_grad():
    for i, batch in enumerate(test_loader):
        batch = to_cuda(batch)
        stt_outputs, tts_outputs = model(batch)
        ce_loss = stt_outputs['loss'].item()
        recon_loss = tts_outputs['recon_loss'].item()
        kl_loss = tts_outputs['kl_loss'].item()
        print(i, recon_loss)
        
        ce_losses.append(ce_loss)
        recon_losses.append(recon_loss)
        kl_losses.append(kl_loss)
        
        model.eval()
        
        samples_list = []
        for _ in range(10):
            samples, _ = model.inference(batch['text'], batch['mels'].size(2), stt_outputs["alignments"], temperature=1.0, clip=2)
            samples_list.append(samples)
        samples_list = torch.cat(samples_list, dim=0)
        sample_std = torch.std(samples_list, dim=0).mean().item()
        sample_stds.append(sample_std)
            
        
print('done')

print(np.mean(ce_losses))
print(np.mean(recon_losses))
print(np.mean(kl_losses))
print(np.mean(sample_stds))

0 0.054238639771938324
1 0.05548226088285446
2 0.05481388792395592
3 0.049313537776470184
4 0.05553204193711281
5 0.050328329205513
6 0.051741573959589005
7 0.05133785307407379
8 0.052792470902204514
9 0.05060416832566261
10 0.05493341013789177
11 0.04925775155425072
12 0.048077184706926346
13 0.053336337208747864
14 0.05012590438127518
15 0.055695295333862305
16 0.058303095400333405
17 0.05095816031098366
18 0.050051622092723846
19 0.05093294382095337
20 0.05248141661286354
21 0.06291216611862183
22 0.04887690022587776
23 0.05346674099564552
24 0.05205389857292175
25 0.05104187875986099
26 0.051297347992658615
27 0.046963032335042953
28 0.057194869965314865
29 0.050954997539520264
30 0.05279622972011566
31 0.05351537466049194
32 0.050626110285520554
33 0.04942727088928223
34 0.05219503864645958
35 0.05340008810162544
36 0.05295513570308685
37 0.052949629724025726
38 0.05130462348461151
39 0.049308087676763535
40 0.05057826265692711
41 0.0526852123439312
42 0.05403364077210426
43 0.055

343 0.04319530725479126
344 0.0455731600522995
345 0.043707702308893204
346 0.044133905321359634
347 0.04946621134877205
done
0.075774136133762
0.048824842665986766
0.01528154366820965
0.274776464991871
