In [1]:
import warnings
warnings.filterwarnings('ignore')

import os
import torch
from torch import nn

from hparams.hparams_W5G import create_hparams
from model import Model
from datasets import LJDataset, TextMelCollate
from utils import sizeof_fmt, Logger

In [2]:
!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="0"

Sat Mar 13 20:56:53 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 208...  Off  | 00000000:2E:00.0 Off |                  N/A |
| 55%   66C    P0    97W / 250W |      0MiB / 11016MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  GeForce RTX 208...  Off  | 00000000:2F:00.0 Off |                  N/A |
| 30%   59C    P0    50W / 250W |      0MiB / 11019MiB |      0%      Defaul

In [6]:
save_dir = '/data/save/model_W5G'
logger = Logger(save_dir=save_dir, new=False)
print('done')

done


In [7]:
!ls $save_dir

data.json    save_20951  save_40000  save_60000  save_77732  save_95000
save_100000  save_25000  save_42922  save_65000  save_80000  save_96773
save_100003  save_30000  save_45000  save_67743  save_85000
save_15000   save_35000  save_50000  save_70000  save_85982
save_20000   save_38005  save_55000  save_75000  save_90000


In [8]:
stt_hparams, tts_hparams = create_hparams()
model = Model(stt_hparams, tts_hparams, mode='train')
model = model.cuda()
step = 100000

size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters()))
print(f"Model size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.parameters()))
print(f"TTS size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.mel_decoder.parameters()))
print(f"MelDecoder size {size}")

if True:
    model, _, _ = logger.load(step, model, None)
print(step)

print('done')

Model size 132.9MiB
TTS size 45.2MiB
MelDecoder size 28.7MiB
loaded : 100000
100000
done


In [9]:
testset = LJDataset(tts_hparams, split='test')
collate_fn = TextMelCollate(tts_hparams)

test_loader = torch.utils.data.DataLoader(testset, num_workers=1, 
                          shuffle=False, sampler=None, batch_size=1, pin_memory=False,
                          drop_last=True, collate_fn=collate_fn)
print(test_loader)

<torch.utils.data.dataloader.DataLoader object at 0x7fbf87cd3fd0>


In [10]:
def to_cuda(batch):
    batch['text'] = batch['text'].cuda()
    batch['text_lengths'] = batch['text_lengths'].cuda()
    batch['mels'] = batch['mels'].cuda()
    batch['mel_lengths'] = batch['mel_lengths'].cuda()
    
    return batch

In [15]:
import matplotlib.pyplot as plt
import librosa.display
import numpy as np

ce_losses = []
recon_losses = []
kl_losses = []
model.eval()
with torch.no_grad():
    for i, batch in enumerate(test_loader):
        batch = to_cuda(batch)
        stt_outputs, tts_outputs = model(batch)
        ce_loss = stt_outputs['loss'].item()
        recon_loss = tts_outputs['recon_loss'].item()
        kl_loss = tts_outputs['kl_loss'].item()
        print(i, recon_loss)
        
        ce_losses.append(ce_loss)
        recon_losses.append(recon_loss)
        kl_losses.append(kl_loss)
        
print('done')

0 0.0511164553463459
1 0.05334679037332535
2 0.04996704310178757
3 0.0455244816839695
4 0.04983947426080704
5 0.04694046080112457
6 0.0477580688893795
7 0.04764760658144951
8 0.04883931949734688
9 0.04869336262345314
10 0.04870795086026192
11 0.046197354793548584
12 0.04644682630896568
13 0.05081651732325554
14 0.04799611493945122
15 0.05097957327961922
16 0.05274451896548271
17 0.048027075827121735
18 0.050488341599702835
19 0.04839976876974106
20 0.04963965341448784
21 0.05072908475995064
22 0.045613065361976624
23 0.05006881058216095
24 0.049521613866090775
25 0.048272572457790375
26 0.04660419002175331
27 0.04385516047477722
28 0.04969530925154686
29 0.04959821701049805
30 0.04940313100814819
31 0.04851392284035683
32 0.04737941548228264
33 0.04469595476984978
34 0.04860522598028183
35 0.04907471314072609
36 0.04989916458725929
37 0.04996095597743988
38 0.046461526304483414
39 0.04553047567605972
40 0.04668311029672623
41 0.04796278476715088
42 0.05322438105940819
43 0.051545288413

344 0.042558733373880386
345 0.040530264377593994
346 0.04186986759305
347 0.0447218120098114
done


In [16]:
print(np.mean(ce_losses))
print(np.mean(recon_losses))
print(np.mean(kl_losses))

0.11090015389211542
0.04535754048533138
0.015004534860430607
