# PyTorch Refresher: Essential Concepts for Transformers

This notebook reviews PyTorch fundamentals you'll need for building transformers. If you're already comfortable with PyTorch, you can skim or skip this.

## Topics Covered

1. Tensor creation and operations
2. Broadcasting and shapes
3. Matrix operations (crucial for attention!)
4. Autograd and gradients
5. Building with `nn.Module`
6. Forward and backward passes
7. Parameter management

**Estimated time:** 45 minutes

Let's begin!

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

print(f"PyTorch version: {torch.__version__}")

# Set device
if torch.cuda.is_available():
    device = torch.device("cuda")
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

## 1. Tensor Creation and Basic Operations

Tensors are the fundamental data structure in PyTorch - think of them as multi-dimensional arrays.

In [None]:
# Creating tensors
a = torch.tensor([1, 2, 3])  # From list
b = torch.zeros(3, 4)         # All zeros
c = torch.ones(2, 3)          # All ones
d = torch.randn(2, 3)         # Random normal distribution
e = torch.arange(0, 10, 2)    # Range with step

print("Tensor from list:", a)
print("\nZeros (3x4):\n", b)
print("\nOnes (2x3):\n", c)
print("\nRandom normal (2x3):\n", d)
print("\nRange [0, 10) step 2:", e)

# Tensor properties
print(f"\nShape: {d.shape}")
print(f"Dimensions: {d.dim()}")
print(f"Data type: {d.dtype}")
print(f"Device: {d.device}")

### Creating Tensors for Transformers

In transformers, we typically work with 3D tensors: `(batch_size, sequence_length, dimension)`

In [None]:
# Typical transformer tensor shapes
batch_size = 8
seq_len = 128
d_model = 512

# Random embeddings
embeddings = torch.randn(batch_size, seq_len, d_model)

print(f"Embeddings shape: {embeddings.shape}")
print(f"  Batch size: {embeddings.size(0)}")
print(f"  Sequence length: {embeddings.size(1)}")
print(f"  Model dimension: {embeddings.size(2)}")

# Alternative indexing (negative indices)
print(f"\nUsing negative indexing:")
print(f"  Last dimension (d_model): {embeddings.size(-1)}")
print(f"  Second to last (seq_len): {embeddings.size(-2)}")

## 2. Broadcasting and Shape Manipulation

Broadcasting allows operations on tensors of different shapes. This is crucial for transformers!

In [None]:
# Broadcasting examples
x = torch.randn(3, 4)
y = torch.randn(4)  # Note: different shape!

# Broadcasting: (3, 4) + (4,) → (3, 4)
z = x + y
print(f"x shape: {x.shape}")
print(f"y shape: {y.shape}")
print(f"x + y shape: {z.shape}")
print(f"\nBroadcasting worked! y was expanded to (3, 4)\n")

# Common reshape operations
a = torch.randn(2, 3, 4)
print(f"Original shape: {a.shape}")

# View (reshape)
b = a.view(2, 12)  # Flatten last two dimensions
print(f"After view(2, 12): {b.shape}")

# Unsqueeze (add dimension)
c = a.unsqueeze(0)  # Add batch dimension at position 0
print(f"After unsqueeze(0): {c.shape}")

# Squeeze (remove dimension of size 1)
d = c.squeeze(0)  # Remove dimension 0 (size 1)
print(f"After squeeze(0): {d.shape}")

# Transpose
e = a.transpose(1, 2)  # Swap dimensions 1 and 2
print(f"After transpose(1, 2): {e.shape}")

### Critical: Transpose for Attention

The most important operation in attention is transposing the key matrix!

In [None]:
# Simulating Q @ K^T in attention
batch_size = 2
seq_len = 4
d_k = 8

Q = torch.randn(batch_size, seq_len, d_k)
K = torch.randn(batch_size, seq_len, d_k)

print(f"Q shape: {Q.shape}  (batch, seq_len, d_k)")
print(f"K shape: {K.shape}  (batch, seq_len, d_k)")

# Transpose last two dimensions of K
K_T = K.transpose(-2, -1)  # Or K.transpose(1, 2)
print(f"K^T shape: {K_T.shape}  (batch, d_k, seq_len)")

# Compute Q @ K^T
scores = Q @ K_T  # @ is matrix multiplication
print(f"Q @ K^T shape: {scores.shape}  (batch, seq_len, seq_len)")

print("\n✓ This creates the attention score matrix!")

## 3. Matrix Operations

Matrix multiplication is the core of transformers. Let's master it!

In [None]:
# Matrix multiplication operators
A = torch.randn(3, 4)
B = torch.randn(4, 5)

# Three ways to do matrix multiplication
C1 = torch.matmul(A, B)
C2 = A @ B  # Preferred!
C3 = torch.mm(A, B)  # Only works for 2D

print(f"A shape: {A.shape}")
print(f"B shape: {B.shape}")
print(f"A @ B shape: {C2.shape}")

# Batched matrix multiplication
A_batch = torch.randn(8, 3, 4)  # 8 batches of (3x4) matrices
B_batch = torch.randn(8, 4, 5)  # 8 batches of (4x5) matrices

C_batch = A_batch @ B_batch  # Batched: (8, 3, 4) @ (8, 4, 5) = (8, 3, 5)
print(f"\nBatched matmul: {A_batch.shape} @ {B_batch.shape} = {C_batch.shape}")

# Element-wise multiplication (NOT matrix multiplication!)
D = torch.randn(3, 4)
E = torch.randn(3, 4)
F = D * E  # Element-wise (Hadamard product)
print(f"\nElement-wise: {D.shape} * {E.shape} = {F.shape}")

## 4. Autograd: Automatic Differentiation

PyTorch automatically computes gradients for backpropagation. This is what makes training possible!

In [None]:
# Enable gradient tracking
x = torch.tensor([2.0], requires_grad=True)
print(f"x: {x}")
print(f"Requires grad: {x.requires_grad}")

# Forward pass: compute y = x^2
y = x ** 2
print(f"\ny = x^2: {y}")

# Backward pass: compute dy/dx
y.backward()
print(f"\nGradient dy/dx: {x.grad}")
print(f"Expected: 2*x = 2*2 = 4 ✓")

### More Complex Example: Simple Neural Network

In [None]:
# Create parameters
W = torch.randn(3, 4, requires_grad=True)
b = torch.randn(4, requires_grad=True)
x = torch.randn(2, 3)  # Input (batch_size=2, input_dim=3)

# Forward pass: y = x @ W + b
y = x @ W + b  # (2, 3) @ (3, 4) = (2, 4)
print(f"Output shape: {y.shape}")

# Compute a simple loss (mean)
loss = y.mean()
print(f"Loss: {loss.item()}")

# Backward pass
loss.backward()

# Check gradients
print(f"\nW.grad shape: {W.grad.shape}")
print(f"b.grad shape: {b.grad.shape}")
print("\n✓ Gradients computed automatically!")

## 5. Building with nn.Module

`nn.Module` is the base class for all neural network modules. You'll use this constantly!

In [None]:
# Simple linear layer example
class SimpleLinear(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        # Define parameters
        self.weight = nn.Parameter(torch.randn(input_dim, output_dim))
        self.bias = nn.Parameter(torch.randn(output_dim))
    
    def forward(self, x):
        # x: (batch_size, input_dim)
        return x @ self.weight + self.bias  # (batch_size, output_dim)

# Create and test
layer = SimpleLinear(input_dim=10, output_dim=5)
x = torch.randn(32, 10)  # Batch of 32 examples
output = layer(x)  # Calls forward() automatically

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"\nParameters:")
for name, param in layer.named_parameters():
    print(f"  {name}: {param.shape}")

### Using Built-in Layers

In [None]:
# PyTorch provides optimized implementations
linear = nn.Linear(10, 5)  # Same as our SimpleLinear

x = torch.randn(32, 10)
output = linear(x)

print(f"Output shape: {output.shape}")
print(f"\nBuilt-in Linear layer parameters:")
for name, param in linear.named_parameters():
    print(f"  {name}: {param.shape}")

### Building a Multi-Layer Module

In [None]:
class SimpleNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.layer1 = nn.Linear(input_dim, hidden_dim)
        self.layer2 = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        # x: (batch_size, input_dim)
        x = self.layer1(x)  # (batch_size, hidden_dim)
        x = F.relu(x)       # Activation function
        x = self.layer2(x)  # (batch_size, output_dim)
        return x

# Create and test
model = SimpleNet(input_dim=20, hidden_dim=64, output_dim=10)
x = torch.randn(16, 20)
output = model(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"\nAll parameters:")
for name, param in model.named_parameters():
    print(f"  {name}: {param.shape}")

## 6. Forward and Backward Passes

The training loop consists of:
1. Forward pass (compute predictions)
2. Compute loss
3. Backward pass (compute gradients)
4. Update parameters

In [None]:
# Create a simple model
model = nn.Linear(10, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# Dummy data
x = torch.randn(32, 10)
y_true = torch.randn(32, 1)

print("Training for 5 iterations:\n")

for i in range(5):
    # 1. Forward pass
    y_pred = model(x)
    
    # 2. Compute loss
    loss = F.mse_loss(y_pred, y_true)
    
    # 3. Backward pass
    optimizer.zero_grad()  # Clear previous gradients
    loss.backward()        # Compute gradients
    
    # 4. Update parameters
    optimizer.step()
    
    print(f"Iteration {i+1}: Loss = {loss.item():.6f}")

print("\n✓ Loss should decrease over iterations")

## 7. Essential Operations for Transformers

Let's put it all together with operations you'll use constantly in transformer code.

### Softmax

Critical for attention weights!

In [None]:
# Softmax converts scores to probabilities
scores = torch.randn(2, 4, 4)  # (batch, seq_len, seq_len)

# Apply softmax over last dimension (keys)
attention = F.softmax(scores, dim=-1)

print(f"Scores shape: {scores.shape}")
print(f"Attention shape: {attention.shape}")

# Verify: sum over last dimension should be 1
sums = attention.sum(dim=-1)
print(f"\nSum over last dim (should be all 1.0):")
print(sums[0])  # First batch

# All values should be between 0 and 1
print(f"\nMin value: {attention.min().item():.6f}")
print(f"Max value: {attention.max().item():.6f}")

### Masking

Causal masking prevents attending to future tokens.

In [None]:
# Create causal mask (lower triangular)
seq_len = 5
mask = torch.tril(torch.ones(seq_len, seq_len))

print("Causal mask (1 = allowed, 0 = blocked):")
print(mask)

# Apply mask to attention scores
scores = torch.randn(1, seq_len, seq_len)
print(f"\nOriginal scores:\n{scores[0]}")

# Mask out future positions by setting to -inf
scores_masked = scores.masked_fill(mask == 0, float('-inf'))
print(f"\nMasked scores (-inf prevents attention):\n{scores_masked[0]}")

# After softmax, -inf becomes 0
attention = F.softmax(scores_masked, dim=-1)
print(f"\nAttention weights (future positions are 0):\n{attention[0]}")

### Dropout

Regularization technique used throughout transformers.

In [None]:
# Dropout randomly zeros elements during training
x = torch.ones(4, 4)
dropout = nn.Dropout(p=0.5)  # Drop 50% of elements

# Training mode
dropout.train()
x_dropped = dropout(x)

print("Original tensor (all ones):")
print(x)
print("\nAfter dropout (training mode):")
print(x_dropped)
print("Note: Remaining values are scaled up (× 2) to maintain expected sum")

# Evaluation mode (no dropout)
dropout.eval()
x_eval = dropout(x)
print("\nAfter dropout (eval mode):")
print(x_eval)
print("No dropout in eval mode ✓")

### Layer Normalization

Stabilizes training in transformers.

In [None]:
# Layer normalization normalizes across features
batch_size = 2
seq_len = 3
d_model = 4

x = torch.randn(batch_size, seq_len, d_model)
layer_norm = nn.LayerNorm(d_model)

x_normalized = layer_norm(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {x_normalized.shape}")

# Check normalization (mean ≈ 0, std ≈ 1 over last dimension)
print(f"\nPer-position statistics:")
print(f"Mean (should be ~0): {x_normalized[0, 0].mean().item():.6f}")
print(f"Std (should be ~1): {x_normalized[0, 0].std(unbiased=False).item():.6f}")

## 8. Putting It All Together: Mini Attention

Let's implement a minimal attention mechanism to tie everything together!

In [None]:
def simple_attention(Q, K, V):
    """
    Simplified attention mechanism.
    
    Args:
        Q: Query (batch, seq_len, d_k)
        K: Key (batch, seq_len, d_k)
        V: Value (batch, seq_len, d_v)
    
    Returns:
        output: (batch, seq_len, d_v)
        attention: (batch, seq_len, seq_len)
    """
    # 1. Compute attention scores: Q @ K^T
    d_k = Q.size(-1)
    scores = Q @ K.transpose(-2, -1)  # (batch, seq_len, seq_len)
    
    # 2. Scale
    scores = scores / math.sqrt(d_k)
    
    # 3. Softmax to get attention weights
    attention = F.softmax(scores, dim=-1)  # (batch, seq_len, seq_len)
    
    # 4. Apply attention to values
    output = attention @ V  # (batch, seq_len, d_v)
    
    return output, attention

# Test it!
batch_size = 2
seq_len = 4
d_k = 8

Q = torch.randn(batch_size, seq_len, d_k)
K = torch.randn(batch_size, seq_len, d_k)
V = torch.randn(batch_size, seq_len, d_k)

output, attention = simple_attention(Q, K, V)

print(f"Q shape: {Q.shape}")
print(f"K shape: {K.shape}")
print(f"V shape: {V.shape}")
print(f"\nOutput shape: {output.shape}")
print(f"Attention shape: {attention.shape}")

# Verify attention weights sum to 1
print(f"\nAttention weights sum: {attention[0, 0].sum().item():.6f}")
print("\n✓ You just implemented attention! This is the core of transformers!")

## Quick Quiz

Test your understanding:

In [None]:
# Question 1: What's the output shape?
A = torch.randn(32, 128, 512)  # (batch, seq_len, d_model)
W = torch.randn(512, 256)       # (d_model, d_out)
result = A @ W
print(f"Q1: Shape of A @ W is: {result.shape}")
print("Expected: (32, 128, 256)\n")

# Question 2: What does transpose do?
B = torch.randn(32, 128, 64)
B_T = B.transpose(-2, -1)
print(f"Q2: Shape after transpose(-2, -1): {B_T.shape}")
print("Expected: (32, 64, 128)\n")

# Question 3: Softmax output properties
scores = torch.randn(32, 128, 128)
probs = F.softmax(scores, dim=-1)
print(f"Q3: Sum of probabilities (dim=-1): {probs[0, 0].sum().item():.6f}")
print("Expected: 1.000000\n")

print("✓ If all answers match, you're ready for Module 01!")

## Summary

You now know:

✓ Creating and manipulating tensors  
✓ Broadcasting and shape operations  
✓ Matrix multiplication (`@`)  
✓ Autograd for automatic differentiation  
✓ Building modules with `nn.Module`  
✓ Forward/backward passes  
✓ Essential operations: softmax, masking, dropout, layer norm  
✓ **You implemented attention from scratch!**  

## Next Steps

1. Review any sections you found challenging
2. Try modifying the simple attention code
3. Proceed to Module 01 where we'll build proper attention with all the bells and whistles!

**You're ready to build transformers!**