<a href="https://colab.research.google.com/github/yongsun-yoon/deep-learning-paper-implementation/blob/main/06-speech/S2VC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# S2VC

## 0. Info

### Paper
* title: S2VC: A Framework for Any-to-Any Voice Conversion with Self-Supervised Pretrained Representations
* authors: Jheng-hao Lin et al.
* url: https://arxiv.org/abs/2104.02901

### Feats
* vocoder: HiFiGAN

### Refs
* https://github.com/howard1337/S2VC
* https://github.com/jik876/hifi-gan

## 1. Setup

In [None]:
import os
import IPython
import easydict
import numpy as np
from glob import glob
from tqdm.auto import tqdm
from einops import rearrange
from collections import defaultdict

import librosa
import soundfile as sf
from scipy.signal import lfilter
from librosa.util import normalize
from scipy.io.wavfile import read
from librosa.filters import mel as librosa_mel_fn

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from transformers import get_scheduler

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
cfg = easydict.EasyDict(
    rawdir = '/mnt/vctk',
    datadir = 'vctk-prep',
    device = 'cuda',
    feature = 'modified_cpc',
    
    lr = 1e-4,
    num_warmup_steps = 1000,
    num_training_steps = 150000,
    
    mel = {
        "segment_size": 8192,
        "num_mels": 80,
        "num_freq": 1025,
        "n_fft": 1024,
        "hop_size": 256,
        "win_size": 1024,
        "sampling_rate": 22050,
        "fmin": 0,
        "fmax": 8000,
        "fmax_for_loss": None,
    }
)

## 2. Data

### 2.1. Preprocess

In [None]:
MAX_WAV_VALUE = 32768.0
mel_basis = {}
hann_window = {}

def load_wav(fpath):
    sr, wav = read(fpath)
    wav = wav / MAX_WAV_VALUE
    return wav, sr


def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
    return torch.log(torch.clamp(x, min=clip_val) * C)


def spectral_normalize_torch(magnitudes):
    output = dynamic_range_compression_torch(magnitudes)
    return output

def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
    if torch.min(y) < -1.:
        print('min value is ', torch.min(y))
    if torch.max(y) > 1.:
        print('max value is ', torch.max(y))

    global mel_basis, hann_window
    if fmax not in mel_basis:
        mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
        mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
        hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)

    y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
    y = y.squeeze(1)

    spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
                      center=center, pad_mode='reflect', normalized=False, onesided=True)

    spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))

    spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
    spec = spectral_normalize_torch(spec)

    return spec

def get_mel(x, mel_cfg):
    return mel_spectrogram(x, mel_cfg.n_fft, mel_cfg.num_mels, mel_cfg.sampling_rate, mel_cfg.hop_size, mel_cfg.win_size, mel_cfg.fmin, mel_cfg.fmax)

In [None]:
extractor = torch.hub.load("s3prl/s3prl", cfg.feature, refresh=True)
_ = extractor.requires_grad_(False).eval().to(cfg.device)

In [None]:
speaker_dirs = glob(f'{cfg.rawdir}/*')
for sdir in tqdm(speaker_dirs):
    speaker = sdir.split('/')[-1]
    speaker_file_paths = glob(f'{sdir}/audio/*.wav')
    for fpath in speaker_file_paths:
        fname = fpath.split('/')[-1].split('.')[0]
        wav, sr = load_wav(fpath)
        mel_wav = librosa.resample(wav, orig_sr=sr, target_sr=22050)
    
        mel_wav = torch.FloatTensor(mel_wav).to(cfg.device).unsqueeze(0)
        mel = get_mel(mel_wav, cfg.mel)[0].cpu() # (80, seqlen)
        mel = mel.T # (seqlen, 80)
        
        feat_wav = librosa.resample(wav, orig_sr=sr, target_sr=16000)
        feat_wav = torch.FloatTensor(feat_wav).to(cfg.device).unsqueeze(0)
        feat = extractor(feat_wav)['hidden_states'][0][0].cpu() # (seqlen, 256)
        
        save_path = f'{cfg.datadir}/{speaker}/{fname}.pt'
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        torch.save({'mel': mel, 'feat': feat}, save_path)

### 2.2. Dataset

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, datadir, num_refs):
        self.datadir = datadir
        self.num_refs = num_refs
        self.speaker_to_files, self.file_to_speaker, self.file_paths = self.prepare()
        
        
    def prepare(self):
        speaker_to_files, file_to_speaker, file_paths = defaultdict(list), {}, []
        speaker_dirs = glob(f'{self.datadir}/*')
        for sdir in tqdm(speaker_dirs):
            speaker = sdir.split('/')[-1]
            speaker_file_paths = glob(f'{sdir}/*.pt')
            for fpath in speaker_file_paths:
                fidx = len(file_paths)
                speaker_to_files[speaker].append(fidx)
                file_to_speaker[fidx] = speaker
                file_paths.append(fpath)
        return speaker_to_files, file_to_speaker, file_paths
                

    def __len__(self):
        return len(self.file_paths)
    
    def read_file(self, idx):
        fpath = self.file_paths[idx]
        data = torch.load(fpath)
        mel, feat = data['mel'], data['feat']
        return feat, mel

    def __getitem__(self, idx):
        src_feat, tgt_mel = self.read_file(idx)
    
        speaker = self.file_to_speaker[idx]
        ref_files = [f for f in self.speaker_to_files[speaker] if f != idx]
        ref_files = np.random.choice(ref_files, self.num_refs, replace=False)
        ref_feats = [self.read_file(f)[0] for f in ref_files]
        ref_feats = torch.cat(ref_feats, dim=0)
        
        return src_feat, ref_feats, tgt_mel
    
    
def collate_fn(batch):
    """Collate a batch of data."""
    srcs, tgts, tgt_mels = zip(*batch)
    device = srcs[0].device

    src_lens = [len(src) for src in srcs]
    tgt_lens = [len(tgt) for tgt in tgts]
    tgt_mel_lens = [len(tgt_mel) for tgt_mel in tgt_mels]
    overlap_lens = [min(src_len, tgt_mel_len) for src_len, tgt_mel_len in zip(src_lens, tgt_mel_lens)]

    srcs = pad_sequence(srcs, batch_first=True)

    src_masks = [torch.arange(srcs.size(1), device=device) >= src_len for src_len in src_lens]
    src_masks = torch.stack(src_masks)

    tgts = pad_sequence(tgts, batch_first=True, padding_value=-20)
    tgts = tgts.transpose(1, 2)  # (batch, mel_dim, max_tgt_len)

    tgt_masks = [torch.arange(tgts.size(2), device=device) >= tgt_len for tgt_len in tgt_lens]
    tgt_masks = torch.stack(tgt_masks)  # (batch, max_tgt_len)

    tgt_mels = pad_sequence(tgt_mels, batch_first=True, padding_value=-20)
    tgt_mels = tgt_mels.transpose(1, 2)  # (batch, mel_dim, max_tgt_len)

    return srcs, src_masks, tgts, tgt_masks, tgt_mels, overlap_lens

In [None]:
dataset = Dataset(cfg.datadir, num_refs=5)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

In [None]:
batch = next(iter(dataloader))

In [None]:
[b.shape for b in batch[:-1]]

## 3. Model

In [None]:
class Smoother(nn.Module):
    def __init__(self, d_model: int, nhead: int, d_hid: int, dropout=0.1):
        super(Smoother, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)

        self.conv1 = nn.Conv1d(d_model, d_hid, 9, padding=4)
        self.conv2 = nn.Conv1d(d_hid, d_model, 1, padding=0)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, src, src_mask, src_key_padding_mask):
        src2 = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        src2 = src.transpose(0, 1).transpose(1, 2)
        src2 = self.conv2(F.relu(self.conv1(src2)))
        src2 = src2.transpose(1, 2).transpose(0, 1)

        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

    
class Extractor(nn.Module):
    """Convolutional Transformer Decoder Layer"""

    def __init__(
        self,
        d_model: int,
        nhead: int,
        d_hid: int,
        bottleneck_dim: int,
        dropout=0.1,
        no_residual=False,
        bottleneck=False,
    ):
        super(Extractor, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.cross_attn = nn.MultiheadAttention(bottleneck_dim, nhead, dropout=dropout)
        self.out_proj = nn.Linear(d_model, d_model)

        self.conv1 = nn.Conv1d(d_model, d_hid, 9, padding=4)
        self.conv2 = nn.Conv1d(d_hid, d_model, 1, padding=0)
        
        self.bottleneck = bottleneck
        self.tgt_bottleneck = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            # InstanceNorm1d(d_model),
            nn.Linear(d_model, bottleneck_dim),
        )

        self.memory_bottleneck = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            # InstanceNorm1d(d_model),
            nn.Linear(d_model, bottleneck_dim),
        )

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

        self.no_residual = no_residual

    def forward(
        self,
        tgt,
        memory,
        tgt_mask = None,
        memory_mask = None,
        tgt_key_padding_mask = None,
        memory_key_padding_mask = None,
    ):
        # multi-head self attention
        tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]

        # add & norm
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)

        # bottleneck feature of target and references
        if self.bottleneck:
            tgt_compat = self.tgt_bottleneck(tgt)
            memory_compact = self.memory_bottleneck(memory)
        else:
            tgt_compat = tgt
            memory_compact = memory

        # multi-head cross attention
        tgt2, attn = self.cross_attn(
            tgt_compat,
            memory_compact,
            memory_compact,
            attn_mask=memory_mask,
            key_padding_mask=memory_key_padding_mask,
        )
        
        if self.bottleneck and attn is not None:
            memory = (
                memory.contiguous()
                .view(memory.size(0), -1, memory.size(-1))
                .transpose(0, 1)
            )
            tgt2 = torch.bmm(attn, memory)
            tgt2 = (
                tgt2.transpose(0, 1)
                .contiguous()
                .view(-1, memory.size(0), memory.size(2))
            )
            tgt2 = F.linear(tgt2, self.out_proj.weight, self.out_proj.bias)
        # add & norm
        if self.no_residual:
            tgt = self.dropout2(tgt2)
        else:
            tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)

        # conv1d
        tgt2 = tgt.transpose(0, 1).transpose(1, 2)
        tgt2 = self.conv2(F.relu(self.conv1(tgt2)))
        tgt2 = tgt2.transpose(1, 2).transpose(0, 1)

        # add & norm
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)

        return tgt, attn

In [None]:
class S2VC(nn.Module):
    def __init__(self, input_dim, ref_dim, d_model=512):
        super().__init__()
        self.unet = UnetBlock(d_model, input_dim, ref_dim)

        self.smoothers = nn.TransformerEncoder(Smoother(d_model, 2, 1024), num_layers=3)

        self.mel_linear = nn.Linear(d_model, 80)

        self.post_net = nn.Sequential(
            nn.Conv1d(80, 512, kernel_size=5, padding=2),
            nn.BatchNorm1d(512),
            nn.Tanh(),
            nn.Dropout(0.5),
            nn.Conv1d(512, 512, kernel_size=5, padding=2),
            nn.BatchNorm1d(512),
            nn.Tanh(),
            nn.Dropout(0.5),
            nn.Conv1d(512, 512, kernel_size=5, padding=2),
            nn.BatchNorm1d(512),
            nn.Tanh(),
            nn.Dropout(0.5),
            nn.Conv1d(512, 512, kernel_size=5, padding=2),
            nn.BatchNorm1d(512),
            nn.Tanh(),
            nn.Dropout(0.5),
            nn.Conv1d(512, 80, kernel_size=5, padding=2),
            nn.BatchNorm1d(80),
            nn.Dropout(0.5),
        )

    def forward(
        self,
        srcs,
        refs,
        src_masks = None,
        ref_masks = None,
    ):
        """Forward function.
        Args:
            srcs: (batch, src_len, 768)
            src_masks: (batch, src_len)
            refs: (batch, 80, ref_len)
            ref_masks: (batch, ref_len)
        """
        # out: (src_len, batch, d_model)
        out, attns = self.unet(srcs, refs, src_masks=src_masks, ref_masks=ref_masks)

        # out: (src_len, batch, d_model)
        out = self.smoothers(out, src_key_padding_mask=src_masks)

        # out: (src_len, batch, 80)
        out = self.mel_linear(out)

        # out: (batch, 80, src_len)
        out = out.transpose(1, 0).transpose(2, 1)
        refined = self.post_net(out)
        out = out + refined

        # out: (batch, 80, src_len)
        return out, attns



class SelfAttentionPooling(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, 1)

    def forward(self, batch_rep, att_mask = None):
        att_logits = self.linear(batch_rep).squeeze(-1)
        if att_mask is not None:
            att_logits = att_logits.masked_fill(att_mask, 1e-20)
        att_w = F.softmax(att_logits, dim=-1).unsqueeze(-1)
        utter_rep = torch.sum(batch_rep * att_w, dim=1)
        return utter_rep

    
class SourceEncoder(nn.Module):
    def __init__(self, d_model: int, input_dim: int):
        super(SourceEncoder, self).__init__()        
        self.lin1 = nn.Linear(input_dim, input_dim)
        self.lin2 = nn.Linear(input_dim,  d_model)
        self.lin3 = nn.Linear(d_model,  d_model)
        self.lin4 = nn.Linear(d_model,  d_model)

        self.bn1 = nn.BatchNorm1d(input_dim)
        self.bn2 = nn.BatchNorm1d(d_model)
        self.bn3 = nn.BatchNorm1d(d_model)
        self.bn4 = nn.BatchNorm1d(d_model)

        self.dropout1 = nn.Dropout(0.0)
        self.dropout2 = nn.Dropout(0.0)
        self.dropout3 = nn.Dropout(0.0)
        self.dropout4 = nn.Dropout(0.0)

        self.SAP = SelfAttentionPooling(d_model)
        self.proj = nn.Linear(d_model, d_model)
        torch.nn.init.xavier_uniform_(self.proj.weight,   gain=torch.nn.init.calculate_gain('linear'))

    def forward(
        self, 
        srcs, 
        refs, 
        src_masks = None, 
        ref_masks = None
    ):
        tgt = F.relu(self.lin1(srcs)).transpose(1, 2)
        tgt = self.dropout1(self.bn1(tgt)).transpose(1, 2)

        tgt = F.relu(self.lin2(tgt)).transpose(1, 2)
        tgt = self.dropout2(self.bn2(tgt)).transpose(1, 2)

        tgt = F.relu(self.lin3(tgt)).transpose(1, 2)
        tgt = self.dropout3(self.bn3(tgt)).transpose(1, 2)

        tgt = F.relu(self.lin4(tgt)).transpose(1, 2)
        tgt = self.dropout4(self.bn4(tgt)).transpose(1, 2)

        spk_embed = F.relu(self.proj(self.SAP(refs.transpose(1, 2), ref_masks))).unsqueeze(1)
        tgt *= spk_embed
        return tgt


class UnetBlock(nn.Module):
    def __init__(self, d_model: int, input_dim: int, ref_dim: int):
        super(UnetBlock, self).__init__()
        self.conv1 = nn.Conv1d(ref_dim, d_model, 3, padding=1, padding_mode="replicate")
        self.conv2 = nn.Conv1d(d_model, d_model, 3, padding=1, padding_mode="replicate")
        self.conv3 = nn.Conv1d(d_model, d_model, 3, padding=1, padding_mode="replicate")
        self.extractor1 = Extractor(d_model=d_model, nhead=2, d_hid=1024, bottleneck_dim=4, no_residual=True, bottleneck=True)
        self.src_encoder = SourceEncoder(d_model, input_dim)
        
    def forward(
        self,
        srcs,
        refs,
        src_masks = None,
        ref_masks = None,
    ):
        ref1 = self.conv1(refs)
        ref2 = self.conv2(F.relu(ref1))
        ref3 = self.conv3(F.relu(ref2))

        tgt = self.src_encoder(srcs, ref3, src_masks, ref_masks)
        tgt = tgt.transpose(0, 1)

        out, attn1 = self.extractor1(
            tgt,
            ref3.transpose(1, 2).transpose(0, 1),
            tgt_key_padding_mask=src_masks,
            memory_key_padding_mask=ref_masks,
        )
        return out, [attn1]

In [None]:
model = S2VC(input_dim=256, ref_dim=256)
# _ = model.load_state_dict(torch.load('s2vc.pt'))
_ = model.to(cfg.device)

In [None]:
criterion = nn.L1Loss()
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
scheduler = get_scheduler('cosine', optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=cfg.num_training_steps)

## 4. Train

In [None]:
def model_fn(batch, model, criterion, device):
    """Forward a batch through model."""
    batch = [b.to(device) if type(b)==torch.Tensor else b for b in batch]

    srcs, src_masks, tgts, tgt_masks, tgt_mels, overlap_lens = batch

    refs = tgts
    ref_masks = tgt_masks

    outs, attns = model(srcs, refs, src_masks=src_masks, ref_masks=ref_masks)
            
    losses = []
    for out, tgt_mel, attn, overlap_len in zip(outs.unbind(), tgt_mels.unbind(), attns[-1], overlap_lens):
        loss = criterion(out[:, :overlap_len], tgt_mel[:, :overlap_len])
        losses.append(loss)
    try:
        attns_plot = []
        for i in range(len(attns)):
            attns_plot.append(attns[i][0][:overlap_lens[0], :overlap_lens[0]])
    except:
        pass

        
    return sum(losses) / len(losses), attns_plot

In [None]:
pbar = tqdm(range(1, cfg.num_training_steps+1))
for st in pbar:
    try:
        batch = next(dataiter)
    except:
        dataiter = iter(dataloader)
        batch = next(dataiter)
    
    loss, attns = model_fn(batch, model, criterion, cfg.device)    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step()
    
    log = {'loss': loss.item()}
    pbar.set_postfix(log)
    
    if st % 1000 == 0:
        torch.save(model.state_dict(), 's2vc.pt')

## 5. Test

### 5.1. Vocoder

In [None]:
LRELU_SLOPE = 0.1

def init_weights(m, mean=0.0, std=0.01):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        m.weight.data.normal_(mean, std)

def get_padding(kernel_size, dilation=1):
    return int((kernel_size*dilation - dilation)/2)


class ResBlock1(nn.Module):
    def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
        super(ResBlock1, self).__init__()
        self.h = h
        self.convs1 = nn.ModuleList([
            nn.utils.weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))),
            nn.utils.weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]))),
            nn.utils.weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], padding=get_padding(kernel_size, dilation[2])))
        ])
        self.convs1.apply(init_weights)

        self.convs2 = nn.ModuleList([
            nn.utils.weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))),
            nn.utils.weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))),
            nn.utils.weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)))
        ])
        self.convs2.apply(init_weights)

    def forward(self, x):
        for c1, c2 in zip(self.convs1, self.convs2):
            xt = F.leaky_relu(x, LRELU_SLOPE)
            xt = c1(xt)
            xt = F.leaky_relu(xt, LRELU_SLOPE)
            xt = c2(xt)
            x = xt + x
        return x

    def remove_weight_norm(self):
        for l in self.convs1:
            nn.utils.remove_weight_norm(l)
        for l in self.convs2:
            nn.utils.remove_weight_norm(l)


class ResBlock2(nn.Module):
    def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
        super(ResBlock2, self).__init__()
        self.h = h
        self.convs = nn.ModuleList([
            nn.utils.weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))),
            nn.utils.weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1])))
        ])
        self.convs.apply(init_weights)

    def forward(self, x):
        for c in self.convs:
            xt = F.leaky_relu(x, LRELU_SLOPE)
            xt = c(xt)
            x = xt + x
        return x

    def remove_weight_norm(self):
        for l in self.convs:
            nn.utils.remove_weight_norm(l)


class Generator(nn.Module):
    def __init__(self, h):
        super(Generator, self).__init__()
        self.h = h
        self.num_kernels = len(h.resblock_kernel_sizes)
        self.num_upsamples = len(h.upsample_rates)
        self.conv_pre = nn.utils.weight_norm(nn.Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
        resblock = ResBlock1 if h.resblock == '1' else ResBlock2

        self.ups = nn.ModuleList()
        for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
            self.ups.append(nn.utils.weight_norm(    
                nn.ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)), k, u, padding=(k-u)//2)
            ))

        self.resblocks = nn.ModuleList()
        for i in range(len(self.ups)):
            ch = h.upsample_initial_channel//(2**(i+1))
            for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
                self.resblocks.append(resblock(h, ch, k, d))

        self.conv_post = nn.utils.weight_norm(nn.Conv1d(ch, 1, 7, 1, padding=3))
        self.ups.apply(init_weights)
        self.conv_post.apply(init_weights)

    def forward(self, x):
        x = self.conv_pre(x)
        for i in range(self.num_upsamples):
            x = F.leaky_relu(x, LRELU_SLOPE)
            x = self.ups[i](x)
            xs = None
            for j in range(self.num_kernels):
                if xs is None:
                    xs = self.resblocks[i*self.num_kernels+j](x)
                else:
                    xs += self.resblocks[i*self.num_kernels+j](x)
            x = xs / self.num_kernels
        x = F.leaky_relu(x)
        x = self.conv_post(x)
        x = torch.tanh(x)

        return x

    def remove_weight_norm(self):
        print('Removing weight norm...')
        for l in self.ups:
            nn.utils.remove_weight_norm(l)
        for l in self.resblocks:
            l.remove_weight_norm()
        nn.utils.remove_weight_norm(self.conv_pre)
        nn.utils.remove_weight_norm(self.conv_post)

In [None]:
h = {
    "resblock": "1",
    "num_gpus": 0,
    "batch_size": 16,
    "learning_rate": 0.0002,
    "adam_b1": 0.8,
    "adam_b2": 0.99,
    "lr_decay": 0.999,
    "seed": 1234,

    "upsample_rates": [8,8,2,2],
    "upsample_kernel_sizes": [16,16,4,4],
    "upsample_initial_channel": 512,
    "resblock_kernel_sizes": [3,7,11],
    "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],

    "segment_size": 8192,
    "num_mels": 80,
    "num_freq": 1025,
    "n_fft": 1024,
    "hop_size": 256,
    "win_size": 1024,
    "sampling_rate": 22050,
    "fmin": 0,
    "fmax": 8000,
    "fmax_for_loss": None,

    "num_workers": 4,

    "dist_config": {
        "dist_backend": "nccl",
        "dist_url": "tcp://localhost:54321",
        "world_size": 1
    }
}

h = easydict.EasyDict(h)

In [None]:
extractor = torch.hub.load("s3prl/s3prl", cfg.feature, refresh=True)
_ = extractor.requires_grad_(False).eval().to(cfg.device)

Using cache found in /root/.cache/torch/hub/s3prl_s3prl_main


In [None]:
generator = Generator(h)
ckpt = torch.load('g_02500000')['generator']
generator.load_state_dict(ckpt)
generator.remove_weight_norm()
_ = generator.eval().requires_grad_(False).to(cfg.device)

Removing weight norm...


In [None]:
model = S2VC(256, 256)
ckpt = torch.load('s2vc.pt')
model.load_state_dict(ckpt)
_ = model.eval().requires_grad_(False).to(cfg.device)

In [None]:
src_audios = glob('/mnt/vctk/*/audio/*.wav')

In [None]:
src_audio = np.random.choice(src_audios)
src_wav, sr = load_wav(src_audio)
src_wav = librosa.resample(src_wav, orig_sr=sr, target_sr=16000)
src_wav = torch.FloatTensor(src_wav).to(cfg.device).unsqueeze(0)
src_feat = extractor(src_wav)['hidden_states'][0] # (seqlen, 256)

In [None]:
tgt_audios = glob('/mnt/ljspeech/audio/*')
tgt_audios = np.random.choice(tgt_audios, 20, replace=False)

tgt_wavs = []
for a in tgt_audios:
    wav, sr = load_wav(a)
    wav = librosa.resample(wav, orig_sr=sr, target_sr=16000)
    wav = torch.FloatTensor(wav).to(cfg.device)
    tgt_wavs.append(wav)

tgt_feats = extractor(tgt_wavs)['hidden_states'][0]
tgt_feats = rearrange(tgt_feats, 'n s d -> 1 d (n s)')

In [None]:
mel, attn = model(src_feat, tgt_feats)
audio = generator(mel)

audio = audio.squeeze()
audio = audio * MAX_WAV_VALUE
audio = audio.cpu().numpy().astype('int16')

sf.write('output.wav', audio, h.sampling_rate)

In [None]:
IPython.display.Audio(src_audio)

In [None]:
IPython.display.Audio(tgt_audios[0])

In [None]:
IPython.display.Audio('output.wav')