# Introduction

Conformers (https://arxiv.org/abs/2005.08100) are a powerful neural network architecture that takes advantage of both the local reasoning of convolutional networks with the global connections that transformers are able to provide. They have shown state-of-the-art results in automated speech recognition (ASR)

In this notebook, we will use components that you've learned about in previous parts as well as some new ones to implement the convolution module of a Conformer, and then combine that with the other modules to form the entire Conformer block as described in the paper.

Then, in the next notebook, we will explore one powerful application of Conformer models: speech recognition, where local relations and global context are both key to transcribing audio into text.

In [None]:
# Setup libraries and helpers
!pip install einops

import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange
import random
import numpy as np

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def _set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


## Part 1: Convolution Module

We start by implementing our Convolution module, which contains several components that you had some practice with earlier.

For your reference, this is the illustration of the module provided by the paper:

![conv module](https://drive.google.com/uc?export=view&id=1aHxwDy0RjBEhNPeRzWwoQMrAUG4NZizA)

First, fill in the blanks to complete the pointwise and depthwise convolution modules.

HINT: For speech data, what would each of the dimensions represent? Therefore, would it make sense to use `nn.Conv1d` or `nn.Conv2d`?

In [None]:
class PointwiseConv(nn.Module):
    def __init__(
        self, 
        in_channels: int, 
        out_channels: int, 
        stride=1, 
        padding=0, 
        bias=True
    ):
        super(PointwiseConv, self).__init__()
        ### YOUR CODE HERE ###
        self.conv = None

    def forward(self, x):
        ### YOUR CODE HERE ###
        return None

class DepthwiseConv(nn.Module):
    def __init__(
        self, 
        in_channels: int, 
        out_channels: int,
        kernel_size: int, 
        stride=1, 
        padding=0, 
        bias=False
    ):
        super(DepthwiseConv, self).__init__()
        ### YOUR CODE HERE ###
        self.conv = None 

    def forward(self, x):
        ### YOUR CODE HERE ###
        return None


In [None]:
# Test your implementation
_set_seed(2023)
p_model = PointwiseConv(2, 3)
x = torch.tensor([[[1., 2.], [3., 4.]], 
                  [[5., 6.], [7., 8.]]])
output = p_model(x)
assert output.shape == torch.Size([2, 3, 2])
assert torch.allclose(
    output,
    torch.tensor(
        [
            [[0.7040, 0.9148],
            [0.1491, 0.7541],
            [1.6921, 2.4444]],

            [[1.5473, 1.7581],
            [2.5691, 3.1741],
            [4.7013, 5.4536]]
        ],
        dtype=torch.float
    ),
    rtol=1e-03
)

d_model = DepthwiseConv(2, 4, 2)
output = d_model(x)
assert output.shape == torch.Size([2, 4, 1])
assert torch.allclose(
    output,
    torch.tensor([
            [[ 0.4918], [-0.1206], [-0.3651], [-1.3516]],
            [[ 0.8677], [-1.4797], [-0.2528], [-2.6693]]
        ],
        dtype=torch.float
    ),
    rtol=1e-03
)


Now let's implement our activation functions. Recall that the swish function is
$$y=\frac{x}{1+e^{-x}}$$

GLU stands for Gated Linear Unit, and was first introduced in [Language Modeling with Gated Convolutional Networks](https://arxiv.org/abs/1612.08083). 

The idea is similar to self-attention, where we use one view of the data (query vs gate) in order to weigh another view of that data (value vs output). In this case, instead of taking the product of the query and key, GLU takes the sigmoid in order to generate "gates", values between 0 and 1, that determine how much of each output passes through.

$$h_l(\mathbf X)=(\mathbf X*\mathbf W+\mathbf b)\otimes \sigma(\mathbf X*\mathbf V+\mathbf c)$$

($\otimes$ is the element-wise product)

The function defined in the paper (as seen above) applies a linear transformation to the input $\mathbf X$ to generate the output $\mathbf X*\mathbf W+\mathbf b$ and the gate $\mathbf X*\mathbf V+\mathbf c$. In our implementation, we will instead split our input into two equally-sized matrices to use as our output and gate (with the assumption that length along the split dimension is even).

In [None]:
class Swish(nn.Module):
    def __init__(self):
        super(Swish, self).__init__()

    def forward(self, x):
        ### YOUR CODE HERE ###
        return None

class GLU(nn.Module):
    def __init__(self, dim: int):
        super(GLU, self).__init__()
        self.dim = dim

    def forward(self, x):
        ### START OF YOUR CODE ###
        outputs = None
        gate = None
        return None
        ### END OF YOUR CODE ###

In [None]:
# Test your implementation
x = torch.tensor([3., 1., 4., 1., 5., 9., 2., 6.])

swish = Swish()
output = swish(x)
assert output.shape == torch.Size([8])
assert torch.allclose(
    output,
    torch.tensor([2.8577, 0.7311, 3.9281, 0.7311, 4.9665, 8.9989, 1.7616, 5.9852]),
    rtol=1e-04
)

glu = GLU(0)
output = glu(x)
assert output.shape == torch.Size([4])
assert torch.allclose(
    output,
    torch.tensor([2.9799, 0.9999, 3.5232, 0.9975]),
    rtol=1e-04
)

Now we connect them to form our Convolution and Feed Forward modules. 

The convolution module is already implemented for you, but you need to fill in the blanks to complete the feed forward module, which is depicted here.

![ff module](https://drive.google.com/uc?export=view&id=1oAYXoUbKSh-Pp7-bIvu4rCpM4WqBXqz_)


In [None]:
class Transpose(nn.Module):
    """Transposes the time and channel dimensions for use in nn.Sequential"""
    def forward(self, x):
        return x.transpose(1, 2)


class ConvModule(nn.Module):
    def __init__(
        self, 
        in_channels: int, 
        kernel_size: int, 
        dropout: float
    ):
        super(ConvModule, self).__init__()

        assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
        padding = (kernel_size-1) // 2

        self.sequential = nn.Sequential(
            nn.LayerNorm(in_channels),
            Transpose(),
            PointwiseConv(in_channels, in_channels * 2),
            GLU(dim=1),
            DepthwiseConv(in_channels, in_channels, kernel_size, padding=padding),
            nn.BatchNorm1d(in_channels),
            Swish(),
            PointwiseConv(in_channels, in_channels),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.sequential(x).transpose(1, 2)

class FeedForwardModule(nn.Module):
    def __init__(
        self, 
        in_dim: int, 
        hidden_dim: int, 
        out_dim: int, 
        dropout: float
    ):
        super(FeedForwardModule, self).__init__()
        self.sequential = nn.Sequential(
            ### YOUR CODE HERE ###
        )

    def forward(self, x):
        return self.sequential(x)

In [None]:
# Test your implementation

_set_seed(4)
ff = FeedForwardModule(2, 3, 2, 0.2)
x = torch.tensor([[9., 8.], [7., 6.], [5., 4.], [3., 2.]])
output = ff(x)
assert output.shape == torch.Size([4, 2])
assert torch.allclose(
    output,
    torch.tensor(
        [[0.0000, 0.3803],
        [1.7393, 0.0000],
        [0.4592, 0.3812],
        [0.5688, 0.3707]]
    ),
    rtol=1e-03
)

## Part 2: Convolution Block

Now we combine everything we've built to make our entire Convolution block!


Here is an implementation of the Multi-Head Self Attention module. You only need to run this block and use the module in your implementation.


In [None]:
#@title Attention Module

# Source: https://github.com/lucidrains/conformer
class AttentionModule(nn.Module):
    def __init__(
        self,
        dim,
        heads = 8,
        dim_head = 64,
        dropout = 0.,
        max_pos_emb = 512
    ):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads= heads
        self.scale = dim_head ** -0.5
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

        self.max_pos_emb = max_pos_emb
        self.rel_pos_emb = nn.Embedding(2 * max_pos_emb + 1, dim_head)

        self.norm = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, context = None, mask = None, context_mask = None):
        x = self.norm(x)

        n, device, h, max_pos_emb, has_context = x.shape[-2], x.device, self.heads, self.max_pos_emb, exists(context)
        context = default(context, x)

        q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        # shaw's relative positional embedding
        seq = torch.arange(n, device = device)
        dist = rearrange(seq, 'i -> i ()') - rearrange(seq, 'j -> () j')
        dist = dist.clamp(-max_pos_emb, max_pos_emb) + max_pos_emb
        rel_pos_emb = self.rel_pos_emb(dist).to(q)
        pos_attn = torch.einsum('b h n d, n r d -> b h n r', q, rel_pos_emb) * self.scale
        dots = dots + pos_attn

        if exists(mask) or exists(context_mask):
            mask = default(mask, lambda: torch.ones(*x.shape[:2], device = device))
            context_mask = default(context_mask, mask) if not has_context else default(context_mask, lambda: torch.ones(*context.shape[:2], device = device))
            mask_value = -torch.finfo(dots.dtype).max
            mask = rearrange(mask, 'b i -> b () i ()') * rearrange(context_mask, 'b j -> b () () j')
            dots.masked_fill_(~mask, mask_value)

        attn = dots.softmax(dim = -1)

        out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)
        return self.dropout(out)

Now, implement the Conformer block. Remember to add residual connections between the modules.

![conv block](https://drive.google.com/uc?export=view&id=1a4L1UcxjRdzIrPxnotdWS0zuUts4UHfy)

In [None]:
class ConformerBlock(nn.Module):
    def __init__(
        self,
        dim,
        heads,
        dim_head,
        ff_dim,
        kernel_size,
        attn_dropout,
        ff_dropout,
        conv_dropout
    ):
        super(ConformerBlock, self).__init__()

        ### START OF YOUR CODE  ###
        self.ff1 = None
        self.attn = None
        self.conv = None
        self.ff2 = None
        self.norm = None
        ### END OF YOUR CODE ###

    def forward(self, x, mask=None):
        out = x
        ### YOUR CODE HERE  ###
        return out

In [None]:
# Test your implementation

_set_seed(420)
model = ConformerBlock(3, 4, 4, 4, 3, 0.2, 0.3, 0.4)
x = torch.tensor([[[1., 2., 3.], [4., 5., 6.]], 
                  [[-1., 0., 1.], [9., 8., 7.]]])
output = model(x)
assert output.shape == torch.Size([2, 2, 3])
assert torch.allclose(
    output,
    torch.tensor(
        [[[ 1.0219, -1.3575,  0.3357],
         [-1.3136,  0.2031,  1.1105]],
        [[-0.8080, -0.6011,  1.4092],
         [ 1.3563, -0.3312, -1.0251]]]
    ),
    rtol=1e-03
)