In [1]:
import warnings
warnings.filterwarnings('ignore')

import os
import torch
from torch import nn

from hparams.hparams_W4G import create_hparams
from model import Model
from datasets import LJDataset, TextMelCollate
from utils import sizeof_fmt, Logger

In [2]:
!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="0"

Sat Mar 13 20:48:20 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.23.04    Driver Version: 455.23.04    CUDA Version: 11.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 3090    Off  | 00000000:19:00.0 Off |                  N/A |
| 75%   64C    P2   271W / 370W |  12878MiB / 24268MiB |     57%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  GeForce RTX 3090    Off  | 00000000:68:00.0 Off |                  N/A |
| 67%   80C    P2   286W / 350W |  12431MiB / 24265MiB |    100%      Defaul

In [3]:
save_dir = 'save/model_W4G'
logger = Logger(save_dir=save_dir, new=False)
print('done')

done


In [13]:
!ls $save_dir

data.json   save_20000	save_40000  save_55000	save_70000  save_85000
save_0	    save_25000	save_45000  save_58657	save_75000  save_85150
save_10000  save_30000	save_5000   save_60000	save_80000  save_90000
save_15000  save_35000	save_50000  save_65000	save_82879  save_95000


In [14]:
stt_hparams, tts_hparams = create_hparams()
model = Model(stt_hparams, tts_hparams, mode='train')
model = model.cuda()
step = 95000

size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters()))
print(f"Model size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.parameters()))
print(f"TTS size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.mel_decoder.parameters()))
print(f"MelDecoder size {size}")

if True:
    model, _, _ = logger.load(step, model, None)
print(step)

print('done')

Model size 124.1MiB
TTS size 36.5MiB
MelDecoder size 22.9MiB
loaded : 95000
95000
done


In [15]:
testset = LJDataset(tts_hparams, split='test')
collate_fn = TextMelCollate(tts_hparams)

test_loader = torch.utils.data.DataLoader(testset, num_workers=1, 
                          shuffle=False, sampler=None, batch_size=1, pin_memory=False,
                          drop_last=True, collate_fn=collate_fn)
print(test_loader)

<torch.utils.data.dataloader.DataLoader object at 0x7f74f0d9fd90>


In [16]:
def to_cuda(batch):
    batch['text'] = batch['text'].cuda()
    batch['text_lengths'] = batch['text_lengths'].cuda()
    batch['mels'] = batch['mels'].cuda()
    batch['mel_lengths'] = batch['mel_lengths'].cuda()
    
    return batch

In [17]:
import matplotlib.pyplot as plt
import librosa.display
import numpy as np

recon_losses = []
model.eval()
with torch.no_grad():
    for i, batch in enumerate(test_loader):
        batch = to_cuda(batch)
        stt_outputs, tts_outputs = model(batch)
        recon_loss = tts_outputs['recon_loss'].item()
        print(i, recon_loss)
        recon_losses.append(recon_loss)
        
print('done')

0 0.05051849037408829
1 0.05046040192246437
2 0.04937683045864105
3 0.0468469113111496
4 0.05073393136262894
5 0.04750362038612366
6 0.048118092119693756
7 0.049034588038921356
8 0.04871154576539993
9 0.049938857555389404
10 0.04947817325592041
11 0.045689281076192856
12 0.046406496316194534
13 0.052509941160678864
14 0.04731038212776184
15 0.05133011192083359
16 0.05782129615545273
17 0.04817279055714607
18 0.05089019984006882
19 0.04777495935559273
20 0.05020926892757416
21 0.04995884373784065
22 0.04550570622086525
23 0.05027031525969505
24 0.04870418459177017
25 0.04783337563276291
26 0.04773690924048424
27 0.04534708335995674
28 0.050113532692193985
29 0.04902051389217377
30 0.04929393157362938
31 0.05014122650027275
32 0.048422038555145264
33 0.04685681685805321
34 0.05223092436790466
35 0.04868811368942261
36 0.04824455454945564
37 0.049792446196079254
38 0.04826692119240761
39 0.04556838050484657
40 0.047954071313142776
41 0.04925279691815376
42 0.05215505510568619
43 0.0533707

344 0.0437898263335228
345 0.04034622386097908
346 0.043044958263635635
347 0.04522157087922096
done


In [18]:
print(np.mean(recon_losses))

0.045737755880958735
