Today we will be implementing chunked-streaming for ctc-conformer model

Recap:

* [Conformer](https://arxiv.org/pdf/2005.08100) 
* [Chunked streaming](https://arxiv.org/pdf/2312.17279)
* [NeMo repository](https://github.com/NVIDIA/NeMo) -- source of weight and basic idea for our seminar

In [None]:
# !pip install librosa
# !pip install torch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 --index-url https://download.pytorch.org/whl/cpu

In [None]:
import json
import librosa
import math
import numpy as np
import os
import pickle
import queue
import requests
import torch
import torch.nn
import torch.nn.functional as F
import wave

from IPython.display import Audio
from urllib.parse import urlencode

In [None]:
def download_file(public_link, filename='archieve.tgz'):
    base_url = 'https://cloud-api.yandex.net/v1/disk/public/resources/download?'
    final_url = base_url + urlencode(dict(public_key=public_link))
    response = requests.get(final_url)
    parse_href = response.json()['href']

    url = parse_href
    download_url = requests.get(url)
    final_link = os.path.join(os.getcwd(), filename)
    print(final_link)
    with open(final_link, 'wb') as ff:
        ff.write(download_url.content)

In [None]:
# link_to_archive = "https://disk.yandex.ru/d/Omgg4HryF5AWLQ"
# download_file(link_to_archive, filename='archieve.tgz')
# !mkdir -p ../data
# !mv archieve.tgz ../data/
# !tar xzvf ../data/archieve.tgz -C ../data


In [None]:
class FilterbankFeatures(torch.nn.Module):
    """Featurizer that converts wavs to Mel Spectrograms.
    See AudioToMelSpectrogramPreprocessor for args.
    """

    def __init__(
        self,
        sample_rate=16000,
        n_window_size=400,
        n_window_stride=160,
        # n_fft=512,
        preemph=0.97,
        nfilt=80,
        lowfreq=0,
        highfreq=None,
        log=True,
        log_zero_guard_value=2 ** -24,
        pad_value=0,
        nb_max_freq=4000,
        mel_norm="slaney",
    ):
        super().__init__()
        self.log_zero_guard_value = log_zero_guard_value

        self.win_length = n_window_size
        self.hop_length = n_window_stride
        self.n_fft = n_window_size

        window_fn = torch.hann_window
        window_tensor = window_fn(self.win_length, periodic=False)
        self.register_buffer("window", window_tensor)
        self.stft = lambda x: torch.stft(
            x,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            win_length=self.win_length,
            center=True,
            window=self.window.to(dtype=torch.float),
            return_complex=True,
        )

        self.nfilt = nfilt
        self.preemph = preemph
        highfreq = highfreq or sample_rate / 2

        filterbanks = torch.tensor(
            librosa.filters.mel(
                sr=sample_rate, n_fft=self.n_fft, n_mels=nfilt, fmin=lowfreq, fmax=highfreq, norm=mel_norm
            ),
            dtype=torch.float,
        ).unsqueeze(0)
        self.register_buffer("fb", filterbanks)

        self.pad_value = pad_value

        self.forward = torch.no_grad()(self.forward)

    @property
    def filter_banks(self):
        return self.fb

    def forward(self, x):
        x = torch.cat((x[0].unsqueeze(0), x[1:] - self.preemph * x[:-1]), dim=0)
        # disable autocast to get full range of stft values
        with torch.cuda.amp.autocast(enabled=False):
            x = self.stft(x)
        x = torch.view_as_real(x)
        x = x.pow(2).sum(-1)
        # dot with filterbank energies
        x = torch.matmul(self.fb.to(x.dtype), x).squeeze(0)
        
        x = torch.log(x + self.log_zero_guard_value)
        return x

In [None]:
class RelPositionalEncoding(torch.nn.Module):
    """Relative positional encoding for TransformerXL's layers
    See : Appendix B in https://arxiv.org/abs/1901.02860
    Args:
        d_model (int): embedding dim
        dropout_rate (float): dropout rate
        max_len (int): maximum input length
        xscale (bool): whether to scale the input by sqrt(d_model)
        dropout_rate_emb (float): dropout rate for the positional embeddings
    """
    def __init__(self, d_model, dropout_rate, max_len=5000, xscale=None, dropout_rate_emb=0.0):
        """Construct an PositionalEncoding object."""
        super().__init__()
        self.d_model = d_model
        self.xscale = xscale
        self.dropout = torch.nn.Dropout(p=dropout_rate)
        self.max_len = max_len
        if dropout_rate_emb > 0:
            self.dropout_emb = torch.nn.Dropout(dropout_rate_emb)
        else:
            self.dropout_emb = None

    def create_pe(self, positions):
        pos_length = positions.size(0)
        pe = torch.zeros(pos_length, self.d_model, device=positions.device)
        div_term = torch.exp(
            torch.arange(0, self.d_model, 2, dtype=torch.float32, device=positions.device)
            * -(math.log(10000.0) / self.d_model)
        )
        pe[:, 0::2] = torch.sin(positions * div_term)
        pe[:, 1::2] = torch.cos(positions * div_term)
        pe = pe.unsqueeze(0)
        if hasattr(self, 'pe'):
            self.pe = pe
        else:
            self.register_buffer('pe', pe, persistent=False)

    def extend_pe(self, length, device=None):
        """Reset and extend the positional encodings if needed."""
        needed_size = 2 * length - 1
        if hasattr(self, 'pe') and self.pe.size(1) >= needed_size:
            return
        # positions would be from negative numbers to positive
        # positive positions would be used for left positions and negative for right positions
        positions = torch.arange(length - 1, -length, -1, dtype=torch.float32, device=device).unsqueeze(1)
        self.create_pe(positions=positions)

    def forward(self, x):
        """Compute positional encoding.
        Args:
            x (torch.Tensor): Input. Its shape is (batch, time, feature_size)
        Returns:
            x (torch.Tensor): Its shape is (batch, time, feature_size)
            pos_emb (torch.Tensor): Its shape is (1, 2 * time - 1, feature_size)
        """

        if self.xscale:
            x = x * self.xscale

        # center_pos would be the index of position 0
        # negative positions would be used for right and positive for left tokens
        # for input of length L, 2*L-1 positions are needed, positions from (L-1) to -(L-1)
        input_len = x.size(1)
        center_pos = self.pe.size(1) // 2 + 1
        start_pos = center_pos - input_len
        end_pos = center_pos + input_len - 1
        pos_emb = self.pe[:, start_pos:end_pos]
        if self.dropout_emb:
            pos_emb = self.dropout_emb(pos_emb)
        return self.dropout(x), pos_emb

    def get_initial_state(self, chunk_size, left_context_chunks):
        raise NotImplementedError()
    
    def streaming_forward(self, x, state):
        raise NotImplementedError()

In [None]:
def create_attn_mask(chunk_size: int, left_chunks_num: int, max_length: int, device):
    """
    Returns:
        torch.Tensor (1, max_length, max_length) [bool]: True means value should be used.
    """
    # [t]
    chunk_idx = torch.arange(0, max_length, dtype=torch.int, device=device)
    chunk_idx = torch.div(chunk_idx, chunk_size, rounding_mode="trunc")

    # [t, t]: diff_chunks[i, j] = chunk_idx[i] - chunk_idx[j]
    diff_chunks = chunk_idx.unsqueeze(1) - chunk_idx.unsqueeze(0)
    chunked_limited_mask = torch.logical_and(
        torch.le(diff_chunks, left_chunks_num), torch.ge(diff_chunks, 0)
    )
    att_mask = chunked_limited_mask.unsqueeze(0)
    return att_mask


class RelPositionMultiHeadAttention(torch.nn.Module):
    """Multi-Head Attention layer of Transformer-XL with support of relative positional encoding.
    Paper: https://arxiv.org/abs/1901.02860
    Args:
        n_head (int): number of heads
        n_feat (int): size of the features
        dropout_rate (float): dropout rate
    """

    def __init__(
        self,
        n_head: int,
        n_feat: int,
        dropout_rate: float,
    ):
        super().__init__()
        assert n_feat % n_head == 0
        # We assume d_v always equals d_k
        self.d_k = n_feat // n_head
        self.s_d_k = math.sqrt(self.d_k)
        self.h = n_head
        self.linear_q = torch.nn.Linear(n_feat, n_feat)
        self.linear_k = torch.nn.Linear(n_feat, n_feat)
        self.linear_v = torch.nn.Linear(n_feat, n_feat)
        self.linear_out = torch.nn.Linear(n_feat, n_feat)
        self.dropout = torch.nn.Dropout(p=dropout_rate)

        # linear transformation for positional encoding
        self.linear_pos = torch.nn.Linear(n_feat, n_feat, bias=False)

        # self.pos_bias_u = pos_bias_u
        # self.pos_bias_v = pos_bias_v
        self.pos_bias_u = torch.nn.Parameter(torch.FloatTensor(self.h, self.d_k))
        self.pos_bias_v = torch.nn.Parameter(torch.FloatTensor(self.h, self.d_k))

        torch.nn.init.zeros_(self.pos_bias_u)
        torch.nn.init.zeros_(self.pos_bias_v)

    def forward_qkv(self, query, key, value):
        """Transforms query, key and value.
        Args:
            query (torch.Tensor): (batch, time1, size)
            key (torch.Tensor): (batch, time2, size)
            value (torch.Tensor): (batch, time2, size)
        returns:
            q (torch.Tensor): (batch, head, time1, size)
            k (torch.Tensor): (batch, head, time2, size)
            v (torch.Tensor): (batch, head, time2, size)
        """
        n_batch = query.size(0)
        q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
        k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
        v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        return q, k, v

    def forward_attention(self, value, scores, mask):
        """Compute attention context vector.
        Args:
            value (torch.Tensor): (batch, time2, size)
            scores(torch.Tensor): (batch, time1, time2)
            mask(torch.Tensor): (batch, time1, time2)
        returns:
            value (torch.Tensor): transformed `value` (batch, time2, d_model) weighted by the attention scores
        """
        n_batch = value.size(0)
        if mask is not None:
            mask = mask.unsqueeze(1)  # (batch, 1, time1, time2)
            scores = scores.masked_fill(mask, -10000.0)
            attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)  # (batch, head, time1, time2)
        else:
            attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)

        p_attn = self.dropout(attn)
        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
        x = x.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k)  # (batch, time1, d_model)

        return self.linear_out(x)  # (batch, time1, d_model)
    
    def rel_shift(self, x):
        """Compute relative positional encoding.
        Args:
            x (torch.Tensor): (batch, nheads, time1, 2*time2-1)
        """
        b, h, qlen, pos_len = x.size()  # (b, h, t1, 2 * t2 - 1)
        # need to add a column of zeros on the left side of last dimension to perform the relative shifting
        x = torch.nn.functional.pad(x, pad=(1, 0))  # (b, h, t1, 1 + (2 * t2 - 1))
        x = x.view(b, h, -1, qlen)  # (b, h, t1 + 1, 2 * t2 - 1)
        # need to drop the first row
        x = x[:, :, 1:].view(b, h, qlen, pos_len)  # (b, h, t1, 2 * t2 - 1)
        return x

    def forward(self, query, key, value, mask, pos_emb):
        """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
        Args:
            query (torch.Tensor): (batch, time1, size)
            key (torch.Tensor): (batch, time2, size)
            value(torch.Tensor): (batch, time2, size)
            mask (torch.Tensor): (batch, time1, time2)
            pos_emb (torch.Tensor) : (batch, 2 * time1 - 1, size)

        Returns:
            output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention
        """

        q, k, v = self.forward_qkv(query, key, value)
        q = q.transpose(1, 2)  # (batch, time1, head, d_k)

        n_batch_pos = pos_emb.size(0)

        # (batch, 2 * time1 - 1, head, d_k)
        p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
        # (batch, head, 2 * time1 - 1, d_k)
        p = p.transpose(1, 2)

        # (batch, head, time1, d_k)
        q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
        # (batch, head, time1, d_k)
        q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)

        # compute attention score
        # first compute matrix a and matrix c
        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
        # (batch, head, time1, time2)
        matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))

        # compute matrix b and matrix d
        # (batch, head, time1, time1)
        matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
        matrix_bd = self.rel_shift(matrix_bd)

        # drops extra elements in the matrix_bd to match the matrix_ac's size
        matrix_bd = matrix_bd[:, :, :, : matrix_ac.size(-1)]

        scores = (matrix_ac + matrix_bd) / self.s_d_k  # (batch, head, time1, time2)

        out = self.forward_attention(v, scores, mask)
        
        return out

    def get_initial_state(self, chunk_size, left_context_chunks):
        raise NotImplementedError()

    def streaming_forward(self, x, pos_emb, state):
        """
        Args:
            x: torch.Tensor (1, T, d)
            pos_emb: ??
            state: ???
        Returns:
            x, state
        """
        raise NotImplementedError()

In [None]:
def calc_length(lengths: torch.Tensor, all_paddings: int, kernel_size: int, stride: int, repeat_num=1):
    """ Calculates the output length of a Tensor passed through a convolution or max pooling layer"""
    add_pad: float = all_paddings - kernel_size
    one: float = 1.0
    for i in range(repeat_num):
        lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + one
        lengths = torch.floor(lengths)
    return lengths.to(dtype=torch.int)


class CausalConv2D(torch.nn.Conv2d):
    """
    A causal version of nn.Conv2d where each location in the 2D matrix would have no access to locations on its right or down
    All arguments are the same as nn.Conv2d except padding which should be set as None
    """

    def __init__(
        self,
        in_feats: int,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        groups: int = 1,
    ) -> None:
        self._in_feats = in_feats
        self._in_channels = in_channels

        # NOTE: originally (in NeMo repo) right_padding = bottom_padding = stride - 1
        # but we change right_padding to 0 for better streaming consistency
        # and keep _bottom_padding at stride - 1 to have matching weights shape
        self._left_padding = kernel_size - 1
        self._right_padding = 0 # stride - 1
        self._top_padding = kernel_size - 1
        self._bottom_padding = stride - 1

        super(CausalConv2D, self).__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=0,
            groups=groups,
        )

    def forward(self, x):
        x = F.pad(x, pad=(self._top_padding, self._bottom_padding, self._left_padding, self._right_padding))
        x = super().forward(x)
        return x

    def get_initial_state(self):
        raise NotImplementedError()

    def streaming_forward(self, x, state):
        raise NotImplementedError()
        

class ConvSubsampling(torch.nn.Module):
    def __init__(
        self,
        subsampling_factor: int,
        feat_in: int,
        feat_out: int,
        conv_channels: int,
        activation,
    ):
        super().__init__()
        self._conv_channels = conv_channels
        self._feat_in = feat_in
        self._feat_out = feat_out
        self._sampling_num = int(math.log(subsampling_factor, 2))
        self._subsampling_factor = subsampling_factor

        in_channels = 1
        layers = []

        self._stride = 2
        self._kernel_size = 3

        self._left_padding = self._kernel_size - 1
        # self._right_padding = self._stride - 1
        self._right_padding = 0
        self._top_padding = self._kernel_size - 1
        self._bottom_padding = self._stride - 1

        # Layer 1
        
        layers.append(
            CausalConv2D(
                in_feats = self._feat_in,
                in_channels=in_channels,
                out_channels=conv_channels,
                kernel_size=self._kernel_size,
                stride=self._stride,
            )
        )
        in_channels = conv_channels
        in_feats = int(
            calc_length(
                torch.tensor(self._feat_in, dtype=torch.float),
                all_paddings=self._top_padding + self._bottom_padding,
                kernel_size=self._kernel_size,
                stride=self._stride,
                repeat_num=1
            )
        )
            
        layers.append(activation)

        for i in range(self._sampling_num - 1):
            layers.append(
                CausalConv2D(
                    in_feats=in_feats,
                    in_channels=in_channels,
                    out_channels=in_channels,
                    kernel_size=self._kernel_size,
                    stride=self._stride,
                    groups=in_channels,
                )
            )

            layers.append(
                torch.nn.Conv2d(
                    in_channels=in_channels,
                    out_channels=conv_channels,
                    kernel_size=1,
                    stride=1,
                    padding=0,
                    groups=1,
                )
            )
            layers.append(activation)
            in_channels = conv_channels
            in_feats = int(
                calc_length(
                    torch.tensor(in_feats, dtype=torch.float),
                    all_paddings=self._top_padding + self._bottom_padding,
                    kernel_size=self._kernel_size,
                    stride=self._stride,
                    repeat_num=1
                )
            )

        out_length = calc_length(
            lengths=torch.tensor(self._feat_in, dtype=torch.float),
            all_paddings=self._top_padding + self._bottom_padding,
            kernel_size=self._kernel_size,
            stride=self._stride,
            repeat_num=self._sampling_num,
        )

        # b, c, t, f -> b, t, feat_out
        self.out = torch.nn.Linear(conv_channels * out_length, self._feat_out)
        self.conv = torch.nn.Sequential(*layers)

    def forward(self, x, lengths):
        lengths = calc_length(
            lengths,
            all_paddings=self._left_padding + self._right_padding,
            kernel_size=self._kernel_size,
            stride=self._stride,
            repeat_num=self._sampling_num,
        )
        # b, t, f -> b, c, t, f
        x = x.unsqueeze(1)
        x = self.conv(x)

        b, c, t, f = x.size()
        # b, c, t, f -> b, t, c * f
        x = x.transpose(1, 2).reshape(b, t, -1)
        x = self.out(x)
        return x, lengths

    def get_initial_state(self):
        raise NotImplementedError()

    def streaming_forward(self, x, state):
        """
            x: torch.Tensor of shape [1, T, F]
            state: dict
        """
        raise NotImplementedError()
        

In [None]:
class ConformerFeedForward(torch.nn.Module):
    def __init__(
        self,
        d_model: int,
        d_ff: int,
        dropout: float,
        activation: torch.nn.Module = torch.nn.SiLU()
    ):
        super().__init__()
        self._d_model = d_model
        self._d_ff = d_ff
        self._dropout = dropout
        
        self.linear1 = torch.nn.Linear(self._d_model, self._d_ff)
        self.activation = activation
        self.dropout = torch.nn.Dropout(p=self._dropout)
        self.linear2 = torch.nn.Linear(self._d_ff, self._d_model)

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x


class CausalConv1D(torch.nn.Conv1d):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int,
        groups: int = 1,
        bias: bool = True,
    ):
        self._left_padding = kernel_size - 1
        self._right_padding = 0

        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=0,
            groups=groups,
            bias=bias,
        )

    def forward(self, x):
        x = F.pad(x, pad=(self._left_padding, self._right_padding))
        return super().forward(x)

    def get_initial_state(self):
        raise NotImplementedError()

    def streaming_forward(self, x, state):
        raise NotImplementedError()


class ConformerConvolution(torch.nn.Module):
    def __init__(
        self,
        d_model: int,
        kernel_size: int,
    ):
        super().__init__()
        assert (kernel_size - 1) % 2 == 0
        self._d_model = d_model
        self._kernel_size = kernel_size
        self.pointwise_activation = lambda x: torch.nn.functional.glu(x, dim=1)

        self.pointwise_conv1 = torch.nn.Conv1d(
            in_channels=self._d_model, out_channels=self._d_model * 2, kernel_size=1, stride=1, padding=0, bias=True
        )
        self.depthwise_conv = CausalConv1D(
            in_channels=self._d_model,
            out_channels=self._d_model,
            kernel_size=self._kernel_size,
            stride=1,
            groups=self._d_model,
            bias=True,
        )

        # yep, batch_norm here is layer norm
        self.batch_norm = torch.nn.LayerNorm(self._d_model)
        self.activation = torch.nn.SiLU()
        self.pointwise_conv2 = torch.nn.Conv1d(
            in_channels=self._d_model, out_channels=d_model, kernel_size=1, stride=1, padding=0, bias=True
        )

    def forward(self, x, pad_mask=None):
        # x: [B, T, F]

        x = x.transpose(1, 2)
        x = self.pointwise_conv1(x)
        x = self.pointwise_activation(x)

        if pad_mask is not None:
            x = x.float().masked_fill(pad_mask.unsqueeze(1), 0.0)

        x = self.depthwise_conv(x)

        x = x.transpose(1, 2)
        x = self.batch_norm(x)
        x = x.transpose(1, 2)

        x = self.activation(x)
        x = self.pointwise_conv2(x)
        x = x.transpose(1, 2)
        return x

    def get_initial_state(self):
        raise NotImplementedError()

    def streaming_forward(self, x, state):
        raise NotImplementedError()


class ConformerLayer(torch.nn.Module):
    """A single block of the Conformer encoder.

    Args:
        d_model (int): input dimension of MultiheadAttentionMechanism and PositionwiseFeedForward
        d_ff (int): hidden dimension of PositionwiseFeedForward
        n_heads (int): number of heads for multi-head attention
        conv_kernel_size (int): kernel size for depthwise convolution in convolution module
        dropout (float): dropout probabilities for linear layers
        dropout_att (float): dropout probabilities for attention distributions
    """

    def __init__(
        self,
        d_model: int,
        d_ff: int,
        n_heads: int,
        conv_kernel_size: int,
        dropout: float,
        dropout_att: float,
        # pos_bias_u: torch.nn.Parameter,
        # pos_bias_v: torch.nn.Parameter,
        att_context_size: tuple[int, int],
    ):
        super().__init__()
        self._d_model = d_model
        self._d_ff = d_ff
        self._n_heads = n_heads
        self._conv_kernel_size = conv_kernel_size
        self._dropout = dropout
        self._dropout_att = dropout_att
        self._att_context_size = att_context_size

        self._fc_factor = 0.5

        self.norm_feed_forward1 = torch.nn.LayerNorm(self._d_model)
        self.feed_forward1 = ConformerFeedForward(d_model=self._d_model, d_ff=self._d_ff, dropout=self._dropout)
        
        self.norm_conv = torch.nn.LayerNorm(self._d_model)
        self.conv = ConformerConvolution(
            d_model=self._d_model,
            kernel_size=self._conv_kernel_size,
        )

        # multi-headed self-attention module
        self.norm_self_att = torch.nn.LayerNorm(self._d_model)

        # TODO: Add RelPositionMultiHeadAttention
        self.self_attn = RelPositionMultiHeadAttention(
            n_head=self._n_heads,
            n_feat=self._d_model,
            dropout_rate=self._dropout_att,
            # pos_bias_u=pos_bias_u,
            # pos_bias_v=pos_bias_v
        )

        self.norm_feed_forward2 = torch.nn.LayerNorm(self._d_model)
        self.feed_forward2 = ConformerFeedForward(d_model=self._d_model, d_ff=self._d_ff, dropout=self._dropout)

        self.dropout = torch.nn.Dropout(self._dropout)
        self.norm_out = torch.nn.LayerNorm(self._d_model)

    def forward(self, x, att_mask=None, pos_emb=None, pad_mask=None):
        """
        Args:
            x (torch.Tensor): input signals (B, T, d_model)
            att_mask (torch.Tensor): attention masks(B, T, T)
            pos_emb (torch.Tensor): (L, 1, d_model)
            pad_mask (torch.tensor): padding mask (B, T)
        Returns:
            x (torch.Tensor): (B, T, d_model)
            cache_last_channel (torch.tensor) : next cache for MHA layers (B, T_cache, d_model)
            cache_last_time (torch.tensor) : next cache for convolutional layers (B, d_model, T_cache)
        """
        residual = x
        x = self.norm_feed_forward1(x)
        x = self.feed_forward1(x)
        residual = residual + self.dropout(x) * self._fc_factor

        x = self.norm_self_att(residual)
        x = self.self_attn(query=x, key=x, value=x, mask=att_mask, pos_emb=pos_emb)

        residual = residual + self.dropout(x)

        x = self.norm_conv(residual)
        x = self.conv(x, pad_mask=pad_mask)
        residual = residual + self.dropout(x)

        x = self.norm_feed_forward2(residual)
        x = self.feed_forward2(x)
        residual = residual + self.dropout(x) * self._fc_factor

        x = self.norm_out(residual)

        return x

    def get_initial_state(self, chunk_size, left_context_chunks):
        raise NotImplementedError()

    def streaming_forward(self, x, pos_emb, state):
        raise NotImplementedError()


class ConformerEncoder(torch.nn.Module):
    def __init__(
        self,
        feat_in: int,
        n_layers: int,
        d_model: int,
        ff_expansion_factor: int,
        n_heads: int,
        subsampling_factor: int,
        subsampling_conv_channels: int,
        att_context_size: tuple[int, int],
        conv_kernel_size: int,
        pos_emb_max_len: int = 5000,
        dropout: float = 0.1,
        dropout_pre_encoder: float = 0.1,
        dropout_emb: float = 0.1,
        dropout_att: float = 0.0,
    ):
        super().__init__()
        self._feat_in = feat_in
        self._n_layers = n_layers
        self._d_model = d_model
        self._ff_expansion_factor = ff_expansion_factor
        self._n_heads = n_heads
        self._subsampling_factor = subsampling_factor
        self._subsampling_conv_channels = subsampling_conv_channels

        self._x_scale = math.sqrt(self._d_model)
        
        self._att_context_size = att_context_size
        self._conv_kernel_size = conv_kernel_size
        self._conv_context_size = [conv_kernel_size - 1, 0]
        self._pos_emb_max_len = pos_emb_max_len
        self._dropout = dropout
        self._dropout_pre_encoder = dropout_pre_encoder
        self._dropout_emb = dropout_emb
        self._dropout_att = dropout_att

        self.pre_encode = ConvSubsampling(
            subsampling_factor=self._subsampling_factor,
            feat_in=self._feat_in,
            feat_out=self._d_model,
            conv_channels=self._subsampling_conv_channels,
            activation=torch.nn.ReLU(),
        )

        self._feat_out = d_model

        self._d_head = self._d_model // self._n_heads


        self.pos_enc = RelPositionalEncoding(
            d_model=d_model,
            dropout_rate=dropout_pre_encoder,
            max_len=pos_emb_max_len,
            xscale=self._x_scale,
            dropout_rate_emb=dropout_emb,
        )
        self.pos_enc.extend_pe(pos_emb_max_len)

        self.layers = torch.nn.ModuleList()
        for i in range(n_layers):
            layer = ConformerLayer(
                d_model=self._d_model,
                d_ff=self._d_model * self._ff_expansion_factor,
                n_heads=self._n_heads,
                conv_kernel_size=self._conv_kernel_size,
                dropout=self._dropout,
                dropout_att=self._dropout_att,
                att_context_size=self._att_context_size,
            )
            self.layers.append(layer)

    def _create_masks(self, lengths, max_length, device):
        """
        Returns:
            Tuple of (pad_mask, att_mask)
            pad_mask: torch.Tensor, (B, T), bool. True means value should not be used
            att_mask: torch.Tensor, (B, T, T), bool. True means value should not be used
        """
        chunk_size = self._att_context_size[1] + 1
        
        # left_chunks_num specifies the number of chunks to be visible by each chunk on the left side
        left_chunks_num = self._att_context_size[0] // chunk_size

        att_mask = create_attn_mask(chunk_size, left_chunks_num, max_length, device)

        # pad_mask is the masking to be used to ignore paddings
        # [b, t]: pad_mask[i, j] = lengths[i] <= j
        pad_mask = torch.arange(0, max_length, device=device).expand(
            lengths.size(0), -1
        ) < lengths.unsqueeze(-1)

        
        # pad_mask_for_att_mask is the mask which helps to ignore paddings
        # [b, t, t]
        pad_mask_for_att_mask = pad_mask.unsqueeze(1).repeat([1, max_length, 1])
        pad_mask_for_att_mask = torch.logical_and(pad_mask_for_att_mask, pad_mask_for_att_mask.transpose(1, 2))

        # just in case
        assert att_mask.shape[1] == max_length
        assert att_mask.shape[2] == max_length
        assert att_mask.device == pad_mask_for_att_mask.device

        # paddings should also get ignored, so pad_mask_for_att_mask is used to ignore their corresponding scores
        att_mask = torch.logical_and(pad_mask_for_att_mask, att_mask)
        att_mask = ~att_mask
        pad_mask = ~pad_mask
        return pad_mask, att_mask

    
    def forward(self, features, lengths):
        """
            features: [B, F, T]
            lengths: [B]
        """
        features = torch.transpose(features, 1, 2)
        features, lengths = self.pre_encode(x=features, lengths=lengths)
        lengths = lengths.to(torch.int64)
        features, pos_emb = self.pos_enc(x=features)
        pad_mask, att_mask = self._create_masks(
            lengths=lengths,
            max_length=features.size(1),
            device=features.device,
        )

        for layer in self.layers:
            features = layer(
                x=features,
                att_mask=att_mask,
                pos_emb=pos_emb,
                pad_mask=pad_mask,
            )
            # return features, lengths

        features = torch.transpose(features, 1, 2)
        return features, lengths

    def get_initial_state(self):
        raise NotImplementedError()
    
    def streaming_forward(self, features, state):
        raise NotImplementedError()



In [None]:
class CtcDecoder(torch.nn.Module):
    def __init__(self, enc_output_size, tokenizer_settings):
        super().__init__()
        self._tokenizer_settings = tokenizer_settings
        self.decoder_layers = torch.nn.Sequential(
            torch.nn.Conv1d(
                in_channels=enc_output_size,
                out_channels=len(tokenizer_settings['token_to_piece']) + 1,
                kernel_size=1,
                stride=1
            )
        )

    def forward(self, enc_output, enc_lengths):
        return self.decoder_layers(enc_output), enc_lengths

    def decode(self, logits):
        # logits = logits.cpu().detach().numpy()
        logits = logits.transpose(2, 1)
        result = []
        for idx in range(len(logits)):
            tokens = list(map(int, logits[idx].max(dim=-1)[1].cpu().detach().numpy()))
            prediction = []
            prev_token = None
            for token in tokens:
                if token != prev_token and token != self._tokenizer_settings['blank_idx']:
                    prediction.append(self._tokenizer_settings['token_to_piece'][str(token)])
                prev_token = token
            result.append(''.join(prediction).replace(self._tokenizer_settings['special_symbol'], ' '))
        return result


In [None]:
encoder = ConformerEncoder(
    feat_in=80,
    n_layers=17,
    d_model=512,
    ff_expansion_factor=4,
    n_heads=8,
    subsampling_factor=8,
    subsampling_conv_channels=256,
    att_context_size=[70, 1],
    conv_kernel_size=9,
)

In [None]:
with open('../data/week12_data/token.json') as fp:
    tokenizer_settings = json.load(fp)

In [None]:
decoder = CtcDecoder(512, tokenizer_settings)
decoder = decoder.cpu().eval()

In [None]:
with open('../data/week12_data/encoder_state.pkl', 'rb') as fp:
    encoder.load_state_dict(pickle.load(fp))

In [None]:
with open('../data/week12_data/decoder_state.pkl', 'rb') as fp:
    decoder.load_state_dict(pickle.load(fp))

In [None]:
encoder = encoder.cpu().eval()

In [None]:

with open('../data/week12_data/audio.wav', 'rb') as fp:
    with wave.open(fp, 'r') as wfp:
        pcm_data = wfp.readframes(wfp.getnframes())

signal = np.frombuffer(pcm_data, dtype=np.int16)
signal = (signal.astype(np.float32) / 2. ** 15).astype(np.float32)

In [None]:
Audio('../data/week12_data/audio.wav')

In [None]:
# torch.set_num_threads(1)
# torch.set_num_interop_threads(1)

In [None]:
featurizer = FilterbankFeatures()

In [None]:
with torch.no_grad():
    features = featurizer(torch.tensor(signal))
    print(features.shape)

In [None]:
with torch.no_grad():
    encoded, encoded_len = encoder(features.unsqueeze(0), torch.tensor([features.size(1)]))
    print(encoded.shape)

In [None]:
with torch.no_grad():
    logits, logits_len = decoder(encoded, encoded_len)

In [None]:
with torch.no_grad():
    decoder.decode(logits)