In [1]:
import os
import sys
import glob
#import subprocess

import numpy as np
import torch

repos_mlps_dir = os.path.dirname(os.getcwd())
sys.path.append(repos_mlps_dir)
import default_settings as default
from data.preprocess import Preprocessor
from utils import load_config, load_model, make_directory, set_seed

sys.path.append(default.repos_dir)
from vocoders import hifigan as hfg

from sak import display as dp
from vocoders import default_settings as default_vcd

In [2]:
class Inference():
    def __init__(
        self,
        preprocessor_config_path,
        checkpoint_path
        ):
        self.preprocessor_config = load_config(preprocessor_config_path)
        self.preprocessor = Preprocessor(self.preprocessor_config)
        self.model = load_model(checkpoint_path, eval=True)

    
    def load_notes(self, mid_path, txt_path):
        notes, phonemes = self.preprocessor.prepare_inference(
            mid_path, txt_path
        )
        return notes, phonemes
    
    
    def inference(self, notes, phonemes, mel_path=None):
        chunk_size = self.model.seq_len
        preds = []
        total_len = len(notes)
        notes = notes.to(device)
        phonemes = phonemes.to(device)
        remainder = total_len % chunk_size
        mel_dim = self.preprocessor_config.n_mel_channels
        
        ## inference.
        #@torch.no_grad()
        if remainder:
            pad_size = chunk_size - remainder
            padding = torch.zeros(pad_size, dtype=int).to(device)
            phonemes = torch.cat((phonemes, padding))
            notes = torch.cat((notes, padding))
            batch_phonemes = phonemes.reshape(-1, chunk_size)
            batch_notes = notes.reshape(-1, chunk_size)
            preds = self.model(batch_notes, batch_phonemes)
            preds = preds.reshape(-1, mel_dim)[:-pad_size]
        else:
            batch_phonemes = phonemes.reshape(-1, chunk_size)
            batch_notes = notes.reshape(-1, chunk_size)
            preds = self.model(batch_notes, batch_phonemes)
            mel_dim = preds.size(-1)
            preds = preds.reshape(-1, mel_dim)
        preds = preds.transpose(0, 1).unsqueeze(0)
        
        if not mel_path==None:
            np.save(mel_path, preds.detach().numpy())
        
        return preds

## Initialize.

In [3]:
## experimental settings.
device = "cpu" # 'cpu' or 'cuda'

# MLP_singer.
svs_dir  = r'/home/akikun/projects/svs'
work_dir = os.path.join(svs_dir, 'csd_hifigan_universal')

checkpoint_path = os.path.join(work_dir, 'checkpoints', 'trained', 'best.pt')
preprocessor_config_path = os.path.join(work_dir, 'configs', 'preprocess.json')

# hifi_gan.
checkpoint_hifigan_path = r'/home/akikun/common/pre_trained_vocoders/hifigan/g_02500000'

In [4]:
## svs
infer = Inference(
    preprocessor_config_path, 
    checkpoint_path
)

## hifigan.
global h
h = hfg.load_config(default_vcd.hifigan_config_path)
global device_vocode
generator, device_vocode = hfg.load_model(default_vcd.hifigan_model_path, h)

Loading '/home/akikun/common/pre_trained_vocoders/hifigan/g_02500000'
Complete.
Removing weight norm...


## Inference.

In [5]:
def svs(mid_path, txt_path, mel_path, pred_path):
    ## load notes and phonemes. 
    notes, phonemes = infer.load_notes(mid_path, txt_path)

    ## inference.
    preds = infer.inference(notes, phonemes, mel_path=mel_path)

    ## vocode.
    x = preds.cpu().detach().numpy()
    mel = torch.FloatTensor(x).to(device_vocode)
    hfg.vocode(mel, generator, h, pred_path)

In [6]:
## directories.

# data directory in where the directory 'wav', 'mid' and 'txt' are included.
data_dir = os.path.join(svs_dir, 'data')

# the directory where the predicted wav will be saved.
pred_dir = os.path.join(work_dir, 'preds')
if not os.path.exists(pred_dir):
    os.makedirs(pred_dir)

# the directory where the predicted mel will be saved.
mel_dir = os.path.join(work_dir, 'mel')
if not os.path.exists(mel_dir):
    os.makedirs(mel_dir)

In [None]:
## wav id to infer.
#song = 'kr001a'
wav_paths = glob.glob(os.path.join(data_dir, 'wav', '*.wav'))
song_ids  = [os.path.basename(wav_path).replace('.wav', '') for wav_path in wav_paths]

for song_id in song_ids:
    print(song_id)
    
    # input
    mid_path = os.path.join(data_dir, "mid", f"{song_id}.mid")
    txt_path = os.path.join(data_dir, "txt", f"{song_id}.txt")

    # output
    mel_path  = os.path.join(mel_dir, f"{song_id}.npy")
    pred_path = os.path.join(pred_dir, f"{song_id}.wav")

    # inference.
    svs(mid_path, txt_path, mel_path, pred_path)

kr001a
kr001b
kr002a
kr002b
kr003a
kr003b
kr004a
kr004b
kr005a
kr005b
kr006a
kr006b
kr007a
kr007b
kr008a
kr008b
kr009a
kr009b
kr010a
kr010b
kr011a
kr011b
kr012a
kr012b
kr013a
kr013b
kr014a
kr014b
kr015a
kr015b
kr016a
kr016b
kr017a
kr017b
kr018a
kr018b
kr019a
kr019b
kr020a
kr020b
kr021a
kr021b
kr022a
kr022b
kr023a
kr023b
kr024a
kr024b
kr025a
kr025b
kr026a
kr026b
kr027a
kr027b
kr028a


## listen to the result

In [8]:
pred0_dir = r'/home/akikun/projects/svs/csd_hifigan_universal/box/wav'
pred0_path = os.path.join(pred0_dir, f"{song}_generated_e2e.wav")
dp.disp_wav(pred_path)
dp.disp_wav(pred0_path)

mel = np.load(mel_path)
mel = mel.reshape(np.shape(mel)[1:])
dp.disp_spectrogram(mel)

In [9]:
# from sak import signal_processing as sp
# p = sp.load_wav(pred_path)
# r = sp.load_wav(raw_path)
# print(np.shape(p))
# print(np.shape(r))