In [1]:
import warnings
warnings.filterwarnings('ignore')

import os
import torch
from torch import nn

from hparams.hparams_S5G import create_hparams
from model import Model
from datasets import LJDataset, TextMelCollate
from utils import sizeof_fmt, Logger

In [2]:
!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="0"

Sat Mar 13 20:57:39 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 208...  Off  | 00000000:2E:00.0 Off |                  N/A |
| 57%   68C    P2   116W / 250W |   1686MiB / 11016MiB |     32%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  GeForce RTX 208...  Off  | 00000000:2F:00.0 Off |                  N/A |
| 32%   60C    P0    49W / 250W |      0MiB / 11019MiB |      0%      Defaul

In [4]:
save_dir = '/data/save/model_S5G'
logger = Logger(save_dir=save_dir, new=False)
print('done')

done


In [5]:
!ls $save_dir

data.json    save_104221  save_123593  save_130000  save_75000	save_98268
save_100000  save_105000  save_123671  save_132767  save_80000	save_98508
save_100189  save_110000  save_123692  save_133754  save_85000
save_100481  save_115000  save_125000  save_135000  save_90000
save_100728  save_120000  save_128431  save_137981  save_95000


In [6]:
stt_hparams, tts_hparams = create_hparams()
model = Model(stt_hparams, tts_hparams, mode='train')
model = model.cuda()
step = 100000

size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters()))
print(f"Model size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.parameters()))
print(f"TTS size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.mel_decoder.parameters()))
print(f"MelDecoder size {size}")

if True:
    model, _, _ = logger.load(step, model, None)
print(step)

print('done')

Model size 283.0MiB
TTS size 195.3MiB
MelDecoder size 119.0MiB
loaded : 100000
100000
done


In [7]:
testset = LJDataset(tts_hparams, split='test')
collate_fn = TextMelCollate(tts_hparams)

test_loader = torch.utils.data.DataLoader(testset, num_workers=1, 
                          shuffle=False, sampler=None, batch_size=1, pin_memory=False,
                          drop_last=True, collate_fn=collate_fn)
print(test_loader)

<torch.utils.data.dataloader.DataLoader object at 0x7f22e54cd1f0>


In [8]:
def to_cuda(batch):
    batch['text'] = batch['text'].cuda()
    batch['text_lengths'] = batch['text_lengths'].cuda()
    batch['mels'] = batch['mels'].cuda()
    batch['mel_lengths'] = batch['mel_lengths'].cuda()
    
    return batch

In [14]:
import matplotlib.pyplot as plt
import librosa.display
import numpy as np

ce_losses = []
recon_losses = []
kl_losses = []
model.eval()
with torch.no_grad():
    for i, batch in enumerate(test_loader):
        batch = to_cuda(batch)
        stt_outputs, tts_outputs = model(batch)
        ce_loss = stt_outputs['loss'].item()
        recon_loss = tts_outputs['recon_loss'].item()
        kl_loss = tts_outputs['kl_loss'].item()
        print(i, recon_loss)
        
        ce_losses.append(ce_loss)
        recon_losses.append(recon_loss)
        kl_losses.append(kl_loss)
        
print('done')

0 0.0541694238781929
1 0.055544547736644745
2 0.054068852216005325
3 0.04986480996012688
4 0.054749783128499985
5 0.050514306873083115
6 0.05077545717358589
7 0.05132405087351799
8 0.05162058025598526
9 0.050808876752853394
10 0.05451950803399086
11 0.04633314535021782
12 0.04832647368311882
13 0.053541649132966995
14 0.050139591097831726
15 0.055215924978256226
16 0.05895960330963135
17 0.05089988559484482
18 0.05035927891731262
19 0.050800297409296036
20 0.05246121436357498
21 0.05786382034420967
22 0.04842120036482811
23 0.05350016430020332
24 0.05234437808394432
25 0.05177300050854683
26 0.052850205451250076
27 0.04785546287894249
28 0.055966414511203766
29 0.051394328474998474
30 0.052268486469984055
31 0.052504729479551315
32 0.050145067274570465
33 0.049484848976135254
34 0.0515691339969635
35 0.05443224683403969
36 0.051830559968948364
37 0.053973257541656494
38 0.05145539343357086
39 0.04824238643050194
40 0.051234904676675797
41 0.05279374122619629
42 0.05373697355389595
43 0

344 0.044365569949150085
345 0.043613310903310776
346 0.0440126396715641
347 0.04910952225327492
done


In [15]:
print(np.mean(ce_losses))
print(np.mean(recon_losses))
print(np.mean(kl_losses))

0.075774136133762
0.048784357335033085
0.01530142380299325
