In [1]:
import warnings
warnings.filterwarnings('ignore')

import os
import torch
from torch import nn

from hparams.hparams_S4G import create_hparams
from model import Model
from datasets import LJDataset, TextMelCollate
from utils import sizeof_fmt, Logger

In [2]:
!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="1"

Fri Mar 19 12:55:02 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 208...  Off  | 00000000:2E:00.0 Off |                  N/A |
| 44%   61C    P0    96W / 250W |      0MiB / 11016MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  GeForce RTX 208...  Off  | 00000000:2F:00.0 Off |                  N/A |
| 26%   50C    P0    46W / 250W |      0MiB / 11019MiB |      0%      Default |
|       

In [3]:
save_dir = '/data/save/model_S4G'
logger = Logger(save_dir=save_dir, new=False)
print('done')

done


In [4]:
!ls $save_dir

data.json    save_100667  save_61887  save_75000  save_90000
save_100000  save_102025  save_65000  save_80000  save_95000
save_100096  save_400000  save_70000  save_85000  save_96091


In [5]:
stt_hparams, tts_hparams = create_hparams()
model = Model(stt_hparams, tts_hparams, mode='train')
model = model.cuda()
step = 400000

size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters()))
print(f"Model size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.parameters()))
print(f"TTS size {size}")

size = sizeof_fmt(4 * sum(p.numel() for p in model.tts.mel_decoder.parameters()))
print(f"MelDecoder size {size}")

if True:
    model, _, _ = logger.load(step, model, None)
print(step)

print('done')

Model size 243.8MiB
TTS size 156.2MiB
MelDecoder size 94.9MiB
loaded : 400000
400000
done


In [10]:
trainset = LJDataset(tts_hparams, split='train')
testset = LJDataset(tts_hparams, split='valid')
collate_fn = TextMelCollate(tts_hparams)

train_loader = torch.utils.data.DataLoader(trainset, num_workers=1, #tts_hparams.num_workers, 
                          shuffle=True, sampler=None, batch_size=tts_hparams.batch_size, pin_memory=False,
                          drop_last=True, collate_fn=collate_fn)
print(train_loader)

test_loader = torch.utils.data.DataLoader(testset, num_workers=1, 
                          shuffle=False, sampler=None, batch_size=1, pin_memory=False,
                          drop_last=True, collate_fn=collate_fn)
print(test_loader)

<torch.utils.data.dataloader.DataLoader object at 0x7efb7098eb50>
<torch.utils.data.dataloader.DataLoader object at 0x7efb70d51fd0>


In [11]:
def to_cuda(batch):
    batch['text'] = batch['text'].cuda()
    batch['text_lengths'] = batch['text_lengths'].cuda()
    batch['mels'] = batch['mels'].cuda()
    batch['mel_lengths'] = batch['mel_lengths'].cuda()
    
    return batch

In [12]:
import matplotlib.pyplot as plt
import librosa.display
import numpy as np

ce_losses = []
recon_losses = []
kl_losses = []
sample_stds = []
model.eval()
with torch.no_grad():
    for i, batch in enumerate(train_loader):
        batch = to_cuda(batch)
        stt_outputs, tts_outputs = model(batch)
        ce_loss = stt_outputs['loss'].item()
        recon_loss = tts_outputs['recon_loss'].item()
        kl_loss = tts_outputs['kl_loss'].item()
        print(i, recon_loss)
        
        ce_losses.append(ce_loss)
        recon_losses.append(recon_loss)
        kl_losses.append(kl_loss)
        
        model.eval()
        
        samples_list = []
        for _ in range(10):
            samples, _ = model.inference(batch['text'], batch['mels'].size(2), stt_outputs["alignments"], temperature=1.0, clip=2)
            samples_list.append(samples)
        samples_list = torch.cat(samples_list, dim=0)
        sample_std = torch.std(samples_list, dim=0).mean().item()
        sample_stds.append(sample_std)
            
        
print('done')

print(np.mean(ce_losses))
print(np.mean(recon_losses))
print(np.mean(kl_losses))
print(np.mean(sample_stds))

0 0.022382887080311775
1 0.023449862375855446
2 0.025559375062584877
3 0.022370530292391777
4 0.02345271222293377
5 0.023348787799477577
6 0.023410581052303314
7 0.021123075857758522
8 0.022145019844174385
9 0.022328414022922516
10 0.02523878403007984
11 0.025300167500972748
12 0.023659462109208107
13 0.02310110256075859
14 0.02213812991976738
15 0.024974577128887177
16 0.024835029616951942
17 0.023937776684761047
18 0.025353282690048218
19 0.02510518953204155
20 0.021854959428310394
21 0.022878659889101982
22 0.021198203787207603
23 0.024732433259487152
24 0.02315761335194111
25 0.026018014177680016
26 0.02376648783683777
27 0.023244759067893028
28 0.023344721645116806
29 0.024323996156454086
30 0.024355165660381317
31 0.02530411258339882
32 0.024242132902145386
33 0.023628029972314835
34 0.02310008555650711
35 0.024252530187368393
36 0.023968582972884178
37 0.023101666942238808
38 0.02426758036017418
39 0.022403167560696602
40 0.0250279288738966
41 0.024806970730423927
42 0.023003494

338 0.025207098573446274
339 0.02434537000954151
340 0.022503377869725227
341 0.023177873343229294
342 0.024008246138691902
343 0.021935803815722466
344 0.023487508296966553
345 0.02543039247393608
346 0.024731161072850227
347 0.0256471149623394
348 0.02332654967904091
349 0.022623606026172638
350 0.021224310621619225
351 0.022512972354888916
352 0.021250782534480095
353 0.022557398304343224
354 0.022394316270947456
355 0.025877567008137703
356 0.02083621546626091
357 0.02211798168718815
358 0.023378843441605568
359 0.024565227329730988
360 0.021742798388004303
361 0.02292608842253685
362 0.023887598887085915
363 0.02292718179523945
364 0.023267745971679688
365 0.023426825180649757
366 0.022565539926290512
367 0.023529935628175735
368 0.02575341798365116
369 0.025642869994044304
370 0.022154318168759346
371 0.023595545440912247
372 0.024525586515665054
373 0.02134658396244049
374 0.02236410416662693
375 0.026104576885700226
376 0.023626994341611862
377 0.025683073326945305
378 0.023310

In [13]:
5.611293990595001e-05

5.611293990595001e-05

In [15]:
0.00005

5e-05