In [None]:
logging.basicConfig(
    format="%(asctime)s %(levelname)-8s %(message)s",
    level=logging.INFO,
    datefmt="%Y-%m-%d %H:%M:%S",
)

In [None]:
class Args:
    batch_size = 128
    bins = 13
    checkpoint = '/checkpoint/vincentqb/checkpoint/checkpoint-31271449-94c4cbddeb0034797a6c9a37d84f6e58.pth.tar'
    clip_grad = None
    dataset_folder_in_archive = '062419'
    dataset_root = '/datasets01/librispeech/'
    dataset_train = ['train-clean-100']
    dataset_valid = ['dev-clean']
    decoder = 'greedy'
    distributed = False
    dropout = 0.0
    epochs = 1000
    eps = 1e-08
    freq_mask = 0
    gamma = 0.99
    hidden_channels = 2000
    hop_length = 160
    learning_rate = 0.6
    model_input_type = 'mfcc'
    momentum = 0.8
    normalize = True
    optimizer = 'adadelta'
    output = None
    print_freq = 1
    progress_bar = False
    reduce_lr_valid = True
    reduction = 'sum'
    resume = False
    rho = 0.95
    scheduler = 'reduceonplateau'
    seed = 0
    speechcommands = False
    start_epoch = 0
    time_mask = 0
    weight_decay = 1e-05
    win_length = 400
    workers = 0
    world_size = 8

In [None]:
import itertools
import logging
from typing import List

import torch
from torch import Tensor
from torchaudio.datasets import LIBRISPEECH

from speechcommands import SPEECHCOMMANDS


def pad_sequence(sequences, padding_value=0.0):
    # type: (List[Tensor], float) -> Tensor
    r"""Pad a list of variable length Tensors with ``padding_value``

    ``pad_sequence`` stacks a list of Tensors along a new dimension,
    and pads them to equal length. If the input is list of
    sequences with size ``* x L`` then the output is and ``B x * x T``.

    `B` is batch size. It is equal to the number of elements in ``sequences``.
    `T` is length of the longest sequence.
    `L` is length of the sequence.
    `*` is any number of trailing dimensions, including none.

    Example:
        >>> from torch.nn.utils.rnn import pad_sequence
        >>> a = torch.ones(300, 25)
        >>> b = torch.ones(300, 22)
        >>> c = torch.ones(300, 15)
        >>> pad_sequence([a, b, c]).size()
        torch.Size([300, 3, 25])

    Note:
        This function returns a Tensor of size ``B x * x T``
        where `T` is the length of the longest sequence. This function assumes
        trailing dimensions and type of all the Tensors in sequences are same.

    Arguments:
        sequences (list[Tensor]): list of variable length sequences.
        padding_value (float, optional): value for padded elements. Default: 0.

    Returns:
        Tensor of size ``B x * x T``
    """

    # assuming trailing dimensions and type of all the Tensors
    # in sequences are same and fetching those from sequences[0]
    max_size = sequences[0].size()
    trailing_dims = max_size[:-1]
    max_len = max([s.size(-1) for s in sequences])
    out_dims = (len(sequences),) + trailing_dims + (max_len,)

    out_tensor = sequences[0].new_full(out_dims, padding_value)
    for i, tensor in enumerate(sequences):
        length = tensor.size(-1)
        # use index notation to prevent duplicate references to the tensor
        out_tensor[i, ..., :length] = tensor

    return out_tensor


class IterableMemoryCache:
    def __init__(self, iterable):
        self.iterable = iterable
        self._iter = iter(iterable)
        self._done = False
        self._values = []

    def __iter__(self):
        if self._done:
            return iter(self._values)
        return itertools.chain(self._values, self._gen_iter())

    def _gen_iter(self):
        for new_value in self._iter:
            self._values.append(new_value)
            yield new_value
        self._done = True

    def __len__(self):
        return len(self._iterable)


class MapMemoryCache(torch.utils.data.Dataset):
    """
    Wrap a dataset so that, whenever a new item is returned, it is saved to memory.
    """

    def __init__(self, dataset):
        self.dataset = dataset
        self._cache = [None] * len(dataset)

    def __getitem__(self, n):
        if self._cache[n] is None:
            self._cache[n] = self.dataset[n]
        return self._cache[n]

    def __len__(self):
        return len(self.dataset)


class Processed(torch.utils.data.Dataset):
    def __init__(self, dataset, transforms, encode):
        self.dataset = dataset
        self.transforms = transforms
        self.encode = encode

    def __getitem__(self, key):
        item = self.dataset[key]
        return self.process_datapoint(item)

    def __len__(self):
        return len(self.dataset)

    def process_datapoint(self, item):
        """
        Consume a LibriSpeech data point tuple:
        (waveform, sample_rate, utterance, speaker_id, chapter_id, utterance_id).
        - Transforms are applied to waveform. Output tensor shape (freq, time).
        - target gets transformed into lower case, and encoded into a one dimensional long tensor.
        """
        transformed = item[0]
        target = item[2].lower()

        transformed = self.transforms(transformed)

        target = self.encode(target)

        print_length = 20
        output_print = item[2].ljust(print_length)[:print_length]
        target_print = target[:print_length]
        logging.info("Target: %s    | Output: %s", target_print, output_print)

        target = torch.tensor(target, dtype=torch.long, device=transformed.device)

        return transformed, target


def split_process_librispeech(
    datasets, transforms, language_model, root, folder_in_archive,
):
    def create(tags, cache=True):

        if isinstance(tags, str):
            tags = [tags]
        if isinstance(transforms, list):
            transform_list = transforms
        else:
            transform_list = [transforms]

        data = torch.utils.data.ConcatDataset(
            [
                Processed(
                    LIBRISPEECH(
                        root, tag, folder_in_archive=folder_in_archive, download=False,
                    ),
                    transform,
                    language_model.encode,
                )
                for tag, transform in zip(tags, transform_list)
            ]
        )

        data = MapMemoryCache(data)
        return data

    # For performance, we cache all datasets
    return tuple(create(dataset) for dataset in datasets)


def split_process_speechcommands(
    datasets, transforms, language_model, root,
):
    def create(tags, cache=True):

        if isinstance(tags, str):
            tags = [tags]
        if isinstance(transforms, list):
            transform_list = transforms
        else:
            transform_list = [transforms]

        data = torch.utils.data.ConcatDataset(
            [
                Processed(
                    SPEECHCOMMANDS(root, split=tag, download=False,),
                    transform,
                    language_model.encode,
                )
                for tag, transform in zip(tags, transform_list)
            ]
        )

        data = MapMemoryCache(data)
        return data

    # For performance, we cache all datasets
    return tuple(create(dataset) for dataset in datasets)


def collate_factory(model_length_function, transforms=None):

    if transforms is None:
        transforms = torch.nn.Sequential()

    def collate_fn(batch):

        tensors = [transforms(b[0]) for b in batch]  # apply transforms to waveforms

        tensors_lengths = torch.tensor(
            [model_length_function(t) for t in tensors],
            dtype=torch.long,
            device=tensors[0].device,
        )

        # tensors = [b.transpose(1, -1) for b in batch]
        # tensors = torch.nn.utils.rnn.pad_sequence(tensors, batch_first=True)
        # tensors = tensors.transpose(1, -1)
        tensors = pad_sequence(tensors)

        targets = [b[1] for b in batch]  # extract target utterance
        target_lengths = torch.tensor(
            [target.shape[0] for target in targets],
            dtype=torch.long,
            device=tensors.device,
        )
        # targets = torch.nn.utils.rnn.pad_sequence(targets, batch_first=True)
        targets = pad_sequence(targets)

        return tensors, targets, tensors_lengths, target_lengths

    return collate_fn


In [None]:
import itertools
from collections.abc import Iterable


class LanguageModel:
    def __init__(self, labels, char_blank, char_space):

        self.char_space = char_space
        self.char_blank = char_blank

        enumerated = list(enumerate(labels))
        flipped = [(sub[1], sub[0]) for sub in enumerated]

        self._decode_map = dict(enumerated)
        self._encode_map = dict(flipped)

    def encode(self, listlike):
        if not isinstance(listlike, str):
            return [self.encode(i) for i in listlike]
        return [self._encode_map[i] + self._encode_map[self.char_blank] for i in listlike]

    def decode(self, tensor):
        if len(tensor) > 0 and isinstance(tensor[0], Iterable):
            return [self.decode(t) for t in tensor]

        # not idempotent, since clean string
        x = (self._decode_map[i] for i in tensor)
        x = "".join(i for i, _ in itertools.groupby(x))
        x = x.replace(self.char_blank, "")
        # x = x.strip()
        return x

    def __len__(self):
        return len(self._encode_map)


In [None]:
from collections import Counter

import torch
from torch import topk
from tqdm import tqdm


class GreedyIterableDecoder:
    def __init__(self, blank_label=0, collapse_repeated=True):
        self.blank_label = blank_label
        self.collapse_repeated = collapse_repeated

    def __call__(self, output):
        arg_maxes = torch.argmax(output, dim=-1)
        decodes = []
        for args in arg_maxes:
            decode = []
            for j, index in enumerate(args):
                if index != self.blank_label:
                    if self.collapse_repeated and j != 0 and index == args[j - 1]:
                        continue
                    decode.append(index.item())
            decode = torch.tensor(decode)
            decodes.append(decode)
        # decodes = torch.tensor(decodes)
        decodes = torch.nn.utils.rnn.pad_sequence(decodes, batch_first=True)
        return decodes


class GreedyDecoder:
    def __call__(self, outputs):
        """Greedy Decoder. Returns highest probability of class labels for each timestep

        Args:
            outputs (torch.Tensor): shape (input length, batch size, number of classes (including blank))

        Returns:
            torch.Tensor: class labels per time step.
        """
        _, indices = topk(outputs, k=1, dim=-1)
        return indices[..., 0]


def zeros_like(m):
    return zeros(len(m), len(m[0]))


def zeros(d1, d2):
    return list(list(0 for _ in range(d2)) for _ in range(d1))


def apply_transpose(f, m):
    return list(map(f, zip(*m)))


def argmax(l):
    return max(range(len(l)), key=lambda i: l[i])


def add1d2d(m1, m2):
    return [[v2 + v1 for v2 in m2_row] for m2_row, v1 in zip(m2, m1)]


def add1d1d(v1, v2):
    return [e + s for e, s in zip(v1, v2)]


class ListViterbiDecoder:
    def __init__(self, data_loader, vocab_size, n=2, progress_bar=False):
        self._transitions = self._build_transitions(
            data_loader, vocab_size, n, progress_bar
        )

    def __call__(self, emissions):
        return torch.tensor([self._decode(emissions[i].tolist(), self._transitions)[0] for i in range(len(emissions))])

    @staticmethod
    def _build_transitions(data_loader, vocab_size, n=2, progress_bar=False):

        # Count n-grams
        count = Counter()
        for _, label in tqdm(data_loader, disable=not progress_bar):
            count += Counter(a for a in zip(*(label[i:] for i in range(n))))

        # Write as matrix
        transitions = zeros(vocab_size, vocab_size)
        for (k1, k2), v in count.items():
            transitions[k1][k2] = v

        return transitions

    @staticmethod
    def _decode(emissions, transitions):
        scores = zeros_like(emissions)
        back_pointers = zeros_like(emissions)
        scores = emissions[0]

        # Generate most likely scores and paths for each step in sequence
        for i in range(1, len(emissions)):
            score_with_transition = add1d2d(scores, transitions)
            max_score_with_transition = apply_transpose(max, score_with_transition)
            scores = add1d1d(emissions[i], max_score_with_transition)
            back_pointers[i] = apply_transpose(argmax, score_with_transition)

        # Generate the most likely path
        viterbi = [argmax(scores)]
        for bp in reversed(back_pointers[1:]):
            viterbi.append(bp[viterbi[-1]])
        viterbi.reverse()
        viterbi_score = max(scores)

        return viterbi, viterbi_score


class ViterbiDecoder:
    def __init__(self, data_loader, vocab_size, n=2, progress_bar=False):
        self.vocab_size = vocab_size
        self.n = n
        self.top_k = 1
        self.progress_bar = progress_bar

        self._build_transitions(data_loader)

    def _build_transitions(self, data_loader):

        # Count n-grams

        c = Counter()
        for _, label in tqdm(data_loader, disable=not self.progress_bar):
            count = Counter(
                tuple(b.item() for b in a)
                for a in zip(*(label[i:] for i in range(self.n)))
            )
            c += count

        # Encode as transition matrix

        ind = torch.tensor([a for (a, _) in c.items()]).t()
        val = torch.tensor([b for (_, b) in c.items()], dtype=torch.float)

        transitions = (
            torch.sparse_coo_tensor(
                indices=ind, values=val, size=[self.vocab_size, self.vocab_size]
            )
            .coalesce()
            .to_dense()
        )
        transitions = transitions / torch.max(
            torch.tensor(1.0), transitions.max(dim=1)[0]
        ).unsqueeze(1)

        self.transitions = transitions

    def _viterbi_decode(self, tag_sequence: torch.Tensor):
        """
        Perform Viterbi decoding in log space over a sequence given a transition matrix
        specifying pairwise (transition) potentials between tags and a matrix of shape
        (sequence_length, num_tags) specifying unary potentials for possible tags per
        timestep.

        Parameters
        ----------
        tag_sequence : torch.Tensor, required.
            A tensor of shape (sequence_length, num_tags) representing scores for
            a set of tags over a given sequence.

        Returns
        -------
        viterbi_path : List[int]
            The tag indices of the maximum likelihood tag sequence.
        viterbi_score : float
            The score of the viterbi path.
        """
        sequence_length, num_tags = tag_sequence.size()

        path_scores = []
        path_indices = []
        # At the beginning, the maximum number of permutations is 1; therefore, we unsqueeze(0)
        # to allow for 1 permutation.
        path_scores.append(tag_sequence[0, :].unsqueeze(0))
        # assert path_scores[0].size() == (n_permutations, num_tags)

        # Evaluate the scores for all possible paths.
        for timestep in range(1, sequence_length):
            # Add pairwise potentials to current scores.
            # assert path_scores[timestep - 1].size() == (n_permutations, num_tags)
            summed_potentials = (
                path_scores[timestep - 1].unsqueeze(2) + self.transitions
            )
            summed_potentials = summed_potentials.view(-1, num_tags)

            # Best pairwise potential path score from the previous timestep.
            max_k = min(summed_potentials.size()[0], self.top_k)
            scores, paths = torch.topk(summed_potentials, k=max_k, dim=0)
            # assert scores.size() == (n_permutations, num_tags)
            # assert paths.size() == (n_permutations, num_tags)

            scores = tag_sequence[timestep, :] + scores
            # assert scores.size() == (n_permutations, num_tags)
            path_scores.append(scores)
            path_indices.append(paths.squeeze())

        # Construct the most likely sequence backwards.
        path_scores = path_scores[-1].view(-1)
        max_k = min(path_scores.size()[0], self.top_k)
        viterbi_scores, best_paths = torch.topk(path_scores, k=max_k, dim=0)

        viterbi_paths = []
        for i in range(max_k):

            viterbi_path = [best_paths[i].item()]
            for backward_timestep in reversed(path_indices):
                viterbi_path.append(int(backward_timestep.view(-1)[viterbi_path[-1]]))

            # Reverse the backward path.
            viterbi_path.reverse()

            # Viterbi paths uses (num_tags * n_permutations) nodes; therefore, we need to modulo.
            viterbi_path = [j % num_tags for j in viterbi_path]
            viterbi_paths.append(viterbi_path)

        return viterbi_paths, viterbi_scores

    def __call__(self, tag_sequence: torch.Tensor):

        outputs = []
        scores = []
        for i in range(tag_sequence.shape[1]):
            paths, score = self._viterbi_decode(tag_sequence[:, i, :])
            outputs.append(paths)
            scores.append(score)

        return torch.tensor(outputs).transpose(0, -1), torch.cat(scores)[:, 0, :]


In [None]:
from typing import List, Union


def levenshtein_distance(r: Union[str, List[str]], h: Union[str, List[str]]) -> int:
    """
    Calculate the Levenshtein distance between two lists or strings.

    The function computes an edit distance allowing deletion, insertion and substitution.
    The result is an integer. Users may want to normalize by the length of the reference.

    Args:
        r (str or List[str]): the reference list or string to compare.
        h (str or List[str]): the hypothesis, the predicted list or string, to compare.
    Returns:
        int: The distance between the reference and the hypothesis.
    """

    # Initialisation
    dold = list(range(len(h) + 1))
    dnew = list(0 for _ in range(len(h) + 1))

    # Computation
    for i in range(1, len(r) + 1):
        dnew[0] = i
        for j in range(1, len(h) + 1):
            if r[i - 1] == h[j - 1]:
                dnew[j] = dold[j - 1]
            else:
                substitution = dold[j - 1] + 1
                insertion = dnew[j - 1] + 1
                deletion = dold[j] + 1
                dnew[j] = min(substitution, insertion, deletion)

        dnew, dold = dold, dnew

    return dold[-1]

In [None]:
import torch


class Normalize(torch.nn.Module):
    def forward(self, tensor):
        return (tensor - tensor.mean(-1, keepdim=True)) / tensor.std(-1, keepdim=True)


class UnsqueezeFirst(torch.nn.Module):
    def forward(self, tensor):
        return tensor.unsqueeze(0)


class ToMono(torch.nn.Module):
    def forward(self, tensor):
        return tensor[0, ...]

In [None]:
import logging
import os
import signal
import string

import torch
import torchaudio
from torch.optim import SGD, Adadelta, Adam, AdamW
from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau
from torch.utils.data import DataLoader
from torchaudio.datasets.utils import bg_iterator
from torchaudio.transforms import MFCC

from ctc_decoders import (
    GreedyDecoder,
    GreedyIterableDecoder,
    ListViterbiDecoder,
    ViterbiDecoder,
)
from datasets import (
    collate_factory,
    split_process_librispeech,
    split_process_speechcommands,
)
from languagemodels import LanguageModel
from metrics import levenshtein_distance
from transforms import Normalize, ToMono, UnsqueezeFirst
from utils import Logger, count_parameters, save_checkpoint

# from torchaudio.models.wav2letter import Wav2Letter
from wav2letter import Wav2Letter

# TODO Remove before merge pull request
MAIN_PID = os.getpid()
SIGNAL_RECEIVED = False


# TODO Remove before merge pull request
def signal_handler(a, b):
    global SIGNAL_RECEIVED
    logging.warning("Signal received")
    SIGNAL_RECEIVED = True


# TODO Remove before merge pull request
def trigger_job_requeue():
    # Submit a new job to resume from checkpoint.
    if os.environ["SLURM_PROCID"] == "0" and os.getpid() == MAIN_PID:
        logging.warning("PID: %s. PPID: %s.", os.getpid(), os.getppid())
        logging.warning("Resubmitting job")
        command = "scontrol requeue " + os.environ["SLURM_JOB_ID"]
        logging.warning(command)
        if os.system(command):
            raise RuntimeError("Fail to resubmit")
        logging.warning("New job submitted to the queue")
    exit(0)


def setup_distributed(rank, world_size, master_addr, master_port):
    os.environ["MASTER_ADDR"] = str(master_addr)
    os.environ["MASTER_PORT"] = str(master_port)

    # See documentation for choice of backend
    # https://pytorch.org/docs/stable/distributed.html
    backend = "nccl" if torch.cuda.is_available() else "gloo"

    # initialize the process group
    torch.distributed.init_process_group(
        backend, rank=rank, world_size=world_size, init_method="env://"
    )


def model_length_function_constructor(model_input_type):
    if model_input_type == "waveform":
        return lambda tensor: int(tensor.shape[-1]) // 160 // 2 + 1
    elif model_input_type == "mfcc":
        return lambda tensor: int(tensor.shape[-1]) // 2 + 1
    raise NotImplementedError(
        f"Selected model input type {model_input_type} not supported"
    )


def record_losses(outputs, targets, decoder, language_model, loss_value, metric):

    # outputs: input length, batch size, number of classes (including blank)
    metric["batch size"] = outputs.shape[1]
    metric["cumulative batch size"] += metric["batch size"]

    # Record loss

    metric["cumulative loss"] += loss_value
    metric["epoch loss"] = metric["cumulative loss"] / metric["cumulative batch size"]
    metric["batch loss"] = loss_value / metric["batch size"]

    # Decode output

    output = outputs.transpose(0, 1).to("cpu")
    output = decoder(output)

    # Compute CER

    output = language_model.decode(output.tolist())
    target = language_model.decode(targets.tolist())

    cers = [levenshtein_distance(t, o) for t, o in zip(target, output)]
    cers = sum(cers)
    n = sum(len(t) for t in target)

    metric["total chars"] += n
    metric["cumulative char errors"] += cers
    metric["batch cer"] = cers / n
    metric["epoch cer"] = metric["cumulative char errors"] / metric["total chars"]

    # Print a few output/target pairs

    print_length = 20
    for i in range(2):
        # Print a few examples
        output_print = output[i].ljust(print_length)[:print_length]
        target_print = target[i].ljust(print_length)[:print_length]
        logging.info("Target: %s    | Output: %s", target_print, output_print)

    # Compute WER

    output = [o.split(language_model.char_space) for o in output]
    target = [t.split(language_model.char_space) for t in target]

    wers = [levenshtein_distance(t, o) for t, o in zip(target, output)]
    wers = sum(wers)
    n = sum(len(t) for t in target)

    metric["total words"] += n
    metric["cumulative word errors"] += wers
    metric["batch wer"] = wers / n
    metric["epoch wer"] = metric["cumulative word errors"] / metric["total words"]

    return metric["epoch loss"]


def _get_optimizer(args, model):
    if args.optimizer == "adadelta":
        return Adadelta(
            model.parameters(),
            lr=args.learning_rate,
            weight_decay=args.weight_decay,
            eps=args.eps,
            rho=args.rho,
        )
    elif args.optimizer == "sgd":
        return SGD(
            model.parameters(),
            lr=args.learning_rate,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
        )
    elif args.optimizer == "adam":
        return Adam(
            model.parameters(),
            lr=args.learning_rate,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
        )
    elif args.optimizer == "adamw":
        return AdamW(
            model.parameters(),
            lr=args.learning_rate,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
        )

    raise NotImplementedError(f"Selected optimizer {args.optimizer} not supported")


def _get_scheduler(args, optimizer):
    if args.scheduler == "exponential":
        return ExponentialLR(optimizer, gamma=args.gamma)
    elif args.scheduler == "reduceonplateau":
        return ReduceLROnPlateau(optimizer, patience=10, threshold=1e-3)

    raise NotImplementedError(f"Selected scheduler {args.scheduler} not supported")


def train_one_epoch(
    model,
    criterion,
    optimizer,
    scheduler,
    data_loader,
    decoder,
    language_model,
    device,
    epoch,
    clip_grad,
    disable_logger=False,
    reduce_lr_on_plateau=False,
):

    model.train()

    metric = Logger("train", disable=disable_logger)
    metric["epoch"] = epoch

    for inputs, targets, tensors_lengths, target_lengths in bg_iterator(
        data_loader, maxsize=2
    ):

        inputs = inputs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        # keep batch first for data parallel
        outputs = model(inputs).transpose(-1, -2).transpose(0, 1)

        # CTC
        # outputs: input length, batch size, number of classes (including blank)
        # targets: batch size, max target length
        # input_lengths: batch size
        # target_lengths: batch size

        logging.info("%s     outputs", outputs)
        logging.info("%s     targets", targets)
        logging.info("%s     input_lengths", tensors_lengths)
        logging.info("%s     target_lengths", target_lengths)

        logging.info(
            "%s     outputs: input length, batch size, number of classes (including blank)",
            outputs.shape,
        )
        logging.info("%s     targets: batch size, max target length", targets.shape)
        logging.info("%s     input_lengths: batch size", tensors_lengths.shape)
        logging.info("%s     target_lengths: batch size", target_lengths.shape)

        loss = criterion(outputs, targets, tensors_lengths, target_lengths)

        optimizer.zero_grad()
        loss.backward()

        if clip_grad is not None:
            metric["gradient"] = torch.nn.utils.clip_grad_norm_(
                model.parameters(), clip_grad
            )

        optimizer.step()

        # FIXME reduced summed loss value in distributed case?
        avg_loss = record_losses(
            outputs, targets, decoder, language_model, loss.item(), metric
        )

        metric["lr"] = optimizer.param_groups[0]["lr"]
        metric["channel size"] = inputs.shape[1]
        metric["time size"] = inputs.shape[-1]
        metric.flush()

        # TODO Remove before merge pull request
        if SIGNAL_RECEIVED:
            break

    if reduce_lr_on_plateau and isinstance(scheduler, ReduceLROnPlateau):
        scheduler.step(avg_loss)
    elif not isinstance(scheduler, ReduceLROnPlateau):
        scheduler.step()


def evaluate(
    model,
    criterion,
    data_loader,
    decoder,
    language_model,
    device,
    epoch,
    disable_logger=False,
):

    with torch.no_grad():

        model.eval()
        metric = Logger("validation", disable=disable_logger)
        metric["epoch"] = epoch

        for inputs, targets, tensors_lengths, target_lengths in bg_iterator(
            data_loader, maxsize=2
        ):

            inputs = inputs.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)

            # keep batch first for data parallel
            outputs = model(inputs).transpose(-1, -2).transpose(0, 1)

            # CTC
            # outputs: input length, batch size, number of classes (including blank)
            # targets: batch size, max target length
            # input_lengths: batch size
            # target_lengths: batch size

            loss = criterion(outputs, targets, tensors_lengths, target_lengths)

            avg_loss = record_losses(
                outputs, targets, decoder, language_model, loss.item(), metric
            )

            # TODO Remove before merge pull request
            if SIGNAL_RECEIVED:
                break

        metric.flush()

        return avg_loss


def main(rank, args):

    # Distributed setup

    if args.distributed:
        setup_distributed(rank, args.world_size, args.distributed_master_addr, args.distributed_master_port)

    main_rank = rank == 0

    # Install signal handler
    # TODO Remove before merge pull request
    signal.signal(signal.SIGUSR1, signal_handler)

    logging.info("Start")

    # Empty CUDA cache
    torch.cuda.empty_cache()

    # Change backend for flac files
    torchaudio.set_audio_backend("soundfile")

    # Transforms

    melkwargs = {
        "n_fft": args.win_length,
        "n_mels": args.bins,
        "hop_length": args.hop_length,
    }

    sample_rate_original = 16000

    transforms = torch.nn.Sequential(ToMono())

    if args.model_input_type == "mfcc":
        transforms = torch.nn.Sequential(
            transforms,
            MFCC(
                sample_rate=sample_rate_original, n_mfcc=args.bins, melkwargs=melkwargs,
            ),
        )
    elif args.model_input_type == "waveform":
        transforms = torch.nn.Sequential(transforms, UnsqueezeFirst())
        # assert args.bins == 1, "waveform model input type only supports bins == 1"
        if args.bins != 1:
            logging.warn("waveform model input type only supports bins == 1")
            args.bins = 1
    else:
        raise NotImplementedError(
            f"Selected model input type {args.model_input_type} not supported"
        )

    if args.normalize:
        transforms = torch.nn.Sequential(transforms, Normalize())

    augmentations = torch.nn.Sequential()
    if args.freq_mask:
        augmentations = torch.nn.Sequential(
            augmentations,
            torchaudio.transforms.FrequencyMasking(freq_mask_param=args.freq_mask),
        )
    if args.time_mask:
        augmentations = torch.nn.Sequential(
            augmentations,
            torchaudio.transforms.TimeMasking(time_mask_param=args.time_mask),
        )

    # Text preprocessing

    char_blank = "*"
    char_space = " "
    char_apostrophe = "'"
    labels = char_blank + char_space + char_apostrophe + string.ascii_lowercase
    language_model = LanguageModel(labels, char_blank, char_space)

    # Dataset

    if args.speechcommands:
        training, validation = split_process_speechcommands(
            ["training", "validation"],
            [transforms, transforms],
            language_model,
            root="/private/home/vincentqb/audio-pytorch/examples/pipeline_wav2letter/",
            # root=args.dataset_root,
            # folder_in_archive=args.dataset_folder_in_archive,
        )
    else:
        training, validation = split_process_librispeech(
            [args.dataset_train, args.dataset_valid],
            [transforms, transforms],
            language_model,
            root=args.dataset_root,
            folder_in_archive=args.dataset_folder_in_archive,
        )

    # Decoder

    if args.decoder == "greedy":
        decoder = GreedyDecoder()
    elif args.decoder == "greedyiter":
        decoder = GreedyIterableDecoder()
    elif args.decoder == "viterbi":
        decoder = ListViterbiDecoder(
            training, len(language_model), progress_bar=args.progress_bar
        )
    else:
        raise ValueError("Selected decoder not supported")

    # Model

    model = Wav2Letter(
        num_classes=len(language_model),
        input_type=args.model_input_type,
        num_features=args.bins,
        num_hidden_channels=args.hidden_channels,
        dropout=args.dropout,
    )

    if args.distributed:
        n = torch.cuda.device_count() // args.world_size
        devices = list(range(rank * n, (rank + 1) * n))
        model = model.to(devices[0])
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=devices)
    else:
        devices = ["cuda" if torch.cuda.is_available() else "cpu"]
        model = model.to(devices[0], non_blocking=True)
        model = torch.nn.DataParallel(model)

    n = count_parameters(model)
    logging.info("Number of parameters: %s", n)

    # Optimizer
    optimizer = _get_optimizer(args, model)
    scheduler = _get_scheduler(args, optimizer)

    # Loss
    encoded_char_blank = language_model.encode(char_blank)[0]
    criterion = torch.nn.CTCLoss(
        blank=encoded_char_blank, zero_infinity=False, reduction=args.reduction
    )

    # Data Loader

    model_length_function = model_length_function_constructor(args.model_input_type)
    collate_fn_train = collate_factory(model_length_function, augmentations)
    collate_fn_valid = collate_factory(model_length_function)

    loader_training = DataLoader(
        training,
        batch_size=args.batch_size,
        collate_fn=collate_fn_train,
        num_workers=args.workers,
        pin_memory=True,
        shuffle=True,
        drop_last=True,
    )
    loader_validation = DataLoader(
        validation,
        batch_size=args.batch_size,
        collate_fn=collate_fn_valid,
        num_workers=args.workers,
        pin_memory=True,
        shuffle=False,
        drop_last=False,
    )

    # Setup checkpoint

    best_loss = 1.0

    checkpoint_exists = args.checkpoint and os.path.isfile(args.checkpoint)

    if args.distributed:
        torch.distributed.barrier()

    if args.checkpoint and checkpoint_exists:
        logging.info("Checkpoint loading %s", args.checkpoint)
        checkpoint = torch.load(args.checkpoint)

        args.start_epoch = checkpoint["epoch"]
        best_loss = checkpoint["best_loss"]

        model.load_state_dict(checkpoint["state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        scheduler.load_state_dict(checkpoint["scheduler"])

        logging.info(
            "Checkpoint loaded '%s' at epoch %s", args.checkpoint, checkpoint["epoch"]
        )
    elif args.checkpoint and main_rank:
        save_checkpoint(
            {
                "epoch": args.start_epoch,
                "state_dict": model.state_dict(),
                "best_loss": best_loss,
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
            },
            False,
            args.checkpoint,
        )

    if args.distributed:
        torch.distributed.barrier()

    torch.autograd.set_detect_anomaly(False)

    for epoch in range(args.start_epoch, args.max_epoch):

        logging.info("Epoch: %s", epoch)

        train_one_epoch(
            model,
            criterion,
            optimizer,
            scheduler,
            loader_training,
            decoder,
            language_model,
            devices[0],
            epoch,
            args.clip_grad,
            not main_rank,
            not args.reduce_lr_valid,
        )

        loss = evaluate(
            model,
            criterion,
            loader_validation,
            decoder,
            language_model,
            devices[0],
            epoch,
            not main_rank,
        )

        if args.reduce_lr_valid and isinstance(scheduler, ReduceLROnPlateau):
            scheduler.step(loss)

        is_best = loss < best_loss
        best_loss = min(loss, best_loss)
        if main_rank and args.checkpoint:
            save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "state_dict": model.state_dict(),
                    "best_loss": best_loss,
                    "optimizer": optimizer.state_dict(),
                    "scheduler": scheduler.state_dict(),
                },
                is_best,
                args.checkpoint,
            )

        # TODO Remove before merge pull request
        if SIGNAL_RECEIVED:
            if main_rank and args.checkpoint:
                save_checkpoint(
                    {
                        "epoch": epoch + 1,
                        "state_dict": model.state_dict(),
                        "best_loss": best_loss,
                        "optimizer": optimizer.state_dict(),
                        "scheduler": scheduler.state_dict(),
                    },
                    False,
                    args.checkpoint,
                )
            trigger_job_requeue()

    logging.info("End")

    if args.distributed:
        torch.distributed.destroy_process_group()


def spawn_main(args):
    if args.distributed:
        torch.multiprocessing.spawn(
            main, args=(args,), nprocs=args.world_size, join=True
        )
    else:
        main(0, args)


In [None]:
main(Args)