In [None]:
import logging
import torch
import torchaudio

from torch import Tensor
from typing import List


logging.basicConfig(
    format="%(asctime)s %(levelname)-8s %(message)s",
    level=logging.INFO,
    datefmt="%Y-%m-%d %H:%M:%S",
)

In [None]:
class args:
    checkpoint = '/checkpoint/vincentqb/checkpoint/checkpoint-notebook-201007-1700.pth.tar'
    dataset_train = 'train-clean-100'
    dataset_valid = 'dev-clean'
    dataset_test = 'test-clean'
    batch_size = 128
    start_epoch = 0
    max_epoch = 1000
    model_input_type = 'mfcc'
    print_freq = 1
    reduce_lr_valid = True
    resume = False
    workers = 0

In this tutorial, we build an automatic speech recognition (ASR) training loop using pytorch and torchaudio. The goal is thus to recognize the text from an audio recording. The dataset used is [LibriSpeech](http://www.openslr.org/12/) and the model is [Wav2Letter](https://arxiv.org/abs/1609.03193). Both are available in torchaudio.

The dataset comprises of an audio recording of someone reading a target sentence, along with supplemental information. More precisely, a data point is `(waveform, sample_rate, utterance, speaker_id, chapter_id, utterance_id)`, and we need the waveform and the utterance from it.


In [13]:
kwargs = {
    "url": 'train-clean-100',
    "root": '/datasets01/librispeech/',
    "folder_in_archive": '062419',
    "download": False
}
dataset_train = LIBRISPEECH(**kwargs)

For example,

In [18]:
waveform, sample_rate, utterance, speaker_id, chapter_id, utterance_id = dataset_train[0]

We use only the first channel of the waveform, preprocess using the MFCC transform, and then normalize the result.

In [15]:
from torchaudio.transforms import MFCC


def get_transforms(model_input_type, sample_rate=None, melkwargs=None):

    if model_input_type == "mfcc":
        n_mfcc = melkwargs["n_mels"]
        mfcc = MFCC(sample_rate=sample_rate, n_mfcc=n_mfcc, melkwargs=melkwargs)
    
    def transforms(tensor):
        tensor = tensor[0, ...]  # downsample
        if model_input_type == "mfcc":
            tensor = mfcc(tensor)  # apply mfcc transform
        tensor = (tensor - tensor.mean(-1, keepdim=True)) / tensor.std(-1, keepdim=True)  # normalize
        return tensor

    return transforms


melkwargs = {
    "n_fft": 400,
    "n_mels": 13,
    "hop_length": 160,
}
model_input_type = "mfcc"

transforms = get_transforms(
    model_input_type,
    sample_rate,  # 16000 Hz for librispeech
    melkwargs,
)

In [17]:
transforms(dataset_train[0][0])

tensor([[-1.4363, -1.4952, -1.4389,  ..., -0.0213,  0.0154,  0.0940],
        [ 0.1158, -0.1005, -0.1491,  ..., -1.1537, -1.1546, -1.1718],
        [ 0.4979,  0.4646,  0.4829,  ..., -0.8638, -0.8657, -0.8790],
        ...,
        [ 0.3266,  0.2912, -0.3310,  ...,  0.3128,  0.1431,  0.5881],
        [ 0.7309,  0.3005, -0.3413,  ...,  0.7735, -0.1685, -0.1231],
        [-0.7672,  0.4788, -0.0693,  ...,  0.5826,  0.2591, -0.7708]])

We also introduces data augmentation techniques from SpecAugment during training. They are available in torchaudio.

In [None]:
def get_augmentations(freq_mask, time_mask):
    
    if freq_mask:
        fm = torchaudio.transforms.FrequencyMasking(freq_mask_param=freq_mask)
    if time_mask:
        tm = torchaudio.transforms.TimeMasking(time_mask_param=time_mask)
    
    def transforms(tensor):
        if freq_mask:
            tensor = fm(tensor)
        if time_mask:
            tensor = tm(tensor)
        return tensor

    return transforms


augmentations = get_augmentations(freq_mask=0, time_mask=0)

In [None]:
augmentations()

The utterance needs to be encoded into a tensor.

In [None]:
import itertools
from collections.abc import Iterable


class LanguageModel:
    def __init__(self, labels, char_blank, char_space):

        self.char_space = char_space
        self.char_blank = char_blank

        enumerated = list(enumerate(labels))
        flipped = [(sub[1], sub[0]) for sub in enumerated]

        self._decode_map = dict(enumerated)
        self._encode_map = dict(flipped)

    def encode(self, listlike):
        if not isinstance(listlike, str):
            return [self.encode(i) for i in listlike]
        return [self._encode_map[i] + self._encode_map[self.char_blank] for i in listlike]

    def decode(self, tensor):
        if len(tensor) > 0 and isinstance(tensor[0], Iterable):
            return [self.decode(t) for t in tensor]

        # not idempotent, since clean string
        x = (self._decode_map[i] for i in tensor)
        x = "".join(i for i, _ in itertools.groupby(x))
        x = x.replace(self.char_blank, "")
        # x = x.strip()
        return x

    def __len__(self):
        return len(self._encode_map)

In [None]:
# Text preprocessing

char_blank = "*"
char_space = " "
char_apostrophe = "'"
labels = char_blank + char_space + char_apostrophe + string.ascii_lowercase
language_model = LanguageModel(labels, char_blank, char_space)

In [None]:
# Transforms

melkwargs = {
    "n_fft": 400,
    "n_mels": 13,
    "hop_length": 160,
}
sample_rate_original = 16000

transforms = get_transforms(args.model_input_type, sample_rate_original, melkwargs)
augmentations = get_augmentations(freq_mask=0, time_mask=0)

# Text preprocessing

char_blank = "*"
char_space = " "
char_apostrophe = "'"
labels = char_blank + char_space + char_apostrophe + string.ascii_lowercase
language_model = LanguageModel(labels, char_blank, char_space)

In [None]:
import pandas as pd
import pprint
import shutil
import time
from collections import defaultdict


class Logger(defaultdict):
    def __init__(self, name, print_freq=1, disable=False, filename=None, dataframe=None):
        super().__init__(float)

        self.disable = disable
        self.print_freq = print_freq
        self.filename = filename
        self.dataframe = dataframe

        self._name = "name"
        self._time = "elapsed time"
        self._iteration = "iteration"

        self[self._name] = name
        self[self._time] = time.monotonic()
        self[self._iteration] = 0

    def __repr__(self):
        self[self._time] = time.monotonic() - self[self._time]
        return dict.__repr__(self)
    
    def __str__(self):
        return pprint.pformat(self)

    def flush(self):
        self[self._iteration] += 1
        if not self[self._iteration] % self.print_freq:
            if self.filename is not None:
                self._append_to_file()
            if self.dataframe is not None:
                self._append_to_pandas()
            if not self.disable:
                print(self, flush=True)

    def _append_to_file(self):
        with open(self._filename, "a") as f:
            f.write(self + "\n")
    
    def _append_to_pandas(self):
        self.dataframe = self.dataframe.append(self, ignore_index=True)


def save_checkpoint(state, is_best, filename):
    """
    Save the model to a temporary file first,
    then copy it to filename, in case the signal interrupts
    the torch.save() process.
    """

    torch.save(state, filename)

    if is_best:
        shutil.copyfile(filename, "best_" + filename)

    logging.warning("Checkpoint: saved")


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
from torchaudio.datasets import LIBRISPEECH

# Change backend for flac files
torchaudio.set_audio_backend("soundfile")


def pad_sequence(sequences, padding_value=0.0):
    # type: (List[Tensor], float) -> Tensor
    r"""Pad a list of variable length Tensors with ``padding_value``

    ``pad_sequence`` stacks a list of Tensors along a new dimension,
    and pads them to equal length. If the input is list of
    sequences with size ``* x L`` then the output is and ``B x * x T``.

    `B` is batch size. It is equal to the number of elements in ``sequences``.
    `T` is length of the longest sequence.
    `L` is length of the sequence.
    `*` is any number of trailing dimensions, including none.

    Example:
        >>> from torch.nn.utils.rnn import pad_sequence
        >>> a = torch.ones(300, 25)
        >>> b = torch.ones(300, 22)
        >>> c = torch.ones(300, 15)
        >>> pad_sequence([a, b, c]).size()
        torch.Size([300, 3, 25])

    Note:
        This function returns a Tensor of size ``B x * x T``
        where `T` is the length of the longest sequence. This function assumes
        trailing dimensions and type of all the Tensors in sequences are same.

    Arguments:
        sequences (list[Tensor]): list of variable length sequences.
        padding_value (float, optional): value for padded elements. Default: 0.

    Returns:
        Tensor of size ``B x * x T``
    """

    # assuming trailing dimensions and type of all the Tensors
    # in sequences are same and fetching those from sequences[0]
    max_size = sequences[0].size()
    trailing_dims = max_size[:-1]
    max_len = max([s.size(-1) for s in sequences])
    out_dims = (len(sequences),) + trailing_dims + (max_len,)

    out_tensor = sequences[0].new_full(out_dims, padding_value)
    for i, tensor in enumerate(sequences):
        length = tensor.size(-1)
        # use index notation to prevent duplicate references to the tensor
        out_tensor[i, ..., :length] = tensor

    return out_tensor


class ProcessedLIBRISPEECH(LIBRISPEECH):
    def __init__(self, transforms, encode, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.transforms = transforms
        self.encode = encode
        self._cache = [None] * len(dataset)

    def __getitem__(self, key):
        # Cache results
        if self._cache[key] is None:
            item = super().__getitem__(key)
            item = self.process_datapoint(item)
            self._cache[key] = item
        return self._cache[key]

    def process_datapoint(self, item):
        """
        Consume a LibriSpeech data point tuple:
        (waveform, sample_rate, utterance, speaker_id, chapter_id, utterance_id).
        - transforms are applied to waveform. Output tensor shape (freq, time).
        - target gets transformed into lower case, and encoded into a one dimensional long tensor.
        """
        transformed = item[0]
        target = item[2].lower()

        transformed = self.transforms(transformed)

        target = self.encode(target)
        target = torch.tensor(target, dtype=torch.long,
                              device=transformed.device)

        return transformed, target


def collate_factory(model_length_function, transforms=None):

    if transforms is None:
        transforms = torch.nn.Sequential()

    def collate_fn(batch):

        # apply transforms to waveforms
        tensors = [transforms(b[0]) for b in batch]
        tensors_lengths = torch.tensor(
            [model_length_function(t) for t in tensors],
            dtype=torch.long,
            device=tensors[0].device,
        )
        tensors = pad_sequence(tensors)

        # extract target utterance
        targets = [b[1] for b in batch]
        target_lengths = torch.tensor(
            [target.shape[0] for target in targets],
            dtype=torch.long,
            device=tensors.device,
        )
        targets = pad_sequence(targets)

        return tensors, targets, tensors_lengths, target_lengths

    return collate_fn

In [None]:
from torch import topk


class GreedyDecoder:
    def __call__(self, outputs):
        """Greedy Decoder. Returns highest probability of class labels for each timestep

        Args:
            outputs (torch.Tensor): shape (input length, batch size, number of classes (including blank))

        Returns:
            torch.Tensor: class labels per time step.
        """
        _, indices = topk(outputs, k=1, dim=-1)
        return indices[..., 0]

In [None]:
from typing import List, Union


def levenshtein_distance(r: Union[str, List[str]], h: Union[str, List[str]]) -> int:
    """
    Calculate the Levenshtein distance between two lists or strings.

    The function computes an edit distance allowing deletion, insertion and substitution.
    The result is an integer. Users may want to normalize by the length of the reference.

    Args:
        r (str or List[str]): the reference list or string to compare.
        h (str or List[str]): the hypothesis, the predicted list or string, to compare.
    Returns:
        int: The distance between the reference and the hypothesis.
    """

    # Initialisation
    dold = list(range(len(h) + 1))
    dnew = list(0 for _ in range(len(h) + 1))

    # Computation
    for i in range(1, len(r) + 1):
        dnew[0] = i
        for j in range(1, len(h) + 1):
            if r[i - 1] == h[j - 1]:
                dnew[j] = dold[j - 1]
            else:
                substitution = dold[j - 1] + 1
                insertion = dnew[j - 1] + 1
                deletion = dold[j] + 1
                dnew[j] = min(substitution, insertion, deletion)

        dnew, dold = dold, dnew

    return dold[-1]

In [None]:
import logging
import os
import string


def get_augmentations(freq_mask, time_mask):
    
    if freq_mask:
        fm = torchaudio.transforms.FrequencyMasking(freq_mask_param=freq_mask)
    if time_mask:
        tm = torchaudio.transforms.TimeMasking(time_mask_param=time_mask)
    
    def transforms(tensor):
        if freq_mask:
            tensor = fm(tensor)
        if time_mask:
            tensor = tm(tensor)
        return tensor

    return transforms


def model_length_function_constructor(model_input_type):
    if model_input_type == "waveform":
        return lambda tensor: int(tensor.shape[-1]) // 160 // 2 + 1
    elif model_input_type == "mfcc":
        return lambda tensor: int(tensor.shape[-1]) // 2 + 1
    raise NotImplementedError(
        f"Selected model input type {model_input_type} not supported"
    )


def record_losses(outputs, targets, decoder, language_model, loss_value, metric):

    # outputs: input length, batch size, number of classes (including blank)
    metric["batch size"] = outputs.shape[1]
    metric["cumulative batch size"] += metric["batch size"]

    # Record loss

    metric["cumulative loss"] += loss_value
    metric["epoch loss"] = metric["cumulative loss"] / metric["cumulative batch size"]
    metric["batch loss"] = loss_value / metric["batch size"]

    # Decode output

    output = outputs.transpose(0, 1).to("cpu")
    output = decoder(output)

    # Compute CER

    output = language_model.decode(output.tolist())
    target = language_model.decode(targets.tolist())

    cers = [levenshtein_distance(t, o) for t, o in zip(target, output)]
    cers = sum(cers)
    n = sum(len(t) for t in target)

    metric["total chars"] += n
    metric["cumulative char errors"] += cers
    metric["batch cer"] = cers / n
    metric["epoch cer"] = metric["cumulative char errors"] / metric["total chars"]

    # Print a few output/target pairs

    print_length = 20
    for i in range(2):
        # Print a few examples
        output_print = output[i].ljust(print_length)[:print_length]
        target_print = target[i].ljust(print_length)[:print_length]
        logging.info("Target: %s    | Output: %s", target_print, output_print)

    # Compute WER

    output = [o.split(language_model.char_space) for o in output]
    target = [t.split(language_model.char_space) for t in target]

    wers = [levenshtein_distance(t, o) for t, o in zip(target, output)]
    wers = sum(wers)
    n = sum(len(t) for t in target)

    metric["total words"] += n
    metric["cumulative word errors"] += wers
    metric["batch wer"] = wers / n
    metric["epoch wer"] = metric["cumulative word errors"] / metric["total words"]

    return metric["epoch loss"]


def train_one_epoch(
    model,
    criterion,
    optimizer,
    scheduler,
    data_loader,
    decoder,
    language_model,
    device,
    metric,
):

    model.train()

    for inputs, targets, tensors_lengths, target_lengths in bg_iterator(
        data_loader, maxsize=2
    ):

        inputs = inputs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        # keep batch first for data parallel
        outputs = model(inputs).transpose(-1, -2).transpose(0, 1)

        # CTC
        # outputs: input length, batch size, number of classes (including blank)
        # targets: batch size, max target length
        # input_lengths: batch size
        # target_lengths: batch size

        loss = criterion(outputs, targets, tensors_lengths, target_lengths)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        avg_loss = record_losses(
            outputs, targets, decoder, language_model, loss.item(), metric
        )

        metric["lr"] = optimizer.param_groups[0]["lr"]
        metric["channel size"] = inputs.shape[1]
        metric["time size"] = inputs.shape[-1]
        metric.flush()


def evaluate(
    model,
    criterion,
    data_loader,
    decoder,
    language_model,
    device,
    epoch,
    metric,
):

    with torch.no_grad():

        model.eval()

        for inputs, targets, tensors_lengths, target_lengths in bg_iterator(
            data_loader, maxsize=2
        ):

            inputs = inputs.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)

            # keep batch first for data parallel
            outputs = model(inputs).transpose(-1, -2).transpose(0, 1)

            # CTC
            # outputs: input length, batch size, number of classes (including blank)
            # targets: batch size, max target length
            # input_lengths: batch size
            # target_lengths: batch size

            loss = criterion(outputs, targets, tensors_lengths, target_lengths)

            avg_loss = record_losses(
                outputs, targets, decoder, language_model, loss.item(), metric
            )

        metric.flush()

        return avg_loss

In [None]:
from torch.optim import Adadelta
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from torchaudio.datasets.utils import bg_iterator
from torchaudio.models.wav2letter import Wav2Letter


logging.info("Start")

# Change backend for flac files
torchaudio.set_audio_backend("soundfile")


# Dataset loader

kwargs = {
    "transforms": transforms,
    "encode", language_model.encode,
    "root": args.dataset_root,
    "folder_in_archive": args.folder_in_archive,
    "download": False
}
dataset_train = ProcessedLIBRISPEECH(args.dataset_train, **kwargs)
dataset_valid = ProcessedLIBRISPEECH(args.dataset_valid, **kwargs)
if args.dataset_test:
    dataset_test = ProcessedLIBRISPEECH(args.dataset_test,**kwargs)

model_length_function = model_length_function_constructor(args.model_input_type)
collate_fn_train = collate_factory(model_length_function, augmentations)
collate_fn_valid = collate_factory(model_length_function)

loader_train = DataLoader(
    dataset_train,
    batch_size=args.batch_size,
    collate_fn=collate_fn_train,
    num_workers=args.workers,
    pin_memory=True,
    shuffle=True,
    drop_last=True,
)
loader_valid = DataLoader(
    dataset_valid,
    batch_size=args.batch_size,
    collate_fn=collate_fn_valid,
    num_workers=args.workers,
    pin_memory=True,
    shuffle=False,
    drop_last=False,
)
if args.dataset_test:
    loader_test = DataLoader(
        dataset_test,
        batch_size=args.batch_size,
        collate_fn=collate_fn_valid,
        num_workers=args.workers,
        pin_memory=True,
        shuffle=False,
        drop_last=False,
    )

# Decoder

decoder = GreedyDecoder()

# Model

model = Wav2Letter(
    num_classes=len(language_model),
    input_type=args.model_input_type,
    num_features=melkwargs["n_mels"],
)

devices = ["cuda" if torch.cuda.is_available() else "cpu"]
model = model.to(devices[0], non_blocking=True)
model = torch.nn.DataParallel(model)

n = count_parameters(model)
logging.info("Number of parameters: %s", n)

# Optimizer

optimizer = Adadelta(
    model.parameters(),
    lr=0.6,
    weight_decay=1e-05,
    eps=1e-08,
    rho=0.95,
)
scheduler = ReduceLROnPlateau(optimizer, patience=10, threshold=1e-3)

# Loss

encoded_char_blank = language_model.encode(char_blank)[0]
criterion = torch.nn.CTCLoss(blank=encoded_char_blank, zero_infinity=False, reduction="sum")
best_loss = 1.0

In [None]:
# Setup checkpoint

checkpoint_exists = os.path.isfile(args.checkpoint)

if args.checkpoint and checkpoint_exists and args.resume:
    logging.info("Checkpoint loading %s", args.checkpoint)
    checkpoint = torch.load(args.checkpoint)

    args.start_epoch = checkpoint["epoch"]
    best_loss = checkpoint["best_loss"]

    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    scheduler.load_state_dict(checkpoint["scheduler"])

    logging.info(
        "Checkpoint loaded '%s' at epoch %s", args.checkpoint, checkpoint["epoch"]
    )
elif args.checkpoint and checkpoint_exists:
    raise RuntimeError(
        "Checkpoint already exists. Set resume to True, or manually delete existing file."
    )
elif args.checkpoint and args.resume:
    raise RuntimeError("Checkpoint not found")
elif args.checkpoint:
    save_checkpoint(
        {
            "epoch": args.start_epoch,
            "state_dict": model.state_dict(),
            "best_loss": best_loss,
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict(),
        },
        False,
        args.checkpoint,
    )
elif not args.checkpoint and args.resume:
    raise RuntimeError("Checkpoint not provided. Use checkpoint to specify.")

In [None]:
dataframe_log = pd.DataFrame()

for epoch in range(args.start_epoch, args.max_epoch):

    logging.info("Epoch: %s", epoch)

    metric = Logger("train", dataframe=dataframe_log)
    metric["epoch"] = epoch

    train_one_epoch(
        model,
        criterion,
        optimizer,
        scheduler,
        loader_train,
        decoder,
        language_model,
        devices[0],
        metric,
    )

    metric = Logger("validation", dataframe=dataframe_log)
    metric["epoch"] = epoch

    loss = evaluate(
        model,
        criterion,
        loader_valid,
        decoder,
        language_model,
        devices[0],
        epoch,
        metric,
    )

    scheduler.step(loss)

    is_best = loss < best_loss
    best_loss = min(loss, best_loss)
    if args.checkpoint:
        save_checkpoint(
            {
                "epoch": epoch + 1,
                "state_dict": model.state_dict(),
                "best_loss": best_loss,
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
            },
            is_best,
            args.checkpoint,
        )

    metric = Logger("test", dataframe=dataframe_log)
    metric["epoch"] = epoch

    evaluate(
        model,
        criterion,
        loader_test,
        decoder,
        language_model,
        devices[0],
        epoch,
        metric,
    )

logging.info("End")