<a target="_blank" href="https://colab.research.google.com/github/MENA-ML/tutorials2025-tasks/blob/main/transformers_llms/Intro_to_transformers_Task.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Building a Transformer from Scratch


Welcome to this hands-on tutorial where you'll dive into the world of PyTorch and learn to build a [Transformer](https://arxiv.org/abs/1706.03762) model!

**What is PyTorch?**

PyTorch is a widely used, open-source machine learning framework known for its flexibility, ease of use, and strong support for GPU acceleration. It's favored by researchers and developers alike for building and deploying deep learning models, thanks to its dynamic computational graph, intuitive Pythonic syntax, and extensive ecosystem of tools and libraries.

**Why PyTorch?**

Flexibility and Control: PyTorch's dynamic computation graph allows for more intuitive model building and debugging, especially for complex architectures.
Strong GPU Acceleration: PyTorch seamlessly integrates with GPUs, significantly speeding up training and inference.
Pythonic and Easy to Learn: PyTorch's API is designed to be intuitive for Python developers, making it relatively easy to learn and use.
Large and Active Community: PyTorch benefits from a vast and active community, providing ample resources, support, and pre-trained models.
In this walk-through tutorial, you'll learn:

- **Part I** : The fundamentals of PyTorch, including its tensor operations and automatic differentiation capabilities.
How to build a basic Multilayer Perceptron (MLP) model using PyTorch.
- **Part II** : How to build a text tokenizer.
- **Part III** : How to implement a Transformer model from scratch, for classification tasks.
- **Part IV** : How to implement a Transformer model from scratch, for sequence generation tasks.

Let's get started!

In [10]:
import functools  # Used for creating partial functions
import numpy as np  # Used for numerical computation in plain NumPy to compare with PyTorch
import tqdm.notebook  as tqdm # Used for displaying progress bars
import matplotlib.pyplot as plt  # Used for plotting graphs

In [12]:
import math
import random
from typing import Any, Mapping, Tuple
import pandas as pd
import seaborn as sns  # Visualization library

In [13]:
import tensorflow as tf  # We only import it for the tokenizer
import torch
import torch._dynamo
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import DataLoader, TensorDataset

In [14]:
# Set global torch defaults
torch.set_default_dtype(torch.float32)
torch.set_float32_matmul_precision('high')
torch._dynamo.config.suppress_errors = True

In [15]:
sns.set_style(
    "whitegrid"
)  # See example in https://seaborn.pydata.org/generated/seaborn.set_style.html

In [16]:
seed = 2024
np.random.seed(seed)
generator = torch.Generator()
generator.manual_seed(seed)

<torch._C.Generator at 0x7d7cf3e550d0>

In PyTorch, neural networks are typically defined as classes that inherit from `torch.nn.Module`. This object-oriented approach provides a structured way to define and manage your network's layers and parameters. For optimization, PyTorch offers a variety of optimizers within the `torch.optim` package, such as Adam, SGD, and others.

In [17]:
import torch.nn as nn  # PyTorch's neural network module
import torch.optim as optim  # PyTorch's optimizer library

In PyTorch, you typically need to explicitly specify the device you want to use for computation (e.g., CPU or GPU).  PyTorch provides torch.device to represent devices, and `tensor.to(device)` or `module.to(device)` to move data and models to the desired device.  You can check if a GPU is available and which device your tensors are on as shown below.

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device.type)

cuda


# Part I: Pytorch
After completing this part, you will be able to:

1. Understand the basics of PyTorch.
2. Convert NumPy code to PyTorch code.
3. Utilize PyTorch's compilation to speed up your numerical computations.
4. Calculate gradients using PyTorch's automatic differentiation capabilities.
5. Implement a basic training loop in PyTorch to optimize a simple model.



---


Things to try:
1. Implement weight decay by adding a regularization term $||W||$ to the loss function.
2. What happens to the training and test loss when we use more layers.
3. What happens when the learning rate is increased/decreased?


## simple model
Before diving into PyTorch, let's first understand the simple Multilayer Perceptron (MLP) architecture we'll be implementing. An MLP consists of multiple layers of interconnected nodes (neurons). Each connection between nodes has an associated weight, and each node applies an activation function to its input.

In this Colab, we'll be working with a simple MLP with one hidden layer. The input data is first multiplied by a weight matrix W, before a bias term b is added. Then, a ReLU (Rectified Linear Unit) activation function is applied. This introduces non-linearity, allowing the model to learn more complex patterns in the data. The output of the hidden layer is then multiplied by another weight matrix and a bias term is added to produce the final output.

We'll implement this MLP first using NumPy and then using PyTorch, showcasing how easy it is to transition between the two libraries.

## fundamentals

### numpy

In [19]:
#  Use the example code below to implement your own `predict` and `loss` functions.
def predict(params, inputs):
  """Numpy implementation of the MLP architecture.

  Implement a function that loops through all layers and unpack them into
  weights W and bias b.
  Then, it does matrix multiplication followed by relu activation.

  Args:
    params: a list [(W0, b0), (W1, b1), ...] containing the neural network
      weights.
    inputs: the inputs to the neural network,
      of shape (batch_size, input_dimension).

  Returns:
    outputs: the outputs of the neural network,
      of shape (batch_size, output_dimension)
  """
  outputs = []  # in case params = []
  for W, b in params:

    # --- YOUR CODE START ---
    # outputs =... @... +...  # matrix multiplication on inputs
    # inputs =...(...)  # relu activation
    # --- YOUR CODE END ---
    outputs = W @ inputs
    out
  return outputs

def loss(params, batch):
  """Numpy implementation of the square loss function.

  Args:
    params: a list [(W0, b0), (W1, b1), ...] containing the neural network
      weights.
    batch: a tuple (inputs, targets).

  Returns:
    loss: the loss value.
  """
  loss_value = 0
  inputs, targets = batch  # Unpack the batch into inputs and targets
  # --- YOUR CODE START ---
  # preds =...(...)  # get the predictions
  # loss_value = ...  # calculate the mean squared error
  # --- YOUR CODE END ---

  return loss_value

In [None]:
# prepare data
# here: target (y_train) is a linear function of input (x_train) plus some noise

num_examples = 10_000
dim = 100
x_train = np.random.randn(num_examples, dim)
w = np.random.randn(dim,)
y_train = np.dot(x_train, w) + 0.2 * np.random.randn(num_examples,)

x_train = x_train.astype(np.float32)
y_train = y_train.astype(np.float32)

batch = (x_train, y_train)

In [None]:
# initialize model parameters
W1 = np.identity(dim)  # identity matrix
b1 = 0.

W2 = np.random.randn(dim,)
b2 = 0.

params = [(W1, b1), (W2, b2)]  # two layers

In [None]:
loss(params, batch)

In [None]:
# This magic command measures the execution time of the loss function.
%timeit loss(params, batch)

### pytorch

In [None]:
# now implement the same functions in pytorch

def predict(params, inputs):
  """PyTorch implementation of the model."""
  outputs = None  # in case params is empty
  for W, b in params:
    # Note that there is no need to transpose in PyTorch, matmul handles it
    pass
    # --- YOUR CODE START ---
    # outputs =... @... +...  # matrix multiplication on inputs
    # inputs =...(...)  # relu activation
    # --- YOUR CODE END ---
  return outputs

def loss(params, batch):
  """PyTorch implementation of the loss function."""
  loss_value = 0
  inputs, targets = batch  # Unpack the batch into inputs and targets
  # --- YOUR CODE START ---
  # preds =...(...)  # get the predictions
  # loss_value = ...  # calculate the mean squared error
  # --- YOUR CODE END ---

  return loss_value

In [None]:
# Use torch.tensor() to load data into PyTorch

x_train_tensor = torch.tensor(x_train)
y_train_tensor = torch.tensor(y_train)

W1_tensor = torch.tensor(W1).float()
b1_tensor = torch.tensor(b1).float()
W2_tensor = torch.tensor(W2).float()
b2_tensor = torch.tensor(b2).float()

batch = (x_train_tensor, y_train_tensor)
params = [(W1_tensor, b1_tensor), (W2_tensor, b2_tensor)]

In [None]:
# Warm up. The result of `loss(params, batch)` is a PyTorch Tensor.
# If a CUDA-enabled GPU is available, the tensor will likely reside on the GPU.

loss(params, batch)

We haven't done anything yet, like compilation or parallelization. Yet, we observe a significant speedup already.

In [None]:
# PyTorch also benefits from just-in-time compilation and GPU acceleration.
# The first run might be slower due to initial overhead, but subsequent runs are usually faster.
# The `%timeit` magic command in IPython/Jupyter will automatically run the function multiple times
# to give you an average execution time.

%timeit loss(params, batch)

### jit
When using Just-in-time (JIT), the code is complied the first time, then the compiled code is used in the subsequent calls.

In [None]:
# Compiling using torch.compile

jit_loss = torch.compile(loss)

jit_loss(params, batch)  # warmup

In [None]:
# now compare the time with before. We bigger speedup without having to change
# much in our code
%timeit jit_loss(params, batch)

### auto-differentiation

Automatic differentiation is a technique for automatically calculating the gradients of a function. This is crucial for training neural networks, where we need to know the gradient of the loss function with respect to the model parameters in order to update the parameters and improve the model.

In PyTorch, automatic differentiation is handled by the torch.autograd package. It works by dynamically building a computational graph as you perform operations on tensors that have `requires_grad=True`.  Gradients are then calculated using the chain rule during the backward pass, which is initiated by calling `.backward()` on a scalar tensor (usually the loss).

Think of it like this: PyTorch's autograd keeps track of operations on tensors that require gradients. When you call .`backward()`, it automatically traverses this recorded history to compute the gradients for you. You can then access these gradients through the `.grad` attribute of your tensors.

In [None]:
def grad_loss(params, batch):
  """Calculates the gradient of the loss function w.r.t. the parameters."""
  for W, b in params:
    pass
    # --- YOUR CODE START ---
    #... # set requires_grad=True for the parameters
    # --- YOUR CODE END ---

  loss_value = jit_loss(params, batch) # Use the traced loss for efficiency
  # --- YOUR CODE START ---
  # ... # calculate gradients using autograd
  # --- YOUR CODE END ---

  # Collect gradients
  grads = []
  for W, b in params:
    pass
    # --- YOUR CODE START ---
    # grads.append(...)  # collect the gradients
    #... # Set requires_grad back to False and reset
    # --- YOUR CODE END ---
  return grads

# Because params is a list of tuples, the gradient will also be a list of tuples
print(grad_loss(params, batch))

## Training Loop: putting it all together
This training loop iteratively updates the model parameters to minimize the loss function. Here's a breakdown of the steps involved:

1. **Calculate the loss:** The `jit_loss` function calculates the loss between the model's predictions and the actual targets.
2. **Calculate the gradient:** The `grad_loss` function calculates the gradient of the loss function with respect to the model parameters. This tells us how to adjust the parameters to reduce the loss.
3. **Update the parameters:** The parameters are updated by subtracting a fraction of the gradient (determined by the learning rate `lr`) from the current parameter values. This moves the parameters in the direction that reduces the loss.

This process is repeated for a specified number of steps (`num_steps`), gradually improving the model's performance.

In [None]:
num_steps = 500  # total number of steps
lr = 0.001  # learning rate

modelparams = params  # initialize the parameters
history = []

# Use a progress bar from tqdm
for i in tqdm.tqdm(range(num_steps), desc="Training Progress"):
  # calculate the loss
  new_loss = jit_loss(modelparams, batch)
  history.append(new_loss.item()) # Append the loss value as a Python scalar

  # calculate the gradient
  grads = grad_loss(modelparams, batch)

  # update the parameters
  with torch.no_grad():  # Disable gradient tracking during updates
      for lyr in range(len(modelparams)):
          modelparams[lyr] = (
              modelparams[lyr][0] - lr * grads[lyr][0],
              modelparams[lyr][1] - lr * grads[lyr][1],
          )

In [None]:
plt.figure(figsize=(8, 5))  # Adjust figure size for better readability
plt.plot(history, color='blue', linewidth=2)  # Customize line color and width
plt.title('Training Loss', fontsize=14)  # Add a title with increased font size
plt.xlabel('Training Step', fontsize=12)  # Add x-axis label with increased font size
plt.ylabel('Loss (Log Scale)', fontsize=12)  # Add y-axis label with increased font size
plt.yscale('log')
plt.grid(True, linestyle='--', alpha=0.7)  # Add a grid for better readability
plt.show()

# Part II: Text Tokenization

Part II provides a gentle introduction to tokenization, a fundamental step in Natural Language Processing (NLP). Tokenization is the process of breaking down text into smaller units, called tokens, which can be words, subwords, or characters. These tokens are then converted into numerical representations that can be processed by machine learning models.

Our "difference dataset" is of the form `<number1 - number2>`, where `number1` and `number2` are some integers and the binary label y is 1 if and only if `number1 > number2`.

In [None]:
# @title dataset
# @markdown <font color='blue'>Double click to see how the dataset is constructed</font>
def sample_difference_dataset(
    dataset_size: int,
    lengths: list[int],
    k: int,
):
  """Generates a dataset of expressions representing the difference between two

  numbers.

  Args:
    dataset_size: The number of expressions to generate.
    lengths: A list of integers representing the desired lengths (number of
      digits) of the expressions. For example, the expression 123 - 4 has a
      length of five.
    k: The base of the numbers in the expressions.

  Returns:
    A list of tuples, where each tuple contains an expression (str) and
    its corresponding sign (1 if the difference is positive, 0 otherwise).
  """
  data_all = []
  for length in lengths:
    if length <= 2:
      raise ValueError(
          f"The length of the expression must be greater than 2, got {length}."
      )

    # We only use `length - 1` tokens for the two values to account for the `+`.
    length_n = np.random.randint(1, length - 1, size=(dataset_size,))
    length_m = length - 1 - length_n

    integer_n = [random.randint(1, k ** int(len_n) - 1) for len_n in length_n]
    integer_m = [random.randint(1, k ** int(len_m) - 1) for len_m in length_m]
    diff_sign = [int(x > y) for x, y in zip(integer_n, integer_m)]

    integer_n = [str(x)[::-1] for x in integer_n]
    integer_m = [str(x)[::-1] for x in integer_m]
    expressions = [f"{a}-{b}" for a, b in zip(integer_n, integer_m)]

    data = [(x, y) for x, y in zip(expressions, diff_sign)]
    data_all.extend(data)
  data_all = list(set(data_all))
  random.shuffle(data_all)
  return data_all

In [None]:
MAX_TRAIN_LENGTH = 10  # the maximum length allowed in the training split
MAX_TEST_LENGTH = 15  # the maximum length allowed in the test split

train_ds = sample_difference_dataset(
    dataset_size=2500,
    lengths=list(range(3, MAX_TRAIN_LENGTH + 1)),
    k=10,
)
test_ds = sample_difference_dataset(
    dataset_size=1000,
    lengths=list(range(MAX_TRAIN_LENGTH + 1, MAX_TEST_LENGTH + 1)),
    k=10,
)

print(f"Train dataset size {len(train_ds)}")
print(f"Test dataset size {len(test_ds)}")

First, let's see how a single example in the dataset looks like.

In [None]:
it = iter(train_ds)

In [None]:
# let's see a few examples
for _ in range(5):
  text, label = next(it)
  print("text: ", text)
  print("label: ", label)
  print()

## building the vocabulary

In order to build a tokenizer, we will need to build a vocabulary of tokens (think of those as 'words' or any other useful chunks of text). To construct a good vocabulary, we will learn it from our dataset itself. So, let's collect a bunch of examples.

In [None]:
# size of corpus to build the tokenizer
corpus_size = 5_000  # @param = 'int'

# size of the vocabulary
vocab_size = 12  # @param = 'int'

# maximum length of examples in tokens
max_len = MAX_TEST_LENGTH + 1  # @param = 'int'

# pad value
pad_value = 0  # @param = 'int'

In [None]:
#  preprocessing function will decode and keep text only
corpus = [text for text, _ in train_ds[:corpus_size]]

In [None]:
corpus[:10]

## creating the tokenizer


In [None]:
# now, we build the tokenizer
tokenizer = tf.keras.preprocessing.text.Tokenizer(
    num_words=vocab_size,
    oov_token=None,
    char_level=True,
)
tokenizer.fit_on_texts(corpus)

In [None]:
# note how the tokenizer figured out it was best to tokenize each digit separately
tokenizer.index_word

The tokenizer is now trained. Let's see how it works.

In [None]:
# Example usage:
print("original text: ", text)

# tokenize text
tokens = tokenizer.texts_to_sequences([text])
print("tokens: ", tokens)
print("number of tokens: ", len(tokens[0]))

In [None]:
# we can see the actual tokens by converting each token individually to text
print(tokenizer.sequences_to_texts(np.array(tokens).reshape((-1, 1))))

get some information about the distribution of tokens in our corpus

In [None]:
# Let's examine the distribution of tokens in the corpus:
print("Token frequency:")
dict(list(tokenizer.word_counts.items()))

## preprocessing the data

The last step is to make sure that all examples have the same shape so that they can be batched together. For that, we will need to pad short examples with a special padding token. Second, we will use the last token as `cls` for classification. We can give it a special value or just use the same as the padding value. We will use a value for 0 for both.

**Why?**

Neural networks typically process data in batches for efficiency. Batching allows the network to perform computations on multiple examples simultaneously, which speeds up training and inference. However, for batching to work, all examples in a batch must have the same shape (i.e., the same number of tokens).

Since text sequences have variable lengths, we use padding to make them uniform. Padding involves adding special padding tokens to shorter sequences to make them the same length as the longest sequence in the batch. This ensures that all examples in a batch have a consistent shape, enabling efficient batch processing.

In [None]:
def preprocess_function(text, label):
  pass
  # --- YOUR CODE START ---
  # tokens =...  # tokenize the text
  # tokens =...  # pad the sequences
  # tokens =...  # convert to tensor
  # label =...  # convert to tensor
  # --- YOUR CODE END ---
  return tokens, label

In [None]:
# Apply the preprocessing function to the training and test datasets
print("preprocessing training examples ... ")
x_train = []
y_train = []
for text, label in tqdm.tqdm(train_ds):
  tokens, label = preprocess_function(text, label)
  x_train.append(tokens)
  y_train.append(label)

print("preprocessing test examples ... ")
x_test = []
y_test = []
for text, label in tqdm.tqdm(test_ds):
  tokens, label = preprocess_function(text, label)
  x_test.append(tokens)
  y_test.append(label)

# convert to Jax Arrays
x_train = torch.stack(x_train)
y_train = torch.cat(y_train)
x_test = torch.stack(x_test)
y_test = torch.cat(y_test)

In [None]:
print("x_train.shape: ", x_train.shape)
print("y_train.shape: ", y_train.shape)
print("x_test.shape: ", x_test.shape)
print("y_test.shape: ", y_test.shape)

In [None]:
# let's see how it looks like
x_train[0], y_train[0]

# Part III: Transformer Architecture - Classification task


## Training loop

First, let's define and fix our loss function, training step and training loop. We won't need to change these later. Make sure you execute the cell below and double click to see the actual code.

In [None]:
#@markdown <font color="blue">Double click here to see the training loop</font>
class TrainState:
    def __init__(self, model: nn.Module, optimizer: Optimizer, step: int = 0, **kwargs: Any):
        self.model = model
        self.optimizer = optimizer
        self.step = step
        self.metadata = kwargs

    def state_dict(self):
        return {
            "step": self.step,
            "model_state": self.model.state_dict(),
            "optimizer_state": self.optimizer.state_dict(),
            "metadata": self.metadata,
        }

def train(Model, epochs=10, batch_size=128, lr=3e-4, wd=1e-5, **kwargs):
    # Initialize the model
    model = Model(**kwargs)  # 2 classes for sentiment analysis

    # Use GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Initialize the optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)

    # Create a TrainState equivalent (using a class as defined previously)
    state = TrainState(model=model, optimizer=optimizer)

    # Define the loss function
    def loss_fn(params, x, y):
        # params are automatically handled by PyTorch within the model
        model.load_state_dict(params) # Load the parameters into the model

        logits = model(x)

        loss = nn.CrossEntropyLoss()(logits, y)

        return loss

    # Gradient function (using autograd)
    def train_step(state, x, y):
        # Enable gradient calculation
        state.model.train()

        loss = loss_fn(state.model.state_dict(), x, y)

        # Calculate gradients
        state.optimizer.zero_grad()
        loss.backward()

        # Clip gradients
        torch.nn.utils.clip_grad_norm_(state.model.parameters(), max_norm=1.0)

        # Update parameters
        state.optimizer.step()

        return state, loss.item()

    # Report accuracy and loss
    def report(state, x, y):
        x = x.to(device)
        y = y.to(device)

        state.model.eval()  # Set the model to evaluation mode
        with torch.no_grad():
            logits = state.model(x)
            predictions = torch.argmax(logits, dim=1)
            acc = torch.sum(predictions == y).item()
        return acc

    train_step_jit = torch.compile(train_step)

    # Create DataLoaders
    train_dataset = TensorDataset(x_train, y_train)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    test_dataset = TensorDataset(x_test, y_test)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)

    # Training accuracy (on a small subset only)
    num_eval_examples = min(3200, len(x_train))
    train_accuracy = 0
    model.eval()
    with torch.no_grad():
        for i in range(0, num_eval_examples, batch_size):
            x_batch = x_train[i : i + batch_size].to(device)
            y_batch = y_train[i : i + batch_size].to(device)
            train_accuracy += report(state, x_batch, y_batch)
    train_accuracy /= num_eval_examples

    # Test accuracy
    num_eval_examples = min(3200, len(x_test))
    test_accuracy = 0
    model.eval()
    with torch.no_grad():
        for i in range(0, num_eval_examples, batch_size):
            x_batch = x_test[i : i + batch_size].to(device)
            y_batch = y_test[i : i + batch_size].to(device)
            test_accuracy += report(state, x_batch, y_batch)
    test_accuracy /= num_eval_examples

    print("Before training:")
    print(f"train accuracy: {train_accuracy}, test accuracy: {test_accuracy}")

    # Begin the training loop
    print(f"epochs {epochs}")
    for epoch in tqdm.tqdm(range(epochs), desc="Epochs"):
      model.train()
      for x_batch, y_batch in tqdm.tqdm(train_loader, desc="Batches", leave=False):
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        state, loss = train_step_jit(state, x_batch, y_batch)

        # Training accuracy
        train_accuracy = 0
        model.eval()
        with torch.no_grad():
            for i in range(0, len(x_train), batch_size):
                x_batch = x_train[i : i + batch_size].to(device)
                y_batch = y_train[i : i + batch_size].to(device)

                train_accuracy += report(state, x_batch, y_batch)
        train_accuracy /= len(x_train)

        # Test accuracy
      test_accuracy = 0
      model.eval()
      with torch.no_grad():
        for i in range(0, len(x_test), batch_size):
          x_batch = x_test[i : i + batch_size].to(device)
          y_batch = y_test[i : i + batch_size].to(device)
          test_accuracy += report(state, x_batch, y_batch)
      test_accuracy /= len(x_test)

      print(f"Epoch: {epoch + 1}")
      print(f"train accuracy: {train_accuracy}, test accuracy: {test_accuracy}")

## initial model
Initially, we will only use a single self-attention layer and then apply a linear classifier on the CLS token. Later, we will improve the model step-by-step until we arrive at the standard transformer architecture.

In [None]:
# Define the self-attention layer
class SelfAttention(nn.Module):

  def __init__(self, embed_dim):
    super().__init__()
    self.embed_dim = embed_dim

    # Calculate query, key, and value matrices using linear layers
    self.query = nn.Linear(embed_dim, embed_dim)
    self.key = nn.Linear(embed_dim, embed_dim)
    self.value = nn.Linear(embed_dim, embed_dim)

  def forward(self, x):
    # --- YOUR CODE START ---
    query = self.query(x)
    key = self.key(x)
    value = self.value(x)
    # attention_scores =...  # calculate the attention scores
    # attention_weights =...  # apply the softmax function
    # output =...  # apply the attention weights
    # --- YOUR CODE END ---
    return output


# Define the model
class SimpleTransformer(nn.Module):

  def __init__(
      self, vocab_size, embed_dim=128, num_classes=2, max_seq_length=None
  ):
    super().__init__()
    self.embed_dim = embed_dim
    self.num_classes = num_classes
    self.max_seq_length = max_seq_length  # Not used in this specific code, but included for consistency

    # Embedding layer
    self.embedding = nn.Embedding(vocab_size, embed_dim)

    # Self-attention layer
    self.self_attention = SelfAttention(embed_dim)

    # Linear classifier
    self.classifier = nn.Linear(embed_dim, num_classes)

  def forward(self, x):
    # Embedding layer
    x = self.embedding(x)

    # Self-attention layer
    x = self.self_attention(x)

    # Extract the CLS token (the last token)
    cls_token = x[:, -1, :]  # Assuming the last token is the CLS token

    # Linear classifier
    logits = self.classifier(cls_token)
    return logits

Feel free to change the embedding dimension and observe how it impacts the test accuracy. Note how the model cannot do better than random guessing. Can you think why?

In [None]:
kwg = dict(
    embed_dim=64,
    vocab_size=vocab_size,
)
train(SimpleTransformer, **kwg)

## MLP layers

In [None]:
class SimpleTransformer(nn.Module):

  def __init__(
      self,
      vocab_size,
      mlp_dim=256,
      embed_dim=128,
      num_classes=2,
      max_seq_length=None,
  ):
    super().__init__()
    self.embed_dim = embed_dim
    self.mlp_dim = mlp_dim
    self.num_classes = num_classes
    self.max_seq_length = max_seq_length  # Not used in this specific code, but included for consistency

    # Embedding layer
    self.embedding = nn.Embedding(vocab_size, embed_dim)

    # Self-attention layer
    self.self_attention = SelfAttention(embed_dim)

    # Linear classifier
    self.classifier = nn.Linear(embed_dim, num_classes)

    # MLP layers
    self.mlp = nn.Sequential(
        nn.Linear(embed_dim, mlp_dim), nn.ReLU(), nn.Linear(mlp_dim, embed_dim)
    )

  def forward(self, x):
    logits = None
    # --- YOUR CODE START ---
    # x =...  # embedding layer
    # x =...  # self-attention layer
    # x =...  # mlp layer
    # cls_token =...  # extract the CLS token
    # logits =...  # linear classifier
    # --- YOUR CODE END ---
    return logits

In [None]:
kwg = dict(
    embed_dim=64,
    mlp_dim=64 * 4,
    vocab_size=vocab_size,
)
train(SimpleTransformer, **kwg)

## positional embeddings
One major issue with our model above is that it cannot see the position of the token. For example, the following two examples are similar:
- "123 - 4"
- "12 - 34"

We will use learned position embeddings below.

In [None]:
class SimpleTransformer(nn.Module):

  def __init__(
      self,
      vocab_size,
      max_seq_length,
      mlp_dim=256,
      embed_dim=128,
      num_classes=2,
  ):
    super().__init__()
    self.embed_dim = embed_dim
    self.mlp_dim = mlp_dim
    self.num_classes = num_classes
    self.max_seq_length = max_seq_length

    # Embedding layer
    self.embedding = nn.Embedding(vocab_size, embed_dim)

    # Positional embedding layer
    self.pos_embedding = nn.Embedding(max_seq_length, embed_dim)

    # Self-attention layer
    self.self_attention = SelfAttention(embed_dim)

    # Linear classifier
    self.classifier = nn.Linear(embed_dim, num_classes)

    # MLP layers
    self.mlp = nn.Sequential(
        nn.Linear(embed_dim, mlp_dim), nn.ReLU(), nn.Linear(mlp_dim, embed_dim)
    )

  def forward(self, x):
    # Embedding layer
    x = self.embedding(x)

    # Positional embeddings
    # --- YOUR CODE START ---
    # positions =...  # generate position indices
    # pos_embeddings =...  # get position embeddings
    # x =...  # add scaled position embeddings to input embeddings
    # --- YOUR CODE END ---

    # Self-attention layer
    x = self.self_attention(x)

    # MLP layer
    x = self.mlp(x)

    # Extract the CLS token (the last token)
    cls_token = x[:, -1, :]  # Assuming the last token is the CLS token

    # Linear classifier
    logits = self.classifier(cls_token)
    return logits

In [None]:
kwg = dict(
    embed_dim=64,
    mlp_dim=64 * 4,
    vocab_size=vocab_size,
    max_seq_length=max_len,
)
train(SimpleTransformer, **kwg)

## normalization layer

In [None]:
class SimpleTransformer(nn.Module):

  def __init__(
      self,
      vocab_size,
      max_seq_length,
      mlp_dim=256,
      embed_dim=128,
      num_classes=2,
  ):
    super().__init__()
    self.embed_dim = embed_dim
    self.mlp_dim = mlp_dim
    self.num_classes = num_classes
    self.max_seq_length = max_seq_length

    # Embedding layer
    self.embedding = nn.Embedding(vocab_size, embed_dim)

    # Positional embedding layer
    self.pos_embedding = nn.Embedding(max_seq_length, embed_dim)

    # Self-attention layer
    self.self_attention = SelfAttention(embed_dim)

    # Linear classifier
    self.classifier = nn.Linear(embed_dim, num_classes)

    # MLP layers
    self.mlp = nn.Sequential(
        nn.Linear(embed_dim, mlp_dim), nn.ReLU(), nn.Linear(mlp_dim, embed_dim)
    )

    # Layer Normalization (Specify normalized_shape)
    self.layer_norm1 = nn.LayerNorm(embed_dim)
    self.layer_norm2 = nn.LayerNorm(embed_dim)

  def forward(self, x):
    # Embedding layer
    x = self.embedding(x)

    # Positional embeddings
    positions = torch.arange(x.shape[1], device=x.device)
    pos_embeddings = self.pos_embedding(positions)
    x = x + pos_embeddings / math.sqrt(
        self.embed_dim
    )  # Scale positional embeddings

    # Self-attention layer
    x = self.self_attention(x)

    # We add layer norm
    # --- YOUR CODE START ---
    # x =...  # apply layer normalization
    # --- YOUR CODE END ---

    # MLP layer
    x = self.mlp(x)

    # We add layer norm
    # --- YOUR CODE START ---
    # x =...  # apply layer normalization
    # --- YOUR CODE END ---

    # Extract the CLS token (the last token)
    cls_token = x[:, -1, :]  # Assuming the last token is the CLS token

    # Linear classifier
    logits = self.classifier(cls_token)
    return logits

In [None]:
kwg = dict(
    embed_dim=64,
    mlp_dim=64 * 4,
    vocab_size=vocab_size,
    max_seq_length=max_len,
)
train(SimpleTransformer, **kwg)

## deeper architectures

In order to make the code more concise and readable, we define the following TransformerEncoderBlock.

In [None]:
class TransformerEncoderBlock(nn.Module):

  def __init__(self, embed_dim, mlp_dim):
    super().__init__()
    self.embed_dim = embed_dim
    self.mlp_dim = mlp_dim

    # Self-attention layer
    self.self_attention = SelfAttention(embed_dim)

    # Layer normalization
    self.layer_norm1 = nn.LayerNorm(embed_dim)

    # MLP layers
    self.mlp = nn.Sequential(
        nn.Linear(embed_dim, mlp_dim), nn.ReLU(), nn.Linear(mlp_dim, embed_dim)
    )

    # Layer normalization
    self.layer_norm2 = nn.LayerNorm(embed_dim)

  def forward(self, x):
    x = self.self_attention(x)
    x = self.layer_norm1(x)
    # MLP layer
    x = self.mlp(x)
    # Layer norm
    x = self.layer_norm2(x)

    return x

In [None]:
class SimpleTransformer(nn.Module):

  def __init__(
      self,
      *,
      vocab_size,
      max_seq_length,
      num_layers,
      mlp_dim=256,
      embed_dim=128,
      num_classes=2,
  ):
    super().__init__()
    self.embed_dim = embed_dim
    self.mlp_dim = mlp_dim
    self.num_classes = num_classes
    self.max_seq_length = max_seq_length
    self.num_layers = num_layers

    # Embedding layer
    self.embedding = nn.Embedding(vocab_size, embed_dim)

    # Positional embedding layer
    self.pos_embedding = nn.Embedding(max_seq_length, embed_dim)

    # Self-attention layer
    self.self_attention = SelfAttention(embed_dim)

    # Linear classifier
    self.classifier = nn.Linear(embed_dim, num_classes)

    # Transformer Encoder Blocks
    self.transformer_blocks = nn.ModuleList(
        [TransformerEncoderBlock(embed_dim, mlp_dim) for _ in range(num_layers)]
    )

  def forward(self, x):
    # Embedding layer
    x = self.embedding(x)

    # Positional embeddings
    positions = torch.arange(x.shape[1], device=x.device)
    pos_embeddings = self.pos_embedding(positions)
    x = x + pos_embeddings / math.sqrt(
        self.embed_dim
    )  # Scale positional embeddings

    # Stack multiple transformer encoder blocks
    # --- YOUR CODE START ---
    # for... in...:  # iterate over transformer blocks
    #   x =...  # apply the block
    # --- YOUR CODE END ---

    # Extract the CLS token (the last token)
    cls_token = x[:, -1, :]  # Assuming the last token is the CLS token

    # Linear classifier
    logits = self.classifier(cls_token)
    return logits

In [None]:
kwg = dict(
    embed_dim=64,
    mlp_dim=64 * 4,
    vocab_size=vocab_size,
    max_seq_length=max_len,
    num_layers=3,
)
train(SimpleTransformer, **kwg)

## skip connections

In [None]:
class TransformerEncoderBlock(nn.Module):

  def __init__(self, embed_dim, mlp_dim):
    super().__init__()
    self.embed_dim = embed_dim
    self.mlp_dim = mlp_dim

    # Self-attention layer
    self.self_attention = SelfAttention(embed_dim)

    # Layer normalization
    self.layer_norm1 = nn.LayerNorm(embed_dim)

    # MLP layers
    self.mlp = nn.Sequential(
        nn.Linear(embed_dim, mlp_dim), nn.ReLU(), nn.Linear(mlp_dim, embed_dim)
    )

    # Layer normalization
    self.layer_norm2 = nn.LayerNorm(embed_dim)

  def forward(self, x):
    x = self.self_attention(x)
    x = self.layer_norm1(x)
    y = self.mlp(x)

    # We introduce a skip connection
    # --- YOUR CODE START ---
    #   x =...  # skip connections
    # --- YOUR CODE END ---

    # Layer norm
    x = self.layer_norm2(x)

    return x

The SimpleTransformer is the same, we just update it with the new TransformerEncoderBlock

In [None]:
class SimpleTransformer(nn.Module):

  def __init__(
      self,
      *,
      vocab_size,
      max_seq_length,
      num_layers,
      mlp_dim=256,
      embed_dim=128,
      num_classes=2,
  ):
    super().__init__()
    self.embed_dim = embed_dim
    self.mlp_dim = mlp_dim
    self.num_classes = num_classes
    self.max_seq_length = max_seq_length
    self.num_layers = num_layers

    # Embedding layer
    self.embedding = nn.Embedding(vocab_size, embed_dim)

    # Positional embedding layer
    self.pos_embedding = nn.Embedding(max_seq_length, embed_dim)

    # Self-attention layer
    self.self_attention = SelfAttention(embed_dim)

    # Linear classifier
    self.classifier = nn.Linear(embed_dim, num_classes)

    # Transformer Encoder Blocks
    self.transformer_blocks = nn.ModuleList(
        [TransformerEncoderBlock(embed_dim, mlp_dim) for _ in range(num_layers)]
    )

  def forward(self, x):
    # Embedding layer
    x = self.embedding(x)

    # Positional embeddings
    positions = torch.arange(x.shape[1], device=x.device)
    pos_embeddings = self.pos_embedding(positions)
    x = x + pos_embeddings / math.sqrt(
        self.embed_dim
    )  # Scale positional embeddings

    # Stack multiple transformer encoder blocks
    for block in self.transformer_blocks:
      x = block(x)

    # Extract the CLS token (the last token)
    cls_token = x[:, -1, :]  # Assuming the last token is the CLS token

    # Linear classifier
    logits = self.classifier(cls_token)
    return logits

In [None]:
kwg = dict(
    embed_dim=64,
    mlp_dim=64 * 4,
    vocab_size=vocab_size,
    max_seq_length=max_len,
    num_layers=3,
)
train(SimpleTransformer, **kwg)

# Part IV: Transformer Architecture - Sequence generation task

In the previous part, the task was to predict a label.
We added a classification head on top of the last output of the transformer sequence to predict it.

In this part, we will train a transformer on a sequence generation task.
In order to do so, we will apply several key changes :
- Cross attention module
- Causal masking
- Encoder-Decoder architecture
- Autoregressive sampling
- Perplexity loss instead of the cross-entropy loss

We'll go through each point in the following sections :)

## Dataset - integer addition

We'll tackle the problem of **integer addition** using a Transformer model.  Instead of calculating the entire sum at once, we'll train our model to generate the answer **auto-regressively**, predicting one digit at a time, similar to how we learn to do addition manually.

We will represent the numbers in a **reversed** manner. For instance, the number 123 will be represented as "321".  We'll also pad our input and output sequences with zeros to ensure they have a uniform length and add a special "#" character to the output to signal the end of the sum.

The following `AdditionTask` class will be responsible for generating our training and eval data, which consists of pairs of reversed addition problems and their corresponding non reversed solutions.

Let's dive into the code!

In [None]:
# @markdown <font color='blue'>Double click here to see how the dataset is constructed</font>


class AdditionTask:
  """Returns a batch of additions and their results.

  This function generates a batch of addition problems where each problem
  consists of two randomly generated numbers and their sum. The numbers are
  represented as strings in reversed order (e.g., '321' for 123) and are
  padded with zeros to ensure uniform length. The result also includes a special
  "#" character as a separator.

  Args:
    batch_size: The number of addition problems to generate.
    length: The maximum length of the input sequence (including the "+" symbol
      and padding).

  Returns:
    A dictionary containing two keys:
      - "input": A list of strings, where each string is a padded addition
      problem in reversed order (e.g., "123+45" might be "321+54000").
      - "output": A list of strings, where each string is the result of the
      corresponding addition problem, also padded and with a "#" (e.g.,
      "664#000").

  Raises:
    ValueError: If the provided length is less than or equal to 2.
  """

  def sample_batch(self, batch_size: int, length: int):
    """Returns a batch of additions and their results."""
    if length <= 2:
      raise ValueError("Length must be greater than 2.")

    # We only use `length - 1` tokens for the two values to account for the `+`.
    # Generate random lengths for the two numbers in each addition problem.
    length_n = np.random.randint(1, length - 1, size=(batch_size,))
    length_m = length - 1 - length_n

    # Generate random integers based on the calculated lengths.
    integer_n = [random.randint(1, 10 ** int(len_n) - 1) for len_n in length_n]
    integer_m = [random.randint(1, 10 ** int(len_m) - 1) for len_m in length_m]
    # Calculate the sum of the generated integers.
    integer_sum = list(map(sum, zip(integer_n, integer_m)))

    # Convert integers to reversed strings (e.g., 123 becomes "321").
    knary_n = [str(x)[::-1] for x in integer_n]
    knary_m = [str(x)[::-1] for x in integer_m]

    # Create the addition expressions by concatenating the reversed strings with a "+".
    expressions = [f"{a}+{b}" for a, b in zip(knary_n, knary_m)]

    # Pad the expressions with zeros to reach the desired length.
    expressions = [a + "".join(["0"] * (length - len(a))) for a in expressions]

    # Convert the sums to strings.
    results = list(map(str, integer_sum))
    # Append "#" to the results and pad with zeros.
    results = [
        res + "#" + "".join(["0"] * (length - len(res))) for res in results
    ]
    return {
        "input": expressions,
        "output": results,
    }

  @property
  def input_size(self) -> int:
    """Returns the input size for the models."""
    return 12

  @property
  def output_size(self) -> int:
    """Returns the output size for the models."""
    return 12

  @property
  def vocab_size(self) -> int:
    """Returns the output size for the models."""
    return 12

  def output_length(self, input_length: int) -> int:
    return input_length + 1

In [None]:
# # Instantiate an AdditionTask object. This object will handle data generation for our addition task.
task = AdditionTask()

# Define the maximum length of the addition sequence (including digits, '+', and padding).
# We can control the difficulty of the task by changing this value: larger values correspond to more complex addition problems, potentially requiring more steps to solve.
# Note that the length must be greater than 2 to accommodate at least one digit for each number plus the '+' symbol.
MAX_TRAIN_LENGTH = 10
MAX_TEST_LENGTH = 20

# Generate a sample batch of addition problems.
data = task.sample_batch(batch_size=16, length=MAX_TRAIN_LENGTH)

# Let's visualize the generated data using a Pandas DataFrame for better readability.
# Each row represents a single addition problem.
# The 'input' column shows the addition problem in reversed and padded format (e.g., "6+54058925").
# The 'output' column shows the corresponding result, not reversed, padded, and with the '#' symbol (e.g., "52985051#00").
pd.DataFrame(data)

The token "#" is the EOS (End of Sentence) token. All the 0 characters after it are added with padding, for shape consistency during training and eval.

We also fit a tokenizer and check the tokens it assigned

In [None]:
# Define and fit the tokenizer
tokenizer = tf.keras.preprocessing.text.Tokenizer(
    num_words=None,
    oov_token=None,
    char_level=True,
)
tokenizer.fit_on_texts(data["input"] + data["output"])
tokenizer.word_index

We also define a preprocessing function, that we'll be using in the train and eval loops.

In [None]:
def preprocess_data(batch, tokenizer):
  """Tokenizes and pads the input and output sequences for the model.

  Args:
    batch: A dictionary containing the input and output sequences as lists of
      strings.
    tokenizer: A fitted Tokenizer object (e.g., from
      torchtext.data.utils.get_tokenizer or a custom implementation).

  Returns:
    A dictionary containing the processed input and output sequences as PyTorch
    tensors,
  """
  # Tokenize the input sequences using the provided tokenizer.
  # This converts each string into a sequence of integer indices.
  tokens_input = (
      torch.tensor(
          tokenizer.texts_to_sequences(batch["input"]), dtype=torch.long
      )
      - 1
  )

  # Tokenize the output sequences.
  tokens_output = (
      torch.tensor(
          tokenizer.texts_to_sequences(batch["output"]), dtype=torch.long
      )
      - 1
  )

  # Pad the sequences to the maximum length within the batch for consistent tensor shapes.
  tokens_input = torch.nn.utils.rnn.pad_sequence(tokens_input, batch_first=True)
  tokens_output = torch.nn.utils.rnn.pad_sequence(
      tokens_output, batch_first=True
  )

  # Return the processed data as a dictionary.
  return dict(input=tokens_input, output=tokens_output)

In [None]:
preprocess_data(data, tokenizer)

## Loss and accuracy functions

**Connecting to Previous Concepts and Explaining the Loss Function**

Recall that in Part III, we dealt with a single-label classification problem.  For that scenario, we used a standard classification loss, as our goal was to predict a single, correct label.

In this task, however, we are generating an entire sequence of tokens.  We can think of this as a series of individual classification problems, one for each position in the sequence. Each element in the output sequence requires us to predict the correct token from our vocabulary (digits 0-9, "+", and "#"). Therefore for each position, the model outputs a probability distribution over all possible tokens, and we need to compare it with the true distribution (one-hot encoded).

This is where the `loss_fn` we defined earlier comes in. It computes the cross-entropy loss between the predicted token probabilities and the true one-hot encoded tokens *for each position* in the sequence.  Then, it averages these individual cross-entropy losses across the time dimension (represented by `axis=-1`) to get the overall loss for the sequence.

### Loss definitions

In [None]:
def _pointwise_loss_fn(
    output: torch.Tensor, target: torch.Tensor
) -> torch.Tensor:
  """Calculates the pointwise cross-entropy loss between predicted probabilities and the true target values.

  This function computes the loss for each token in the sequence individually.
  """
  pass
  # --- YOUR CODE START ---
  # target_one_hot =...  # convert target to one-hot encoding
  # return...  # calculate the cross-entropy loss
  # --- YOUR CODE END ---


# Create the loss and accuracy based on the pointwise ones.
def loss_fn(output, target):
  """Returns the loss between an output and a target, averaged across tokens."""
  # While in the previous part, we didn't need to average as we predicted a single token.
  # We sum over the last dimension (num_classes) and then average
  # over all other dimensions (batch_size and sequence_length).
  return torch.mean(torch.sum(_pointwise_loss_fn(output, target), dim=-1))

### Toy example

Let's define a small toy example to see how the function works.

In [None]:
batch_size = 2
sequence_length = 3
num_classes = 4  # Let's say we have 4 possible tokens

# 1. Dummy Model Output (logits - before softmax):
#    Imagine the model predicts the following logits for each token in each sequence:
output = torch.tensor(
    [
        [
            [1.0, 2.0, 3.0, 0.5],
            [0.1, 0.5, 1.5, 2.0],
            [2.5, 1.0, 0.2, 0.1],
        ],  # Sequence 1
        [
            [0.2, 0.3, 0.5, 0.1],
            [1.0, 1.5, 2.0, 0.5],
            [0.1, 0.2, 0.5, 2.5],
        ],  # Sequence 2
    ],
    dtype=torch.float32,
)

# 2. Dummy Target :
target = torch.Tensor([
    [0, 3, 2],
    [1, 0, 0],
]).long()

# 3. We calculate pointwise loss
pointwise_loss = _pointwise_loss_fn(output, target)
print("Pointwise Loss:\n", pointwise_loss)

# Note that pointwise_loss is a Tensor of shape (batch_size, sequence_length)
# We need to average over the sequence dimension losses in this case, as opposed to Part III
# where the sequence_length was 1.

**Defining Accuracy for Sequence Prediction**

Building on our understanding of loss calculation for sequence generation, we now define an accuracy function tailored for this task. As we're predicting a sequence of tokens, we need to consider the correctness of each predicted token within the sequence.

The `accuracy_fn` we defined in the cell below does precisely this.  It computes the accuracy for each token position and then averages them. Importantly, it incorporates a masking mechanism (`_accuracy_mask`) to handle sequences of varying lengths. This mask ensures that we only consider the relevant tokens in our accuracy calculation, effectively ignoring any padding tokens that might be present due to the fixed-size nature of our input arrays.

In essence, we are evaluating the model's performance on a token-by-token basis, focusing only on the parts of the sequence that correspond to the actual target values. This gives us a more granular and accurate measure of how well the model is learning to generate the correct addition results.

In [None]:
def _accuracy_fn(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
  """Returns the accuracy between an output and a target."""
  return (torch.argmax(output, dim=-1) == target).float()


def accuracy_fn(output, target):
  acc = _accuracy_fn(output, target)
  return torch.mean(acc)

Similarly to Part III, we define an update function using the losses we just defined.

In [None]:
def _apply_loss_and_metrics_fn(
    params: dict,
    batch: dict[str, torch.Tensor],
    model: torch.nn.Module,
):
  """Computes the model output and applies the loss function.

  Args:
    params: The model parameters (typically the state_dict of a
      torch.nn.Module).
    batch: The data (consists of both inputs and outputs).
    model: The PyTorch model.

  Returns:
    The loss of the model for the batch of data, extra loss metrics and the
    accuracy.
  """
  # Load parameters into the model
  model.load_state_dict(params)
  model.train()  # Put the model into training mode

  outputs = model(inputs=batch["input"], targets=batch["output"], sample=False)

  loss = loss_fn(outputs, batch["output"])
  accuracy = accuracy_fn(outputs, batch["output"])
  return loss, (accuracy)


def _update_parameters(
    params: dict,
    batch: dict[str, torch.Tensor],
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    opt_state: Any,
) -> tuple[Any, Any, tuple[float, float]]:
  """Applies a single SGD update step to the model parameters.

  Args:
    params: The model parameters (typically the state_dict of a
      torch.nn.Module).
    batch: The data (consists of both inputs and outputs).
    model: The PyTorch model.
    optimizer: The optimizer that computes the updates from the gradients of the
      `loss_fn` with respect to the `params` and the previous `opt_state`.
    opt_state: The optimizer state, e.g., momentum for each variable when using
      Adam. Not used in this case, but kept for consistency with original code.

  Returns:
    The updated parameters, the new optimizer state, and the loss, loss metrics
    and accuracy.
  """

  # Reset gradients
  optimizer.zero_grad()

  # Compute loss and gradients
  loss, (accuracy) = _apply_loss_and_metrics_fn(params, batch, model)
  loss.backward()

  # Update parameters
  optimizer.step()

  # Get updated model parameters
  new_params = model.state_dict()

  # Update optimizer state (not strictly necessary for optimizers like Adam, but a good practice)
  # Since opt_state is not used here, we can just pass it back
  new_opt_state = optimizer.state_dict()

  return new_params, new_opt_state, (loss.item(), accuracy.item())

## Training loop

In order to focus on the architectural changes in the next part, we define a standard training loop, which we will be using across the cells.

In [None]:
# @markdown <font color="blue">Double click here to see the full training loop</font>
def run_training(
    *,
    task: Any,
    model: torch.nn.Module,
    max_sequence_length: int,
    train_steps: int = 10_000,
    seed: int = 0,  # Used to sample during forward pass (e.g. from final logits).
    model_init_seed: int = 0,  # Used to initialize model parameters.
    log_frequency: int = 50,
    batch_size: int = 128,
    learning_rate: float = 1e-3,
    max_grad_norm: float = 1.0,
) -> Tuple[pd.DataFrame, dict, Any]:
  """Trains the model with the provided config."""

  # Fix the seeds
  random.seed(seed)
  np.random.seed(seed)
  # torch.manual_seed(seed)
  # torch.cuda.manual_seed_all(seed)

  # Use GPU if available
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  model.to(device)

  # Define the optimizer
  optimizer = optim.Adam(model.parameters(), lr=learning_rate)

  # Sample a batch to fit the tokenizer
  dummy_batch = task.sample_batch(
      length=max_sequence_length,
      batch_size=256,
  )

  # Define and fit the tokenizer
  tokenizer = tf.keras.preprocessing.text.Tokenizer(
      num_words=None,
      oov_token=None,
      char_level=True,
  )
  tokenizer.fit_on_texts(dummy_batch["input"] + dummy_batch["output"])

  print(f"The tokenizer index is: {tokenizer.word_index}")

  params = model.state_dict()
  opt_state = optimizer.state_dict()

  results = []
  for step in tqdm.tqdm(range(train_steps + 1)):
    # Randomness handled by either python.random or numpy.
    length = random.choice(list(range(3, max_sequence_length + 1)))

    # Randomness handled by either torch, python.random or numpy.
    train_batch = task.sample_batch(length=length, batch_size=batch_size)
    train_batch = preprocess_data(train_batch, tokenizer)

    # Move batch to device
    train_batch["input"] = train_batch["input"].to(device)
    train_batch["output"] = train_batch["output"].to(device)

    # Update the parameters.
    params, opt_state, (train_loss, train_accuracy) = _update_parameters(
        params=params,
        batch=train_batch,
        model=model,
        optimizer=optimizer,
        opt_state=opt_state,
    )

    # Log the training metrics
    if (log_frequency > 0) and (step % log_frequency == 0):
      log_data = {
          "step": step,
          "train_loss": float(train_loss),
          "train_accuracy": float(train_accuracy),
      }
      print(log_data)
      results.append(log_data)

  df_results = pd.DataFrame(results)
  return df_results, params, tokenizer

## Eval function

Next, we define the eval loop function. We'll be running this function at the end of each model change, it will be unchanged across the rest of the colab.

In [None]:
# @markdown <font color="blue">Double click here to see the full eval code</font>


def run_evaluation(
    *,
    model: torch.nn.Module,
    params: Any,
    tokenizer: Any,
    task: Any,
    max_test_length: int = 20,  # The largest sequence length to evaluate on
    total_batch_size: int = 512,
    sub_batch_size: int = 64,  # We use this to avoid memory overflow.
    seed: int = 1,
    is_autoregressive: bool = False,
) -> pd.DataFrame:
  """Evaluates the model on addition problems of various lengths and logs the results.

  This function tests the model's ability to generalize to sequences longer than
  those seen during training.
  It generates batches of addition problems with increasing lengths, evaluates
  the model's accuracy on each batch,
  and returns a Pandas DataFrame containing the accuracies for each length.

  Args:
    model: The PyTorch model to evaluate.
    params: The trained model parameters (a state_dict).
    tokenizer: The Tokenizer used to convert between text and token IDs.
    task: An instance of the AdditionTask class, used for generating data.
    max_test_length: The maximum length of sequences to evaluate on.
    total_batch_size: The total number of examples to evaluate for each length.
    sub_batch_size: The size of each sub-batch used during evaluation (to avoid
      memory issues).
    seed: The random seed for reproducibility.
    is_autoregressive: A boolean indicating whether the model is autoregressive
      (if True, it expects an additional 'sample' argument during inference).

  Returns:
    A pandas DataFrame with columns 'length' and 'accuracy', where each row
    represents the model's accuracy on sequences of a given length.
  """

  # Fix the random seed for reproducibility.
  random.seed(seed)
  np.random.seed(seed)
  # torch.manual_seed(seed)
  # torch.cuda.manual_seed_all(seed)

  # Use GPU if available
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  # Load parameters into the model
  model.load_state_dict(params)
  model.to(device)
  model.eval()  # Set the model to evaluation mode

  results = []
  lengths = range(3, max_test_length + 1)

  # Iterate over different sequence lengths
  for length in tqdm.tqdm(lengths, desc="Lengths"):
    sub_accuracies = []

    # Evaluate on multiple sub-batches to avoid memory overflow.
    for _ in range(total_batch_size // sub_batch_size):
      # Generate a batch of addition problems with the current length.
      batch = task.sample_batch(sub_batch_size, length)
      batch = preprocess_data(batch, tokenizer)

      # Move batch to device
      batch["input"] = batch["input"].to(device)
      batch["output"] = batch["output"].to(device)

      # Run the model to get predictions.
      with torch.no_grad():
        outputs = model(
            inputs=batch["input"],
            targets=batch["output"],
            sample=is_autoregressive,
        )

        # Calculate the accuracy for the current sub-batch.
        sub_accuracies.append(
            float(accuracy_fn(outputs, batch["output"]).cpu().numpy())
        )

    # Calculate the average accuracy for the current length.
    log_data = {
        "length": length,
        "accuracy": np.mean(sub_accuracies),
    }
    print(log_data)
    results.append(log_data)

  # Return the results as a pandas DataFrame.
  return pd.DataFrame(results)

## Architecture - Encoder

In Part III, we explored an encoder-only Transformer architecture. Recall that such a configuration is particularly well-suited for tasks like classification, where the primary goal is to map an input sequence to a single output label or category.

The implementation below is slightly different than the one above, but the core elements are identical.

We redefine the Transformer Encoder module, similarly to Part III,
with the following differences:
- We don't add positional embeddings immmediately
- We remove the output classifier and return the full sequence of embeddings.

In [None]:
class BaseTransformerEncoder(nn.Module):

  def __init__(
      self,
      *,
      vocab_size,
      max_seq_length,
      num_layers,
      mlp_dim=256,
      embed_dim=128,
      num_classes=2,
  ):
    super().__init__()
    self.embed_dim = embed_dim
    self.mlp_dim = mlp_dim
    self.num_classes = num_classes
    self.max_seq_length = max_seq_length
    self.num_layers = num_layers

    # Embedding layer
    self.embedding = nn.Embedding(vocab_size, embed_dim)

    # Positional embedding layer
    self.pos_embedding = nn.Embedding(max_seq_length, embed_dim)

    # Self-attention layer
    self.self_attention = SelfAttention(embed_dim)

    # Transformer Encoder Blocks
    self.transformer_blocks = nn.ModuleList(
        [TransformerEncoderBlock(embed_dim, mlp_dim) for _ in range(num_layers)]
    )

  def forward(self, x):
    # Embedding layer
    x = self.embedding(x)

    # Positional embeddings
    positions = torch.arange(x.shape[1], device=x.device)
    pos_embeddings = self.pos_embedding(positions)
    x = x + pos_embeddings / math.sqrt(
        self.embed_dim
    )  # Scale positional embeddings

    # Stack multiple transformer encoder blocks
    for block in self.transformer_blocks:
      x = block(x)
    return x

## Architecture - Cross attention

**Introducing Cross-Attention: Bridging the Gap Between Encoder and Decoder**

In the previous sections, we explored the self-attention mechanism, a powerful tool that allows a model to weigh the importance of different parts of an input sequence when processing it.  Self-attention operates within a single sequence, where the query, key, and value vectors are all derived from the same input.

Now, we introduce a generalization of this concept: **cross-attention**. While conceptually similar to self-attention, cross-attention enables the model to attend to information from a *different* source sequence. This opens up exciting possibilities for tasks that require relating information between two distinct sequences, such as the sequence-to-sequence task that the transformer was originally designed for.

In the context of the Transformer architecture, cross-attention serves as the crucial bridge between the encoder and the decoder. The decoder uses cross-attention to focus on relevant parts of the encoded input sequence (produced by the encoder) while generating the output sequence. This allows the model to effectively translate, summarize, or otherwise transform the input based on the learned relationships between the two sequences. In the following code, we implement a `CrossAttention` class. The main difference with self-attention is that the keys and values will be sourced from a different source than the query.

In [None]:
# Define the self-attention layer
class CrossAttention(nn.Module):

  def __init__(self, embed_dim):
    super().__init__()
    self.embed_dim = embed_dim

    # Calculate query, key, and value matrices using linear layers
    # --- YOUR CODE START ---
    # self.query = ...
    # self.key = ...
    # self.value = ...
    # --- YOUR CODE END ---

  def forward(
      self, inputs_q: torch.Tensor, inputs_kv: torch.Tensor
  ) -> torch.Tensor:
    # Calculate query, key, and value matrices

    # --- YOUR CODE START ---
    # query = self.query(...)
    # key = self.key(...)
    # value = self.value(...)

    # Calculate attention scores (scaled dot-product attention)

    # attention_scores = torch.matmul(..., ...) / math.sqrt(self.embed_dim)
    # attention_weights = nn.functional.softmax(attention_scores, dim=...)

    # Apply attention weights to values
    # output = torch.matmul(..., ...)
    # --- YOUR CODE END ---

    return output

## Architecture - Decoder

### Motivation

**Introducing the Decoder: Generating Output Sequences**

In Part III, our focus was on classification, where the model's objective was to predict a single label for a given input sequence.  In this section, we'll be tackling a more complex task: **sequence generation**. This requires us to introduce the **Decoder** component of the Transformer architecture.

The Decoder's role is to take the encoded representation of the input sequence (generated by the Encoder) and generate an output sequence, one token at a time. While its architecture shares many similarities with the Encoder, there are some crucial distinctions that enable it to perform this generative task:

**Key Differences Between the Decoder and Encoder:**

1.  **Cross-Attention:**  Perhaps the most significant difference is the inclusion of **cross-attention** layers. In each decoder layer, cross-attention allows the decoder to attend to the *encoder's* output. This enables the decoder to incorporate information from the input sequence when generating the output. The decoder's embeddings are used as queries, while the encoder's embeddings are used as keys and values.
2.  **Shifted Right Input:** During training, the decoder's input is the target sequence, *shifted one position to the right* and prepended with a special "start" token. This is done to facilitate **teacher forcing**, where the model is trained to predict the next token given the *previous* ground-truth tokens. The reason is that, at inference time, the model will generate the sequence auto-regressively, therefore it is useful to train it with a similar distribution of inputs.
3.  **Causal Masking:** To prevent the decoder from "cheating" during training and peeking at future tokens, a **causal mask** (also known as a look-ahead mask) is applied during self-attention within the decoder. This mask ensures that the model can only attend to tokens that have already been generated, effectively simulating the auto-regressive nature of sequence generation during inference.

**Detailed Exploration in the Following Cells:**

We'll delve into each of these key features – cross-attention, shifted right input, and causal masking – in more detail in the following cells. We'll examine their implementation and understand their crucial role in enabling the decoder to effectively generate sequences. By understanding these components, we'll gain a deeper appreciation for the Transformer's ability to tackle complex sequence-to-sequence tasks.

### Shift the decoder input to the right

In [None]:
def shift_right(x: torch.Tensor, vocab_size: int) -> torch.Tensor:
  """Shifts the target sequence one step to the right and pads with a special start token.

  This function prepares the target sequence for teacher forcing during
  training.
  By shifting the target sequence to the right and adding a special start token
  at the beginning,
  we provide the model with the previous token as input when predicting the
  current token.

  Returns:
    The shifted and padded target sequence, one-hot encoded. Expected shape:
    (batch_size, sequence_length)
  """
  # Add a time dimension for the single-output case (i.e., when x is a single sequence instead of a batch).
  if x.ndim == 1:
    x = x.unsqueeze(1)

  # Pad the sequence at the beginning with the special start token (represented by output_size).
  # Remove the last element to maintain the original sequence length.
  # The sequence is now shifted to the right, and the first token is the start token.

  # --- YOUR CODE START ---
  # return F.pad(..., (1, 0), mode='constant', value=...)
  # --- YOUR CODE END ---

Let's walk through the code step by step

In [None]:
# Define the output size (vocabulary size).
output_size = 4  # Tokens will be 0, 1, 2, 3

# Create a one-hot encoded toy input sequence.
# Represents the sequence [2, 3, 0]
x = torch.tensor(
    [[
        [0, 0, 1, 0],  # Token 2
        [0, 0, 0, 1],  # Token 3
        [1, 0, 0, 0],  # Token 0
    ]],
    dtype=torch.float32,
)
print(f"input : \n{x}\n")

# Compute the argmax
x = torch.argmax(x, dim=-1)
print(f"argmax : {x}\n")

padded = F.pad(x[:, :-1], (1, 0), mode="constant", value=output_size)
print(f"padded : \n{padded}\n")

### Decoder architecture

In [None]:
class TransformerDecoderBlock(nn.Module):

  def __init__(self, embed_dim, mlp_dim):
    super().__init__()
    self.embed_dim = embed_dim
    self.mlp_dim = mlp_dim

    # Self-attention layer
    self.self_attention = SelfAttention(embed_dim)
    self.cross_attention = CrossAttention(embed_dim)

    # Layer normalization
    self.layer_norm1 = nn.LayerNorm(embed_dim)

    # MLP layers
    self.mlp = nn.Sequential(
        nn.Linear(embed_dim, mlp_dim), nn.ReLU(), nn.Linear(mlp_dim, embed_dim)
    )

    # Layer normalization
    self.layer_norm2 = nn.LayerNorm(embed_dim)

  def forward(self, enc_emb, dec_emb):
    # --- YOUR CODE START ---
    # x = self.self_attention(...)
    # x = self.layer_norm1(x)

    # y = self.cross_attention(..., ...)
    y = self.layer_norm1(y)

    # MLP layer
    y = self.mlp(y)
    # Layer norm
    y = self.layer_norm2(y)

    # --- YOUR CODE END ---
    return y

In [None]:
class BaseTransformerDecoder(nn.Module):

  def __init__(
      self,
      *,
      vocab_size,
      max_seq_length,
      num_layers,
      mlp_dim=256,
      embed_dim=128,
      num_classes=2,
  ):
    super().__init__()
    self.embed_dim = embed_dim
    self.mlp_dim = mlp_dim
    self.num_classes = num_classes
    self.max_seq_length = max_seq_length
    self.num_layers = num_layers
    self.vocab_size = vocab_size

    # Embedding layer
    self.embedding = nn.Embedding(vocab_size + 1, embed_dim)

    # Positional embedding layer
    self.pos_embedding = nn.Embedding(max_seq_length, embed_dim)

    # Self-attention layer
    self.self_attention = SelfAttention(embed_dim)

    # Transformer Encoder Blocks
    # --- YOUR CODE START ---
    # self.transformer_blocks = nn.ModuleList(
    # [TransformerDecoderBlock(..., ...) for _ in range(...)]
    # )
    # --- YOUR CODE END ---

  def forward(
      self, encoded: torch.Tensor, targets: torch.Tensor
  ) -> torch.Tensor:
    # Prepare the target sequence for teacher forcing by shifting it to the right.
    x = shift_right(targets, self.vocab_size)

    # Embedding layer
    x = self.embedding(x)

    # Positional embeddings
    positions = torch.arange(x.shape[1], device=x.device)
    pos_embeddings = self.pos_embedding(positions)
    x = x + pos_embeddings / math.sqrt(
        self.embed_dim
    )  # Scale positional embeddings

    # Stack multiple transformer encoder blocks
    # --- YOUR CODE START ---
    # for block in self.transformer_blocks:
    # x = block(..., ...)
    # --- YOUR CODE END ---
    return x

## Architecture Encoder-Decoder

#### Combining the encoder and decoder

**Connecting Encoder and Decoder: Passing Information for Sequence Generation**

The core strength of the Transformer architecture lies in its ability to effectively process and generate sequences.  A crucial element in this process is the seamless flow of information from the encoder to the decoder.  As we've seen, the encoder's role is to create a rich, contextualized representation of the input sequence.  Now, it's the decoder's job to take that representation and generate the corresponding output sequence.

To achieve this, we **pass the encoder's output as an input to the decoder**. This encoded representation serves as the decoder's "memory" of the input sequence.  By using cross-attention layers, the decoder can then selectively attend to different parts of this encoded input while generating each token of the output sequence. This allows the decoder to make informed decisions based on the entire context of the input, leading to more accurate and coherent output generation. The `targets` will be used by the decoder to guide its generation process, computing a loss with the prediction at each step.

In [None]:
class BaseTransformer(nn.Module):
  """Transformer model for sequence-to-sequence tasks.

  This class combines the encoder and decoder modules to create a complete
  Transformer model.
  """

  def __init__(
      self,
      num_layers: int,
      max_seq_length: int,
      vocab_size: int,
      embed_dim: int = 32,
  ):
    super().__init__()
    self.vocab_size = vocab_size
    self.num_layers = num_layers
    self.embed_dim = embed_dim
    self.max_seq_length = max_seq_length

    # Create an instance of the TransformerEncoder.
    # --- YOUR CODE START ---

    # self.encoder = ...

    # Create an instance of the TransformerDecoder.
    # self.decoder = ...

    # Create a dense layer to project the decoder output to the vocabulary space.
    # self.linear_output = ...

    # --- YOUR CODE END ---

  def forward(
      self, inputs: torch.Tensor, targets: torch.Tensor, sample: bool
  ) -> torch.Tensor:
    """Applies the Transformer model to the given input and target sequences."""

    del sample  # For now it's a dummy variable, we'll use it later.
    logits = None
    # --- YOUR CODE START ---

    # Encode the input sequence.
    # encoder_output = ...

    # Decode the encoded input, using the target sequence for teacher forcing.
    # decoder_output = ...

    # Project the decoder output to the vocabulary space to get logits.
    # logits = ...

    # --- YOUR CODE END ---

    return logits

#### Training

Let's train the model above

In [None]:
# Define the model
model = BaseTransformer(
    num_layers=1,
    vocab_size=task.vocab_size,
    max_seq_length=MAX_TEST_LENGTH + 2,
)

# Run the training loop
df_train, params, tokenizer = run_training(
    max_sequence_length=MAX_TRAIN_LENGTH,
    task=task,
    model=model,
    batch_size=128,
    train_steps=2_500,
)

# Plot the training accuracy over training
sns.lineplot(data=df_train, x="step", y="train_accuracy")

#### Eval

Let's evaluate the model at test time.

In [None]:
# Evaluate the model
df_eval = run_evaluation(
    model=model,
    params=params,
    tokenizer=tokenizer,
    task=task,
    max_test_length=MAX_TEST_LENGTH,
)

# Save the eval data for later comparisons across experiments
df_eval_dict = dict(base=df_eval)

# Plot the test accuracy for each length
sns.lineplot(data=df_eval, x="length", y="accuracy", marker="o").set_ylim(
    -0.05, 1.05
)

#### Analysis

- A reminder that the lengths [3, 10] are in distribution wrt the training dataset, while [11, 20] are out of distribution. Also, note that the in distribution eval dataset may be contaminated with some training samples.

- We note that the average accuracy is better than random (<10% accuracy) for all lengths.
- We also note that the test accuracy decays with the sequence length, it can be explained by being more out of distribution as well as the increase of the task complexity.

## Autoregressive sampling

#### Motivation

**Autoregressive Sampling: Generating Sequences Token by Token**

Now that we have a trained autoregressive Transformer, we can use it to generate new sequences.  Unlike training, where we had access to the entire target sequence and could use teacher forcing, at inference time, we need to generate the output sequence **one token at a time**, using the model's own predictions as input for the subsequent tokens. This process is called **autoregressive sampling**.

During autoregressive sampling, the model starts with an initial input (usually just the input sequence for the encoder) and a special "start" token for the decoder. It then predicts the probability distribution over the output vocabulary for the first token. We sample a token from this distribution, append it to the decoder's input, and repeat the process. This continues until a special "end" token is generated or a maximum sequence length is reached. Each prediction is conditioned on the input sequence and the previously generated tokens.

#### Code : Add greedy autoregressive sampling

In [None]:
def make_transformer_autoregressive(TransformerModel):
  """Wraps a Transformer model to make it autoregressive.

  This function modifies a standard Transformer model to generate outputs
  autoregressively,
  one token at a time.
  """

  class AutoregressiveTransformer(nn.Module):
    """Autoregressive Transformer model for sequence generation.

    This class wraps a standard Transformer model to enable autoregressive
    generation.
    In autoregressive mode, the model generates the output sequence one token at
    a time,
    using its previous predictions as input for subsequent tokens.
    """

    def __init__(self, num_layers: int, max_seq_length: int, vocab_size: int):
      super().__init__()
      self.vocab_size = vocab_size
      self.num_layers = num_layers
      self.max_seq_length = max_seq_length
      self.model = TransformerModel(
          vocab_size=vocab_size,
          num_layers=num_layers,
          max_seq_length=max_seq_length,
      )

    def forward(
        self, inputs: torch.Tensor, targets: torch.Tensor, sample: bool
    ) -> torch.Tensor:
      """Applies the autoregressive Transformer model."""
      # Determine the output length based on the targets shape. output_length will be target_sequence_length, or 1 in the case of a single prediction.
      output_length = 1 if len(targets.shape) == 2 else targets.shape[1]
      # Get the output size from the targets.
      output_size = targets.shape[-1]

      if not sample or output_length == 1:
        output = self.model(inputs, targets, sample)
      else:

        def evaluate_ §model_autoregressively(idx, predictions):
          """Iteratively evaluates the model based on the previous predictions.

          This function performs a single step of autoregressive generation.
          It takes the current predictions, predicts the next token, and updates
          the predictions array.

          Args:
              idx: The index of the target sequence that should be evaluated.
              predictions: The logits for the predictions up to but not
                including the index `idx`.

          Returns:
              The `predictions` array modified only at position `idx` where the
              logits for index `idx` have been inserted.
          """

          # --- YOUR CODE START ---

          # Apply the model to get the logits for the next token.
          logits = self.model(inputs, ..., sample)

          # Update the predictions array with the new logits at the current index.
          # return predictions.index_put(
          # indices=...,
          # values=logits[..., ...]
          # )

          # --- YOUR CODE END ---

        # Use torch.for_loop for efficient autoregressive generation.
        output = torch.for_loop(
            lower=0,
            upper=output_length,
            body_fun=evaluate_model_autoregressively,
            init_val=torch.empty_like(targets),
        )

      return output

  return AutoregressiveTransformer

#### Training

In [None]:
# Define the model
AutoregressiveTransformer = make_transformer_autoregressive(BaseTransformer)
model = AutoregressiveTransformer(
    num_layers=1,
    vocab_size=task.vocab_size,
    max_seq_length=MAX_TEST_LENGTH + 2,
)

# Run the training loop.
df_train, params, tokenizer = run_training(
    max_sequence_length=MAX_TRAIN_LENGTH,
    task=task,
    model=model,
    batch_size=128,
    train_steps=2_500,
)

# Visualize the training accuracy.
sns.lineplot(data=df_train, x="step", y="train_accuracy")

#### Eval

In [None]:
# Evaluate the model

df_eval = run_evaluation(
    model=model,
    params=params,
    tokenizer=tokenizer,
    task=task,
    max_test_length=MAX_TEST_LENGTH,
    is_autoregressive=True,
)
# Save the eval data for later comparisons across experiments
df_eval_dict["base_autoregressive"] = df_eval

# Plot the test accuracy for each length
sns.lineplot(data=df_eval, x="length", y="accuracy", marker="o").set_ylim(
    -0.05, 1.05
)

#### Analysis

In [None]:
df_eval_all = pd.concat(
    df_eval_dict, keys=df_eval_dict.keys(), names=["experiment", None]
)
sns.lineplot(
    data=df_eval_all, x="length", y="accuracy", hue="experiment", marker="o"
).set_ylim(-0.05, 1.05)

- We note that the autoregressive model has a lower accuracy than the non autoregressive model. The reason is that the non autoregressive model has accress to the ground truth sequence as an input, while the autoregressive model doesn't. Because it only sees the past generated tokens.

## Add causal masks

### Motivation

**Causal Masking: Enforcing Autoregressive Generation in the Decoder**

In the decoder, we need to ensure that the model generates the output sequence **autoregressively**, meaning that each token is predicted based only on the previously generated tokens and the encoded input sequence.  To achieve this, we employ a technique called **causal masking** (also known as **look-ahead masking**) within the self-attention mechanism.

The core idea behind causal masking is simple: **prevent the model from "peeking" at future tokens during training.** We want the model to learn to predict the next token based solely on the information available up to the current position in the sequence.  This is accomplished by applying a mask to the attention weights, effectively setting the weights for future tokens to zero (or a very large negative value before the softmax).

This masking is essential because, during inference, the model generates the sequence one token at a time. It doesn't have access to future tokens, so it must rely only on its past predictions and the encoded input. By enforcing this constraint during training with causal masking, we ensure that the model learns to generate sequences in a truly autoregressive manner, making it suitable for realistic inference scenarios where future information is unavailable. The following class `MaskedSelfAttention` implements this masking mechanism.

In [None]:
# Illustrating the causal mask

sequence_length = 4  # You can change this to any desired sequence length

# Create the causal mask using jnp.tril
causal_mask = torch.tril(
    torch.ones((sequence_length, sequence_length), dtype=torch.bool)
)

print(causal_mask)

### Add causal masks to the model

#### MaskedSelfAttention

In [None]:
class MaskedSelfAttention(nn.Module):

  def __init__(self, embed_dim):
    super().__init__()
    self.embed_dim = embed_dim

    # Calculate query, key, and value matrices using linear layers
    self.query = nn.Linear(embed_dim, embed_dim)
    self.key = nn.Linear(embed_dim, embed_dim)
    self.value = nn.Linear(embed_dim, embed_dim)

  def forward(self, x, mask):
    # Calculate query, key, and value matrices
    query = self.query(x)
    key = self.key(x)
    value = self.value(x)

    # Calculate attention scores (scaled dot-product attention)
    attention_scores = torch.matmul(query, key.transpose(1, 2)) / math.sqrt(
        self.embed_dim
    )

    # New  - we add the mask
    if mask is not None:
      pass
      # --- YOUR CODE START ---
      # attention_scores = ...
      # --- YOUR CODE END ---

    attention_weights = nn.functional.softmax(attention_scores, dim=-1)

    # Apply attention weights to values
    output = torch.matmul(attention_weights, value)
    return output

#### MaskedTransformerDecoderBlock

In [None]:
class MaskedTransformerDecoderBlock(nn.Module):

  def __init__(self, embed_dim, mlp_dim):
    super().__init__()
    self.embed_dim = embed_dim
    self.mlp_dim = mlp_dim

    # Self-attention layer
    self.self_attention = MaskedSelfAttention(embed_dim)  # Change
    self.cross_attention = CrossAttention(embed_dim)

    # Layer normalization
    self.layer_norm1 = nn.LayerNorm(embed_dim)

    # MLP layers
    self.mlp = nn.Sequential(
        nn.Linear(embed_dim, mlp_dim), nn.ReLU(), nn.Linear(mlp_dim, embed_dim)
    )

    # Layer normalization
    self.layer_norm2 = nn.LayerNorm(embed_dim)

  def forward(self, enc_emb, dec_emb, mask):  # Change
    x = self.self_attention(dec_emb, mask)  # Change
    x = self.layer_norm1(x)

    y = self.cross_attention(x, dec_emb)
    y = self.layer_norm1(y)

    # MLP layer
    y = self.mlp(y)
    # Layer norm
    y = self.layer_norm2(y)

    return y

#### MaskedTransformerDecoder

In [None]:
class MaskedTransformerDecoder(nn.Module):

  def __init__(
      self,
      *,
      vocab_size,
      max_seq_length,
      num_layers,
      mlp_dim=256,
      embed_dim=128,
      num_classes=2,
  ):
    super().__init__()
    self.embed_dim = embed_dim
    self.mlp_dim = mlp_dim
    self.num_classes = num_classes
    self.max_seq_length = max_seq_length
    self.num_layers = num_layers
    self.vocab_size = vocab_size

    # Embedding layer
    self.embedding = nn.Embedding(vocab_size + 1, embed_dim)

    # Positional embedding layer
    self.pos_embedding = nn.Embedding(max_seq_length, embed_dim)

    # Self-attention layer
    self.self_attention = SelfAttention(embed_dim)

    # Transformer Encoder Blocks
    self.transformer_blocks = nn.ModuleList([
        MaskedTransformerDecoderBlock(embed_dim, mlp_dim)
        for _ in range(num_layers)
    ])  # Change

  def forward(
      self, encoded: torch.Tensor, targets: torch.Tensor
  ) -> torch.Tensor:
    x = shift_right(targets, self.vocab_size)

    # Embedding layer
    x = self.embedding(x)

    # Positional embeddings
    positions = torch.arange(x.shape[1], device=x.device)
    pos_embeddings = self.pos_embedding(positions)
    x = x + pos_embeddings / math.sqrt(self.embed_dim)

    # Change : Add a causal mask
    batch_size, output_sequence_length, embedding_size = x.shape

    # --- YOUR CODE START ---
    # causal_mask = ...
    # causal_mask = causal_mask.to(...)
    # --- YOUR CODE END ---

    # Stack multiple transformer encoder blocks
    for block in self.transformer_blocks:
      x = block(encoded, x, causal_mask)  # Change
    return x

#### MaskedTransformer

In [None]:
class MaskedTransformer(nn.Module):
  """Transformer model for sequence-to-sequence tasks.

  This class combines the encoder and decoder modules to create a complete
  Transformer model.
  """

  def __init__(
      self,
      num_layers: int,
      max_seq_length: int,
      vocab_size: int,
      embed_dim: int = 32,
  ):
    super().__init__()
    self.vocab_size = vocab_size
    self.num_layers = num_layers
    self.embed_dim = embed_dim
    self.max_seq_length = max_seq_length

    # Create an instance of the TransformerEncoder.
    self.encoder = BaseTransformerEncoder(
        max_seq_length=self.max_seq_length,
        num_layers=self.num_layers,
        embed_dim=self.embed_dim,
        vocab_size=self.vocab_size,
    )
    # Create an instance of the MaskedTransformerDecoder.
    self.decoder = MaskedTransformerDecoder(  # Change
        vocab_size=self.vocab_size,
        num_layers=self.num_layers,
        embed_dim=self.embed_dim,
        max_seq_length=self.max_seq_length,
    )
    # Create a dense layer to project the decoder output to the vocabulary space.
    self.linear_output = nn.Linear(self.embed_dim, self.vocab_size)

  def forward(
      self, inputs: torch.Tensor, targets: torch.Tensor, sample: bool
  ) -> torch.Tensor:
    del sample  # For now it's a dummy variable, we'll use it later.

    # Encode the input sequence.
    encoder_output = self.encoder(inputs)

    # Decode the encoded input, using the target sequence for teacher forcing.
    decoder_output = self.decoder(encoder_output, targets)

    # Project the decoder output to the vocabulary space to get logits.
    logits = self.linear_output(decoder_output)

    return logits

### Training

In [None]:
# Define the model
AutoregressiveTransformer = make_transformer_autoregressive(MaskedTransformer)
model = AutoregressiveTransformer(
    num_layers=1,
    vocab_size=task.vocab_size,
    max_seq_length=MAX_TEST_LENGTH + 2,
)

# Run the training loop.
df_train, params, tokenizer = run_training(
    max_sequence_length=MAX_TRAIN_LENGTH,
    task=task,
    model=model,
    batch_size=128,
    train_steps=2_500,
)

# Visualize the training accuracy.
sns.lineplot(data=df_train, x="step", y="train_accuracy")

### Eval

In [None]:
# Evaluate the model
df_eval = run_evaluation(
    model=model,
    params=params,
    tokenizer=tokenizer,
    task=task,
    max_test_length=20,
    is_autoregressive=True,
)

# Save the eval data for later comparisons across experiments
df_eval_dict["masked_autoregressive"] = df_eval

# Plot the test accuracy for each length
sns.lineplot(data=df_eval, x="length", y="accuracy", marker="o").set_ylim(
    -0.05, 1.05
)

### Analysis

In [None]:
df_eval_all = pd.concat(
    df_eval_dict, keys=df_eval_dict.keys(), names=["experiment", None]
)
sns.lineplot(
    data=df_eval_all, x="length", y="accuracy", hue="experiment", marker="o"
).set_ylim(-0.05, 1.05)

- Given that, at train time, the model was trained not to peek into the future tokens, at test time it leads to significant improvement on accuracy for autoregressive eval.

## Positional embeddings

### Motivation

**Improving Generalization with Sinusoidal Positional Encodings**

In Part III, we explored the use of positional embeddings and saw how they significantly impacted the model's ability to generalize to longer sequences. Recall that these embeddings provide the Transformer model with crucial information about the order of tokens within a sequence, as the architecture itself doesn't inherently capture positional relationships.  In our previous implementation, we used randomly initialized, learnable positional embeddings.

In this section, we introduce a different type of positional encoding: **sinusoidal positional encodings**. While the motivation remains the same – to inform the model about token positions – sinusoidal encodings offer a deterministic and potentially more effective way to achieve this. Instead of learning a unique embedding for each position, sinusoidal encodings are generated using a predefined mathematical function based on sine and cosine waves with varying frequencies.

In [None]:
def sinusoid_position_encoding(
    sequence_length: int,
    hidden_size: int,
    memory_length: int = 0,
    max_timescale: float = 1e4,
    min_timescale: float = 2.0,
    clamp_length: int = 0,
    causal: bool = False,
):
  """Creates sinusoidal positional encodings.

  These encodings are used to provide positional information to the Transformer
  model,
  as it doesn't have inherent mechanisms to understand the order of tokens in a
  sequence.

  Args:
    sequence_length: `int` sequence length (L).
    hidden_size: `int` dimension of the positional encoding vectors (D). This is
      usually the same as the embedding dimension of the model.
    memory_length: `int` size of the memory (M). This is used for models like
      Transformer-XL.
    max_timescale: `float` maximum timescale for the frequency.
    min_timescale: `float` minimum timescale for the frequency.
    clamp_length: If greater than 0, any positions further apart than
      `clamp_length` are clamped to this value. This is used to limit the
      distance between tokens.
    causal: If true, then generates a smaller set (L vs 2 * L) of time-encodings
      for the use-case of causal attention (e.g., in the decoder).

  Returns:
    A NumPy array of shape [L + M, D] for causal and [2 * L + M, D] otherwise,
    representing the positional encodings.
  """
  # Calculate frequencies for each dimension.
  freqs = np.arange(0, hidden_size, min_timescale)
  inv_freq = max_timescale ** (-freqs / hidden_size)

  # Create the position sequence.
  if causal:
    # For causal models, we only need encodings for positions up to sequence_length + memory_length.
    pos_seq = np.arange(sequence_length + memory_length, 0, -1.0)
  else:
    # For non-causal models, we need encodings for positions from -(sequence_length + memory_length) to sequence_length.
    pos_seq = np.arange(sequence_length + memory_length, -sequence_length, -1.0)

  # Clamp positions if clamp_length is specified.
  if clamp_length:
    pos_seq = np.clip(pos_seq, a_min=-clamp_length, a_max=clamp_length)

  # Calculate the sinusoidal inputs.
  sinusoid_inp = np.einsum('i,j->ij', pos_seq, inv_freq)

  # Create the positional encodings by concatenating sine and cosine values.
  pos_emb = np.concatenate(
      [np.sin(sinusoid_inp), np.cos(sinusoid_inp)], axis=-1
  )

  return pos_emb

### Visualize the positional embeddings

In [None]:
# Generate the positional encodings
sequence_length = 50
hidden_size = 64
positional_encodings = sinusoid_position_encoding(sequence_length, hidden_size)

# Visualize the positional encodings
plt.figure(figsize=(10, 8))
plt.pcolormesh(positional_encodings, cmap='RdBu')
plt.xlabel('Hidden Dimension')
plt.xlim((0, hidden_size))
plt.ylabel('Sequence Position')
plt.title('Sinusoidal Positional Encodings')
plt.colorbar()
plt.show()

In [None]:
# Visualize individual dimensions to see the waves
plt.figure(figsize=(10, 4))
for i in range(4):  # Visualize the first 4 dimensions
  plt.plot(positional_encodings[:, i], label=f'Dimension {i}')
plt.xlabel('Sequence Position')
plt.ylabel('Encoding Value')
plt.title('Individual Dimensions of Sinusoidal Positional Encodings')
plt.legend()
plt.show()

### Add the positional embeddings to the model

#### FullTransformerEncoder

In [None]:
class FullTransformerEncoder(nn.Module):

  def __init__(
      self,
      *,
      vocab_size,
      max_seq_length,
      num_layers,
      mlp_dim=256,
      embed_dim=128,
      num_classes=2,
  ):
    super().__init__()
    self.embed_dim = embed_dim
    self.mlp_dim = mlp_dim
    self.num_classes = num_classes
    self.max_seq_length = max_seq_length
    self.num_layers = num_layers

    # Embedding layer
    self.embedding = nn.Embedding(vocab_size, embed_dim)

    # Self-attention layer
    self.self_attention = SelfAttention(embed_dim)

    # Transformer Encoder Blocks
    self.transformer_blocks = nn.ModuleList(
        [TransformerEncoderBlock(embed_dim, mlp_dim) for _ in range(num_layers)]
    )

  def forward(self, x):
    # Embedding layer
    x = self.embedding(x)

    ### Change to sinusoidal positional embeddings
    batch_size, sequence_length, embedding_size = x.shape
    pos_encodings = sinusoid_position_encoding(
        sequence_length=sequence_length,
        hidden_size=self.embed_dim,
        memory_length=0,
        max_timescale=10_000,
        min_timescale=2,
        clamp_length=0,
        causal=True,
    )
    pos_encodings = torch.Tensor(pos_encodings).to(x.device)
    h = x + pos_encodings
    ###

    # Stack multiple transformer encoder blocks
    for block in self.transformer_blocks:
      x = block(x)
    return x

#### FullTransformerDecoder

In [None]:
class FullTransformerDecoder(nn.Module):

  def __init__(
      self,
      *,
      vocab_size,
      max_seq_length,
      num_layers,
      mlp_dim=256,
      embed_dim=128,
      num_classes=2,
  ):
    super().__init__()
    self.embed_dim = embed_dim
    self.mlp_dim = mlp_dim
    self.num_classes = num_classes
    self.max_seq_length = max_seq_length
    self.num_layers = num_layers
    self.vocab_size = vocab_size

    # Embedding layer
    self.embedding = nn.Embedding(vocab_size + 1, embed_dim)

    # Self-attention layer
    self.self_attention = SelfAttention(embed_dim)

    # Transformer Encoder Blocks
    self.transformer_blocks = nn.ModuleList([
        MaskedTransformerDecoderBlock(embed_dim, mlp_dim)
        for _ in range(num_layers)
    ])

  def forward(
      self, encoded: torch.Tensor, targets: torch.Tensor
  ) -> torch.Tensor:
    x = shift_right(targets, self.vocab_size)

    # Embedding layer
    x = self.embedding(x)

    ### Change to sinusoidal positional embeddings
    batch_size, sequence_length, embedding_size = x.shape
    pos_encodings = sinusoid_position_encoding(
        sequence_length=sequence_length,
        hidden_size=self.embed_dim,
        memory_length=0,
        max_timescale=10_000,
        min_timescale=2,
        clamp_length=0,
        causal=True,
    )
    pos_encodings = torch.Tensor(pos_encodings).to(x.device)
    h = x + pos_encodings
    ###

    batch_size, output_sequence_length, embedding_size = x.shape
    causal_mask = torch.tril(
        torch.ones(
            (batch_size, 1, output_sequence_length, output_sequence_length),
            dtype=torch.bool,
        )
    )
    causal_mask = causal_mask.to(encoded.device)

    # Stack multiple transformer encoder blocks
    for block in self.transformer_blocks:
      x = block(encoded, x, causal_mask)
    return x

#### FullTransformer

In [None]:
class FullTransformer(nn.Module):
  """Transformer model for sequence-to-sequence tasks.

  This class combines the encoder and decoder modules to create a complete
  Transformer model.
  """

  def __init__(
      self,
      num_layers: int,
      max_seq_length: int,
      vocab_size: int,
      embed_dim: int = 32,
  ):
    super().__init__()
    self.vocab_size = vocab_size
    self.num_layers = num_layers
    self.embed_dim = embed_dim
    self.max_seq_length = max_seq_length

    # Create an instance of the TransformerEncoder.
    self.encoder = FullTransformerEncoder(
        max_seq_length=self.max_seq_length,
        num_layers=self.num_layers,
        embed_dim=self.embed_dim,
        vocab_size=self.vocab_size,
    )
    # Create an instance of the MaskedTransformerDecoder.
    self.decoder = FullTransformerDecoder(
        vocab_size=self.vocab_size,
        num_layers=self.num_layers,
        embed_dim=self.embed_dim,
        max_seq_length=self.max_seq_length,
    )
    # Create a dense layer to project the decoder output to the vocabulary space.
    self.linear_output = nn.Linear(self.embed_dim, self.vocab_size)

  def forward(
      self, inputs: torch.Tensor, targets: torch.Tensor, sample: bool
  ) -> torch.Tensor:
    del sample  # For now it's a dummy variable, we'll use it later.

    # Encode the input sequence.
    encoder_output = self.encoder(inputs)

    # Decode the encoded input, using the target sequence for teacher forcing.
    decoder_output = self.decoder(encoder_output, targets)

    # Project the decoder output to the vocabulary space to get logits.
    logits = self.linear_output(decoder_output)

    return logits

### Training

In [None]:
# Define the model
AutoregressiveTransformer = make_transformer_autoregressive(FullTransformer)
model = AutoregressiveTransformer(
    num_layers=1,
    vocab_size=task.vocab_size,
    max_seq_length=MAX_TEST_LENGTH + 2,
)

# Run the training loop.
df_train, params, tokenizer = run_training(
    max_sequence_length=MAX_TRAIN_LENGTH,
    task=task,
    model=model,
    batch_size=128,
    train_steps=2_500,
)

# Visualize the training accuracy.
sns.lineplot(data=df_train, x="step", y="train_accuracy")

### Eval

In [None]:
# Evaluate the model
df_eval = run_evaluation(
    model=model,
    params=params,
    tokenizer=tokenizer,
    task=task,
    max_test_length=20,
    is_autoregressive=True,
)

# Save the eval data for later comparisons across experiments
df_eval_dict["full_sinusoidal_autoregressive"] = df_eval

# Plot the test accuracy for each length
sns.lineplot(data=df_eval, x="length", y="accuracy", marker="o").set_ylim(
    -0.05, 1.05
)

### Analysis

In [None]:
df_eval_all = pd.concat(
    df_eval_dict, keys=df_eval_dict.keys(), names=["experiment", None]
)
sns.lineplot(
    data=df_eval_all, x="length", y="accuracy", hue="experiment", marker="o"
).set_ylim(-0.05, 1.05)

- In this experiment, sinusoidal positional embeddings lead to a lower accuracy than the learned positional embeddings.

## Ideas of things to try

- Play with the hyperparameters and the ordering of modules.
- Check to which extend the eval dataset is contaminated with training samples, for the in distribution lengths range. Then fix the issue, to have an in distribution test metric.
- Implement other autoregressive sampling methods such as nucleus sampling [1] or beam search [2]
- Implement best of n sampling and analyze how the top@n accuracy improves with n.
- Implement other positional embeddings such as ROPE [3]
- Train the model on real world text datasets
- Implement multi-head attention
- Look into a lower numerical precision than float32 and analyse how it impacts memory, time and accuracy.

**References**
- [1] Holtzman, A., Buys, J., Du, L., Forbes, M., & Choi, Y. (2019). The Curious Case of Neural Text Degeneration. ArXiv, abs/1904.09751.
- [2] Freitag, Markus and Yaser Al-Onaizan. “Beam Search Strategies for Neural Machine Translation.” NMT@ACL (2017).
- [3] Su, J., Lu, Y., Pan, S., Wen, B., & Liu, Y. (2021). RoFormer: Enhanced Transformer with Rotary Position Embedding. ArXiv, abs/2104.09864.