In [3]:
import os
import json
import torch
from librosa import load
from env import AttrDict
from meldataset import mel_spectrogram
from models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator

In [4]:
config_file = os.path.join('config_v3.json')
with open(config_file) as f:
    data = f.read()
json_config = json.loads(data)
h = AttrDict(json_config)
device = torch.device('cuda')

In [5]:
generator = Generator(h).to(device)
MPD = MultiPeriodDiscriminator().to(device)
MSD = MultiScaleDiscriminator().to(device)

In [6]:
wavfile = 'test_files/4_5600.wav'
wav, sr = load(wavfile)
wav = torch.FloatTensor(wav).to(device)
print(wav.shape, sr)

torch.Size([93184]) 22050


In [12]:
def get_mel(x):
    return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)

x = get_mel(wav.unsqueeze(0))
print(x.shape)

torch.Size([1, 80, 364])


In [11]:
y_g_hat = generator(x)

torch.Size([1, 80, 364])
torch.Size([1, 256, 364])
torch.Size([1, 128, 2912])
torch.Size([1, 64, 23296])
torch.Size([1, 32, 93184])
torch.Size([1, 1, 93184])


In [9]:
nwav = wav.reshape(1, 1, -1)
y_d_rs, y_d_gs, fmap_rs, fmap_gs = MPD(nwav, y_g_hat)

torch.Size([1, 1, 93184])
torch.Size([1, 1, 46592, 2])
torch.Size([1, 32, 15531, 2])
torch.Size([1, 128, 5177, 2])
torch.Size([1, 512, 1726, 2])
torch.Size([1, 1024, 576, 2])
torch.Size([1, 1024, 576, 2])
torch.Size([1, 1152])
torch.Size([1, 1, 93184])
torch.Size([1, 1, 46592, 2])
torch.Size([1, 32, 15531, 2])
torch.Size([1, 128, 5177, 2])
torch.Size([1, 512, 1726, 2])
torch.Size([1, 1024, 576, 2])
torch.Size([1, 1024, 576, 2])
torch.Size([1, 1152])
torch.Size([1, 1, 93184])
torch.Size([1, 1, 31062, 3])
torch.Size([1, 32, 10354, 3])
torch.Size([1, 128, 3452, 3])
torch.Size([1, 512, 1151, 3])
torch.Size([1, 1024, 384, 3])
torch.Size([1, 1024, 384, 3])
torch.Size([1, 1152])
torch.Size([1, 1, 93184])
torch.Size([1, 1, 31062, 3])
torch.Size([1, 32, 10354, 3])
torch.Size([1, 128, 3452, 3])
torch.Size([1, 512, 1151, 3])
torch.Size([1, 1024, 384, 3])
torch.Size([1, 1024, 384, 3])
torch.Size([1, 1152])
torch.Size([1, 1, 93184])
torch.Size([1, 1, 18637, 5])
torch.Size([1, 32, 6213, 5])
torch.Si

In [10]:
y_d_rs, y_d_gs, fmap_rs, fmap_gs = MSD(nwav, y_g_hat)

torch.Size([1, 1, 93184])
torch.Size([1, 128, 93184])
torch.Size([1, 128, 46592])
torch.Size([1, 256, 23296])
torch.Size([1, 512, 5824])
torch.Size([1, 1024, 1456])
torch.Size([1, 1024, 1456])
torch.Size([1, 1024, 1456])
torch.Size([1, 1456])
torch.Size([1, 1, 93184])
torch.Size([1, 128, 93184])
torch.Size([1, 128, 46592])
torch.Size([1, 256, 23296])
torch.Size([1, 512, 5824])
torch.Size([1, 1024, 1456])
torch.Size([1, 1024, 1456])
torch.Size([1, 1024, 1456])
torch.Size([1, 1456])
torch.Size([1, 1, 46593])
torch.Size([1, 128, 46593])
torch.Size([1, 128, 23297])
torch.Size([1, 256, 11649])
torch.Size([1, 512, 2913])
torch.Size([1, 1024, 729])
torch.Size([1, 1024, 729])
torch.Size([1, 1024, 729])
torch.Size([1, 729])
torch.Size([1, 1, 46593])
torch.Size([1, 128, 46593])
torch.Size([1, 128, 23297])
torch.Size([1, 256, 11649])
torch.Size([1, 512, 2913])
torch.Size([1, 1024, 729])
torch.Size([1, 1024, 729])
torch.Size([1, 1024, 729])
torch.Size([1, 729])
torch.Size([1, 1, 23297])
torch.Size