In [None]:
cd ..

In [None]:
# load packages
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import random
import yaml
from munch import Munch
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torchaudio
import librosa
import matplotlib.pyplot as plt
from utils.ASR.models import ASRCNN
from utils.JDC.model import JDCNet
from models import Generator, MappingNetwork, StyleEncoder
import soundfile as sf
import IPython.display as ipd
import pyworld
from tqdm import tqdm
% matplotlib inline

In [None]:
speakers = ['F101', 'F102', 'F103', 'F104', 'F105', 'F106', 'F107', 'F108', 'F109', 'F110',
            'M101', 'M102', 'M103', 'M104', 'M105', 'M106', 'M107', 'M108', 'M109', 'M110',
            'FAF', 'FFS', 'FKM', 'FKN', 'FKS', 'FMS', 'FSU', 'FTK', 'FYM', 'FYN',
            'MAU', 'MHT', 'MMS', 'MMY', 'MNM', 'MSH', 'MTK', 'MTM', 'MTT', 'MXM']
print(len(speakers))
to_mel = torchaudio.transforms.MelSpectrogram(
    n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
mean, std = -4, 4


def preprocess(wave):
    wave_tensor = torch.from_numpy(wave).float()
    mel_tensor = to_mel(wave_tensor)
    mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
    return mel_tensor


def build_model(model_params={}):
    args = Munch(model_params)
    generator = Generator(args.dim_in, args.style_dim, args.max_conv_dim, w_hpf=args.w_hpf, F0_channel=args.F0_channel)
    mapping_network = MappingNetwork(args.latent_dim, args.style_dim, args.num_domains, hidden_dim=args.max_conv_dim)
    style_encoder = StyleEncoder(args.dim_in, args.style_dim, args.num_domains, args.max_conv_dim)

    nets_ema = Munch(generator=generator,
                     mapping_network=mapping_network,
                     style_encoder=style_encoder)

    return nets_ema


def compute_style(model, speaker_dicts):
    reference_embeddings = {}
    for key, (path, speaker) in speaker_dicts.items():
        if path == "":
            label = torch.LongTensor([speaker]).to('cuda')
            latent_dim = model.mapping_network.shared[0].in_features
            ref = model.mapping_network(torch.randn(1, latent_dim).to('cuda'), label)
        else:
            print(path)
            wave, sr = librosa.load(path, sr=24000)
            audio, index = librosa.effects.trim(wave, top_db=30)
            if sr != 24000:
                wave = librosa.resample(wave, sr, 24000)
            mel_tensor = preprocess(wave).to('cuda')

            with torch.no_grad():
                label = torch.LongTensor([speaker])
                ref = model.style_encoder(mel_tensor.unsqueeze(1), label)
        reference_embeddings[key] = (ref, label)

    return reference_embeddings


## Load models

In [None]:
# load F0 model
F0_model = JDCNet(num_class=1, seq_len=192)
params = torch.load("Utils/JDC/bst.t7")['net']
F0_model.load_state_dict(params)
_ = F0_model.eval()
F0_model = F0_model.to('cuda')

# load vocoder
from parallel_wavegan.utils import load_model

vocoder = load_model("Vocoder/checkpoint-400000steps.pkl").to('cuda').eval()
vocoder.remove_weight_norm()
_ = vocoder.eval()

# load starganv2
model_path = 'Models/atr/epoch_00032.pth'

with open('Configs/config.yml') as f:
    starganv2_config = yaml.safe_load(f)
starganv2 = build_model(model_params=starganv2_config["model_params"])
params = torch.load(model_path, map_location='cpu')
print("Epochs:", params["epochs"])

params = params['model_ema']
_ = [starganv2[key].load_state_dict(params[key]) for key in starganv2]
_ = [starganv2[key].eval() for key in starganv2]
starganv2.style_encoder = starganv2.style_encoder.to('cuda')
starganv2.mapping_network = starganv2.mapping_network.to('cuda')
starganv2.generator = starganv2.generator.to('cuda')

In [None]:
speakers_normal = ['F101', 'F102', 'F103', 'F104', 'F105', 'F106', 'F107', 'F108', 'F109', 'F110',
                   'M101', 'M102', 'M103', 'M104', 'M105', 'M106', 'M107', 'M108', 'M109', 'M110']
speakers_prof = ['FAF', 'FFS', 'FKM', 'FKN', 'FKS', 'FMS', 'FSU', 'FTK', 'FYM', 'FYN',
                 'MAU', 'MHT', 'MMS', 'MMY', 'MNM', 'MSH', 'MTK', 'MTM', 'MTT', 'MXM']
speakers_male_prof = ['MAU', 'MHT', 'MMS', 'MMY', 'MNM', 'MSH', 'MTK', 'MTM', 'MTT', 'MXM']
speakers_female_prof = ['FAF', 'FFS', 'FKM', 'FKN', 'FKS', 'FMS', 'FSU', 'FTK', 'FYM', 'FYN']
speakers_male_normal = ['M101', 'M102', 'M103', 'M104', 'M105', 'M106', 'M107', 'M108', 'M109', 'M110']
speakers_female_normal = ['F101', 'F102', 'F103', 'F104', 'F105', 'F106', 'F107', 'F108', 'F109', 'F110']

## Generate speaker embedding

In [None]:
# no reference, using mapping network
speaker_dicts = {}
for s in speakers:
    speaker_dicts[s] = ("data/ATR_processed/wav24/%s/1.wav" % s,
                        speakers.index(s))
reference_embeddings = compute_style(starganv2, speaker_dicts)
embedding = np.array([reference_embeddings[k][0].squeeze().cpu().numpy() for k in reference_embeddings])
label = list(reference_embeddings.keys())
print(embedding.shape)

from sklearn.decomposition import PCA
pca = PCA(n_components=6, svd_solver='arpack')
emb_pca = pca.fit_transform(embedding)
print(emb_pca.shape)

## Interactive demo

In [None]:
from ipywidgets import interact, interact_manual
ref, _ = reference_embeddings['M105']
print(ref.shape)
ref_pca = pca.transform(ref.cpu().numpy())
ref_pca = ref_pca.squeeze()
ref_pca_max = np.max(emb_pca, axis=0)
ref_pca_min = np.min(emb_pca, axis=0)
# ref_pca_max[2]
print(ref_pca_min.shape)

In [None]:
wav_path = "data/samples/M105/M105SF_A01.AD.wav"
audio, source_sr = librosa.load(wav_path, sr=24000)
source = preprocess(audio).to('cuda:0')
with torch.no_grad():
    f0_feat = F0_model.get_feature_GAN(source.unsqueeze(1))

In [None]:
@interact(emb_dim_0=(ref_pca_min[0], ref_pca_max[0], 0.1),
#           emb_dim_1=(ref_pca_min[1], ref_pca_max[1], 0.1),
          emb_dim_1=(-10, 10, 0.1),
          emb_dim_2=(ref_pca_min[2], ref_pca_max[2], 0.1),
          emb_dim_3=(ref_pca_min[3], ref_pca_max[3], 0.1),
          emb_dim_4=(ref_pca_min[4], ref_pca_max[4], 0.1))
def voice_conversion(emb_dim_0=ref_pca[0],
                     emb_dim_1=ref_pca[1],
                     emb_dim_2=ref_pca[2],
                     emb_dim_3=ref_pca[3],
                     emb_dim_4=ref_pca[4]):
    tar_emb_pca = ref_pca
    tar_emb_pca[:5] = np.array([emb_dim_0, emb_dim_1, emb_dim_2, emb_dim_3, emb_dim_4])
    tar_emb = torch.from_numpy(pca.inverse_transform(tar_emb_pca)).float().cuda()[None,]
    with torch.no_grad():
        global mel_hat
        mel_hat = starganv2.generator(source.unsqueeze(1), tar_emb, F0=f0_feat).squeeze().cpu().numpy()
    plt.figure(dpi=100, figsize=(5, 3))
    plt.imshow(mel_hat, aspect='auto', origin='lower', cmap='inferno')
    plt.colorbar()
    plt.title("Converted")
    plt.show()

## Synthesize waveform

In [None]:
with torch.no_grad():
    out = torch.from_numpy(mel_hat).float()[None,].transpose(-1, -2).squeeze().to('cuda')
    y_out = vocoder.inference(out)
    y_out = y_out.view(-1).cpu().numpy()
print("Original")
ipd.display(ipd.Audio(audio, rate=24000))
print("Converted")
ipd.display(ipd.Audio(y_out, rate=24000))
