In [17]:
import torch
import torch.nn as nn
import torch.optim as optim
import math
import time
import numpy as np

## Set the device to CUDA (GPU) if available, otherwise use CPU

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


## Parameters used for data generation

In [9]:
# Maximum possible token value for random generation
MAX_TOKEN_VALUE = 100
# Size of the dataset to generate
DATASET_SIZE = 100000

## Model Hyperparameters

In [10]:
# The dimensionality of the token embeddings. This is the size of the vector
# that will represent each token.
EMBEDDING_DIM = 256

# The number of attention heads in the multi-head attention mechanism.
# The embedding dimension must be divisible by this number.
NUM_HEADS = 8

# The number of Transformer encoder layers to stack.
NUM_ENCODER_LAYERS = 4

# The dimension of the feed-forward network within each Transformer layer.
FF_DIM = 1024

# The dropout probability to be applied in the model.
DROPOUT = 0.1

## Training Hyperparameters

In [11]:
# The number of independent sequences to process in parallel.
BATCH_SIZE = 32

# The length of the subsequences to be used for training. This is also known
# as the context window or "backpropagation through time" (BPTT) length.
SEQUENCE_LENGTH = 64

# The number of epochs to train the model for.
EPOCHS = 5

# The learning rate for the optimizer.
LEARNING_RATE = 0.001

# How often to log training progress (in batches).
LOG_INTERVAL = 200

## Generate random sequence for training

In [12]:
def generate_random_sequence(length, max_token_value):
    """Generates a long sequence of random discrete tokens."""
    # Create a tensor of random integers between 0 and max_token_value
    # Note: We use a non-uniform distribution to make it more realistic
    # Some tokens will be more frequent than others
    tokens = []
    for _ in range(length):
        # Use a power law distribution to make some tokens more frequent
        # This simulates real-world token distributions (Zipf's law)
        if np.random.random() < 0.7:  # 70% chance for common tokens
            token = np.random.randint(0, max_token_value // 4)
        elif np.random.random() < 0.9:  # 20% chance for medium tokens
            token = np.random.randint(max_token_value // 4, max_token_value // 2)
        else:  # 10% chance for rare tokens
            token = np.random.randint(max_token_value // 2, max_token_value)
        tokens.append(token)
    
    return torch.tensor(tokens, dtype=torch.long)

## Analyze Vocabulary

In [13]:
def analyze_vocabulary(data):
    """
    Analyzes the dataset to determine vocabulary characteristics.
    Returns vocabulary size and other statistics.
    """
    print("Analyzing vocabulary from the dataset...")
    
    # Get unique tokens and their frequencies
    unique_tokens, counts = torch.unique(data, return_counts=True)
    vocab_size = len(unique_tokens)
    
    # Calculate statistics
    min_token = unique_tokens.min().item()
    max_token = unique_tokens.max().item()
    total_tokens = len(data)
    
    # Sort by frequency for analysis
    sorted_indices = torch.argsort(counts, descending=True)
    sorted_tokens = unique_tokens[sorted_indices]
    sorted_counts = counts[sorted_indices]
    
    print(f"Vocabulary Analysis Results:")
    print(f"  - Total unique tokens (vocab size): {vocab_size}")
    print(f"  - Token range: {min_token} to {max_token}")
    print(f"  - Total tokens in dataset: {total_tokens}")
    print(f"  - Average token frequency: {total_tokens / vocab_size:.2f}")
    
    # Show top 10 most frequent tokens
    print(f"  - Top 10 most frequent tokens:")
    for i in range(min(10, len(sorted_tokens))):
        token = sorted_tokens[i].item()
        count = sorted_counts[i].item()
        percentage = (count / total_tokens) * 100
        print(f"    Token {token}: {count} occurrences ({percentage:.2f}%)")
    
    # Show some rare tokens
    print(f"  - Some rare tokens (bottom 5):")
    for i in range(max(0, len(sorted_tokens) - 5), len(sorted_tokens)):
        token = sorted_tokens[i].item()
        count = sorted_counts[i].item()
        percentage = (count / total_tokens) * 100
        print(f"    Token {token}: {count} occurrences ({percentage:.2f}%)")
    
    return vocab_size, {
        'unique_tokens': unique_tokens,
        'counts': counts,
        'min_token': min_token,
        'max_token': max_token,
        'total_tokens': total_tokens
    }

## Create token mappings

In [14]:
def create_token_mapping(unique_tokens):
    """
    Creates a mapping from original tokens to contiguous indices.
    This is important for efficiency - we want our embedding layer to have
    size equal to vocab_size, not max_token_value.
    """
    print("Creating token mapping for efficient embedding...")
    
    # Create mapping from original token to index
    token_to_idx = {token.item(): idx for idx, token in enumerate(unique_tokens)}
    idx_to_token = {idx: token.item() for idx, token in enumerate(unique_tokens)}
    
    print(f"  - Mapped {len(unique_tokens)} tokens to indices 0-{len(unique_tokens)-1}")
    
    return token_to_idx, idx_to_token

In [15]:
def remap_dataset(data, token_to_idx):
    """
    Remaps the dataset to use contiguous indices instead of original token values.
    """
    print("Remapping dataset to use contiguous indices...")
    
    # Create a new tensor with remapped values
    remapped_data = torch.zeros_like(data)
    for i, token in enumerate(data):
        remapped_data[i] = token_to_idx[token.item()]
    
    print(f"  - Remapped {len(data)} tokens")
    print(f"  - Original token range: {data.min().item()} to {data.max().item()}")
    print(f"  - Remapped token range: {remapped_data.min().item()} to {remapped_data.max().item()}")
    
    return remapped_data

## Creating raw dataset

In [16]:
# Generate the raw dataset
print(f"Generating random dataset of {DATASET_SIZE} tokens with max value {MAX_TOKEN_VALUE}...")
raw_train_data = generate_random_sequence(DATASET_SIZE, MAX_TOKEN_VALUE)
print(f"Sample of raw data: {raw_train_data[:20]}...")

# Analyze vocabulary and get dynamic vocab size
VOCAB_SIZE, vocab_stats = analyze_vocabulary(raw_train_data)

# Create token mapping for efficient embedding
token_to_idx, idx_to_token = create_token_mapping(vocab_stats['unique_tokens'])

# Remap the dataset to use contiguous indices
train_data = remap_dataset(raw_train_data, token_to_idx)
print(f"Sample of remapped data: {train_data[:20]}...")

# Move the dataset to the selected device
train_data = train_data.to(device)

print(f"\n--- DYNAMIC VOCABULARY SIZE: {VOCAB_SIZE} ---")
print(f"This will be used for embedding and output layer dimensions.")

Generating random dataset of 100000 tokens with max value 100...
Sample of raw data: tensor([23,  4,  4, 57, 57, 37, 48,  5,  3, 24, 42,  3, 17, 19,  4, 34, 38, 12,
         1, 44])...
Analyzing vocabulary from the dataset...
Vocabulary Analysis Results:
  - Total unique tokens (vocab size): 100
  - Token range: 0 to 99
  - Total tokens in dataset: 100000
  - Average token frequency: 1000.00
  - Top 10 most frequent tokens:
    Token 19: 2896 occurrences (2.90%)
    Token 7: 2884 occurrences (2.88%)
    Token 10: 2880 occurrences (2.88%)
    Token 0: 2867 occurrences (2.87%)
    Token 21: 2864 occurrences (2.86%)
    Token 8: 2850 occurrences (2.85%)
    Token 13: 2833 occurrences (2.83%)
    Token 1: 2826 occurrences (2.83%)
    Token 24: 2821 occurrences (2.82%)
    Token 17: 2819 occurrences (2.82%)
  - Some rare tokens (bottom 5):
    Token 58: 50 occurrences (0.05%)
    Token 79: 45 occurrences (0.04%)
    Token 97: 45 occurrences (0.04%)
    Token 83: 41 occurrences (0.04%)
    T

## Function to get batch of training data

In [18]:
def get_batch(source_data, seq_length, batch_size):
    """
    Generates a batch of source and target sequences for training.
    This is the core of how we set up the "next token prediction" task.
    """
    # Get the total length of the dataset
    num_tokens = len(source_data)
    # Generate random starting points for our sequences within the dataset
    # We subtract seq_length + 1 to ensure we have a valid target for each sequence
    start_indices = torch.randint(0, num_tokens - seq_length - 1, (batch_size,))

    # Create the source sequences (input to the model)
    # torch.stack builds a new tensor from a list of tensors
    x = torch.stack([source_data[i : i + seq_length] for i in start_indices])

    # Create the target sequences (what the model should predict)
    # The target for each token in the input is the very next token in the sequence.
    y = torch.stack([source_data[i + 1 : i + seq_length + 1] for i in start_indices])

    return x, y

print(f"\nLet's see an example of a single batch with batch_size=1 and seq_length=5:")
x_sample, y_sample = get_batch(train_data, 5, 1)
print(f"Source (x): {x_sample.squeeze().tolist()}")
print(f"Target (y): {y_sample.squeeze().tolist()}")
print("Notice that the target is the source sequence shifted one position to the right.")


Let's see an example of a single batch with batch_size=1 and seq_length=5:
Source (x): [6, 10, 40, 4, 23]
Target (y): [10, 40, 4, 23, 19]
Notice that the target is the source sequence shifted one position to the right.


## Create positional encoding since transformers see all permutations as same. Need this to make sure sequence is learned.

In [19]:
class PositionalEncoding(nn.Module):
    """
    Injects position information into the token embeddings.
    Since the Transformer architecture itself doesn't have a notion of order,
    we add these positional encodings to the input embeddings.
    """
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Create a matrix for positional encodings of shape (max_len, d_model)
        pe = torch.zeros(max_len, d_model)
        # Create a position tensor [0, 1, 2, ..., max_len-1]
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        # Calculate the division term for the sine and cosine functions
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        # Apply sine to even indices in the array; 2i
        pe[:, 0::2] = torch.sin(position * div_term)
        # Apply cosine to odd indices in the array; 2i+1
        pe[:, 1::2] = torch.cos(position * div_term)

        # Add a batch dimension to the positional encoding matrix
        pe = pe.unsqueeze(0).transpose(0, 1)
        # Register 'pe' as a buffer. Buffers are part of the model's state,
        # but they are not considered model parameters to be trained.
        self.register_buffer('pe', pe)
        print("Initialized PositionalEncoding module.")

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        # Add the positional encoding to the input tensor
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

## The transformer model

In [20]:
class TransformerModel(nn.Module):
    """
    A Transformer model for sequence-to-sequence tasks.
    In our case, it's used for next-token prediction.
    """
    def __init__(self, vocab_size, d_model, nhead, d_hid, nlayers, dropout=0.1):
        super(TransformerModel, self).__init__()
        self.model_type = 'Transformer'
        self.d_model = d_model
        self.vocab_size = vocab_size

        # 1. Token Embedding Layer: Maps input token indices to dense vectors.
        self.encoder = nn.Embedding(vocab_size, d_model)
        print(f"Initialized nn.Embedding: maps {vocab_size} tokens to {d_model}-dim vectors.")

        # 2. Positional Encoding: Adds positional information.
        self.pos_encoder = PositionalEncoding(d_model, dropout)

        # 3. Transformer Encoder Layers: The core of the model.
        encoder_layers = nn.TransformerEncoderLayer(d_model, nhead, d_hid, dropout, batch_first=False)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, nlayers)
        print(f"Initialized nn.TransformerEncoder with {nlayers} layers.")

        # 4. Final Linear Layer (Decoder): Maps the Transformer output back to the vocabulary space.
        self.decoder = nn.Linear(d_model, vocab_size)
        print(f"Initialized final nn.Linear decoder: maps {d_model}-dim vectors to {vocab_size} (vocab size) logits.")

        self.init_weights()

    def init_weights(self):
        """Initializes weights for the embedding and linear layers."""
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src, src_mask):
        """
        Forward pass of the model.
        Args:
            src: the sequence to the encoder (required). Shape: [seq_len, batch_size].
            src_mask: the mask for the src sequence (required).
        """
        print("\n--- Inside Model Forward Pass ---")
        print(f"Input `src` shape: {src.shape} [Sequence Length, Batch Size]")

        # 1. Embed the tokens and scale by sqrt(d_model)
        src = self.encoder(src) * math.sqrt(self.d_model)
        print(f"Shape after Embedding and Scaling: {src.shape} [Seq Len, Batch, Embedding Dim]")

        # 2. Add positional encoding
        src = self.pos_encoder(src)
        print(f"Shape after Positional Encoding: {src.shape} [Seq Len, Batch, Embedding Dim]")

        # 3. Pass through the Transformer encoder layers
        output = self.transformer_encoder(src, src_mask)
        print(f"Shape after Transformer Encoder: {output.shape} [Seq Len, Batch, Embedding Dim]")

        # 4. Decode the output to get logits for each token in the vocabulary
        output = self.decoder(output)
        print(f"Shape after Final Decoder Layer: {output.shape} [Seq Len, Batch, Vocab Size]")
        print("--- End of Model Forward Pass ---\n")
        return output

## Creating mask to prevent the model from seeing future tokens

In [21]:
def generate_square_subsequent_mask(sz):
    """
    Generates a square causal mask for the sequence.
    The masked positions are filled with -inf.
    Unmasked positions are 0. This prevents the model from "cheating" by
    looking at future tokens during training.
    """
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

## Initialize Model, loss and Optimizer

In [24]:
# Instantiate the model with the dynamically calculated vocab size
model = TransformerModel(
    VOCAB_SIZE, EMBEDDING_DIM, NUM_HEADS, FF_DIM, NUM_ENCODER_LAYERS, DROPOUT
).to(device)

# Define the loss function and the optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

print(f"Model initialized with dynamic vocabulary size: {VOCAB_SIZE}")
print(f"Total model parameters: {sum(p.numel() for p in model.parameters()):,}")

Initialized nn.Embedding: maps 100 tokens to 256-dim vectors.
Initialized PositionalEncoding module.
Initialized nn.TransformerEncoder with 4 layers.
Initialized final nn.Linear decoder: maps 256-dim vectors to 100 (vocab size) logits.
Model initialized with dynamic vocabulary size: 100
Total model parameters: 3,210,340


## Training the model

In [25]:
def train(epoch):
    """Defines the training loop for one epoch."""
    model.train()  # Set the model to training mode
    total_loss = 0.
    start_time = time.time()
    # Generate the causal mask. It's the same for all sequences of the same length.
    src_mask = generate_square_subsequent_mask(SEQUENCE_LENGTH).to(device)

    # Calculate the number of batches in one epoch
    num_batches = len(train_data) // (SEQUENCE_LENGTH * BATCH_SIZE)

    print(f"\n--- Starting Epoch {epoch} ---")
    for batch, i in enumerate(range(0, train_data.size(0) - 1 - SEQUENCE_LENGTH, SEQUENCE_LENGTH)):
        # Get a batch of data
        data, targets = get_batch(train_data, SEQUENCE_LENGTH, BATCH_SIZE)
        
        # The model expects inputs of shape [sequence_length, batch_size]
        # Our get_batch function returns [batch_size, sequence_length], so we permute it.
        data = data.permute(1, 0)
        targets = targets.permute(1, 0)
        
        # The first time through, print shapes to be extra clear
        if batch == 0:
            print(f"Shape of data batch fed to model: {data.shape}")
            print(f"Shape of target batch for loss: {targets.shape}")
            print(f"Shape of causal mask: {src_mask.shape}")
            print("Starting batch iterations...")

        optimizer.zero_grad() # Reset gradients

        # This is where we stop printing the forward pass details to avoid clutter
        # We'll only do it once during the prediction phase.
        if batch == 0 and epoch == 1:
            output = model(data, src_mask) # Run the forward pass
        else:
            # Temporarily disable the print statements in the forward pass
            # for cleaner training logs.
            _print = __builtins__.print
            __builtins__.print = lambda *args, **kwargs: None
            output = model(data, src_mask)
            __builtins__.print = _print

        # Reshape the output and targets for the loss function
        # The loss function expects [Batch * SeqLen, VocabSize] and [Batch * SeqLen]
        loss = criterion(output.view(-1, VOCAB_SIZE), targets.reshape(-1))

        loss.backward() # Compute gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) # Prevent exploding gradients
        optimizer.step() # Update weights

        total_loss += loss.item()

        # Log progress
        if batch % LOG_INTERVAL == 0 and batch > 0:
            lr = optimizer.param_groups[0]['lr']
            ms_per_batch = (time.time() - start_time) * 1000 / LOG_INTERVAL
            cur_loss = total_loss / LOG_INTERVAL
            print(f'| epoch {epoch:3d} | {batch:5d}/{num_batches * BATCH_SIZE:5d} batches | '
                  f'lr {lr:02.5f} | ms/batch {ms_per_batch:5.2f} | '
                  f'loss {cur_loss:5.2f} | ppl {math.exp(cur_loss):8.2f}')
            total_loss = 0
            start_time = time.time()

# Run the training loop
for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(epoch)
    print(f'--- End of Epoch {epoch} | Time: {(time.time() - epoch_start_time):.2f}s ---')


--- Starting Epoch 1 ---
Shape of data batch fed to model: torch.Size([64, 32])
Shape of target batch for loss: torch.Size([64, 32])
Shape of causal mask: torch.Size([64, 64])
Starting batch iterations...

--- Inside Model Forward Pass ---
Input `src` shape: torch.Size([64, 32]) [Sequence Length, Batch Size]
Shape after Embedding and Scaling: torch.Size([64, 32, 256]) [Seq Len, Batch, Embedding Dim]
Shape after Positional Encoding: torch.Size([64, 32, 256]) [Seq Len, Batch, Embedding Dim]
Shape after Transformer Encoder: torch.Size([64, 32, 256]) [Seq Len, Batch, Embedding Dim]
Shape after Final Decoder Layer: torch.Size([64, 32, 100]) [Seq Len, Batch, Vocab Size]
--- End of Model Forward Pass ---

| epoch   1 |   200/ 1536 batches | lr 0.00100 | ms/batch 14.24 | loss  4.00 | ppl    54.75
| epoch   1 |   400/ 1536 batches | lr 0.00100 | ms/batch 12.95 | loss  3.96 | ppl    52.43
| epoch   1 |   600/ 1536 batches | lr 0.00100 | ms/batch 12.93 | loss  3.95 | ppl    51.85
| epoch   1 |  

## Generating new sequence

In [26]:
def predict(model, seed_sequence, max_len=50, idx_to_token=None):
    """
    Generates a sequence token by token based on a seed.
    """
    model.eval() # Set the model to evaluation mode
    print(f"Seed sequence (indices): {seed_sequence}")
    if idx_to_token:
        original_tokens = [idx_to_token[idx] for idx in seed_sequence]
        print(f"Seed sequence (original tokens): {original_tokens}")
    
    # Convert the seed sequence (list of ints) to a tensor
    input_tensor = torch.tensor(seed_sequence, dtype=torch.long).unsqueeze(1).to(device)
    
    generated_sequence = seed_sequence.copy()

    with torch.no_grad():
        for step in range(max_len):
            print(f"\n--- Prediction step {step + 1} ---")
            
            # Create the causal mask for the current sequence length
            current_seq_len = input_tensor.size(0)
            mask = generate_square_subsequent_mask(current_seq_len).to(device)
            
            # Get the model's output. We will re-enable printing for this one pass.
            if step == 0:  # Only print details for the first step
                output = model(input_tensor, mask)
            else:
                # Disable printing for subsequent steps
                _print = __builtins__.print
                __builtins__.print = lambda *args, **kwargs: None
                output = model(input_tensor, mask)
                __builtins__.print = _print
            
            # We only care about the prediction for the VERY LAST token in the input sequence.
            # The output shape is [seq_len, batch_size, vocab_size].
            # We take the last token's output: output[-1, 0, :]
            last_token_logits = output[-1, 0, :]
            
            # Apply softmax to get probabilities
            probabilities = torch.softmax(last_token_logits, dim=-1)
            
            # Choose the next token. We can use argmax for the most likely token,
            # or sample from the distribution for more variety.
            # next_token = torch.argmax(probabilities).item()
            next_token = torch.multinomial(probabilities, 1).item()
            
            predicted_original = idx_to_token[next_token] if idx_to_token else next_token
            print(f"Model predicted next token: {next_token} (original: {predicted_original})")
            
            # Append the predicted token to our sequence
            generated_sequence.append(next_token)
            
            # Create the new input for the next iteration by appending the predicted token
            input_tensor = torch.cat([input_tensor, torch.tensor([[next_token]], device=device)], dim=0)

    return generated_sequence

## Generating new sequence

In [30]:
# Let's generate a new sequence using some tokens from our vocabulary
# We'll use the first few tokens from our remapped vocabulary
seed_indices = [0, 1, 2, 3, 4]  # These are indices in our remapped space
predicted_sequence = predict(model, seed_indices, max_len=20, idx_to_token=idx_to_token)

print("\n\n--- Final Result ---")
print(f"Original Seed (indices): {seed_indices}")
print(f"Original Seed (original tokens): {[idx_to_token[idx] for idx in seed_indices]}")
print(f"Generated Sequence (indices): {predicted_sequence}")
print(f"Generated Sequence (original tokens): {[idx_to_token[idx] for idx in predicted_sequence]}")
print("\nTraining and prediction complete!")
print(f"Final vocabulary size used: {VOCAB_SIZE}")

Seed sequence (indices): [0, 1, 2, 3, 4]
Seed sequence (original tokens): [0, 1, 2, 3, 4]

--- Prediction step 1 ---

--- Inside Model Forward Pass ---
Input `src` shape: torch.Size([5, 1]) [Sequence Length, Batch Size]
Shape after Embedding and Scaling: torch.Size([5, 1, 256]) [Seq Len, Batch, Embedding Dim]
Shape after Positional Encoding: torch.Size([5, 1, 256]) [Seq Len, Batch, Embedding Dim]
Shape after Transformer Encoder: torch.Size([5, 1, 256]) [Seq Len, Batch, Embedding Dim]
Shape after Final Decoder Layer: torch.Size([5, 1, 100]) [Seq Len, Batch, Vocab Size]
--- End of Model Forward Pass ---

Model predicted next token: 18 (original: 18)

--- Prediction step 2 ---
Model predicted next token: 7 (original: 7)

--- Prediction step 3 ---
Model predicted next token: 3 (original: 3)

--- Prediction step 4 ---
Model predicted next token: 30 (original: 30)

--- Prediction step 5 ---
Model predicted next token: 14 (original: 14)

--- Prediction step 6 ---
Model predicted next token: 6