<a href="https://colab.research.google.com/github/pielie34/quartznet-implementation/blob/main/QN_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install youtokentome
!pip install pyyaml
!pip install easydict
!pip install wandb
!pip install torchaudio
!pip install librosa
!pip install python-Levenshtein
!pip install audiomentations
!pip install pytorch_warmup

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
#!git clone --recursive https://github.com/parlance/ctcdecode.git
#!cd ctcdecode && pip install .

In [None]:
import os
import random
import numpy as np
import string
import json
import time
import argparse
from tqdm import tqdm
from collections import defaultdict
import math
import argparse
import importlib

from torch import nn
import torch
from torch.utils import data
from torch.utils.data import DataLoader, Subset, ConcatDataset
from torch.utils.data.dataloader import default_collate
import torchaudio
from torchvision.transforms import Normalize

import pytorch_warmup as warmup
import youtokentome as yttm
from audiomentations import TimeStretch, PitchShift, AddGaussianNoise
from functools import partial
import yaml
from easydict import EasyDict as edict
import wandb

In [None]:
import Levenshtein as Lev

# Data

## Dataset

In [None]:
class LibriDataset(torchaudio.datasets.LIBRISPEECH):
    def __init__(self, transforms, *args, **kwargs):
        if kwargs.get('download', False):
            os.makedirs(kwargs['root'], exist_ok=True)
        super(LibriDataset, self).__init__(*args, **kwargs)
        self.transforms = transforms

    def __getitem__(self, idx):
        audio, sample_rate, text, _, _, _ = super().__getitem__(idx)
        return self.transforms({'audio' : audio, 'text': text, 'sample_rate': sample_rate})

    def get_text(self, idx):
        fileid = self._walker[idx]
        speaker_id, chapter_id, utterance_id = fileid.split("-")

        file_text = speaker_id + "-" + chapter_id + self._ext_txt
        file_text = os.path.join(self._path, speaker_id, chapter_id, file_text)

        fileid_audio = speaker_id + "-" + chapter_id + "-" + utterance_id
        file_audio = fileid_audio + self._ext_audio

        # Load text
        with open(file_text) as ft:
            for line in ft:
                fileid_text, utterance = line.strip().split(" ", 1)
                if fileid_audio == fileid_text:
                    break
            else:
                # Translation not found
                raise FileNotFoundError("Translation not found for " + fileid_audio)

        return self.transforms({'text' : utterance})['text']


def get_dataset(config, transforms=lambda x: x, part='train'):
    if part == 'train':
        dataset = LibriDataset(root='DB/LibriSpeech', url='train-clean-100', download=True, transforms=transforms)
        return dataset
    elif part == 'val':
        dataset = LibriDataset(root='DB/LibriSpeech', url='dev-clean', download=True, transforms=transforms)
        return dataset
    elif part == 'bpe':
        dataset = LibriDataset(root='DB/LibriSpeech', url='train-clean-100', download=True, transforms=transforms)
        indices = list(range(len(dataset)))
        return dataset, indices
    else:
        raise ValueError('Unknown')

## Collate

In [None]:
def no_pad_collate(batch):
    keys = batch[0].keys()
    collated_batch = {key: [] for key in keys}
    for key in keys:
        items = [item[key] for item in batch]
        collated_batch[key] = items
    return collated_batch


def gpu_collate(batch):
    '''
    Padds batch of variable length
    note: it converts things ToTensor manually here since the ToTensor transform
    assume it takes in images rather than arbitrary tensors.
    '''
    keys = batch[0].keys()
    collated_batch = {key: [] for key in keys}
    for key in keys:
        items = [item[key] for item in batch]
        if len(items[0]) < 2:
            items = [item[None] for item in items]
        items = torch.nn.utils.rnn.pad_sequence(items)
        collated_batch[key] = items
    return collated_batch


def collate_fn(batch):
    keys = batch[0].keys()
    max_lengths = {key: 0 for key in keys}
    collated_batch = {key: [] for key in keys}

    # find out the max lengths
    for row in batch:
        for key in keys:
            if not np.isscalar(row[key]):
                max_lengths[key] = max(max_lengths[key], row[key].shape[-1])

    # pad to the max lengths
    for row in batch:
        for key in keys:
            if not np.isscalar(row[key]):
                array = row[key]
                dim = len(array.shape)
                assert dim == 1 or dim == 2
                if dim == 1:
                    padded_array = np.pad(array, (0, max_lengths[key] - array.shape[-1]), mode='constant')
                else:
                    # padded_array = np.pad(array, ((0, max_lengths[key] - array.shape[0]), (0, 0)), mode='constant')
                    padded_array = np.pad(array, ((0, 0), (0, max_lengths[key] - array.shape[-1])), mode='constant')
                collated_batch[key].append(padded_array)
            else:
                collated_batch[key].append(row[key])

    # use the default_collate to convert to tensors
    for key in keys:
        collated_batch[key] = default_collate(collated_batch[key])
    return collated_batch

## Transforms

In [None]:
PUNCTUATION = string.punctuation + '—–«»−…‑'


class Compose(object):
    """Composes several transforms together."""

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, data):
        for t in self.transforms:
            try:
              data = t(data)
            except TypeError:
              # audiomentation transform
              data['audio'] = t(data['audio'], sample_rate=data['sample_rate'])
        return data


class AudioSqueeze:
    def __call__(self, data):
        data['audio'] = data['audio'].squeeze(0)
        return data


class BPEtexts:
    def __init__(self, bpe, dropout_prob=0):
        self.bpe = bpe
        self.dropout_prob = dropout_prob

    def __call__(self, data):
        data['text'] = np.array(self.bpe.encode(data['text'], dropout_prob=self.dropout_prob))
        return data


class TextPreprocess:
    def __call__(self, data):
        data['text'] = data['text'].lower().strip().translate(str.maketrans('', '', PUNCTUATION))
        return data


class ToNumpy:
    """
    Transform to make numpy array
    """
    def __call__(self, data):
        data['audio'] = np.array(data['audio'])
        return data

# on gpu:

class ToGpu:
    def __init__(self, device):
        self.device = device

    def __call__(self, data):
        data = {k: [torch.from_numpy(np.array(item)).to(self.device) for item in v] for k, v in data.items()}
        return data


class AddLengths:
    def __call__(self, data):
        data['input_lengths'] = torch.tensor([item.shape[-1] for item in data['audio']]).to(data['audio'][0].device)
        data['target_lengths'] = torch.tensor([item.shape[0] for item in data['text']]).to(data['audio'][0].device)
        return data


class Pad:
    def __call__(self, data):
        padded_batch = {}
        for k, v in data.items():
            if len(v[0].shape) < 2:
                items = [item[..., None] for item in v]
                padded_batch[k] = torch.nn.utils.rnn.pad_sequence(items, batch_first=True)[..., 0]
            else:
                items = [item.permute(1, 0) for item in v]
                padded_batch[k] = torch.nn.utils.rnn.pad_sequence(items, batch_first=True).permute(0, 2, 1)
        return padded_batch


class MelSpectrogram(torchaudio.transforms.MelSpectrogram):
    def forward(self, data):
        for i in range(len(data['audio'])):
            data['audio'][i] = super(MelSpectrogram, self).forward(data['audio'][i])
        return data


class NormalizedMelSpectrogram(torchaudio.transforms.MelSpectrogram):
    def __init__(self, normalize=None, *args, **kwargs):
        super(NormalizedMelSpectrogram, self).__init__(*args, **kwargs)
        if normalize == 'to05':
            self.normalize = Normalize([0.5], [0.5])
        elif normalize == 'touniform':
            self.normalize = lambda x: (x - torch.mean(x, dim=1, keepdim=True)) / (torch.std(x, dim=1, keepdim=True) + 1e-18)
        else:
            self.normalize = None


    def forward(self, data):
        for i in range(len(data['audio'])):
            melsec = super(NormalizedMelSpectrogram, self).forward(data['audio'][i])
            if self.normalize is not None:
                logmelsec = torch.log(torch.clamp(melsec, min=1e-18))
                melsec = self.normalize(logmelsec[None])[0]
            data['audio'][i] = melsec
        return data


class MaskSpectrogram(object):
    """Masking a spectrogram aka SpecAugment."""

    def __init__(self, frequency_mask_max_percentage=0.3, time_mask_max_percentage=0.1, probability=1.0):
        self.frequency_mask_probability = frequency_mask_max_percentage
        self.time_mask_probability = time_mask_max_percentage
        self.probability = probability

    def __call__(self, data):
        for i in range(len(data['audio'])):
            if random.random() < self.probability:
                nu, tau = data['audio'][i].shape

                f = random.randint(0, int(self.frequency_mask_probability*nu))
                f0 = random.randint(0, nu - f)
                data['audio'][i][f0:f0 + f, :] = 0

                t = random.randint(0, int(self.time_mask_probability*tau))
                t0 = random.randint(0, tau - t)
                data['audio'][i][:, t0:t0 + t] = 0

        return data

# Utils

In [None]:
def fix_seeds(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.deterministic = True

def remove_from_dict(the_dict, keys):
    for key in keys:
        the_dict.pop(key, None)
    return the_dict

def prepare_bpe(config):
    # train BPE
    dataset, ids = get_dataset(config, part='bpe', transforms=TextPreprocess())
    train_data_path = 'bpe_texts.txt'
    with open(train_data_path, "w") as f:
        # run ovefr only train part
        for i in ids:
            text = dataset.get_text(i)
            f.write(f"{text}\n")
    yttm.BPE.train(data=train_data_path, vocab_size=80, model='yttm.bpe')
    os.system(f'rm {train_data_path}')

    bpe = yttm.BPE(model='yttm.bpe')
    return bpe

# Model

## Encoder

GroupShuffle для каждого сверточного слоя

In [None]:
class GroupShuffle(nn.Module):

    def __init__(self, groups, channels):
        super(GroupShuffle, self).__init__()

        self.groups = groups
        self.channels_per_group = channels // groups

    def forward(self, x):
        sh = x.shape

        x = x.view(-1, self.groups, self.channels_per_group, sh[-1])

        x = torch.transpose(x, 1, 2).contiguous()

        x = x.view(-1, self.groups * self.channels_per_group, sh[-1])

        return x

1D сверточный слой + нормализация

In [None]:
def get_conv_bn_layer(in_channels, out_channels, kernel_size=11,
                     stride=1, dilation=1, padding=0, bias=False,
                     groups=1, separable=False,
                     normalization="batch", norm_groups=1):
    if norm_groups == -1:
        norm_groups = out_channels

    if separable:
        layers = [
            nn.Conv1d(in_channels, in_channels, kernel_size,
                    stride=stride, dilation=dilation, padding=padding, bias=bias,
                    groups=in_channels),
            nn.Conv1d(in_channels, out_channels, kernel_size=1,
                    stride=1, dilation=1, padding=0, bias=bias, groups=groups)
        ]
    else:
        layers = [
            nn.Conv1d(in_channels, out_channels, kernel_size,
                    stride=stride, dilation=dilation, padding=padding, bias=bias, groups=groups)
        ]

    if normalization == "group":
        layers.append(nn.GroupNorm(
            num_groups=norm_groups, num_channels=out_channels))
    elif normalization == "instance":
        layers.append(nn.GroupNorm(
            num_groups=out_channels, num_channels=out_channels))
    elif normalization == "layer":
        layers.append(nn.GroupNorm(
            num_groups=1, num_channels=out_channels))
    elif normalization == "batch":
        layers.append(nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.1))
    else:
        raise ValueError(
            f"Normalization method ({normalization}) does not match"
            f" one of [batch, layer, group, instance].")

    if groups > 1:
        layers.append(GroupShuffle(groups, out_channels))
    return nn.Sequential(*layers)

Активация + Dropout

In [None]:
def get_act_dropout_layer(drop_prob=0.2, activation='relu'):
    if activation is None or activation == 'tanh':
        activation = nn.Hardtanh(min_val=0.0, max_val=20.0)
    elif activation == 'relu':
        activation = nn.ReLU()
    layers = [
        activation,
        nn.Dropout(p=drop_prob)
    ]
    return nn.Sequential(*layers)

Базовый блок c параметром R (repeat) и residual connection

In [None]:
class MainBlock(nn.Module):
    def __init__(self, inplanes, planes, repeat=3, kernel_size=11, stride=1, residual=True,
             dilation=1, dropout=0.2, activation='relu',
             groups=1, separable=False, normalization="batch",
             norm_groups=1):
        super(MainBlock, self).__init__()
        padding_val = get_same_padding(kernel_size, stride, dilation)

        temp_planes = inplanes
        net = []
        for _ in range(repeat):
            net.append(
                get_conv_bn_layer(
                    temp_planes,
                    planes,
                    kernel_size=kernel_size,
                    stride=stride,
                    dilation=dilation,
                    padding=padding_val,
                    groups=groups,
                    separable=separable,
                    normalization=normalization,
                    norm_groups=norm_groups)
            )
            net.append(
                get_act_dropout_layer(dropout, activation)
            )
            temp_planes = planes
        self.net = nn.Sequential(*net)
        self.residual = residual
        if self.residual:
            self.residual_layer = get_conv_bn_layer(
                                inplanes,
                                planes,
                                kernel_size=1,
                                normalization=normalization,
                                norm_groups=norm_groups)
        self.out = get_act_dropout_layer(dropout, activation)

    def forward(self, x):
        out = self.net(x)
        if self.residual:
            resudial = self.residual_layer(x)
            out += resudial
        return self.out(out)

Padding

In [None]:
def get_same_padding(kernel_size, stride, dilation):
    if stride > 1 and dilation > 1:
        raise ValueError("Only stride OR dilation may be greater than 1")
    if dilation > 1:
        return (dilation * kernel_size) // 2 - 1
    return kernel_size // 2

Инициализация весов

In [None]:
def init_weights(m, mode='xavier_uniform'):
    if isinstance(m, nn.Conv1d):
        if mode == 'xavier_uniform':
            nn.init.xavier_uniform_(m.weight, gain=1.0)
        elif mode == 'xavier_normal':
            nn.init.xavier_normal_(m.weight, gain=1.0)
        elif mode == 'kaiming_uniform':
            nn.init.kaiming_uniform_(m.weight, nonlinearity="relu")
        elif mode == 'kaiming_normal':
            nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
        else:
            raise ValueError("Unknown Initialization mode: {0}".format(mode))
    elif isinstance(m, nn.BatchNorm1d):
        if m.track_running_stats:
            m.running_mean.zero_()
            m.running_var.fill_(1)
            m.num_batches_tracked.zero_()
        if m.affine:
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

Сложим все B слоев вместе

In [None]:
class QuartzNet(nn.Module):
    def __init__(
            self,
            model_config,
            feat_in,
            vocab_size,
            activation='relu',
            normalization_mode="batch",
            norm_groups=-1,
            frame_splicing=1,
            init_mode='xavier_uniform',
            **kwargs
    ):
        super(QuartzNet, self).__init__()

        feat_in = feat_in * frame_splicing
        self.stride = 1

        residual_panes = []
        layers = []
        for lcfg in model_config:
            self.stride *= lcfg['stride']

            groups = lcfg.get('groups', 1)
            separable = lcfg.get('separable', False)
            residual = lcfg.get('residual', True)
            layers.append(
                MainBlock(feat_in,
                    lcfg['filters'],
                    repeat=lcfg['repeat'],
                    kernel_size=lcfg['kernel'],
                    stride=lcfg['stride'],
                    dilation=lcfg['dilation'],
                    dropout=lcfg['dropout'] if 'dropout' in lcfg else 0.0,
                    residual=residual,
                    groups=groups,
                    separable=separable,
                    normalization=normalization_mode,
                    norm_groups=norm_groups,
                    activation=activation))
            feat_in = lcfg['filters']

        self.encoder = nn.Sequential(*layers)
        self.classify = nn.Conv1d(1024, vocab_size,
                      kernel_size=1, bias=True)
        self.apply(lambda x: init_weights(x, mode=init_mode))

    def forward(self, audio_signal):
        feat = self.encoder(audio_signal)
        # BxCxT
        return self.classify(feat)

    def load_weights(self, path, map_location='cpu'):
        weights = torch.load(path, map_location=map_location)
        print(self.load_state_dict(weights, strict=False))

In [None]:
_quartznet5x5_config = [
    {'filters': 256, 'repeat': 1, 'kernel': 33, 'stride': 2, 'dilation': 1, 'dropout': 0.2, 'residual': False, 'separable': True},

    {'filters': 256, 'repeat': 5, 'kernel': 33, 'stride': 1, 'dilation': 1, 'dropout': 0.2, 'residual': True, 'separable': True},

    {'filters': 256, 'repeat': 5, 'kernel': 39, 'stride': 1, 'dilation': 1, 'dropout': 0.2, 'residual': True, 'separable': True},

    {'filters': 512, 'repeat': 5, 'kernel': 51, 'stride': 1, 'dilation': 1, 'dropout': 0.2, 'residual': True, 'separable': True},

    {'filters': 512, 'repeat': 5, 'kernel': 63, 'stride': 1, 'dilation': 1, 'dropout': 0.2, 'residual': True, 'separable': True},

    {'filters': 512, 'repeat': 5, 'kernel': 75, 'stride': 1, 'dilation': 1, 'dropout': 0.2, 'residual': True, 'separable': True},

    {'filters': 512, 'repeat': 1, 'kernel': 87, 'stride': 1, 'dilation': 2, 'dropout': 0.2, 'residual': False, 'separable': True},

    {'filters': 1024, 'repeat': 1, 'kernel': 1, 'stride': 1, 'dilation': 1, 'dropout': 0.2, 'residual': False, 'separable': False}
]

In [None]:
scratch = QuartzNet(_quartznet5x5_config, 64, 28)

## Decoder

In [None]:
class Decoder(object):
    """
    Basic decoder class from which all other decoders inherit. Implements several
    helper functions. Subclasses should implement the decode() method.
    Arguments:
        labels (string): mapping from integers to characters.
        blank_index (int, optional): index for the blank '_' character. Defaults to 0.
        space_index (int, optional): index for the space ' ' character. Defaults to 28.
    """

    def __init__(self, bpe, blank_index=0, space_simbol='▁'):
        # e.g. labels = "_'ABCDEFGHIJKLMNOPQRSTUVWXYZ#"
        self.labels = labels = bpe.vocab()
        print(labels)
        self.int_to_char = bpe.id_to_subword
        self.blank_index = blank_index
        self.space_simbol = space_simbol
        space_index = None  # To prevent errors in decode, we add an out of bounds index for the space
        if self.space_simbol in labels:
            space_index = labels.index(self.space_simbol)
        else:
            raise ValueError('I wanna break free!!!')
        self.space_index = space_index

    def wer(self, s1, s2):
        """
        Computes the Word Error Rate, defined as the edit distance between the
        two provided sentences after tokenizing to words.
        Arguments:
            s1 (string): space-separated sentence
            s2 (string): space-separated sentence
        """

        # build mapping of words to integers
        b = set(s1.split(self.space_simbol) + s2.split(self.space_simbol))
        word2char = dict(zip(b, range(len(b))))

        # map the words to a char array (Levenshtein packages only accepts
        # strings)
        w1 = [chr(word2char[w]) for w in s1.split(self.space_simbol)]
        w2 = [chr(word2char[w]) for w in s2.split(self.space_simbol)]

        return Lev.distance(''.join(w1), ''.join(w2)) / len(w1)

    def cer(self, s1, s2):
        """
        Computes the Character Error Rate, defined as the edit distance.
        Arguments:
            s1 (string): space-separated sentence
            s2 (string): space-separated sentence
        """
        s1_concated, s2_concated = s1.replace(self.space_simbol, ''), s2.replace(self.space_simbol, '')
        return Lev.distance(s1_concated, s2_concated) / len(s1)

    def convert_to_strings(self, sequences, sizes=None, remove_repetitions=False, return_offsets=False):
        """Given a list of numeric sequences, returns the corresponding strings"""
        strings = []
        offsets = [] if return_offsets else None
        for x in range(len(sequences)):
            seq_len = sizes[x] if sizes is not None else len(sequences[x])
            string, string_offsets = self.process_string(sequences[x], seq_len, remove_repetitions)
            strings.append(string)  # We only return one path
            if return_offsets:
                offsets.append(string_offsets)
        if return_offsets:
            return strings, offsets
        else:
            return strings

    def process_string(self, sequence, size, remove_repetitions=False):
        string = ''
        offsets = []
        for i in range(size):
            char = self.int_to_char(sequence[i].item())
            if char != self.int_to_char(self.blank_index):
                # if this char is a repetition and remove_repetitions=true, then skip
                if remove_repetitions and i != 0 and char == self.int_to_char(sequence[i - 1].item()):
                    pass
                elif char == self.labels[self.space_index]:
                    string += self.space_simbol
                    offsets.append(i)
                else:
                    string = string + char
                    offsets.append(i)
        return string, torch.tensor(offsets, dtype=torch.int)

    def decode(self, probs, sizes=None):
        """
        Given a matrix of character probabilities, returns the decoder's
        best guess of the transcription
        Arguments:
            probs: Tensor of character probabilities, where probs[c,t]
                            is the probability of character c at time t
            sizes(optional): Size of each sequence in the mini-batch
        Returns:
            string: sequence of the model's best guess for the transcription
        """
        raise NotImplementedError


class BeamCTCDecoder(Decoder):
    def __init__(self, bpe, lm_path=None, alpha=0, beta=0, cutoff_top_n=40, cutoff_prob=1.0, beam_width=100,
                 num_processes=4, blank_index=0):
        self.labels = labels = bpe.vocab()
        super(BeamCTCDecoder, self).__init__(bpe=bpe)
        try:
            from ctcdecode import CTCBeamDecoder
        except ImportError:
            raise ImportError("BeamCTCDecoder requires paddledecoder package.")
        self._decoder = CTCBeamDecoder(labels, 
                model_path=lm_path, alpha=alpha, beta=beta, cutoff_top_n=cutoff_top_n, 
                cutoff_prob=cutoff_prob, beam_width=beam_width, num_processes=num_processes, blank_id=self.blank_index)

    def convert_to_strings_ctc(self, out, seq_len):
        results = []
        for b, batch in enumerate(out):
            utterances = []
            for p, utt in enumerate(batch):
                size = seq_len[b][p]
                if size > 0:
                    transcript = ''.join(map(lambda x: self.int_to_char(x.item()), utt[0:size]))
                else:
                    transcript = ''
                utterances.append(transcript)
            results.append(utterances)
        return results

    def convert_tensor_ctc(self, offsets, sizes):
        results = []
        for b, batch in enumerate(offsets):
            utterances = []
            for p, utt in enumerate(batch):
                size = sizes[b][p]
                if sizes[b][p] > 0:
                    utterances.append(utt[0:size])
                else:
                    utterances.append(torch.tensor([], dtype=torch.int))
            results.append(utterances)
        return results

    def decode(self, probs, sizes=None):
        """
        Decodes probability output using ctcdecode package.
        Arguments:
            probs: Tensor of character probabilities, where probs[c,t]
                            is the probability of character c at time t
            sizes: Size of each sequence in the mini-batch
        Returns:
            string: sequences of the model's best guess for the transcription
        """
        probs = probs.cpu()
        out, scores, offsets, seq_lens = self._decoder.decode(probs, sizes)
        # print(scores)
        strings = self.convert_to_strings_ctc(out, seq_lens)
        strings = [item[0] for item in strings]
        # offsets = self.convert_tensor_ctc(offsets, seq_lens)
        return strings


class GreedyDecoder(Decoder):
    def decode(self, probs, sizes=None):
        """
        Returns the argmax decoding given the probability matrix. Removes
        repeated elements in the sequence, as well as blanks.
        Arguments:
            probs: Tensor of character probabilities from the network. Expected shape of batch x seq_length x output_dim
            sizes(optional): Size of each sequence in the mini-batch
        Returns:
            strings: sequences of the model's best guess for the transcription on inputs
            offsets: time step per character predicted
        """
        _, max_probs = torch.max(probs, 2)
        strings, offsets = self.convert_to_strings(max_probs.view(max_probs.size(0), max_probs.size(1)), sizes,
                                                   remove_repetitions=True, return_offsets=True)
        return strings

# Train

In [None]:
fix_seeds(42)
config = None
bpe = prepare_bpe(config)

In [None]:
def train(config):
    fix_seeds(seed=42)
    config = None
    num_epoches=2
    bpe = prepare_bpe(config)

    transforms_train = Compose([
            TextPreprocess(),
            ToNumpy(),
            BPEtexts(bpe=bpe, dropout_prob=0.05),
            AudioSqueeze(),
            AddGaussianNoise(
                min_amplitude=0.001,
                max_amplitude=0.015,
                p=0.5
            ),
            TimeStretch(
                min_rate=0.8,
                max_rate=1.25,
                p=0.5
            ),
            PitchShift(
                min_semitones=-4,
                max_semitones=4,
                p=0.5
            )
    ])

    batch_transforms_train = Compose([
            ToGpu('cuda' if torch.cuda.is_available() else 'cpu'),
            NormalizedMelSpectrogram(
                sample_rate=16000,
                n_mels=64,
                normalize='touniform'
            ).to('cuda' if torch.cuda.is_available() else 'cpu'),
            MaskSpectrogram(
                probability=0.5,
                time_mask_max_percentage=0.05,
                frequency_mask_max_percentage=0.15
            ),
            AddLengths(),
            Pad()
    ])

    transforms_val = Compose([
            TextPreprocess(),
            ToNumpy(),
            BPEtexts(bpe=bpe),
            AudioSqueeze()
    ])

    batch_transforms_val = Compose([
            ToGpu('cuda' if torch.cuda.is_available() else 'cpu'),
            NormalizedMelSpectrogram(
                sample_rate=16000,
                n_mels=64,
                normalize='touniform'
            ).to('cuda' if torch.cuda.is_available() else 'cpu'),
            AddLengths(),
            Pad()
    ])

    # load datasets
    train_dataset = get_dataset(config, transforms=transforms_train, part='train')
    val_dataset = get_dataset(config, transforms=transforms_val, part='val')

    #take subsets
    indices_train = list(range(0, len(train_dataset), 50))
    trainsubset = torch.utils.data.Subset(train_dataset, indices_train)
    indices_val = list(range(0, len(val_dataset), 50))
    valsubset = torch.utils.data.Subset(val_dataset, indices_val)

    train_dataloader = DataLoader(trainsubset, num_workers=16,
                batch_size=32, collate_fn=no_pad_collate)

    val_dataloader = DataLoader(valsubset, num_workers=16,
                batch_size=1, collate_fn=no_pad_collate)

    model = QuartzNet(_quartznet5x5_config, 64, 80)

    print(model)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=0.0001)
    num_steps = len(train_dataloader) * num_epoches
    print('num steps:', num_steps)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps)

    if torch.cuda.is_available():
        model = model.cuda()

    criterion = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
    decoder = GreedyDecoder(bpe=bpe)

    prev_wer = 1000
    for epoch_idx in tqdm(range(num_epoches)):
        # train:
        model.train()
        for batch_idx, batch in enumerate(train_dataloader):
            batch = batch_transforms_train(batch)
            optimizer.zero_grad()
            logits = model(batch['audio'])
            output_length = torch.ceil(batch['input_lengths'].float() / model.stride).int()
            loss = criterion(logits.permute(2, 0, 1).log_softmax(dim=2), batch['text'], output_length, batch['target_lengths'])
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 15)
            optimizer.step()
            lr_scheduler.step()

            if batch_idx % 20 == 0:
                target_strings = decoder.convert_to_strings(batch['text'])
                decoded_output = decoder.decode(logits.permute(0, 2, 1).softmax(dim=2))

                wer = np.mean([decoder.wer(true, pred) for true, pred in zip(target_strings, decoded_output)])
                cer = np.mean([decoder.cer(true, pred) for true, pred in zip(target_strings, decoded_output)])
                step = epoch_idx * len(train_dataloader) * train_dataloader.batch_size + batch_idx * train_dataloader.batch_size

        # validate:
        model.eval()
        val_stats = defaultdict(list)
        for batch_idx, batch in enumerate(val_dataloader):
            batch = batch_transforms_val(batch)
            with torch.no_grad():
                logits = model(batch['audio'])
                output_length = torch.ceil(batch['input_lengths'].float() / model.stride).int()
                loss = criterion(logits.permute(2, 0, 1).log_softmax(dim=2), batch['text'], output_length, batch['target_lengths'])

            target_strings = decoder.convert_to_strings(batch['text'])
            decoded_output = decoder.decode(logits.permute(0, 2, 1).softmax(dim=2))
            wer = np.mean([decoder.wer(true, pred) for true, pred in zip(target_strings, decoded_output)])
            cer = np.mean([decoder.cer(true, pred) for true, pred in zip(target_strings, decoded_output)])
            val_stats['val_loss'].append(loss.item())
            val_stats['wer'].append(wer)
            val_stats['cer'].append(cer)
        for k, v in val_stats.items():
            val_stats[k] = np.mean(v)

        # save model, TODO: save optimizer:
        if val_stats['wer'] < prev_wer:
            os.makedirs('logs', exist_ok=True)
            prev_wer = val_stats['wer']
            torch.save(
                model.state_dict(),
                os.path.join('logs', f'model_{epoch_idx}_{prev_wer}.pth')
            )

In [None]:
train(None)

  cpuset_checked))


QuartzNet(
  (encoder): Sequential(
    (0): MainBlock(
      (net): Sequential(
        (0): Sequential(
          (0): Conv1d(64, 64, kernel_size=(33,), stride=(2,), padding=(16,), groups=64, bias=False)
          (1): Conv1d(64, 256, kernel_size=(1,), stride=(1,), bias=False)
          (2): BatchNorm1d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): Sequential(
          (0): ReLU()
          (1): Dropout(p=0.2, inplace=False)
        )
      )
      (out): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.2, inplace=False)
      )
    )
    (1): MainBlock(
      (net): Sequential(
        (0): Sequential(
          (0): Conv1d(256, 256, kernel_size=(33,), stride=(1,), padding=(16,), groups=256, bias=False)
          (1): Conv1d(256, 256, kernel_size=(1,), stride=(1,), bias=False)
          (2): BatchNorm1d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): Sequential(
          (0): ReLU()
 

100%|██████████| 2/2 [03:47<00:00, 113.78s/it]
