In [17]:
from model.wavenet_model import *
from data.dataset import NpssDataset
import hparams
import pyworld as pw
import matplotlib.pyplot as plt
import numpy as np
import soundfile as sf
import librosa
from data.data_util import decode_harmonic
from data.preprocess import process_wav

fft_size = 2048

def load_latest_model_from(mtype, location):

    files = [location + "/" + f for f in os.listdir(location)]
    newest_file = max(files, key=os.path.getctime)
    #debug
#     if mtype == 0:
#         newest_file = 'snapshots/harmonic/harm_800_2019-04-29_12-00-53'
    # else:
    #     newest_file = '/home/sean/pythonProj/torch_npss/snapshots/aperiodic/ap_server1649'


    print("load model " + newest_file)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if mtype == 0:
        hparam = hparams.create_harmonic_hparams()
    elif mtype == 1:
        hparam = hparams.create_aperiodic_hparams()
    else:
        hparam = hparams.create_vuv_hparams()

    model = WaveNetModel(hparam, device).to(device)
    states = torch.load(newest_file)
    model.load_state_dict(states['state_dict'])

    return model


def load_timbre(path, m_type, mx, mn):
    load_t = np.load(path).astype(np.double)

    load_t = load_t * (mx - mn) + mn
    
    plt.imshow(np.transpose(load_t), aspect='auto', origin='bottom', interpolation='none')
    plt.show()
    
    decode_sp = decode_harmonic(load_t, fft_size)
    if m_type == 1:
        decode_sp = pw.decode_aperiodicity(load_t, 32000, fft_size)

    return decode_sp


#  type 0:harmonic, 1:aperiodic,
def generate_timbre(m_type, mx, mn, condition, cat_input=None, init_input=None):
    model_path = 'snapshots/harmonic0_0003'
    if m_type == 1:
        model_path = 'snapshots/aperiodic'
    model = load_latest_model_from(m_type, model_path)
    raw_gen = model.generate(condition, cat_input, init_input)
    sample = raw_gen.transpose(0, 1).cpu().numpy().astype(np.double) * (mx - mn) + mn

    plt.imshow(np.transpose(sample), aspect='auto', origin='bottom', interpolation='none')
    plt.show()
    
    decode_sp = None
    if m_type == 0:
        decode_sp = decode_harmonic(sample, fft_size)
    elif m_type == 1:
        decode_sp = pw.decode_aperiodicity(sample, 32000, fft_size)

    return decode_sp, raw_gen

def generate_vuv(condition, cat_input, init_input=None):
    model_path = 'snapshots/vuv'
    model = load_latest_model_from(2, model_path)
    gen = model.generate(condition, cat_input, init_input).squeeze()

    return gen.cpu().numpy().astype(np.uint8)


def get_ap_cat():

    wav_path = 'data/timbre_model/test/sp/nitech_jp_song070_f001_015_sp.npy'

    code_sp = np.load(wav_path).astype(np.double)
    return torch.Tensor(code_sp).transpose(0, 1)

def get_vuv_cat():
    wav_path = 'data/timbre_model/test/sp/nitech_jp_song070_f001_015_sp.npy'

    code_sp = np.load(wav_path).astype(np.double)
    sp_cat = torch.Tensor(code_sp).transpose(0, 1)

    wav_path = 'data/timbre_model/test/ap/nitech_jp_song070_f001_015_ap.npy'

    code_sp = np.load(wav_path).astype(np.double)
    ap_cat = torch.Tensor(code_sp).transpose(0, 1)

    cat = torch.cat((ap_cat, sp_cat), 0)
    return cat



def get_first_input(song_name):

    wav_path = 'data/timbre_model/test/sp/nitech_jp_song070_f001_'+song_name+'_sp.npy'
    #wav_path = '/home/sean/pythonProj/torch_npss/data/timbre_model/train/ap/nitech_jp_song070_f001_055_ap.npy'

    code_sp = np.load(wav_path).astype(np.double)
    return torch.Tensor(code_sp).transpose(0, 1)


def get_condition(song_name):

    c_path = 'data/timbre_model/test/condition/nitech_jp_song070_f001_'+song_name+'_condi.npy'
    conditon = np.load(c_path).astype(np.float)
    return torch.Tensor(conditon).transpose(0, 1)




In [18]:
def gen_song(song_name, t, epoch):    
    [sp_min, sp_max, ap_min, ap_max] = np.load('data/timbre_model/min_max_record.npy')
    condi = get_condition(song_name)
    #cat_input = get_ap_cat()
    #fist_input = get_first_input(song_name)
    sp, raw_sp = generate_timbre(0, sp_max, sp_min, condi, None, None)

    plt.imshow(np.log(np.transpose(sp)), aspect='auto', origin='bottom', interpolation='none')
    plt.show()

#     sp1 = load_timbre('data/timbre_model/test/sp/nitech_jp_song070_f001_'+song_name+'_sp.npy', 0, sp_max, sp_min)

#     plt.imshow(np.log(np.transpose(sp1)), aspect='auto', origin='bottom', interpolation='none')
#     plt.show()
####################################################################################################
    ap, raw_ap = generate_timbre(1, ap_max, ap_min, condi, raw_sp, None)
    
    plt.imshow(np.log(np.transpose(ap)), aspect='auto', origin='bottom', interpolation='none')
    plt.show()
    
#     ap1 = load_timbre('data/timbre_model/test/ap/nitech_jp_song070_f001_'+song_name+'_ap.npy', 1, ap_max, ap_min)
    
#     plt.imshow(np.log(np.transpose(ap1)), aspect='auto', origin='bottom', interpolation='none')
#     plt.show()

#########################################################################################################
    # vuv_cat = get_vuv_cat()
    gen_cat = torch.cat((raw_ap, raw_sp), 0)
    
    vuv = generate_vuv(condi, gen_cat)
    plt.plot(vuv)
    plt.show()
    
#     vuv1 = np.load('data/timbre_model/test/vuv/nitech_jp_song070_f001_'+song_name+'_vuv.npy')
#     plt.plot(vuv1)
#     plt.show()

    path = 'data/cut_raw/nitech_jp_song070_f001_'+song_name+'.raw'
    _f0, _sp, code_sp, _ap, code_ap = process_wav(path)
    # 合成原始语音
    synthesized = pw.synthesize(_f0, _sp, _ap, 32000, pw.default_frame_period)
    # 1.输出原始语音
    sf.write('./data/gen_wav/'+epoch+'epoch_'+t+'_'+song_name+'.wav', synthesized, 32000)

song_name = '040_0'
t = 'origin_allow0.05'
epoch = '00_1650'

#gen_song(song_name, t, epoch)

for i in range(12):
    song_name = '040_'+str(i)
    gen_song(song_name, t, epoch)

load model snapshots/harmonic0_0003/harm0_0003_1649_2019-04-30_19-11-32
one generating step does take approximately 0.00586141586303711 seconds)


KeyboardInterrupt: 