# Transformers in PyTorch

This notebook demonstrates transformer implementation from self-attention to full transformer models.

## Table of Contents
1. [Self-Attention from Scratch](#self-attention-from-scratch)
2. [Using PyTorch's Built-in Transformer](#using-pytorchs-built-in-transformer)
3. [Masking: Causal and Padding](#masking-causal-and-padding)
4. [Next-Token Prediction Demo](#next-token-prediction-demo)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import math

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Self-Attention from Scratch

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # Linear projections for Q, K, V
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # Q, K, V: (batch_size, n_heads, seq_len, d_k)
        batch_size, n_heads, seq_len, d_k = Q.size()
        
        # Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
        
        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
            
        # Apply softmax
        attention_weights = F.softmax(scores, dim=-1)
        
        # Apply attention to values
        output = torch.matmul(attention_weights, V)
        
        return output, attention_weights
    
    def forward(self, x, mask=None):
        batch_size, seq_len, d_model = x.size()
        
        # Linear projections
        Q = self.W_q(x)  # (batch_size, seq_len, d_model)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # Reshape for multi-head attention
        Q = Q.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        # Apply attention
        attn_output, attn_weights = self.scaled_dot_product_attention(Q, K, V, mask)
        
        # Concatenate heads
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, d_model
        )
        
        # Final linear projection
        output = self.W_o(attn_output)
        
        return output, attn_weights

# Test the attention mechanism
d_model = 64
n_heads = 8
seq_len = 10
batch_size = 2

attention = MultiHeadAttention(d_model, n_heads)
x = torch.randn(batch_size, seq_len, d_model)

output, weights = attention(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {weights.shape}")

## Using PyTorch's Built-in Transformer

In [None]:
class SimpleTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, n_layers, d_ff, max_seq_len):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = self.create_positional_encoding(max_seq_len, d_model)
        
        # PyTorch's built-in transformer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_ff,
            dropout=0.1,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, n_layers)
        
        self.output_projection = nn.Linear(d_model, vocab_size)
        
    def create_positional_encoding(self, max_seq_len, d_model):
        pe = torch.zeros(max_seq_len, d_model)
        position = torch.arange(0, max_seq_len).unsqueeze(1).float()
        
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        return pe.unsqueeze(0)  # Add batch dimension
        
    def forward(self, x, mask=None):
        seq_len = x.size(1)
        
        # Embedding and positional encoding
        x = self.embedding(x) * math.sqrt(self.d_model)
        x = x + self.pos_encoding[:, :seq_len, :].to(x.device)
        
        # Transformer layers
        x = self.transformer(x, src_key_padding_mask=mask)
        
        # Output projection
        return self.output_projection(x)

# Create a simple transformer
vocab_size = 1000
d_model = 128
n_heads = 8
n_layers = 4
d_ff = 512
max_seq_len = 100

model = SimpleTransformer(vocab_size, d_model, n_heads, n_layers, d_ff, max_seq_len)

# Test with random input
batch_size = 4
seq_len = 20
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))

output = model(input_ids)
print(f"Input shape: {input_ids.shape}")
print(f"Output shape: {output.shape}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## Masking: Causal and Padding

In [None]:
def create_causal_mask(seq_len):
    """Create a causal (lower triangular) mask"""
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask

def create_padding_mask(sequences, pad_token_id=0):
    """Create a padding mask"""
    return (sequences != pad_token_id)

# Example of causal masking for autoregressive generation
seq_len = 5
causal_mask = create_causal_mask(seq_len)
print("Causal mask:")
print(causal_mask.numpy())

# Example of padding mask
sequences = torch.tensor([
    [1, 2, 3, 4, 5],    # No padding
    [1, 2, 3, 0, 0],    # Padded with 0s
    [1, 0, 0, 0, 0]     # Heavily padded
])
padding_mask = create_padding_mask(sequences)
print("\nSequences:")
print(sequences.numpy())
print("Padding mask (True = real token, False = padding):")
print(padding_mask.numpy())

# Using masks in transformer
class MaskedTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, n_layers):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=n_heads,
            batch_first=True
        )
        self.transformer = nn.TransformerDecoder(decoder_layer, n_layers)
        self.output_proj = nn.Linear(d_model, vocab_size)
        
    def forward(self, tgt, memory, tgt_mask=None, tgt_key_padding_mask=None):
        tgt = self.embedding(tgt)
        output = self.transformer(
            tgt, memory, 
            tgt_mask=tgt_mask, 
            tgt_key_padding_mask=tgt_key_padding_mask
        )
        return self.output_proj(output)

print("\nMasked transformer created for autoregressive generation")

## Next-Token Prediction Demo

In [None]:
# Simple character-level next-token prediction
class CharacterLevelTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=64, n_heads=4, n_layers=2):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_model * 4,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, n_layers)
        self.output_proj = nn.Linear(d_model, vocab_size)
        
    def forward(self, x, mask=None):
        x = self.embedding(x) * math.sqrt(self.d_model)
        x = self.transformer(x, mask=mask)
        return self.output_proj(x)

# Create toy dataset
text = "hello world this is a simple example"
chars = sorted(list(set(text)))
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}
vocab_size = len(chars)

print(f"Vocabulary: {chars}")
print(f"Vocab size: {vocab_size}")

# Convert text to indices
encoded = [char_to_idx[ch] for ch in text]
print(f"Encoded text: {encoded[:10]}...")  # First 10 tokens

# Create training data (input: n chars, target: next char)
def create_sequences(encoded_text, seq_len):
    inputs, targets = [], []
    for i in range(len(encoded_text) - seq_len):
        inputs.append(encoded_text[i:i+seq_len])
        targets.append(encoded_text[i+1:i+seq_len+1])  # Shifted by 1
    return torch.tensor(inputs), torch.tensor(targets)

seq_len = 8
inputs, targets = create_sequences(encoded, seq_len)
print(f"\nTraining data shapes: {inputs.shape}, {targets.shape}")
print(f"Example input: {inputs[0]} -> target: {targets[0]}")

# Create and train model
model = CharacterLevelTransformer(vocab_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop (simplified)
model.train()
for epoch in range(100):
    optimizer.zero_grad()
    
    # Forward pass
    output = model(inputs)  # Shape: (batch, seq_len, vocab_size)
    
    # Reshape for loss calculation
    loss = criterion(output.view(-1, vocab_size), targets.view(-1))
    
    # Backward pass
    loss.backward()
    optimizer.step()
    
    if epoch % 20 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

# Generate text
def generate_text(model, start_text, max_len=20, temperature=1.0):
    model.eval()
    with torch.no_grad():
        # Convert start text to indices
        current = [char_to_idx[ch] for ch in start_text]
        generated = current.copy()
        
        for _ in range(max_len):
            # Take last seq_len characters
            input_seq = torch.tensor(current[-seq_len:]).unsqueeze(0)
            
            # Get predictions
            output = model(input_seq)
            
            # Get next token probabilities
            next_token_logits = output[0, -1, :] / temperature
            probs = F.softmax(next_token_logits, dim=-1)
            
            # Sample next token
            next_token = torch.multinomial(probs, 1).item()
            
            current.append(next_token)
            generated.append(next_token)
            
            # Stop if we generate the same character many times (simple stop condition)
            if len(set(generated[-5:])) == 1:
                break
        
        return ''.join([idx_to_char[idx] for idx in generated])

# Generate some text
start = "hello"
generated = generate_text(model, start, max_len=15)
print(f"\nGenerated text starting with '{start}': '{generated}'")

print("\nNext-token prediction demo completed!")