In [1]:
import warnings
warnings.filterwarnings('ignore')

import os
import torch
from torch import nn

from hparams.hparams_W5G import create_hparams
from model import Model
from datasets import LJDataset, TextMelCollate
from utils import sizeof_fmt, Logger

In [2]:
!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="0"

Sat Mar 13 21:58:18 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 208...  Off  | 00000000:2E:00.0 Off |                  N/A |
| 51%   65C    P0    97W / 250W |      0MiB / 11016MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  GeForce RTX 208...  Off  | 00000000:2F:00.0 Off |                  N/A |
| 30%   59C    P0    49W / 250W |      0MiB / 11019MiB |      0%      Default |
|       

In [3]:
save_dir = '/data/save/model_W5G'
logger = Logger(save_dir=save_dir, new=False)
print('done')

done


In [4]:
!ls $save_dir

data.json    save_20951  save_40000  save_60000  save_77732  save_95000
save_100000  save_25000  save_42922  save_65000  save_80000  save_96773
save_100003  save_30000  save_45000  save_67743  save_85000
save_15000   save_35000  save_50000  save_70000  save_85982
save_20000   save_38005  save_55000  save_75000  save_90000


In [5]:
stt_hparams, tts_hparams = create_hparams()
model = Model(stt_hparams, tts_hparams, mode='train')
model = model.cuda()
step = 100000

size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters()))
print(f"Model size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.parameters()))
print(f"TTS size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.mel_decoder.parameters()))
print(f"MelDecoder size {size}")

if True:
    model, _, _ = logger.load(step, model, None)
print(step)

print('done')

Model size 132.9MiB
TTS size 45.2MiB
MelDecoder size 28.7MiB
loaded : 100000
100000
done


In [6]:
testset = LJDataset(tts_hparams, split='test')
collate_fn = TextMelCollate(tts_hparams)

test_loader = torch.utils.data.DataLoader(testset, num_workers=1, 
                          shuffle=False, sampler=None, batch_size=1, pin_memory=False,
                          drop_last=True, collate_fn=collate_fn)
print(test_loader)

<torch.utils.data.dataloader.DataLoader object at 0x7efd703a4400>


In [7]:
def to_cuda(batch):
    batch['text'] = batch['text'].cuda()
    batch['text_lengths'] = batch['text_lengths'].cuda()
    batch['mels'] = batch['mels'].cuda()
    batch['mel_lengths'] = batch['mel_lengths'].cuda()
    
    return batch

In [8]:
import matplotlib.pyplot as plt
import librosa.display
import numpy as np

ce_losses = []
recon_losses = []
kl_losses = []
sample_stds = []
model.eval()
with torch.no_grad():
    for i, batch in enumerate(test_loader):
        batch = to_cuda(batch)
        stt_outputs, tts_outputs = model(batch)
        ce_loss = stt_outputs['loss'].item()
        recon_loss = tts_outputs['recon_loss'].item()
        kl_loss = tts_outputs['kl_loss'].item()
        print(i, recon_loss)
        
        ce_losses.append(ce_loss)
        recon_losses.append(recon_loss)
        kl_losses.append(kl_loss)
        
        model.eval()
        
        samples_list = []
        for _ in range(10):
            samples, _ = model.inference(batch['text'], batch['mels'].size(2), stt_outputs["alignments"], temperature=1.0, clip=2)
            samples_list.append(samples)
        samples_list = torch.cat(samples_list, dim=0)
        sample_std = torch.std(samples_list, dim=0).mean().item()
        sample_stds.append(sample_std)
            
        
print('done')

print(np.mean(ce_losses))
print(np.mean(recon_losses))
print(np.mean(kl_losses))
print(np.mean(sample_stds))

0 0.05123396962881088
1 0.05228736251592636
2 0.050009384751319885
3 0.044017791748046875
4 0.04984109848737717
5 0.04742718115448952
6 0.048378605395555496
7 0.04735332727432251
8 0.04977024346590042
9 0.04879579693078995
10 0.049075689166784286
11 0.04574919864535332
12 0.045801471918821335
13 0.049523577094078064
14 0.047937288880348206
15 0.05100790038704872
16 0.05178368464112282
17 0.04863189160823822
18 0.04938217252492905
19 0.04892591014504433
20 0.04840421676635742
21 0.05158684402704239
22 0.045918211340904236
23 0.05128055438399315
24 0.048318710178136826
25 0.04856890067458153
26 0.04713268578052521
27 0.043809302151203156
28 0.048022374510765076
29 0.0494854710996151
30 0.04859326779842377
31 0.048781171441078186
32 0.04797697812318802
33 0.04451070725917816
34 0.04695217311382294
35 0.04895850270986557
36 0.04811578243970871
37 0.05001342296600342
38 0.04852979630231857
39 0.04456247389316559
40 0.04733704403042793
41 0.04877296835184097
42 0.05254240706562996
43 0.05202

343 0.04076365754008293
344 0.04368073493242264
345 0.039602067321538925
346 0.04253411293029785
347 0.043406080454587936
done
0.11090015389211542
0.0453229375804464
0.015003038059127912
0.2793006016605202
