In [None]:
!wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
!tar -xjf LJSpeech-1.1.tar.bz2

--2021-12-16 10:05:57--  https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
Resolving data.keithito.com (data.keithito.com)... 174.138.79.61
Connecting to data.keithito.com (data.keithito.com)|174.138.79.61|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2748572632 (2.6G) [application/octet-stream]
Saving to: ‘LJSpeech-1.1.tar.bz2’


2021-12-16 10:06:55 (46.0 MB/s) - ‘LJSpeech-1.1.tar.bz2’ saved [2748572632/2748572632]



In [None]:
!pip install librosa



In [None]:
!pip install torch==1.10.0+cu111 torchaudio==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html

Looking in links: https://download.pytorch.org/whl/torch_stable.html


In [None]:
import torch
import librosa
import torchaudio
from torch import nn
from dataclasses import dataclass
from typing import Tuple, Optional, List, Union
from torch.nn.utils.rnn import pad_sequence
import matplotlib.pyplot as plt

# Model + utils

In [None]:
@dataclass
class Batch:
    waveform: torch.Tensor
    waveforn_length: torch.Tensor
    transcript: List[str]
    tokens: torch.Tensor
    token_lengths: torch.Tensor
    durations: Optional[torch.Tensor] = None

    def to(self, device: torch.device) -> 'Batch':
        self.waveform = self.waveform.to(device)
        self.tokens = self.tokens.to(device)
        return self


class LJSpeechCollator:

    def __call__(self, instances: List[Tuple]) -> Batch:
        waveform, waveforn_length, transcript, tokens, token_lengths = list(
            zip(*instances)
        )

        waveform = pad_sequence([
            waveform_[0] for waveform_ in waveform
        ]).transpose(0, 1)
        waveforn_length = torch.cat(waveforn_length)

        tokens = pad_sequence([
            tokens_[0] for tokens_ in tokens
        ]).transpose(0, 1)
        token_lengths = torch.cat(token_lengths)

        return Batch(waveform, waveforn_length, transcript, tokens, token_lengths)

In [None]:
class LJSpeechDataset(torchaudio.datasets.LJSPEECH):

    def __init__(self, root):
        super().__init__(root=root)
        self._tokenizer = torchaudio.pipelines.TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH.get_text_processor()

    def __getitem__(self, index: int):
        waveform, _, _, transcript = super().__getitem__(index)
        waveforn_length = torch.tensor([waveform.shape[-1]]).int()

        tokens, token_lengths = self._tokenizer(transcript)

        return waveform, waveforn_length, transcript, tokens, token_lengths

    def decode(self, tokens, lengths):
        result = []
        for tokens_, length in zip(tokens, lengths):
            text = "".join([
                self._tokenizer.tokens[token]
                for token in tokens_[:length]
            ])
            result.append(text)
        return result

In [None]:
@dataclass
class MelSpectrogramConfig:
    sr: int = 22050
    win_length: int = 1024
    hop_length: int = 256
    n_fft: int = 1024
    f_min: int = 0
    f_max: int = 8000
    n_mels: int = 80
    power: float = 1.0

    # value of melspectrograms if we fed a silence into `MelSpectrogram`
    pad_value: float = -11.5129251


class MelSpectrogram(nn.Module):

    def __init__(self, config: MelSpectrogramConfig):
        super(MelSpectrogram, self).__init__()

        self.config = config

        self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
            sample_rate=config.sr,
            win_length=config.win_length,
            hop_length=config.hop_length,
            n_fft=config.n_fft,
            f_min=config.f_min,
            f_max=config.f_max,
            n_mels=config.n_mels
        )

        # The is no way to set power in constructor in 0.5.0 version.
        self.mel_spectrogram.spectrogram.power = config.power

        # Default `torchaudio` mel basis uses HTK formula. In order to be compatible with WaveGlow
        # we decided to use Slaney one instead (as well as `librosa` does by default).
        mel_basis = librosa.filters.mel(
            sr=config.sr,
            n_fft=config.n_fft,
            n_mels=config.n_mels,
            fmin=config.f_min,
            fmax=config.f_max
        ).T
        self.mel_spectrogram.mel_scale.fb.copy_(torch.tensor(mel_basis))

    def forward(self, audio: torch.Tensor) -> torch.Tensor:
        """
        :param audio: Expected shape is [B, T]
        :return: Shape is [B, n_mels, T']
        """

        mel = self.mel_spectrogram(audio) \
            .clamp_(min=1e-5) \
            .log_()

        return mel

In [None]:
class ResBlock(nn.Module):
    def __init__(self, hidden_size, kernel_size, dilations=[[1, 1], [3, 1], [5, 1]]):
        super().__init__()
        self.layers = nn.ModuleList()
        for i in range(len(dilations)):
            inner_conv = []
            for j in range(len(dilations[i])):
                inner_conv.append(nn.LeakyReLU(0.1))
                inner_conv.append(nn.Conv1d(hidden_size, hidden_size, kernel_size, dilation=dilations[i][j], padding=(dilations[i][j] * (kernel_size - 1)) // 2))
            self.layers.append(nn.Sequential(*inner_conv))
    
    def forward(self, input):
        out = input
        for i in range(len(self.layers)):
            out = out + self.layers[i](out)
        return out


class MRF(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.layers = nn.ModuleList([ResBlock(hidden_size, 3), ResBlock(hidden_size, 7), ResBlock(hidden_size, 11)])
    
    def forward(self, input):
        out = self.layers[0](input)
        for i in range(2):
            out = out + self.layers[i + 1](input)
        out /= 3
        return out


class Generator(nn.Module):
    def __init__(self, in_ch, kernel_sizes=[16, 16, 4, 4]):
        super().__init__()
        self.input_conv = nn.Conv1d(in_ch, 512, 7, padding=3)

        self.layers = nn.Sequential(*[nn.Sequential(
                    nn.LeakyReLU(0.1), nn.ConvTranspose1d(512 // (2 ** i), 512 // (2 ** (i + 1)), kernel_sizes[i], kernel_sizes[i] // 2, kernel_sizes[i] // 4),
                    MRF(512 // (2 ** (i + 1)))) for i in range(4)])

        self.out_conv = nn.Sequential(
            nn.LeakyReLU(0.1),
            nn.Conv1d(32, 1, 7, padding=3),
            nn.Tanh()
        )
    
    def forward(self, input):
        out = self.input_conv(input)
        out = self.layers(out)
        out = self.out_conv(out)
        return out

In [None]:
!pip install wandb

Collecting wandb
  Downloading wandb-0.12.7-py2.py3-none-any.whl (1.7 MB)
[?25l[K     |▏                               | 10 kB 27.1 MB/s eta 0:00:01[K     |▍                               | 20 kB 33.9 MB/s eta 0:00:01[K     |▋                               | 30 kB 23.1 MB/s eta 0:00:01[K     |▊                               | 40 kB 18.7 MB/s eta 0:00:01[K     |█                               | 51 kB 10.7 MB/s eta 0:00:01[K     |█▏                              | 61 kB 10.3 MB/s eta 0:00:01[K     |█▍                              | 71 kB 10.6 MB/s eta 0:00:01[K     |█▌                              | 81 kB 11.8 MB/s eta 0:00:01[K     |█▊                              | 92 kB 9.1 MB/s eta 0:00:01[K     |██                              | 102 kB 9.9 MB/s eta 0:00:01[K     |██▏                             | 112 kB 9.9 MB/s eta 0:00:01[K     |██▎                             | 122 kB 9.9 MB/s eta 0:00:01[K     |██▌                             | 133 kB 9.9 MB/s eta 0:00:

In [None]:
import wandb


class WanDBWriter:
    def __init__(self, project='hifi_project'):
        wandb.login(key='777734be0649971345886f08a6b84c9b9b190223')
        wandb.init(project=project)

    def add_metrics(self, metrics):
        wandb.log(metrics)

    def add_audio(self, pred, true, transcript):
        wandb.log({
            'pred audio': wandb.Audio(pred.squeeze().numpy(), sample_rate=22050, caption=transcript),
            'true audio': wandb.Audio(true.squeeze().numpy(), sample_rate=22050, caption=transcript)
        })

    def add_spectrogram(self, pred, true, transcript):
        wandb.log({
            'pred spectrogram': wandb.Image(pred.squeeze().numpy(), caption=transcript),
            'true spectrogram': wandb.Image(true.squeeze().numpy(), caption=transcript)
        })

# Train

In [None]:
logger = WanDBWriter()



VBox(children=(Label(value=' 24.40MB of 24.40MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.…

0,1
Learning rate,██▇▇▇▇▆▆▆▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁
Loss,█▆▄▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Learning rate,4e-05
Loss,15.13535


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
featurizer = MelSpectrogram(MelSpectrogramConfig()).to(device)

In [None]:
from torch.utils.data import DataLoader
from itertools import islice
dataloader = DataLoader(LJSpeechDataset('.'), batch_size=3, collate_fn=LJSpeechCollator())

In [None]:
dummy_batch = list(islice(dataloader, 1))[0]
waveform = dummy_batch.waveform.to(device)
mels = featurizer(waveform)

In [None]:
from tqdm import tqdm
import random

loss_func = nn.L1Loss()
n_epochs = 7000
model = Generator(80).to(device)
model.train()
opt = torch.optim.AdamW(model.parameters(), lr=2e-4, betas=(0.8, 0.99))
scheduler = torch.optim.lr_scheduler.ExponentialLR(opt, 0.999)

for i in tqdm(range(n_epochs)):
    pred = model(mels)
    pred_mels = featurizer(pred.squeeze(1))
    opt.zero_grad()
    min_mel = min(mels.shape[2], pred_mels.shape[2])
    loss = loss_func(mels[:, :, :min_mel], pred_mels[:, :, :min_mel]) * 45
    loss.backward()
    opt.step()
    scheduler.step()
    logger.add_metrics({"Loss": loss.item(), "Learning rate": scheduler.get_last_lr()[0]})
    if (i + 1) % 50 == 0:
        rand_idx = random.randint(0, 2)
        tr = dummy_batch.transcript[rand_idx]
        mel_t = mels[rand_idx].detach()
        mel_p = pred_mels[rand_idx].detach()

        aud_t = waveform[rand_idx]
        aud_p = pred[rand_idx].detach()
        logger.add_audio(aud_p.cpu(), aud_t.cpu(), tr)
        logger.add_spectrogram(mel_p.cpu(), mel_t.cpu(), tr)

100%|██████████| 7000/7000 [2:18:32<00:00,  1.19s/it]
