# Lecture 5 Notebook: Grok Transformer

**Goal**: Implement the Grok transformer building blocks in detail.

You'll learn:
- RMSNorm (replaces LayerNorm)
- Rotary Embedding (RoPE)
- Multi-Head Attention with GQA, RoPE, and soft capping
- SwiGLU FFN (DenseBlock)
- DecoderLayer with residual connections
- Transformer stack
- make_recsys_attn_mask for candidate isolation

---

## Setup

Import PyTorch and set up configuration.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Configuration
EMB_SIZE = 64
NUM_LAYERS = 2
NUM_Q_HEADS = 4
NUM_KV_HEADS = 2
KEY_SIZE = 16

print(f"Embedding dimension: {EMB_SIZE}")
print(f"Number of layers: {NUM_LAYERS}")
print(f"Number of Q heads: {NUM_Q_HEADS}")
print(f"Number of KV heads: {NUM_KV_HEADS}")
print(f"Key size: {KEY_SIZE}")

## RMSNorm

Root Mean Square Layer Normalization.

**Key difference from LayerNorm**:
- LayerNorm: Centers the data first (subtract mean), then scales
- RMSNorm: Only scales, no centering

**Why RMSNorm?**
- Simpler (no mean computation)
- Works just as well in practice
- Faster (one less pass over the data)

In [None]:
class RMSNorm(nn.Module):
    """
    Root Mean Square Layer Normalization.
    
    RMSNorm(x) = x * rsqrt(mean(x^2) + eps) * scale
    
    Args:
        dim: Feature dimension
        eps: Numerical stability constant
    """
    
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.dim = dim
        self.eps = eps
        
        # Learnable scale parameter
        self.scale = nn.Parameter(torch.ones(dim))
    
    def forward(self, x):
        """
        Args:
            x: [B, ..., D] input tensor of any shape
        
        Returns:
            Normalized tensor of same shape
        """
        # Compute RMS: sqrt(mean(x^2))
        rms = torch.sqrt(torch.mean(x.float() ** 2, dim=-1, keepdim=True) + self.eps)
        
        # Normalize and apply scale
        return x / rms * self.scale

# Test RMSNorm
print("Testing RMSNorm...")
norm = RMSNorm(EMB_SIZE)

# Test with different shapes
x1 = torch.randn(4, 10, EMB_SIZE)  # [B, T, D]
x2 = torch.randn(4, EMB_SIZE)      # [B, D]
x3 = torch.randn(EMB_SIZE)         # [D]

y1 = norm(x1)
y2 = norm(x2)

print(f"Input shape 1: {x1.shape} -> Output shape: {y1.shape}")
print(f"Input shape 2: {x2.shape} -> Output shape: {y2.shape}")

# Verify RMS is ~1.0 after normalization
rms_after = torch.sqrt(torch.mean(y1.float() ** 2, dim=-1))
print(f"\nRMS after normalization (should be ~1.0): {rms_after.mean():.6f}")

assert y1.shape == x1.shape
assert y2.shape == x2.shape
print("✓ RMSNorm working correctly!")

## Rotary Embedding (RoPE)

**What is RoPE?**
- Encodes positions using rotation matrices
- Allows the model to understand **relative positions**
- No absolute position embeddings needed

**How it works:**
1. Split embedding into pairs of dimensions
2. Rotate each pair by an angle proportional to position
3. The rotation encodes relative distance

Reference: https://arxiv.org/abs/2104.09864

In [None]:
class RotaryEmbedding(nn.Module):
    """
    Rotary Position Embedding (RoPE).
    
    Encodes positions using rotation matrices.
    
    For position p and dimension d (d must be even):
    - Embedding dimensions [d-2, d-1] are rotated by angle = p * theta_d
    - where theta_d = base^(-2*d / dim)
    
    Args:
        dim: Embedding dimension (must be even)
        base: Base for frequency computation (default: 10000)
    """
    
    def __init__(self, dim, base=10000):
        super().__init__()
        self.dim = dim
        self.base = base
        
        # Compute frequency for each dimension pair
        # theta_d = base^(-2*d / dim)
        assert dim % 2 == 0, "RoPE dim must be even"
        freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('freqs', freqs)
    
    def rotate_half(self, x):
        """
        Rotate half of the embedding.
        
        x = [x_0, x_1, x_2, x_3, ...]
        output = [-x_1, x_0, -x_3, x_2, ...]
        """
        x1, x2 = x[..., ::2], x[..., 1::2]
        return torch.cat([-x2, x1], dim=-1)
    
    def forward(self, x, start_pos=0):
        """
        Apply rotary embedding to input.
        
        Args:
            x: [B, T, D] or [B, H, T, D] input tensor
            start_pos: Starting position offset
        
        Returns:
            Rotated embeddings of same shape
        """
        B, T = x.shape[0], x.shape[1]
        
        # Compute positions: [T]
        positions = torch.arange(start_pos, start_pos + T, device=x.device).float()
        
        # Compute phase for each position and dimension: [T, D/2]
        # phase = position * frequency
        phase = torch.einsum('p,f->pf', positions, self.freqs)
        
        # Expand to full embedding dimension: [T, D]
        # Interleave cos and sin
        cos = torch.cos(phase)
        sin = torch.sin(phase)
        
        # Build full cos/sin for this shape
        cos = self._interleave(cos, x.shape[-1])
        sin = self._interleave(sin, x.shape[-1])
        
        # Apply rotation using the formula:
        # RoPE(x_pos) = x * cos(theta) + rotate_half(x) * sin(theta)
        return x * cos + self.rotate_half(x) * sin
    
    def _interleave(self, half, target_dim):
        """Interleave half-dimension to match full dimension."""
        result = torch.zeros(*half.shape[:-1], target_dim, device=half.device)
        result[..., ::2] = half
        result[..., 1::2] = half
        return result

# Test RotaryEmbedding
print("Testing RotaryEmbedding...")
rope = RotaryEmbedding(EMB_SIZE)

x = torch.randn(2, 10, EMB_SIZE)
y = rope(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")

# Verify relative position encoding
print("\n--- Verifying Relative Position Encoding ---")
x_test = torch.zeros(1, 3, EMB_SIZE)  # Simple test
x_test[0, 1, 0] = 1.0  # Only activate position 1

y_at_pos0 = rope(x_test.clone(), start_pos=0)
y_at_pos5 = rope(x_test.clone(), start_pos=5)

print("Input: Only position 1 is activated")
print(f"At position 0: {y_at_pos0[0, 0, :4].tolist()}")
print(f"At position 5: {y_at_pos5[0, 5, :4].tolist()}")
print("Note: Different positions produce different embeddings!")

assert y.shape == x.shape
print("\n✓ RotaryEmbedding working correctly!")

## Multi-Head Attention with GQA

**Grouped Query Attention (GQA)**:
- Multiple **query** heads (Q)
- Fewer **key/value** heads (K, V)
- Each K/V head is shared by multiple Q heads

**Benefits**:
- Memory efficient (fewer K/V projections and attention maps)
- Quality similar to full attention

**Example**:
- Q heads: 8, KV heads: 2
- Each KV head is shared by 4 Q heads

In [None]:
class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention with Grouped Query Attention (GQA).
    
    Features:
    - Grouped Query Attention: Fewer KV heads than Q heads
    - Rotary Embedding (RoPE) for position encoding
    - Soft attention capping: tanh(x/c) * c
    
    Args:
        dim: Model dimension
        num_q_heads: Number of query heads
        num_kv_heads: Number of key/value heads (usually < num_q_heads)
        key_size: Dimension per head (default: dim // num_q_heads)
        cap: Attention logits cap (soft capping)
    """
    
    def __init__(self, dim, num_q_heads, num_kv_heads, key_size=None, cap=30.0):
        super().__init__()
        self.dim = dim
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.key_size = key_size or (dim // num_q_heads)
self.cap = cap

        # Compute value size
        self.value_size = self.key_size

        # Q projection: dim -> num_q_heads * key_size
        self.q_proj = nn.Linear(dim, num_q_heads * self.key_size, bias=False)
        
        # KV projection: dim -> num_kv_heads * key_size (and value_size)
        self.k_proj = nn.Linear(dim, num_kv_heads * self.key_size, bias=False)
        self.v_proj = nn.Linear(dim, num_kv_heads * self.value_size, bias=False)
        
        # Output projection: num_q_heads * value_size -> dim
        self.o_proj = nn.Linear(num_q_heads * self.value_size, dim, bias=False)
        
        # RoPE
        self.rope = RotaryEmbedding(self.key_size)

    def forward(self, x, attention_mask=None, key_padding_mask=None):
        """
        Args:
            x: [B, T, D] input embeddings
            attention_mask: [B, T] or [B, 1, T, T] mask (True=attend)
            key_padding_mask: [B, T] mask for padding (True=mask)
        
        Returns:
            Output embeddings [B, T, D]
        """
        B, T, D = x.shape

        # Project to Q, K, V
        q = self.q_proj(x).view(B, T, self.num_q_heads, self.key_size)
        k = self.k_proj(x).view(B, T, self.num_kv_heads, self.key_size)
        v = self.v_proj(x).view(B, T, self.num_kv_heads, self.value_size)

        # Apply RoPE to Q and K
        q = self.rope(q)
        k = self.rope(k)

        # GQA: Repeat KV heads to match Q heads
        # [B, T, num_kv_heads, D] -> [B, T, num_q_heads, D]
        if self.num_kv_heads < self.num_q_heads:
            repeats = self.num_q_heads // self.num_kv_heads
            k = k.repeat_interleave(repeats, dim=2)
            v = v.repeat_interleave(repeats, dim=2)

        # Transpose for attention: [B, H, T, D]
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # Compute attention scores
        # Q @ K^T / sqrt(key_size)
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.key_size ** 0.5)

        # Soft attention capping: tanh(x/c) * c
        # Prevents extreme attention weights
        scores = self.cap * torch.tanh(scores / self.cap)

        # Apply masks
        if key_padding_mask is not None:
            # Convert padding mask to attention mask
            scores = scores.masked_fill(~key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf'))

        if attention_mask is not None:
            scores = scores.masked_fill(~attention_mask, float('-inf'))

        # Softmax
        attn_weights = F.softmax(scores.float(), dim=-1)

        # Weighted sum
        output = torch.matmul(attn_weights, v)

        # Concatenate heads and project
        output = output.transpose(1, 2).contiguous().view(B, T, D)
        output = self.o_proj(output)

        return output

# Test MultiHeadAttention
print("Testing MultiHeadAttention with GQA...")
mha = MultiHeadAttention(
    dim=EMB_SIZE,
    num_q_heads=NUM_Q_HEADS,
    num_kv_heads=NUM_KV_HEADS,
    key_size=KEY_SIZE
)

B, T = 2, 8
x = torch.randn(B, T, EMB_SIZE)
padding_mask = torch.ones(B, T, dtype=torch.bool)
padding_mask[:, 6:] = False  # Last 2 positions are padding

y = mha(x, key_padding_mask=~padding_mask)

print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")
print(f"GQA ratio: {NUM_Q_HEADS} Q heads / {NUM_KV_HEADS} KV heads = {NUM_Q_HEADS // NUM_KV_HEADS}x sharing")

assert y.shape == x.shape
print("✓ MultiHeadAttention working correctly!")

## SwiGLU Feed-Forward Network

**SwiGLU** = Swish activation + Gated Linear Unit.

**Architecture**:
```
FFN(x) = (SiLU(x @ W_1) * x @ W_2) @ W_3
```

**Components**:
- **SiLU** (Swish): x * sigmoid(x)
- **Gate**: Element-wise multiplication creates a "veto" mechanism
- **Widening factor**: Hidden dimension is 4/3 of input (rounded to multiple of 8)

In [None]:
def ffn_size(emb_size, widening_factor=4.0):
    """
    Compute feed-forward hidden size with widening factor.
    
    Grok uses: hidden = int(widening_factor * emb_size) * 2 // 3
    Then rounds to multiple of 8.
    """
    hidden = int(widening_factor * emb_size) * 2 // 3
    hidden = hidden + (8 - hidden) % 8  # Round to multiple of 8
    return hidden

class SwiGLUFFN(nn.Module):
    """
    SwiGLU Feed-Forward Network (DenseBlock).
    
    FFN(x) = (SiLU(x @ W_1) * x @ W_2) @ W_3
    
    Args:
        dim: Input dimension
        hidden_dim: Hidden dimension (default: computed from widening factor)
        widening_factor: How much to widen (default: 4.0)
    """
    
    def __init__(self, dim, hidden_dim=None, widening_factor=4.0):
        super().__init__()
        
        if hidden_dim is None:
            hidden_dim = ffn_size(dim, widening_factor)
        
        self.dim = dim
        self.hidden_dim = hidden_dim
        
        # W_1: dim -> hidden
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        
        # W_2: dim -> hidden (gate projection)
        self.w2 = nn.Linear(dim, hidden_dim, bias=False)
        
        # W_3: hidden -> dim (output projection)
        self.w3 = nn.Linear(hidden_dim, dim, bias=False)
    
    def forward(self, x):
        """
        Args:
            x: [B, T, D] input
        
        Returns:
            Output [B, T, D]
        """
        # SiLU(x @ W_1)
        gate = F.silu(self.w1(x))
        
        # x @ W_2
        up = self.w2(x)
        
        # Element-wise multiply (gate)
        gated = gate * up
        
        # @ W_3
        return self.w3(gated)

# Test SwiGLUFFN
print("Testing SwiGLUFFN...")

ffn = SwiGLUFFN(EMB_SIZE)

print(f"Input dim: {EMB_SIZE}")
print(f"Hidden dim: {ffn.hidden_dim}")
print(f"Widening factor: {(ffn.hidden_dim * 3) / (2 * EMB_SIZE):.1f}x")

x = torch.randn(2, 8, EMB_SIZE)
y = ffn(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")

assert y.shape == x.shape
print("✓ SwiGLUFFN working correctly!")

## DecoderLayer

Complete transformer decoder layer with:
1. RMSNorm
2. Multi-Head Attention with residual
3. RMSNorm
4. SwiGLU FFN with residual

In [None]:
class DecoderLayer(nn.Module):
    """
    Transformer Decoder Layer.
    
    Architecture:
    1. x' = RMSNorm(x)
    2. h = MHA(x') + x (residual)
    3. x'' = RMSNorm(h)
    4. y = SwiGLU(x'') + x'' (residual from layer start)
    
    Note: Second residual is from h, not x (like some implementations).
    
    Args:
        dim: Model dimension
        num_q_heads: Number of query heads
        num_kv_heads: Number of key/value heads
        key_size: Dimension per head
    """
    
    def __init__(self, dim, num_q_heads, num_kv_heads, key_size=None):
        super().__init__()
        
        # Pre-attention norm
        self.attn_norm = RMSNorm(dim)
        
        # Multi-head attention
        self.attn = MultiHeadAttention(
            dim=dim,
            num_q_heads=num_q_heads,
            num_kv_heads=num_kv_heads,
            key_size=key_size
        )
        
        # Pre-FFN norm
        self.ffn_norm = RMSNorm(dim)
        
        # SwiGLU FFN
        self.ffn = SwiGLUFFN(dim)
    
    def forward(self, x, attention_mask=None, key_padding_mask=None):
        """
        Args:
            x: [B, T, D] input
            attention_mask: [B, T] or [B, 1, T, T] mask
            key_padding_mask: [B, T] padding mask
        
        Returns:
            Output [B, T, D]
        """
        # Attention block
        h = self.attn_norm(x)
        h = self.attn(h, key_padding_mask=key_padding_mask)
        h = h + x  # Residual from input
        
        # FFN block
        h = self.ffn_norm(h)
        h = self.ffn(h)
        h = h + x  # Residual from start of layer
        
        return h

# Test DecoderLayer
print("Testing DecoderLayer...")
layer = DecoderLayer(
    dim=EMB_SIZE,
    num_q_heads=NUM_Q_HEADS,
    num_kv_heads=NUM_KV_HEADS,
    key_size=KEY_SIZE
)

B, T = 2, 8
x = torch.randn(B, T, EMB_SIZE)
padding_mask = torch.ones(B, T, dtype=torch.bool)
padding_mask[:, 6:] = False

y = layer(x, key_padding_mask=~padding_mask)

print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")

assert y.shape == x.shape
print("✓ DecoderLayer working correctly!")

## make_recsys_attn_mask

Create attention mask for recommendation system inference.

**Key feature**: Candidate isolation!

**Mask structure**:
- User + History: Can attend to all previous positions (causal)
- Candidates: Can attend to User + History AND themselves
- **CANNOT** attend to other candidates

In [None]:
def make_recsys_attn_mask(seq_len, candidate_start_offset, dtype=torch.float32):
    """
    Create attention mask for recommendation system inference.
    
    The mask ensures:
    1. User + History: Causal attention (can attend to previous positions)
    2. Candidates: Can attend to User + History and themselves
    3. Candidates: CANNOT attend to other candidates (candidate isolation)
    
    Sequence layout:
    [0] = User
    [1:S] = History (S positions)
    [S:S+C] = Candidates (C positions)
    
    Args:
        seq_len: Total sequence length
        candidate_start_offset: Position where candidates start
        dtype: Output dtype
    
    Returns:
        Attention mask [1, 1, seq_len, seq_len] where True=can attend
    """
    # Start with causal mask (lower triangle)
    causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=dtype))
    
    # Create the mask
    attn_mask = causal_mask.clone()
    
    # Zero out candidate-to-candidate attention
    # This is the bottom-right block
    attn_mask[candidate_start_offset:, candidate_start_offset:] = 0
    
    # Add back self-attention for candidates (diagonal of candidate block)
    for i in range(candidate_start_offset, seq_len):
        attn_mask[i, i] = 1
    
    # Add batch and head dimensions
    attn_mask = attn_mask.unsqueeze(0).unsqueeze(0)
    
    return attn_mask

# Test make_recsys_attn_mask
print("Testing make_recsys_attn_mask...")

USER_LEN = 1
HISTORY_LEN = 5
CANDIDATE_LEN = 3
seq_len = USER_LEN + HISTORY_LEN + CANDIDATE_LEN
candidate_start_offset = USER_LEN + HISTORY_LEN

mask = make_recsys_attn_mask(seq_len, candidate_start_offset)

print(f"Sequence layout: [User] + [History x{HISTORY_LEN}] + [Candidates x{CANDIDATE_LEN}]")
print(f"Total seq_len: {seq_len}")
print(f"Candidate start offset: {candidate_start_offset}")
print(f"Mask shape: {mask.shape}")

# Visualize
fig, ax = plt.subplots(figsize=(8, 7))
mask_np = mask.squeeze().numpy()
im = ax.imshow(mask_np, cmap='Blues', interpolation='nearest')
ax.axvline(x=candidate_start_offset - 0.5, color='red', linestyle='--', linewidth=2)
ax.axhline(y=candidate_start_offset - 0.5, color='red', linestyle='--', linewidth=2)
ax.set_xlabel('Position (can attend to)')
ax.set_ylabel('Position (query)')
ax.set_title('Attention Mask\n(White=Can Attend, Blue=Cannot)')
ax.set_xticks(range(seq_len))
ax.set_yticks(range(seq_len))
ax.set_xticklabels(['U'] + [f'H{i}' for i in range(HISTORY_LEN)] + [f'C{i}' for i in range(CANDIDATE_LEN)])
ax.set_yticklabels(['U'] + [f'H{i}' for i in range(HISTORY_LEN)] + [f'C{i}' for i in range(CANDIDATE_LEN)])
plt.colorbar(im, ax=ax)
plt.tight_layout()
plt.show()

print("\nMask regions:")
print("  UL: User + History -> User + History (causal)")
print("  UR: User + History -> Candidates (allowed)")
print("  LL: Candidates -> User + History (allowed)")
print("  LR diagonal: Candidates -> Themselves (allowed)")
print("  LR off-diagonal: Candidates -> Other candidates (BLOCKED)")
print("\n✓ Candidate isolation verified!")

## Transformer

Complete transformer stack with optional candidate isolation.

In [None]:
class Transformer(nn.Module):
    """
    Transformer Encoder for recommendation systems.
    
    Features:
    - Stacks of DecoderLayers
    - Optional candidate isolation via custom attention mask
    - RMSNorm throughout
    
    Args:
        dim: Model dimension
        num_layers: Number of decoder layers
        num_q_heads: Number of query heads per layer
        num_kv_heads: Number of key/value heads per layer
        key_size: Dimension per head
    """
    
    def __init__(self, dim, num_layers, num_q_heads, num_kv_heads, key_size=None):
        super().__init__()
        self.dim = dim
        self.num_layers = num_layers
        
        # Stack of decoder layers
        self.layers = nn.ModuleList([
            DecoderLayer(dim, num_q_heads, num_kv_heads, key_size)
            for _ in range(num_layers)
        ])
        
        # Final norm
        self.norm = RMSNorm(dim)
    
    def forward(self, x, padding_mask=None, candidate_start_offset=None):
        """
        Args:
            x: [B, T, D] input embeddings
            padding_mask: [B, T] True for valid positions
            candidate_start_offset: Position where candidates start (None for standard causal)
        
        Returns:
            Output embeddings [B, T, D]
        """
        B, T, D = x.shape
        
        # Create attention mask
        if candidate_start_offset is not None:
            # Use candidate isolation mask
            attn_mask = make_recsys_attn_mask(T, candidate_start_offset, x.dtype)
            attn_mask = attn_mask.to(x.device)
            
            # Combine with padding mask
            if padding_mask is not None:
                # padding_mask: [B, T] -> [B, 1, 1, T]
                padding = (~padding_mask).unsqueeze(1).unsqueeze(1)
                attn_mask = attn_mask & ~padding  # True = can attend
        else:
            # Standard causal mask
            causal = torch.tril(torch.ones(T, T, dtype=torch.bool, device=x.device))
            attn_mask = causal.unsqueeze(0).unsqueeze(0)
            
            if padding_mask is not None:
                padding = (~padding_mask).unsqueeze(1).unsqueeze(1)
                attn_mask = attn_mask & ~padding
        
        # Pass through all layers
        h = x
        for layer in self.layers:
            h = layer(h, attention_mask=attn_mask)
        
        # Final norm
        h = self.norm(h)
        
        return h

# Test Transformer
print("Testing Transformer...")

transformer = Transformer(
    dim=EMB_SIZE,
    num_layers=NUM_LAYERS,
    num_q_heads=NUM_Q_HEADS,
    num_kv_heads=NUM_KV_HEADS,
    key_size=KEY_SIZE
)

B, T = 2, 10
x = torch.randn(B, T, EMB_SIZE)
padding_mask = torch.ones(B, T, dtype=torch.bool)
padding_mask[:, 8:] = False

# Test with candidate isolation
candidate_start = 5
y = transformer(x, padding_mask=padding_mask, candidate_start_offset=candidate_start)

print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")
print(f"With candidate isolation starting at position: {candidate_start}")

assert y.shape == x.shape
print("✓ Transformer working correctly!")

## Demo: With vs Without Candidate Isolation

Compare transformer outputs with and without candidate isolation.

In [None]:
print("=" * 60)
print("DEMO: With vs Without Candidate Isolation")
print("=" * 60)

transformer = Transformer(
    dim=EMB_SIZE,
    num_layers=NUM_LAYERS,
    num_q_heads=NUM_Q_HEADS,
    num_kv_heads=NUM_KV_HEADS,
    key_size=KEY_SIZE
)

B = 2
USER_LEN = 1
HISTORY_LEN = 5
CANDIDATE_LEN = 3
T = USER_LEN + HISTORY_LEN + CANDIDATE_LEN
candidate_start = USER_LEN + HISTORY_LEN

x = torch.randn(B, T, EMB_SIZE, requires_grad=True)
padding_mask = torch.ones(B, T, dtype=torch.bool)

# Forward pass WITHOUT candidate isolation
y_no_iso = transformer(x, padding_mask=padding_mask, candidate_start_offset=None)

# Forward pass WITH candidate isolation
y_iso = transformer(x, padding_mask=padding_mask, candidate_start_offset=candidate_start)

print(f"Input: {B} users, {T} positions")
print(f"Positions: [1 user] + [{HISTORY_LEN} history] + [{CANDIDATE_LEN} candidates]")
print(f"\nOutput shapes (should be same):")
print(f"  Without isolation: {y_no_iso.shape}")
print(f"  With isolation: {y_iso.shape}")

# Check if outputs differ
diff = torch.abs(y_no_iso - y_iso)
print(f"\nMax difference between outputs: {diff.max().item():.6f}")
print(f"Mean difference between outputs: {diff.mean().item():.6f}")

# Verify candidate isolation property
print("\n" + "-" * 60)
print("Verifying Candidate Isolation Property")
print("-" * 60)

# Modify only candidate 2
x_modified = x.clone()
x_modified[:, candidate_start + 1, :] += 100.0  # Big change to candidate 1

# Get outputs with isolation
y_iso_orig = transformer(x, padding_mask=padding_mask, candidate_start_offset=candidate_start)
y_iso_mod = transformer(x_modified, padding_mask=padding_mask, candidate_start_offset=candidate_start)

# Check effect on each candidate
print("\nAfter modifying candidate 1 (index 6):")
for i in range(CANDIDATE_LEN):
    cand_idx = candidate_start + i
    diff = torch.abs(y_iso_orig[0, cand_idx] - y_iso_mod[0, cand_idx]).max().item()
    if i == 1:  # Modified candidate
        print(f"  Candidate {i} (modified): output changed by {diff:.4f} ✓")
    else:  # Other candidates
        print(f"  Candidate {i} (unchanged): output changed by {diff:.8f}")

print("\n" + "=" * 60)
print("Key Insight:")
print("Without candidate isolation, all positions attend to each other.")
print("With candidate isolation, candidates cannot see other candidates.")
print("This ensures fair scoring and no information leakage.")
print("=" * 60)

## Visualization: Attention Patterns

Show how attention differs with and without candidate isolation.

In [None]:
# Create a simple attention model to visualize patterns
class SimpleTransformer(nn.Module):
    """Simplified transformer for visualization."""
    def __init__(self, dim, num_layers=1):
        super().__init__()
        self.layers = nn.ModuleList([nn.Linear(dim, dim) for _ in range(num_layers)])
    
    def forward(self, x, mask):
        return x

# Use the attention mask directly
USER_LEN = 1
HISTORY_LEN = 4
CANDIDATE_LEN = 3
T = USER_LEN + HISTORY_LEN + CANDIDATE_LEN
candidate_start = USER_LEN + HISTORY_LEN

# Create masks
causal_mask = torch.tril(torch.ones(T, T))
recsys_mask = make_recsys_attn_mask(T, candidate_start)

# Convert to numpy for visualization
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Standard causal mask
ax1 = axes[0]
im1 = ax1.imshow(causal_mask.numpy(), cmap='Blues')
ax1.axvline(x=candidate_start - 0.5, color='red', linestyle='--', linewidth=2)
ax1.axhline(y=candidate_start - 0.5, color='red', linestyle='--', linewidth=2)
ax1.set_title('Standard Causal Mask\n(Candidates can see each other)')
ax1.set_xlabel('Position')
ax1.set_ylabel('Position')
plt.colorbar(im1, ax=ax1)

# Recsys attention mask
ax2 = axes[1]
im2 = ax2.imshow(recsys_mask.squeeze().numpy(), cmap='Blues')
ax2.axvline(x=candidate_start - 0.5, color='red', linestyle='--', linewidth=2)
ax2.axhline(y=candidate_start - 0.5, color='red', linestyle='--', linewidth=2)
ax2.set_title('RecSys Attention Mask\n(Candidates CANNOT see each other)')
ax2.set_xlabel('Position')
ax2.set_ylabel('Position')
plt.colorbar(im2, ax=ax2)

# Add labels
labels = ['U'] + [f'H{i}' for i in range(HISTORY_LEN)] + [f'C{i}' for i in range(CANDIDATE_LEN)]
for ax in axes:
    ax.set_xticks(range(T))
    ax.set_yticks(range(T))
    ax.set_xticklabels(labels)
    ax.set_yticklabels(labels)

plt.tight_layout()
plt.show()

print("Key differences:")
print("1. Standard: Lower-right block has full attention (candidates see each other)")
print("2. RecSys: Lower-right block only has diagonal (candidates see only themselves)")

## Summary

In this notebook, you implemented the Grok transformer building blocks:

1. **RMSNorm**: Root Mean Square normalization (simpler than LayerNorm)
2. **RotaryEmbedding**: RoPE for relative position encoding
3. **MultiHeadAttention**: GQA attention with RoPE and soft capping
4. **SwiGLUFFN**: Swish-gated feed-forward network
5. **DecoderLayer**: Complete transformer layer with residuals
6. **make_recsys_attn_mask**: Candidate isolation attention mask
7. **Transformer**: Complete transformer stack

### Key Takeaways

- **GQA**: Memory-efficient attention (fewer KV heads)
- **RoPE**: Better relative position encoding (no learned positions)
- **Soft capping**: Prevents extreme attention weights
- **SwiGLU**: Better than ReLU for transformers
- **Candidate isolation**: Critical for fair ranking

### Next Steps

- **Lecture 6**: Scoring Pipeline - see how the transformer fits into the full system
- **Lecture 7**: Replication Guide - build your own recommender!