# Automatic Speech Recognition I

In [None]:
# %%capture pip_install_requirements_output
%pip install --quiet --upgrade -r requirements.txt

In [None]:
import os
import random
import urllib
from collections import defaultdict
from typing import List, Tuple, TypeVar, Optional, Iterable

import arpa
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import nltk
import numpy as np
import pandas as pd
import requests
import seaborn as sns
import sentencepiece as spm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torchaudio
from g2p_en import G2p

data_directory = './week_04_data'

In [None]:
%load_ext autoreload
%autoreload 2

import utils

In [None]:
base_url = 'https://cloud-api.yandex.net/v1/disk/public/resources/download?'
public_key = 'https://disk.yandex.ru/d/KqloT1zKr_2VaA'
final_url = base_url + urllib.parse.urlencode(dict(public_key=public_key))
response = requests.get(final_url)
download_url = response.json()['href']
!wget -O week_04_data.tar.gz "{download_url}"
!mkdir -p week_04_data
!tar -xf week_04_data.tar.gz -C week_04_data

## Acoustic features

In [None]:
target_sample_rate = 16_000

waveform, sample_rate = torchaudio.load(os.path.join(data_directory, 'babenko.wav'))

resample_waveform = torchaudio.transforms.Resample(
    sample_rate,
    target_sample_rate,
)

wav_to_melspec = torchaudio.transforms.MelSpectrogram(
    sample_rate=target_sample_rate,
    n_mels=80,
)

melspec = wav_to_melspec(resample_waveform(waveform))[0]

# plotting
plt.figure(figsize=(10, 5))
melspec_log = (torch.log(melspec + 1e-9)).numpy()
ax = sns.heatmap(melspec_log, cmap='viridis', cbar_kws={'label': 'Power'})

num_ticks = 20
time_ticks = np.linspace(0, melspec_log.shape[1], num_ticks)
time_labels = [f"{t:.1f}" for t in np.linspace(0, waveform.shape[1] / target_sample_rate, num_ticks)]
ax.set_xticks(time_ticks)
ax.set_xticklabels(time_labels)

plt.xlabel('Time (seconds)')
plt.ylabel('Mel Frequency Bands')
plt.title('Mel Spectrogram')
plt.tight_layout()
plt.show()

## Speech units

In [None]:
text = 'you will not be forced to learn machine learning'

words = text.split(' ')

graphemes = list(text)

nltk.download('averaged_perceptron_tagger_eng', quiet=True)
grapheme_to_phoneme_model = G2p()
phonemes = grapheme_to_phoneme_model(text)

subword_model = spm.SentencePieceProcessor(model_file=os.path.join(data_directory, 'sentencepiece.bpe.model'))  # https://huggingface.co/facebook/s2t-small-librispeech-asr/blob/main/sentencepiece.bpe.model
subwords = subword_model.EncodeAsPieces(text, enable_sampling=False)

print(f'{text = }\n{words = }\n{graphemes = }\n{phonemes = }\n{subwords = }')

## Metrics

Word (Character, Phoneme) Error Rate (WER/CER/PER) – are the most popular metrics, which try to approximate how we perceive errors in the speech we hear. We will:
- learn to calculate each of these distances using the Levenstein distance
- implement both a naive recursive version of the Levenstein algorithm as well as a more efficient dynamic-programming implementation
- using implementation of Levenstein distance you wil compute ASR quality metrics.

### Levenshtein distance

Consider an iterable sequence of elements, such as word or characters. Assume that the following operations can be performed on the sequence:
* **insertion**: cat → ca<font color='green'>s</font>t,
* **deletion**: ca<font color='red'>s</font>t → cat,
* **substitution**: c<font color='blue'>a</font>t → c<font color='blue'>u</font>t,

and suppose they have equal **costs**. These operations are enough to translate an arbitrary sequence into a different arbitrary sequence. There are many ways in which we can transform (edit) the sequence into a difference sequence. Given two sequences A and B, our goal is to find the minimum number of edits which are needed to transform sequence A into sequence B. This is known as the **Levenstein Distance**.

<!-- ### Algorithm definition -->

**Levenshtein distance** – the minimum number insertions, deletions, and substitutions required to transform sequence A into sequence B.


The Levenstein distance can be computed using the following recursive algorithm, known as the **Levenstein Algorithm**:

$$
\mathrm{L}(a, b) =
\begin{cases}
    |a|,& \text{if } |b| = 0, ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ \text{\# second sequence is empty} \\
    |b|,& \text{if } |a| = 0, ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ \text{\# first sequence is empty} \\
    \mathrm{L}(\mathrm{tail}(a), \mathrm{tail}(b)),& \text{if } \mathrm{head}(a) = \mathrm{head}(b), ~ ~ \text{\# first elements of two sequencies are equal} \\
    1 + min 
    \begin{cases} 
        \mathrm{L}(\mathrm{tail}(a), b), ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ \text{\# deletion from first sequence} \\ 
        \mathrm{L}(a, \mathrm{tail}(b)), ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ \text{\# insertion into first sequence} \\ 
        \mathrm{L}(\mathrm{tail}(a), \mathrm{tail}(b)); ~ ~ ~ ~ \text{\# substitution}
    \end{cases} & \text{, otherwise.}
\end{cases}
$$

<!-- %%
\text{lev}(a,b) = \left \{
  \begin{aligned}
    &\text{len}(a), & \text(del)& &&\text{if}\ \text{len}(b)=0, \\
    &\text{len}(b), & \text(ins)& &&\text{if}\ \text{len}(a)=0, \\
    &\text{lev}(a[1:],b[1:]) & \text(cor)& &&\text{if}\ a[0] = b[0], \\
    &1+\min\left \{
  \begin{aligned}
    &\text{lev}(a[1:],b) & \text(del)\\
    &\text{lev}(a,b[1:]) & \text(ins)\\
    &\text{lev}(a[1:],b[1:]) & \text(sub/cor)
  \end{aligned} \right. && &&\text{otherwise}
  \end{aligned} \right.
%% -->

As you can see, the Levenshtein distance is a metric in the mathematical sense (symmetry, positive certainty, triangle inequality).

**Question: what is the complexity of this algorithm?**

### Naive recursive implementation

Let's try to implement the recursive algorithm described above.

In [None]:
def levenshtein_naive(a: Iterable, b: Iterable) -> int:
    """Recursive implementation of Levenshtein distance

    :param a: Iterable
    :param b: Iterable
    :return distance: int
    """

    #############################################
    # <YOUR CODE>
    #############################################

In [None]:
# Assess algorithm correctness
def run_tests(fn):
    assert fn('kitten', 'sitten') == 1
    assert fn('kitten', 'sit') == 4
    assert fn('kitten', 'puppy') == 6
    assert fn('bcabac', 'cabcab') == 3

run_tests(levenshtein_naive)

###  Wagner–Fischer algorithm

The complexity of the naive recursive implementation of the Levenshtein distance algorithm is exponential. This is due to the fact that the distances for the same suffixes are recalculated more than once! This can be avoided if we cache the results of calculations in the form of a matrix of distances between suffixes (more conveniently, prefixes), and fill in this matrix iteratively. The resulting algorithm is named **Wagner–Fischer algorithm** and is an example of a __dynamic programming__ algorithm.

The Wagner-Fisher algorithm is defined as follows:


$$
  \mathrm{L}_{a,b}(i,j) = \left \{
  \begin{aligned}
    &\max(i,j), && &&\text{if}\ \min(i,j)=0, \\
    &\min\left \{
  \begin{aligned}
    &\mathrm{L}_{a,b}(i-1,j)+1 & \text(del)\\
    &\mathrm{L}_{a,b}(i,j-1)+1 & \text(ins)\\
    &\mathrm{L}_{a,b}(i-1,j-1)+\delta(a_i \neq b_j) & \text(sub)
  \end{aligned} \right. && &&\text{otherwise}
  \end{aligned} \right.
$$


**Implement** the `levenshtein_distance_matrix` function, which returns **the distance matrix between the prefixes of the two sequences**. The lower-right element of this matrix is the distance between the prefixes that are equal to the original sequences.

It is necessary to fill in this matrix line by line: for a new element of this matrix, it is enough to know only its neighbors to the left, top, and left-top.

We will also prepend an element denoting an **empty prefix** to the sequences – this is done in order to initialize the initial boundary values (initialize the first row and the first column of the matrix with the index values).



In [None]:
def levenshtein_distance_matrix(a: Iterable, b: Iterable) -> np.ndarray:
    """Matrix implementation of Levenshtein distance

    :param a: Iterable
    :param b: Iterable
    :return distance matrix: np.ndarray
    """
    a = ['#'] + list(a)
    b = ['#'] + list(b)
    d = np.zeros((len(a), len(b)), dtype=int)

    #############################################
    # <YOUR CODE>
    #############################################

    return d

def levenshtein_dp(a: Iterable, b: Iterable) -> int:
    return levenshtein_distance_matrix(a, b)[-1, -1]

In [None]:
# Auxiliary function for drawing this matrix:
def plot_matrix(matrix, row_names, column_names, path=None, mods=None):
    """
    :param matrix: np.array [n_rows, n_cols] levenstein distance matrix
    :param row_names: Name of the row elements
    :param column_names: Name of the column elements
    :param path:
    :param mods:
    :return: None
    """
    row_names = ['#'] + list(row_names)
    column_names = ['#'] + list(column_names)
    matrix = np.array(matrix)

    plt.figure(figsize=(len(column_names) / 2, len(row_names) / 2))
    plt.imshow(matrix, interpolation='nearest', cmap=plt.get_cmap('Blues'))
    plt.title("Levenshtein prefix distances")

    r = 0 if max(map(len, row_names + column_names)) < 3 else 45
    plt.gca().xaxis.tick_top()
    plt.xticks(range(len(column_names)), column_names, fontsize=12, rotation=r)
    plt.yticks(range(len(row_names)), row_names, fontsize=12, rotation=r)

    for i in range(matrix.shape[0]):
        for j in range(matrix.shape[1]):
            kwargs = {
                'color': "white" if matrix[i, j] > matrix.max() / 2 else "black",
                'horizontalalignment': 'center'
            }
            plt.text(j, i, "{:,}".format(matrix[i, j]), **kwargs)

    if path is not None:
        for (i, j), mod in zip(path, mods):
            colors = {
                'same': '#888888',
                'subst': '#0000ff',
                'del': '#ff0000',
                'insert': '#00ff00'
            }

            rect = patches.Rectangle(
                (j - 0.45, i - 0.45), 0.9, 0.9,
                edgecolor=colors[mod], facecolor='none', linewidth=2)
            plt.gca().add_patch(rect)

    plt.show()


plot_matrix([
    [0, 1, 2, 3, 4],
    [1, 0, 1, 2, 3],
    [2, 1, 0, 1, 2],
    [3, 2, 1, 1, 1]
], 'cat', 'cast')

In [None]:
# first, second = 'sunday', 'saturday'
first, second = 'elephant', 'relevant'
plot_matrix(levenshtein_distance_matrix(first, second), first, second)

def run_tests(fn):
    assert fn('kitten', 'sitten') == 1
    assert fn('kitten', 'sit') == 4
    assert fn('kitten', 'puppy') == 6
    assert fn('bcabac', 'cabcab') == 3

    for _ in range(100):
        first = "".join([random.choice('abc') for _ in range(random.choice(range(3, 10)))])
        second = "".join([random.choice('abc') for _ in range(random.choice(range(3, 10)))])
        assert fn(first, second) == levenshtein_naive(first, second)

# lets check our implementation on random sequences
run_tests(levenshtein_dp)

### Backtrace

To understand what insertions, deletions and substitutions were made on the original sequence, you can do a backtrace on the resulting matrix. 

Let's consider the first sequence as the original one, abd we will call the deletions and inserts relative to it.

**Implement** the `backtrace` function, based on the construction logic `levenshtein_distance_matrix`:
* write the path to the variable `path` – the list of the coordinates of the matrix cells that lie on the optimal path through the matrix;
* and in the `mods` variable, write down the modifications that we make on the original sequence:
    * `same` - leaving the element unchanged
    * `subst` - replacing the element
    * `del` - deleting the element
    * `insert` - inserting the element

In [None]:
plot_matrix([
    [0, 1, 2, 3, 4],
    [1, 0, 1, 2, 3],
    [2, 1, 0, 1, 2],
    [3, 2, 1, 1, 1]
],
    'cat', 'cast',
    [(0, 0), (1, 1), (2, 2), (2, 3), (3, 4)],
    ['same', 'same', 'same', 'insert', 'same'])

In [None]:
def backtrace(d : np.ndarray):
    """Backtrace for Levenstein Distance

    :param d: Levenstein Distance matrix (np.ndarray)
    :return path:
    :return path:
    """
    path = []
    mods = []

    ##########################################
    # <YOUR CODE>
    ##########################################

    return path, mods

In [None]:
first, second = 'thursday tea', 'friday beer'

path, mods = backtrace(levenshtein_distance_matrix(first, second))
plot_matrix(levenshtein_distance_matrix(first, second), first, second, path, mods)

Let's try applying the Levenshtein distance to a sequence of words, not characters.

In [None]:
first = "вас не будут заставлять учить машинное обучение".split()
second = "вас не будут force to учить machine learning".split()
path, mods = backtrace(levenshtein_distance_matrix(first, second))

S = int((np.array(mods) == 'subst').sum())
I = int((np.array(mods) == 'insert').sum())
D = int((np.array(mods) == 'del').sum())
print(f"{S = }\t{I = }\t{D = }")
assert (S, I, D) == (3, 1, 0)

plot_matrix(levenshtein_distance_matrix(first, second), first, second, path, mods)

### Error Rate

In this part you will use the Levenstein distance which you implemented in the first part to obtain an measure of __mistmatch__ or __error__ between a reference and an ASR hypothesis at the word, character and phone level.

#### Theoretical Recap

Suppose we have our reference sequence, relative to which we want to calculate the recognition error.

Why do you think the Levenshtein distance is not suitable for measuring the quality of the ASR system because the number of tokens in sentence can be different. Therefore, we need to normalize the Levenshtein distance by the length of the reference. This years our minimum edit distance **rate**.

Word (character, phoneme, morpheme, syllable) error rate can then be computed as:

$$
\mathrm{WER} = \frac{\mathrm{S} + \mathrm{I} + \mathrm{D}}{\mathrm{N}}, \text{where:} \\
\text{S is the number of substitutions,} \\
\text{I is the number of insertions,} \\
\text{D is the number of deletions,} \\
\text{N is the length of reference.}
$$

We can assess the error rate between two sequences at multiple levels - the word level (WER), the character (letter) level (CER) and the phoneme level (PER). WER is the most strict, as even a partially correct word is considered incorrect. Phone Error rate is in some sense the most lenient, as is measures whether the reference and hypthesis "sound" the same. Note that error can be assessed at other levels, like morphemes, lexemes and syllables, for example. Which metric is appropriate depends on the choice of language and what is being measured.  

#### WER vs CER vs PER

Implement the `error_rate` function, which will calculate the prediction error for a given sequence of tokens (words, characters or phonemes) by formula above.

In [None]:
def error_rate(reference: Iterable, predicted: Iterable) -> float:
    assert len(reference) > 0

    ######################################
    # <YOUR CODE>
    ######################################

In [None]:
# Assess error rate function

# Calculate WER and CER
first = "you will not be forced to learn machine learning"
second = "you'll not be forced to learn my sheen learning"

wer = np.round(error_rate(first.split(), second.split()), 4)
cer = np.round(error_rate(first, second), 4)

# Calculate PER - we provide a phonetic transcription of the above sentences.
first = grapheme_to_phoneme_model(first)
second = grapheme_to_phoneme_model(second)
print(first)
print(second)
per = np.round(error_rate(first, second), 4)

print('Word Error Rate:', wer)
print('Character Error Rate:', cer)
print('Phone Error Rate:', per)

assert np.allclose(wer, 0.4444, rtol=1e-5, atol=1e-5)
assert np.allclose(cer, 0.1875, rtol=1e-5, atol=1e-5)
assert np.allclose(per, 0.1316, rtol=1e-5, atol=1e-5)

## Connectionist Temporal Classification (CTC)

### Lecture recap

#### Problem statement

<!-- <p style="text-align:center;"><img src="http://drive.google.com/uc?export=view&id=1FFGXZsCgy-uQfCBp7F4w1gaJbIGv6CV2" height="200px" width="700px">   -->
<p style="text-align:center;"><img src="https://raw.githubusercontent.com/yandexdataschool/speech_course/main/week_04_asr/images/spectrogram_to_text.png" height="200px" width="700px">  

Define a modified label sequence $\omega'_{1:2L + 1}$:
- add blanks to the beginning and the end of the original label sequence $\omega_{1:L}$
- insert blanks between every pair of labels

<!-- <p style="text-align:center;"><img src="http://drive.google.com/uc?export=view&id=1CEhWtVYrSSkaRtEsJr5QwiH8lMaSQ_uN" height="150px" width="400px"> -->
<p style="text-align:center;"><img src="https://raw.githubusercontent.com/yandexdataschool/speech_course/main/week_04_asr/images/grapheme_eps.png" height="150px" width="400px">


Define $\alpha_t(s)$ as the probability of all paths of length $t$ which go through state $\omega_s'$:

Denote a sequence of **acoustic features** or **observations** as

$$
    \mathbf{X}_{1:T} = \{x_1, \ldots, x_T\}
$$

Define a mapping $\mathcal{M}$ between words $\mathbf{w}$ and speech units $\omega_{1:L}$:

$$
    \{\omega^{(q)}_{1:L_q}\}^Q_{q = 1} = \mathcal{M}(\mathbf{w})
$$

$$
    \{\mathbf{w}^{(p)}\}^P_{p = 1} = \mathcal{M}^{-1}(\omega_{1:L})
$$

For some choices of speech units this mapping is not 1-to-1 ($Q > 1$, $P > 1$). A possible pair of text (green) and speech units (yellow):

<!-- <p style="text-align:center;"><img src="http://drive.google.com/uc?export=view&id=1HWD_SFZzids3Nz67BK_NQ5awkw6yUvLo" height="200px" width="600px"> -->
<p style="text-align:center;"><img src="https://raw.githubusercontent.com/yandexdataschool/speech_course/main/week_04_asr/images/text_to_letters.png" height="200px" width="600px">

Automated speech recognition (ASR) is a **discriminative** task $\rightarrow$ "Which sequence $\mathbf{\hat w}$ is likely given the audio?":

$$
    \mathbf{\hat w} = \mathcal{M}^{-1}(\hat \omega_{1:L}), \quad \hat \omega_{1:L} = \arg \max_{\hat \omega_{1: L}} P(\hat \omega_{1:L} | \mathbf{X}_{1: T}; \theta),
$$

where $\theta$ denotes the parameters of the model we are building to solve the problem.

#### Discriminative state-space models

How feature vectors $\mathbf{X}_{1: T}$ and speech units $\omega_{1:L}$ relate or **align** to each other? Two common approaches to constructing models which can align:
- state-space models
- neural attention mechanisms

State-space models represent the space of various alignments in the form of a table (called **trellis**), the rows of which correspond to phonemes, and the columns are observed variables. One alignment is the path in this table from the upper left corner to the lower right.

<!-- <p style="text-align:center;"><img src="http://drive.google.com/uc?export=view&id=1npycuLvYq_-3p_xd6bouR21tfVeOvMUd" height="300px" width="600px"> -->
<p style="text-align:center;"><img src="https://raw.githubusercontent.com/yandexdataschool/speech_course/main/week_04_asr/images/ctc_trellis.png" height="300px" width="600px">

Denote a set of all paths in trellis that map onto the phoneme sequence $\omega_{1:L}$ as $\mathcal{A}(\omega_{1:L})$, and let $\pi_{1:T} \in \mathcal{A}(\omega_{1:L})$ be an element of this set. Then a discriminative state-space system models $P(\omega_{1:L} | \mathbf{X}_{1: T}; \theta)$ as 

$$
    P(\omega_{1:L} | \mathbf{X}_{1: T}; \theta) = \sum_{\pi_{1:T} \in \mathcal{A}(\omega_{1:L})} P(\pi_{1:T} | \mathbf{X}_{1:T}; \theta)
$$

Imagine that we have a recurrent neural network parametrized with $\theta$. The network outputs a distribution $P(z_t|x_t; \theta)$ over possible speech units $\omega$ for each frame $x_t$:

<!-- <p style="text-align:center;"><img src="http://drive.google.com/uc?export=view&id=153E-ailMiLPg3joPSx016lGv6S4vXVD2" height="300px" width="550px"> -->
<p style="text-align:center;"><img src="https://raw.githubusercontent.com/yandexdataschool/speech_course/main/week_04_asr/images/ctc_rnn.png" height="300px" width="550px">

CTC is a discriminative state-space model defined as:
    
$$
    P(\omega_{1:L} | \mathbf{X}_{1: T}; \theta) = \sum_{\pi_{1:T} \in \mathcal{A}(\omega_{1:L})} \prod_{t = 1}^T P(z_t = \pi_t| x_t; \theta)
$$
    
- CTC assumes all states conditionally independent
- Alignment free -- does not need prior alignment for training

### CTC Forward-Backward Algorithm


#### Forward Algorithm

$$
    \alpha_t(s) = P(\omega_{1:s/2}, \pi_t = \omega_s' | \mathbf{X}_{1:T}, \theta) = \sum_{\pi_{1:t - 1} \in \mathcal{A}(\omega_{1:s/2}), \, \pi_t = \omega_s'}  P(\pi_{1:t} | \mathbf{X}_{1:T}, \theta)
$$

Note that despite the fact that we have moved to the extended sequence $\omega'$, we are still interested in maximizing the probability of alignments to the original sequence. And step $s$ in the new sequence corresponds to step $s/2$ in the old sequence (rounded to the bottom).

The CTC forward algorithm recursively computes the forward variable $\alpha_t(s)$.

<!-- <p style="text-align:center;"><img src="http://drive.google.com/uc?export=view&id=1QaW0mJ9c3Z0KJVk3pUSyC_kS_pFC_QxS" height="400px" width="600px">   -->
<p style="text-align:center;"><img src="https://raw.githubusercontent.com/yandexdataschool/speech_course/main/week_04_asr/images/ctc_trellis_alpha.png" height="400px" width="600px">  

**Initialization.** We allow all prefixes to start with either a blank ($\epsilon$) or the first symbol in $\omega_{1:L}$. Also note that $\alpha_t(s) = 0,\ \forall s < (2L + 1) - 2(T - t) - 1$, because these variables correspond to states for which there are not enough time-steps left to complete the sequence.

This gives us the following rules for initialization:

$$
  \begin{aligned}
    &\alpha_t(0) = 0, \forall t & \\
    &\alpha_1(1) = P(z_1 = \epsilon | \mathbf{X}_{1:T}), &\\
    &\alpha_1(2) = P(z_1 = \omega^{'}_2 | \mathbf{X}_{1:T}), &\\
    &\alpha_1(s) = 0,\ \forall s > 2 &\\
    &\alpha_t(s) = 0,\ \forall s < (2L + 1) - 2(T - t) - 1 &  \text{top right zeros}\\
  \end{aligned}
$$

**Recursion.** 

$$
  \begin{aligned}
    &\alpha_t(s) = \left \{
  \begin{aligned}
    &\big(\alpha_{t-1}(s) + \alpha_{t-1}(s-1) \big) P(z_t = \omega^{'}_s | \mathbf{X}_{1:T}) & \text{if}\ \omega_s^{'} = \epsilon\ \text{or}\
    \omega_s^{'} = \omega_{s-2}^{'} \\
    &\big(\alpha_{t-1}(s) + \alpha_{t-1}(s-1) + \alpha_{t-1}(s-2)\big) P(z_t = \omega^{'}_s | \mathbf{X}_{1:T}) & \text{otherwise}\\
  \end{aligned} \right. 
  \end{aligned}
$$


<!-- <p style="text-align:center;"><img src="http://drive.google.com/uc?export=view&id=1Tre3oFHyjigpqG-GI1xVrOchZAMnRYBK" height="250px" width="650px"> -->
<p style="text-align:center;"><img src="./images/ctc_alpha_step.png" height="250px" width="650px">

#### Backward Algorithm

Define $\beta_t(s)$ as the probability of all valid alignments $\omega'_{s:L}$ starting in state $\omega_s'$:

$$
    \beta_t(s) = P(\omega_{s/2:L}, \pi_t = \omega'_s | \mathbf{X}_{1:T}, \theta) = \sum_{\pi_{t + 1:T} \in \mathcal{A}(\omega_{s/2:L}), \, \pi_t = \omega_s'} P(\pi_{t + 1:T} | \mathbf{X}_{1:T}, \theta)
$$

The CTC backward algorithm recursively computes the backward variable $\beta_t(s)$:

<!-- <p style="text-align:center;"><img src="http://drive.google.com/uc?export=view&id=11x3TGAzL2LWfO0ZKpPHegOvv8Iw6ZC0X" height="400px" width="600px"> -->
<p style="text-align:center;"><img src="https://raw.githubusercontent.com/yandexdataschool/speech_course/main/week_04_asr/images/ctc_trellis_beta.png" height="400px" width="600px">


The formulas for backward algorithm are as follows:

$$
  \begin{aligned}
    &\beta_T(2L+1) = 1 &\\
    &\beta_T(2L) = 1 & \\
    &\beta_T(s) = 0, \forall s < 2L &\\
    &\beta_t(s) = 0,\ \forall s > 2t &\\
    &\beta_t(2L+2) = 0,\ \forall t  & \text{bottom left zeros} \\
    &\beta_t(s) = \left \{
  \begin{aligned}
    &\big(\beta_{t+1}(s) + \beta_{t+1}(s+1) \big) P(z_t = \omega^{'}_s | \mathbf{X}_{1:T}) & \text{if}\ \omega_s^{'} = \epsilon\ \text{or}\
    \omega_s^{'} = \omega_{s+2}^{'} \\
    &\big(\beta_{t+1}(s) +\beta_{t+1}(s+1) + \beta_{t+1}(s+2)\big) P(z_t = \omega^{'}_s | \mathbf{X}_{1:T}) & \text{otherwise}\\
  \end{aligned} \right. 
  \end{aligned}
$$

<!-- <p style="text-align:center;"><img src="http://drive.google.com/uc?export=view&id=1h7OBZZ02dwZ1mDhRYh7yTy7-UW4NmbXm" height="250px" width="650px">  -->
<p style="text-align:center;"><img src="https://raw.githubusercontent.com/yandexdataschool/speech_course/main/week_04_asr/images/ctc_beta_step.png" height="350px" width="600px"> 

#### Alignment and Loss Computation

Use your newfound knowledge of the CTC forward-backward algorithm to obtain a soft-alignment

Remember, that the forward variable is computed as follows:

The probability of all paths passing through a state $\pi_t = \omega_s'$ is the product of forward and backward variables:

$$
    \alpha_t(s) \beta_t(s) = \sum_{\pi_{1:T} \in \mathcal{A}(\omega_{1:L}), \,\pi_t=\omega_s'} P(\pi_{1:T} | \mathbf{X}_{1:T}, \theta)
$$

Then, for any $t$, sum of all such products yields total probability:

$$
     \sum_{s = 1}^{2 L + 1} \alpha_t(s) \beta_t(s) = P(\omega_{1:L} | \mathbf{X}_{1:T}, \theta)
$$

We can also use normalized $\alpha_t(s) \beta_t(s)$ as a measure of **soft-alignment**:

$$
    \text{align}_t(s) = \frac{\alpha_t(s) \beta_t(s)}{\sum_{s = 1}^{2 L + 1} \alpha_t(s) \beta_t(s)}
$$

You should get something like

<!-- <p style="text-align:center;"><img src="http://drive.google.com/uc?export=view&id=1HAIl9UPReiFQ7dNOZFGfvUWDurFDBZYM" height="300px" width="800px">  -->
<p style="text-align:center;"><img src="https://raw.githubusercontent.com/yandexdataschool/speech_course/main/week_04_asr/images/soft_align.png" height="300px" width="800px"> 

$$
  \text{align}_t(s) = \frac{\alpha_t(s)\beta_t(s)}{\sum_{s}\alpha_t(s)\beta_t(s)}
$$


Doing the computation in probability space can be numerically unstable, so you should do it in Log-Space using the
provided logsumexp operation. Remember to return to prob space at the end. 

### Implement CTC Forward Algorithm

In [None]:
NEG_INF = utils.NEG_INF
BLANK_SYMBOL = utils.BLANK_SYMBOL

tokenizer = utils.CTCTokenizer()

In [None]:
def forward_algorithm(sequence: List[int], matrix: np.ndarray) -> np.ndarray:
    """
    :param sequence: a string converted to an index array by Tokenizer
    :param matrix: A matrix of shape (K, T) with probability distributions over phonemes at each moment of time.
    :return: the result of the forward pass of shape (2 * len(sequence) + 1, T)
    """
    # Turn probs into log-probs
    matrix = np.log(matrix)

    blank = tokenizer.get_symbol_index(BLANK_SYMBOL)
    mod_sequence = utils.modify_sequence(sequence, blank)

    # Initialze
    alphas = np.full([len(mod_sequence), matrix.shape[1]], NEG_INF)

    for t in range(matrix.shape[1]):
        for s in range(len(mod_sequence)):
            # First Step
            if t == 0:
                ########################
                # YOUR CODE HERE
                ########################
            # Upper diagonal zeros
            elif ...: # CONDITION
                ########################
                # YOUR CODE HERE
                ########################
            else:
                # Need to do this stabily
                if s == 0:
                    ########################
                    # YOUR CODE HERE
                    ########################
                elif s == 1:
                    ########################
                    # YOUR CODE HERE
                    ########################
                else:
                    ########################
                    # YOUR CODE HERE HINT - THERE IS ANOTHER IFELSE
                    ########################
    return alphas

### Implement The CTC Backward Algorithm

In [None]:
def backward_algorithm(sequence: List[int], matrix: np.ndarray) -> np.ndarray:
    """
    :param sequence: a string converted to an index array by Tokenizer
    :param matrix: A matrix of shape (K, T) with probability distributions over phonemes at each moment of time.
    :return: the result of the backward pass of shape (2 * len(sequence) + 1, T)
    """
    matrix = np.log(matrix)
    blank = tokenizer.get_symbol_index(BLANK_SYMBOL)
    mod_sequence = utils.modify_sequence(sequence, blank)
    betas = np.full([len(mod_sequence), matrix.shape[1]], NEG_INF)

    for t in reversed(range(matrix.shape[1])):
        for s in reversed(range(len(mod_sequence))):
            # First Step
            if t == matrix.shape[1] - 1:
                ########################
                # YOUR CODE HERE
                ########################
            # Lower Diagonal Zeros
            elif :  # CONDITION
                ########################
                # YOUR CODE HERE
                ########################
            else:
                if s == len(mod_sequence) - 1:
                    ########################
                    # YOUR CODE HERE
                    ########################
                elif s == len(mod_sequence) - 2:
                    ########################
                    # YOUR CODE HERE
                    ########################
                else:
                    ########################
                    # YOUR CODE HERE HINT - THERE IS ANOTHER IFELSE
                    ########################
    return betas

### Obtain Soft-Alignment


In [None]:
def soft_alignment(labels_indices: List[int], matrix: np.ndarray) -> np.ndarray:
    """
    Returns the alignment coefficients for the input sequence
    """
    alphas = forward_algorithm(labels_indices, matrix)
    betas = backward_algorithm(labels_indices, matrix)

    # Move from log space back to prob space
    align = np.exp(alphas + betas)

    # Normalize Alignment
    align = align / np.sum(align, axis=0, keepdims=True)

    return align

In [None]:
#!L
# Test your implementation

# Load numpy matrix, add axis [classes,time]
matrix = np.loadtxt(os.path.join(data_directory, 'test_matrix.txt'))

# Create label_sequence
labels_indices = tokenizer.text_to_indices('there se ms no good reason for believing that twillc ange')

align = soft_alignment(labels_indices, matrix)
f, ax = plt.subplots(1, 2, dpi=75, figsize=(15, 5))

im = ax[0].imshow(align, aspect='auto', cmap='viridis', interpolation='nearest')
ax[0].set_title("Alignment")
ax[0].set_ylabel("Phonemes")
ax[0].set_xlabel("Time")
f.colorbar(im, ax=ax[0])

im = ax[1].imshow(np.log(align), aspect='auto', cmap='viridis', interpolation='nearest')
ax[1].set_title("Alignment in log scale")
ax[1].set_ylabel("Phonemes")
ax[1].set_xlabel("Time")
f.colorbar(im, ax=ax[1])

plt.tight_layout()

ref_align = np.loadtxt(os.path.join(data_directory, 'soft_alignment.txt'))
assert np.allclose(ref_align, align)

### Implementing a Decoder for CTC model (5 points)

Before you can start having fun with a CTC ASR model, you first need to make sure that you can correctly "decode" or generate text from a working model. This can be done in two ways - using a Greedy Decoder, which is simple and fast, or using a Prefix Beam Search decoder, which is slower, but takes advantages of the fact that multiple plath though a CTC trellis can map to the sample sentence. In the following exercise you will implement both decoders.

#### Greedy Best-Path Decoder (1 point)

After we’ve trained the model, we’d like to use it to find a likely output for a given input. Your goal is to implement a Greedy Best-Path decoder. Remember than in CTC the joint distribution over states factors out into a product of marginals:

$${\tt P}(\mathbf{z}_{1:T}|\mathbf{X}_{1:T},\mathbf{\theta}) = \prod_{t = 1}^T{\tt P}(z_t|\mathbf{X}_{1:T},\mathbf{\theta})$$

We can take the most likely output at each time-step, which gives us the alignment with the highest probability:

$$\mathbf{\pi}^*_{1:T} = \arg \max_{\mathbf{\pi}_{1:T} } \prod_{t=1}^T {\tt P}(z_t = \pi_t|\mathbf{X}_{1:T})$$

Then merge repeats and remove blanks.

In [None]:
def greedy_decoder(output: torch.Tensor, labels: List[torch.Tensor],
                   label_lengths: List[int], collapse_repeated: bool = True) -> Tuple[np.ndarray, np.ndarray]:
    """
    :param output: torch.Tensor of Probs or Log-Probs of shape [batch, time, classes]
    :param labels: list of label indices converted to torch.Tensors
    :param label_lengths: list of label lengths (without padding)
    :param collapse_repeated: whether the repeated characters should be deduplicated
    :return: the result of the decoding and the target sequence
    """
    blank_label = tokenizer.get_symbol_index(BLANK_SYMBOL)

    # Get max classes
    ########################
    # YOUR CODE HERE
    arg_maxes = ...
    ########################

    decodes = []
    targets = []

    # For targets and decodes remove repeats and blanks
    for i, args in enumerate(arg_maxes):
        decode = []
        true_labels = labels[i][:label_lengths[i]].tolist()
        targets.append(tokenizer.indices_to_text(true_labels))

        # Remove repeats, then remove blanks
        ########################
        # YOUR CODE HERE
        ########################

        decodes.append(tokenizer.indices_to_text(decode))
    return decodes, targets

Testing the greedy decoding

In [None]:
# Load numpy matrix, make its shape be in the form of [batch, classes, time]
matrix = np.loadtxt(os.path.join(data_directory, 'test_matrix.txt'))[np.newaxis, :, :]

# Turn into Torch Tensor of shape [batch, time, classes]
matrix = torch.Tensor(matrix).transpose(1, 2)

# Convert indices into torch.Tensor
labels_indices = torch.Tensor(tokenizer.text_to_indices('there seems no good reason for believing that it will change'))

# Run the Decoder
decodes, targets = greedy_decoder(matrix, [labels_indices], [len(labels_indices)])

assert decodes[0] == 'there se ms no good reason for believing that twillc ange'
assert targets[0] == 'there seems no good reason for believing that it will change'

#### Prefix (Beam Search) Decoding With LM (4 points)

The greedy decoder doesn't take into account the fact that a single output can have many alignments. For example, imagine that the true label for a phoneme sequence is $[a]$. Assume that alignments $[a, a, \epsilon]$ and $[a, a, a]$ individually have lower probability than the probability $[b, b, b]$, but the sum of their probabilities is higher. In this case, the greedy decoder would choose the wrong alignment $[b, b, b]$ and propose a wrong hypothesis $[b]$ instead of $[a]$.

Prefix decoding considers probabilities of multiple paths and merges them. It can also add external language model.

<!-- <p style="text-align:center;"><img src="http://drive.google.com/uc?export=view&id=1_X9NfoSe8HLKfAErDtr0rBsIxoejA1kq" height="500px" width="900px">  -->
<p style="text-align:center;"><img src="https://raw.githubusercontent.com/yandexdataschool/speech_course/main/week_04_asr/images/beam_search.png" height="500px" width="900px"> 

Prefix decoding algorithm has 3 nested loops:
- over time - we extend prefixes up to T times
- over prefixes in the beam
- over possible extensions of a prefix

Each prefix can be extended in three possible ways:
- with a blank
- with a repeating character
- with a non-repeating character

We must keep track of two probabilities per prefix:
- The probability of prefix ending with blank $P_b(t, s)$. 
- The probability of prefix not ending with blank $P_{nb}(t, s)$

Here $t$ denotes time step and $s$ denotes a prefix we got after $t$ time steps.

We start with an empty string prefix: 

$$
    P_b(0, \text{""}) = 1
$$
$$
    P_{nb}(0, \text{""}) = 0
$$

If we extend $s$ with a blank, update the probability of ending with a blank:

$$
    P_b(t, s) = P(\epsilon | x_t) \cdot (P_b(t - 1, s) + P_{nb}(t - 1, s))
$$

The prefix $s$ is not updated because blanks are eliminated in the end.

If we extend with a repeat character $c$, there are two options:
1. The previous symbol is a blank, and now we extend the prefix
2. The previous symbol is not a blank, so we don't extend the prefix (repeats are merged)

In this case, the probability $P_{nb}$ is updated as follows:

$$
    P_{nb}(t, s + c) = P(c | x_t) \cdot P_b(t - 1, s)
$$
$$
    P_{nb}(t, s) = P(c | x_t) \cdot P_{nb}(t - 1, s)
$$

Finally, consider extending $s$ at time $t$ with a non-repeat character. It can follow both blank and non-blank characters, so the probability $P_{nb}$ is updated as follows:

$$
    P_{nb}(t, s + c) = P(c | x_t) \cdot (P_b(t - 1, s) + P_{nb}(t - 1, s))
$$

We may also want to apply a language model during decoding, but only in the case we have a new complete word. This happens when the current symbol is a non-repeat space. As CTC is a discriminative model, LMs can only be integrated as a heuristic:

$$
    \mathbf{w}^* = \arg \max_\mathbf{w} \underbrace{P(\mathbf{w} | \mathbf{X}_{1:T})}_{\text{CTC prob}} \cdot \underbrace{P(\mathbf{w})^{\alpha}}_{\text{LM prob}} \cdot \underbrace{|\mathbf{w}|^\beta}_{\text{Length correction}}
$$

The formula for an update of $P_{nb}$ when LM is used and the current symbol is a non-repeat space:

$$
    P_{nb}(t, s + c) = P_{\text{LM}}(s)^\alpha \cdot |s|^\beta \cdot P(c | x_t) \cdot (P_b(t - 1, s) + P_{nb}(t - 1, s))
$$

In [None]:
LanguageModel = TypeVar("LanguageModel")
# Helper function

class Beam:
    def __init__(self, beam_size: int) -> None:
        self.beam_size = beam_size

        fn = lambda : (NEG_INF, NEG_INF)
        self.candidates = defaultdict(fn)
        self.top_candidates_list = [(tuple(), (0.0, NEG_INF))]

    def get_probs_for_prefix(self, prefix: str) -> Tuple[float, float]:
        p_blank, p_not_blank = self.candidates[prefix]
        return p_blank, p_not_blank

    def update_probs_for_prefix(self, prefix: str, next_p_blank: float, next_p_not_blank: float) -> None:
        self.candidates[prefix] = (next_p_blank, next_p_not_blank)

    def update_top_candidates_list(self) -> None:
        top_candidates = sorted(
            self.candidates.items(),
            key=lambda x: utils.logsumexp(*x[1]),
            reverse=True
        )
        self.top_candidates_list = top_candidates[:self.beam_size]


def calculate_probability_score_with_lm(lm: LanguageModel, prefix: str) -> float:
    text = tokenizer.indices_to_text(prefix).upper().strip()    # Use upper case for LM and remove the trailing space
    lm_prob = lm.log_p(text)
    score = lm_prob / np.log10(np.e)    # Convert to natural log, as ARPA LM uses log10
    return score

In [None]:
def decode(probs: np.ndarray, beam_size: int = 5, lm: Optional[LanguageModel] = None,
           prune: float = 1e-5, alpha: float = 0.1, beta: float = 2):
    """
    :param probs: A matrix of shape (T, K) with probability distributions over phonemes at each moment of time.
    :param beam_size: the size of beams
    :lm: arpa language model
    :prune: the minimal probability for a symbol at which it can be added to a prefix
    :alpha: the parameter to de-weight the LM probability
    :beta: the parameter to up-weight the length correction term
    :return: the prefix with the highest sum of probabilites P_blank and P_not_blank
    """
    T, S = probs.shape
    probs = np.log(probs)
    blank = tokenizer.get_symbol_index(BLANK_SYMBOL)
    space = tokenizer.get_symbol_index(" ")
    prune = NEG_INF if prune == 0.0 else np.log(prune)

    beam = Beam(beam_size)
    for t in range(T):
        next_beam = Beam(beam_size)

        for s in range(S):
            p = probs[t, s]
            if p < prune:    # Prune the vocab
                continue

            for prefix, (p_blank, p_not_blank) in beam.top_candidates_list:
                if s == blank:
                    p_b, p_nb = next_beam.get_probs_for_prefix(prefix)
                    next_beam.update_probs_for_prefix(
                        prefix=,  # YOUR CODE
                        next_p_blank=,  # YOUR CODE
                        next_p_not_blank=,  # YOUR CODE
                    )
                    continue

                end_t = prefix[-1] if prefix else None
                n_prefix = prefix + (s,)

                if s == end_t:
                    p_b, p_nb = next_beam.get_probs_for_prefix(n_prefix)
                    next_beam.update_probs_for_prefix(
                        prefix=,  # YOUR CODE
                        next_p_blank=,  # YOUR CODE
                        next_p_not_blank=,  # YOUR CODE
                    )

                    p_b, p_nb = next_beam.get_probs_for_prefix(prefix)
                    next_beam.update_probs_for_prefix(
                        prefix=,  # YOUR CODE
                        next_p_blank=,  # YOUR CODE
                        next_p_not_blank=,  # YOUR CODE
                    )
                elif s == space and end_t is not None and lm is not None:
                    p_b, p_nb = next_beam.get_probs_for_prefix(n_prefix)
                    score = calculate_probability_score_with_lm(lm, n_prefix)
                    length = len(tokenizer.indices_to_text(prefix))

                    next_beam.update_probs_for_prefix(
                        prefix=,  # YOUR CODE
                        next_p_blank=,  # YOUR CODE
                        next_p_not_blank=,  # YOUR CODE
                    )
                else:
                    p_b, p_nb = next_beam.get_probs_for_prefix(n_prefix)
                    next_beam.update_probs_for_prefix(
                        prefix=,  # YOUR CODE
                        next_p_blank=,  # YOUR CODE
                        next_p_not_blank=,  # YOUR CODE
                    )

        next_beam.update_top_candidates_list()
        beam = next_beam

    best = beam.top_candidates_list[0]
    return best[0], -utils.logsumexp(*best[1])


def beam_search_decoder(probs: np.ndarray, labels: List[List[int]], label_lengths: List[int],
                        input_lengths: List[int], lm: LanguageModel, beam_size: int = 5,
                        prune: float = 1e-3, alpha: float = 0.1, beta: float = 0.1):
    probs = probs.cpu().detach().numpy()
    decodes, targets = [], []

    for i, prob in enumerate(probs):
        targets.append(tokenizer.indices_to_text(labels[i][:label_lengths[i]].tolist()))
        int_seq, _ = decode(prob[:input_lengths[i]], lm=lm, beam_size=beam_size, prune=prune, alpha=alpha, beta=beta)
        decodes.append(tokenizer.indices_to_text(int_seq))

    return decodes, targets

In [None]:
# Create LM
alm = arpa.loadf(os.path.join(data_directory, '3-gram.pruned.1e-7.arpa'))[0]
alm._unk = '<UNK>'

Testing prefix (beam search) decoding

In [None]:
# Load numpy matrix, add axis [batch, classes, time]
matrix = np.loadtxt(os.path.join(data_directory, 'test_matrix.txt'))[np.newaxis, :, :]

# Turn into Torch Tensor of shape [batch, time, classes]
matrix = torch.Tensor(matrix).transpose(1, 2)

labels_indices = torch.Tensor(tokenizer.text_to_indices('there seems no good reason for believing that it will change'))

# Run the Decoder
decodes, targets = beam_search_decoder(
    matrix, [labels_indices], [len(labels_indices)], [matrix.size()[1]],
    lm=None, beam_size=5, prune=1e-3, alpha=0.1, beta=0.3
)

# assert decodes[0] == 'there se ms no good reason for believing that twillc ange'  # greedy
assert decodes[0] == 'there se ms no good reason for believing that twil c ange'
assert targets[0] == 'there seems no good reason for believing that it will change'

decodes, targets = beam_search_decoder(
    matrix, [labels_indices], [len(labels_indices)], [matrix.size()[1]],
    lm=alm, beam_size=5, prune=1e-3, alpha=0.1, beta=0.3
)

assert decodes[0] == 'there seems no good reason for believing that twil c ange'
assert targets[0] == 'there seems no good reason for believing that it will change'

### Examples

- Jasper https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/jasper.html
- DeepSpeech2 https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/deepspeech2.html
- VGGTransformer https://github.com/facebookresearch/fairseq/blob/main/examples/speech_recognition/models/vggtransformer.py

## RNN-Transducer

### Lecture recap

#### Alignment

Let $\mathbf{x} = (x_1, x_2, \ldots, x_T)$ be a length $T$ input sequence of arbitrary length beloging to the set $X^*$ of all sequences over some input space $X$. Let $\mathbf{y} = (y_1, \ldots, y_U)$ be a length $U$ output sequence belonging to the set $Y^*$ of all sequences over some output space $Y$.

Define the *extended output space* $\overline Y$ as $Y \cup \emptyset$, where $\emptyset$ denotes the null output. The intuitive meaning of $\emptyset$ is 'output nothing'. The sequence $(y_1, \emptyset, \emptyset, y_2, \emptyset, y_3) \in \overline Y^*$ is therefore equivalent to $(y_1, y_2, y_3) \in Y^*$. We refer to the elements $\mathbf{a} \in \overline Y^*$ as *alignments*, since the location of the null symbols determines an alignment between the input and output sequences.

As we saw in CTC, various alignments can be represented in the form of a table called trellis. An example of how an RNN-T trellis may look like:

<!-- <p style="text-align:center;"><img src="http://drive.google.com/uc?export=view&id=1CfXfkePAESz2n20AABVUw9SaZ_xszxwf"> -->
<p style="text-align:center;"><img src="https://raw.githubusercontent.com/yandexdataschool/speech_course/main/week_04_asr/images/rnnt_trellis_1.png">
    
    
Possible alignments in that trellis:
    
<!-- <p style="text-align:center;"><img src="http://drive.google.com/uc?export=view&id=1ipRlSrznwmoD5gCk7k6G06JeUtqPzDQq"> -->
<p style="text-align:center;"><img src="https://raw.githubusercontent.com/yandexdataschool/speech_course/main/week_04_asr/images/rnnt_trellis_2.png">
    
The final label can be determined by simply removing the blank characher:
    
$$
    C \emptyset \emptyset A \emptyset T \emptyset \to CAT
$$
$$
    \emptyset \emptyset \emptyset C A T \emptyset \to CAT
$$
    
Given $\mathbf{x}$, the RNN transducer defines a conditional distribution $P(\mathbf{a} \in \overline Y^* | \mathbf{x})$. This distribution is then collapsed onto the following distribution over $Y^*$:
    
$$
    P(\mathbf y \in Y^* | \mathbf x) = \sum_{\mathbf a \in \mathcal{B}^{-1}(\mathbf y)} P(\mathbf a | \mathbf x),
$$
    
where $\mathcal B: \overline Y^* \mapsto Y^*$ is a function that removes the null symbols from the alignments in $Y^*$.


#### Architecture

<!-- <p style="text-align:center;"><img src="http://drive.google.com/uc?export=view&id=1P2aztCi9Z7ookMbHmWBcGtSmG_JHIiMj"> -->
<p style="text-align:center;"><img src="https://raw.githubusercontent.com/yandexdataschool/speech_course/main/week_04_asr/images/rnnt_arch.png">

The RNN-T model consists of three neural networks: Encoder, Predictor and Joiner. The Encoder converts the acoustic feature $x_t$ into a high-level representation $f_t$, where $t$ is time index:

$$
    f_t = \mathrm{Encoder}(x_t)
$$

The Predictor works like an RNN language model, which produces a high-level representation $g_u$ by conditioning on the previous non-blank target $y_{u - 1}$ predicted by the RNN-T model, where $u$ is output label index:

$$
    g_u = \mathrm{Predictor}(y_{u - 1})
$$

Note that the input sequence for the predictor **is prepended with the special symbol** $\langle s \rangle$ that defines the start of a sentence.

The Joiner is a feed forward network that combines the Encoder output $f_t$ and the Predictor output $g_u$ as

$$
    h_{t, u} = \mathrm{Joiner}(f_t, g_u) = \mathrm{FeedForward}(\mathrm{ReLU}(f_t + g_u))
$$

The final posterior for each output token $y$ is obtained after applying the softmax operation:

$$
    P(y | t, u) = \mathrm{softmax}(h_{t, u})
$$
    
where $P(y | t, u)$ is a distribution of probabilities to emit $y \in \overline Y$ at time step $t$ after $u$ previously generated characters, $t \in [1, T], u \in [0, U]$.

<!-- <p style="text-align:center;"><img src="http://drive.google.com/uc?export=view&id=1tn1wS3fCVFJGwrYumf5Im6gOFZsxRMV-"> -->
<p style="text-align:center;"><img src="./images/rnnt_probs.png">

We will further need to work with probabilities of individual tokens $y$ for different $t$ and $u$. Instead of writing each time something like $P(y = C | t = 1, u = 0)$, we will, for the sake of simplicity, write it as $P(C | 1, 0)$.

#### Training: forward-backward algorithm

The loss function of RNN-T is the negative log posterior of output label sequence $\mathbf y$ given acoustic feature $\mathbf x$:

$$
    \mathcal L = -\ln P(\mathbf y \in Y^* | \mathbf x) = -\ln \sum_{\mathbf a \in \mathcal{B}^{-1}(\mathbf y)} P(\mathbf a | \mathbf x)
$$

To determine $P(\mathbf a | \mathbf x)$ for an arbitrary alignment $\mathbf a$, we need to multiply the probabilities $P(y | t, u)$ of each symbol across the path:

<!-- <p style="text-align:center;"><img src="http://drive.google.com/uc?export=view&id=1O-aykP5Wods7ZESCJDBsBw2MeBo5egW4"> -->
<p style="text-align:center;"><img src="https://raw.githubusercontent.com/yandexdataschool/speech_course/main/week_04_asr/images/rnnt_trellis_probs.png">

$$
    \mathbf a = C \emptyset \emptyset A \emptyset T \emptyset
$$
    
$$
    P(\mathbf a | \mathbf x) = P(C | 1, 0) \cdot P(\emptyset | 1, 0) \cdot P(\emptyset | 2, 1) \cdot P(A | 3, 1) \cdot P(\emptyset | 3, 2) \cdot P(T | 3, 2) \cdot P(\emptyset | 4, 3)
$$

There are usually too many possible alignments to compute the loss function by just adding them all up directly. We will use dynamic programming to make this computation feasible.

Define the *forward variable* $\alpha(t, u)$ as the probability of outputting $\mathbf y_{[1:u]}$ during $\mathbf f_{[1:t]}$. The forward variables for all $1 \le t \le T$ and $0 \le u \le U$ can be calculated recursively using

$$
    \alpha(t, u) = \alpha(t - 1, u) P(\emptyset | t - 1, u) + \alpha(t, u - 1) P(y_{u - 1} | t, u - 1)
$$

with initial condition $\alpha(1, 0) = 1$. Here $y_{u - 1}$ is the $(u - 1)$-th symbol from the ground truth label $\mathbf y$.

The total output sequene probability is equal to the forward variable at the terminal node:

$$
    P(\mathbf y | \mathbf x) = \alpha(T, U) P(\emptyset | T, U)
$$

Define the *backward variable* $\beta(t, u)$ as the probability of outputting $\mathbf y_{[u + 1: U]}$ during $\mathbf f_{[t:T]}$. Then

$$
    \beta(t, u) = \beta(t + 1, u) P(\emptyset | t, u) + \beta(t, u + 1) P(y_u | t, u)
$$

with initial condition $\beta(T, U) = P(\emptyset | T, U)$. The final value is $\beta(1, 0)$.

From the definition of the forward and backward variables it follows that their product $\alpha(t, u) \beta(t, u)$ at any point $(t, u)$ in the output lattice is equal to the probability of emitting the complete output sequence *if $y_u$ is emitted during transcription step $t$*.

### RNN-T Forward-Backward Algorithm (2 points)

Implement forward and backward passes.


#### Implementation tips

- Note that all indices in the arrays you will work with in your code start with zeros. So, the initial condition for forward algorithm will be $\alpha(0, 0) = 1$ (and $\log \alpha(0, 0) = 0$) and the output value for backward algorithm will be $\beta(0, 0)$. The recurrent formulas stay the same. Also, don't be confused with the terminal node: you don't have to add it to $\alpha$- and $\beta$-arrays. The dynamic starts in the upper left corner for forward variables and in the lower right corner for backward variables.
- You will need to do everything in log-domain for calculations to be numercally stable. The function [np.logaddexp](https://numpy.org/doc/stable/reference/generated/numpy.logaddexp.html) might help you with it.

In [None]:
def forward(log_probs: torch.FloatTensor, targets: torch.LongTensor,
            blank: int = -1) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
    """
    :param log_probs: model outputs after applying log_softmax
    :param targets: the target sequence of tokens, represented as integer indexes
    :param blank: the index of blank symbol
    :return: Tuple[ln alpha, -(ln alpha(T, U) + ln P(blank | T, U))]. The latter term is loss value, which is -ln P(y | x)
    """
    max_T, max_U, D = log_probs.shape

    # here the alpha variable contains logarithm of the alpha variable from the formulas above
    alpha = np.zeros((max_T, max_U), dtype=np.float32)

    for t in range(1, max_T):
        # <YOUR CODE>

    for u in range(1, max_U):
        # <YOUR CODE>

    for t in range(1, max_T):
        for u in range(1, max_U):
            # <YOUR CODE>

    cost = ...  # <YOUR CODE>
    return alpha, cost


def backward(log_probs: torch.FloatTensor, targets: torch.LongTensor,
             blank: int = -1) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
    """
    :param log_probs: model outputs after applying log_softmax
    :param targets: the target sequence of tokens, represented as integer indexes
    :param blank: the index of blank symbol
    :return: Tuple[ln beta, -ln beta(0, 0)]. The latter term is loss value, which is -ln P(y | x)
    """
    max_T, max_U, D = log_probs.shape

    # here the beta variable contains logarithm of the beta variable from the formulas above
    beta = np.zeros((max_T, max_U), dtype=np.float32)
    beta[-1, -1] = log_probs[-1, -1, blank]

    for t in reversed(range(max_T - 1)):
        # <YOUR CODE>

    for u in reversed(range(max_U - 1)):
        # <YOUR CODE>

    for t in reversed(range(max_T - 1)):
        for u in reversed(range(max_U - 1)):
            # <YOUR CODE>

    cost = ...  # <YOUR CODE>
    return beta, cost

In [None]:
def run_test(logits: torch.FloatTensor, targets: torch.LongTensor,
             ref_costs: torch.FloatTensor, blank: int = -1) -> None:
    """
    :param logits: model outputs
    :param targets: the target sequence of tokens, represented as integer indexes
    :param ref_costs: the true values of RNN-T costs for test inputs
    :param blank: the index of blank symbol
    """
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
    cost = np.zeros(log_probs.shape[0])

    for batch_id in range(log_probs.shape[0]):
        alphas, cost_alpha = forward(log_probs[batch_id], targets[batch_id], blank=blank)
        betas, cost_beta = backward(log_probs[batch_id], targets[batch_id], blank=blank)
        np.testing.assert_almost_equal(cost_alpha, cost_beta, decimal=2)
        cost[batch_id] = cost_beta

    np.testing.assert_almost_equal(cost, ref_costs, decimal=2)

In [None]:
# Tests

'''
All logits in tests have shapes in the form (B, T, U, D) where

B: batch size
T: maximum source sequence length in batch
U: maximum target sequence length in batch
D: feature dimension of each source sequence element
'''

# test 1
logits = torch.FloatTensor([
    0.1, 0.6, 0.1, 0.1, 0.1,
    0.1, 0.1, 0.6, 0.1, 0.1,
    0.1, 0.1, 0.2, 0.8, 0.1,
    0.1, 0.6, 0.1, 0.1, 0.1,
    0.1, 0.1, 0.2, 0.1, 0.1,
    0.7, 0.1, 0.2, 0.1, 0.1,
]).reshape(1, 2, 3, 5)

targets = torch.LongTensor([[1, 2]])
ref_costs = torch.FloatTensor([5.09566688538])

run_test(
    logits=logits,
    targets=targets,
    ref_costs=ref_costs,
    blank=-1
)

# test 2
logits = torch.FloatTensor([
    0.065357, 0.787530, 0.081592, 0.529716, 0.750675, 0.754135, 0.609764, 0.868140,
    0.622532, 0.668522, 0.858039, 0.164539, 0.989780, 0.944298, 0.603168, 0.946783,
    0.666203, 0.286882, 0.094184, 0.366674, 0.736168, 0.166680, 0.714154, 0.399400,
    0.535982, 0.291821, 0.612642, 0.324241, 0.800764, 0.524106, 0.779195, 0.183314,
    0.113745, 0.240222, 0.339470, 0.134160, 0.505562, 0.051597, 0.640290, 0.430733,
    0.829473, 0.177467, 0.320700, 0.042883, 0.302803, 0.675178, 0.569537, 0.558474,
    0.083132, 0.060165, 0.107958, 0.748615, 0.943918, 0.486356, 0.418199, 0.652408,
    0.024243, 0.134582, 0.366342, 0.295830, 0.923670, 0.689929, 0.741898, 0.250005,
    0.603430, 0.987289, 0.592606, 0.884672, 0.543450, 0.660770, 0.377128, 0.358021,
]).reshape(2, 4, 3, 3)

targets = torch.LongTensor([[1, 2], [1, 1]])
ref_costs = torch.FloatTensor([4.2806528590890736, 3.9384369822503591])

run_test(
    logits=logits,
    targets=targets,
    ref_costs=ref_costs,
    blank=0
)

#### Utilities

In [None]:
BOS = utils.BOS
tokenizer = utils.RNNTTokenizer()  # added <BOS> token

In [None]:
# Download LibriSpeech test dataset

if not os.path.isdir("./data"):
    os.makedirs("./data")

test_dataset = torchaudio.datasets.LIBRISPEECH("./data", url="test-clean", download=True)
test_transforms = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=80)

In [None]:
def collator_fn(data, transforms) -> Tuple[torch.Tensor, torch.IntTensor, torch.IntTensor, torch.IntTensor]:
    """
    :param data: a LIBRISPEECH dataset
    :param data_type: "train" or "test"
    :return: tuple of
        spectrograms, shape: (B, T, n_mels)
        labels, shape: (B, U)
        input_lengths -- the length of each spectrogram in the batch, shape: (B,)
        label_lengths -- the length of each text label in the batch, shape: (B,)
        where
        B: batch size
        T: maximum source sequence length in batch
        U: maximum target sequence length in batch
        D: feature dimension of each source sequence element
    """
    spectrograms = []
    labels = []
    input_lengths = []
    label_lengths = []
    for (waveform, _, utterance, _, _, _) in data:
        spec = transforms(waveform).squeeze(0).transpose(0, 1)
        spectrograms.append(spec)
        label = torch.IntTensor(tokenizer.text_to_indices(utterance.lower()))
        labels.append(label)
        input_lengths.append(spec.shape[0])
        label_lengths.append(len(label))

    spectrograms = nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
    labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)

    return spectrograms, torch.IntTensor(labels), torch.IntTensor(input_lengths), torch.IntTensor(label_lengths)

test_collator_fn = lambda data: collator_fn(data, test_transforms)

### Implementing a greedy decoder (2 points)

<!-- <p style="text-align:center;"><img src="http://drive.google.com/uc?export=view&id=1tHsoq0ZH0tHSHYlYlw00y8ksF-wHmrmC"> -->
<p style="text-align:center;"><img src="https://raw.githubusercontent.com/yandexdataschool/speech_course/main/week_04_asr/images/rnnt_greedy.png">

Now we know how to train a Transducer, but how do we infer it? Our task is to generate an output sequence $\mathbf y$ given an input acoustic sequence $\mathbf x$.

Here we will index the encoder outputs $f_t$ starting from zero, because it is more convenient when describing an algorithm.

The greedy decoding procedure is as follows:
1. Compute $\{f_0, \ldots, f_T\}$ using $\mathbf x$.
2. Set $t = 0$, $u = 0$, $\mathbf y = []$, $\mathrm{iteration} = 0$.
3. If $u = 0$, set $g_0 = \mathrm{Encoder}(\langle s \rangle)$. If $u > 0$, compute $g_u$ using the last predicted token $\mathbf y[-1]$.
4. Compute $P(y | t, u)$ using $f_t$ and $g_u$.
5. If argmax of $P(y | t, u)$ is a label, set $u = u + 1$ and append the new label to $\mathbf y$. 
6. If argmax of $P(y | t, u)$ is $\emptyset$, set $t = t + 1$.
7. If $t = T$ or $\mathrm{iteration} = \mathrm{max\_iterations}$, we are done. Else, set $\mathrm{iteration} = \mathrm{iteration + 1}$ and go to step 3.

In [None]:
@torch.no_grad()
def greedy_decode(model: 'RNNTransducer', encoder_output: torch.Tensor, max_steps: int = 2000) -> torch.Tensor:
    """
    :param model: an RNN-T model in eval mode
    :param encoder_output: the output of the encoder part of RNN-T, shape: (T, encoder_output_dim)
    :param max_steps: the maximum number of decoding steps
    :return: the predicted labels
    """
    pred_tokens, hidden_state = [], None
    blank = tokenizer.get_symbol_index(BLANK_SYMBOL)
    max_time_steps = encoder_output.size(0)
    t = 0

    decoder_input = encoder_output.new_tensor([[tokenizer.get_symbol_index(BOS)]], dtype=torch.long)
    decoder_output, hidden_state = model.decoder(decoder_input, hidden_states=hidden_state)

    for _ in range(max_steps):
        # <YOUR CODE>

        if t == max_time_steps:
            break

    return torch.LongTensor(pred_tokens)


@torch.no_grad()
def recognize(model: 'RNNTransducer', inputs: torch.Tensor, input_lengths: torch.Tensor) -> List[torch.Tensor]:
    """
    :param model: an RNN-T model in eval mode
    :param inputs: spectrograms, shape: (B, T, n_mels)
    :param input_lengths: the lengths of the spectrograms in the batch, shape: (B,)
    :return: a list with the predicted labels
    """
    outputs = []
    encoder_outputs, _ = model.encoder(inputs, input_lengths)

    for encoder_output in encoder_outputs:
        decoded_seq = greedy_decode(model, encoder_output)
        outputs.append(decoded_seq)

    return outputs


def get_transducer_predictions(
        transducer: 'RNNTransducer', inputs: torch.Tensor, input_lengths: torch.Tensor,
        targets: torch.Tensor, target_lengths: torch.Tensor
    ) -> pd.DataFrame:
    """
    :param transducer: an RNN-T model in eval mode
    :param inputs: spectrograms, shape: (B, T, n_mels)
    :param input_lengths: the lengths of the spectrograms in the batch, shape: (B,)
    :param targets: labels, shape: (B, U)
    :param target_lengths: the lengths of the text labels in the batch, shape: (B,)
    :return: a pd.DataFrame with inference results
    """
    predictions = recognize(transducer, inputs, input_lengths)
    result = []
    for pred, target, target_len in zip(predictions, targets, target_lengths):
        label = target[:target_len]
        utterance = tokenizer.indices_to_text(list(map(int, label)))
        pred_utterance = tokenizer.indices_to_text(list(map(int, pred)))
        result.append({
            "ground_truth": utterance,
            "prediction": pred_utterance,
            "cer": utils.cer(utterance, pred_utterance),
            "wer": utils.wer(utterance, pred_utterance)
        })
    return pd.DataFrame.from_records(result)


In [None]:
model = torch.jit.load(os.path.join(data_directory, 'model_scripted_epoch_5.pt'))
model.eval()

In [None]:
loader = data.DataLoader(test_dataset, batch_size=5, shuffle=False, collate_fn=test_collator_fn)
spectrograms, labels, input_lengths, label_lengths = next(iter(loader))
predictions = get_transducer_predictions(
    model, spectrograms, input_lengths,
    labels, label_lengths
)
predictions

In [None]:
reference_values = [
    {
        "gt": "he hoped there would be stew for dinner turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick peppered flour fattened sauce",
        "prediction": "he hoped there would be stew for dinner turnips and characts and bruised potatoes and fat much and pieces to be lateled out in the thick peppered flowerfacton sauce"
    },
    {
        "gt": "stuff it into you his belly counselled him",
        "prediction": "stuffed into you his belly counciled him"
    },
    {
        "gt": "after early nightfall the yellow lamps would light up here and there the squalid quarter of the brothels",
        "prediction": "after early night fall the yellow lamps would lie how peer and there the squalit quarter of the brothels"
    },
    {
        "gt": "hello bertie any good in your mind",
        "prediction": "her about he and he good in your mind"
    },
    {
        "gt": "number ten fresh nelly is waiting on you good night husband",
        "prediction": "none but den fresh now as waiting on you could night husband"
    }
]

for index in range(5):
    gt = predictions.iloc[index].ground_truth
    prediction = predictions.iloc[index].prediction
    assert gt == reference_values[index]["gt"]
    assert prediction == reference_values[index]["prediction"]

#### RNN-T module (1 point)

In [None]:
class EncoderRNNT(nn.Module):
    def __init__(self, input_dim: int, hidden_size: int, output_dim: int, n_layers: int,
                 dropout: float = 0.2, bidirectional: bool = True):
        """
        An RNN-based model that encodes input audio features into a hidden representation.
        The architecture is a stack of LSTM's followed by a fully-connected output layer.

        :param input_dim: the number of mel-spectrogram features
        :param hidden_size: the number of features in the hidden states in LSTM layers
        :param output_dim: the output dimension
        :param n_layers: the number of stacked LSTM layers
        :param dropout: the dropout probability for LSTM layers
        :param bidirectional: If True, each LSTM layer becomes bidirectional
        """
        super().__init__()

        self.lstm = # <YOUR CODE>

        self.output_proj = # <YOUR CODE>

    def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        """
        :param inputs: spectrograms, shape: (B, T, n_mels)
        :param input_lengths: the lengths of the spectrograms in the batch, shape: (B,)
        :return: outputs of the projection layer and hidden states from LSTMs
        """
        # <YOUR CODE>

        return logits, hidden

In [None]:
def get_pseudo_batch():
    spectrograms = nn.utils.rnn.pad_sequence([
        torch.rand((835, 80)),
        torch.rand((800, 80))
    ], batch_first=True)
    labels = nn.utils.rnn.pad_sequence([
        torch.randint(len(tokenizer.char_map) - 2, (158,)) + 2,
        torch.randint(len(tokenizer.char_map) - 2, (150,)) + 2
    ], batch_first=True)
    input_lengths = torch.IntTensor([835, 800])
    label_lengths = torch.IntTensor([158, 150])
    return spectrograms, labels, input_lengths, label_lengths

In [None]:
encoder = EncoderRNNT(
    input_dim=80,
    hidden_size=320,
    output_dim=512,
    n_layers=4,
    dropout=0.2,
    bidirectional=True
)

spectrograms, labels, input_lengths, label_lengths = get_pseudo_batch()
logits, hidden_states = encoder.forward(spectrograms, input_lengths)

assert spectrograms.shape == torch.Size([2, 835, 80])
assert logits.shape == torch.Size([2, 835, 512])
assert len(hidden_states) == 2
assert hidden_states[0].shape == torch.Size([8, 2, 320])

In [None]:
class DecoderRNNT(nn.Module):
    def __init__(self, hidden_size: int, vocab_size: int, output_dim: int, n_layers: int, dropout: float = 0.2):
        """
        A simple RNN-based autoregressive language model that takes as input previously generated text tokens
        and outputs a hidden representation of the next token

        :param hidden_size: the number of features in the hidden states in LSTM layers
        :param vocab_size: the number of text tokens in the dictionary
        :param output_dim: the output dimension
        :param n_layers: the number of stacked LSTM layers
        :param dropout: the dropout probability for LSTM layers
        """
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.lstm = # <YOUR CODE>
        self.output_proj = # <YOUR CODE>

    def forward(self, inputs: torch.Tensor, input_lengths: Optional[torch.Tensor] = None,
                hidden_states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        """
        :param inputs: labels, shape: (B, U)
        :param input_lengths: the lengths of the text labels in the batch, shape: (B,)
        :return: outputs of the projection layer and hidden states from LSTMs
        """
        embed_inputs = self.embedding(inputs)

        if input_lengths is not None:
            # training phase, the code here is close to `forward` of the Encoder
            # <YOUR CODE>
        else:
            # testing phase
            outputs, hidden = self.lstm(embed_inputs, hidden_states)

        outputs = self.output_proj(outputs)
        return outputs, hidden

In [None]:
decoder = DecoderRNNT(
    hidden_size=512,
    vocab_size=len(tokenizer.char_map),
    output_dim=512,
    n_layers=1,
    dropout=0.2
)

spectrograms, labels, input_lengths, label_lengths = get_pseudo_batch()
logits, hidden_states = decoder.forward(labels, label_lengths)

assert labels.shape == torch.Size([2, 158])
assert logits.shape == torch.Size([2, 158, 512])
assert len(hidden_states) == 2
assert hidden_states[0].shape == torch.Size([1, 2, 512])

In [None]:
class Joiner(nn.Module):
    def __init__(self, joiner_dim: int, num_outputs: int):
        """
        Adds encoder and decoder outputs, applies ReLU and passes the result
        through a fully connected layer to get the output logits

        :param joiner_dim: the dimension of the encoder and decoder outputs
        :num_outputs: the number of text tokens in the dictionary
        """
        super().__init__()
        self.linear = nn.Linear(joiner_dim, num_outputs)

    def forward(self, encoder_outputs: torch.Tensor, decoder_outputs: torch.Tensor) -> torch.Tensor:
        """
        :param encoder_outputs: the encoder outputs (f_t), shape: (B, T, joiner_dim) or (joiner_dim,)
        :param decoder_outputs: the decoder outputs (g_u), shape: (B, U, joiner_dim) or (joiner_dim,)
        :return: output logits
        """
        if encoder_outputs.dim() == 3 and decoder_outputs.dim() == 3:    # True for training phase
            encoder_outputs = encoder_outputs.unsqueeze(2)
            decoder_outputs = decoder_outputs.unsqueeze(1)

        # Linear(ReLU(f_t + g_u))
        out = self.linear(F.relu(encoder_outputs + decoder_outputs))
        return out

In [None]:
class RNNTransducer(torch.nn.Module):
    def __init__(self,
        num_classes: int,
        input_dim: int,
        num_encoder_layers: int = 4,
        num_decoder_layers: int = 1,
        encoder_hidden_state_dim: int = 320,
        decoder_hidden_state_dim: int = 512,
        output_dim: int = 512,
        encoder_is_bidirectional: bool = True,
        encoder_dropout_p: float = 0.2,
        decoder_dropout_p: float = 0.2
    ):
        """
        :param num_classes: the number of text tokens in the dictionary
        :param input_dim: the number of mel-spectrogram features
        :param num_encoder_layers: the number of LSTM layers in the encoder
        :param num_decoder_layers: the number of LSTM layers in the decoder
        :param encoder_hidden_state_dim: the number of features in the hidden states for the encoder
        :param decoder_hidden_state_dim: the number of features in the hidden states for the decoder
        :param output_dim: the output dimension
        :param encoder_is_bidirectional: whether to use bidirectional LSTM's in the encoder
        :param encoder_dropout_p: the dropout probability for the encoder
        :param decoder_dropout_p: the dropout probability for the decoder
        """
        super().__init__()
        self.encoder = # <YOUR CODE>

        # The decoder takes the input <BOS> + the original sequence.
        # You need to shift the current label, and F.pad can help with that.
        self.decoder = # <YOUR CODE>
        self.joiner = Joiner(output_dim, num_classes)

    def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor,
                targets: torch.Tensor, target_lengths: torch.Tensor) -> torch.Tensor:
        """
        :param inputs: spectrograms, shape: (B, T, n_mels)
        :param input_lengths: the lengths of the spectrograms in the batch, shape: (B,)
        :param targets: labels, shape: (B, U)
        :param target_lengths: the lengths of the text labels in the batch, shape: (B,)
        :return: the output logits, shape: (B, T, U, n_tokens)
        """
        encoder_outputs, _ = self.encoder(inputs, input_lengths)
        # <YOUR CODE>
        decoder_outputs, _ = # <YOUR CODE>
        joiner_out = self.joiner(encoder_outputs, decoder_outputs)
        return joiner_out


In [None]:
transducer = RNNTransducer(
    num_classes=len(tokenizer.char_map),
    input_dim=80,
    num_encoder_layers=4,
    num_decoder_layers=1,
    encoder_hidden_state_dim=320,
    decoder_hidden_state_dim=512,
    output_dim=512,
    encoder_is_bidirectional=True,
    encoder_dropout_p=0.2,
    decoder_dropout_p=0.2
)

spectrograms, labels, input_lengths, label_lengths = get_pseudo_batch()
result = transducer.forward(spectrograms, input_lengths, labels, label_lengths)

assert spectrograms.shape == torch.Size([2, 835, 80])
assert labels.shape == torch.Size([2, 158])
assert result.shape == torch.Size([2, 835, 159, 30])

### Examples

- Nvidia https://huggingface.co/nvidia/parakeet-rnnt-1.1b
- Streaming https://pytorch.org/audio/main/tutorials/online_asr_tutorial.html