Practical Session - Transformers
---

### Goal  
The goal of this session is to implement a standard Transformer in PyTorch from scratch.

### Task  
As a sample problem, we will focus on sorting a list of digits from 1 to 20.

Example:
Source Sequence: `[19, 7, 2, 9, 18]`
Target Output:   `[2, 7, 9, 18, 19]`

### Outline
1. Embed the tokens/numbers into vectors
2. The `Transformer` layer
- 2.1. Implement the self-attention layer, as seen in class
- 2.2 Integrate the self-attention layer in a transformer layer
3. Implement a `Transformer` network
4. Train the network!






---
#### 0. Generating data

In [2]:
import torch
import numpy as np

VOCAB_SIZE = 20  # Numbers from 1 to vocab_size
SEQ_LENGTH = 5   # Sequence length

def generate_data(batch_size, seq_length, vocab_size):
    """
    Generates random sequences of integers and their sorted counterparts.

    Args:
        batch_size (int): Number of sequences to generate.
        seq_length (int): Length of each sequence.
        vocab_size (int): Maximum integer value (exclusive) for the sequence elements.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: 
            - src (torch.Tensor): A tensor of shape (batch_size, seq_length) containing 
              random integers in the range [1, vocab_size).
            - tgt (torch.Tensor): A tensor of shape (batch_size, seq_length) containing 
              the sorted version of each sequence in `src`.
    """
    src = torch.randint(1, vocab_size, (batch_size, seq_length))
    tgt = torch.sort(src, dim=1)[0] 
    return src, tgt

source, target = generate_data(1, SEQ_LENGTH, VOCAB_SIZE)
print("Example source sequence and target:")
print("-" * 50)
print(f"Input Sequence:         {source.tolist()[0]}")
print(f"Expected Sorted Output: {target.tolist()[0]}")

Example source sequence and target:
--------------------------------------------------
Input Sequence:         [2, 1, 19, 17, 16]
Expected Sorted Output: [1, 2, 16, 17, 19]


---
### Step 1: Embed the tokens into vectors

First step is to transform the input integers into vectors of a fixed dimension `d`
##### How to do this:  
1. **Token Embeddings**: Each input token (integer index) is mapped to a high-dimensional vector using [`torch.nn.Embedding`](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html).  
2. **Positional Encoding**: Instead of the classical sine-cosine positional encodings, we simply use a learnable vector for each position in the sequence, again using `torch.nn.Embedding`.
3. **Summation**: The final embedding is the sum of token embeddings and positional encodings.  


In [5]:
torch.arange(SEQ_LENGTH).view(1, SEQ_LENGTH)

tensor([[0, 1, 2, 3, 4]])

In [6]:
import torch.nn as nn

class IntegerSequenceEmbedding(nn.Module):
    """
    Embedding module that combines token embeddings with positional encodings.

    Args:
        vocab_size (int): Size of the vocabulary (number of unique tokens).
        embed_dim (int): Dimension of the embeddings.
        seq_length (int): Sequence length.
    """
    def __init__(self, vocab_size=20, embed_dim=16, seq_length=5):
        super().__init__()
        # embedding layer for the tokens (numbers):
        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        # embedding layer for the positions:
        self.positional_embedding = nn.Embedding(seq_length, embed_dim)

    def forward(self, x):
        """
        Forward pass of the embedding module.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_length), 
                              containing integer token indices.
        Returns:
            torch.Tensor: Embedded tensor of shape (batch_size, seq_length, embed_dim).
        """
        # Token embedding
        x = self.token_embedding(x)  # Shape: (batch_size, seq_length, embed_dim)
        # Positional encoding
        positions = torch.arange(x.shape[1]).unsqueeze(0)              # Shape: (1, seq_length)
        x = x + self.positional_embedding(positions)

        return x
    
embedding_layer = IntegerSequenceEmbedding(vocab_size=21, embed_dim=3, seq_length=SEQ_LENGTH)
src, _ = generate_data(batch_size=1, seq_length=SEQ_LENGTH, vocab_size=VOCAB_SIZE)
embedded_src = embedding_layer(src)
print(f"Embedded batch of sequences {src.tolist()} into:")
print(f"Tensor of shape  {embedded_src.shape}")
embedded_src

Embedded batch of sequences [[7, 2, 13, 4, 18]] into:
Tensor of shape  torch.Size([1, 5, 3])


tensor([[[-1.3827,  0.0612,  0.3681],
         [-0.7669,  1.0268,  0.7466],
         [-1.2059,  0.1586, -1.3335],
         [ 1.3124, -3.4395,  0.5323],
         [ 0.5295,  0.5688, -1.6422]]], grad_fn=<AddBackward0>)

---
### Step 2: The `Transformer` layer

#### 2.1. Implement the self-attention layer, as seen in class

In [12]:
class SingleHeadAttention(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()

        self.W_q = nn.Linear(embed_dim, embed_dim, bias=False)  # Query projection
        self.W_k = nn.Linear(embed_dim, embed_dim, bias=False)  # Key projection
        self.W_v = nn.Linear(embed_dim, embed_dim, bias=False)  # Value projection

    def forward(self, x):
        # Input x: (batch_size, seq_length, embed_dim)
        # TODO: Complete this function
        Q = self.W_q(x)  # (batch_size, seq_length, embed_dim)
        K = self.W_k(x)  # (batch_size, seq_length, embed_dim)
        V = self.W_v(x)  # (batch_size, seq_length, embed_dim)
        d_k = K.shape[-1] # Key dimension


        # Dot-product similarities
        scores = Q @ K.transpose(1, 2)
        # Scale by dimension
        scores /= d_k ** 0.5            
        # Transform the scores into probabilities with the softmax function
        scores = torch.softmax(scores, dim=-1)
    
        # Optional: store the attention weights for visualization
        self.attention_weights = scores

        # Update the vectors x
        x = scores @ V

        return x

# Testing
attn = SingleHeadAttention(embed_dim=128)  
x = torch.randn(32, 10, 128)  # Batch of 32 sequences, each of length 10 with 128-d embeddings  
output = attn(x)  
print(output.shape)  # Should be (32, 10, 128)

torch.Size([32, 10, 128])


In [None]:
# Checking the softmax dimension
A = torch.arange(9).view(1,3,3).float()
print(A[0])
print(torch.softmax(A, dim=2)[0])

tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]])
tensor([[0.0900, 0.2447, 0.6652],
        [0.0900, 0.2447, 0.6652],
        [0.0900, 0.2447, 0.6652]])


#### 2.2 Integrate the self-attention layer in a transformer layer

A **Transformer Encoder Layer** consists of:  
- A *self-attention mechanism* to capture long-range dependencies.  
- *Fully connected (feedforward) layers* to transform representations.  
- *Layer normalization* to stabilize training.  
- *Residual connections* to improve gradient flow and prevent vanishing gradients.
<div style="max-width:400px">
<img src="https://substackcdn.com/image/fetch/w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F5024bcc5-33c9-4d53-9bd7-56cbcf9c4627_874x1108.png" alt="Transformer Layer" />
<div/>


In [14]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.self_attn = SingleHeadAttention(embed_dim)
        # Normalization layers
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        # Fully connected layers
        self.fc_layers = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 2),
            nn.ReLU(),
            nn.Linear(embed_dim * 2, embed_dim)
        )
        

    def forward(self, x):
        # Input x: (batch_size, seq_length, embed_dim)
        # TODO: Implement encoder block, with residual connections!
        
        x = x + self.self_attn(self.norm1(x))
        x = x + self.fc_layers(self.norm2(x))

        return x
    
# Testing
attn = TransformerEncoderLayer(embed_dim=128)
x = torch.randn(32, 10, 128)  # Batch of 32 sequences, each of length 10 with 128-d embeddings  
output = attn(x)  
print(output.shape)  # Should be (32, 10, 128)

torch.Size([32, 10, 128])


---
### 3. Implement a `Transformer` network

#### 3.1. General architecture:
The full Transformer network consists of:  
1. **Embedding Module**: Converts input tokens into dense vectors and adds positional encodings.  
2. **Transformer Layers**: A stack of self-attention layers with feedforward networks and normalization.  

<div style="max-width:600px">
<img src="https://substackcdn.com/image/fetch/w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Ff6133c18-bfaf-4578-8c5a-e5ac7809f65b_1632x784.png" alt="Transformer Architecture, with zoom on transformer layer", "width="50px"\>
</div>

3. **Classification Head**: Processes the output of the Transformer layers to produce predictions.

#### 3.2. Predictions for our task

The task is to **sort a list of integers**. What should be the output of the model? Of what dimension is it?

In [None]:
import torch
import torch.nn as nn

class Transformer(nn.Module):
    def __init__(self, vocab_size, embed_dim, seq_length=5, num_layers=2):
        """
        Transformer Encoder for sequence processing.

        Args:
            vocab_size (int): Number of unique tokens in the input vocabulary.
            embed_dim (int): Dimension of the token embeddings.
            num_layers (int): Number of Transformer encoder layers.
        """
        super().__init__()

        # Token embedding layer
        self.embedding = IntegerSequenceEmbedding(vocab_size, embed_dim, seq_length)

        # Stack of Transformer Encoder Layers
        self.encoder_layers = nn.ModuleList([
            TransformerEncoderLayer(embed_dim) for i in range(num_layers)
        ])

        # Final classification head: a simple linear layer
        self.fc_out = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        """
        Forward pass of the Transformer Encoder.

        Args:
            x (Tensor): Input tensor of shape (batch_size, seq_length).

        Returns:
            Tensor: Output tensor of shape (batch_size, seq_length, vocab_size) containing probabilities for each token.
        """
        # Convert input sequence to embeddings
        x = self.embedding(x)

        # Pass through Transformer Encoder Layers
        for layer in self.encoder_layers:
            x = layer(x)

        # Apply final linear layer to get logits
        outputs = self.fc_out(x)

        return outputs

embed_dim = 32
batch_size = 16
transformer = Transformer(VOCAB_SIZE, embed_dim=embed_dim, seq_length=SEQ_LENGTH)
# Generate source and target data
source, target = generate_data(batch_size, SEQ_LENGTH, VOCAB_SIZE)

# Pass the source data through the transformer and check the output shape
outputs = transformer(source)
predictions = outputs.argmax(dim=-1) # predictions should be a list of integers, the same length as source.
print(f"Input Sequence:         {source.tolist()[0]}")
print(f"Expected Sorted Output: {target.tolist()[0]}")
print(f"Model Prediction:       {predictions.tolist()[0]}")
print("-" * 50)

Input Sequence:         [18, 13, 15, 10, 16]
Expected Sorted Output: [10, 13, 15, 16, 18]
Model Prediction:       [10, 7, 2, 9, 13]
--------------------------------------------------


---
### 4. Train the network!

As for other neural networks, the Transformer parameters are learned by stochastic gradient descent on a training dataset.

In [None]:
# Data
vocab_size = 20 
seq_length = 10
# Network hyperparameters
embed_dim = 32
num_layers = 2
# Training hyperparameters
batch_size = 32
num_epochs = 500

# Model, Loss, Optimizer
model = Transformer(vocab_size, embed_dim=embed_dim, seq_length=seq_length)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Training loop
for epoch in range(num_epochs):
    # reset gradients
    optimizer.zero_grad()
    # Generate a mini-batch for training
    src, tgt = generate_data(batch_size, seq_length, vocab_size)
    # Forward pass
    output = model(src)
    loss = criterion(output.flatten(0,1), tgt.flatten()) # Why flatten the output?
    # Backward pass
    loss.backward()
    # Parameter updates
    optimizer.step()

    # Print model output at each epoch
    if epoch % 50 == 0 or epoch==10:
        test_src, test_tgt = generate_data(batch_size, seq_length, vocab_size)
        # test_pred should be a list of integers, the same length as test_src.
        test_pred = model(test_src).argmax(dim=-1)

        print(f"Epoch {epoch}")
        print(f"Input Sequence:         {test_src.tolist()[0]}")
        print(f"Expected Sorted Output: {test_tgt.tolist()[0]}")
        print(f"Model Prediction:       {test_pred.tolist()[0]}")
        print(f"Loss: {loss.item():.4f}")
        print("-" * 50)

print("Training complete!")


Epoch 0
Input Sequence:         [1, 11, 16, 10, 2, 5, 9, 3, 1, 7]
Expected Sorted Output: [1, 1, 2, 3, 5, 7, 9, 10, 11, 16]
Model Prediction:       [6, 1, 6, 15, 17, 15, 13, 16, 17, 8]
Loss: 3.2640
--------------------------------------------------
Epoch 10
Input Sequence:         [14, 7, 6, 10, 5, 8, 14, 12, 9, 17]
Expected Sorted Output: [5, 6, 7, 8, 9, 10, 12, 14, 14, 17]
Model Prediction:       [1, 3, 5, 6, 7, 8, 13, 13, 17, 19]
Loss: 2.0396
--------------------------------------------------
Epoch 50
Input Sequence:         [16, 17, 2, 7, 12, 18, 6, 3, 15, 19]
Expected Sorted Output: [2, 3, 6, 7, 12, 15, 16, 17, 18, 19]
Model Prediction:       [2, 3, 7, 8, 12, 16, 17, 17, 18, 18]
Loss: 1.1139
--------------------------------------------------
Epoch 100
Input Sequence:         [10, 12, 12, 19, 4, 9, 12, 11, 17, 2]
Expected Sorted Output: [2, 4, 9, 10, 11, 12, 12, 12, 17, 19]
Model Prediction:       [2, 4, 7, 9, 10, 12, 12, 12, 17, 19]
Loss: 0.6212
-----------------------------------

In [27]:
# Test the model on 5 example sequences
for i in range(5):
    print(f"Input Sequence {i + 1}:         {test_src.tolist()[i]}")
    print(f"Expected Sorted Output {i + 1}: {test_tgt.tolist()[i]}")
    print(f"Model Prediction {i + 1}:       {test_pred.tolist()[i]}")
    print("-" * 50)

print(f"Loss: {loss.item():.4f}")
print("-" * 50)

Input Sequence 1:         [16, 10, 12, 3, 17, 6, 11, 13, 8, 7]
Expected Sorted Output 1: [3, 6, 7, 8, 10, 11, 12, 13, 16, 17]
Model Prediction 1:       [3, 6, 7, 8, 10, 11, 12, 13, 16, 17]
--------------------------------------------------
Input Sequence 2:         [8, 5, 16, 13, 8, 9, 10, 5, 15, 9]
Expected Sorted Output 2: [5, 5, 8, 8, 9, 9, 10, 13, 15, 16]
Model Prediction 2:       [5, 5, 8, 8, 9, 9, 10, 13, 15, 16]
--------------------------------------------------
Input Sequence 3:         [18, 19, 17, 11, 14, 8, 9, 18, 19, 3]
Expected Sorted Output 3: [3, 8, 9, 11, 14, 17, 18, 18, 19, 19]
Model Prediction 3:       [3, 8, 9, 11, 14, 17, 18, 18, 19, 19]
--------------------------------------------------
Input Sequence 4:         [10, 9, 3, 10, 13, 10, 6, 1, 11, 9]
Expected Sorted Output 4: [1, 3, 6, 9, 9, 10, 10, 10, 11, 13]
Model Prediction 4:       [1, 3, 6, 9, 9, 10, 10, 10, 11, 13]
--------------------------------------------------
Input Sequence 5:         [19, 18, 4, 12, 17, 

---
### 5. Implement multi-headed attention

<div style="max-width:400px">
<img src="https://substackcdn.com/image/fetch/w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F65c156ae-5cc5-4f7f-8652-dd5311b19beb_544x724.png" alt="Transformer Architecture, with zoom on transformer layer", "width="50px"\>
</div>


In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__inita__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        assert embed_dim % num_heads == 0, "Embedding size must be divisible by num_heads"

        self.W_q = nn.Linear(embed_dim, embed_dim)  # Query projection
        self.W_k = nn.Linear(embed_dim, embed_dim)  # Key projection
        self.W_v = nn.Linear(embed_dim, embed_dim)  # Value projection
        self.fc_out = nn.Linear(embed_dim, embed_dim)  # Output layer

    def forward(self, x):
        batch_size, seq_length, embed_dim = x.shape

        # TODO: Compute Queries, Keys, Values
        Q = self.W_q(x)  # (batch_size, seq_length, embed_dim)
        K = self.W_k(x)  # (batch_size, seq_length, embed_dim)
        V = self.W_v(x)  # (batch_size, seq_length, embed_dim)

        # Reshape for multiple heads
        Q = Q.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)

        # TODO: Apply scaled dot-product attention
        # Dot-product similarities
        scores = Q @ K.transpose(-2, -1) 
        # Scale by key dimension
        scores /=  K.shape[-1] ** 0.5            
        # Transform the scores into probabilities with the softmax function
        scores = torch.softmax(scores, dim=-1)

        # Reshape and apply final linear layer
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, embed_dim)
        output = self.fc_out(attn_output)

        return output
