In [1]:
import warnings
warnings.filterwarnings('ignore')

import os
import torch
from torch import nn

from hparams.hparams_S4G import create_hparams
from model import Model
from datasets import LJDataset, TextMelCollate
from utils import sizeof_fmt, Logger

In [2]:
!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="1"

Sat Mar 13 21:03:19 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 208...  Off  | 00000000:2E:00.0 Off |                  N/A |
| 57%   59C    P8    43W / 250W |   3265MiB / 11016MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  GeForce RTX 208...  Off  | 00000000:2F:00.0 Off |                  N/A |
| 51%   54C    P8    22W / 250W |   1252MiB / 11019MiB |      0%      Defaul

In [3]:
save_dir = '/data/save/model_S4G'
logger = Logger(save_dir=save_dir, new=False)
print('done')

done


In [4]:
!ls $save_dir

data.json    save_100096  save_65000  save_75000  save_85000  save_95000
save_100000  save_61887   save_70000  save_80000  save_90000  save_96091


In [5]:
stt_hparams, tts_hparams = create_hparams()
model = Model(stt_hparams, tts_hparams, mode='train')
model = model.cuda()
step = 100000

size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters()))
print(f"Model size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.parameters()))
print(f"TTS size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.mel_decoder.parameters()))
print(f"MelDecoder size {size}")

if True:
    model, _, _ = logger.load(step, model, None)
print(step)

print('done')

Model size 243.8MiB
TTS size 156.2MiB
MelDecoder size 94.9MiB
loaded : 100000
100000
done


In [6]:
testset = LJDataset(tts_hparams, split='test')
collate_fn = TextMelCollate(tts_hparams)

test_loader = torch.utils.data.DataLoader(testset, num_workers=1, 
                          shuffle=False, sampler=None, batch_size=1, pin_memory=False,
                          drop_last=True, collate_fn=collate_fn)
print(test_loader)

<torch.utils.data.dataloader.DataLoader object at 0x7f2d04390b50>


In [7]:
def to_cuda(batch):
    batch['text'] = batch['text'].cuda()
    batch['text_lengths'] = batch['text_lengths'].cuda()
    batch['mels'] = batch['mels'].cuda()
    batch['mel_lengths'] = batch['mel_lengths'].cuda()
    
    return batch

In [12]:
import matplotlib.pyplot as plt
import librosa.display
import numpy as np

ce_losses = []
recon_losses = []
kl_losses = []
model.eval()
with torch.no_grad():
    for i, batch in enumerate(test_loader):
        batch = to_cuda(batch)
        stt_outputs, tts_outputs = model(batch)
        ce_loss = stt_outputs['loss'].item()
        recon_loss = tts_outputs['recon_loss'].item()
        kl_loss = tts_outputs['kl_loss'].item()
        print(i, recon_loss)
        
        ce_losses.append(ce_loss)
        recon_losses.append(recon_loss)
        kl_losses.append(kl_loss)
        
print('done')

0 0.052864715456962585
1 0.0544963963329792
2 0.05121661722660065
3 0.05131535977125168
4 0.05239732936024666
5 0.049047697335481644
6 0.050235550850629807
7 0.05173332989215851
8 0.051472630351781845
9 0.053450167179107666
10 0.051915910094976425
11 0.04981265962123871
12 0.04805361106991768
13 0.05450405180454254
14 0.05013661086559296
15 0.0553056001663208
16 0.05948622524738312
17 0.051566798239946365
18 0.05343294516205788
19 0.051177579909563065
20 0.05261608585715294
21 0.0541599839925766
22 0.04765339940786362
23 0.051880404353141785
24 0.05120086297392845
25 0.05179828777909279
26 0.051317375153303146
27 0.047384414821863174
28 0.05447123944759369
29 0.05219137296080589
30 0.051852062344551086
31 0.051634132862091064
32 0.04912685602903366
33 0.0482877641916275
34 0.06482241302728653
35 0.05529700964689255
36 0.05304110050201416
37 0.05337778478860855
38 0.0508897565305233
39 0.04937813803553581
40 0.052334263920784
41 0.0520012266933918
42 0.05646272003650665
43 0.05498070642

344 0.04655865207314491
345 0.045198775827884674
346 0.044914428144693375
347 0.04988231509923935
done


In [13]:
print(np.mean(ce_losses))
print(np.mean(recon_losses))
print(np.mean(kl_losses))

0.12111433699945745
0.04851238742128186
0.015692466489927864
