# Transformer

### Imports

In [2]:
import os
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from einops import rearrange

import pytorch_lightning as pl
pl.seed_everything(42)

Seed set to 42


42

### Multi-head attention

In [3]:
class MultiHeadAttention(nn.Module):

    def __init__(self, n_head, d_model, d_k, d_v):
        super().__init__()
        self.n_head = n_head
        self.d_k = d_k

        self.wq = nn.Linear(d_model, n_head * d_k)
        self.wk = nn.Linear(d_model, n_head * d_k)
        self.wv = nn.Linear(d_model, n_head * d_v)

        self.linear = nn.Linear(n_head * d_v, d_model)  

    def forward(self, q, k, v, mask=None):
        q = rearrange(self.wq(q), 'b t (h k) -> b h t k', h=self.n_head)    
        k = rearrange(self.wk(k), 'b s (h k) -> b h s k', h=self.n_head)
        v = rearrange(self.wv(v), 'b s (h v) -> b h s v', h=self.n_head)
        attn = torch.einsum('bhtk, bhsk -> bhts', q, k) / np.sqrt(self.d_k)
        if mask is not None:
            attn = attn.masked_fill(mask==0, -np.inf)
        attn = F.softmax(attn, dim=3)
        out = torch.einsum('bhts, bhsv -> bhtv', attn, v)
        out = rearrange(out, 'b h t v -> b t (h v)')
        out = self.linear(out)
        return out, attn

### Encoder block

In [4]:
class EncoderBlock(nn.Module):

    def __init__(self, n_head, d_model, d_k, d_v, d_ff, dropout_prob=0.0):
        super().__init__()
        self.attn = MultiHeadAttention(n_head, d_model, d_k, d_v)
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout_prob)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.Dropout(dropout_prob),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )

    def forward(self, x, mask=None):
        attn_out, _ = self.attn(x, x, x, mask)
        x = self.ln1(x + self.dropout(attn_out))

        mlp_out = self.mlp(x)
        x = self.ln2(x + self.dropout(mlp_out))

        return x


### Transformer encoder

In [15]:
class TransformerEncoder(nn.Module):

    def __init__(self, n_enc_blocks, **enc_block_kwargs):
        super().__init__()
        self.blocks = nn.ModuleList(
            [EncoderBlock(**enc_block_kwargs) \
                for _ in range(n_enc_blocks)]
            )

    def forward(self, x, mask=None):
        for block in self.blocks:
            x = block(x, mask)
        return x

    def get_attn_maps(self, x, mask=None):
        attn_maps = []
        for block in self.blocks:
            _, attn_map = block.attn(x, x, x, mask)
            attn_maps.append(attn_map)
            x = block(x, mask)
        return attn_maps

### Positional encoding

In [14]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_len=5000):
        super().__init__()

        pe = torch.zeros(max_len, d_model)
        position = rearrange(torch.arange(0, max_len, dtype=torch.float), 'i -> i 1')
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = rearrange(pe, 'max_len d_model -> 1 max_len d_model')
        self.register_buffer('pe', pe, persistent=False)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]