In [1]:
import warnings
warnings.filterwarnings('ignore')

import os
import torch
from torch import nn

from hparams.hparams_S4L import create_hparams
from model import Model
from datasets import LJDataset, TextMelCollate
from utils import sizeof_fmt, Logger

In [2]:
!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="1"

Sat Mar 13 21:58:21 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 208...  Off  | 00000000:2E:00.0 Off |                  N/A |
| 57%   65C    P2    96W / 250W |    256MiB / 11016MiB |      3%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  GeForce RTX 208...  Off  | 00000000:2F:00.0 Off |                  N/A |
| 37%   59C    P0    50W / 250W |      0MiB / 11019MiB |      0%      Default |
|       

In [3]:
save_dir = '/data/save/model_S4L'
logger = Logger(save_dir=save_dir, new=False)
print('done')

done


In [4]:
!ls $save_dir

data.json    save_100817  save_65000  save_75000  save_85000  save_94076
save_100000  save_60377   save_70000  save_80000  save_90000  save_95000


In [5]:
stt_hparams, tts_hparams = create_hparams()
model = Model(stt_hparams, tts_hparams, mode='train')
model = model.cuda()
step = 100000

size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters()))
print(f"Model size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.parameters()))
print(f"TTS size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.mel_decoder.parameters()))
print(f"MelDecoder size {size}")

if True:
    model, _, _ = logger.load(step, model, None)
print(step)

print('done')

Model size 243.8MiB
TTS size 156.2MiB
MelDecoder size 94.9MiB
loaded : 100000
100000
done


In [6]:
testset = LJDataset(tts_hparams, split='test')
collate_fn = TextMelCollate(tts_hparams)

test_loader = torch.utils.data.DataLoader(testset, num_workers=1, 
                          shuffle=False, sampler=None, batch_size=1, pin_memory=False,
                          drop_last=True, collate_fn=collate_fn)
print(test_loader)

<torch.utils.data.dataloader.DataLoader object at 0x7ff786e29ac0>


In [7]:
def to_cuda(batch):
    batch['text'] = batch['text'].cuda()
    batch['text_lengths'] = batch['text_lengths'].cuda()
    batch['mels'] = batch['mels'].cuda()
    batch['mel_lengths'] = batch['mel_lengths'].cuda()
    
    return batch

In [8]:
import matplotlib.pyplot as plt
import librosa.display
import numpy as np

ce_losses = []
recon_losses = []
kl_losses = []
sample_stds = []
model.eval()
with torch.no_grad():
    for i, batch in enumerate(test_loader):
        batch = to_cuda(batch)
        stt_outputs, tts_outputs = model(batch)
        ce_loss = stt_outputs['loss'].item()
        recon_loss = tts_outputs['recon_loss'].item()
        kl_loss = tts_outputs['kl_loss'].item()
        print(i, recon_loss)
        
        ce_losses.append(ce_loss)
        recon_losses.append(recon_loss)
        kl_losses.append(kl_loss)
        
        model.eval()
        
        samples_list = []
        for _ in range(10):
            samples, _ = model.inference(batch['text'], batch['mels'].size(2), stt_outputs["alignments"], temperature=1.0, clip=2)
            samples_list.append(samples)
        samples_list = torch.cat(samples_list, dim=0)
        sample_std = torch.std(samples_list, dim=0).mean().item()
        sample_stds.append(sample_std)
            
        
print('done')

print(np.mean(ce_losses))
print(np.mean(recon_losses))
print(np.mean(kl_losses))
print(np.mean(sample_stds))

0 0.053579915314912796
1 0.05577331408858299
2 0.05418180674314499
3 0.0505235530436039
4 0.05432279035449028
5 0.05032835900783539
6 0.053301308304071426
7 0.052974503487348557
8 0.05195525661110878
9 0.05141054466366768
10 0.051849573850631714
11 0.04942905530333519
12 0.049922797828912735
13 0.053881313651800156
14 0.05044994503259659
15 0.052898701280355453
16 0.06357131153345108
17 0.05134565010666847
18 0.05227970331907272
19 0.0526837520301342
20 0.052169203758239746
21 0.05378590524196625
22 0.048659149557352066
23 0.05182451009750366
24 0.054230693727731705
25 0.052053097635507584
26 0.051507368683815
27 0.04693121835589409
28 0.05309556424617767
29 0.0522792674601078
30 0.05274364352226257
31 0.05327274277806282
32 0.051350805908441544
33 0.0489443801343441
34 0.06983238458633423
35 0.05263146385550499
36 0.05331539362668991
37 0.05516939237713814
38 0.052530910819768906
39 0.049828317016363144
40 0.05102500319480896
41 0.05195726826786995
42 0.05799698457121849
43 0.05826881

343 0.04273485019803047
344 0.04578794538974762
345 0.04235602915287018
346 0.04523347318172455
347 0.04936463385820389
done
0.13321736186010943
0.04891854572784284
0.015724523621878917
0.27639047394710026
