In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy

# 1. Multi-head Attention Block

A crucial building block of transformers is the Attention mechanism. This step computes the attention between each pair of positions in a sequence and consists of multiple "attention heads", each mapping to a different feature space to capture different aspects of the input sequence.

## 1.1 Scaled Dot Product Attention
Each head of multi-head attention is computes the scaled dot product attention which can be described as:

$$ \text{SoftMax} \left( \frac{Q}{T_e}K^T\right) V, \quad T_e = \sqrt{d_k}$$

In [13]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        """
        d_model -> dimensionality of the embedding space

        This models implement dk as d_model // n_heads 
        and dv = dk.
        """
        super(ScaledDotProductAttention, self).__init__()

        # Initialize dimenions
        self.softmax = nn.Softmax(dim = -1)

    def forward(self, q, k, v, mask = None, e = 1e-12):
        # Here the input should be a 4 dimensional tensor of size
        # [batch size, number heads, sequence length, embedding Q,K,V]
        batch_size, head, length, d_tensor = k.size()

        # 1. First step is to compute QK^T so we need to transpose K
        k_t = k.transpose(2, 3)
        score = torch.matmul(q, k_t) / math.sqrt(d_tensor)

        # 2. Now we need to implement optional masking
        if mask is not None:
            score.masked_fill(mask == 0, -1e-8)

        # SoftMax the scores to obtain probabilities
        score = self.softmax(score)

        # This value must then be multiplied by the values
        attention = score @ v

        return attention, score # Score is merely for visualization

## 1.2 Multi-Head Attention

In [14]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super(MultiHeadAttention, self).__init__()

        self.d_model = d_model
        self.n_heads = n_heads
        self.attention = ScaledDotProductAttention()

        # Since this models implement dk as d_model // n_heads 
        # the output of the linear is d*n_heads = d_model
        self.wQ = nn.Linear(d_model, d_model, bias = False)
        self.wK = nn.Linear(d_model, d_model, bias = False)
        self.wV = nn.Linear(d_model, d_model, bias = False)
        self.wO = nn.Linear(d_model, d_model, bias = False)

    def split(self, tensor):
        """ 
        This function is used to split the Q, K, V 
        matrices by their respective number of heads

        tensor: [batch_size, sequence_length, dk*n_heads]
        returns: [batch_size, n_heads, sequence_length, dk]
        """
        batch_size, seq_length, d_model = tensor.size()
        dk = d_model // self.n_heads

        tensor = tensor.view(batch_size, seq_length, self.n_heads, dk)
        tensor = tensor.transpose(1, 2)

        return tensor
    
    def concat(self, tensor):
        """
        Perform the inverse operation of the split function

        tensor: [batch_size, n_heads, sequence_length, dk]
        returns: [batch_size, sequence_length, dk*n_heads]
        """
        batch_size, n_heads, seq_length, dk = tensor.size()
        d_model = dk*n_heads

        # Not sure why contiguous is needed. Try to remove
        tensor = tensor.transpose(1, 2).contiguous() 
        tensor = tensor.view(batch_size, seq_length, d_model)

        return tensor

    def forward(self, q, k, v, mask = None):
        # 1. Dot Product with the weight matrices
        q, k, v = self.Wq(q), self.wK(k), self.wV(v)

        # 2. Now we need to split the tensor by the number of heads
        q, k, v = q.split(q), k.split(k), v.split(v)

        # 3. Calculate Attention
        out, attn = self.attention(q, k, v, mask = mask)

        # 4. Concatenate and pass to a linear layer
        out = self.concat(out)
        out = self.wO(out)

        # 5. Here should be the implementation to visualize the attention map

        return out

# 2. Feed Forward Layer

This is just a straight forward implementation of a multilayer perceptron (MLP) using two layers. This is place right after the multi-head attention block and its input is the output of the attention block added to a residual connection from the original embeddings.

In [16]:
class FFN(nn.Module):
    def __init__(self, d_model, d_hidden, dropout = 0.1):
        super(FFN, self).__init__()
        self.layer1 = nn.linear(d_model, d_hidden)
        self.layer2 = nn.Linear(d_hidden, d_model)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """
        Not sure dropout is where it should be.
        """
        x = self.layer1(x)
        x = self.relu(x)
        x = self.layer2(x)
        x = self.dropout(x)

        return x

## 2.1 Implement Layer Normalization

In [None]:
class LayerNorm(nn.Module):
    def __init__(self, d_model, eps = 1e-12):
        super(LayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.oens(d_model))
        self.eps = eps

    def forward(self, x):
        # Normalize last dimension
        mean = x.mean(-1, keepdim = True)
        var = x.var(-1, unbiased = True, keepdim = True)

        out = (x - mean) / torch.sqrt(var + self.eps)

        # This does nothing, it multiplies by one and adds zero
        out = self.gamma * out + self.beta

        return out

# 3. Encoder/Decoders

All the layers necessary for the encoder and decoder blocks have been implemented. It is now a matter of assembling these blocks

## 3.1 Encoder

Starting with the encoder, this piece takes as inputs the embedded sequence, passes them through multi-head attention, followed by a feed forward layer. Both these layers have residual connections in between.