In [1]:
!pip install arpa
!pip install pydub
!pip install ffmpeg
!pip install datasets
!pip install youtokentome

  Building wheel for youtokentome (setup.py) ... [?25l[?25hdone
  Created wheel for youtokentome: filename=youtokentome-1.0.6-cp310-cp310-linux_x86_64.whl size=1927629 sha256=c93a75c538c9b3b45c2c0d5088d8add6b0ff8ef196eca5fd85f001e287161aea
  Stored in directory: /root/.cache/pip/wheels/df/85/f8/301d2ba45f43f30bed2fe413efa760bc726b8b660ed9c2900c
Successfully built youtokentome
Installing collected packages: youtokentome
Successfully installed youtokentome-1.0.6


In [2]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
from datasets import load_dataset, DatasetDict
from datasets import Audio

common_voice = DatasetDict()

common_voice_train = load_dataset("mozilla-foundation/common_voice_11_0", "ru", split="train", use_auth_token=True)
common_voice_test = load_dataset("mozilla-foundation/common_voice_11_0", "ru", split="test", use_auth_token=True)

In [5]:
common_voice_train = common_voice_train.cast_column("audio", Audio(sampling_rate=16000))
common_voice_test = common_voice_test.cast_column("audio", Audio(sampling_rate=16000))

In [6]:
common_voice_train = common_voice_train.remove_columns(['client_id', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'])
common_voice_test = common_voice_test.remove_columns(['client_id', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'])

In [7]:
print(common_voice_train)
print(common_voice_test)

Dataset({
    features: ['path', 'audio', 'sentence'],
    num_rows: 22862
})
Dataset({
    features: ['path', 'audio', 'sentence'],
    num_rows: 9630
})


In [8]:
import os
import random
import shutil
import string
import time
from collections import defaultdict
from typing import List, Tuple, TypeVar, Optional, Callable, Iterable
import youtokentome as yttm

import arpa
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from matplotlib.colors import LogNorm
from torch import optim
from tqdm.notebook import tqdm
from torch import Tensor
import torch.nn.init as init
import math
from torchvision import transforms
from torch import distributions
import librosa
import editdistance
from torch.utils.data import DataLoader, Dataset
import pathlib

In [9]:
class QuartzNetBlock(nn.Module):
    def __init__(self, feat_in, filters, repeat=3, kernel_size=11, stride=1,
                 dilation=1, residual=True, separable=False):
        super().__init__()
        self.res = nn.Sequential(nn.Conv1d(feat_in, filters, kernel_size=1),
                                 nn.BatchNorm1d(filters)) if residual else None
        self.conv = nn.ModuleList()
        for idx in range(repeat):
            self.conv.extend(
                self._get_conv_bn_layer(
                    feat_in,
                    filters,
                    kernel_size=kernel_size,
                    stride=stride,
                    dilation=dilation,
                    separable=separable))
            if (idx != repeat - 1 and residual):
                self.conv.extend([nn.ReLU(), nn.Dropout(p=0.2)])
            feat_in = filters
        self.out = nn.Sequential(nn.ReLU(), nn.Dropout(p=0.2))

    def _get_conv_bn_layer(self, in_channels, out_channels, kernel_size,
                           stride=1, dilation=1, separable=False):
        if dilation > 1:
            same_padding = (dilation * kernel_size) // 2 - 1
        else:
            same_padding = kernel_size // 2
        if separable:
            layers = [
                nn.Conv1d(in_channels, in_channels, kernel_size,
                          groups=in_channels, stride=stride, dilation=dilation, padding=same_padding),
                nn.Conv1d(in_channels, out_channels, kernel_size=1)
            ]
        else:
            layers = [
                nn.Conv1d(in_channels, out_channels, kernel_size,
                          stride=stride, dilation=dilation, padding=same_padding)
            ]
        layers.append(nn.BatchNorm1d(out_channels))
        return layers

    def forward(self, inputs):
        inputs_for_res = inputs
        for layer in self.conv:
            inputs = layer(inputs)
        if self.res is not None:
            inputs = inputs + self.res(inputs_for_res)
        inputs = self.out(inputs)
        return inputs


class QuartzNet(nn.Module):
    def __init__(self, quartznet_conf, feat_in, num_classes):
        super().__init__()
        layers = []
        for block_conf in quartznet_conf:
            layers.append(
                QuartzNetBlock(feat_in,
                               block_conf['filters'],
                               repeat=block_conf['repeat'],
                               kernel_size=block_conf['kernel'],
                               stride=block_conf['stride'],
                               dilation=block_conf['dilation'],
                               residual=block_conf['residual'],
                               separable=block_conf['separable']))
            feat_in = block_conf['filters']
        layers.append(nn.Conv1d(feat_in, num_classes, kernel_size=1))
        self.layers = nn.Sequential(*layers)

    def forward(self, inputs):
        return self.layers(inputs)

In [10]:
mel_spectrogramer = torchaudio.transforms.MelSpectrogram(
    sample_rate=16000,
    n_fft=512,
    win_length=int(20e-3 * 16000),
    hop_length=int(10e-3 * 16000),
    f_min=0,
    f_max=8000,
    n_mels=64,
)


class AddNormalNoise(object):

    def __init__(self):
        self.var = 0.01

    def __call__(self, wav):
        noiser = distributions.Normal(0, self.var)
        if np.random.uniform() < 0.5:
            wav += noiser.sample(wav.size())
        return wav.clamp(-1, 1)


class TimeStretch(object):

    def __init__(self):
        self.min_scale = 0.9
        self.max_scale = 1.1

    def __call__(self, wav):
        random_stretch = np.random.uniform(self.min_scale, self.max_scale, 1)[0]
        if np.random.uniform() < 0.5:
            wav_stretched = librosa.effects.time_stretch(wav.numpy(), rate=random_stretch)
        else:
            wav_stretched = wav.numpy()
        return torch.from_numpy(wav_stretched)


class PitchShifting(object):
    def __init__(self):
        self.sample_rate = 16000
        self.min_shift = -3
        self.max_shift = 3

    def __call__(self, wav):
        random_shift = np.random.uniform(self.min_shift, self.max_shift, 1)[0]
        if np.random.uniform() < 0.5:
            wav_shifted = librosa.effects.pitch_shift(wav.numpy(), sr=self.sample_rate, n_steps=random_shift)
        else:
            wav_shifted = wav.numpy()
        return torch.from_numpy(wav_shifted)


class MelSpectrogram(object):
    def __call__(self, wav):
        mel_spectrogram = mel_spectrogramer(wav.float())
        return mel_spectrogram


class NormalizePerFeature(object):
    def __call__(self, spec):
        log_mel = torch.log(torch.clamp(spec, min=1e-18))
        mean = torch.mean(log_mel, dim=1, keepdim=True)
        std = torch.std(log_mel, dim=1, keepdim=True) + 1e-5
        log_mel = (log_mel - mean) / std
        return log_mel


transforms = {
    'train': transforms.Compose([
        AddNormalNoise(),
        PitchShifting(),
        TimeStretch(),
        MelSpectrogram(),
        
    ]),
    'test': transforms.Compose([
        MelSpectrogram(),
        NormalizePerFeature(),
    ]),
}


def collate_fn(batch):
    inputs, inputs_length, targets, targets_length = list(zip(*batch))
    input_aligned = torch.nn.utils.rnn.pad_sequence(inputs, batch_first=True).permute(0, 2, 1)
    target_aligned = torch.nn.utils.rnn.pad_sequence(targets, batch_first=True)

    return input_aligned, torch.Tensor(inputs_length).long(), \
           target_aligned, torch.Tensor(targets_length).long()

In [11]:
PUNCTUATION = string.punctuation + '—–«»−…‑'
def prepare_bpe():
    #tokenizer = tokenize.WordPunctTokenizer()
    bpe_path = "bpe_model"
    if not os.path.exists(bpe_path):
        train_data_path = 'bpe_texts.txt'
        with open(train_data_path, "w") as f:
            for i in range(len(common_voice_train)):
                text = common_voice_train[i]['sentence'].lower().strip().translate(str.maketrans('', '', PUNCTUATION))
                f.write(f"{text}\n")
        yttm.BPE.train(data=train_data_path, vocab_size=120, model=bpe_path)
        os.system(f'rm {train_data_path}')
    bpe = yttm.BPE(model=bpe_path)
    return bpe

In [12]:
class SpeechDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.labels = df['sentence']
        self.transform = transform
        self.bpe = prepare_bpe()

    def __getitem__(self, idx):
        text = self.labels[idx].lower().strip().translate(str.maketrans('', '', PUNCTUATION))
        text = np.array(self.bpe.encode(text, dropout_prob=0.2))
        wav = torch.tensor(self.df[idx]['audio']['array'])
        sr = self.df[idx]['audio']['sampling_rate']
        input = self.transform(wav.transpose(-1, 0))
        len_input = input.shape[1]
        return input.T.float(), len_input // 2, torch.Tensor(text).float(), len(text)

    def __len__(self):
        return len(self.labels)

In [13]:
class CerWer():
    def __init__(self, blank_index=0, space_simbol='▁'):
        bpe = prepare_bpe()
        self.idx2char = bpe.id_to_subword
        self.blank_index = blank_index
        self.space_simbol = space_simbol

    def __call__(self, predicts, targets, inputs_length, targets_length):
        cer = 0.0
        wer = 0.0
        for predict, target, input_length, target_length in zip(predicts, targets, inputs_length, targets_length):
            predict_string = self.process_string(predict, input_length, remove_repetitions=True)
            target_string = self.process_string(target, target_length)

            predict_words = predict_string.rstrip().split(self.space_simbol)
            target_words = target_string.rstrip().split(self.space_simbol)

            dist = editdistance.eval(target_string, predict_string)
            dist_word = editdistance.eval(target_words, predict_words)

            cer += dist / len(target_string)
            wer += dist_word / len(target_words)
        return cer, wer, predict_string, target_string

    def process_string(self, sequence, length, remove_repetitions=False):
        string = ''
        for i in range(length):
            char = self.idx2char(sequence[i])
            if char != self.idx2char(self.blank_index):
                if remove_repetitions and i != 0 and char == self.idx2char(sequence[i - 1]):
                    pass
                else:
                    string = string + char
        return string

    def inference(self, predicts, input_len):
        """
        :param predicts:
        :param input_len:
        :return:
        """
        predict_string = self.process_string(predicts, input_len, remove_repetitions=True)
        predict_words = predict_string.split('▁')
        return predict_words

In [14]:
train_dataset = SpeechDataset(common_voice_train, transform=transforms['train'])
test_dataset = SpeechDataset(common_voice_test, transform=transforms['test'])

In [15]:
train_dataloader = DataLoader(train_dataset,
                              batch_size=32,
                              num_workers=4,
                              shuffle=True,
                              collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset,
                             batch_size=32,
                             num_workers=4,
                             collate_fn=collate_fn)



In [16]:
model_config = [
    {'filters': 256, 'repeat': 1, 'kernel': 33, 'stride': 2, 'dilation': 1, 'residual': False, 'separable': True},
    {'filters': 256, 'repeat': 5, 'kernel': 33, 'stride': 1, 'dilation': 1, 'residual': True, 'separable': True},
    {'filters': 256, 'repeat': 5, 'kernel': 39, 'stride': 1, 'dilation': 1, 'residual': True, 'separable': True},
    {'filters': 512, 'repeat': 5, 'kernel': 51, 'stride': 1, 'dilation': 1, 'residual': True, 'separable': True},
    {'filters': 512, 'repeat': 5, 'kernel': 63, 'stride': 1, 'dilation': 1, 'residual': True, 'separable': True},
    {'filters': 512, 'repeat': 5, 'kernel': 75, 'stride': 1, 'dilation': 1, 'residual': True, 'separable': True},
    {'filters': 512, 'repeat': 1, 'kernel': 87, 'stride': 1, 'dilation': 2, 'residual': False, 'separable': True},
    {'filters': 1024, 'repeat': 1, 'kernel': 1, 'stride': 1, 'dilation': 1, 'residual': False, 'separable': False}
]

In [16]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [17]:
device

device(type='cuda', index=0)

In [None]:
model = QuartzNet(quartznet_conf=model_config, num_classes=120, feat_in=64)
model.to(device)
criterion = nn.CTCLoss(zero_infinity=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.005)
num_steps = len(train_dataloader) * 100
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps, eta_min=0.00001)
cerwer = CerWer()

for epoch in range(50):
    train_cer, train_wer, val_wer, val_cer = 0.0, 0.0, 0.0, 0.0
    train_losses = []
    model.train()
    for inputs, inputs_length, targets, targets_length in tqdm(train_dataloader):
        inputs = inputs.to(device)
        targets = targets.to(device)
        outputs = model(inputs)
        outputs = outputs.permute(2, 0, 1)
        optimizer.zero_grad()
        loss = criterion(outputs.log_softmax(dim=2), targets, inputs_length, targets_length).cpu()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 15)
        optimizer.step()
        lr_scheduler.step()
        train_losses.append(loss.item())
        _, max_probs = torch.max(outputs, 2)
        train_epoch_cer, train_epoch_wer, train_decoded_words, train_target_words = cerwer(max_probs.T.cpu().numpy(),
                                                                                           targets.cpu().numpy(),
                                                                                           inputs_length,
                                                                                           targets_length)
        train_wer += train_epoch_wer
        train_cer += train_epoch_cer

    model.eval()
    with torch.no_grad():
        val_losses = []
        for inputs, inputs_length, targets, targets_length in tqdm(test_dataloader):
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            outputs = outputs.permute(2, 0, 1)
            loss = criterion(outputs.log_softmax(dim=2), targets, inputs_length, targets_length).cpu()
            val_losses.append(loss.item())
            _, max_probs = torch.max(outputs, 2)
            val_epoch_cer, val_epoch_wer, val_decoded_words, val_target_words = cerwer(max_probs.T.cpu().numpy(),
                                                                                       targets.cpu().numpy(),
                                                                                       inputs_length, targets_length)
            val_wer += val_epoch_wer
            val_cer += val_epoch_cer
        torch.save(model.state_dict(), f"model{epoch}.pth")

  0%|          | 0/715 [00:00<?, ?it/s]



  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/715 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/715 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/715 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/715 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/715 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/715 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/715 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/715 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/715 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/715 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/715 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/715 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/715 [00:00<?, ?it/s]

In [None]:
device = torch.device('cpu')
model = QuartzNet(quartznet_conf=model_config, num_classes=120, feat_in=64)
model.load_state_dict(torch.load("model12.pth", map_location=device))
model = model.eval()
wav_file = "ac470735-fc2f-4e39-85ab-4ff1a2702de8 (1).wav"
wav, sr = torchaudio.load(wav_file)
wav = wav.squeeze()
input = transforms['test'](wav)
len_input = input.shape[1]
cerwer = CerWer()
output = model(input.unsqueeze(0))
output = output.permute(2, 0, 1)
_, max_probs = torch.max(output, 2)
decoded_words = cerwer.inference(max_probs.T.numpy().squeeze(), len_input // 2)
with open(f"{wav_file}".replace("wav", "txt"), "w") as txt_file:
    txt_file.write(" ".join(decoded_words) + "\n")