In [15]:
import warnings
warnings.filterwarnings('ignore')

import os
import torch
from torch import nn

from hparams.hparams_W4L import create_hparams
from model import Model
from datasets import LJDataset, TextMelCollate
from utils import sizeof_fmt, Logger

In [16]:
!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="0"

Sat Mar 13 21:47:13 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.23.04    Driver Version: 455.23.04    CUDA Version: 11.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 3090    Off  | 00000000:19:00.0 Off |                  N/A |
|  0%   51C    P8    34W / 370W |   1441MiB / 24268MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  GeForce RTX 3090    Off  | 00000000:68:00.0 Off |                  N/A |
| 30%   40C    P0    70W / 350W |      0MiB / 24265MiB |      3%      Default |
|       

In [17]:
save_dir = 'save/model_W4L'
logger = Logger(save_dir=save_dir, new=False)
print('done')

done


In [18]:
!ls $save_dir

data.json    save_15000  save_35000  save_50000  save_65000  save_85000
save_0	     save_20000  save_40000  save_55000  save_70000  save_90000
save_10000   save_25000  save_45000  save_58396  save_75000  save_95000
save_100000  save_30000  save_5000   save_60000  save_80000


In [19]:
stt_hparams, tts_hparams = create_hparams()
model = Model(stt_hparams, tts_hparams, mode='train')
model = model.cuda()
step = 100000

size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters()))
print(f"Model size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.parameters()))
print(f"TTS size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.mel_decoder.parameters()))
print(f"MelDecoder size {size}")

if True:
    model, _, _ = logger.load(step, model, None)
print(step)

print('done')

Model size 124.1MiB
TTS size 36.5MiB
MelDecoder size 22.9MiB
loaded : 100000
100000
done


In [20]:
testset = LJDataset(tts_hparams, split='test')
collate_fn = TextMelCollate(tts_hparams)

test_loader = torch.utils.data.DataLoader(testset, num_workers=1, 
                          shuffle=False, sampler=None, batch_size=1, pin_memory=False,
                          drop_last=True, collate_fn=collate_fn)
print(test_loader)

<torch.utils.data.dataloader.DataLoader object at 0x7f04183d3c90>


In [21]:
def to_cuda(batch):
    batch['text'] = batch['text'].cuda()
    batch['text_lengths'] = batch['text_lengths'].cuda()
    batch['mels'] = batch['mels'].cuda()
    batch['mel_lengths'] = batch['mel_lengths'].cuda()
    
    return batch

In [25]:
import matplotlib.pyplot as plt
import librosa.display
import numpy as np

ce_losses = []
recon_losses = []
kl_losses = []
sample_stds = []
model.eval()
with torch.no_grad():
    for i, batch in enumerate(test_loader):
        batch = to_cuda(batch)
        stt_outputs, tts_outputs = model(batch)
        ce_loss = stt_outputs['loss'].item()
        recon_loss = tts_outputs['recon_loss'].item()
        kl_loss = tts_outputs['kl_loss'].item()
        print(i, recon_loss)
        
        ce_losses.append(ce_loss)
        recon_losses.append(recon_loss)
        kl_losses.append(kl_loss)
        
        model.eval()
        
        samples_list = []
        for _ in range(10):
            samples, _ = model.inference(batch['text'], batch['mels'].size(2), stt_outputs["alignments"], temperature=1.0, clip=2)
            samples_list.append(samples)
        samples_list = torch.cat(samples_list, dim=0)
        sample_std = torch.std(samples_list, dim=0).mean().item()
        sample_stds.append(sample_std)
            
        
print('done')

print(np.mean(ce_losses))
print(np.mean(recon_losses))
print(np.mean(kl_losses))
print(np.mean(sample_stds))

0 0.05117952823638916
1 0.05328477919101715
2 0.050713758915662766
3 0.05012204498052597
4 0.05137081444263458
5 0.048514194786548615
6 0.048756733536720276
7 0.04893988370895386
8 0.04954352602362633
9 0.04991484433412552
10 0.049470651894807816
11 0.046213723719120026
12 0.04611681401729584
13 0.05175754055380821
14 0.04774719849228859
15 0.05252053588628769
16 0.05569665879011154
17 0.04767701029777527
18 0.053088244050741196
19 0.04904475808143616
20 0.05004804953932762
21 0.0522751621901989
22 0.04533199593424797
23 0.05048172548413277
24 0.04923214390873909
25 0.04895111173391342
26 0.04730668663978577
27 0.04635137319564819
28 0.052181679755449295
29 0.04959711804986
30 0.04889385402202606
31 0.048393238335847855
32 0.04846106469631195
33 0.045911602675914764
34 0.05393170565366745
35 0.05001062527298927
36 0.050474829971790314
37 0.0512676015496254
38 0.04823332279920578
39 0.0471385233104229
40 0.048709459602832794
41 0.04884311929345131
42 0.05511021614074707
43 0.05468567833

343 0.0408361442387104
344 0.04355400428175926
345 0.04189413785934448
346 0.0417381227016449
347 0.04582593962550163
done


In [26]:
print(np.mean(ce_losses))
print(np.mean(recon_losses))
print(np.mean(kl_losses))
print(np.mean(sample_stds))

0.11727559972173472
0.046232935608546626
0.014900211046781691
0.2766318055319375
