In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals

import glob
import os
import argparse
import json
import torch
from scipy.io.wavfile import write
from env import AttrDict
from meldataset import mel_spectrogram, MAX_WAV_VALUE, load_wav
from models import Generator
from stft import TorchSTFT

from Utils.JDC.model import JDCNet

In [None]:
h = None
device = None


def load_checkpoint(filepath, device):
    assert os.path.isfile(filepath)
    print("Loading '{}'".format(filepath))
    checkpoint_dict = torch.load(filepath, map_location=device)
    print("Complete.")
    return checkpoint_dict


def get_mel(x):
    return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)


def scan_checkpoint(cp_dir, prefix):
    pattern = os.path.join(cp_dir, prefix + '*')
    cp_list = glob.glob(pattern)
    if len(cp_list) == 0:
        return ''
    return sorted(cp_list)[-1]

In [None]:
F0_model = JDCNet(num_class=1, seq_len=192)

In [None]:
cp_path = "cp_hifigan"

In [None]:
with open(cp_path + "/config.json") as f:
    data = f.read()

json_config = json.loads(data)
h = AttrDict(json_config)

In [None]:
device = torch.device('cuda:{:d}'.format(0))

In [None]:
generator = Generator(h, F0_model).to(device)
stft = TorchSTFT(filter_length=h.gen_istft_n_fft, hop_length=h.gen_istft_hop_size, win_length=h.gen_istft_n_fft).to(device)

In [None]:
cp_g = scan_checkpoint(cp_path, 'g_')
state_dict_g = load_checkpoint(cp_g, device)
generator.load_state_dict(state_dict_g['generator'])
generator.remove_weight_norm()
_ = generator.eval()

### Resynthesis

In [None]:
# pick a file to resynthesize
path = os.path.join("LJSpeech-1.1/wavs", "LJ049-0163.wav")

In [None]:
wav, sr = load_wav(path)
wav = wav / MAX_WAV_VALUE
wav = torch.FloatTensor(wav).to(device)
x = get_mel(wav.unsqueeze(0))

In [None]:
with torch.no_grad():
    spec, phase = generator(x)
    y_g_hat = stft.inverse(spec, phase)
    audio = y_g_hat.squeeze()
    audio = audio * MAX_WAV_VALUE
    audio = audio.cpu().numpy().astype('int16')
import IPython.display as ipd

print('Synthesized:')
display(ipd.Audio(audio, rate=22050))

print('Original:')
display(ipd.Audio(wav.cpu(), rate=22050))