In [1]:
import warnings
warnings.filterwarnings('ignore')

import os
import torch
from torch import nn

from hparams.hparams_S5L import create_hparams
from model import Model
from datasets import LJDataset, TextMelCollate
from utils import sizeof_fmt, Logger

In [2]:
!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="1"

Sat Mar 13 21:16:20 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.39       Driver Version: 460.39       CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  TITAN RTX           Off  | 00000000:0A:00.0 Off |                  N/A |
| 54%   75C    P2   285W / 280W |  17732MiB / 24220MiB |     55%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  TITAN RTX           Off  | 00000000:42:00.0 Off |                  N/A |
| 30%   60C    P0    55W / 280W |      0MiB / 24217MiB |      0%      Defaul

In [3]:
save_dir = '/data/save/model_S5L'
logger = Logger(save_dir=save_dir, new=False)
print('done')

done


In [4]:
!ls $save_dir

data.json    save_15000  save_40000  save_50000  save_75000  save_95000
save_0	     save_20000  save_40493  save_55000  save_80000
save_10000   save_25000  save_40840  save_60000  save_85000
save_100000  save_30000  save_45000  save_65000  save_90000
save_100808  save_35000  save_5000   save_70000  save_94571


In [5]:
stt_hparams, tts_hparams = create_hparams()
model = Model(stt_hparams, tts_hparams, mode='train')
model = model.cuda()
step = 100000

size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters()))
print(f"Model size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.parameters()))
print(f"TTS size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.mel_decoder.parameters()))
print(f"MelDecoder size {size}")

if True:
    model, _, _ = logger.load(step, model, None)
print(step)

print('done')

Model size 283.0MiB
TTS size 195.3MiB
MelDecoder size 119.0MiB
loaded : 100000
100000
done


In [6]:
testset = LJDataset(tts_hparams, split='test')
collate_fn = TextMelCollate(tts_hparams)

test_loader = torch.utils.data.DataLoader(testset, num_workers=1, 
                          shuffle=False, sampler=None, batch_size=1, pin_memory=False,
                          drop_last=True, collate_fn=collate_fn)
print(test_loader)

<torch.utils.data.dataloader.DataLoader object at 0x7f97bcef1950>


In [7]:
def to_cuda(batch):
    batch['text'] = batch['text'].cuda()
    batch['text_lengths'] = batch['text_lengths'].cuda()
    batch['mels'] = batch['mels'].cuda()
    batch['mel_lengths'] = batch['mel_lengths'].cuda()
    
    return batch

In [8]:
import matplotlib.pyplot as plt
import librosa.display
import numpy as np

ce_losses = []
recon_losses = []
kl_losses = []
model.eval()
with torch.no_grad():
    for i, batch in enumerate(test_loader):
        batch = to_cuda(batch)
        stt_outputs, tts_outputs = model(batch)
        ce_loss = stt_outputs['loss'].item()
        recon_loss = tts_outputs['recon_loss'].item()
        kl_loss = tts_outputs['kl_loss'].item()
        print(i, recon_loss)
        
        ce_losses.append(ce_loss)
        recon_losses.append(recon_loss)
        kl_losses.append(kl_loss)
        
print('done')

0 0.053443748503923416
1 0.05689402297139168
2 0.054812222719192505
3 0.050693854689598083
4 0.05591539666056633
5 0.05216245725750923
6 0.05295998603105545
7 0.05111776292324066
8 0.05358664318919182
9 0.05311176925897598
10 0.05503297969698906
11 0.048743873834609985
12 0.04986262321472168
13 0.05695251002907753
14 0.05277096852660179
15 0.05429377034306526
16 0.05827256664633751
17 0.05206287279725075
18 0.053244393318891525
19 0.05135001242160797
20 0.055212657898664474
21 0.05682828277349472
22 0.051175039261579514
23 0.055411446839571
24 0.053938329219818115
25 0.05218488350510597
26 0.05062031373381615
27 0.0489133857190609
28 0.055943362414836884
29 0.052995465695858
30 0.05407753214240074
31 0.05633395165205002
32 0.052949223667383194
33 0.050998199731111526
34 0.0535462312400341
35 0.05361344292759895
36 0.052428144961595535
37 0.05498763173818588
38 0.05533332750201225
39 0.049471013247966766
40 0.05211900174617767
41 0.05202941969037056
42 0.056318093091249466
43 0.05765515

345 0.04445688799023628
346 0.04618540778756142
347 0.05070817470550537
done


In [9]:
print(np.mean(ce_losses))
print(np.mean(recon_losses))
print(np.mean(kl_losses))

0.11566263598693884
0.050034213918207705
0.015279741984545842
