## Input sequence: "Dream big and work for it"

In [1]:
import torch

inputs = torch.tensor(
    [[0.72, 0.45, 0.31], # Dream    (x^1)
     [0.75, 0.20, 0.55], # big      (x^2)
     [0.30, 0.80, 0.40], # and      (x^3)
     [0.85, 0.35, 0.60], # work     (x^4)
     [0.55, 0.15, 0.75], # for      (x^5)
     [0.25, 0.20, 0.85]] # it       (x^6)
)

# Corresponding words
words = ['Dream', 'big', 'and', 'work', 'for', 'it']

## Class for implementing causal attention

In [2]:
from torch import nn

class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ values.transpose(1, 2)
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1] ** 0.5, dim=-1
        )

        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ values
        return context_vec

In [3]:
d_in = inputs.shape[-1]
d_out = 2
print(d_in, d_out)

3 2


In [4]:
batch = torch.stack((inputs,), dim=0)
print(batch.shape)

torch.Size([1, 6, 3])


## Class for implementing multi-head attention

In [5]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList([
            CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)
        ])

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [6]:
torch.manual_seed(123)
context_length = batch.shape[1]
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)

In [7]:
context_vecs = mha(batch)
print(context_vecs)
print('context_vecs.shape:', context_vecs.shape)

tensor([[[-0.5762, -0.1627,  0.5569,  0.3635],
         [-0.5650, -0.0622,  0.5600,  0.2991],
         [-0.5474, -0.1223,  0.5293,  0.3414],
         [-0.5786, -0.0938,  0.5637,  0.3375],
         [-0.5586, -0.0418,  0.5521,  0.3015],
         [-0.5271, -0.0006,  0.5282,  0.2698]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([1, 6, 4])


## Implementing multi-head attention with weight splits

Instead of maintaining two separate classes, MultiHeadAttentionWrapper and CausalAttention, we can combine both of these concepts into a single MultiHeadAttention class.

* Step 1: Reduce the projection dim to match desired output dim
* Step 2: Use a linear layer to combine head outputs
* Step 3: Tensor shape: (b, num_tokens, d_out)
* Step 4: We implicitly split the matrix by adding a num_heads dimension. Then we unroll last dim: (b, num_tokens, head_dim)
* Step 5: Transpose from shape (b, num_tokens, num_heads, head_dim) to (b, num_heads, num_tokens, head_dim)
* Step 6: Compute dot product of each head
* Step 7: Mask truncated to the number of tokens
* Step 8: Use the mask to fill attention scores
* Step 9: Tensor shape: (b, num_tokens, n_heads, head_dim)
* Step 10: Combine heads, where self.d_out = self.num_heads * self.head_dim
* Step 11: Add an optional linear projection 

In [8]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_in) # Linear layer to combine
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(
                context_length, context_length
            ), diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        queries = self.W_query(x) # Shape: (b, num_tokens, d_out)
        keys = self.W_key(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a 'num_heads' dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the maks to fill attention scores
        attn_scores.masked_fill(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec

In [10]:
torch.manual_seed(123)

# Define the tensor with 3 rows and 6 columns
inputs = torch.tensor([
    [0.43, 0.15, 0.89, 0.55, 0.87, 0.66], # Row 1
    [0.57, 0.85, 0.64, 0.22, 0.58, 0.33], # Row 2
    [0.77, 0.25, 0.10, 0.05, 0.80, 0.55]  # Row 3
])

batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)
print(batch)

batch_size, context_length, d_in = batch.shape
d_out = 6
context_length = inputs.shape[0]
dropout = 0.0
num_heads = 2

mha = MultiHeadAttention(d_in, d_out, context_length, dropout, num_heads, True)

context_vecs = mha(batch)
print(context_vecs)
print('context_vecs.shape:', context_vecs.shape)

torch.Size([2, 3, 6])
tensor([[[0.4300, 0.1500, 0.8900, 0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400, 0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000, 0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900, 0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400, 0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000, 0.0500, 0.8000, 0.5500]]])
tensor([[[-0.3464,  0.3428,  0.3846, -0.1505,  0.1512,  0.0491],
         [-0.3462,  0.3426,  0.3844, -0.1505,  0.1516,  0.0493],
         [-0.3465,  0.3427,  0.3842, -0.1503,  0.1512,  0.0489]],

        [[-0.3464,  0.3428,  0.3846, -0.1505,  0.1512,  0.0491],
         [-0.3462,  0.3426,  0.3844, -0.1505,  0.1516,  0.0493],
         [-0.3465,  0.3427,  0.3842, -0.1503,  0.1512,  0.0489]]],
       grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 3, 6])


## Stepwise implementation of multi-head attention

In [11]:
import torch

# Input: (batch=1, seq_len=3, d_model=6)
x = torch.tensor([[
    [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], # The 
    [6.0, 5.0, 4.0, 3.0, 2.0, 1.0], # Kid
    [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]  # Smiles
]])

batch_size, seq_len, d_model = x.shape

In [12]:
# Define 6x6 projection matrices (d_model x d_model)
torch.manual_seed(0) # for reprducibility
Wq = torch.randn(d_model, d_model)
Wk = torch.randn(d_model, d_model)
Wv = torch.randn(d_model, d_model)

# Compute Q, K, V
# Shape logic: (B, T, d_model) @ (d_model, d_model) -> (B, T, d_model)
Q = x @ Wq
K = x @ Wk
V = x @ Wv

# Print Q, K, and V
print("Q:\n", Q)
print("K:\n", K)
print("V:\n", V)

# Print dimensionalities
print('X shape:', x.shape)      # (1, 3, 6)
print('Wq shape:', Wq.shape)    # (6, 6)
print('Wk shape:', Wk.shape)    # (6, 6)
print('Wv shape:', Wv.shape)    # (6, 6)
print('Q shape:', Q.shape)      # (1, 3, 6)
print('K shape:', K.shape)      # (1, 3, 6)
print('V shape:', V.shape)      # (1, 3, 6)

Q:
 tensor([[[ -9.0244, -11.7287,  15.5360,  -1.4474,  -4.5326,   9.4674],
         [ -8.0564, -13.2309,   8.2228,  -8.9680,   3.1995,   4.8321],
         [ -2.4401,  -3.5657,   3.3941,  -1.4879,  -0.1904,   2.0428]]])
K:
 tensor([[[  8.2602,  14.1116,  -5.0345, -16.4865,  -2.9948,   8.3139],
         [ -6.1188,  -0.1587,  -5.0885, -14.3014,   4.9540,   5.6093],
         [  0.3059,   1.9933,  -1.4461,  -4.3983,   0.2799,   1.9890]]])
V:
 tensor([[[ 0.5076, -3.4353,  1.8576,  2.8041,  8.9427, 13.1841],
         [-1.9113, -3.6934,  1.8502,  1.7622,  1.6981,  3.0978],
         [-0.2005, -1.0184,  0.5297,  0.6523,  1.5201,  2.3260]]])
X shape: torch.Size([1, 3, 6])
Wq shape: torch.Size([6, 6])
Wk shape: torch.Size([6, 6])
Wv shape: torch.Size([6, 6])
Q shape: torch.Size([1, 3, 6])
K shape: torch.Size([1, 3, 6])
V shape: torch.Size([1, 3, 6])


In [13]:
# Print values
torch.set_printoptions(precision=2)
print("\nWq:\n", Wq)
print("\nWk:\n", Wk)
print("\nWv:\n", Wv)

print("\nQ:\n", Q)
print("\nK:\n", K)
print("\nV:\n", V)


Wq:
 tensor([[-1.13, -1.15, -0.25, -0.43,  0.85,  0.69],
        [-0.32, -2.12,  0.32, -1.26,  0.35,  0.31],
        [ 0.12,  1.24,  1.12, -0.25, -1.35, -1.70],
        [ 0.57,  0.79,  0.44,  0.11,  0.64,  0.44],
        [-0.22, -0.74,  0.56,  0.26,  0.52,  2.30],
        [-1.47, -1.59,  1.20,  0.08, -1.20, -0.00]])

Wk:
 tensor([[-0.23, -0.39,  0.54, -0.40,  0.21, -0.45],
        [-0.57, -0.56, -1.53, -1.23,  1.82, -0.55],
        [-1.33,  0.19, -0.07, -0.49, -1.48,  2.57],
        [-0.47,  0.34, -0.00, -0.53,  1.17,  0.39],
        [ 1.94,  0.79, -0.02, -0.44, -1.54, -0.41],
        [ 0.97,  1.62, -0.37, -1.30,  0.10,  0.44]])

Wv:
 tensor([[ 0.07,  1.11,  0.28,  0.43, -0.80, -1.30],
        [-0.75, -1.31,  0.21, -0.33, -0.43,  0.23],
        [ 0.80, -0.18, -0.37, -1.21, -0.70,  1.04],
        [-0.60, -1.28, -0.03,  1.37,  2.66,  0.99],
        [-0.26,  0.12,  0.24,  1.16,  2.70,  1.24],
        [ 0.54,  0.53,  0.19, -0.77, -1.90,  0.13]])

Q:
 tensor([[[ -9.02, -11.73,  15.54,  -1.

In [14]:
num_heads = 2
head_dim = 3

Q = Q.view(1, 3, num_heads, head_dim)
K = K.view(1, 3, num_heads, head_dim)
V = V.view(1, 3, num_heads, head_dim)

print('Q after unrolling:', Q)
print('K after unrolling:', K)
print('V after unrolling:', V)

Q after unrolling: tensor([[[[ -9.02, -11.73,  15.54],
          [ -1.45,  -4.53,   9.47]],

         [[ -8.06, -13.23,   8.22],
          [ -8.97,   3.20,   4.83]],

         [[ -2.44,  -3.57,   3.39],
          [ -1.49,  -0.19,   2.04]]]])
K after unrolling: tensor([[[[  8.26,  14.11,  -5.03],
          [-16.49,  -2.99,   8.31]],

         [[ -6.12,  -0.16,  -5.09],
          [-14.30,   4.95,   5.61]],

         [[  0.31,   1.99,  -1.45],
          [ -4.40,   0.28,   1.99]]]])
V after unrolling: tensor([[[[ 0.51, -3.44,  1.86],
          [ 2.80,  8.94, 13.18]],

         [[-1.91, -3.69,  1.85],
          [ 1.76,  1.70,  3.10]],

         [[-0.20, -1.02,  0.53],
          [ 0.65,  1.52,  2.33]]]])


In [15]:
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)

print('Q after grouping by heads:', Q)
print('K after grouping by heads:', K)
print('V after grouping by heads:', V)

Q after grouping by heads: tensor([[[[ -9.02, -11.73,  15.54],
          [ -8.06, -13.23,   8.22],
          [ -2.44,  -3.57,   3.39]],

         [[ -1.45,  -4.53,   9.47],
          [ -8.97,   3.20,   4.83],
          [ -1.49,  -0.19,   2.04]]]])
K after grouping by heads: tensor([[[[  8.26,  14.11,  -5.03],
          [ -6.12,  -0.16,  -5.09],
          [  0.31,   1.99,  -1.45]],

         [[-16.49,  -2.99,   8.31],
          [-14.30,   4.95,   5.61],
          [ -4.40,   0.28,   1.99]]]])
V after grouping by heads: tensor([[[[ 0.51, -3.44,  1.86],
          [-1.91, -3.69,  1.85],
          [-0.20, -1.02,  0.53]],

         [[ 2.80,  8.94, 13.18],
          [ 1.76,  1.70,  3.10],
          [ 0.65,  1.52,  2.33]]]])


In [16]:
K_T = K.transpose(2, 3)
print('K_T shape:', K_T)

K_T shape: tensor([[[[  8.26,  -6.12,   0.31],
          [ 14.11,  -0.16,   1.99],
          [ -5.03,  -5.09,  -1.45]],

         [[-16.49, -14.30,  -4.40],
          [ -2.99,   4.95,   0.28],
          [  8.31,   5.61,   1.99]]]])


In [17]:
attn_scores = Q @ K_T
print('Attention scores shape:', attn_scores.shape)
print('Attention scores:\n', attn_scores)

Attention scores shape: torch.Size([1, 2, 3, 3])
Attention scores:
 tensor([[[[-318.27,  -21.97,  -48.61],
          [-294.65,    9.55,  -40.73],
          [ -87.56,   -1.77,  -12.76]],

         [[ 116.15,   51.35,   23.93],
          [ 178.44,  171.21,   49.95],
          [  42.08,   31.79,   10.55]]]])


In [18]:
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
print('Causal mask:\n', mask)

attn_scores.masked_fill_(mask, -torch.inf)
print('Attention scores after masking:\n', attn_scores)

Causal mask:
 tensor([[False,  True,  True],
        [False, False,  True],
        [False, False, False]])
Attention scores after masking:
 tensor([[[[-318.27,    -inf,    -inf],
          [-294.65,    9.55,    -inf],
          [ -87.56,   -1.77,  -12.76]],

         [[ 116.15,    -inf,    -inf],
          [ 178.44,  171.21,    -inf],
          [  42.08,   31.79,   10.55]]]])


In [19]:
torch.set_printoptions(precision=3, sci_mode=False)
head_dim = 3
attn_weights = torch.softmax(attn_scores / head_dim**0.5, dim=-1)
print('Attention weights shape:', attn_weights.shape)
print('Attention weights:\n', attn_weights)

Attention weights shape: torch.Size([1, 2, 3, 3])
Attention weights:
 tensor([[[[    1.000,     0.000,     0.000],
          [    0.000,     1.000,     0.000],
          [    0.000,     0.998,     0.002]],

         [[    1.000,     0.000,     0.000],
          [    0.985,     0.015,     0.000],
          [    0.997,     0.003,     0.000]]]])


In [20]:
dropout = torch.nn.Dropout(0.1)
attn_weights = dropout(attn_weights)
print('Attention weights after dropout:\n', attn_weights)

Attention weights after dropout:
 tensor([[[[    1.111,     0.000,     0.000],
          [    0.000,     1.111,     0.000],
          [    0.000,     1.109,     0.000]],

         [[    1.111,     0.000,     0.000],
          [    1.094,     0.000,     0.000],
          [    1.108,     0.003,     0.000]]]])


In [21]:
context_vec = attn_weights @ V
print('Context vector shape:', context_vec.shape)
print('Context vec:\n', context_vec)

Context vector shape: torch.Size([1, 2, 3, 3])
Context vec:
 tensor([[[[ 0.564, -3.817,  2.064],
          [-2.124, -4.104,  2.056],
          [-2.120, -4.097,  2.052]],

         [[ 3.116,  9.936, 14.649],
          [ 3.068,  9.786, 14.427],
          [ 3.113,  9.915, 14.620]]]])


### Step 11: Reformat and concatenate

In [22]:
context_vec = context_vec.transpose(1, 2)
print('Context vector after swapping dimensions 1 and 2:', context_vec.shape)
print('Context vector:\n', context_vec)

Context vector after swapping dimensions 1 and 2: torch.Size([1, 3, 2, 3])
Context vector:
 tensor([[[[ 0.564, -3.817,  2.064],
          [ 3.116,  9.936, 14.649]],

         [[-2.124, -4.104,  2.056],
          [ 3.068,  9.786, 14.427]],

         [[-2.120, -4.097,  2.052],
          [ 3.113,  9.915, 14.620]]]])


In [23]:
context_vec = context_vec.reshape(batch_size, seq_len, num_heads * head_dim)
print('Context vector after concatenating heads:', context_vec.shape)
print('Context vector:\n', context_vec)

Context vector after concatenating heads: torch.Size([1, 3, 6])
Context vector:
 tensor([[[ 0.564, -3.817,  2.064,  3.116,  9.936, 14.649],
         [-2.124, -4.104,  2.056,  3.068,  9.786, 14.427],
         [-2.120, -4.097,  2.052,  3.113,  9.915, 14.620]]])
