In [7]:
# Common imports and settings
import torch
import torch.nn as nn
from importlib.metadata import version

print("torch version:", version("torch"))
torch.manual_seed(123)

# Sample input for Exercises 3.1 and 3.2 (6 tokens, 3-dimensional embeddings)
inputs = torch.tensor([
    [0.43, 0.15, 0.89],
    [0.55, 0.87, 0.66],
    [0.57, 0.85, 0.64],
    [0.22, 0.58, 0.33],
    [0.77, 0.25, 0.10],
    [0.05, 0.80, 0.55]
])
d_in, d_out = 3, 2

# Exercise 3.1: Compare SelfAttention_v1 and SelfAttention_v2

# SelfAttention_v1 uses nn.Parameter for weight matrices
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        queries = x @ self.W_query
        keys = x @ self.W_key
        values = x @ self.W_value
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        return attn_weights @ values

# SelfAttention_v2 uses nn.Linear layers
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=False)
        self.W_key   = nn.Linear(d_in, d_out, bias=False)
        self.W_value = nn.Linear(d_in, d_out, bias=False)

    def forward(self, x):
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        return attn_weights @ values

# Instantiate both modules
sa_v1 = SelfAttention_v1(d_in, d_out)
sa_v2 = SelfAttention_v2(d_in, d_out)

# Transfer weights from sa_v2 (nn.Linear) to sa_v1 (transpose required)
with torch.no_grad():
    sa_v1.W_query.copy_(sa_v2.W_query.weight.data.T)
    sa_v1.W_key.copy_(sa_v2.W_key.weight.data.T)
    sa_v1.W_value.copy_(sa_v2.W_value.weight.data.T)

# Compute outputs to verify they match
output_v1 = sa_v1(inputs)
output_v2 = sa_v2(inputs)
print("Exercise 3.1 outputs:")
print("SelfAttention_v1:", output_v1)
print("SelfAttention_v2:", output_v2)

# Exercise 3.2: Returning 2-dimensional embedding vectors using multi-head attention

# Define a simple causal attention module (with masking)
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        # Create an upper-triangular causal mask
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, n, _ = x.shape
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill_(self.mask[:n, :n].bool(), -float('inf'))
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        return attn_weights @ values

# MultiHeadAttentionWrapper stacks multiple single-head attention modules
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList([
            CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)
            for _ in range(num_heads)
        ])

    def forward(self, x):
        # Concatenate outputs from each head along the last dimension
        return torch.cat([head(x) for head in self.heads], dim=-1)

# Create a dummy batch input (batch size = 2) for Exercise 3.2
batch = torch.stack([inputs, inputs], dim=0)  # Shape: (2, 6, 3)
context_length = batch.shape[1]
# To obtain a final output dimension of 2 with 2 heads, set d_out per head to 1 (1*2 = 2)
d_out_ex2 = 1
num_heads_ex2 = 2
dropout_rate = 0.0

mha_wrapper = MultiHeadAttentionWrapper(d_in, d_out_ex2, context_length, dropout_rate, num_heads_ex2)
context_vecs_ex2 = mha_wrapper(batch)
print("\nExercise 3.2 output:")
print("Context vectors shape:", context_vecs_ex2.shape)  # Expected: (2, 6, 2)

# Exercise 3.3: GPT-2 style Multi-Head Attention Module

# Efficient multi-head attention with weight splits and causal masking
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
        self.num_heads = num_heads
        self.d_head = d_out // num_heads
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.context_length = context_length
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, n, _ = x.shape
        # Compute queries, keys, and values
        Q = self.W_query(x)  # (b, n, d_out)
        K = self.W_key(x)
        V = self.W_value(x)
        # Split into multiple heads
        Q = Q.view(b, n, self.num_heads, self.d_head).transpose(1,2)  # (b, num_heads, n, d_head)
        K = K.view(b, n, self.num_heads, self.d_head).transpose(1,2)
        V = V.view(b, n, self.num_heads, self.d_head).transpose(1,2)
        # Scaled dot-product attention
        attn_scores = torch.matmul(Q, K.transpose(-2, -1))  # (b, num_heads, n, n)
        attn_scores = attn_scores / (self.d_head ** 0.5)
        # Apply causal mask
        attn_scores = attn_scores.masked_fill(self.mask[:n, :n].bool(), -float('inf'))
        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context = torch.matmul(attn_weights, V)  # (b, num_heads, n, d_head)
        # Concatenate heads
        context = context.transpose(1,2).contiguous().view(b, n, -1)  # (b, n, d_out)
        return self.out_proj(context)

# GPT-2 style parameters: d_in = d_out = 768, num_heads = 12, context_length = 1024
batch_size = 2
context_length_gpt2 = 1024
d_in_gpt2 = 768
d_out_gpt2 = 768
num_heads_gpt2 = 12
dropout_rate_gpt2 = 0.0

# Create a dummy input for GPT-2 style attention (e.g., two sequences of 1024 tokens)
dummy_input = torch.rand(batch_size, context_length_gpt2, d_in_gpt2)
mha_gpt2 = MultiHeadAttention(d_in_gpt2, d_out_gpt2, context_length_gpt2, dropout_rate_gpt2, num_heads_gpt2)
output_gpt2 = mha_gpt2(dummy_input)
print("\nExercise 3.3 output:")
print("Output shape:", output_gpt2.shape)  # Expected: (2, 1024, 768)

torch version: 2.6.0
Exercise 3.1 outputs:
SelfAttention_v1: tensor([[0.5085, 0.3508],
        [0.5084, 0.3508],
        [0.5084, 0.3506],
        [0.5074, 0.3471],
        [0.5076, 0.3446],
        [0.5077, 0.3493]], grad_fn=<MmBackward0>)
SelfAttention_v2: tensor([[0.5085, 0.3508],
        [0.5084, 0.3508],
        [0.5084, 0.3506],
        [0.5074, 0.3471],
        [0.5076, 0.3446],
        [0.5077, 0.3493]], grad_fn=<MmBackward0>)

Exercise 3.2 output:
Context vectors shape: torch.Size([2, 6, 2])

Exercise 3.3 output:
Output shape: torch.Size([2, 1024, 768])
