In [1]:
import warnings
warnings.filterwarnings('ignore')

import os
import torch
from torch import nn

from hparams.hparams_S4L import create_hparams
from model import Model
from datasets import LJDataset, TextMelCollate
from utils import sizeof_fmt, Logger

In [2]:
!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="1"

Sat Mar 13 21:03:52 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 208...  Off  | 00000000:2E:00.0 Off |                  N/A |
| 56%   59C    P8    44W / 250W |   3265MiB / 11016MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  GeForce RTX 208...  Off  | 00000000:2F:00.0 Off |                  N/A |
| 52%   61C    P2   112W / 250W |   1800MiB / 11019MiB |     34%      Defaul

In [3]:
save_dir = '/data/save/model_S4L'
logger = Logger(save_dir=save_dir, new=False)
print('done')

done


In [4]:
!ls $save_dir

data.json    save_100817  save_65000  save_75000  save_85000  save_94076
save_100000  save_60377   save_70000  save_80000  save_90000  save_95000


In [5]:
stt_hparams, tts_hparams = create_hparams()
model = Model(stt_hparams, tts_hparams, mode='train')
model = model.cuda()
step = 100000

size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters()))
print(f"Model size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.parameters()))
print(f"TTS size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.mel_decoder.parameters()))
print(f"MelDecoder size {size}")

if True:
    model, _, _ = logger.load(step, model, None)
print(step)

print('done')

Model size 243.8MiB
TTS size 156.2MiB
MelDecoder size 94.9MiB
loaded : 100000
100000
done


In [6]:
testset = LJDataset(tts_hparams, split='test')
collate_fn = TextMelCollate(tts_hparams)

test_loader = torch.utils.data.DataLoader(testset, num_workers=1, 
                          shuffle=False, sampler=None, batch_size=1, pin_memory=False,
                          drop_last=True, collate_fn=collate_fn)
print(test_loader)

<torch.utils.data.dataloader.DataLoader object at 0x7fddb7136700>


In [7]:
def to_cuda(batch):
    batch['text'] = batch['text'].cuda()
    batch['text_lengths'] = batch['text_lengths'].cuda()
    batch['mels'] = batch['mels'].cuda()
    batch['mel_lengths'] = batch['mel_lengths'].cuda()
    
    return batch

In [13]:
import matplotlib.pyplot as plt
import librosa.display
import numpy as np

ce_losses = []
recon_losses = []
kl_losses = []
model.eval()
with torch.no_grad():
    for i, batch in enumerate(test_loader):
        batch = to_cuda(batch)
        stt_outputs, tts_outputs = model(batch)
        ce_loss = stt_outputs['loss'].item()
        recon_loss = tts_outputs['recon_loss'].item()
        kl_loss = tts_outputs['kl_loss'].item()
        print(i, recon_loss)
        
        ce_losses.append(ce_loss)
        recon_losses.append(recon_loss)
        kl_losses.append(kl_loss)
        
print('done')

0 0.05272858962416649
1 0.05557969585061073
2 0.05278673395514488
3 0.05208509415388107
4 0.054941438138484955
5 0.05106896907091141
6 0.05389751121401787
7 0.053070854395627975
8 0.051431771367788315
9 0.050334177911281586
10 0.052468348294496536
11 0.04768550395965576
12 0.05029032751917839
13 0.054752666503190994
14 0.05152641609311104
15 0.053270258009433746
16 0.06171341985464096
17 0.05198776721954346
18 0.05307072773575783
19 0.05128839239478111
20 0.05234328284859657
21 0.05551302805542946
22 0.047177769243717194
23 0.05117900297045708
24 0.05248289182782173
25 0.05093435198068619
26 0.05164570361375809
27 0.04600267484784126
28 0.05417192354798317
29 0.05225495621562004
30 0.05178442224860191
31 0.05376122519373894
32 0.05019319802522659
33 0.05119403451681137
34 0.07113347202539444
35 0.053953882306814194
36 0.05186666548252106
37 0.05534675717353821
38 0.052564386278390884
39 0.05061803013086319
40 0.05019264295697212
41 0.05269106850028038
42 0.05647304281592369
43 0.058270

344 0.04667556658387184
345 0.043855130672454834
346 0.04519238695502281
347 0.04903966560959816
done


In [14]:
print(np.mean(ce_losses))
print(np.mean(recon_losses))
print(np.mean(kl_losses))

0.13321736186010943
0.0489136982366614
0.01571730160604006
