# Transformer using PyTorch

Replicating the paper "Attention Is All You Need" from scratch using PyTorch.

Available here: https://arxiv.org/abs/1706.03762

In [2]:
import torch
import torch.nn as nn

import math


# Input Embeddings
In NLP, words/tokens are represented as integers (token IDs), but since neural networks cannot work directly with numbers, these are first converted into continuous dense vectors via embeddings.

The embedding layer:
- Maps each token ID to a learnable vector of dimension `d_model`
- Allows the model to learn semantic meaning - similar words/tokens get similar embeddings during training

Refer to section 3.4 of the paper.

In [None]:
class InputEmbeddings(nn.Module):
    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_model) #maps token IDs to vectors

    def forward(self, x):
        """
        Args:
            x: Tensor of token indices, shape (batch_size, seq_len)

        Returns:
            Embedded tensor, scaled, shape (batch_size, seq_len, d_model)
        
        """
        return self.embedding(x) * math.sqrt(self.d_model)  #Refer section 3.4 of the paper (This is primarily done to balance the scale of the embeddings and positional encodings, and helps stabilize training early on.)


## Positional Encoding
Transformers process tokens in parallel, with no built-in sense of order. But language is sequential:

"The cat sat on the mat" =/= "The mat cat on the sat"

Positional information is injected into the input to create awareness of the token position, using Positional Encoding. They have the same dimension `d_model` as the input embeddings.

Refer to section 3.5 of the paper.

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)

        #Create matrix of shape (seq_len, d_model)
        pe = torch.zeros(seq_len, d_model)
        