# Complete GPT Model

This notebook provides an interactive guide to understanding this component of GPT.


In [None]:
# Import necessary libraries
import torch
import numpy as np
import matplotlib.pyplot as plt
import sys
import os

# Add project root to path
project_root = os.path.dirname(os.path.dirname(os.path.abspath('')))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Import our GPT model and components
from src.model.gpt import GPTModel, TransformerBlock
from src.model.embeddings import TokenEmbedding, PositionalEmbedding
from src.config import GPTConfig
from src.data.tokenizer import get_tokenizer
import tiktoken

## Complete GPT Model

This notebook demonstrates how all components come together to form a complete GPT model: embeddings, transformer blocks, and output layers.

### 1. Token and Positional Embeddings

In [None]:
# Create embeddings
vocab_size = 50257
embedding_dim = 128
context_length = 128

token_embedding = TokenEmbedding(vocab_size, embedding_dim)
position_embedding = PositionalEmbedding(context_length, embedding_dim)

print(f"Token Embedding:")
print(f"  Vocabulary size: {vocab_size}")
print(f"  Embedding dimension: {embedding_dim}")
print(f"  Parameters: {sum(p.numel() for p in token_embedding.parameters()):,}")

print(f"\nPositional Embedding:")
print(f"  Context length: {context_length}")
print(f"  Embedding dimension: {embedding_dim}")
print(f"  Parameters: {sum(p.numel() for p in position_embedding.parameters()):,}")

# Test embeddings
token_ids = torch.tensor([[1, 2, 3, 4, 5]])
positions = torch.arange(5)

token_embeds = token_embedding(token_ids)
pos_embeds = position_embedding(positions)

print(f"\nToken embeddings shape: {token_embeds.shape}")
print(f"Position embeddings shape: {pos_embeds.shape}")

# Combined embeddings
combined = token_embeds + pos_embeds.unsqueeze(0)
print(f"Combined embeddings shape: {combined.shape}")

### 2. Creating a GPT Model

In [None]:
# Create a small GPT model configuration
config = GPTConfig(
    vocab_size=50257,
    context_length=128,
    embedding_dimension=256,
    number_of_heads=4,
    number_of_layers=4,
    dropout_rate=0.1
)

print("GPT Model Configuration:")
print(f"  Vocabulary size: {config.vocab_size}")
print(f"  Context length: {config.context_length}")
print(f"  Embedding dimension: {config.embedding_dimension}")
print(f"  Number of heads: {config.number_of_heads}")
print(f"  Number of layers: {config.number_of_layers}")
print(f"  Dropout rate: {config.dropout_rate}")

# Create the model
model = GPTModel(config)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel Statistics:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Model size (FP32): {total_params * 4 / 1024 / 1024:.2f} MB")

In [None]:
# Break down parameter count by component
print("Parameter breakdown:")
print(f"  Token embedding: {sum(p.numel() for p in model.token_embedding.parameters()):,}")
print(f"  Position embedding: {sum(p.numel() for p in model.position_embedding.parameters()):,}")
print(f"  Transformer blocks: {sum(p.numel() for p in model.transformer_blocks.parameters()):,}")
print(f"  Final layer norm: {sum(p.numel() for p in model.final_norm.parameters()):,}")
print(f"  Language model head: {sum(p.numel() for p in model.lm_head.parameters()):,}")

# Verify
total_breakdown = (
    sum(p.numel() for p in model.token_embedding.parameters()) +
    sum(p.numel() for p in model.position_embedding.parameters()) +
    sum(p.numel() for p in model.transformer_blocks.parameters()) +
    sum(p.numel() for p in model.final_norm.parameters()) +
    sum(p.numel() for p in model.lm_head.parameters())
)
print(f"\nTotal from breakdown: {total_breakdown:,}")
print(f"Matches total: {total_breakdown == total_params}")

### 3. Forward Pass

In [None]:
# Prepare input
tokenizer = get_tokenizer("gpt2")
text = "The quick brown fox"
token_ids = tokenizer.encode(text)

print(f"Text: '{text}'")
print(f"Token IDs: {token_ids}")
print(f"Number of tokens: {len(token_ids)}")

# Convert to tensor
input_ids = torch.tensor([token_ids], dtype=torch.long)
print(f"Input shape: {input_ids.shape}")

# Forward pass
model.eval()
with torch.no_grad():
    logits = model(input_ids)

print(f"\nOutput logits shape: {logits.shape}")
print(f"  Batch size: {logits.shape[0]}")
print(f"  Sequence length: {logits.shape[1]}")
print(f"  Vocabulary size: {logits.shape[2]}")

# Get predictions
probs = torch.softmax(logits, dim=-1)
predicted_token_ids = torch.argmax(logits, dim=-1)

print(f"\nPredicted token IDs: {predicted_token_ids[0].tolist()}")
print(f"Predicted tokens: {[tokenizer.decode([tid]) for tid in predicted_token_ids[0].tolist()]}")

### 4. Understanding the Output

The model outputs logits (unnormalized probabilities) for each position in the sequence. These represent the model's prediction for the next token at each position.

In [None]:
# Examine the logits for the last position (next token prediction)
last_position_logits = logits[0, -1, :]  # [vocab_size]
last_position_probs = torch.softmax(last_position_logits, dim=-1)

# Get top-10 most likely next tokens
top_k = 10
top_probs, top_indices = torch.topk(last_position_probs, top_k)

print(f"Top {top_k} most likely next tokens:")
for i, (prob, idx) in enumerate(zip(top_probs, top_indices)):
    token = tokenizer.decode([idx.item()])
    print(f"  {i+1}. '{token}' (probability: {prob.item():.4f})")

In [None]:
# Visualize the probability distribution
plt.figure(figsize=(12, 6))
sorted_probs, sorted_indices = torch.sort(last_position_probs, descending=True)
top_20_probs = sorted_probs[:20]
top_20_indices = sorted_indices[:20]
top_20_tokens = [tokenizer.decode([idx.item()]) for idx in top_20_indices]

plt.bar(range(len(top_20_probs)), top_20_probs.numpy())
plt.xticks(range(len(top_20_tokens)), top_20_tokens, rotation=45, ha='right')
plt.xlabel('Token', fontsize=12)
plt.ylabel('Probability', fontsize=12)
plt.title('Top 20 Next Token Predictions (Untrained Model)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("Note: This is an untrained model, so predictions are essentially random.")

### 5. Model Architecture Visualization

In [None]:
# Print model architecture
print("GPT Model Architecture:")
print("=" * 60)
print("\n1. Input: Token IDs [batch_size, sequence_length]")
print("\n2. Token Embedding:")
print(f"   - Maps {config.vocab_size} tokens to {config.embedding_dimension}D vectors")
print("\n3. Positional Embedding:")
print(f"   - Adds position information (max {config.context_length} positions)")
print("\n4. Combined Embeddings:")
print("   - token_embedding + positional_embedding")
print(f"   - Shape: [batch_size, sequence_length, {config.embedding_dimension}]")
print("\n5. Transformer Blocks (x{}):".format(config.number_of_layers))
print("   - Multi-Head Attention")
print("   - Feed-Forward Network")
print("   - Layer Normalization (x2)")
print("   - Residual Connections (x2)")
print("\n6. Final Layer Normalization")
print("\n7. Language Model Head:")
print(f"   - Linear layer: {config.embedding_dimension} â†’ {config.vocab_size}")
print("\n8. Output: Logits [batch_size, sequence_length, vocab_size]")
print("=" * 60)

### 6. Testing with Different Input Sizes

In [None]:
# Test with different sequence lengths
test_texts = [
    "Hello",
    "The quick brown fox jumps",
    "Once upon a time, in a land far away, there lived"
]

print("Testing with different input lengths:\n")
for text in test_texts:
    token_ids = tokenizer.encode(text)
    input_ids = torch.tensor([token_ids], dtype=torch.long)
    
    with torch.no_grad():
        logits = model(input_ids)
    
    print(f"Text: '{text}'")
    print(f"  Input length: {len(token_ids)} tokens")
    print(f"  Output shape: {logits.shape}")
    print()

### 7. Model Components Summary

In [None]:
# List all model components
print("Model Components:")
print("-" * 60)
for name, module in model.named_children():
    num_params = sum(p.numel() for p in module.parameters())
    print(f"{name:20s} : {num_params:>12,} parameters")
    if hasattr(module, '__len__'):
        print(f"{'':20s}   ({len(module)} sub-modules)")

print("\n" + "=" * 60)
print(f"Total: {total_params:>12,} parameters")