<a href="https://colab.research.google.com/github/yongsun-yoon/deep-learning-paper-implementation/blob/main/06-speech/FragmentVC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# FragmentVC

## 0. Info

### Paper
* title: FragmentVC: Any-to-Any Voice Conversion by End-to-End Extracting and Fusing Fine-Grained Voice Fragments With Attention
* authors: Yist Y. Lin et al.
* url: https://arxiv.org/abs/2010.14150

### Feats
* train data: vctk
* test data: ljspeech

### Refs
* https://github.com/yistLin/FragmentVC

## 1. Setup

In [None]:
import os
import wandb
import IPython
import easydict
import numpy as np
import soundfile as sf
from glob import glob
from tqdm.auto import tqdm
from collections import defaultdict

import librosa
from scipy.signal import lfilter

import torch
from transformers import Wav2Vec2Model
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from transformers import get_scheduler

2022-10-23 00:22:23.518152: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-10-23 00:22:23.656886: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2022-10-23 00:22:23.694170: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-10-23 00:22:24.395572: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; 

In [None]:
cfg = easydict.EasyDict(
    wav2vec_model_name = 'facebook/wav2vec2-base-960h',
    device = 'cuda',
    rawdir = '/mnt/vctk',
    featdir = 'vctk-prep',
    
    num_refs = 10,
    batch_size = 16,
    num_warmup_steps = 1000,
    num_training_steps = 250000,
    phases = [50000, 150000],
    
    lr = 1e-4,
)

## 2. Data

## 2.1. Preprocess

In [None]:
def load_wav(
    audio_path, 
    sample_rate=16000, 
    trim = True
):
    wav = librosa.load(audio_path, sr=sample_rate)[0]
    wav = wav / (np.abs(wav).max() + 1e-6)
    if trim:
        _, (start_frame, end_frame) = librosa.effects.trim(wav, top_db=25, frame_length=512, hop_length=128)
        start_frame = max(0, start_frame - 0.1 * sample_rate)
        end_frame = min(len(wav), end_frame + 0.1 * sample_rate)

        start = int(start_frame)
        end = int(end_frame)
        if end - start > 1000:  # prevent empty slice
            wav = wav[start:end]
    return wav


def log_mel_spectrogram(
    x: np.ndarray,
    preemph: float = 0.97,
    sample_rate: int = 16000,
    n_mels: int = 80,
    n_fft: int = 1304,
    hop_length: int = 326,
    win_length: int = 1304,
    f_min: int = 80,
) -> np.ndarray:

    x = lfilter([1, -preemph], [1], x)
    magnitude = np.abs(librosa.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length))
    mel_fb = librosa.filters.mel(sr=sample_rate, n_fft=n_fft, n_mels=n_mels, fmin=f_min)
    mel_spec = np.dot(mel_fb, magnitude)
    log_mel_spec = np.log(mel_spec + 1e-9)
    return log_mel_spec.T

In [None]:
wav2vec = Wav2Vec2Model.from_pretrained(cfg.wav2vec_model_name)
_ = wav2vec.eval().requires_grad_(False).to(cfg.device)

Some weights of the model checkpoint at facebook/wav2vec2-base-960h were not used when initializing Wav2Vec2Model: ['lm_head.weight', 'lm_head.bias']
- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
speaker_dirs = glob(f'{cfg.rawdir}/*')
for sdir in tqdm(speaker_dirs):
    speaker = sdir.split('/')[-1]
    speaker_audio_paths = glob(f'{sdir}/audio/*.wav')
    for apath in speaker_audio_paths:
        audio_name = apath.split('/')[-1].split('.')[0]
        
        wav = load_wav(apath)
        mel = log_mel_spectrogram(wav)
        mel = torch.FloatTensor(mel)
        feat = wav2vec(torch.FloatTensor(wav).to(cfg.device).unsqueeze(0)).last_hidden_state[0].cpu()
        
        save_path = f'{cfg.featdir}/{speaker}/{audio_name}.pt'
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        torch.save({'mel': mel, 'feat': feat}, save_path)

### 2.2. Dataset

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, datadir, num_refs):
        self.datadir = datadir
        self.num_refs = num_refs
        self.speaker_to_files, self.file_to_speaker, self.file_paths = self.prepare()
        
        
    def prepare(self):
        speaker_to_files, file_to_speaker, file_paths = defaultdict(list), {}, []
        speaker_dirs = glob(f'{self.datadir}/*')
        for sdir in tqdm(speaker_dirs):
            speaker = sdir.split('/')[-1]
            speaker_file_paths = glob(f'{sdir}/*.pt')
            for fpath in speaker_file_paths:
                fidx = len(file_paths)
                speaker_to_files[speaker].append(fidx)
                file_to_speaker[fidx] = speaker
                file_paths.append(fpath)
        return speaker_to_files, file_to_speaker, file_paths
                

    def __len__(self):
        return len(self.file_paths)
    
    def read_file(self, idx):
        fpath = self.file_paths[idx]
        data = torch.load(fpath)
        mel, feat = data['mel'], data['feat']
        return feat, mel

    def __getitem__(self, idx):
        cnt_feat, target_mel = self.read_file(idx)
    
        speaker = self.file_to_speaker[idx]
        ref_files = [f for f in self.speaker_to_files[speaker] if f != idx]
        ref_files = np.random.choice(ref_files, self.num_refs, replace=False)
        ref_mels = [self.read_file(f)[1] for f in ref_files]
        ref_mels = torch.cat(ref_mels, dim=0)
        
        return cnt_feat, ref_mels, target_mel
    
    
def collate_fn(batch):
    srcs, refs, tgts = zip(*batch)

    src_lens = [len(src) for src in srcs]
    ref_lens = [len(ref) for ref in refs]
    tgt_lens = [len(tgt) for tgt in tgts]
    overlap_lens = [min(src_len, tgt_len) for src_len, tgt_len in zip(src_lens, tgt_lens)]

    srcs = pad_sequence(srcs, batch_first=True)  # (batch, max_src_len, wav2vec_dim)

    src_masks = [torch.arange(srcs.size(1)) >= src_len for src_len in src_lens]
    src_masks = torch.stack(src_masks)  # (batch, max_src_len)

    refs = pad_sequence(refs, batch_first=True, padding_value=-20)
    refs = refs.transpose(1, 2)  # (batch, mel_dim, max_ref_len)

    ref_masks = [torch.arange(refs.size(2)) >= ref_len for ref_len in ref_lens]
    ref_masks = torch.stack(ref_masks)  # (batch, max_ref_len)

    tgts = pad_sequence(tgts, batch_first=True, padding_value=-20)
    tgts = tgts.transpose(1, 2)  # (batch, mel_dim, max_tgt_len)

    tgt_masks = [torch.arange(tgts.size(2)) >= tgt_len for tgt_len in tgt_lens]
    tgt_masks = torch.stack(tgt_masks)  # (batch, max_tgt_len)

    return srcs, src_masks, refs, ref_masks, tgts, tgt_masks, overlap_lens

In [None]:
dataset = Dataset(
    datadir = cfg.featdir,
    num_refs = cfg.num_refs,
)

train_size = int(0.9 * len(dataset))
eval_size = len(dataset) - train_size
train_dataset, eval_dataset = torch.utils.data.random_split(dataset, [train_size, eval_size])

In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn)
eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=cfg.batch_size, shuffle=False, drop_last=True, collate_fn=collate_fn)

In [None]:
srcs, src_masks, refs, ref_masks, tgts, tgt_masks, overlap_lens = next(iter(train_loader))

## 3. Model

In [None]:
class Smoother(nn.Module):
    def __init__(
        self, 
        d_model, 
        nhead, 
        d_hid, 
        dropout=0.1
    ):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)

        self.conv1 = nn.Conv1d(d_model, d_hid, 9, padding=4)
        self.conv2 = nn.Conv1d(d_hid, d_model, 1, padding=0)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(
        self, 
        src, 
        src_mask=None, 
        src_key_padding_mask = None
    ):
        src2 = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        src2 = src.transpose(0, 1).transpose(1, 2)
        src2 = self.conv2(F.relu(self.conv1(src2)))
        src2 = src2.transpose(1, 2).transpose(0, 1)

        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src


class Extractor(nn.Module):
    def __init__(
        self, 
        d_model, 
        nhead, 
        d_hid, 
        dropout=0.1, 
        no_residual=False
    ):
        super().__init__()

        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)

        self.conv1 = nn.Conv1d(d_model, d_hid, 9, padding=4)
        self.conv2 = nn.Conv1d(d_hid, d_model, 1, padding=0)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

        self.no_residual = no_residual

    def forward(
        self,
        tgt,
        memory,
        tgt_mask = None,
        memory_mask = None,
        tgt_key_padding_mask = None,
        memory_key_padding_mask = None,
    ):
        tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)

        tgt2, attn = self.cross_attn(
            tgt,
            memory,
            memory,
            attn_mask=memory_mask,
            key_padding_mask=memory_key_padding_mask,
        )

        if self.no_residual:
            tgt = self.dropout2(tgt2)
        else:
            tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)

        tgt2 = tgt.transpose(0, 1).transpose(1, 2)
        tgt2 = self.conv2(F.relu(self.conv1(tgt2)))
        tgt2 = tgt2.transpose(1, 2).transpose(0, 1)

        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)

        return tgt, attn

In [None]:
class UnetBlock(nn.Module):
    def __init__(self, d_model: int):
        super(UnetBlock, self).__init__()

        self.conv1 = nn.Conv1d(80, d_model, 3, padding=1, padding_mode="replicate")
        self.conv2 = nn.Conv1d(d_model, d_model, 3, padding=1, padding_mode="replicate")
        self.conv3 = nn.Conv1d(d_model, d_model, 3, padding=1, padding_mode="replicate")

        self.prenet = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.Linear(768, d_model))

        self.extractor1 = Extractor(d_model, 2, 1024, no_residual=True)
        self.extractor2 = Extractor(d_model, 2, 1024)
        self.extractor3 = Extractor(d_model, 2, 1024)

    def forward(
        self,
        srcs,
        refs,
        src_masks = None,
        ref_masks = None,
    ):
        """Forward function.
        Args:
            srcs: (batch, src_len, 768)
            src_masks: (batch, src_len)
            refs: (batch, 80, ref_len)
            ref_masks: (batch, ref_len)
        """
        tgt = self.prenet(srcs) # (batch, tgt_len, d_model)
        tgt = tgt.transpose(0, 1) # (tgt_len, batch, d_model)

        
        ref1 = self.conv1(refs) # (batch, d_model, mel_len)
        ref2 = self.conv2(F.relu(ref1)) # (batch, d_model, mel_len)
        ref3 = self.conv3(F.relu(ref2)) # (batch, d_model, mel_len)

        
        out, attn1 = self.extractor1(
            tgt,
            ref3.transpose(1, 2).transpose(0, 1),
            tgt_key_padding_mask=src_masks,
            memory_key_padding_mask=ref_masks,
        ) 
        out, attn2 = self.extractor2(
            out,
            ref2.transpose(1, 2).transpose(0, 1),
            tgt_key_padding_mask=src_masks,
            memory_key_padding_mask=ref_masks,
        )
        out, attn3 = self.extractor3(
            out,
            ref1.transpose(1, 2).transpose(0, 1),
            tgt_key_padding_mask=src_masks,
            memory_key_padding_mask=ref_masks,
        ) # (tgt_len, batch, d_model)

        return out, [attn1, attn2, attn3]


class FragmentVC(nn.Module):
    def __init__(self, d_model=512):
        super().__init__()
        self.unet = UnetBlock(d_model)
        self.smoothers = nn.TransformerEncoder(Smoother(d_model, 2, 1024), num_layers=3)
        self.mel_linear = nn.Linear(d_model, 80)

        self.post_net = nn.Sequential(
            nn.Conv1d(80, 512, kernel_size=5, padding=2),
            nn.BatchNorm1d(512),
            nn.Tanh(),
            nn.Dropout(0.5),
            nn.Conv1d(512, 512, kernel_size=5, padding=2),
            nn.BatchNorm1d(512),
            nn.Tanh(),
            nn.Dropout(0.5),
            nn.Conv1d(512, 512, kernel_size=5, padding=2),
            nn.BatchNorm1d(512),
            nn.Tanh(),
            nn.Dropout(0.5),
            nn.Conv1d(512, 512, kernel_size=5, padding=2),
            nn.BatchNorm1d(512),
            nn.Tanh(),
            nn.Dropout(0.5),
            nn.Conv1d(512, 80, kernel_size=5, padding=2),
            nn.BatchNorm1d(80),
            nn.Dropout(0.5),
        )

    def forward(
        self,
        srcs,
        refs,
        src_masks = None,
        ref_masks = None,
    ):
        """Forward function.
        Args:
            srcs: (batch, src_len, 768)
            src_masks: (batch, src_len)
            refs: (batch, 80, ref_len)
            ref_masks: (batch, ref_len)
        """

        out, attns = self.unet(srcs, refs, src_masks=src_masks, ref_masks=ref_masks) # (src_len, batch, d_model)
        out = self.smoothers(out, src_key_padding_mask=src_masks) # (src_len, batch, d_model)
        out = self.mel_linear(out) # (src_len, batch, 80)
        out = out.transpose(1, 0).transpose(2, 1) # (batch, 80, src_len)
        refined = self.post_net(out)
        out = out + refined # (batch, 80, src_len)
        return out, attns 

In [None]:
model = FragmentVC().to(cfg.device)

In [None]:
model.load_state_dict(torch.load('fragmentvc.pt'))

In [None]:
criterion = nn.L1Loss()
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr)

In [None]:
scheduler = get_scheduler('cosine', optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=cfg.num_training_steps)

## 4. Train

In [None]:
def model_fn(batch, model, criterion, ref_included, self_exclude, device):    
    batch = [b.to(device) if type(b)==torch.Tensor else b for b in batch]
    srcs, src_masks, refs, ref_masks, tgts, tgt_masks, overlap_lens = batch

    if ref_included and np.random.rand() >= self_exclude:
        refs = torch.cat((refs, tgts), dim=-1)
        ref_masks = torch.cat((ref_masks, tgt_masks), dim=-1)
    else:
        refs = tgts
        ref_masks = tgt_masks

    outs, _ = model(srcs, refs, src_masks=src_masks, ref_masks=ref_masks)

    losses = []
    for out, tgt, overlap_len in zip(outs.unbind(), tgts.unbind(), overlap_lens):
        loss = criterion(out[:, :overlap_len], tgt[:, :overlap_len])
        losses.append(loss)
    return sum(losses) / len(losses)

In [None]:
wandb.init(project='voice-conversion')

In [None]:
self_exclude = 0.0
ref_included = False

pbar = tqdm(range(50000, cfg.num_training_steps+1))
for st in pbar:
    try:
        batch = next(train_iter)
    except:
        train_iter = iter(train_loader)
        batch = next(train_iter)
    
    loss = model_fn(batch, model, criterion, ref_included, self_exclude, cfg.device)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step()
    
    log = {'loss': loss.item()}
    pbar.set_postfix(log)
    wandb.log(log)

    if st == cfg.phases[0]:
        ref_included = True
        optimizer = torch.optim.AdamW(
            [
                {"params": model.unet.parameters(), "lr": 1e-6},
                {"params": model.smoothers.parameters()},
                {"params": model.mel_linear.parameters()},
                {"params": model.post_net.parameters()},
            ],
            lr=cfg.lr,
        )
        
        scheduler = get_scheduler('cosine', optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=cfg.num_training_steps-cfg.phases[0])
        
        
    if st > cfg.phases[0]:
        self_exclude = min(1., (st - cfg.phases[0]) / (cfg.phases[1] - cfg.phases[0]))
        
    
    if st % 1000 == 0:
        torch.save(model.state_dict(), 'fragmentvc.pt')

## 5. Test

In [None]:
wav2vec = Wav2Vec2Model.from_pretrained(cfg.wav2vec_model_name)
_ = wav2vec.eval().requires_grad_(False).to(cfg.device)

Some weights of the model checkpoint at facebook/wav2vec2-base-960h were not used when initializing Wav2Vec2Model: ['lm_head.weight', 'lm_head.bias']
- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
model = FragmentVC().eval().requires_grad_(False).to(cfg.device)
model.load_state_dict(torch.load('fragmentvc.pt'))

<All keys matched successfully>

In [None]:
vocoder = torch.jit.load("vocoder.pt")
_ = vocoder.eval().to(cfg.device)

In [None]:
src_audios = glob('/mnt/vctk/*/audio/*.wav')

In [None]:
src_audio = np.random.choice(src_audios)
src_wav = load_wav(src_audio)
src_feat = wav2vec(torch.FloatTensor(src_wav).to(cfg.device).unsqueeze(0)).last_hidden_state[0]
src_feat = src_feat.unsqueeze(0).to(cfg.device)

In [None]:
tgt_audios = glob('/mnt/ljspeech/audio/*')
tgt_audios = np.random.choice(tgt_audios, 20, replace=False)

In [None]:
tgt_wavs = [load_wav(a) for a in tgt_audios]
tgt_mels = [log_mel_spectrogram(w) for w in tgt_wavs]
tgt_mels = torch.cat([torch.FloatTensor(m) for m in tgt_mels], dim=0)
tgt_mels = tgt_mels.unsqueeze(0).transpose(2,1).to(cfg.device)

In [None]:
with torch.no_grad():
    out_mels, _ = model(src_feat, tgt_mels)
out_mels = out_mels.transpose(1, 2).squeeze(0)

In [None]:
with torch.no_grad():
    wavs = vocoder.generate([out_mels])[0]
wavs =  wavs.cpu().numpy()
sf.write('output.wav', wavs, 16000)

  


In [None]:
IPython.display.Audio(src_audio)

In [None]:
IPython.display.Audio(tgt_audios[0])

In [None]:
IPython.display.Audio('output.wav')