## Seminar 4

In seminar 4 you will implemement forward and backward algorithms for calculating the RNN-T loss.

# Setup - Install package, download files, etc...

In [8]:
# TODO: change link to a link from repository
!wget --load-cookies /tmp/cookies.txt "https://drive.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://drive.google.com/uc?export=download&id=14vgOVBayQGYv9B1P3hYo3JM56rS6ap3U' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=14vgOVBayQGYv9B1P3hYo3JM56rS6ap3U" -O model_scripted_epoch_5.pt && rm -rf /tmp/cookies.txt


--2024-02-29 16:18:51--  https://drive.google.com/uc?export=download&confirm=&id=14vgOVBayQGYv9B1P3hYo3JM56rS6ap3U
Распознаётся drive.google.com (drive.google.com)… 74.125.205.194
Подключение к drive.google.com (drive.google.com)|74.125.205.194|:443... соединение установлено.
HTTP-запрос отправлен. Ожидание ответа… 303 See Other
Адрес: https://drive.usercontent.google.com/download?id=14vgOVBayQGYv9B1P3hYo3JM56rS6ap3U&export=download [переход]
--2024-02-29 16:18:51--  https://drive.usercontent.google.com/download?id=14vgOVBayQGYv9B1P3hYo3JM56rS6ap3U&export=download
Распознаётся drive.usercontent.google.com (drive.usercontent.google.com)… 74.125.131.132
Подключение к drive.usercontent.google.com (drive.usercontent.google.com)|74.125.131.132|:443... соединение установлено.
HTTP-запрос отправлен. Ожидание ответа… 200 OK
Длина: 2436 (2,4K) [text/html]
Сохранение в: «model_scripted_epoch_5.pt»


2024-02-29 16:18:51 (8,64 MB/s) - «model_scripted_epoch_5.pt» сохранён [2436/2436]



In [9]:
import os
import string
from typing import Tuple, List, Dict, Optional

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import wandb
import ipywidgets as widgets
import itertools
from torch import optim
from torchaudio.transforms import RNNTLoss
from tqdm import tqdm_notebook, tqdm
from IPython.display import display, clear_output

In [10]:
import utils as utils 

# Lecture recap

Link to paper: https://arxiv.org/abs/1211.3711

## Alignment

Let $\mathbf{x} = (x_1, x_2, \ldots, x_T)$ be a length $T$ input sequence of arbitrary length beloging to the set $X^*$ of all sequences over some input space $X$. Let $\mathbf{y} = (y_1, \ldots, y_U)$ be a length $U$ output sequence belonging to the set $Y^*$ of all sequences over some output space $Y$.

Define the *extended output space* $\overline Y$ as $Y \cup \emptyset$, where $\emptyset$ denotes the null output. The intuitive meaning of $\emptyset$ is 'output nothing'. The sequence $(y_1, \emptyset, \emptyset, y_2, \emptyset, y_3) \in \overline Y^*$ is therefore equivalent to $(y_1, y_2, y_3) \in Y^*$. We refer to the elements $\mathbf{a} \in \overline Y^*$ as *alignments*, since the location of the null symbols determines an alignment between the input and output sequences.

As we saw in CTC, various alignments can be represented in the form of a table called trellis. An example of how an RNN-T trellis may look like:

<p style="text-align:center;"><img src="images/rnnt_trellis_1.png">
    
    
Possible alignments in that trellis:
    
<p style="text-align:center;"><img src="images/rnnt_trellis_2.png">
    
The final label can be determined by simply removing the blank characher:
    
$$
    C \emptyset \emptyset A \emptyset T \emptyset \to CAT
$$
$$
    \emptyset \emptyset \emptyset C A T \emptyset \to CAT
$$
    
Given $\mathbf{x}$, the RNN transducer defines a conditional distribution $P(\mathbf{a} \in \overline Y^* | \mathbf{x})$. This distribution is then collapsed onto the following distribution over $Y^*$:
    
$$
    P(\mathbf y \in Y^* | \mathbf x) = \sum_{\mathbf a \in \mathcal{B}^{-1}(\mathbf y)} P(\mathbf a | \mathbf x),
$$
    
where $\mathcal B: \overline Y^* \mapsto Y^*$ is a function that removes the null symbols from the alignments in $\overline Y^*$.


## Architecture

<p style="text-align:center;"><img src="images/architecture.png">

The RNN-T model consists of three neural networks: Encoder, Predictor and Joiner. The Encoder converts the acoustic feature $x_t$ into a high-level representation $f_t$, where $t$ is time index:

$$
    f_t = \mathrm{Encoder}(x_t)
$$

The Predictor works like an RNN language model, which produces a high-level representation $g_u$ by conditioning on the previous non-blank target $y_{u - 1}$ predicted by the RNN-T model, where $u$ is output label index:

$$
    g_u = \mathrm{Predictor}(y_{u - 1})
$$

Note that the input sequence for the predictor **is prepended with the special symbol** $\langle s \rangle$ that defines the start of a sentence.

The Joiner is a feed forward network that combines the Encoder output $f_t$ and the Predictor output $g_u$ as

$$
    h_{t, u} = \mathrm{Joiner}(f_t, g_u) = \mathrm{FeedForward}(\mathrm{ReLU}(f_t + g_u))
$$

The final posterior for each output token $y$ is obtained after applying the softmax operation:

$$
    P(y | t, u) = \mathrm{softmax}(h_{t, u})
$$
    
where $P(y | t, u)$ is a distribution of probabilities to emit $y \in \overline Y$ at time step $t$ after $u$ previously generated characters, $t \in [1, T], u \in [0, U]$.

<p style="text-align:center;"><img src="images/distribution.png">

We will further need to work with probabilities of individual tokens $y$ for different $t$ and $u$. Instead of writing each time something like $P(y = C | t = 1, u = 0)$, we will, for the sake of simplicity, write it as $P(C | 1, 0)$.

## Training: forward-backward algorithm

The loss function of RNN-T is the negative log posterior of output label sequence $\mathbf y$ given acoustic feature $\mathbf x$:

$$
    \mathcal L = -\ln P(\mathbf y \in Y^* | \mathbf x) = -\ln \sum_{\mathbf a \in \mathcal{B}^{-1}(\mathbf y)} P(\mathbf a | \mathbf x)
$$

To determine $P(\mathbf a | \mathbf x)$ for an arbitrary alignment $\mathbf a$, we need to multiply the probabilities $P(y | t, u)$ of each symbol across the path:

<p style="text-align:center;"><img src="images/rnnt_trellis_3.png">

$$
    \mathbf a = C \emptyset \emptyset A \emptyset T \emptyset
$$
    
$$
    P(\mathbf a | \mathbf x) = P(C | 1, 0) \cdot P(\emptyset | 1, 0) \cdot P(\emptyset | 2, 1) \cdot P(A | 3, 1) \cdot P(\emptyset | 3, 2) \cdot P(T | 3, 2) \cdot P(\emptyset | 4, 3)
$$

There are usually too many possible alignments to compute the loss function by just adding them all up directly. We will use dynamic programming to make this computation feasible.

Define the *forward variable* $\alpha(t, u)$ as the probability of outputting $\mathbf y_{[1:u]}$ during $\mathbf f_{[1:t]}$. The forward variables for all $1 \le t \le T$ and $0 \le u \le U$ can be calculated recursively using

$$
    \alpha(t, u) = \alpha(t - 1, u) P(\emptyset | t - 1, u) + \alpha(t, u - 1) P(y_{u - 1} | t, u - 1)
$$

with initial condition $\alpha(1, 0) = 1$. Here $y_{u - 1}$ is the $(u - 1)$-th symbol from the ground truth label $\mathbf y$.

The total output sequene probability is equal to the forward variable at the terminal node:

$$
    P(\mathbf y | \mathbf x) = \alpha(T, U) P(\emptyset | T, U)
$$

Define the *backward variable* $\beta(t, u)$ as the probability of outputting $\mathbf y_{[u + 1: U]}$ during $\mathbf f_{[t:T]}$. Then

$$
    \beta(t, u) = \beta(t + 1, u) P(\emptyset | t, u) + \beta(t, u + 1) P(y_u | t, u)
$$

with initial condition $\beta(T, U) = P(\emptyset | T, U)$. The final value is $\beta(1, 0)$.

From the definition of the forward and backward variables it follows that their product $\alpha(t, u) \beta(t, u)$ at any point $(t, u)$ in the output lattice is equal to the probability of emitting the complete output sequence *if $y_u$ is emitted during transcription step $t$*.

# Seminar 4: RNN-T Forward-Backward Algorithm

* Implement a Forward Pass
* Implement a Backward Pass

Implement forward and backward passes.


### Implementation tips

- Note that all indices in the arrays you will work with in your code start with zeros. So, the initial condition for forward algorithm will be $\alpha(0, 0) = 1$ (and $\log \alpha(0, 0) = 0$) and the output value for backward algorithm will be $\beta(0, 0)$. The recurrent formulas stay the same. Also, don't be confused with the terminal node: you don't have to add it to $\alpha$- and $\beta$-arrays. The dynamic starts in the upper left corner for forward variables and in the lower right corner for backward variables.
- You will need to do everything in log-domain for calculations to be numercally stable. The function [np.logaddexp](https://numpy.org/doc/stable/reference/generated/numpy.logaddexp.html) might help you with it.

In [10]:
def forward(log_probs: torch.FloatTensor, targets: torch.LongTensor, 
            blank: int = -1) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
    """
    :param log_probs: model outputs after applying log_softmax
    :param targets: the target sequence of tokens, represented as integer indexes
    :param blank: the index of blank symbol
    :return: Tuple[ln alpha, -(ln alpha(T, U) + ln P(blank | T, U))]. 
        The latter term is loss value, which is -ln P(y | x)
    """
    max_T, max_U, D = log_probs.shape
    
    # here the alpha variable contains logarithm of the alpha variable from the formulas above
    alpha = np.zeros((max_T, max_U), dtype=np.float32)

    for t in range(1, max_T):
        alpha[t, 0] = alpha[t-1, 0] + log_probs[t-1, 0, blank]

    for u in range(1, max_U):
        alpha[0, u] = alpha[0, u-1] + log_probs[0, u-1, targets[min(u-1, len(targets)-1)]]

    for t in range(1, max_T):
        for u in range(1, max_U):
            alpha[t, u] =  np.logaddexp(
                alpha[t-1, u] + log_probs[t-1, u, blank],
                alpha[t, u-1] + log_probs[t, u-1, targets[min(u-1, len(targets)-1)]]
            )

    cost = - (log_probs[-1, -1, blank] + alpha[-1, -1]) 
    return alpha, cost


def backward(log_probs: torch.FloatTensor, targets: torch.LongTensor, 
             blank: int = -1) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
    """
    :param log_probs: model outputs after applying log_softmax
    :param targets: the target sequence of tokens, represented as integer indexes
    :param blank: the index of blank symbol
    :return: Tuple[ln beta, -ln beta(0, 0)]. The latter term is loss value, which is -ln P(y | x)
    """
    max_T, max_U, D = log_probs.shape
    
    # here the beta variable contains logarithm of the beta variable from the formulas above
    beta = np.zeros((max_T, max_U), dtype=np.float32)
    beta[-1, -1] = log_probs[-1, -1, blank]

    for t in reversed(range(max_T - 1)):
        beta[t, max_U-1] = beta[t+1, max_U-1] + log_probs[t, max_U-1, blank]

    for u in reversed(range(max_U - 1)):
        beta[max_T-1, u] = beta[max_T-1, u+1] + log_probs[max_T-1, u, targets[min(u, len(targets)-1)]]

    for t in reversed(range(max_T - 1)):
        for u in reversed(range(max_U - 1)):
            beta[t, u] =  np.logaddexp(
                beta[t+1, u] + log_probs[t, u, blank],
                beta[t, u+1] + log_probs[t, u, targets[min(u, len(targets)-1)]]
            )
            
    cost = - beta[0, 0]
    return beta, cost

In [11]:
def run_test(logits: torch.FloatTensor, targets: torch.LongTensor, 
             ref_costs: torch.FloatTensor, blank: int = -1) -> None:
    """
    :param logits: model outputs
    :param targets: the target sequence of tokens, represented as integer indexes
    :param ref_costs: the true values of RNN-T costs for test inputs
    :param blank: the index of blank symbol
    """
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
    cost = np.zeros(log_probs.shape[0])
    
    for batch_id in range(log_probs.shape[0]):        
        alphas, cost_alpha = forward(log_probs[batch_id], targets[batch_id], blank=blank)
        betas, cost_beta = backward(log_probs[batch_id], targets[batch_id], blank=blank)

        np.testing.assert_almost_equal(cost_alpha, cost_beta, decimal=2)
        cost[batch_id] = cost_beta
    
    np.testing.assert_almost_equal(cost, ref_costs, decimal=2)

In [12]:
# Tests

'''
All logits in tests have shapes in the form (B, T, U, D) where

B: batch size
T: maximum source sequence length in batch
U: maximum target sequence length in batch
D: feature dimension of each source sequence element
'''

# test 1
logits = torch.FloatTensor([
    0.1, 0.6, 0.1, 0.1, 0.1,
    0.1, 0.1, 0.6, 0.1, 0.1,
    0.1, 0.1, 0.2, 0.8, 0.1,
    0.1, 0.6, 0.1, 0.1, 0.1,
    0.1, 0.1, 0.2, 0.1, 0.1,
    0.7, 0.1, 0.2, 0.1, 0.1,
]).reshape(1, 2, 3, 5)

targets = torch.LongTensor([[1, 2]])
ref_costs = torch.FloatTensor([5.09566688538])

run_test(
    logits=logits, 
    targets=targets, 
    ref_costs=ref_costs, 
    blank=-1
)

# test 2
logits = torch.FloatTensor([
    0.065357, 0.787530, 0.081592, 0.529716, 0.750675, 0.754135, 0.609764, 0.868140,
    0.622532, 0.668522, 0.858039, 0.164539, 0.989780, 0.944298, 0.603168, 0.946783,
    0.666203, 0.286882, 0.094184, 0.366674, 0.736168, 0.166680, 0.714154, 0.399400,
    0.535982, 0.291821, 0.612642, 0.324241, 0.800764, 0.524106, 0.779195, 0.183314,
    0.113745, 0.240222, 0.339470, 0.134160, 0.505562, 0.051597, 0.640290, 0.430733,
    0.829473, 0.177467, 0.320700, 0.042883, 0.302803, 0.675178, 0.569537, 0.558474,
    0.083132, 0.060165, 0.107958, 0.748615, 0.943918, 0.486356, 0.418199, 0.652408,
    0.024243, 0.134582, 0.366342, 0.295830, 0.923670, 0.689929, 0.741898, 0.250005,
    0.603430, 0.987289, 0.592606, 0.884672, 0.543450, 0.660770, 0.377128, 0.358021,
]).reshape(2, 4, 3, 3)

targets = torch.LongTensor([[1, 2], [1, 1]])
ref_costs = torch.FloatTensor([4.2806528590890736, 3.9384369822503591])

run_test(
    logits=logits, 
    targets=targets, 
    ref_costs=ref_costs, 
    blank=0
)