<a href="https://colab.research.google.com/github/yongsun-yoon/deep-learning-paper-implementation/blob/main/06-speech/NANSY.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# NANSY

## 0. Info

### paper
* title: Neural Analysis and Synthesis: Reconstructing Speech from Self-Supervised Representations
* authors: Hyeong-Seok Choi et al.
* url: https://arxiv.org/abs/2110.14513

### feats
* dataset: AI Hub

### refs
* https://github.com/dhchoi99/NANSY
* https://github.com/jik876/hifi-gan

## 1. Setup

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
import os
import math
import wandb
import random
import IPython
import easydict
import numpy as np
from glob import glob
from tqdm.auto import tqdm
from omegaconf import OmegaConf, DictConfig

import librosa
import soundfile as sf
from librosa.util import normalize
from librosa.filters import mel as librosa_mel_fn
import parselmouth
import scipy.signal
import torchaudio.functional as AF

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import Wav2Vec2ForPreTraining

In [None]:
cfg = easydict.EasyDict(
    device = 'cuda:2'
)

## 2. Data

### 2.1. functional

In [None]:
PRAAT_CHANGEGENDER_PITCHMEDIAN_DEFAULT = 0.0
PRAAT_CHANGEGENDER_FORMANTSHIFTRATIO_DEFAULT = 1.0
PRAAT_CHANGEGENDER_PITCHSHIFTRATIO_DEFAULT = 1.0
PRAAT_CHANGEGENDER_PITCHRANGERATIO_DEFAULT = 1.0
PRAAT_CHANGEGENDER_DURATIONFACTOR_DEFAULT = 1.0


def wav_to_Sound(wav, sampling_frequency: int = 22050) -> parselmouth.Sound:
    r""" load wav file to parselmouth Sound file
    # __init__(self: parselmouth.Sound, other: parselmouth.Sound) -> None \
    # __init__(self: parselmouth.Sound, values: numpy.ndarray[numpy.float64], sampling_frequency: Positive[float] = 44100.0, start_time: float = 0.0) -> None \
    # __init__(self: parselmouth.Sound, file_path: str) -> None
    returns:
        sound: parselmouth.Sound
    """
    if isinstance(wav, parselmouth.Sound):
        sound = wav
    elif isinstance(wav, np.ndarray):
        sound = parselmouth.Sound(wav, sampling_frequency=sampling_frequency)
    elif isinstance(wav, list):
        wav_np = np.asarray(wav)
        sound = parselmouth.Sound(np.asarray(wav_np), sampling_frequency=sampling_frequency)
    else:
        raise NotImplementedError
    return sound


def wav_to_Tensor(wav) -> torch.Tensor:
    if isinstance(wav, np.ndarray):
        wav_tensor = torch.from_numpy(wav)
    elif isinstance(wav, torch.Tensor):
        wav_tensor = wav
    elif isinstance(wav, parselmouth.Sound):
        wav_np = wav.values
        wav_tensor = torch.from_numpy(wav_np)
    else:
        raise NotImplementedError
    return wav_tensor


def get_pitch_median(wav, sr: int = None):
    sound = wav_to_Sound(wav, sr)
    pitch = None
    pitch_median = PRAAT_CHANGEGENDER_PITCHMEDIAN_DEFAULT

    try:
        pitch = parselmouth.praat.call(sound, "To Pitch", 0.8 / 75, 75, 600)
        pitch_median = parselmouth.praat.call(pitch, "Get quantile", 0.0, 0.0, 0.5, "Hertz")
    except Exception as e:
        raise e
        pass

    return pitch, pitch_median


def change_gender(
        sound, pitch=None,
        formant_shift_ratio: float = PRAAT_CHANGEGENDER_FORMANTSHIFTRATIO_DEFAULT,
        new_pitch_median: float = PRAAT_CHANGEGENDER_PITCHMEDIAN_DEFAULT,
        pitch_range_ratio: float = PRAAT_CHANGEGENDER_PITCHRANGERATIO_DEFAULT,
        duration_factor: float = PRAAT_CHANGEGENDER_DURATIONFACTOR_DEFAULT, ) -> parselmouth.Sound:
    try:
        if pitch is None:
            new_sound = parselmouth.praat.call(
                sound, "Change gender", 75, 600,
                formant_shift_ratio,
                new_pitch_median,
                pitch_range_ratio,
                duration_factor
            )
        else:
            new_sound = parselmouth.praat.call(
                (sound, pitch), "Change gender",
                formant_shift_ratio,
                new_pitch_median,
                pitch_range_ratio,
                duration_factor
            )
    except Exception as e:
        raise e

    return new_sound


def apply_formant_and_pitch_shift(
        sound: parselmouth.Sound,
        formant_shift_ratio: float = PRAAT_CHANGEGENDER_FORMANTSHIFTRATIO_DEFAULT,
        pitch_shift_ratio: float = PRAAT_CHANGEGENDER_PITCHSHIFTRATIO_DEFAULT,
        pitch_range_ratio: float = PRAAT_CHANGEGENDER_PITCHRANGERATIO_DEFAULT,
        duration_factor: float = PRAAT_CHANGEGENDER_DURATIONFACTOR_DEFAULT) -> parselmouth.Sound:
    r"""uses praat 'Change Gender' backend to manipulate pitch and formant
        'Change Gender' function: praat -> Sound Object -> Convert -> Change Gender
        see Help of Praat for more details
        # https://github.com/YannickJadoul/Parselmouth/issues/25#issuecomment-608632887 might help
    """

    # pitch = sound.to_pitch()
    pitch = None
    new_pitch_median = PRAAT_CHANGEGENDER_PITCHMEDIAN_DEFAULT
    if pitch_shift_ratio != 1.:
        try:
            pitch, pitch_median = get_pitch_median(sound, None)
            new_pitch_median = pitch_median * pitch_shift_ratio

            # https://github.com/praat/praat/issues/1926#issuecomment-974909408
            pitch_minimum = parselmouth.praat.call(pitch, "Get minimum", 0.0, 0.0, "Hertz", "Parabolic")
            newMedian = pitch_median * pitch_shift_ratio
            scaledMinimum = pitch_minimum * pitch_shift_ratio
            resultingMinimum = newMedian + (scaledMinimum - newMedian) * pitch_range_ratio
            if resultingMinimum < 0:
                new_pitch_median = PRAAT_CHANGEGENDER_PITCHMEDIAN_DEFAULT
                pitch_range_ratio = PRAAT_CHANGEGENDER_PITCHRANGERATIO_DEFAULT

            if math.isnan(new_pitch_median):
                new_pitch_median = PRAAT_CHANGEGENDER_PITCHMEDIAN_DEFAULT
                pitch_range_ratio = PRAAT_CHANGEGENDER_PITCHRANGERATIO_DEFAULT

        except Exception as e:
            raise e

    new_sound = change_gender(
        sound, pitch,
        formant_shift_ratio, new_pitch_median,
        pitch_range_ratio, duration_factor)

    return new_sound


# fs & pr
def formant_and_pitch_shift(sound: parselmouth.Sound) -> parselmouth.Sound:
    r"""calculate random factors and apply formant and pitch shift
    designed for formant shifting(fs) and pitch randomization(pr) in the paper
    """
    formant_shifting_ratio = random.uniform(1, 1.4)
    use_reciprocal = random.uniform(-1, 1) > 0
    if use_reciprocal:
        formant_shifting_ratio = 1 / formant_shifting_ratio

    pitch_shift_ratio = random.uniform(1, 2)
    use_reciprocal = random.uniform(-1, 1) > 0
    if use_reciprocal:
        pitch_shift_ratio = 1 / pitch_shift_ratio

    pitch_range_ratio = random.uniform(1, 1.5)
    use_reciprocal = random.uniform(-1, 1) > 0
    if use_reciprocal:
        pitch_range_ratio = 1 / pitch_range_ratio

    sound_new = apply_formant_and_pitch_shift(
        sound,
        formant_shift_ratio=formant_shifting_ratio,
        pitch_shift_ratio=pitch_shift_ratio,
        pitch_range_ratio=pitch_range_ratio,
        duration_factor=1.
    )
    return sound_new


# fs
def formant_shift(sound: parselmouth.Sound) -> parselmouth.Sound:
    """designed for formant shifting(fs) in the paper
    Args:
        sound: parselmouth Sound object
    Returns:
    """
    formant_shifting_ratio = random.uniform(1, 1.4)
    use_reciprocal = random.uniform(-1, 1) > 0
    if use_reciprocal:
        formant_shifting_ratio = 1 / formant_shifting_ratio

    sound_new = apply_formant_and_pitch_shift(
        sound,
        formant_shift_ratio=formant_shifting_ratio,
    )
    return sound_new


def power_ratio(r: float, a: float, b: float):
    return a * math.pow((b / a), r)


# peq
def parametric_equalizer(wav: torch.Tensor, sr: int) -> torch.Tensor:
    cutoff_low_freq = 60.
    cutoff_high_freq = 10000.

    q_min = 2
    q_max = 5

    num_filters = 8 + 2  # 8 for peak, 2 for high/low
    key_freqs = [
        power_ratio(float(z) / (num_filters), cutoff_low_freq, cutoff_high_freq)
        for z in range(num_filters)
    ]
    Qs = [
        power_ratio(random.uniform(0, 1), q_min, q_max)
        for _ in range(num_filters)
    ]
    gains = [random.uniform(-12, 12) for _ in range(num_filters)]

    # peak filters
    for i in range(1, 9):
        wav = apply_iir_filter(
            wav,
            ftype='peak',
            dBgain=gains[i],
            cutoff_freq=key_freqs[i],
            sample_rate=sr,
            Q=Qs[i]
        )

    # high-shelving filter
    wav = apply_iir_filter(
        wav,
        ftype='high',
        dBgain=gains[-1],
        cutoff_freq=key_freqs[-1],
        sample_rate=sr,
        Q=Qs[-1]
    )

    # low-shelving filter
    wav = apply_iir_filter(
        wav,
        ftype='low',
        dBgain=gains[0],
        cutoff_freq=key_freqs[0],
        sample_rate=sr,
        Q=Qs[0]
    )

    return wav


# implemented using the cookbook https://webaudio.github.io/Audio-EQ-Cookbook/audio-eq-cookbook.html
def lowShelf_coeffs(dBgain, cutoff_freq, sample_rate, Q):
    A = math.pow(10, dBgain / 40.)

    w0 = 2 * math.pi * cutoff_freq / sample_rate
    alpha = math.sin(w0) / 2 / Q
    # alpha = alpha / math.sqrt(2) * math.sqrt(A + 1 / A)

    b0 = A * ((A + 1) - (A - 1) * math.cos(w0) + 2 * math.sqrt(A) * alpha)
    b1 = 2 * A * ((A - 1) - (A + 1) * math.cos(w0))
    b2 = A * ((A + 1) - (A - 1) * math.cos(w0) - 2 * math.sqrt(A) * alpha)

    a0 = (A + 1) + (A - 1) * math.cos(w0) + 2 * math.sqrt(A) * alpha
    a1 = -2 * ((A - 1) + (A + 1) * math.cos(w0))
    a2 = (A + 1) + (A - 1) * math.cos(w0) - 2 * math.sqrt(A) * alpha
    return b0, b1, b2, a0, a1, a2


def highShelf_coeffs(dBgain, cutoff_freq, sample_rate, Q):
    A = math.pow(10, dBgain / 40.)

    w0 = 2 * math.pi * cutoff_freq / sample_rate
    alpha = math.sin(w0) / 2 / Q
    # alpha = alpha / math.sqrt(2) * math.sqrt(A + 1 / A)

    b0 = A * ((A + 1) + (A - 1) * math.cos(w0) + 2 * math.sqrt(A) * alpha)
    b1 = -2 * A * ((A - 1) + (A + 1) * math.cos(w0))
    b2 = A * ((A + 1) + (A - 1) * math.cos(w0) - 2 * math.sqrt(A) * alpha)

    a0 = (A + 1) - (A - 1) * math.cos(w0) + 2 * math.sqrt(A) * alpha
    a1 = 2 * ((A - 1) - (A + 1) * math.cos(w0))
    a2 = (A + 1) - (A - 1) * math.cos(w0) - 2 * math.sqrt(A) * alpha
    return b0, b1, b2, a0, a1, a2


def peaking_coeffs(dBgain, cutoff_freq, sample_rate, Q):
    A = math.pow(10, dBgain / 40.)

    w0 = 2 * math.pi * cutoff_freq / sample_rate
    alpha = math.sin(w0) / 2 / Q
    # alpha = alpha / math.sqrt(2) * math.sqrt(A + 1 / A)

    b0 = 1 + alpha * A
    b1 = -2 * math.cos(w0)
    b2 = 1 - alpha * A

    a0 = 1 + alpha / A
    a1 = -2 * math.cos(w0)
    a2 = 1 - alpha / A
    return b0, b1, b2, a0, a1, a2


def apply_iir_filter(wav: torch.Tensor, ftype, dBgain, cutoff_freq, sample_rate, Q, torch_backend=True):
    if ftype == 'low':
        b0, b1, b2, a0, a1, a2 = lowShelf_coeffs(dBgain, cutoff_freq, sample_rate, Q)
    elif ftype == 'high':
        b0, b1, b2, a0, a1, a2 = highShelf_coeffs(dBgain, cutoff_freq, sample_rate, Q)
    elif ftype == 'peak':
        b0, b1, b2, a0, a1, a2 = peaking_coeffs(dBgain, cutoff_freq, sample_rate, Q)
    else:
        raise NotImplementedError
    if torch_backend:
        return_wav = AF.biquad(wav, b0, b1, b2, a0, a1, a2)
    else:
        # https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.lfilter_zi.html
        wav_numpy = wav.numpy()
        b = np.asarray([b0, b1, b2])
        a = np.asarray([a0, a1, a2])
        zi = scipy.signal.lfilter_zi(b, a) * wav_numpy[0]
        return_wav, _ = scipy.signal.lfilter(b, a, wav_numpy, zi=zi)
        return_wav = torch.from_numpy(return_wav)
    return return_wav


peq = parametric_equalizer
fs = formant_shift


def g(wav: torch.Tensor, sr: int) -> torch.Tensor:
    r"""sequentially apply peq and fs
    """
    wav = peq(wav, sr)
    wav_numpy = wav.numpy()

    sound = wav_to_Sound(wav_numpy, sampling_frequency=sr)
    sound = formant_shift(sound)

    wav = torch.from_numpy(sound.values).float().squeeze(0)
    return wav


def f(wav: torch.Tensor, sr: int) -> torch.Tensor:
    r"""sequentially apply peq, pr and fs
    """
    wav = peq(wav, sr)
    wav_numpy = wav.numpy()

    sound = wav_to_Sound(wav_numpy, sampling_frequency=sr)
    sound = formant_and_pitch_shift(sound)

    wav = torch.from_numpy(sound.values).float().squeeze(0)
    return wav

### 2.2. mel

In [None]:
MAX_WAV_VALUE = 32768.0
mel_basis = {}
hann_window = {}

def dynamic_range_compression(x, C=1, clip_val=1e-5):
    return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)


def dynamic_range_decompression(x, C=1):
    return np.exp(x) / C


def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
    return torch.log(torch.clamp(x, min=clip_val) * C)
    # return torch.clamp(x, min=clip_val) * C


def dynamic_range_decompression_torch(x, C=1):
    return torch.exp(x) / C


def spectral_normalize_torch(magnitudes):
    output = dynamic_range_compression_torch(magnitudes)
    return output


def spectral_de_normalize_torch(magnitudes):
    output = dynamic_range_decompression_torch(magnitudes)
    return output


def mel_spectrogram(y, n_fft=1024, num_mels=80, sampling_rate=22050, hop_size=256, win_size=1024, fmin=0, fmax=8000, center=False):
    global mel_basis, hann_window
    if fmax not in mel_basis:
        mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
        mel_basis[str(fmax) + '_' + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
        hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)

    y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
                                mode='reflect')
    y = y.squeeze(1)

    spec = torch.stft(y, n_fft=n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
                      center=center, pad_mode='reflect', normalized=False, onesided=True)

    spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))

    spec = torch.matmul(mel_basis[str(fmax) + '_' + str(y.device)], spec)
    spec = spectral_normalize_torch(spec)

    return spec

### 2.3. dataset

In [None]:
def pad_audio(x: torch.Tensor, length: int, value: float = 0., pad_at: str = 'end') -> torch.Tensor:
    r"""pads value to audio data, at last dimension
    params:
        x: torch.Tensor of shape (..., T)
        length: int, length to pad
        value: float, value to pad
        pad_at: str, 'start' or 'end'
    returns:
        y: padded torch.Tensor of shape (..., T+length)
    """
    # x: (..., T)
    pad_at = pad_at.strip().lower()
    if pad_at == 'end':
        y = torch.cat([
            x, torch.ones(*x.shape[:-1], length) * value
        ], dim=-1)
    elif pad_at == 'start':
        y = torch.cat([
            torch.ones(*x.shape[:-1], length) * value, x
        ], dim=-1)
    else:
        raise NotImplementedError
    return y


def crop_audio(x: torch.Tensor, start: int, end: int, padding_value: float = 0.) -> torch.Tensor:
    r"""crop audio data at last dimension from start to end, automatically pad with padding_value
    params:
        x: torch.Tensor of shape (..., T)
        start: int, position to crop
        end: int, position to crop
        padding_value: float, value for padding when needed
    returns:
        y: torch.Tensor of shape (..., end-start)
    """
    if start < 0:
        if end < 0:
            y = torch.ones(size=(*x.shape[:-1], end - start), dtype=torch.float, device=x.device) * padding_value
        elif end > x.shape[-1]:
            y = x
            y = pad_audio(y, -start, padding_value, pad_at='start')
            y = pad_audio(y, end - x.shape[-1], padding_value, pad_at='end')
        else:
            y = x[..., :end]
            y = pad_audio(y, -start, padding_value, pad_at='start')
    elif end > x.shape[-1]:
        if start > x.shape[-1]:
            y = torch.ones(size=(*x.shape[:-1], end - start), dtype=torch.float, device=x.device) * padding_value
        else:
            y = x[..., start:]
            y = pad_audio(y, end - x.shape[-1], padding_value, pad_at='end')
    else:
        y = x[..., start:end]
    assert y.shape[-1] == end - start, f'{x.shape}, {start}, {end}, {y.shape}'
    return y

In [None]:
class Dataset(torch.utils.data.Dataset):
    yin_window_22k = 34816
    minimum_audio_length = yin_window_22k
    minimum_mel_length = 136
    mel_window = 128
    hop_size = 256
    audio_window_22k = 32768 # mel_window * hop_size
    audio_window_16k = 23778
    
    
    def __init__(self, datadir, speaker_to_files={}):
        self.datadir = datadir
        self.speakers = os.listdir(datadir)
        self.speaker_to_files = speaker_to_files
    
    def __len__(self):
        return 1000
    
    def get_files(self, speaker):
        files = self.speaker_to_files.get(speaker)
        if files is None:
            files = glob(f'{self.datadir}/{speaker}/*/*.wav')
            self.speaker_to_files[speaker] = files
        return files
    
    
    def get_time_idxs(self, mel_start: int):
        mel_end = mel_start + self.mel_window
        t_start = mel_start * self.hop_size / 22050.
        w_start_22k = int(t_start * 22050)
        w_start_16k = int(t_start * 16000)
        w_end_22k = w_start_22k + self.audio_window_22k
        w_end_22k_yin = w_start_22k + self.yin_window_22k
        w_end_16k = w_start_16k + self.audio_window_16k
        return mel_start, mel_end, t_start, w_start_16k, w_start_22k, w_end_16k, w_end_22k, w_end_22k_yin
    
    
    def get_pos(self, speaker):
        data = {}
        files = self.get_files(speaker)
        file = np.random.choice(files)
        
        wav_22k, sr = librosa.load(file, sr=22050)
        wav_16k = librosa.resample(wav_22k, orig_sr=22050, target_sr=16000)

        wav_22k = torch.from_numpy(wav_22k).float()
        wav_16k = torch.from_numpy(wav_16k).float()

        if wav_22k.shape[-1] < self.minimum_audio_length:
            wav_22k = F.pad(wav_22k, (0, self.minimum_audio_length - wav_22k.size(-1)), mode='constant', value=0.0)

        _, pitch_median = get_pitch_median(wav_22k.numpy(), sr=22050)
        data['ptich_median_pos'] = pitch_median

        mel_22k = mel_spectrogram(wav_22k.unsqueeze(0))[0] # (80, T)
        mel_start = random.randint(0, mel_22k.size(-1) - self.minimum_mel_length)
        time_idxs = self.get_time_idxs(mel_start)
        data['gt_mel_22k'] = crop_audio(mel_22k, time_idxs[0], time_idxs[1])

        data['gt_audio_16k'] = crop_audio(wav_16k, time_idxs[3], time_idxs[5])
        wav_16k_f = f(wav_16k, sr=16000)
        data['gt_audio_16k_f'] = crop_audio(wav_16k_f, time_idxs[3], time_idxs[5])

        data['gt_audio_22k'] = crop_audio(wav_22k, time_idxs[4], time_idxs[6])
        wav_22k_g = g(wav_22k, sr=22050)
        data['gt_audio_22k_g'] = crop_audio(wav_22k_g, time_idxs[4], time_idxs[6])
        return data
    

    def get_neg(self, speaker):
        data = {}
        files = self.get_files(speaker)
        file = np.random.choice(files)
        
        wav_22k, sr = librosa.load(file, sr=22050)
        wav_16k = librosa.resample(wav_22k, orig_sr=22050, target_sr=16000)

        wav_22k = torch.from_numpy(wav_22k).float()
        wav_16k = torch.from_numpy(wav_16k).float()

        if wav_22k.shape[-1] < self.minimum_audio_length:
            wav_22k = F.pad(wav_22k, (0, self.minimum_audio_length - wav_22k.size(-1)), mode='constant', value=0.0)

        _, pitch_median = get_pitch_median(wav_22k.numpy(), sr=22050)
        data['ptich_median_neg'] = pitch_median

        mel_22k = mel_spectrogram(wav_22k.unsqueeze(0))[0] # (80, T)
        mel_start = random.randint(0, mel_22k.size(-1) - self.minimum_mel_length)
        time_idxs = self.get_time_idxs(mel_start)

        data['gt_audio_16k_negative'] = crop_audio(wav_16k, time_idxs[3], time_idxs[5])
        data['gt_audio_22k_negative'] = crop_audio(wav_22k, time_idxs[4], time_idxs[6])
        return data
        
        
    def __getitem__(self, idx):
        pos_speaker = np.random.choice(self.speakers)
        neg_speaker = pos_speaker
        while neg_speaker == pos_speaker:
            neg_speaker = np.random.choice(self.speakers)
            
        pos_item = self.get_pos(pos_speaker)
        neg_item = self.get_neg(neg_speaker)
        item = dict(**pos_item, **neg_item)
        return item
    
    def get_data(self, fpath):
        data = {}
        wav_22k, sr = librosa.load(fpath, sr=22050)
        wav_16k = librosa.resample(wav_22k, orig_sr=22050, target_sr=16000)

        wav_22k = torch.from_numpy(wav_22k).float()
        wav_16k = torch.from_numpy(wav_16k).float()

        if wav_22k.shape[-1] < self.minimum_audio_length:
            wav_22k = F.pad(wav_22k, (0, self.minimum_audio_length - wav_22k.size(-1)), mode='constant', value=0.0)

        mel_22k = mel_spectrogram(wav_22k.unsqueeze(0))[0] # (80, T)
        time_idxs = self.get_time_idxs(0)
        
        data['gt_mel_22k'] = crop_audio(mel_22k, time_idxs[0], time_idxs[1]).unsqueeze(0)
        data['gt_audio_16k'] = crop_audio(wav_16k, time_idxs[3], time_idxs[5]).unsqueeze(0)
        data['gt_audio_22k'] = crop_audio(wav_22k, time_idxs[4], time_idxs[6]).unsqueeze(0)
        return data

In [None]:
datadir = '/mnt/tts/ko-aihub/ko-aihub-emotion2021/'
dataset = Dataset(datadir)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)

In [None]:
batch = next(iter(dataloader))
for k,v in batch.items():
    print(k, v.shape)

## 3. Model

### 3.1. Analysis

In [None]:
# yin
def differenceFunction(x, N, tau_max):
    """
    Compute difference function of data x. This corresponds to equation (6) in [1]
    This solution is implemented directly with Numpy fft.
    :param x: audio data
    :param N: length of data
    :param tau_max: integration window size
    :return: difference function
    :rtype: list
    """

    x = np.array(x, np.float64)
    w = x.size
    tau_max = min(tau_max, w)
    x_cumsum = np.concatenate((np.array([0.]), (x * x).cumsum()))
    size = w + tau_max
    p2 = (size // 32).bit_length()
    nice_numbers = (16, 18, 20, 24, 25, 27, 30, 32)
    size_pad = min(x * 2 ** p2 for x in nice_numbers if x * 2 ** p2 >= size)
    fc = np.fft.rfft(x, size_pad)
    conv = np.fft.irfft(fc * fc.conjugate())[:tau_max]
    return x_cumsum[w:w - tau_max:-1] + x_cumsum[w] - x_cumsum[:tau_max] - 2 * conv


def cumulativeMeanNormalizedDifferenceFunction(df, N, eps=1e-8):
    """
    Compute cumulative mean normalized difference function (CMND).
    This corresponds to equation (8) in [1]
    :param df: Difference function
    :param N: length of data
    :return: cumulative mean normalized difference function
    :rtype: list
    """
    np.seterr(divide='ignore', invalid='ignore')
    # scipy method, assert df>0 for all element
    cmndf = df[1:] * np.asarray(list(range(1, N))) / (np.cumsum(df[1:]).astype(float) + eps)
    return np.insert(cmndf, 0, 1)


def differenceFunctionBatch(xs: np.ndarray, N, tau_max):
    """numpy backend batch-wise differenceFunction
    Args:
        xs: audio segments, np.ndarray of shape (B x t)
        N:
        tau_max:
    Returns:
        y: dF. np.ndarray of shape (B x tau_max)
    """
    xs = xs.astype(np.float64)
    w = xs.shape[-1]
    tau_max = min(tau_max, w)
    zeros = np.zeros((xs.shape[0], 1))
    x_cumsum = np.concatenate((np.zeros((xs.shape[0], 1)), (xs * xs).cumsum(axis=-1)), axis=-1)  # B x w
    size = w + tau_max
    p2 = (size // 32).bit_length()
    nice_numbers = (16, 18, 20, 24, 25, 27, 30, 32)
    size_pad = min(x * 2 ** p2 for x in nice_numbers if x * 2 ** p2 >= size)

    convs = []
    for i in range(xs.shape[0]):
        x = xs[i]
        fc = np.fft.rfft(x, size_pad)
        conv = np.fft.irfft(fc * fc.conjugate())[:tau_max]
        convs.append(conv)
    convs = np.asarray(convs)

    y = x_cumsum[:, w:w - tau_max:-1] + x_cumsum[:, w, np.newaxis] - x_cumsum[:, :tau_max] - 2 * convs
    return y


def cumulativeMeanNormalizedDifferenceFunctionBatch(dFs, N, eps=1e-8):
    """numpy backend batch-wise cumulative Mean Normalized Difference Functions
    Args:
        dFs: differenceFunctions. np.ndarray of shape (B x tau_max)
        N:
        eps:
    Returns:
        cMNDFs: np.ndarray of shape (B x tau_max)
    """
    arange = np.asarray(list(range(1, N)))[np.newaxis, ...]
    cumsum = np.cumsum(dFs[:, 1:], axis=-1).astype(float)
    cMNDFs = dFs[:, 1:] * arange / (cumsum + eps)
    cMNDFs = np.concatenate((np.zeros((cMNDFs.shape[0], 1)), cMNDFs), axis=1)
    return cMNDFs


def differenceFunctionTorch(xs: torch.Tensor, N, tau_max) -> torch.Tensor:
    """pytorch backend batch-wise differenceFunction
    has 1e-4 level error with input shape of (32, 22050*1.5)
    Args:
        xs:
        N:
        tau_max:
    Returns:
    """
    xs = xs.double()
    w = xs.shape[-1]
    tau_max = min(tau_max, w)
    zeros = torch.zeros((xs.shape[0], 1))
    x_cumsum = torch.cat(
        (torch.zeros((xs.shape[0], 1), device=xs.device), (xs * xs).cumsum(dim=-1, dtype=torch.double)),
        dim=-1)  # B x w
    size = w + tau_max
    p2 = (size // 32).bit_length()
    nice_numbers = (16, 18, 20, 24, 25, 27, 30, 32)
    size_pad = min(x * 2 ** p2 for x in nice_numbers if x * 2 ** p2 >= size)

    fcs = torch.fft.rfft(xs, n=size_pad, dim=-1)
    convs = torch.fft.irfft(fcs * fcs.conj())[:, :tau_max]
    y1 = torch.flip(x_cumsum[:, w - tau_max + 1:w + 1], dims=[-1])
    y = y1 + x_cumsum[:, w, np.newaxis] - x_cumsum[:, :tau_max] - 2 * convs
    return y


def cumulativeMeanNormalizedDifferenceFunctionTorch(dfs: torch.Tensor, N, eps=1e-8) -> torch.Tensor:
    arange = torch.arange(1, N, device=dfs.device, dtype=torch.float64)
    cumsum = torch.cumsum(dfs[:, 1:], dim=-1, dtype=torch.float64).to(dfs.device)

    cmndfs = dfs[:, 1:] * arange / (cumsum + eps)
    cmndfs = torch.cat(
        (torch.ones(cmndfs.shape[0], 1, device=dfs.device), cmndfs),
        dim=-1)
    return cmndfs

In [None]:
# ecapa
class Conv1D_ReLU_BN(nn.Module):
    def __init__(self, c_in, c_out, ks, stride, padding, dilation):
        super(Conv1D_ReLU_BN, self).__init__()

        self.network = nn.Sequential(
            nn.Conv1d(c_in, c_out, ks, stride, padding, dilation),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(c_out),
        )

    def forward(self, x):
        y = self.network(x)
        return y


class Res2_Conv1D(nn.Module):
    def __init__(self, c, scale, ks, stride, padding, dilation):
        super(Res2_Conv1D, self).__init__()
        assert c % scale == 0
        self.c = c
        self.scale = scale
        self.width = c // scale

        self.convs = []
        self.bns = []

        for i in range(scale - 1):
            self.convs.append(nn.Conv1d(self.width, self.width, ks, stride, padding, dilation))
            self.bns.append(nn.BatchNorm1d(self.width))
        self.convs = nn.ModuleList(self.convs)
        self.bns = nn.ModuleList(self.bns)

    def forward(self, x):
        """
        param x: (B x c x d)
        """

        xs = torch.split(x, self.width, dim=1)  # channel-wise split
        ys = []

        for i in range(self.scale):
            if i == 0:
                x_ = xs[i]
                y_ = x_
            elif i == 1:
                x_ = xs[i]
                y_ = self.bns[i - 1](self.convs[i - 1](x_))
            else:
                x_ = xs[i] + ys[i - 1]
                y_ = self.bns[i - 1](self.convs[i - 1](x_))
            ys.append(y_)

        y = torch.cat(ys, dim=1)  # channel-wise concat
        return y


class Res2_Conv1D_ReLU_BN(nn.Module):
    def __init__(self, channel, scale, ks, stride, padding, dilation):
        super(Res2_Conv1D_ReLU_BN, self).__init__()

        self.network = nn.Sequential(
            Res2_Conv1D(channel, scale, ks, stride, padding, dilation),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(channel),
        )

    def forward(self, x):
        y = self.network(x)
        return y


class SE_Block(nn.Module):
    def __init__(self, c_in, c_mid):
        super(SE_Block, self).__init__()

        self.network = nn.Sequential(
            nn.Linear(c_in, c_mid),
            nn.ReLU(inplace=True),
            nn.Linear(c_mid, c_in),
            nn.Sigmoid(),
        )

    def forward(self, x):
        s = self.network(x.mean(dim=-1))
        y = x * s.unsqueeze(-1)
        return y


class SE_Res2_Block(nn.Module):
    def __init__(self, channel, scale, ks, stride, padding, dilation):
        super(SE_Res2_Block, self).__init__()
        self.network = nn.Sequential(
            Conv1D_ReLU_BN(channel, channel, 1, 1, 0, 1),
            Res2_Conv1D_ReLU_BN(channel, scale, ks, stride, padding, dilation),
            Conv1D_ReLU_BN(channel, channel, 1, 1, 0, 1),
            SE_Block(channel, channel)
        )

    def forward(self, x):
        y = self.network(x) + x
        return y


class AttentiveStatisticPool(nn.Module):
    def __init__(self, c_in, c_mid):
        super(AttentiveStatisticPool, self).__init__()

        self.network = nn.Sequential(
            nn.Conv1d(c_in, c_mid, kernel_size=1),
            nn.Tanh(),  # seems like most implementations uses tanh?
            nn.Conv1d(c_mid, c_in, kernel_size=1),
            nn.Softmax(dim=-1)
        )

    def forward(self, x):
        # x.shape: B x C x t
        alpha = self.network(x)
        mu_hat = torch.sum(alpha * x, dim=-1)
        var = torch.sum(alpha * x ** 2, dim=-1) - mu_hat ** 2
        std_hat = torch.sqrt(var.clamp(min=1e-9))
        y = torch.cat([mu_hat, std_hat], dim=-1)
        # y.shape: B x (c_in*2)
        return y


class ECAPA_TDNN(nn.Module):
    def __init__(self, c_in=80, c_mid=512, c_out=192):
        super(ECAPA_TDNN, self).__init__()

        self.layer1 = Conv1D_ReLU_BN(c_in, c_mid, 5, 1, 2, 1)
        self.layer2 = SE_Res2_Block(c_mid, 8, 3, 1, 2, 2)
        self.layer3 = SE_Res2_Block(c_mid, 8, 3, 1, 3, 3)
        self.layer4 = SE_Res2_Block(c_mid, 8, 3, 1, 4, 4)

        self.network = nn.Sequential(
            # Figure 2 in https://arxiv.org/pdf/2005.07143.pdf seems like groupconv?
            nn.Conv1d(c_mid * 3, 1536, kernel_size=1, groups=3),
            AttentiveStatisticPool(1536, 128),
        )

        self.bn1 = nn.BatchNorm1d(3072)
        self.linear = nn.Linear(3072, c_out)
        self.bn2 = nn.BatchNorm1d(c_out)

    def forward(self, x):
        # x.shape: B x C x t
        y1 = self.layer1(x)
        y2 = self.layer2(y1) + y1
        y3 = self.layer3(y1 + y2) + y1 + y2
        y4 = self.layer4(y1 + y2 + y3) + y1 + y2 + y3

        y = torch.cat([y2, y3, y4], dim=1)  # channel-wise concat
        y = self.network(y)

        y = self.linear(self.bn1(y.unsqueeze(-1)).squeeze(-1))
        y = self.bn2(y.unsqueeze(-1)).squeeze(-1)

        return y

In [None]:
class Linguistic(torch.nn.Module):
    def __init__(self, wav2vec2):
        super().__init__()
        self.wav2vec2 = wav2vec2
        
    def forward(self, x):
        """
        Args:
            x: torch.Tensor of shape (B x t)
        Returns:
            y: torch.Tensor of shape(B x C x t)
        """
        with torch.no_grad():
            outputs = self.wav2vec2(x, output_hidden_states=True)
        y = outputs.hidden_states[12]  # B x t x C(1024)
        y = y.permute((0, 2, 1))  # B x t x C -> B x C x t
        return y
    
    def train(self, mode: bool = True):
        if not isinstance(mode, bool):
            raise ValueError("training mode is expected to be boolean")
        self.training = mode
        # for module in self.children():
        #     module.train(mode)
        return self


class Speaker(torch.nn.Module):
    def __init__(self, wav2vec2):
        super(Speaker, self).__init__()
        self.wav2vec2 = wav2vec2
        self.spk = ECAPA_TDNN(c_in=1024, c_mid=512, c_out=192)

    def forward(self, x):
        """
        Args:
            x: torch.Tensor of shape (B x t)
        Returns:
            y: torch.Tensor of shape (B x 192)
        """
        with torch.no_grad():
            outputs = self.wav2vec2(x, output_hidden_states=True)
        y = outputs.hidden_states[1]  # B x t x C(1024)
        y = y.permute((0, 2, 1))  # B x t x C -> B x C x t
        y = self.spk(y)  # B x C(1024) x t -> B x D(192)
        y = torch.nn.functional.normalize(y, p=2, dim=-1)
        return y

    def train(self, mode: bool = True):
        if not isinstance(mode, bool):
            raise ValueError("training mode is expected to be boolean")
        self.training = mode
        # for module in self.children():
        #     module.train(mode)
        self.spk.train(mode)
        return self


class Energy(torch.nn.Module):
    def forward(self, mel):
        """For the energy feature, we simply took an average from a log-mel spectrogram along the frequency axis.
        Args:
            mel: torch.Tensor of shape (B x t x C)
        Returns:
            y: torch.Tensor of shape (B x 1 x C)
        """
        y = torch.mean(mel, dim=1, keepdim=True)  # B x 1(channel) x t
        return y


class Pitch(torch.nn.Module):
    @staticmethod
    def midi_to_lag(m: int, sr: int, semitone_range: float = 12):
        """converts midi-to-lag, eq. (4)
        Args:
            m: midi
            sr: sample_rate
            semitone_range:
        Returns:
            lag: time lag(tau, c(m)) calculated from midi, eq. (4)
        """
        f = 440 * math.pow(2, (m - 69) / semitone_range)
        lag = sr / f
        return lag

    @staticmethod
    def yingram_from_cmndf(cmndfs: torch.Tensor, ms: list, sr: int = 22050) -> torch.Tensor:
        """ yingram calculator from cMNDFs(cumulative Mean Normalized Difference Functions)
        Args:
            cmndfs: torch.Tensor
                calculated cumulative mean normalized difference function
                for details, see models/yin.py or eq. (1) and (2)
            ms: list of midi(int)
            sr: sampling rate
        Returns:
            y:
                calculated batch yingram
        """
        c_ms = np.asarray([Pitch.midi_to_lag(m, sr) for m in ms])
        c_ms = torch.from_numpy(c_ms).to(cmndfs.device)
        c_ms_ceil = torch.ceil(c_ms).long().to(cmndfs.device)
        c_ms_floor = torch.floor(c_ms).long().to(cmndfs.device)

        y = (cmndfs[:, c_ms_ceil] - cmndfs[:, c_ms_floor]) / (c_ms_ceil - c_ms_floor).unsqueeze(0) * (
                c_ms - c_ms_floor).unsqueeze(0) + cmndfs[:, c_ms_floor]
        return y

    @staticmethod
    def yingram(x: torch.Tensor, W: int = 2048, tau_max: int = 2048, sr: int = 22050, w_step: int = 256):
        """calculates yingram from raw audio (multi segment)
        Args:
            x: raw audio, torch.Tensor of shape (t)
            W: yingram Window Size
            tau_max:
            sr: sampling rate
            w_step: yingram bin step size
        Returns:
            yingram: yingram. torch.Tensor of shape (80 x t')
        """
        # x.shape: t
        w_len = W

        startFrames = list(range(0, x.shape[-1] - w_len, w_step))
        startFrames = np.asarray(startFrames)
        # times = startFrames / sr
        frames = [x[..., t:t + W] for t in startFrames]
        frames_torch = torch.stack(frames, dim=0).to(x.device)

        # If not using gpu, or torch not compatible, implemented numpy batch function is still fine
        dfs = differenceFunctionTorch(frames_torch, frames_torch.shape[-1], tau_max)
        cmndfs = cumulativeMeanNormalizedDifferenceFunctionTorch(dfs, tau_max)

        midis = list(range(5, 85))
        yingram = Pitch.yingram_from_cmndf(cmndfs, midis, sr)
        return yingram

    @staticmethod
    def yingram_batch(x: torch.Tensor, W: int = 2048, tau_max: int = 2048, sr: int = 22050, w_step: int = 256):
        """calculates yingram from batch raw audio.
        currently calculates batch-wise through for loop, but seems it can be implemented to act batch-wise
        Args:
            x: torch.Tensor of shape (B x t)
            W:
            tau_max:
            sr:
            w_step:
        Returns:
            yingram: yingram. torch.Tensor of shape (B x 80 x t')
        """
        batch_results = []
        for i in range(len(x)):
            yingram = Pitch.yingram(x[i], W, tau_max, sr, w_step)
            batch_results.append(yingram)
        result = torch.stack(batch_results, dim=0).float()
        result = result.permute((0, 2, 1)).to(x.device)
        return result


class Analysis(torch.nn.Module):
    def __init__(self, conf=None):
        """joins all analysis modules into one
        Args:
            conf:
        """
        super(Analysis, self).__init__()
        wav2vec2 = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-large-xlsr-53")
        _ = wav2vec2.eval().requires_grad_(False)
        self.wav2vec2 = wav2vec2

        self.linguistic = Linguistic(wav2vec2)
        self.speaker = Speaker(wav2vec2)
        self.energy = Energy()
        self.pitch = Pitch()

### 3.2. Synthesis

In [None]:
# hifigan

from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm

# from utils import init_weights, get_padding

LRELU_SLOPE = 0.1


def init_weights(m, mean=0.0, std=0.01):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        m.weight.data.normal_(mean, std)


def get_padding(kernel_size, dilation=1):
    return int((kernel_size * dilation - dilation) / 2)


class ResBlock1(torch.nn.Module):
    def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
        super(ResBlock1, self).__init__()
        self.h = h
        self.convs1 = nn.ModuleList([
            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
                               padding=get_padding(kernel_size, dilation[0]))),
            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
                               padding=get_padding(kernel_size, dilation[1]))),
            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
                               padding=get_padding(kernel_size, dilation[2])))
        ])
        self.convs1.apply(init_weights)

        self.convs2 = nn.ModuleList([
            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
                               padding=get_padding(kernel_size, 1))),
            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
                               padding=get_padding(kernel_size, 1))),
            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
                               padding=get_padding(kernel_size, 1)))
        ])
        self.convs2.apply(init_weights)

    def forward(self, x):
        for c1, c2 in zip(self.convs1, self.convs2):
            xt = F.leaky_relu(x, LRELU_SLOPE)
            xt = c1(xt)
            xt = F.leaky_relu(xt, LRELU_SLOPE)
            xt = c2(xt)
            x = xt + x
        return x

    def remove_weight_norm(self):
        for l in self.convs1:
            remove_weight_norm(l)
        for l in self.convs2:
            remove_weight_norm(l)


class ResBlock2(torch.nn.Module):
    def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
        super(ResBlock2, self).__init__()
        self.h = h
        self.convs = nn.ModuleList([
            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
                               padding=get_padding(kernel_size, dilation[0]))),
            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
                               padding=get_padding(kernel_size, dilation[1])))
        ])
        self.convs.apply(init_weights)

    def forward(self, x):
        for c in self.convs:
            xt = F.leaky_relu(x, LRELU_SLOPE)
            xt = c(xt)
            x = xt + x
        return x

    def remove_weight_norm(self):
        for l in self.convs:
            remove_weight_norm(l)


class HiFiGANGenerator(torch.nn.Module):
    def __init__(self, h):
        super().__init__()
        self.h = h
        self.num_kernels = len(h.resblock_kernel_sizes)
        self.num_upsamples = len(h.upsample_rates)
        self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
        resblock = ResBlock1 if h.resblock == '1' else ResBlock2

        self.ups = nn.ModuleList()
        for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
            self.ups.append(weight_norm(
                ConvTranspose1d(h.upsample_initial_channel // (2 ** i), h.upsample_initial_channel // (2 ** (i + 1)),
                                k, u, padding=(k - u) // 2)))

        self.resblocks = nn.ModuleList()
        for i in range(len(self.ups)):
            ch = h.upsample_initial_channel // (2 ** (i + 1))
            for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
                self.resblocks.append(resblock(h, ch, k, d))

        self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
        self.ups.apply(init_weights)
        self.conv_post.apply(init_weights)

    def forward(self, x):
        x = self.conv_pre(x)
        for i in range(self.num_upsamples):
            x = F.leaky_relu(x, LRELU_SLOPE)
            x = self.ups[i](x)
            xs = None
            for j in range(self.num_kernels):
                if xs is None:
                    xs = self.resblocks[i * self.num_kernels + j](x)
                else:
                    xs += self.resblocks[i * self.num_kernels + j](x)
            x = xs / self.num_kernels
        x = F.leaky_relu(x)
        x = self.conv_post(x)
        x = torch.tanh(x)

        return x

    def remove_weight_norm(self):
        print('Removing weight norm...')
        for l in self.ups:
            remove_weight_norm(l)
        for l in self.resblocks:
            l.remove_weight_norm()
        remove_weight_norm(self.conv_pre)
        remove_weight_norm(self.conv_post)


class DiscriminatorP(torch.nn.Module):
    def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
        super(DiscriminatorP, self).__init__()
        self.period = period
        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
        self.convs = nn.ModuleList([
            norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
            norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
            norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
            norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
            norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
        ])
        self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))

    def forward(self, x):
        fmap = []

        # 1d to 2d
        b, c, t = x.shape
        if t % self.period != 0:  # pad first
            n_pad = self.period - (t % self.period)
            x = F.pad(x, (0, n_pad), "reflect")
            t = t + n_pad
        x = x.view(b, c, t // self.period, self.period)

        for l in self.convs:
            x = l(x)
            x = F.leaky_relu(x, LRELU_SLOPE)
            fmap.append(x)
        x = self.conv_post(x)
        fmap.append(x)
        x = torch.flatten(x, 1, -1)

        return x, fmap


class MultiPeriodDiscriminator(torch.nn.Module):
    def __init__(self):
        super(MultiPeriodDiscriminator, self).__init__()
        self.discriminators = nn.ModuleList([
            DiscriminatorP(2),
            DiscriminatorP(3),
            DiscriminatorP(5),
            DiscriminatorP(7),
            DiscriminatorP(11),
        ])

    def forward(self, y, y_hat):
        y_d_rs = []
        y_d_gs = []
        fmap_rs = []
        fmap_gs = []
        for i, d in enumerate(self.discriminators):
            y_d_r, fmap_r = d(y)
            y_d_g, fmap_g = d(y_hat)
            y_d_rs.append(y_d_r)
            fmap_rs.append(fmap_r)
            y_d_gs.append(y_d_g)
            fmap_gs.append(fmap_g)

        return y_d_rs, y_d_gs, fmap_rs, fmap_gs


class DiscriminatorS(torch.nn.Module):
    def __init__(self, use_spectral_norm=False):
        super(DiscriminatorS, self).__init__()
        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
        self.convs = nn.ModuleList([
            norm_f(Conv1d(1, 128, 15, 1, padding=7)),
            norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
            norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
            norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
            norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
            norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
            norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
        ])
        self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))

    def forward(self, x):
        fmap = []
        for l in self.convs:
            x = l(x)
            x = F.leaky_relu(x, LRELU_SLOPE)
            fmap.append(x)
        x = self.conv_post(x)
        fmap.append(x)
        x = torch.flatten(x, 1, -1)

        return x, fmap


class MultiScaleDiscriminator(torch.nn.Module):
    def __init__(self):
        super(MultiScaleDiscriminator, self).__init__()
        self.discriminators = nn.ModuleList([
            DiscriminatorS(use_spectral_norm=True),
            DiscriminatorS(),
            DiscriminatorS(),
        ])
        self.meanpools = nn.ModuleList([
            AvgPool1d(4, 2, padding=2),
            AvgPool1d(4, 2, padding=2)
        ])

    def forward(self, y, y_hat):
        y_d_rs = []
        y_d_gs = []
        fmap_rs = []
        fmap_gs = []
        for i, d in enumerate(self.discriminators):
            if i != 0:
                y = self.meanpools[i - 1](y)
                y_hat = self.meanpools[i - 1](y_hat)
            y_d_r, fmap_r = d(y)
            y_d_g, fmap_g = d(y_hat)
            y_d_rs.append(y_d_r)
            fmap_rs.append(fmap_r)
            y_d_gs.append(y_d_g)
            fmap_gs.append(fmap_g)

        return y_d_rs, y_d_gs, fmap_rs, fmap_gs


def feature_loss(fmap_r, fmap_g):
    loss = 0
    for dr, dg in zip(fmap_r, fmap_g):
        for rl, gl in zip(dr, dg):
            loss += torch.mean(torch.abs(rl - gl))

    return loss * 2


def discriminator_loss(disc_real_outputs, disc_generated_outputs):
    loss = 0
    r_losses = []
    g_losses = []
    for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
        r_loss = torch.mean((1 - dr) ** 2)
        g_loss = torch.mean(dg ** 2)
        loss += (r_loss + g_loss)
        r_losses.append(r_loss.item())
        g_losses.append(g_loss.item())

    return loss, r_losses, g_losses


def generator_loss(disc_outputs):
    loss = 0
    gen_losses = []
    for dg in disc_outputs:
        l = torch.mean((1 - dg) ** 2)
        gen_losses.append(l)
        loss += l

    return loss, gen_losses

In [None]:
class ConditionalLayerNorm(nn.Module):
    def __init__(self, embedding_dim: int, normalize_embedding: bool = True):
        super(ConditionalLayerNorm, self).__init__()
        self.normalize_embedding = normalize_embedding

        self.linear_scale = nn.Linear(embedding_dim, 1)
        self.linear_bias = nn.Linear(embedding_dim, 1)

    def forward(self, x, embedding):
        if self.normalize_embedding:
            embedding = torch.nn.functional.normalize(embedding, p=2, dim=-1)
        scale = self.linear_scale(embedding).unsqueeze(-1)  # shape: (B, 1, 1)
        bias = self.linear_bias(embedding).unsqueeze(-1)  # shape: (B, 1, 1)

        out = (x - torch.mean(x, dim=-1, keepdim=True)) / torch.var(x, dim=-1, keepdim=True)
        out = scale * out + bias
        return out


class ConvGLU(nn.Module):
    def __init__(self, channel, ks, dilation, embedding_dim=192, use_cLN=False):
        super(ConvGLU, self).__init__()

        self.dropout = nn.Dropout()
        self.conv = nn.Conv1d(channel, channel * 2, kernel_size=ks, stride=1, padding=(ks - 1) // 2 * dilation,
                              dilation=dilation)
        self.glu = nn.GLU(dim=1)  # channel-wise

        self.use_cLN = use_cLN
        if self.use_cLN:
            self.norm = ConditionalLayerNorm(embedding_dim)

    def forward(self, x, speaker_embedding=None):
        y = self.dropout(x)
        y = self.conv(y)
        y = self.glu(y)
        y = y + x

        if self.use_cLN and speaker_embedding is not None:
            y = self.norm(y, speaker_embedding)
        return y


class PreConv(nn.Module):
    def __init__(self, c_in, c_mid, c_out):
        super(PreConv, self).__init__()
        self.network = nn.Sequential(
            nn.Conv1d(c_in, c_mid, kernel_size=1, dilation=1),
            nn.LeakyReLU(),
            nn.Dropout(),

            nn.Conv1d(c_mid, c_mid, kernel_size=1, dilation=1),
            nn.LeakyReLU(),
            nn.Dropout(),

            nn.Conv1d(c_mid, c_out, kernel_size=1, dilation=1),
        )

    def forward(self, x):
        y = self.network(x)
        return y


class Generator(nn.Module):
    def __init__(self, c_in=1024, c_preconv=512, c_mid=512, c_out=80):
        super(Generator, self).__init__()

        self.network1 = nn.Sequential(
            PreConv(c_in, c_preconv, c_mid),

            ConvGLU(c_mid, ks=3, dilation=1, use_cLN=False),
            ConvGLU(c_mid, ks=3, dilation=3, use_cLN=False),
            ConvGLU(c_mid, ks=3, dilation=9, use_cLN=False),
            ConvGLU(c_mid, ks=3, dilation=27, use_cLN=False),

            ConvGLU(c_mid, ks=3, dilation=1, use_cLN=False),
            ConvGLU(c_mid, ks=3, dilation=3, use_cLN=False),
            ConvGLU(c_mid, ks=3, dilation=9, use_cLN=False),
            ConvGLU(c_mid, ks=3, dilation=27, use_cLN=False),

            # ConvGLU(c_mid, ks=3, dilation=1, use_cLN=False),
            # ConvGLU(c_mid, ks=3, dilation=3, use_cLN=False),
            # ConvGLU(c_mid, ks=3, dilation=9, use_cLN=False),
            # ConvGLU(c_mid, ks=3, dilation=27, use_cLN=False),

            ConvGLU(c_mid, ks=3, dilation=1, use_cLN=False),
            ConvGLU(c_mid, ks=3, dilation=3, use_cLN=False),
        )

        self.LR = nn.Sequential(
            nn.LeakyReLU(),
            nn.Conv1d(c_mid + 1, c_mid + 1, kernel_size=1, stride=1))

        self.network3 = nn.ModuleList([
            ConvGLU(c_mid + 1, ks=3, dilation=1, use_cLN=True),
            ConvGLU(c_mid + 1, ks=3, dilation=3, use_cLN=True),
            ConvGLU(c_mid + 1, ks=3, dilation=9, use_cLN=True),
            ConvGLU(c_mid + 1, ks=3, dilation=27, use_cLN=True),

            ConvGLU(c_mid + 1, ks=3, dilation=1, use_cLN=True),
            ConvGLU(c_mid + 1, ks=3, dilation=3, use_cLN=True),
            ConvGLU(c_mid + 1, ks=3, dilation=9, use_cLN=True),
            ConvGLU(c_mid + 1, ks=3, dilation=27, use_cLN=True),

            # ConvGLU(c_mid + 1, ks=3, dilation=1, use_cLN=True),
            # ConvGLU(c_mid + 1, ks=3, dilation=3, use_cLN=True),
            # ConvGLU(c_mid + 1, ks=3, dilation=9, use_cLN=True),
            # ConvGLU(c_mid + 1, ks=3, dilation=27, use_cLN=True),

            ConvGLU(c_mid + 1, ks=3, dilation=1, use_cLN=True),
            ConvGLU(c_mid + 1, ks=3, dilation=3, use_cLN=True),
        ])

        self.lastConv = nn.Conv1d(c_mid + 1, c_out, kernel_size=1, dilation=1)

    def forward(self, x, energy, speaker_embedding):
        """
        Args:
            x: wav2vec feature or yingram. torch.Tensor of shape (B x C x t)
            energy: energy. torch.Tensor of shape (B x 1 x t)
            speaker_embedding: embedding. torch.Tensor of shape (B x d x 1)
        Returns:
        """
        y = self.network1(x)
        B, C, _ = y.shape

        y = F.interpolate(y, energy.shape[-1])  # B x C x d
        y = torch.cat((y, energy), dim=1)  # channel-wise concat
        y = self.LR(y)

        for module in self.network3:  # doing this since sequential takes only 1 input
            y = module(y, speaker_embedding)
        y = self.lastConv(y)
        return y


class Synthesis(nn.Module):
    def __init__(self):
        super(Synthesis, self).__init__()
        self.filter_generator = Generator(1024, 512, 128, 80)
        self.source_generator = Generator(50, 512, 128, 80)
        self.set_vocoder()

    def set_vocoder(self):
        path_config = './hifi-gan/UNIVERSAL_V1/config.json'
        path_ckpt = './hifi-gan/UNIVERSAL_V1/g_02500000'

        hifigan_config = OmegaConf.load(path_config)
        self.vocoder = HiFiGANGenerator(hifigan_config)

        state_dict_g = torch.load(path_ckpt)
        self.vocoder.load_state_dict(state_dict_g['generator'])
        self.vocoder.eval()

        for param in self.vocoder.parameters():
            param.requires_grad = False

        zero_audio = torch.zeros(44100).float()
        zero_mel = mel_spectrogram(
            zero_audio.unsqueeze(0),
            1024, 80, 22050, 256, 1024, 0, 8000
        )
        self.mel_padding_value = torch.min(zero_mel)

    def _denormalize(self, spec):
        return spec * -self.mel_padding_value + self.mel_padding_value

    def forward(self, lps, s, e, ps):
        result = {}
        result['mel_filter'] = self.filter_generator(lps, e, s)
        result['mel_source'] = self.source_generator(ps, e, s)
        result['gen_mel'] = result['mel_filter'] + result['mel_source']
        with torch.no_grad():
            # hifigan_mel = self._denormalize(result['gen_mel'])
            hifigan_mel = result['gen_mel']
            result['audio_gen'] = self.vocoder(hifigan_mel)
        return result

    def train(self, mode: bool = True):
        if not isinstance(mode, bool):
            raise ValueError("training mode is expected to be boolean")
        self.training = mode
        # for module in self.children():
        #     module.train(mode)
        self.filter_generator.train(mode)
        self.source_generator.train(mode)
        return self

### 3.3 Discriminator

In [None]:
class ResBlock(nn.Module):
    def __init__(self, c_in, c_mid=128, c_out=128):
        super(ResBlock, self).__init__()
        self.leaky_relu1 = nn.LeakyReLU()
        self.conv1 = nn.Conv1d(c_in, c_mid, kernel_size=3, stride=1, padding=(3 - 1) // 2 * 3, dilation=3)

        self.leaky_relu2 = nn.LeakyReLU()
        self.conv2 = nn.Conv1d(c_mid, c_out, kernel_size=3, stride=1, padding=(3 - 1) // 2 * 3, dilation=3)

        self.conv3 = nn.Conv1d(c_in, c_out, kernel_size=1, dilation=1)

    def forward(self, x):
        y = self.conv1(self.leaky_relu1(x))
        y = self.conv2(self.leaky_relu2(y))
        y = y + self.conv3(x)
        return y


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        c_in = 80
        c_mid = 128
        c_out = 192

        self.phi = nn.Sequential(
            nn.Conv1d(c_in, c_mid, kernel_size=3, stride=1, padding=1, dilation=1),
            ResBlock(c_mid, c_mid, c_mid),
            ResBlock(c_mid, c_mid, c_mid),
            ResBlock(c_mid, c_mid, c_mid),
            ResBlock(c_mid, c_mid, c_mid),
            ResBlock(c_mid, c_mid, c_mid),
        )
        self.res = ResBlock(c_mid, c_mid, c_out)

        self.psi = nn.Conv1d(c_mid, 1, kernel_size=3, stride=1, padding=1, dilation=1)

    def forward(self, mel, positive, negative):
        """
        Args:
            mel: mel spectrogram, torch.Tensor of shape (B x C x T)
            positive: positive speaker embedding, torch.Tensor of shape (B x d)
            negative: negative speaker embedding, torch.Tensor of shape (B x d)
        Returns:
Nsi
        """
        pred1 = self.psi(self.phi(mel))
        pred = self.res(self.phi(mel))
        pred2 = torch.bmm(positive.unsqueeze(1), pred)
        pred3 = torch.bmm(negative.unsqueeze(1), pred)
        result = pred1 + pred2 - pred3
        result = result.squeeze(1)
        # result = torch.mean(result, dim=-1)
        return result


### 3.4. Loss

In [None]:
class GANLoss(nn.Module):
    def __init__(self, conf):
        self.conf = conf
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(self.conf.real))
        self.register_buffer('fake_label', torch.tensor(self.conf.fake))
        gan_mode = self.conf.gan_mode
        self.gan_mode = gan_mode
        if gan_mode == 'lsgan':
            self.loss = nn.MSELoss()
        elif gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode in ['wgangp', 'hinge']:
            self.loss = None
        else:
            raise NotImplementedError('gan mode %s not implemented' % gan_mode)

    def get_target_tensor(self, prediction, target_is_real):
        """Create label tensors with the same size as the input.
        Parameters:
            prediction (tensor) - - tpyically the prediction from a discriminator
            target_is_real (bool) - - if the ground truth label is for real images or fake images
        Returns:
            A label tensor filled with ground truth label, and with the size of the input
        """

        if target_is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        return target_tensor.type_as(prediction).expand_as(prediction)

    def __call__(self, prediction, target_is_real, for_discriminator=True):
        """Calculate loss given Discriminator's output and grount truth labels.
        Parameters:
            prediction (tensor) - - tpyically the prediction output from a discriminator
            target_is_real (bool) - - if the ground truth label is for real images or fake images
        Returns:
            the calculated loss.
        """
        if self.gan_mode in ['lsgan', 'vanilla']:
            target_tensor = self.get_target_tensor(prediction, target_is_real)
            loss = self.loss(prediction, target_tensor)
        elif self.gan_mode == 'wgangp':
            if target_is_real:
                loss = -prediction.mean()
            else:
                loss = prediction.mean()
        elif self.gan_mode == 'hinge':
            if not for_discriminator:
                loss = -prediction.mean()
            else:
                if target_is_real:
                    loss = torch.nn.ReLU()(1.0 - prediction).mean()
                else:
                    loss = torch.nn.ReLU()(1.0 + prediction).mean()
        return loss

## 4. Train

In [None]:
def step(batch):
    logs = {}
    loss = {}
    logs.update(batch)
    
    logs['lps'] = analysis.linguistic(batch['gt_audio_16k_f'])
    logs['s_pos'] = analysis.speaker(batch['gt_audio_16k'])
    logs['s_neg'] = analysis.speaker(batch['gt_audio_16k_negative'])

    with torch.no_grad():
        logs['e'] = analysis.energy(batch['gt_mel_22k'])
        logs['ps'] = analysis.pitch.yingram_batch(batch['gt_audio_22k_g'])
        logs['ps'] = logs['ps'][:, 19:69]

    result = synthesis(logs['lps'], logs['s_pos'], logs['e'], logs['ps'])
    logs.update(result)

    loss['mel'] = F.l1_loss(logs['gen_mel'], logs['gt_mel_22k'])
    loss['backward'] = loss['mel']

    # G step
    pred_gen = discriminator(logs['gen_mel'], logs['s_pos'], logs['s_neg'])
    loss['D_gen_forG'] = gan_loss(pred_gen, True, False)
    loss['backward'] = loss['backward'] + 1 * loss['D_gen_forG']

    # D step
    logs['gen_mel'] = logs['gen_mel'].detach()
    logs['s_pos'] = logs['s_pos'].detach()
    logs['s_neg'] = logs['s_neg'].detach()
    pred_gen = discriminator(logs['gen_mel'], logs['s_pos'], logs['s_neg'])
    pred_gt = discriminator(logs['gt_mel_22k'], logs['s_pos'], logs['s_neg'])

    loss['D_gen_forD'] = gan_loss(pred_gen, False, True)
    loss['D_gt_forD'] = gan_loss(pred_gt, True, True)
    loss['D_backward'] = loss['D_gen_forD'] + loss['D_gt_forD']
    
    return loss, logs

In [None]:
analysis = Analysis()
_ = analysis.to(cfg.device)
analysis_optimizer = torch.optim.Adam(analysis.parameters(), lr=1e-4, betas=[0.5, 0.9])

In [None]:
synthesis = Synthesis()
_ = synthesis.to(cfg.device)
synthesis_optimizer = torch.optim.Adam(synthesis.parameters(), lr=1e-4, betas=[0.5, 0.9])

In [None]:
discriminator = Discriminator()
_ = discriminator.to(cfg.device)
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-4, betas=[0.5, 0.9])

In [None]:
gan_loss = GANLoss(DictConfig({'gan_mode': 'lsgan', 'real': 1., 'fake': 0.}))

In [None]:
wandb.init(project='voice-conversion')

In [None]:
for ep in range(50):
    pbar = tqdm(dataloader)
    for batch in pbar:
        batch = {k:v.to(cfg.device) for k,v in batch.items()}

        loss, _ = step(batch)

        analysis_optimizer.zero_grad()
        synthesis_optimizer.zero_grad()
        discriminator_optimizer.zero_grad()

        loss['backward'].backward()

        analysis_optimizer.step()
        synthesis_optimizer.step()
        discriminator_optimizer.step()
        
        log = {k:v.item() for k,v in loss.items()}
        pbar.set_postfix(log)
        wandb.log(log)
        
    state_dict = {
        'analysis': analysis.state_dict(),
        'synthesis': synthesis.state_dict(),
        'discriminator': discriminator.state_dict(),
    }
    torch.save(state_dict, 'nansy.pt')

## 5. Test

In [None]:
state_dict = torch.load('nansy.pt')

In [None]:
analysis = Analysis()
analysis.load_state_dict(state_dict['analysis'])
_ = analysis.eval().requires_grad_(False).to(cfg.device)

In [None]:
synthesis = Synthesis()
synthesis.load_state_dict(state_dict['synthesis'])
_ = synthesis.eval().requires_grad_(False).to(cfg.device)

In [None]:
src_audios = glob('/mnt/tts/vctk')
tgt_audios = glob('/mnt/tts/kss')

In [None]:
src_file = src_audios[10]
tgt_file = tgt_audios[20]

src = dataset.get_data(src_file)
tgt = dataset.get_data(tgt_file)

src = {k:v.to(cfg.device) for k,v in src.items()}
tgt = {k:v.to(cfg.device) for k,v in tgt.items()}

In [None]:
lps = analysis.linguistic(src['gt_audio_16k'])
s = analysis.speaker(tgt['gt_audio_16k'])
e = analysis.energy(tgt['gt_mel_22k'])
ps = analysis.pitch.yingram_batch(tgt['gt_audio_22k'])
ps = ps[:, 19:69]

In [None]:
result = synthesis(lps, s, e, ps)
audio = result['audio_gen']
audio = audio.squeeze()
audio = audio * 32768.0
audio = audio.cpu().numpy().astype('int16')

sf.write('output.wav', audio, 22050)

In [None]:
IPython.display.Audio(src_file)

In [None]:
IPython.display.Audio(tgt_file)

In [None]:
IPython.display.Audio('output.wav')