In [2]:
import json
import os
from pprint import pprint

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchaudio
from comet_ml import Experiment
from ctcdecode import CTCBeamDecoder

# Dataset Downloading

In [3]:
DATASET_FOLDER = 'LibriSpeech'

In [4]:
def download_data(dataset_folder, download=False):
    if not os.path.isdir("./data"):
        os.makedirs("./data")

    train_dataset = torchaudio.datasets.LIBRISPEECH(root='./data', folder_in_archive=dataset_folder, url='train-clean-100', download=download)
    test_dataset = torchaudio.datasets.LIBRISPEECH(root='./data', folder_in_archive=dataset_folder, url='test-clean', download=download)
    return train_dataset, test_dataset

In [5]:
download_data(DATASET_FOLDER, True)

(<torchaudio.datasets.librispeech.LIBRISPEECH at 0x7ff0f125f610>,
 <torchaudio.datasets.librispeech.LIBRISPEECH at 0x7ff1cf56e280>)

In [6]:
train_dataset, test_dataset = download_data(DATASET_FOLDER)
print(len(train_dataset), len(test_dataset))

28539 2620


# Data Processing

In [7]:
char_map_str = """
' 0
<BLANK> 1
a 2
b 3
c 4
d 5
e 6
f 7
g 8
h 9
i 10
j 11
k 12
l 13
m 14
n 15
o 16
p 17
q 18
r 19
s 20
t 21
u 22
v 23
w 24
x 25
y 26
z 27
"""

BLANK_LABEL = None  # to be assigned in TextTranform constructor

class TextTransform:
    """ Maps characters to their indices, and vice versa """
    def __init__(self):
        self.char_map = {}
        self.index_map = {}
        global BLANK_LABEL
        for line in char_map_str.strip().split('\n'):
            ch, index = line.split()
            self.char_map[ch] = int(index)
            self.index_map[int(index)] = ch
            if not BLANK_LABEL and ch == '<BLANK>':
                BLANK_LABEL = int(index)
        self.index_map[BLANK_LABEL] = ' '

    def text_to_int(self, text: list[str]):
        """ Use a character map and convert text to an integer sequence """
        int_sequence = []
        for c in text:
            if c == ' ':
                ind = self.char_map['<BLANK>']
            else:
                ind = self.char_map[c]
            int_sequence.append(ind)
        return int_sequence

    def int_to_text(self, labels: list[int]):
        """ Use a character map and convert integer labels to an text sequence """
        string = []
        for i in labels:
            string.append(self.index_map[i])
        return ''.join(string).replace('<BLANK>', ' ').strip()


# TODO: SpecAugment (masking augmentations)
train_audio_transforms = nn.Sequential(
    torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128),
    torchaudio.transforms.FrequencyMasking(freq_mask_param=15),
    torchaudio.transforms.TimeMasking(time_mask_param=35)
)

valid_audio_transforms = torchaudio.transforms.MelSpectrogram()

text_transform = TextTransform()



In [8]:
print(BLANK_LABEL)

1


In [9]:
# testing the code above
word_start = "yes"
index = text_transform.text_to_int(word_start)
word_recovered = text_transform.int_to_text(index)

print(word_start, "-->", index, "-->", word_recovered)

yes --> [26, 6, 20] --> yes


The function __data_processing()__ will be called in Data Loaders' __collate_fn__.

Data is represented as tuple(wave, sample_rate, utterance (label), speaker id, utterance number)

In [10]:
sample = train_dataset.__getitem__(n=2)
sample

(tensor([[ 0.0052,  0.0074,  0.0113,  ..., -0.0007, -0.0039, -0.0058]]),
 16000,
 "FOR NOT EVEN A BROOK COULD RUN PAST MISSUS RACHEL LYNDE'S DOOR WITHOUT DUE REGARD FOR DECENCY AND DECORUM IT PROBABLY WAS CONSCIOUS THAT MISSUS RACHEL WAS SITTING AT HER WINDOW KEEPING A SHARP EYE ON EVERYTHING THAT PASSED FROM BROOKS AND CHILDREN UP",
 103,
 1240,
 2)

In [11]:
def data_processing(data, data_type="train"):
    spectrograms = []
    labels = []
    input_lengths = []
    label_lengths = []
    for (waveform, _, utterance, _, _, _) in data:
        if data_type == 'train':
            spec = train_audio_transforms(waveform).squeeze(0).transpose(0, 1)
        else:
            spec = valid_audio_transforms(waveform).squeeze(0).transpose(0, 1)
        spectrograms.append(spec)
        # labels are lists of integer character ids
        label = torch.Tensor(text_transform.text_to_int(utterance.lower()))
        labels.append(label)
        # input_lengths, label_lengths are used in loss function
        input_lengths.append(spec.shape[0]//2)
        label_lengths.append(len(label))

    spectrograms = nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3)
    labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)

    return spectrograms, labels, input_lengths, label_lengths

In [12]:
# testing
data_processing((sample,))

(tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
            0.0000e+00, 0.0000e+00],
           [1.0323e-01, 5.5455e-02, 5.8218e-02,  ..., 4.6377e-02,
            4.9086e-02, 5.8545e-02],
           [5.5582e-01, 2.9858e-01, 3.1346e-01,  ..., 2.4971e-01,
            2.6429e-01, 3.1522e-01],
           ...,
           [6.0961e-05, 9.4463e-05, 1.3510e-04,  ..., 1.9498e-04,
            4.0945e-04, 6.1479e-05],
           [4.1905e-05, 1.5869e-04, 1.1976e-04,  ..., 1.0217e-04,
            3.5421e-04, 8.3099e-05],
           [2.9035e-04, 6.6924e-05, 1.1111e-04,  ..., 8.5560e-05,
            3.0291e-05, 4.2126e-05]]]]),
 tensor([[ 7., 16., 19.,  1., 15., 16., 21.,  1.,  6., 23.,  6., 15.,  1.,  2.,
           1.,  3., 19., 16., 16., 12.,  1.,  4., 16., 22., 13.,  5.,  1., 19.,
          22., 15.,  1., 17.,  2., 20., 21.,  1., 14., 10., 20., 20., 22., 20.,
           1., 19.,  2.,  4.,  9.,  6., 13.,  1., 13., 26., 15.,  5.,  6.,  0.,
          20.,  1.,  5., 16., 16., 19.,  1

## CER and WER metrics

In [13]:
def _levenshtein_distance(ref, hyp):
    """Levenshtein distance is a string metric for measuring the difference
    between two sequences. Informally, the levenshtein disctance is defined as
    the minimum number of single-character edits (substitutions, insertions or
    deletions) required to change one word into the other. We can naturally
    extend the edits to word level when calculate levenshtein disctance for
    two sentences.
    """
    m = len(ref)
    n = len(hyp)

    # special case
    if ref == hyp:
        return 0
    if m == 0:
        return n
    if n == 0:
        return m

    if m < n:
        ref, hyp = hyp, ref
        m, n = n, m

    # use O(min(m, n)) space
    distance = np.zeros((2, n + 1), dtype=np.int32)

    # initialize distance matrix
    for j in range(0,n + 1):
        distance[0][j] = j

    # calculate levenshtein distance
    for i in range(1, m + 1):
        prev_row_idx = (i - 1) % 2
        cur_row_idx = i % 2
        distance[cur_row_idx][0] = i
        for j in range(1, n + 1):
            if ref[i - 1] == hyp[j - 1]:
                distance[cur_row_idx][j] = distance[prev_row_idx][j - 1]
            else:
                s_num = distance[prev_row_idx][j - 1] + 1
                i_num = distance[cur_row_idx][j - 1] + 1
                d_num = distance[prev_row_idx][j] + 1
                distance[cur_row_idx][j] = min(s_num, i_num, d_num)

    return distance[m % 2][n]


def word_errors(reference, hypothesis, ignore_case=False, delimiter=' '):
    """Compute the levenshtein distance between reference sequence and
    hypothesis sequence in word-level.
    :param reference: The reference sentence.
    :type reference: basestring
    :param hypothesis: The hypothesis sentence.
    :type hypothesis: basestring
    :param ignore_case: Whether case-sensitive or not.
    :type ignore_case: bool
    :param delimiter: Delimiter of input sentences.
    :type delimiter: char
    :return: Levenshtein distance and word number of reference sentence.
    :rtype: list
    """
    if ignore_case == True:
        reference = reference.lower()
        hypothesis = hypothesis.lower()

    ref_words = reference.split(delimiter)
    hyp_words = hypothesis.split(delimiter)

    edit_distance = _levenshtein_distance(ref_words, hyp_words)
    return float(edit_distance), len(ref_words)


def char_errors(reference, hypothesis, ignore_case=False, remove_space=False):
    """Compute the levenshtein distance between reference sequence and
    hypothesis sequence in char-level.
    :param reference: The reference sentence.
    :type reference: basestring
    :param hypothesis: The hypothesis sentence.
    :type hypothesis: basestring
    :param ignore_case: Whether case-sensitive or not.
    :type ignore_case: bool
    :param remove_space: Whether remove internal space characters
    :type remove_space: bool
    :return: Levenshtein distance and length of reference sentence.
    :rtype: list
    """
    if ignore_case == True:
        reference = reference.lower()
        hypothesis = hypothesis.lower()

    join_char = ' '
    if remove_space == True:
        join_char = ''

    reference = join_char.join(filter(None, reference.split(' ')))
    hypothesis = join_char.join(filter(None, hypothesis.split(' ')))

    edit_distance = _levenshtein_distance(reference, hypothesis)
    return float(edit_distance), len(reference)


def wer(reference, hypothesis, ignore_case=False, delimiter=' '):
    """Calculate word error rate (WER). WER compares reference text and
    hypothesis text in word-level. WER is defined as:
    .. math::
        WER = (Sw + Dw + Iw) / Nw
    where
    .. code-block:: text
        Sw is the number of words subsituted,
        Dw is the number of words deleted,
        Iw is the number of words inserted,
        Nw is the number of words in the reference
    We can use levenshtein distance to calculate WER. Please draw an attention
    that empty items will be removed when splitting sentences by delimiter.
    :param reference: The reference sentence.
    :type reference: basestring
    :param hypothesis: The hypothesis sentence.
    :type hypothesis: basestring
    :param ignore_case: Whether case-sensitive or not.
    :type ignore_case: bool
    :param delimiter: Delimiter of input sentences.
    :type delimiter: char
    :return: Word error rate.
    :rtype: float
    :raises ValueError: If word number of reference is zero.
    """
    edit_distance, ref_len = word_errors(reference, hypothesis, ignore_case,
                                         delimiter)

    if ref_len == 0:
        raise ValueError("Reference's word number should be greater than 0.")

    wer = float(edit_distance) / ref_len
    return wer


def cer(reference, hypothesis, ignore_case=False, remove_space=False):
    """Calculate charactor error rate (CER). CER compares reference text and
    hypothesis text in char-level. CER is defined as:
    .. math::
        CER = (Sc + Dc + Ic) / Nc
    where
    .. code-block:: text
        Sc is the number of characters substituted,
        Dc is the number of characters deleted,
        Ic is the number of characters inserted
        Nc is the number of characters in the reference
    We can use levenshtein distance to calculate CER. Chinese input should be
    encoded to unicode. Please draw an attention that the leading and tailing
    space characters will be truncated and multiple consecutive space
    characters in a sentence will be replaced by one space character.
    :param reference: The reference sentence.
    :type reference: basestring
    :param hypothesis: The hypothesis sentence.
    :type hypothesis: basestring
    :param ignore_case: Whether case-sensitive or not.
    :type ignore_case: bool
    :param remove_space: Whether remove internal space characters
    :type remove_space: bool
    :return: Character error rate.
    :rtype: float
    :raises ValueError: If the reference length is zero.
    """
    edit_distance, ref_len = char_errors(reference, hypothesis, ignore_case,
                                         remove_space)

    if ref_len == 0:
        raise ValueError("Length of reference should be greater than 0.")

    cer = float(edit_distance) / ref_len
    return cer

# Building a Model

## NN Architecture

We use Layer Normalization, not Batch Normalization, because BN is hard to use with sequence data, with small batch sizes, and it's hard to paralellize a NN with BN.

This is due to the dependency on batches. Layer Normalization removes this dependency. It computes the normalization based on the layers inside of the batches.

LN briefly: Input values in all neurons in the same layer are normalized for each data sample.
So, all values in neurons of the same layer will have the same mean and variance.

LN is can deal with sequence data, doesn't depend on batch size, and is easily paralellized.
However, LN sometimes performs worse than BN with CNNs.

In [14]:
class CNNLayerNorm(nn.Module):
    """Layer Normalization"""
    
    def __init__(self, n_features):
        super(CNNLayerNorm, self).__init__()
        self.layer_norm = nn.LayerNorm(normalized_shape=n_features)
        """About normalized_shape parameter of nn.LayerNorm:
        If a single integer is used, it is treated as a singleton list, and this module will normalize
        over the last dimension which is expected to be of that specific size.
        """

    def forward(self, x):
        # x (batch, channel, feature, time)
        x = x.transpose(2, 3).contiguous() # (batch, channel, time, feature)
        x = self.layer_norm(x)
        return x.transpose(2, 3).contiguous() # (batch, channel, feature, time) 

class ResidualCNN(nn.Module):
    """ Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf
        except with layer norm instead of batch norm """
    
    def __init__(self, in_channels, out_channels, kernel, stride, dropout, n_features):
        super(ResidualCNN, self).__init__()

        self.layer_norm1 = CNNLayerNorm(n_features)
        self.dropout1 = nn.Dropout(dropout)
        self.cnn1 = nn.Conv2d(in_channels, out_channels, kernel, stride, padding=kernel//2)

        self.layer_norm2 = CNNLayerNorm(n_features)
        self.dropout2 = nn.Dropout(dropout)
        self.cnn2 = nn.Conv2d(out_channels, out_channels, kernel, stride, padding=kernel//2)

    def forward(self, x):
        residual = x  # (batch, channel, feature, time)
        out = self.layer_norm1(x)
        out = F.gelu(out)
        out = self.dropout1(out)
        out = self.cnn1(out)
        out = self.layer_norm2(out)
        out = F.gelu(out)
        out = self.dropout2(out)
        out = self.cnn2(out)
        out += residual
        return out # (batch, channel, feature, time)
        
class BidirectionalGRU(nn.Module):

    def __init__(self, rnn_type, rnn_dim, hidden_size, dropout, batch_first):
        super(BidirectionalGRU, self).__init__()

        if rnn_type == "GRU":
            self.rnn_cell = nn.GRU
        elif rnn_type == "LSTM":
            self.rnn_cell = nn.LSTM

        self.layer_norm = nn.LayerNorm(rnn_dim)
        self.BiGRU = self.rnn_cell(
            input_size=rnn_dim, hidden_size=hidden_size,
            num_layers=1, batch_first=batch_first, bidirectional=True)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = self.layer_norm(x)
        out = F.gelu(out)
        out, _ = self.BiGRU(out)
        out = self.dropout(out)
        return out


class SpeechRecognitionModel(nn.Module):
    """Speech Recognition Model Inspired by DeepSpeech 2"""

    def __init__(self, rnn_type, n_cnn_layers, n_rnn_layers, rnn_dim, n_class, n_features, stride=2, dropout=0.1):
        super(SpeechRecognitionModel, self).__init__()
        n_features = n_features // 2

        # TODO: purpose of this conv layer
        self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2)  # cnn for extracting heirachal features

        # n residual cnn layers with filter size of 32
        self.rescnn_layers = nn.Sequential(*[
            ResidualCNN(32, 32, kernel=3, stride=1, dropout=dropout, n_features=n_features) 
            for _ in range(n_cnn_layers)
        ])
        self.fully_connected = nn.Linear(n_features*32, rnn_dim)
        """TODO: как я понял, у нас число фичей rnn_dim берётся одиночное для первой GRU, т.к. к ней не перетекает скрытое состояние.
        Каждая следующая GRU получает 2*rnn_dim фичей, т.к. к самим фичам конкатенируется скрытое состояние такой же размерности (hidden_size=rnn_dim)
        """
        self.birnn_layers = nn.Sequential(*[
            BidirectionalGRU(rnn_type=rnn_type, 
                             rnn_dim=rnn_dim if i==0 else rnn_dim*2,
                             hidden_size=rnn_dim, dropout=dropout, batch_first=i==0)
            for i in range(n_rnn_layers)
        ])
        self.classifier = nn.Sequential(
            nn.Linear(rnn_dim*2, rnn_dim),  # birnn returns rnn_dim*2
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(rnn_dim, n_class)
        )

    def forward(self, x):
        x = self.cnn(x)
        x = self.rescnn_layers(x)
        sizes = x.size()
        x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3])  # (batch, feature, time)
        x = x.transpose(1, 2) # (batch, time, feature)
        x = self.fully_connected(x)
        x = self.birnn_layers(x)
        x = self.classifier(x)
        return x

# Decoder

## Greedy Decoder

In [15]:
def GreedyDecoder(output, labels, label_lengths, blank_label=BLANK_LABEL, collapse_repeated=True):
    output = F.log_softmax(output, dim=2)

    arg_maxes = torch.argmax(output, dim=2)
    decodes = []
    targets = []
    
    for i, args in enumerate(arg_maxes):
        decode = []
        targets.append(text_transform.int_to_text(labels[i][:label_lengths[i]].tolist()))
        blank_ctr = 0
        for j, index in enumerate(args):
            if collapse_repeated and index == blank_label:
                blank_ctr += 1
            else:
                if blank_ctr > 2:
                    decode.append(blank_label)
                blank_ctr = 0

                if collapse_repeated and j != 0 and index == args[j - 1]:
                    continue
                
                decode.append(index.item())
        
        decodes.append(text_transform.int_to_text(decode))

    return decodes, targets

## Beam Search Decoder

In [16]:
characters = list(text_transform.char_map.keys())

beam_decoder = CTCBeamDecoder(
    labels=characters,
    cutoff_top_n=len(characters),   # do not discard any characters from beam search
    cutoff_prob=1.0,    # cutoff probability in pruning (1.0 means no pruning)
    beam_width=100,
    num_processes=4,
    blank_id=28,
    log_probs_input=True
)

In [17]:
def BeamSearchDecoder(output):
    top_beams = []
    decodings = []

    beam_results, beam_scores, timesteps, out_lens = beam_decoder.decode(output)
    
    for i in range(len(beam_results)):
        top_beam = beam_results[i][0][:out_lens[i][0]]
        decoding = text_transform.int_to_text(top_beam.tolist())

        top_beams.append(top_beam)
        decodings.append(decoding)
        
    return top_beams, decodings

# Training

In [18]:
def get_device():
    use_cuda = torch.cuda.is_available()
    torch.manual_seed(7)
    device = torch.device("cuda" if use_cuda else "cpu")
    return device, use_cuda

In [19]:
class IterMeter(object):
    """Keeps track of total iterations. Used for Comet.ml"""
    def __init__(self):
        self.val = 0

    def step(self):
        self.val += 1

    def get(self):
        return self.val


def train(model, device, train_loader, criterion, optimizer, scheduler, epoch, iter_meter, experiment):
    model.train()
    data_len = len(train_loader.dataset)
    with experiment.train():
        for batch_idx, _data in enumerate(train_loader):
            spectrograms, labels, input_lengths, label_lengths = _data 
            spectrograms, labels = spectrograms.to(device), labels.to(device)

            optimizer.zero_grad()

            output = model(spectrograms)  # (batch, time, n_class)
            output = F.log_softmax(output, dim=2)
            output = output.transpose(0, 1) # (time, batch, n_class)

            loss = criterion(output, labels, input_lengths, label_lengths)
            loss.backward()

            experiment.log_metric('loss', loss.item(), step=iter_meter.get())
            experiment.log_metric('learning_rate', scheduler.get_last_lr(), step=iter_meter.get())

            optimizer.step()
            scheduler.step()   # scheduler step for oncecycle scheduler
            iter_meter.step()

            if batch_idx % 100 == 0 or batch_idx == data_len:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(spectrograms), data_len,
                    100. * batch_idx / len(train_loader), loss.item()))


def test(model, device, test_loader, criterion, epoch, iter_meter, experiment):
    print('\nEvaluating…')
    model.eval()
    
    test_loss = 0
    test_cer_greedy, test_wer_greedy = [], []
    test_cer_beam, test_wer_beam = [], []

    with experiment.test():
        with torch.no_grad():
            for I, _data in enumerate(test_loader):
                spectrograms, labels, input_lengths, label_lengths = _data 
                spectrograms, labels = spectrograms.to(device), labels.to(device)

                output = model(spectrograms)  # (batch, time, n_class)
                # output = F.log_softmax(output, dim=2)
                # output = output.transpose(0, 1) # (time, batch, n_class)

                output_for_loss = F.log_softmax(output, dim=2).transpose(0, 1)
                loss = criterion(output_for_loss, labels, input_lengths, label_lengths)
                test_loss += loss.item() / len(test_loader)

                greedy_preds, decoded_targets = GreedyDecoder(output, labels, label_lengths)
                beams, beam_preds = BeamSearchDecoder(output)

                for j in range(len(greedy_preds)):
                    test_cer_greedy.append(cer(decoded_targets[j], greedy_preds[j]))
                    test_wer_greedy.append(wer(decoded_targets[j], greedy_preds[j]))

                for j in range(len(beam_preds)):
                    test_cer_beam.append(cer(decoded_targets[j], beam_preds[j]))
                    test_wer_beam.append(wer(decoded_targets[j], beam_preds[j]))

    experiment.log_metric('test_loss', test_loss, step=iter_meter.get())
    print('Test set: Average loss: {:.4f}'.format(test_loss))

    avg_cer_greedy = sum(test_cer_greedy)/len(test_cer_greedy)
    avg_wer_greedy = sum(test_wer_greedy)/len(test_wer_greedy)
    experiment.log_metric('cer_greedy', avg_cer_greedy, step=iter_meter.get())
    experiment.log_metric('wer_greedy', avg_wer_greedy, step=iter_meter.get())
    print('Average Greedy CER: {:4f} Average Greedy WER: {:.4f}\n'.format(avg_cer_greedy, avg_wer_greedy))

    avg_cer_beam = sum(test_cer_beam)/len(test_cer_beam)
    avg_wer_beam = sum(test_wer_beam)/len(test_wer_beam)
    experiment.log_metric('cer_beam', avg_cer_beam, step=iter_meter.get())
    experiment.log_metric('wer_beam', avg_wer_beam, step=iter_meter.get())
    print('Average Beam CER: {:4f} Average Beam WER: {:.4f}\n'.format(avg_cer_beam, avg_wer_beam))


def train_test(hparams, experiment=Experiment(api_key='dummy_key', disabled=True)):

    experiment.log_parameters(hparams)

    train_dataset, test_dataset = download_data(DATASET_FOLDER)

    device, use_cuda = get_device()

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = data.DataLoader(dataset=train_dataset,
                                batch_size=hparams['batch_size'],
                                shuffle=True,
                                collate_fn=lambda x: data_processing(x, 'train'),
                                **kwargs)
    test_loader = data.DataLoader(dataset=test_dataset,
                                batch_size=hparams['batch_size'],
                                shuffle=False,
                                collate_fn=lambda x: data_processing(x, 'test'),
                                **kwargs)

    model = SpeechRecognitionModel(
        hparams['rnn_type'], hparams['n_cnn_layers'], hparams['n_rnn_layers'], hparams['rnn_dim'],
        hparams['n_class'], hparams['n_feats'], hparams['stride'], hparams['dropout']
        ).to(device)

    print(model)
    print('Num Model Parameters', sum([param.nelement() for param in model.parameters()]))

    optimizer = optim.AdamW(model.parameters(), hparams['learning_rate'])
    #optimizer = optim.Adam(model.parameters(), hparams['learning_rate'])
    #optimizer = optim.SGD(model.parameters(), hparams['learning_rate'], momentum=0.1, nesterov=True)
    criterion = nn.CTCLoss(blank=BLANK_LABEL).to(device)
    #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=hparams['scheduler_step_size'], gamma=hparams['scheduler_gamma'], verbose=True)
    scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=hparams['learning_rate'], 
                                            steps_per_epoch=int(len(train_loader)),
                                            epochs=hparams['epochs'],
                                            anneal_strategy='linear')

    iter_meter = IterMeter()
    for epoch in range(1, hparams['epochs'] + 1):
        train(model, device, train_loader, criterion, optimizer, scheduler, epoch, iter_meter, experiment)
        test(model, device, test_loader, criterion, epoch, iter_meter, experiment)
        #scheduler.step()   # scheduler step for stepLR scheduler

    return model



Experiments are monitored using Comet.ml:

In [19]:
hparams = {
    "rnn_type": "LSTM",
    "n_cnn_layers": 3,
    "n_rnn_layers": 3,
    "n_feats": 128,
    "stride": 2,
    "rnn_dim": 512,
    "n_class": len(characters),
    "dropout": 0.05,
    "learning_rate": 5e-4,
    "batch_size": 8,
    "epochs": 3,
}

comet_api_key = "KddYvSKPDO9U8K1lZUIUCgHjT"
project_name = "speech-recognition"
experiment_name = f"rnn:{hparams['rnn_type']}-epochs:{hparams['epochs']}"

if comet_api_key:
    experiment = Experiment(api_key=comet_api_key, project_name=project_name, parse_args=False)
    experiment.set_name(experiment_name)
    # experiment.display()
else:
    experiment = Experiment(api_key='dummy_key', disabled=True)

COMET INFO: Experiment is live on comet.com https://www.comet.com/pigeongcc/speech-recognition/2f6228b4ba604ab2b8d72bc8b05f8b5d



Run the training loop:

In [None]:
model = train_test(hparams, experiment)

In [None]:
experiment.end()

# Serialization

In [20]:
if not os.path.isdir("./models"):
    os.makedirs("./models")

def get_model_paths(experiment_name):
    model_folder_path = f'./models/{experiment_name}'
    model_path = f'{model_folder_path}/{experiment_name}.pth'
    hparams_path = f'{model_folder_path}/{experiment_name}.json'
    return model_folder_path, model_path, hparams_path

## Saving the Model

In [21]:
def save_model(experiment_name):
	model_folder_path, model_path, hparams_path = get_model_paths(experiment_name)

	# create folders if not exist
	if not os.path.isdir(model_folder_path):
		if not os.path.isdir("./models"):
			os.makedirs("./models")
		os.makedirs(model_folder_path)
		
	torch.save(model.state_dict(), model_path)

	with open(hparams_path, 'w') as f:
		json.dump(hparams, f)

In [22]:
# save_model(experiment_name)

## Loading the Model

In [23]:
def load_model(experiment_name):
    _, model_path, hparams_path = get_model_paths(experiment_name)
    
    device, _ = get_device()

    with open(hparams_path, 'r') as f:
        hparams = json.load(f)

    if 'rnn_type' not in hparams:
        hparams['rnn_type'] = "GRU"

    model = SpeechRecognitionModel(
        hparams['rnn_type'], hparams['n_cnn_layers'], hparams['n_rnn_layers'], hparams['rnn_dim'],
        hparams['n_class'], hparams['n_feats'], hparams['stride'], hparams['dropout']
        ).to(device)

    model.load_state_dict(torch.load(model_path))

    return model, hparams

In [24]:
experiment_name = f"asr-lch-optim:adamw-scheduler:oncecycle-data:full-epochs:30"
model_loaded, hparams_loaded = load_model(experiment_name)

# Inference

In [27]:
def infer(model,
          sample_idx: int = None,
          sample_path: int = None,
          collapse_repeated: bool = True):
    device, _ = get_device()

    if sample_path is not None:
        waveform, _ = torchaudio.load(sample_path)
        sample = (waveform, None, "<INFERENCE_SAMPLE>", None, None, None)
    elif sample_idx is not None:
        sample = test_dataset.__getitem__(n=sample_idx)

    spectrogram, label, input_length, label_length = \
        data_processing((sample,))
    spectrogram, label = spectrogram.to(device), label.to(device)

    model.eval()

    output = model(spectrogram)

    greedy_pred, label = GreedyDecoder(output, label, label_length, collapse_repeated=collapse_repeated)
    beams, beam_preds = BeamSearchDecoder(output)
    
    print(f"Negative log likelihood matrix shape: {output.shape}")
    print("\nGREEDY DECODING")
    print(f"Decoded indices:\n{torch.argmax(output, dim=2)}")
    print()
    print(f"Target (len {len(label[0])}): {label}")
    print(f"Prediction (len {len(greedy_pred[0])}): {greedy_pred}")

    print("\nBEAM SEARCH DECODING")
    print(f"Top beam:\n{beams}")
    print()
    print(f"Target (len {len(label[0])}): {label}")
    print(f"Prediction (len {len(beam_preds[0])}): {beam_preds}")

In [29]:
infer(model_loaded, sample_path='data/sample.wav')

ValueError: not enough values to unpack (expected 6, got 5)

In [28]:
infer(model_loaded, sample_idx=2000)

Negative log likelihood matrix shape: torch.Size([1, 396, 28])

GREEDY DECODING
Decoded indices:
tensor([[ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1, 24,  6,
         13, 13,  1, 13,  1,  1,  1,  1,  1,  1,  1,  1,  1,  2,  1,  1,  1,  1,
          1,  1,  1,  4, 16, 16, 15,  1,  1,  1,  1,  1, 21,  7, 10, 20,  1,  4,
          6,  5,  1,  1,  1, 21,  1,  2, 21,  1,  1,  1, 21,  9,  6,  1,  1,  1,
          1,  1,  1,  3, 16,  2, 16,  5,  1,  1,  6,  5,  1,  1,  1,  1,  1, 22,
         17,  1,  1,  1,  1,  1,  1,  1,  1,  9, 16, 22,  1, 20,  6,  1,  1,  1,
          1,  1,  1,  1,  1, 14, 26, 20, 20,  1,  1,  1, 21, 21, 19, 19, 26,  1,
          1, 20,  1,  1,  1,  1,  1,  1,  1,  9,  2, 17,  1,  1, 17, 17,  6, 15,
          6,  5,  1,  1,  1,  1,  1,  1,  1, 15, 15, 16, 21,  1,  1,  1,  1,  1,
          1,  1,  2, 19, 19,  1,  1,  1, 13,  1, 13, 10,  1,  2, 19,  1,  1,  1,
          1,  1, 21,  9,  2, 15,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,