**This exercise requires computing power. The assignment is solvable using a JupyterHub CPU. If implemented efficiently, the training duration will be approximately 15 minutes. However, we recommend you to use a GPU if you have access to one.**

In [None]:
import heapq
import math
import os
import random

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import sentencepiece as spm
from tqdm import tqdm

def set_seeds(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # For CUDA
    torch.cuda.manual_seed_all(seed)  # If you are using multi-GPU.
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seeds(42)

In [None]:
skip_training = False  # Set this flag to True before validation and submission
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# During evaluation, this cell sets skip_training to True
# skip_training = True

# Exercise 4 - Autoregressive Language Modeling (4p)

In the previous exercise, we explored n-gram models, which, despite their simplicity, quickly reach their limits in modeling long-range language dependencies. In this final assignment, your goal is to implement and train a Transformer-based autoregressive language model to predict the next token given a sequence of previous tokens. Your model should achieve low perplexity on a held-out test set from the Shakespeare dataset.

## Tokenization and Data Preparation

Use the previously trained SentencePiece tokenizer to tokenize the dataset

In [None]:
# Check if the file already exists
if not os.path.exists("input.txt"):
    !wget https://raw.githubusercontent.com/karpathy/char-rnn/refs/heads/master/data/tinyshakespeare/input.txt
else:
    print("input.txt already exists. Skipping download.\n")

# Define parameters
model_prefix = "shakespeare_tokenizer"
vocab_size = 512
filename = "input.txt"

if not os.path.exists(f"{model_prefix}_{vocab_size}.model"):
    # Train SentencePiece tokenizer
    spm.SentencePieceTrainer.train(
        input=filename,
        model_prefix=f"{model_prefix}_{vocab_size}",
        vocab_size=vocab_size,
        model_type="bpe",
        max_sentence_length=4096,
        minloglevel=2,
        hard_vocab_limit=False,
        normalization_rule_name="nmt_nfkc",
        remove_extra_whitespaces=True,
        character_coverage=1.0,
        num_threads=1, 
    )
    print(f"Tokenizer creation complete. Model files generated: {model_prefix}_{vocab_size}.model, {model_prefix}_{vocab_size}.vocab")
else:
    sp_model = spm.SentencePieceProcessor()
    sp_model.load(f"shakespeare_tokenizer_{vocab_size}.model")
    print(f"Tokenizer loaded successfully from shakespeare_tokenizer_{vocab_size}.model\n")

def tokenize_text(text, sp_model):
    return sp_model.encode(text, out_type=int)

def load_dataset(file_path):
    with open(file_path, "r", encoding="utf-8") as f:
        return f.read()

text = load_dataset("input.txt")
tokens = tokenize_text(text, sp_model)
print(f"Total tokens in dataset: {len(tokens)}")

block_size = 128  # context window for each input sequence
batch_size = 32

## Creating the Dataset

We divide the dataset into fixed-length sequences (chunks) to train the model. To introduce variability and improve model robustness when performing multi-epoch training, we add a random offset each epoch to shift the dataset slightly to keep the mini-batches different.

In [None]:
class ShakespeareDataset(Dataset):
    """
    Creates fixed-length sequences from a continuous token stream for autoregressive training.
    """
    def __init__(self, tokens, block_size):
        self.tokens = tokens
        self.block_size = block_size
        self.offset = 0
        self._update_chunks()

    def _update_chunks(self):
        # Trim the tokens using the current offset
        adjusted_tokens = self.tokens[self.offset:]
        total_len = len(adjusted_tokens)
        self.num_chunks = (total_len - 1) // self.block_size
        self.chunks_start = [i * self.block_size + self.offset for i in range(self.num_chunks)]

    def set_epoch(self, epoch=None):
        # Set a new random offset at the beginning of each epoch
        self.offset = random.randint(0, self.block_size - 1)
        self._update_chunks()

    def __len__(self):
        return self.num_chunks

    def __getitem__(self, idx):
        start = self.chunks_start[idx]
        end = start + self.block_size + 1
        chunk = self.tokens[start:end]

        x = torch.tensor(chunk[:-1], dtype=torch.long)
        y = torch.tensor(chunk[1:], dtype=torch.long)
        return x, y

We split the dataset into training and test sets:

In [None]:
split_idx = int(0.98 * len(tokens))
train_data = tokens[:split_idx]
test_data = tokens[split_idx:]

train_dataset = ShakespeareDataset(train_data, block_size)
test_dataset = ShakespeareDataset(test_data, block_size)

# train_loader created later
test_loader = DataLoader(test_dataset, batch_size=batch_size)

## Transformer Architecture

You will implement a decoder-only transformer model suitable for autoregressive language modeling. We give you the original sinusoidal positional encoding below, which you may use but are not obligated to do so.

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=1000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer("pe", pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

### TransformerBlock

Your task is to implement a TransformerBlock. A transformer block takes input vectors with shape (batch_size, seq_len, embed_dim) and processes them in two main stages:

Causal Self-Attention:
Apply layer normalization, followed by multi-head self-attention with a causal mask to ensure each token attends only to itself and previous tokens. The attention output is added back to the input (residual connection).

Feedforward Network:
Again apply layer normalization, followed by a feedforward network (two linear layers with a ReLU activation in between). The output is added back via a second residual connection.

We provide this typical transformer description as guidance. The exact implementation details are flexible. We only test correctness of input/output shapes and causal masking.

In [None]:
class TransformerBlock(nn.Module):
    """
    A single decoder-only transformer block. Can, for instance, consist of:
    - LayerNorm followed by multi-head self-attention with residual connection.
    - LayerNorm followed by a feedforward network with residual connection.

    Parameters:
    -----------
    embed_dim : int
        Dimension of the input embeddings and the hidden representations.
    num_heads : int
        Number of attention heads in the multi-head attention layer.
    ff_hidden_dim : int
        Hidden layer size for the feedforward network.
    """
    def __init__(self, embed_dim, num_heads, ff_hidden_dim):
        super().__init__()
        # YOUR CODE HERE
        raise NotImplementedError()
        
    def forward(self, x, attn_mask):
        """
        Forward pass through the transformer block.

        Parameters:
        -----------
        x : Tensor
            Input tensor of shape (batch_size, seq_len, embed_dim).
        attn_mask : Tensor
            Boolean tensor of shape (seq_len, seq_len) used to mask future tokens.

        Returns:
        --------
        Tensor
            Output tensor of the same shape as input (batch_size, seq_len, embed_dim).
        """
        # YOUR CODE HERE
        raise NotImplementedError()

### TransformerDecoderOnlyLM

Next, implement the full language model using your TransformerBlock. The model is a decoder-only transformer for autoregressive language modeling, mapping token indices (batch_size, seq_len) to logits (batch_size, seq_len, vocab_size).

Typically, this model performs the following steps:

Embedding + Positional Encoding: Embed tokens into vectors and add positional information.

Stacked Transformer Blocks: Pass embeddings through several transformer blocks, each using causal self-attention.

Final Projection: Apply a final layer normalization and project the output vectors to logits over the vocabulary.

The exact architecture (number of layers, dimensions) is not tested. We only require correct handling of tensor shapes, causal masking, and final output dimensions.

In [None]:
class TransformerDecoderOnlyLM(nn.Module):
    """
    A full decoder-only Transformer language model.

    This model maps a sequence of token indices to a sequence of output logits 
    over the vocabulary using causal self-attention, suitable for autoregressive 
    language modeling (i.e., predicting the next token given previous ones).

    Parameters:
    -----------
    vocab_size : int
        Size of the vocabulary (number of unique tokens).
    embed_dim : int
        Dimension of the embedding vectors and model hidden states.
    num_heads : int
        Number of attention heads in each Transformer block.
    num_layers : int
        Number of stacked Transformer blocks.
    ff_hidden_dim : int
        Hidden layer size inside the feedforward network of each block.
    max_len : int
        Maximum sequence length supported by the positional encoding.
    """
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, ff_hidden_dim, max_len=1024):
        super().__init__()
        # YOUR CODE HERE
        raise NotImplementedError()

    def generate_causal_mask(self, seq_len, device):
        """
        Creates a causal mask to prevent attention to future tokens.

        Returns:
        --------
        attn_mask : Tensor
            Boolean tensor of shape (seq_len, seq_len) where True values are masked.
        """
        return torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()

    def forward(self, x):
        """
        Forward pass through the Transformer language model.

        Parameters:
        -----------
        x : Tensor
            Input tensor of token indices with shape (batch_size, seq_len).

        Returns:
        --------
        logits : Tensor
            Output tensor of shape (batch_size, seq_len, vocab_size), containing
            unnormalized scores for each token in the vocabulary.
        """
        # YOUR CODE HERE
        raise NotImplementedError()

In [None]:
# Test Config
embed_dim = 16
num_heads = 4
ff_hidden_dim = 64
seq_len = 10
batch_size = 2
num_layers = 2

block = TransformerBlock(embed_dim, num_heads, ff_hidden_dim)
model = TransformerDecoderOnlyLM(vocab_size, embed_dim, num_heads, num_layers, ff_hidden_dim, max_len=seq_len)

print("Running test: TransformerBlock output shape...")
x = torch.randn(batch_size, seq_len, embed_dim)
attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
out = block(x, attn_mask)
assert out.shape == (batch_size, seq_len, embed_dim), "TransformerBlock output shape mismatch"

print("Running test: TransformerDecoderOnlyLM output shape...")
tokens = torch.randint(0, vocab_size, (batch_size, seq_len))
logits = model(tokens)
assert logits.shape == (batch_size, seq_len, vocab_size), "TransformerDecoderOnlyLM output shape mismatch"

print("Running test: Causal mask effectiveness...")
x1 = torch.randint(0, vocab_size, (1, seq_len))
x2 = x1.clone()
change_start = seq_len // 2
x2[0, change_start:] = (x2[0, change_start:] + 1) % vocab_size

with torch.no_grad():
    logits1 = model(x1)
    logits2 = model(x2)

tolerance = 1e-4
diff = (logits1[:, :change_start] - logits2[:, :change_start]).abs().max().item()
assert diff < tolerance, f"Causal mask failed — output changed by {diff:.4f}"

print("Tests pass - success!")

# Reset after tests
block_size = 128 
batch_size = 32
vocab_size = 512

## Training Instructions

You will train the TransformerDecoderOnlyLM model to achieve a test loss below 3.6, corresponding to a good language model on this dataset.

Hints for training:
- Gradient clipping significantly stabilizes training (nn.utils.clip_grad_norm_).
- AdamW optimizer usually outperforms plain Adam.
- For better results, you can consider using learning rate warm-up and cosine decay schedules.

In practice, training a moderately-sized Transformer (e.g., embed_dim=512) for around 3 epochs should take approximately 15 minutes on JupyterHub CPU and under one minute on a Colab GPU and lead to sufficient performance for passing the tests. Training for longer can yield significantly improved results. Modify the cell below to define the hyperparameters, the desired optimizer and training duration, and any other optional components. Do not modify the model creating function, our tests assume this.

In [None]:
embed_dim = None
ff_hidden_dim = None
num_heads = None
num_layers = None

In [None]:
# This cell sets some stuff for TA use

model = TransformerDecoderOnlyLM(
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    num_heads=num_heads,
    num_layers=num_layers,
    ff_hidden_dim=ff_hidden_dim,
    max_len=block_size,
).to(device)
criterion = nn.CrossEntropyLoss()

In [None]:
optimizer = None 
epochs = None

In [None]:
# This cell sets some stuff for TA use

In [None]:
def count_learnable_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Learnable parameters: {count_learnable_params(model):,}")

In [None]:
if not skip_training:
    for epoch in range(epochs):
        train_dataset.set_epoch(epoch)
        # re-created before each epoch because we shift the dataset epoch to get an offset and some variance in the batches
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)  
        
        model.train()
        total_loss = 0
        for x, y in tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]"):
            x, y = x.to(device), y.to(device)
            
            # YOUR CODE HERE
            raise NotImplementedError()
            total_loss += loss.item()
            
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch + 1}, Train Loss: {avg_loss:.4f}")

        # --- Evaluation ---
        model.eval()
        test_loss = 0
        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.to(device), y.to(device)
                output = model(x)
                loss = criterion(output.view(-1, vocab_size), y.view(-1))
                test_loss += loss.item()

        avg_test_loss = test_loss / len(test_loader)
        print(f"Epoch {epoch + 1}, Test Loss: {avg_test_loss:.4f}")

In [None]:
def save_model(model, filename):
    try:
        do_save = input('Do you want to save the model (type yes to confirm)? ').lower()
        if do_save == 'yes':
            torch.save(model.state_dict(), filename)
            print('Model saved to %s.' % (filename))
        else:
            print('Model not saved.')
    except:
        raise Exception('The notebook should be run or validated with skip_training=True.')


def load_model(model, filename, device):
    model.load_state_dict(torch.load(filename, map_location=lambda storage, loc: storage))
    print('Model loaded from %s.' % filename)
    model.to(device)
    model.eval()

In [None]:
# Save the model to disk (the pth-files will be submitted automatically together with your notebook)
if not skip_training:
    save_model(model, 'shakespeare_decoder.pth')
else:
    model = TransformerDecoderOnlyLM(
        vocab_size=vocab_size,
        embed_dim=embed_dim,
        num_heads=num_heads,
        num_layers=num_layers,
        ff_hidden_dim=ff_hidden_dim,
        max_len=block_size,
    )
    load_model(model, 'shakespeare_decoder.pth', device)

## Evaluation and Testing

We evaluate using average test perplexity (corresponds to test loss), including a separate calculation for the loss specifically at the final token of various sequence lengths. This ensures that your causal attention masking is correctly implemented.

Passing criteria:
- Causal mask correctness: Final-token average loss must be below 4.4. (1 point)
- Acceptable model: Average test loss must be below 4.0. (1 additional point)
- Good model (target): Average test loss must be below 3.6. (1 additional point)

In [None]:
def evaluate_perplexity(model, test_loader, device):
    model.eval()
    total_loss = 0
    criterion = nn.CrossEntropyLoss()

    # ---- Compute Perplexity ----
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            output = model(x)
            loss = criterion(output.view(-1, model.output_projection.out_features), y.view(-1))
            total_loss += loss.item()

    avg_loss = total_loss / len(test_loader)
    print(f"Avg loss: {avg_loss:.2f}")
    perplexity = math.exp(avg_loss)
    print(f"Perplexity on test set: {perplexity:.2f}")
    return avg_loss

def evaluate_perplexity_at_token(model, test_loader, token_idx, device, verbose=False):
    model.eval()
    criterion = nn.CrossEntropyLoss()
    losses = []

    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)

            # Cut all sequences at token_idx
            input_cut = x[:, :token_idx]  # Shape: [batch_size, token_idx]
            target_tokens = y[:, token_idx-1]  # Shape: [batch_size]

            output = model(input_cut)  # Shape: [batch_size, token_idx, vocab_size]
            logits = output[:, -1, :]  # Shape: [batch_size, vocab_size]

            loss = criterion(logits, target_tokens)
            losses.append(loss.item())

    avg_loss = sum(losses) / len(losses)
    if verbose:
        print(f"Avg loss at token {token_idx}: {avg_loss:.2f}")
    perplexity = math.exp(avg_loss)
    if verbose:
        print(f"Per-token perplexity at token {token_idx}: {perplexity:.2f}")
    return avg_loss

In [None]:
test_loss = evaluate_perplexity(model, test_loader, device)

losses_at_final_token = []
for tk_index in [32, 64, 96, 128]:
    losses_at_final_token.append(evaluate_perplexity_at_token(model, test_loader, token_idx=tk_index, device=device, verbose=False))
losses_at_final_token = sum(losses_at_final_token)/len(losses_at_final_token)
print(f"Average loss at the final token for different length inputs: {losses_at_final_token:.2f}")

In [None]:
assert losses_at_final_token < 4.4, "If the test loss during training was low but this test fails, most likely the causal mask has not been applied correctly"

In [None]:
assert losses_at_final_token < 4.4, "If the test loss during training was low but this test fails, most likely the causal mask has not been applied correctly"
assert test_loss < 4, "The model does not perform very well"

In [None]:
assert losses_at_final_token < 4.4, "If the test loss during training was low but this test fails, most likely the causal mask has not been applied correctly"
assert test_loss < 3.6, "The model performance could still be improved"
print("Tests pass - success!")

## Text Generation and Sampling

Finally, we sample from the trained model. You can experiment with various sampling strategies:
- Temperature controls the randomness: lower values (e.g., 0.7 or lower) yield more deterministic and coherent text; values close to 1.0 produce more diverse but potentially noisy outputs.
- Top-k sampling restricts the choices to the k most likely tokens at each step, improving coherence.

You should observe that combining moderately low temperatures (around 0.7–0.8) with top-k sampling typically yields high-quality, coherent text samples with perplexity well below 10.

In [None]:
def sample_next_token(logits, temperature=1.0, top_k=0):
    logits = logits / temperature

    if top_k > 0:
        # Top-k sampling: keep only top k tokens with highest logits
        top_k = min(top_k, logits.size(-1))  # safety
        values, indices = torch.topk(logits, top_k)
        logits_filtered = torch.full_like(logits, float('-inf'))
        logits_filtered.scatter_(1, indices, values)
        logits = logits_filtered

    probs = torch.softmax(logits, dim=-1)
    return torch.multinomial(probs, num_samples=1)


def sample(model, test_loader, sp_model, device, num_samples=5, max_new_tokens=100, temperature=1.0, top_k=0):
    # ---- Collect num_samples prompts ----
    print(f"=== Sampling {num_samples} test examples ===\n")
    collected_prompts = []
    criterion = nn.CrossEntropyLoss(reduction='none')  # Per-token loss

    for x, _ in test_loader:
        for sample in x:
            if len(collected_prompts) < num_samples:
                collected_prompts.append(sample[None].to(device))
            else:
                break
        if len(collected_prompts) >= num_samples:
            break

    # ---- Batched Sampling ----
    model.eval()
    with torch.no_grad():
        total_loss = 0
        total_tokens = 0
        for i in range(0, len(collected_prompts), batch_size):
            batch_prompts = collected_prompts[i:i + batch_size]
            generated = torch.cat(batch_prompts, dim=0)

            for _ in range(max_new_tokens):
                input_chunk = generated[:, -block_size:]
                output = model(input_chunk)
                next_token_logits = output[:, -1, :]  # shape: [B, vocab_size]
                next_tokens = sample_next_token(next_token_logits, temperature=temperature, top_k=top_k)
                generated = torch.cat((generated, next_tokens), dim=1)
                
                loss = criterion(next_token_logits, next_tokens.squeeze(-1))  # [batch_size]
                total_loss += loss.sum().item()
                total_tokens += next_tokens.size(0)

            for j, (prompt_tensor, full_output) in enumerate(zip(batch_prompts, generated)):
                prompt_tokens = prompt_tensor.tolist()
                generated_tokens = full_output.tolist()

                prompt_text = sp_model.decode(prompt_tokens)[0]
                generated_text = sp_model.decode(generated_tokens)

                print(f"[Sample {i + j + 1}]")
                print("Prompt:")
                print(prompt_text)
                print("\nGenerated continuation:")
                print(generated_text[len(prompt_text):])
                print("=" * 60)
                
        avg_loss = total_loss / total_tokens
        perplexity = math.exp(avg_loss)
        print(f"\n=== Perplexity of generated text: {perplexity:.2f} ===")
        return perplexity

In [None]:
if not skip_training:
    print("Sampling with default settings: no top_k, temperature=1.0")
    sample(model, test_loader, sp_model, device, num_samples=3, max_new_tokens=100, temperature=1)
    print("\n\nClose-to-argmax sampling: temperature = 0.01")
    sample(model, test_loader, sp_model, device, num_samples=3, max_new_tokens=100, temperature=0.01)
    print("\n\nSampling with top_k: temperature = 0.75, top_k = 30")
    sample(model, test_loader, sp_model, device, num_samples=3, max_new_tokens=100, temperature=0.75, top_k=30)