# Conformer Model

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn.init import xavier_normal_
from torchaudio.functional import rnnt_loss
import math

### Positional Encoding

In [None]:
class PositionalEncoding(nn.Module):
    """
    Positional Encoding proposed in "Attention Is All You Need".
    Since transformer contains no recurrence and no convolution, in order for the model to make
    use of the order of the sequence, we must add some positional information.
    "Attention Is All You Need" use sine and cosine functions of different frequencies:
        PE_(pos, 2i)    =  sin(pos / power(10000, 2i / d_model))
        PE_(pos, 2i+1)  =  cos(pos / power(10000, 2i / d_model))
    """
    def __init__(self, input_dim, max_len = 5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, input_dim, requires_grad=False)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, input_dim, 2).float() * -(math.log(10000.0) / input_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, length: int):
        return self.pe[:, :length]

## Residual Connection Module

In [None]:
class ResidualConnectionModule(nn.Module):
    """
    Residual Connection Module.
    outputs = (module(inputs) x module_factor + inputs x input_factor)
    """
    def __init__(
            self,
            module: nn.Module,
            module_factor: float = 1.0,
            input_factor: float = 1.0,
    ):
        super(ResidualConnectionModule, self).__init__()
        self.module = module
        self.module_factor = module_factor
        self.input_factor = input_factor

    def forward(self, inputs, mask = None):
        if mask is None:
            return (self.module(inputs) * self.module_factor) + (inputs * self.input_factor)
        else:
            return (self.module(inputs, mask) * self.module_factor) + (inputs * self.input_factor)

## Multi-Headed Attention Module


#### Helper Functions

In [None]:
def relative_shift(self, pos_score):
    batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
    zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
    padded_pos_score = torch.cat([zeros, pos_score], dim=-1)

    padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1)
    pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)

    return pos_score

def position_scaled_dot_product(q, k, v, pos_enc, mask=None):
    d_k = q.size()[-1]
    content = torch.matmul(q, k.transpose(-2, -1))          # (seq_len, head_dim) x (head_dim, seq_len)
    positional = torch.matmul(q, pos_enc.transpose(-2, -1)) # (seq_len, head_dim) x (head_dim, seq_len)

    positional = relative_shift(positional)
    attn_logits = (content + positional) / math.sqrt(d_k)   # (seq_len, seq_len) + (seq_len, seq_len)
    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
    attention = F.softmax(attn_logits, dim=-1)
    values = torch.matmul(attention, v)                     # (seq_len, seq_len) x (seq_len, head_dim)
    return values, attention

### Relative Multi-Head Attention

In [None]:
class RelativeMultiHeadAttention(nn.Module):
  def __init__(self, input_dim, heads):
    super(RelativeMultiHeadAttention, self).__init__()
    self.input_dim = input_dim
    self.heads = heads
    self.head_dim = input_dim // heads
    assert (input_dim % heads == 0), "Input dims must be divisible by head num"

    self.qkv_projection = nn.Linear(input_dim, 3 * input_dim, bias=False)
    self.pos_projection = nn.Linear(input_dim, input_dim, bias=False)
    self.final_linear = nn.Linear(input_dim, input_dim,bias=False)

    self.reset_parameters()
  
  def reset_parameters(self):
    nn.init.xavier_uniform_(self.qkv_projection.weight)
    nn.init.xavier_uniform_(self.final_linear.weight)

  def forward(self, x, positional, mask=None, return_attention=False):
    batch_size, seq_len, input_dim = x.size()
    qkv = self.qkv_projection(x)
    pos = self.pos_projection(positional)

    # Separate heads
    qkv = qkv.reshape(batch_size, seq_len, self.heads, 3*self.head_dim)
    pos = pos.reshape(batch_size, seq_len, self.heads, self.head_dim)
    qkv = qkv.permute(0, 2, 1, 3) # [Batch, Head, SeqLen, Dims]
    pos = pos.permute(0, 2, 1, 3)

    q, k, v = qkv.chunk(3, dim=-1)

    values, attention = position_scaled_dot_product(q, k, v, pos, mask=mask)
    values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]
    values.reshape(batch_size, seq_len, input_dim)

    out = self.final_linear(values)

    if return_attention:
      return out, attention
    else:
      return out



### Module Definition

In [None]:
class MultiheadAttentionModule(nn.Module):
  def __init__(self, input_dim, heads = 8, dropout = 0.1):
    super(MultiheadAttentionModule, self).__init__()
    
    self.pe = PositionalEncoding(input_dim),
    self.norm = nn.LayerNorm(input_dim),
    self.attention = RelativeMultiHeadAttention(input_dim, heads),
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, mask = None):
    batch_size, seq_len, _ = x.size()
    pos_enc = self.pe(seq_len)
    pos_enc = pos_enc.repeat(batch_size, 1, 1)

    input = self.norm(x)
    attn = self.attention(input, pos_enc, mask)

    return self.dropout(attn)

## Feed Forward Module

In [None]:
class FeedForward(nn.Module):
  def __init__(self, input_dim, forward_expansion, dropout):
    self.linear1 = nn.Linear(input_dim, input_dim * forward_expansion)
    self.dropout = nn.Dropout(dropout)
    self.activation = nn.ReLU(inplace=True)
    self.linear2 = nn.Linear(input_dim * forward_expansion, input_dim)
  
  def forward(self, x):
    x = self.linear1(x)
    x = self.dropout(x)
    x = self.activation(x)
    out = self.linear2(x)
    return out

## Convolution Module

### Swish Activation

In [None]:
class Swish(nn.Module):
    def __init__(self):
        super(Swish, self).__init__()

    def forward(self, inputs):
        return inputs * inputs.sigmoid()

### GLU Activation

In [None]:
class GLU(nn.Module):
    def __init__(self, activation_dim):
        super(GLU, self).__init__()
        self.activation_dim = activation_dim

    def forward(self, inputs):
        outputs, gate = inputs.chunk(2, dim=self.activation_dim)
        return outputs * gate.sigmoid()

### Module Definition

In [None]:
class ConvolutionModule(nn.Module):
    def __init__(
            self,
            input_dim,
            kernel_size = 31,
            expansion_factor = 2,
            dropout = 0.1,
    ):
        super(ConvolutionModule, self).__init__()
        assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
        assert expansion_factor == 2, "Only expansion_factor 2 allowed"

        self.layer_norm  = nn.LayerNorm(input_dim)
        self.pointwise1 = nn.nn.Conv1d(
            in_channels=input_dim,
            out_channels=input_dim * expansion_factor,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True,
        )

        self.glu = GLU(activation_dim=1)

        self.depthwise = nn.nn.Conv1d(
            in_channels=input_dim,
            out_channels=input_dim,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True,
        )

        self.batch_norm = nn.BatchNorm1d(input_dim)

        self.swish = Swish()

        self.pointwise2 = nn.nn.Conv1d(
            in_channels=input_dim,
            out_channels=input_dim,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True,
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.layer_norm(x)
        x = x.transpose(1,2)
        x = self.pointwise1(x)
        x = self.glu(x)
        x = self.depthwise(x)
        x = self.batch_norm(x)
        x = self.swish(x)
        x = self.pointwise2(x)
        x = self.dropout(x)

        return x.transpose(1, 2)

## Conformer Block

In [None]:
class ConformerBlock(nn.Module):
  def __init__(
      self,
      encoder_dim,
      attention_heads = 8,
      half_step_residual = True,
      ff_expansion = 4,
      ff_dropout = 0.1,
      attn_dropout = 0.1,
      conv_expansion = 2,
      conv_dropout = 0.1,
      conv_kernel = 31
  ):
    super(ConformerBlock, self).__init__()
    if half_step_residual:
      self.ff_residual = 0.5
    else:
      self.ff_residual = 1

    self.sequential([
        ResidualConnectionModule(
            module=FeedForward(
                input_dim=encoder_dim,
                forward_expansion=ff_expansion.ff_expansion,
                dropout=ff_dropout
            ),
            module_factor=self.ff_residual
        ),
        ResidualConnectionModule(
            module=MultiheadAttentionModule(
                input_dim=encoder_dim,
                heads=attention_heads,
                dropout=attn_dropout
            )
        ),
        ResidualConnectionModule(
            module=ConvolutionModule(
                input_dim=encoder_dim,
                kernel_size=conv_kernel,
                expansion_factor=conv_expansion,
                dropout=conv_dropout
            )
        ),
        ResidualConnectionModule(
            module=FeedForward(
                input_dim=encoder_dim,
                forward_expansion=ff_expansion.ff_expansion,
                dropout=ff_dropout
            ),
            module_factor=self.ff_residual
        ),
        nn.LayerNorm(encoder_dim)
    ])
    
    def forward(self, inputs):
      return self.sequential(inputs)

## Conformer Encoder

### Convolutional Subsampling

Masked 2d Subsampling

In [None]:
class MaskConv2d(nn.Module):
    r"""
    Masking Convolutional Neural Network
    Adds padding to the output of the module based on the given lengths.
    This is to ensure that the results of the model do not change when batch sizes change during inference.
    Input needs to be in the shape of (batch_size, channel, hidden_dim, seq_len)
    Refer to https://github.com/SeanNaren/deepspeech.pytorch/blob/master/model.py
    Copyright (c) 2017 Sean Naren
    MIT License
    Args:
        sequential (torch.nn): sequential list of convolution layer
    Inputs: inputs, seq_lengths
        - **inputs** (torch.FloatTensor): The input of size BxCxHxT
        - **seq_lengths** (torch.IntTensor): The actual length of each sequence in the batch
    Returns: output, seq_lengths
        - **output**: Masked output from the sequential
        - **seq_lengths**: Sequence length of output from the sequential
    """
    def __init__(self, sequential: nn.Sequential) -> None:
        super(MaskConv2d, self).__init__()
        self.sequential = sequential

    def forward(self, inputs: Tensor, seq_lengths: Tensor):
        output = None

        for module in self.sequential:
            output = module(inputs)
            mask = torch.BoolTensor(output.size()).fill_(0)

            if output.is_cuda:
                mask = mask.cuda()

            seq_lengths = self._get_sequence_lengths(module, seq_lengths)

            for idx, length in enumerate(seq_lengths):
                length = length.item()

                if (mask[idx].size(2) - length) > 0:
                    mask[idx].narrow(dim=2, start=length, length=mask[idx].size(2) - length).fill_(1)

            output = output.masked_fill(mask, 0)
            inputs = output

        return output, seq_lengths

    def _get_sequence_lengths(self, module: nn.Module, seq_lengths: Tensor) -> Tensor:
        r"""
        Calculate convolutional neural network receptive formula
        Args:
            module (torch.nn.Module): module of CNN
            seq_lengths (torch.IntTensor): The actual length of each sequence in the batch
        Returns: seq_lengths
            - **seq_lengths**: Sequence length of output from the module
        """
        if isinstance(module, nn.Conv2d):
            numerator = seq_lengths + 2 * module.padding[1] - module.dilation[1] * (module.kernel_size[1] - 1) - 1
            seq_lengths = numerator.float() / float(module.stride[1])
            seq_lengths = seq_lengths.int() + 1

        elif isinstance(module, nn.MaxPool2d):
            seq_lengths >>= 1

        return seq_lengths.int()

Convolutional 2d Extractor

In [None]:
class Conv2dExtractor(nn.Module):
    r"""
    Provides inteface of convolutional extractor.
    Note:
        Do not use this class directly, use one of the sub classes.
        Define the 'self.conv' class variable.
    Inputs: inputs, input_lengths
        - **inputs** (batch, time, dim): Tensor containing input vectors
        - **input_lengths**: Tensor containing containing sequence lengths
    Returns: outputs, output_lengths
        - **outputs**: Tensor produced by the convolution
        - **output_lengths**: Tensor containing sequence lengths produced by the convolution
    """
    supported_activations = {
        'hardtanh': nn.Hardtanh(0, 20, inplace=True),
        'relu': nn.ReLU(inplace=True),
        'elu': nn.ELU(inplace=True),
        'leaky_relu': nn.LeakyReLU(inplace=True),
        'gelu': nn.GELU(),
        'swish': Swish(),
    }

    def __init__(self, input_dim: int, activation: str = 'hardtanh') -> None:
        super(Conv2dExtractor, self).__init__()
        self.input_dim = input_dim
        self.activation = Conv2dExtractor.supported_activations[activation]
        self.conv = None

    def get_output_lengths(self, seq_lengths: torch.Tensor):
        assert self.conv is not None, "self.conv should be defined"

        for module in self.conv:
            if isinstance(module, nn.Conv2d):
                numerator = seq_lengths + 2 * module.padding[1] - module.dilation[1] * (module.kernel_size[1] - 1) - 1
                seq_lengths = numerator.float() / float(module.stride[1])
                seq_lengths = seq_lengths.int() + 1

            elif isinstance(module, nn.MaxPool2d):
                seq_lengths >>= 1

        return seq_lengths.int()

    def get_output_dim(self):
        factor = ((self.input_dim - 1) // 2 - 1) // 2
        output_dim = self.out_channels * factor
        return output_dim

    def forward(self, inputs: Tensor, input_lengths: Tensor):
        r"""
        inputs: torch.FloatTensor (batch, time, dimension)
        input_lengths: torch.IntTensor (batch)
        """
        outputs, output_lengths = self.conv(inputs.unsqueeze(1).transpose(2, 3), input_lengths)

        batch_size, channels, dimension, seq_lengths = outputs.size()
        outputs = outputs.permute(0, 3, 1, 2)
        outputs = outputs.view(batch_size, seq_lengths, channels * dimension)

        return outputs, output_lengths

Convolutional 2d Subsample Module

In [None]:
class Conv2dSubsampling(Conv2dExtractor):
    r"""
    Convolutional 2D subsampling (to 1/4 length)
    Args:
        input_dim (int): Dimension of input vector
        in_channels (int): Number of channels in the input vector
        out_channels (int): Number of channels produced by the convolution
        activation (str): Activation function
    Inputs: inputs
        - **inputs** (batch, time, dim): Tensor containing sequence of inputs
        - **input_lengths** (batch): list of sequence input lengths
    Returns: outputs, output_lengths
        - **outputs** (batch, time, dim): Tensor produced by the convolution
        - **output_lengths** (batch): list of sequence output lengths
    """
    def __init__(
            self,
            input_dim: int,
            in_channels: int,
            out_channels: int,
            activation: str = 'relu',
    ) -> None:
        super(Conv2dSubsampling, self).__init__(input_dim, activation)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.conv = MaskConv2d(
            nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2),
                self.activation,
                nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2),
                self.activation,
            )
        )

    def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor):
        outputs, output_lengths = super().forward(inputs, input_lengths)

        return outputs, output_lengths

### Conformer Encoder Module

In [None]:
class ConformerEncoder(nn.Module):
  def __init__(self,
      num_layers=16,
      input_dim=80,
      input_dropout=0.1,
      encoder_dim=512,
      attention_heads=8,
      half_step_residual = True,
      ff_expansion = 4,
      ff_dropout = 0.1,
      attn_dropout = 0.1,
      conv_expansion = 2,
      conv_dropout = 0.1,
      conv_kernel = 32
  ):
    super(ConformerEncoder, self).__init__()
    self.conv_subsample = Conv2dSubsampling(input_dim, in_channels=1, out_channels=encoder_dim)
    self.input_projection = nn.Sequential(
            nn.Linear(self.conv_subsample.get_output_dim(), encoder_dim),
            nn.Dropout(p=input_dropout),
        )
    self.layers = nn.ModuleList([
      ConformerBlock(
        encoder_dim=encoder_dim,
        attention_heads=attention_heads,
        half_step_residual=half_step_residual,
        ff_expansion=ff_expansion,
        ff_dropout=ff_dropout,
        attn_dropout=attn_dropout,
        conv_expansion=conv_expansion,
        conv_dropout=conv_dropout,
        conv_kernel=conv_kernel
        ) for _ in range(num_layers)
    ])
  def forward(self, inputs, input_lengths):
    outputs, output_lengths = self.conv_subsample(inputs, input_lengths)
    outputs = self.input_projection(outputs)

    for layer in self.layers:
      outputs = layer(outputs)

    return outputs, output_lengths

## Conformer Decoder

In [7]:
class RNNDecoder(nn.Module):
  supported_rnns = {
    'lstm': nn.LSTM,
    'gru': nn.GRU,
    'rnn': nn.RNN,
  }
  def __init__(self,
               rnn_type = 'lstm',
               num_classes=29,
               hidden_state_dim=320,
               num_layers=1,

               ):
    super()
    self.embedding = nn.Embedding(num_classes, hidden_state_dim)
    rnn_cell = self.supported_rnns[rnn_type.lower()]
    self.rnn = rnn_cell(
            input_size=hidden_state_dim,
            hidden_size=hidden_state_dim,
            num_layers=num_layers,
            bias=True,
            batch_first=True,
            dropout=dropout_p,
            bidirectional=False,
    )
    self.out_proj = nn.Linear(hidden_state_dim, output_dim)