An implementation from scratch of the transformer architecture from the original paper [Attention is all you need](https://arxiv.org/abs/1706.03762)

In [1]:
import torch
from torch import nn
import torch.nn.functional as F

import math

In [2]:
def scaled_dot_product(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
    r"""Computes attention scores of a given sequence using scaled dot product from the
    original transformer paper.


    Parameters
    ----------
    query : torch.Tensor
        A tensor of size (batch_size, num_queries, d_model)
    key : torch.Tensor
        A tensor of size (batch_size, num_keys, d_model)
    value : torch.Tensor
        A tensor of size (batch_size, num_values, d_model)


    Info
    -----
    For computation efficiency, num_queries = num_keys = num_values
    """

    # Last dim size = embedding size
    k_dim = query.size(-1)

    # key becomes (batch_size, d_model, num_keys) so the output tensor will be (batch_size, num_queries, num_keys)
    omega_scores = torch.bmm(query, key.transpose(1, 2)) / math.sqrt(k_dim)
    weights = F.softmax(omega_scores, dim=-1)

    return weights.bmm(value)

In [3]:
class AttentionHead(nn.Module):
    r"""A single attention with an input of shape [batch_size, seq_len, embed_dim]
    and project it into a tensor of dimension head_dim. The output of each linear
    transformation is a tensor of shape [batch_size, seq_len, head_dim]
    """

    def __init__(self, embed_dim: int, head_dim: int):
        super().__init__()

        self.w_q = nn.Linear(embed_dim, head_dim)
        self.w_k = nn.Linear(embed_dim, head_dim)
        self.w_v = nn.Linear(embed_dim, head_dim)

    def forward(self, hidden_state: torch.Tensor):
        attention_output = scaled_dot_product(
            query=self.w_q(hidden_state),
            key=self.w_k(hidden_state),
            value=self.w_v(hidden_state),
        )

        return attention_output

It is not compulsory to make head_dim smaller than embed_dim, it is a common practice to choose it
as a multiple of embed_dim to make the computation constant across heads.  

In [9]:
class MultiHeadAttention(nn.Module):

    def __init__(self, config):
        super().__init__()

        embed_dim: int = config.hidden_size
        num_heads: int = config.num_attention_heads
        head_dim: int = embed_dim // num_heads

        self.heads = nn.ModuleList(
            [AttentionHead(embed_dim, head_dim) for _ in range(num_heads)]
        )
        self.output_linear = nn.Linear(embed_dim, embed_dim)

    def forward(self, hidden_state: torch.Tensor):
        # This will produces num_heads tensors of shape [batch_size, seq_len, head_dim]
        # However, num_heads * head_dim = embed_dim, thus concatenating on the last dim
        # will produce a final tensor of shape [batch_size, seq_len, embed_dim] that
        # matches self.output_linear input shape
        x = torch.cat([h(hidden_state) for h in self.heads], dim=-1)

        # Adjusting the shape to have the same size as the input sequence
        x = self.output_linear(x)

        return x

In [24]:
class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.linear = nn.Sequential(
            nn.Linear(config.hidden_size, config.intermediate_size),
            nn.GELU(),
            nn.Linear(config.intermediate_size, config.hidden_size),
        )

        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, x: torch.Tensor):
        x = self.linear(x)
        x = self.dropout(x)

        return x

In [28]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.layer_norm_1 = nn.LayerNorm(config.hidden_size)
        self.layer_norm_2 = nn.LayerNorm(config.hidden_size)
        self.attention = MultiHeadAttention(config)
        self.feed_forward = FeedForward(config)

    def forward(self, x: torch.Tensor):
        # norm
        hidden_state = self.layer_norm_1(x)

        # skip connection
        x = x + self.attention(hidden_state)

        # fff + skip connection
        x = x + self.feed_forward(self.layer_norm_2(x))

        return x

In [29]:
class TransformerEmbedding(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.token_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        self.position_embeddings = nn.Embedding(
            config.max_position_embeddings, config.hidden_size
        )
        self.norm = nn.LayerNorm(config.hidden_size)
        self.dropout = nn.Dropout()

    def forward(self, input_ids: torch.Tensor):
        # create positional IDs for the sequence
        seq_len = input_ids.size(1)
        position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)

        # token + position embeddings
        token_embeddings = self.token_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)

        # Final embeddings
        embeddings = token_embeddings + position_embeddings
        embeddings = self.norm(embeddings)
        embeddings = self.dropout(embeddings)

        return embeddings

In [32]:
class TransformerEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.embeddings = TransformerEmbedding(config)
        self.layers = nn.ModuleList(
            [TransformerEncoderLayer(config) for _ in range(config.num_hidden_layers)]
        )

    def forward(self, x: torch.Tensor):
        x = self.embeddings(x)
        for layer in self.layers:
            x = layer(x)

        return x