In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals
import configparser
import matplotlib.pyplot as plt

import glob
import os
import numpy as np
import argparse
import json
import torch
from scipy.io.wavfile import write
import sys
sys.path.append("../../hifi-gan/")
from env import AttrDict
from meldataset import MAX_WAV_VALUE, MelDataset, get_dataset_filelist, mel_spectrogram
from models import Generator
from train import load_checkpoint, scan_checkpoint

h = None
device = None

import sys
sys.path.append("../")
from env import AttrDict, build_env
from meldataset import MelDataset, mel_spectrogram, get_dataset_filelist
from models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_loss,\
    discriminator_loss
from utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint

torch.backends.cudnn.benchmark = True

sys.path.append("../../cuhksz-phd/sho_util/pyfiles/")
from pytorch import cuda2numpy
from basic import plot_spectrogram
from sound import play_audio

In [None]:
a = configparser.ConfigParser()
base_dir = "../"
data_dir = f"{base_dir}LibriTTS/"
############ Checkpoint ###########
# a.checkpoint_file = f'/mntcephfs/data/audiow/shoinoue/Model/models/hifigan/LibriTTS/g_00025000'
# a.checkpoint_file = f'/mntcephfs/data/audiow/shoinoue/Model/models/hifigan/LibriTTS2/g_00031000'
# a.checkpoint_file = f'/mntcephfs/data/audiow/shoinoue/Model/models/hifigan/LibriTTS3/g_00220000'
a.checkpoint_file = f'/mntcephfs/data/audiow/shoinoue/Model/models/hifigan/LibriTTS4/g_00150000'
###################################

a.input_wavs_dir = "/mntcephfs/" 
a.input_mels_dir = "/mntcephfs/"
a.input_training_file = f'{data_dir}/training.txt'
a.input_validation_file = f'{data_dir}/validation.txt'

a.fine_tuning = False

scaler = None
try:
    # if int(os.path.dirname(a.checkpoint_file)[-1]) in [4, 5, 6]:
    import joblib
    scaler_filename = f"../../seq2seq-vc/notebooks/ckpts/scalers/LibriTTS-R_hifiganmel.save"
    scaler = joblib.load(scaler_filename)
except ValueError:
    pass

config_file = os.path.join(os.path.split(a.checkpoint_file)[0], 'config.json')
with open(config_file) as f:
    data = f.read()
global h
json_config = json.loads(data)
h = AttrDict(json_config)

torch.manual_seed(h.seed)
global device
if torch.cuda.is_available():
    torch.cuda.manual_seed(h.seed)
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    
training_filelist, validation_filelist = get_dataset_filelist(a)
validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels,
# validset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels,
                      h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0,
                      fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning,
                      base_mels_path=a.input_mels_dir,
                      scaler=scaler,
                     )

generator = Generator(h).to(device)
state_dict_g = load_checkpoint(a.checkpoint_file, device)
generator.load_state_dict(state_dict_g['generator'])
generator.eval()
generator.remove_weight_norm()

In [None]:
def transform(x):
    mean_ = torch.tensor(scaler.mean_.reshape(1, -1, 1)).to(x.device)
    scale_ = torch.tensor(scaler.scale_.reshape(1, -1, 1)).to(x.device)
    return (x-mean_)/scale_

def inverse_transform(x):
    mean_ = torch.tensor(scaler.mean_.reshape(1, -1, 1)).to(x.device)
    scale_ = torch.tensor(scaler.scale_.reshape(1, -1, 1)).to(x.device)
    return (x*scale_+mean_)

def generate_audio(mel):
    with torch.no_grad():
        x = mel.unsqueeze(0).to(device)
        audio = generator(x)
        mel = mel_spectrogram(audio.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
                                      h.hop_size, h.win_size,
                                      h.fmin, h.fmax_for_loss)
        # if scaler!=None:
            # mel = transform(mel)
        audio = audio.squeeze()
        audio = audio * MAX_WAV_VALUE
        audio = audio.cpu().numpy().astype('int16')
    return audio, cuda2numpy(mel)[0]

In [None]:
%matplotlib inline
num = 5
width = 2
length = 1
for i in range(num):
    print("------------------------------------------------------------------")
    x, y, _, _ = validset[int(len(validset)/500)+i]
    audio, mel = generate_audio(x)
    fig = plt.figure(figsize=(4*width, 3*length))
    if scaler!=None:
        x = inverse_transform(x.unsqueeze(0))[0]
    M = cuda2numpy(x)
    ax = plot_spectrogram(M, fig, (length, width, 1), title="source", title_font=12)
    M = mel
    ax = plot_spectrogram(M, fig, (length, width, 2), title="target", title_font=12)
    plt.show()
    
    play_audio(cuda2numpy(y), h.sampling_rate)
    play_audio(audio, h.sampling_rate)

- Specific Files

In [None]:
import librosa
# wavs_dir = "/mntcephfs/lee_dataset/tts/LibriTTS_R/" 
wavs_dir = "/mntcephfs/lab_data/shoinoue/Dataset/PD-AST/SLT/English/wav/"
# mels_dir = "/mntcephfs/data/audiow/shoinoue/Dataset/LibriTTS_R/features/"
mels_dir = "/mntcephfs/lab_data/shoinoue/Dataset/PD-AST/SLT/English/hifiganmel/"

In [None]:
files = glob.glob(wavs_dir + "*.wav")
files.sort()

In [None]:
%matplotlib inline
num = 3
width = 2
length = 1
for i in range(num):
    print("------------------------------------------------------------------")
    wavfile = files[i]
    wav = librosa.load(wavfile, sr=h.sampling_rate)[0]
    # melfile = mels_dir + wavfile[len(wavs_dir):-4] + "_hifiganmel.npy"
    melfile = mels_dir + wavfile[len(wavs_dir):-4] + ".npy"
    melarray = torch.from_numpy(np.load(melfile))
    if scaler!=None:
        melarray = transform(melarray.unsqueeze(0))[0].to(melarray.dtype)
    audio, mel = generate_audio(melarray)
    fig = plt.figure(figsize=(4*width, 3*length))
    M = cuda2numpy(inverse_transform(melarray.unsqueeze(0))[0])
    ax = plot_spectrogram(M, fig, (length, width, 1), title="source", title_font=12)
    M = mel
    ax = plot_spectrogram(M, fig, (length, width, 2), title="target", title_font=12)
    plt.show()
    
    play_audio(wav, h.sampling_rate)
    play_audio(audio, h.sampling_rate)