# Transformers From Scratch

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import re
import math
import scipy
import numpy as np
from collections import Counter, defaultdict
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns



## 1. Byte Pair Encoding (BPE)

Byte Pair Encoding (BPE) is a widely used method for constructing efficient vocabularies in natural language processing models. It addresses several challenges, such as the need for compact vocabularies, handling rare or unknown words, and maintaining semantic coherence.

The key idea behind BPE is to break words down into characters initially and then iteratively merge the most frequently occurring pairs of characters or subwords. This process continues until a predefined vocabulary size is reached. The result is a set of subword units that can effectively represent both common and rare words.

For example, instead of storing "unhappiness" as a single word, BPE might break it into "un," "happi," and "ness." This allows the model to recombine these subword units to understand or generate related words like "happiness" or "unhappy."

BPE has been instrumental in enabling models like GPT to train efficiently on massive datasets while maintaining a vocabulary size (e.g., 50,257 tokens for GPT) that balances computational efficiency and linguistic expressiveness.

In [None]:
import re
from collections import Counter

# Build vocabulary by separating characters in words and adding an end-of-token marker.
def build_vocab(corpus):
    # Split each word into characters, join them with spaces, and add the end-of-word marker "</w>".
    tokens = [" ".join(word) + " </w>" for word in corpus.split()]
    return Counter(tokens)  # Count the frequency of each token

# Get counts of pairs of consecutive symbols in the vocabulary.
def get_stats(vocab):
    pairs = Counter()
    for word, frequency in vocab.items():
        symbols = word.split()
        # Count occurrences of consecutive symbol pairs.
        for i in range(len(symbols) - 1):
            pairs[(symbols[i], symbols[i + 1])] += frequency
    return pairs

# Merge the most frequent pairs in the vocabulary.
def merge_vocab(pair, vocab):
    merged_vocab = Counter()
    bigram = re.escape(" ".join(pair))
    pattern = re.compile(rf"(?<!\S){bigram}(?!\S)")  # Regex pattern to find whole bigrams

    for word, freq in vocab.items():
        # Replace the bigram in the word with the merged version.
        merged_word = pattern.sub("".join(pair), word)
        merged_vocab[merged_word] = freq

    return merged_vocab

# Tokenize text into subword units based on sorted tokens and input mappings.
def tokenize(
    text,
    sorted_tokens,
    input_id_map,
    return_strings=False,
    max_length=32,
    unknown_token="</u>"
):
    # Append the end-of-word marker "</w>" to each word in the input text.
    text = " ".join([word + "</w>" for word in text.split()])

    def recursive_tokenize(text, sorted_tokens):
        if not text:
            # If the input text is empty, return an empty list.
            return []
        if not sorted_tokens:
            # If there are no sorted tokens left, return the unknown token.
            return [unknown_token]

        tokens = []
        for i, token in enumerate(sorted_tokens):
            token_pattern = re.escape(token)  # Escape special characters in the token.
            matches = [(m.start(), m.end()) for m in re.finditer(token_pattern, text)]

            if not matches:
                # Skip to the next token if there are no matches.
                continue

            end_positions = [start for start, _ in matches]  # Extract start positions of matches.
            start_position = 0

            for end_position in end_positions:
                # Tokenize the substring before the match recursively.
                substring = text[start_position:end_position]
                tokens += recursive_tokenize(substring, sorted_tokens[i + 1:])
                tokens.append(token)  # Add the matched token to the tokens list.
                start_position = end_position + len(token)

            # Tokenize any remaining part of the string after the matches.
            remaining_text = text[start_position:]
            tokens += recursive_tokenize(remaining_text, sorted_tokens[i + 1:])
            break

        return tokens

    tokenized = recursive_tokenize(text, sorted_tokens)

    if return_strings:
        # Return tokens as strings if requested.
        return tokenized

    input_ids = [1] * max_length  # Initialize input IDs with a default value.
    attention_mask = [0] * max_length  # Initialize the attention mask.

    for i, token in enumerate(tokenized[:max_length]):
        # Map tokens to input IDs, handling unknown tokens gracefully.
        input_ids[i] = input_id_map.get(token, input_id_map.get(unknown_token, 0))
        attention_mask[i] = 1  # Set attention mask for valid tokens.

    return input_ids, attention_mask


# corpus = "this is a sample corpus"
# vocab = build_vocab(corpus)
# print(vocab)
# sorted_tokens = ["a", "sample", "is", "this", " </w>"]
# input_id_map = {token: i for i, token in enumerate(sorted_tokens)}
# print(tokenize("this is a sample", sorted_tokens, input_id_map))




In [None]:
def get_tokens_from_vocab(vocab):
    """
    Extract tokens and their frequencies from the vocabulary.
    Also create a mapping of tokenized words to their original words.

    Args:
        vocab (Counter): Vocabulary with words and their frequencies.

    Returns:
        tokens_frequencies (Counter): Frequency of individual tokens.
        vocab_tokenization (dict): Mapping of concatenated tokens to tokenized form.
    """
    tokens_frequencies = Counter()
    vocab_tokenization = {}
    for word, freq in vocab.items():
        word_tokens = word.split()
        for token in word_tokens:
            tokens_frequencies[token] += freq
        vocab_tokenization[''.join(word_tokens)] = word_tokens
    return tokens_frequencies, vocab_tokenization

# Create a mapping of tokens to integers and back.
def measure_token_length(token):
    """
    Measure the length of a token. For tokens ending with "</w>",
    subtract the marker length and add 1.

    Args:
        token (str): Input token.

    Returns:
        int: Measured length of the token.
    """
    return len(token[:-4]) + 1 if token[-4:] == '</w>' else len(token)

# Sort tokens by their length (adjusted for end markers) and frequency in descending order.
sorted_tokens = [
    token
    for (token, freq) in sorted(
        tokens_frequencies.items(),
        key=lambda item: (measure_token_length(item[0]), item[1]),
        reverse=True,
    )
]

# Define special tokens for unknown and padding.
UNK_TOKEN = '</u>'
PAD_TOKEN = '</p>'

# Map tokens to unique integers, starting with special tokens.
token2id = {UNK_TOKEN: 0, PAD_TOKEN: 1}
for i, token in enumerate(sorted_tokens):
    token2id[token] = i + 2

# Create a reverse mapping from integers to tokens.
id2token = {v: k for k, v in token2id.items()}

# Embedding initialization:
# To enable models to learn semantic meaning dynamically, we map tokens to random vectors.
# The dimension of these vectors can vary based on the application's requirements.
# Larger dimensions encode more information but require greater computational resources.

# Initialize embeddings as random vectors.
# embedding_dim = 128
# embeddings = {token: np.random.rand(embedding_dim) for token in token2id}


In [None]:
emb_dim = 32
token_embeddings = nn.Embedding(vocab_size, emb_dim)
token_embeddings.weight.size() # vocab_size x embedding_dim

torch.Size([3815, 32])

## 2. Positional Encoding


# Positional Encoding

Positional encoding is a method used in Transformer models to encode the position of tokens in a sequence, enabling the model to understand their order. Unlike RNNs, which inherently capture sequential information through recurrence, Transformers rely on positional encoding since they lack this structural feature.

---

## Why Positional Encoding?
- Transformers process tokens in parallel and do not have built-in sequential mechanisms.
- Positional encoding adds information about the order of tokens directly into the input embeddings.

---

## Key Properties of Positional Encoding
1. **Uniqueness**: Each position has a distinct encoding.
2. **Consistency**: Relative distances between positions remain consistent across sequences of varying lengths.
3. **Generality**: The encoding generalizes well to longer sequences.
4. **Determinism**: The process is fixed and reproducible.

---

## Mathematical Definition

The positional encoding is represented as a $d$-dimensional vector added to the token embeddings. It is defined as:

\[
PE_{(pos, i)} =
\begin{cases}
\sin\left(\frac{pos}{10000^{\frac{i}{d_{model}}}}\right) & \text{if } i \bmod 2 = 0 \\
\cos\left(\frac{pos}{10000^{\frac{i}{d_{model}}}}\right) & \text{if } i \bmod 2 \neq 0
\end{cases}
\]

Where:
- \( pos \): Position of the token in the sequence.
- \( i \): Dimension index of the embedding.
- \( d_{model} \): Dimensionality of the model.

---

## Why Sine and Cosine?
- The periodic nature allows encodings to generalize to unseen positions.
- Both sine and cosine functions create unique yet consistent encodings.
- These functions enable the model to compute relative distances between positions efficiently.

---

## How It's Used
- The positional encoding vector is added to the token embedding vector for each position.
- This enriched representation is fed into the Transformer layers, allowing the model to process both content and positional information.

By using positional encoding, Transformers effectively handle the sequential nature of language, enabling them to excel in tasks like translation, summarization, and more.


Lets create matrix of `[SeqLen, HiddenDim]` representing the positional encoding for `max_len` inputs.

In [None]:
def position_encode(max_len, emb_dim):
    """
    Generate positional encodings for sequences.

    Args:
        max_len (int): Maximum sequence length.
        emb_dim (int): Embedding dimension.

    Returns:
        torch.Tensor: A tensor of size (max_len, emb_dim) containing positional encodings.
    """
    # Initialize a tensor of zeros for positional encodings.
    pe = torch.zeros(max_len, emb_dim)

    # Create a column vector of position indices [0, 1, ..., max_len-1].
    position = torch.arange(max_len, dtype=torch.float).unsqueeze(1)

    # Calculate the scaling factor for even dimensions.
    div_term = torch.exp(-math.log(10000.0) * torch.arange(0, emb_dim, 2).float() / emb_dim)

    # Compute sine for even indices and cosine for odd indices.
    pe[:, 0::2] = torch.sin(position * div_term)  # Apply sine to even dimensions.
    pe[:, 1::2] = torch.cos(position * div_term)  # Apply cosine to odd dimensions.

    return pe


encoded_positions = position_encode(64, 32)
print(encoded_positions.size())


torch.Size([64, 32])

In [None]:
class PositionalEncoding(nn.Module):

    def __init__(self, emb_dim, max_len, sinusoidal=True):
        """
        Initialize the PositionalEncoding module.

        Args:
            emb_dim (int): Hidden dimensionality of the input.
            max_len (int): Maximum length of a sequence to expect.
            sinusoidal (bool): Whether to use sinusoidal positional encodings or learned embeddings.
        """
        super().__init__()
        self.sinusoidal = sinusoidal
        if self.sinusoidal:
            # Generate sinusoidal positional encodings.
            pe = position_encode(max_len, emb_dim).unsqueeze(0)  # Add batch dimension.
            self.register_buffer('pe', pe, persistent=False)  # Register buffer to avoid optimization.
        else:
            # Use learned positional embeddings.
            self.register_buffer('pos', torch.arange(0, max_len, dtype=torch.long).unsqueeze(0))
            self.pe = nn.Embedding(max_len, emb_dim)

    def forward(self, x):
        """
        Add positional encodings to the input.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_len, emb_dim).

        Returns:
            torch.Tensor: Tensor with positional encodings added.
        """
        if self.sinusoidal:
            return x + self.pe[:, :x.size(1)]  # Add sinusoidal encodings.
        else:
            return x + self.pe(self.pos[:, :x.size(1)])  # Add learned embeddings.

## 3. Self Attention

# Attention Mechanism

The attention mechanism answers the question:  
**How relevant is each element in a sequence to the others?**

It allows models to dynamically focus on the most relevant parts of the input when generating output, rather than relying on a single global representation.

---

## Key Components

Each element in the sequence is represented by three vectors: **Query (Q)**, **Key (K)**, and **Value (V)**.

- **Query (Q):** Describes what a token is looking for in other tokens (what it wants to attend to).  
- **Key (K):** Represents what a token offers and when it might be important with respect to the query.  
- **Value (V):** Contains the actual information of the token, used to compute the output.

---

## Attention Computation

The attention mechanism computes a weighted combination of the values \(V\), where the weights are based on the similarity between the query \(Q\) and the keys \(K\). This is defined as:

$$
\operatorname{Attention}(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V
$$

### Breakdown:
1. $$Q K^T$$: Calculates the similarity between queries and keys.  
2. $$\sqrt{d_k}$$: Scales the dot product to avoid large values that hinder training.  
3. $$\operatorname{softmax}$$: Converts similarity scores into probabilities.  
4. Weighted sum of \(V\): Produces the output representation for each token.

---

## Why Attention Matters
- **Dynamic Focus:** The model identifies and emphasizes the most relevant parts of the input.  
- **Context Sensitivity:** Enables better understanding of relationships within the sequence.  
- **Scalability:** Forms the foundation of modern architectures like Transformers, facilitating parallel processing and high efficiency.

The attention mechanism is key to enabling models to process sequences effectively and has revolutionized tasks in NLP, vision, and beyond.


In [None]:


# Scaled dot-product attention function
def scaled_dot_product_attention(q, k, v, mask=None):
    """
    Compute the scaled dot-product attention.

    Args:
        q (torch.Tensor): Query tensor of shape (..., seq_len_q, depth).
        k (torch.Tensor): Key tensor of shape (..., seq_len_k, depth).
        v (torch.Tensor): Value tensor of shape (..., seq_len_k, depth_v).
        mask (torch.Tensor, optional): Mask tensor of shape (..., seq_len_q, seq_len_k). Defaults to None.

    Returns:
        torch.Tensor: Output values after applying attention.
        torch.Tensor: Attention weights.
    """
    # Compute attention logits by multiplying query and key transpose.
    attn_logits = torch.matmul(q, k.transpose(-2, -1))
    # Scale the logits by the square root of the depth.
    attn_logits = attn_logits / math.sqrt(q.size()[-1])

    if mask is not None:
        # Apply mask to prevent attention on certain positions.
        attn_logits = attn_logits.masked_fill(mask == 0, -9e15)

    # Compute attention weights using softmax.
    attention = F.softmax(attn_logits, dim=-1)
    # Compute the weighted sum of values.
    values = torch.matmul(attention, v)

    return values, attention


In [None]:
values, attention = scaled_dot_product_attention(q, k, v, mask=torch.Tensor(attention_mask).long())
values.size(), attention.size()

(torch.Size([64, 64]), torch.Size([64, 64]))

### Multi-head Attention

# Multi-Head Attention

**Multi-Head Attention** is an extension of the attention mechanism that allows a model to focus on different parts of the sequence simultaneously. It is a critical component of Transformer architectures.

---

## Why Multi-Head Attention?
1. **Diverse Representations:** A single attention mechanism might focus on only one aspect of the relationships in a sequence. Multi-head attention enables the model to capture different types of relationships in parallel.
2. **Improved Learning Capacity:** By using multiple "heads," the model can learn to attend to different positions and features in the input sequence.
3. **Efficient Representation:** Combines outputs from multiple attention mechanisms into a richer and more expressive representation.

---

## Formula
The output of multi-head attention is:
\[
\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \text{head}_2, \ldots, \text{head}_h) W^O
\]

Where:
- \(\text{head}_i = \operatorname{Attention}(QW_i^Q, KW_i^K, VW_i^V)\)
- \(W_i^Q, W_i^K, W_i^V\): Projection matrices for the \(i\)-th head.
- \(W^O\): Projection matrix to combine all heads.

---

## Key Benefits
1. **Parallelism:** Multiple heads work in parallel, improving efficiency.
2. **Enhanced Context:** Each head focuses on different relationships in the sequence.
3. **Flexibility:** Allows the model to capture nuanced patterns in data.

---

## Applications
- Multi-head attention is a core part of the **Transformer architecture** and is widely used in tasks like:
  - Machine Translation
  - Text Summarization
  - Question Answering
  - Vision Transformers (ViT) in Computer Vision

Multi-head attention enhances the model's ability to process complex sequences and is a cornerstone of modern deep learning architectures.


The following is the full implementation of multi-head attention as a PyTorch module.

In [None]:

# Multi-head attention module
class MultiheadAttention(nn.Module):

    def __init__(self, input_dim, emb_dim, num_heads, max_len):
        """
        Initialize the MultiheadAttention module.

        Args:
            input_dim (int): Dimensionality of input features.
            emb_dim (int): Dimensionality of the embedding space.
            num_heads (int): Number of attention heads.
            max_len (int): Maximum sequence length.
        """
        super().__init__()
        assert emb_dim % num_heads == 0, "Embedding dimension must be divisible by the number of heads."

        self.emb_dim = emb_dim
        self.num_heads = num_heads
        self.head_dim = emb_dim // num_heads

        # Linear projections for query, key, value
        self.qkv_proj = nn.Linear(input_dim, 3 * emb_dim)
        self.o_proj = nn.Linear(emb_dim, emb_dim)

        # Causal mask for autoregressive attention
        self.register_buffer("causal_mask", torch.tril(torch.ones(max_len, max_len)).view(1, 1, max_len, max_len))
        self._reset_parameters()

    def _reset_parameters(self):
        """Initialize parameters using Xavier initialization."""
        nn.init.xavier_uniform_(self.qkv_proj.weight)
        nn.init.xavier_uniform_(self.o_proj.weight)
        self.qkv_proj.bias.data.fill_(0)
        self.o_proj.bias.data.fill_(0)

    def forward(self, x, mask=None, return_attention=False):
        """
        Forward pass for multi-head attention.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_length, input_dim).
            mask (torch.Tensor, optional): Attention mask. Defaults to None.
            return_attention (bool, optional): Whether to return attention weights. Defaults to False.

        Returns:
            torch.Tensor: Output tensor after applying multi-head attention.
            torch.Tensor (optional): Attention weights (if return_attention is True).
        """
        batch_size, seq_length, _ = x.size()
        qkv = self.qkv_proj(x)

        # Separate Q, K, V from the linear projection output
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3 * self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3)  # [Batch, Head, SeqLen, 3 * HeadDim]
        q, k, v = qkv.chunk(3, dim=-1)

        # Use causal mask if no mask is provided
        if mask is None:
            mask = self.causal_mask[:, :, :seq_length, :seq_length].expand(batch_size, self.num_heads, -1, -1)
        else:
            mask = mask.view(batch_size, 1, 1, seq_length).expand(-1, self.num_heads, seq_length, -1)

        # Apply scaled dot-product attention
        values, attention = scaled_dot_product_attention(q, k, v, mask=mask)

        # Concatenate attention outputs
        values = values.permute(0, 2, 1, 3)  # [Batch, SeqLen, Head, HeadDim]
        values = values.reshape(batch_size, seq_length, self.emb_dim)

        # Final linear projection
        o = self.o_proj(values)

        if return_attention:
            return o, attention
        else:
            return o


## 4. Transformer Block

# Transformer Block

The **Transformer Block** is the core unit of the Transformer architecture, designed to process sequences by focusing on token relationships without recurrence or convolution.

---

## Components

1. **Multi-Head Attention**  
   - Allows the model to focus on different parts of the sequence simultaneously.  
   - Captures diverse relationships and patterns across tokens.

2. **Feed-Forward Network**  
   - Enhances representations through non-linear transformations.  
   - Applied independently to each token.

3. **Residual Connections and Layer Normalization**  
   - Stabilize training and preserve original information.  
   - Residual connections add inputs back to outputs; normalization improves convergence.

---

## Workflow
1. **Input Embeddings with Positional Encoding**  
2. **Multi-Head Attention**  
3. **Add & Normalize**  
4. **Feed-Forward Network**  
5. **Add & Normalize**  
6. **Output**: Refined token representations.

---

## Advantages
- **Parallel Processing:** Faster than RNNs.  
- **Long-Range Dependencies:** Captures relationships between distant tokens.  
- **Scalability:** Adapts to various tasks by stacking multiple blocks.

---

## Applications
Used in models like **BERT**, **GPT**, and **Vision Transformers**, enabling state-of-the-art performance across NLP and vision tasks.


### Layer Normalization

In [None]:

# Layer normalization function
def layer_normalization(x, gamma, beta):
    """
    Perform layer normalization.

    Args:
        x (torch.Tensor): Input tensor of shape (..., dim).
        gamma (torch.Tensor): Scale parameter of shape (..., dim).
        beta (torch.Tensor): Shift parameter of shape (..., dim).

    Returns:
        torch.Tensor: Normalized tensor.
    """
    mean = x.mean(dim=-1, keepdim=True)
    var = ((x - mean) ** 2).mean(dim=-1, keepdim=True)
    std = (var + 1e-9).sqrt()
    y = (x - mean) / std
    y = y * gamma
    y = y + beta
    return y

# Transformer block
class TransformerBlock(nn.Module):
    def __init__(self, input_dim, num_heads, model_dim, max_len, dropout=0.0):
        """
        Initialize the Transformer block.

        Args:
            input_dim (int): Dimensionality of the input.
            num_heads (int): Number of heads to use in the attention.
            model_dim (int): Dimensionality of the hidden layer in the MLP.
            max_len (int): Maximum length of the input sequence.
            dropout (float): Dropout probability.
        """
        super().__init__()

        # Attention layer
        self.self_attn = MultiheadAttention(
            input_dim=input_dim, emb_dim=input_dim, num_heads=num_heads, max_len=max_len
        )

        # Two-layer MLP
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, model_dim),
            nn.GELU(),
            nn.Linear(model_dim, input_dim),
            nn.Dropout(dropout)
        )

        # Layer normalization
        self.norm1 = nn.LayerNorm(input_dim)
        self.norm2 = nn.LayerNorm(input_dim)

    def forward(self, x, mask=None):
        """
        Forward pass for the Transformer block.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_len, input_dim).
            mask (torch.Tensor, optional): Attention mask. Defaults to None.

        Returns:
            torch.Tensor: Output tensor after processing.
        """
        x = x + self.self_attn(self.norm1(x), mask=mask)
        x = x + self.mlp(self.norm2(x))
        return x


## 5. Transformer Architectures

We can now combine all of the ingredients i.e. BPE, token embeddings, positional encoding, self-attention and Transformer blocks to create the Transformer architecture. We can stack multiple Transformer blocks to scale the expressive power of our models.

Generally, we can allow the attention layers to attend to the entire sequence i.e. both the previous and the next words. This setting can be considered as an "encoder", as the model is learning to encode the entire sequence into something meaningful.

If we limit the attention to only one direction with masking, we call this setting as a "decoder", as the model is learning to look in the past and generate (decode) the next most-likely token.

BERT like models use encoder architecture and language models like GPT use decoder architecture. Both of these settings can be combined as well. Machine Translation models generally use encoder-decoder architecture (as in the original 2017 paper).

![image.png](attachment:6e762df2-93b3-4814-bf9b-363760af964a.png)

In [None]:

# Transformer model
class Transformer(nn.Module):
    def __init__(
        self,
        vocab_size,
        max_len,
        emb_dim,
        model_dim,
        num_layers,
        num_heads,
        dropout=0.1,
        position_sinusoidal=True,
    ):
        """
        Initialize the Transformer model.

        Args:
            vocab_size (int): Size of the vocabulary.
            max_len (int): Maximum sequence length.
            emb_dim (int): Dimensionality of token embeddings.
            model_dim (int): Dimensionality of the model layers.
            num_layers (int): Number of Transformer blocks.
            num_heads (int): Number of attention heads.
            dropout (float): Dropout probability.
            position_sinusoidal (bool): Use sinusoidal positional encodings.
        """
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, emb_dim)
        self.input_projection = nn.Linear(emb_dim, model_dim)
        self.positional_encoding = PositionalEncoding(
            model_dim, max_len, sinusoidal=position_sinusoidal
        )
        self.transformer_blocks = nn.ModuleList(
            [
                TransformerBlock(model_dim, num_heads, model_dim, max_len, dropout)
                for _ in range(num_layers)
            ]
        )
        self.output_projection = nn.Linear(model_dim, vocab_size, bias=False)

    def forward(self, input_ids, attention_mask=None, return_logits=True):
        """
        Forward pass for the Transformer model.

        Args:
            input_ids (torch.Tensor): Input token IDs of shape (batch_size, seq_len).
            attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
            return_logits (bool, optional): Return logits if True, else return hidden states. Defaults to True.

        Returns:
            torch.Tensor: Model output logits or hidden states.
        """
        embedded_tokens = self.embedding(input_ids)
        projected_inputs = self.input_projection(embedded_tokens)
        encoded_inputs = self.positional_encoding(projected_inputs)

        x = encoded_inputs
        for block in self.transformer_blocks:
            x = block(x, mask=attention_mask)

        output = self.output_projection(x) if return_logits else x
        return output


In [None]:
#With causual-attention and by adjusting the `num_layers`, `num_heads` and `model_dim`, most GPT models can be implemented with our code.
gpt2 = dict(num_layers=12, num_heads=12, model_dim=768)  # 120M params
gpt3 = dict(num_layers=96, num_heads=96, model_dim=2048)  # 175B params
gpt_mini = dict(num_layers=6, num_heads=6, model_dim=192) # 1.2M params

In [None]:
model = Transformer(
    vocab_size=vocab_size,
    max_len=max_len,
    emb_dim=32,
    **gpt_mini
)

In [None]:
input_text = 'The fox jumps over the fence'

input_ids, attention_mask = tokenize(input_text,
                     sorted_tokens=sorted_tokens,
                     input_id_map=token2id,
                     return_strings=False,
                     max_length=max_len,
                     unknown_token='</u>')

input_ids = torch.Tensor([input_ids]).long()
attention_mask = torch.Tensor([attention_mask]).long()

In [None]:
model(input_ids, attention_mask).shape

torch.Size([1, 64, 3815])

## 6. Transformer Training - Language Modeling

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

# Example dataset class
def prepare_dataset(tokenizer, texts, max_len):
    """Tokenize texts and create attention masks."""
    tokenized_inputs = []
    attention_masks = []
    for text in texts:
        tokens = tokenizer(text)
        if len(tokens) > max_len:
            tokens = tokens[:max_len]
        attention_mask = [1] * len(tokens) + [0] * (max_len - len(tokens))
        tokens += [0] * (max_len - len(tokens))
        tokenized_inputs.append(tokens)
        attention_masks.append(attention_mask)
    return torch.tensor(tokenized_inputs), torch.tensor(attention_masks)

class ExampleDataset(Dataset):
    def __init__(self, inputs, masks, labels):
        self.inputs = inputs
        self.masks = masks
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.inputs[idx], self.masks[idx], self.labels[idx]

# Training function
def train_model(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    for inputs, masks, labels in dataloader:
        inputs, masks, labels = inputs.to(device), masks.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs, masks)
        loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1))
        total_loss += loss.item()

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

    return total_loss / len(dataloader)

# Evaluation function
def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for inputs, masks, labels in dataloader:
            inputs, masks, labels = inputs.to(device), masks.to(device), labels.to(device)

            # Forward pass
            outputs = model(inputs, masks)
            loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1))
            total_loss += loss.item()

    return total_loss / len(dataloader)

# Hyperparameters and setup
vocab_size = 30522
max_len = 128
emb_dim = 256
model_dim = 256
num_layers = 4
num_heads = 8
dropout = 0.1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instantiate model
tokenizer = lambda text: [ord(c) for c in text]  # Dummy tokenizer
model = Transformer(
    vocab_size=vocab_size,
    max_len=max_len,
    emb_dim=emb_dim,
    model_dim=model_dim,
    num_layers=num_layers,
    num_heads=num_heads,
    dropout=dropout
).to(device)

texts = ["hello world", "transformers are powerful"] * 100
labels = [[1] + [0] * (max_len - 1)] * len(texts)  # Dummy labels
tokenized_inputs, attention_masks = prepare_dataset(tokenizer, texts, max_len)
dataset = ExampleDataset(tokenized_inputs, attention_masks, torch.tensor(labels))
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=5e-5)

# Training loop
epochs = 10
for epoch in range(epochs):
    train_loss = train_model(model, dataloader, criterion, optimizer, device)
    print(f"Epoch {epoch+1}/{epochs}, Training Loss: {train_loss:.4f}")

# Save the model
torch.save(model.state_dict(), "transformer_model.pth")


In [None]:
x, y, mask = train_dataset[100]
x, y, mask

(tensor([   12,   437, 12782,  ...,  3873, 46664,    11]),
 tensor([  437, 12782,   815,  ..., 46664,    11,   284]),
 tensor([1, 1, 1,  ..., 1, 1, 1]))

## References
[1] https://huggingface.co/blog/transformers
[2] https://paperswithcode.com/method/transformer
[3] https://jalammar.github.io/illustrated-transformer/