In [1]:
pip install -q torch

[0mNote: you may need to restart the kernel to use updated packages.


In [2]:
import torch
import torch.nn as nn

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)
d_in = inputs.shape[1] # the input embedding size, d=3
d_out = 2 # the output embedding size, d=2
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape) # 2 inputs with 6 tokens each, and each token has embedding dimension 3

torch.Size([2, 6, 3])


In [3]:


class CausalAttention(nn.Module):

    def __init__(self, d_in, d_out, context_length,
                 dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout) # New
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New

    def forward(self, x):
        # print("x:\n", x)
        b, num_tokens, d_in = x.shape # New batch dimension b
        # For inputs where `num_tokens` exceeds `context_length`, this will result in errors
        # in the mask creation further below.
        # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs  
        # do not exceed `context_length` before reaching this forward method. 
        keys = self.W_key(x)        
        # print("self.W_key:\n", self.W_key.weight)
        # print("keys:\n", keys)
        queries = self.W_query(x)
        values = self.W_value(x)


        attn_scores = queries @ keys.transpose(1, 2) # Changed transpose
        # print("attn_scores:\n", attn_scores)
        attn_scores.masked_fill_(  # New, _ ops are in-place
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)  # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        attn_weights = self.dropout(attn_weights) # New

        context_vec = attn_weights @ values
        return context_vec

In [4]:
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)

context_vecs = ca(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]],

        [[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])


In [5]:
class MultiHeadAttentionWrapper(nn.Module):

    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) 
             for _ in range(num_heads)]
        )

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)


torch.manual_seed(123)

context_length = batch.shape[1] # This is the number of tokens
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(
    d_in, d_out, context_length, 0.0, num_heads=2
)

context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])


In [6]:
#
#
#

In [21]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
        # print("self.head_dim:", self.head_dim)

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        # print("self.W_query:", "\n", self.W_query.weight)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        # print("self.W_key:", "\n", self.W_key.weight)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        # print("self.W_value:", "\n", self.W_value.weight)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        # As in `CausalAttention`, for inputs where `num_tokens` exceeds `context_length`, 
        # this will result in errors in the mask creation further below. 
        # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs  
        # do not exceed `context_length` before reaching this forward method.
        # print("x:", x.shape, "\n", x)

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        # print("keys 1:", keys.shape, "\n", keys)

        queries = self.W_query(x)
        # print("queries 1:", queries.shape, "\n", queries)
        values = self.W_value(x)

        # print("keys:", keys.shape, "\n", keys)
        print("..........view.............")
        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 
        # print("keys view:", keys.shape, "\n", keys)
        
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)

        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)       

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        # print("keys transposed:", keys.shape, "\n", keys)
        queries = queries.transpose(1, 2)
        # print("queries 2:", queries.shape, "\n", queries) 
        values = values.transpose(1, 2)
        # print("values 2:", values.shape, "\n", values)
        # print("keys 2:", keys.shape, "\n", keys)
        
        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        # print("..........keys3.............")
        # keys3 = keys.transpose(2, 3)
        # print("keys3:", keys3.shape, "\n", keys3)

        print("..........attn_scores.............")
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head
        # print("attn_scores:", attn_scores.shape, "\n", attn_scores) 
        
        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        print("..........attn_scores after masking.............")
        # print("attn_scores:", attn_scores.shape, "\n", attn_scores)  
        
        print("..........softmax.............")
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        # print("attn_weights:", attn_weights.shape, "\n", attn_weights) 
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values)
        # print("context_vec:", context_vec.shape, "\n", context_vec) 
        
        context_vec = context_vec.transpose(1, 2) 
        # print("context_vec:", context_vec.shape, "\n", context_vec) 

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        # print("context_vec 2:", context_vec.shape, "\n", context_vec) 
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec

torch.manual_seed(123)

batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)

print("out_proj:\n", mha.out_proj.weight)
context_vecs = mha(batch)

# print(context_vecs)
print("context_vecs:", context_vecs.shape, "\n", context_vecs)

out_proj:
 Parameter containing:
tensor([[-0.1668,  0.2270],
        [ 0.5000,  0.1317]], requires_grad=True)
..........view.............
..........attn_scores.............
..........attn_scores after masking.............
..........softmax.............
context_vecs: torch.Size([2, 6, 2]) 
 tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
