<a href="https://colab.research.google.com/github/teelch0/Data-Mining/blob/main/Multi_Head_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [25]:
import torch
import torch.nn as nn

In [26]:
#Multi-head attention
#multiple causal attentions stuck together
class MultiHeadAttention(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
    super().__init__()
    self.heads = nn.ModuleList([CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)])

  def forward(self, x):
    return torch.cat( [ head(x) for head in self.heads], dim= -1)

In [27]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        # As in `CausalAttention`, for inputs where `num_tokens` exceeds `context_length`,
        # this will result in errors in the mask creation further below.
        # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs
        # do not exceed `context_length` before reaching this forward method.

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec

In [28]:
context_length= 5
d_in= 10

In [29]:
#Start with creating inputs
inputs= torch.nn.Embedding(5, 10)
inputs= inputs.weight
inputs= inputs.data
print(inputs.shape)
inputs

torch.Size([5, 10])


tensor([[-0.3500,  0.7753, -1.1204,  1.4635, -0.1660, -2.0281,  0.4279, -1.2608,
          0.8427, -0.6945],
        [ 0.6027,  0.5035,  2.2409, -0.6299, -0.7292,  1.2694, -0.5501, -0.0203,
          1.5254,  0.1577],
        [-1.7250,  1.2829,  0.9096,  0.5401,  0.0137,  0.1198, -0.6908, -1.0769,
         -0.4205, -0.8605],
        [-0.0182, -0.0219, -0.2350,  0.5616, -0.2769, -0.1936,  0.6561, -0.5074,
          1.5030,  0.7483],
        [ 0.5468,  0.2581,  0.6384,  0.4816,  1.2872, -0.1513, -0.0416, -0.0716,
          1.1854,  1.0654]])

In [30]:
#stack your inputs into a tensor using torch
batches = torch.stack( (inputs, inputs), dim=0)
print(batches.shape)
batches

torch.Size([2, 5, 10])


tensor([[[-0.3500,  0.7753, -1.1204,  1.4635, -0.1660, -2.0281,  0.4279,
          -1.2608,  0.8427, -0.6945],
         [ 0.6027,  0.5035,  2.2409, -0.6299, -0.7292,  1.2694, -0.5501,
          -0.0203,  1.5254,  0.1577],
         [-1.7250,  1.2829,  0.9096,  0.5401,  0.0137,  0.1198, -0.6908,
          -1.0769, -0.4205, -0.8605],
         [-0.0182, -0.0219, -0.2350,  0.5616, -0.2769, -0.1936,  0.6561,
          -0.5074,  1.5030,  0.7483],
         [ 0.5468,  0.2581,  0.6384,  0.4816,  1.2872, -0.1513, -0.0416,
          -0.0716,  1.1854,  1.0654]],

        [[-0.3500,  0.7753, -1.1204,  1.4635, -0.1660, -2.0281,  0.4279,
          -1.2608,  0.8427, -0.6945],
         [ 0.6027,  0.5035,  2.2409, -0.6299, -0.7292,  1.2694, -0.5501,
          -0.0203,  1.5254,  0.1577],
         [-1.7250,  1.2829,  0.9096,  0.5401,  0.0137,  0.1198, -0.6908,
          -1.0769, -0.4205, -0.8605],
         [-0.0182, -0.0219, -0.2350,  0.5616, -0.2769, -0.1936,  0.6561,
          -0.5074,  1.5030,  0.7483],

In [31]:
#defining some variables that will be used
d_out= 6
num_heads= 3 #make sure d_out is perfectly divisible by num_heads
head_dim= 2
b= 2

In [32]:
#Initial creation of the query and key matrices
W_query= nn.Linear(d_in, d_out, bias= False)
W_key= nn.Linear(d_in, d_out, bias= False)

In [33]:
#fill keys with batches
keys= W_key(batches)
print(keys.shape)
keys

torch.Size([2, 5, 6])


tensor([[[ 0.1691,  0.1188,  0.7480,  0.6117,  0.9782, -0.4419],
         [-1.3669,  0.1143, -0.2911,  0.0660, -1.1708, -0.4260],
         [-0.7013,  0.4421,  1.1728,  0.4512, -0.0262,  0.4880],
         [-0.2076,  0.1596,  0.0519,  0.3650, -0.0022, -0.4062],
         [ 0.3004,  0.2110,  0.2574, -0.4897,  0.0420,  0.0453]],

        [[ 0.1691,  0.1188,  0.7480,  0.6117,  0.9782, -0.4419],
         [-1.3669,  0.1143, -0.2911,  0.0660, -1.1708, -0.4260],
         [-0.7013,  0.4421,  1.1728,  0.4512, -0.0262,  0.4880],
         [-0.2076,  0.1596,  0.0519,  0.3650, -0.0022, -0.4062],
         [ 0.3004,  0.2110,  0.2574, -0.4897,  0.0420,  0.0453]]],
       grad_fn=<UnsafeViewBackward0>)

In [34]:
#fill queries with batches
queries = W_query(batches)
print(queries.shape)
queries

torch.Size([2, 5, 6])


tensor([[[-0.5409, -0.3635, -0.0140, -1.2837, -0.0203,  0.7939],
         [ 0.3729,  0.7044, -0.2444,  0.8769, -0.2299,  0.1697],
         [-1.1598, -0.5548, -0.5169, -0.1720,  0.6657, -0.2489],
         [ 0.3508,  0.5328, -0.0522,  0.2179, -0.5714,  0.6954],
         [-0.0082,  0.2895,  0.1432,  0.3992,  0.1078,  0.2411]],

        [[-0.5409, -0.3635, -0.0140, -1.2837, -0.0203,  0.7939],
         [ 0.3729,  0.7044, -0.2444,  0.8769, -0.2299,  0.1697],
         [-1.1598, -0.5548, -0.5169, -0.1720,  0.6657, -0.2489],
         [ 0.3508,  0.5328, -0.0522,  0.2179, -0.5714,  0.6954],
         [-0.0082,  0.2895,  0.1432,  0.3992,  0.1078,  0.2411]]],
       grad_fn=<UnsafeViewBackward0>)

In [35]:
#change the view to satisfy number of heads
keys= keys.view(b, context_length , num_heads, head_dim)
print(keys.shape)
keys

torch.Size([2, 5, 3, 2])


tensor([[[[ 0.1691,  0.1188],
          [ 0.7480,  0.6117],
          [ 0.9782, -0.4419]],

         [[-1.3669,  0.1143],
          [-0.2911,  0.0660],
          [-1.1708, -0.4260]],

         [[-0.7013,  0.4421],
          [ 1.1728,  0.4512],
          [-0.0262,  0.4880]],

         [[-0.2076,  0.1596],
          [ 0.0519,  0.3650],
          [-0.0022, -0.4062]],

         [[ 0.3004,  0.2110],
          [ 0.2574, -0.4897],
          [ 0.0420,  0.0453]]],


        [[[ 0.1691,  0.1188],
          [ 0.7480,  0.6117],
          [ 0.9782, -0.4419]],

         [[-1.3669,  0.1143],
          [-0.2911,  0.0660],
          [-1.1708, -0.4260]],

         [[-0.7013,  0.4421],
          [ 1.1728,  0.4512],
          [-0.0262,  0.4880]],

         [[-0.2076,  0.1596],
          [ 0.0519,  0.3650],
          [-0.0022, -0.4062]],

         [[ 0.3004,  0.2110],
          [ 0.2574, -0.4897],
          [ 0.0420,  0.0453]]]], grad_fn=<ViewBackward0>)

In [36]:
#Change the view of queries to satisfy number of heads
queries= queries.view(b, context_length , num_heads, head_dim)
print(queries.shape)
queries

torch.Size([2, 5, 3, 2])


tensor([[[[-0.5409, -0.3635],
          [-0.0140, -1.2837],
          [-0.0203,  0.7939]],

         [[ 0.3729,  0.7044],
          [-0.2444,  0.8769],
          [-0.2299,  0.1697]],

         [[-1.1598, -0.5548],
          [-0.5169, -0.1720],
          [ 0.6657, -0.2489]],

         [[ 0.3508,  0.5328],
          [-0.0522,  0.2179],
          [-0.5714,  0.6954]],

         [[-0.0082,  0.2895],
          [ 0.1432,  0.3992],
          [ 0.1078,  0.2411]]],


        [[[-0.5409, -0.3635],
          [-0.0140, -1.2837],
          [-0.0203,  0.7939]],

         [[ 0.3729,  0.7044],
          [-0.2444,  0.8769],
          [-0.2299,  0.1697]],

         [[-1.1598, -0.5548],
          [-0.5169, -0.1720],
          [ 0.6657, -0.2489]],

         [[ 0.3508,  0.5328],
          [-0.0522,  0.2179],
          [-0.5714,  0.6954]],

         [[-0.0082,  0.2895],
          [ 0.1432,  0.3992],
          [ 0.1078,  0.2411]]]], grad_fn=<ViewBackward0>)

In [37]:
#Transpose keys to prepare for matrix multiplication
keys_T= keys.transpose(1,2)
print(keys_T.shape)
keys_T

torch.Size([2, 3, 5, 2])


tensor([[[[ 0.1691,  0.1188],
          [-1.3669,  0.1143],
          [-0.7013,  0.4421],
          [-0.2076,  0.1596],
          [ 0.3004,  0.2110]],

         [[ 0.7480,  0.6117],
          [-0.2911,  0.0660],
          [ 1.1728,  0.4512],
          [ 0.0519,  0.3650],
          [ 0.2574, -0.4897]],

         [[ 0.9782, -0.4419],
          [-1.1708, -0.4260],
          [-0.0262,  0.4880],
          [-0.0022, -0.4062],
          [ 0.0420,  0.0453]]],


        [[[ 0.1691,  0.1188],
          [-1.3669,  0.1143],
          [-0.7013,  0.4421],
          [-0.2076,  0.1596],
          [ 0.3004,  0.2110]],

         [[ 0.7480,  0.6117],
          [-0.2911,  0.0660],
          [ 1.1728,  0.4512],
          [ 0.0519,  0.3650],
          [ 0.2574, -0.4897]],

         [[ 0.9782, -0.4419],
          [-1.1708, -0.4260],
          [-0.0262,  0.4880],
          [-0.0022, -0.4062],
          [ 0.0420,  0.0453]]]], grad_fn=<TransposeBackward0>)

In [38]:
#Transpose queries
queries_T= queries.transpose(1,2)
print(queries_T.shape)
queries_T

torch.Size([2, 3, 5, 2])


tensor([[[[-0.5409, -0.3635],
          [ 0.3729,  0.7044],
          [-1.1598, -0.5548],
          [ 0.3508,  0.5328],
          [-0.0082,  0.2895]],

         [[-0.0140, -1.2837],
          [-0.2444,  0.8769],
          [-0.5169, -0.1720],
          [-0.0522,  0.2179],
          [ 0.1432,  0.3992]],

         [[-0.0203,  0.7939],
          [-0.2299,  0.1697],
          [ 0.6657, -0.2489],
          [-0.5714,  0.6954],
          [ 0.1078,  0.2411]]],


        [[[-0.5409, -0.3635],
          [ 0.3729,  0.7044],
          [-1.1598, -0.5548],
          [ 0.3508,  0.5328],
          [-0.0082,  0.2895]],

         [[-0.0140, -1.2837],
          [-0.2444,  0.8769],
          [-0.5169, -0.1720],
          [-0.0522,  0.2179],
          [ 0.1432,  0.3992]],

         [[-0.0203,  0.7939],
          [-0.2299,  0.1697],
          [ 0.6657, -0.2489],
          [-0.5714,  0.6954],
          [ 0.1078,  0.2411]]]], grad_fn=<TransposeBackward0>)

In [39]:
#Multiply transposed queries by newly transposed keys for attention scores
attn_scores= queries_T @ keys_T.transpose(2, 3)
print(attn_scores.shape)
attn_scores

torch.Size([2, 3, 5, 5])


tensor([[[[-1.3464e-01,  6.9784e-01,  2.1856e-01,  5.4249e-02, -2.3919e-01],
          [ 1.4673e-01, -4.2929e-01,  4.9914e-02,  3.5004e-02,  2.6066e-01],
          [-2.6200e-01,  1.5220e+00,  5.6800e-01,  1.5218e-01, -4.6546e-01],
          [ 1.2262e-01, -4.1868e-01, -1.0446e-02,  1.2211e-02,  2.1782e-01],
          [ 3.3008e-02,  4.4287e-02,  1.3374e-01,  4.7900e-02,  5.8627e-02]],

         [[-7.9567e-01, -8.0605e-02, -5.9559e-01, -4.6923e-01,  6.2499e-01],
          [ 3.5357e-01,  1.2900e-01,  1.0897e-01,  3.0735e-01, -4.9232e-01],
          [-4.9187e-01,  1.3913e-01, -6.8389e-01, -8.9642e-02, -4.8789e-02],
          [ 9.4263e-02,  2.9577e-02,  3.7099e-02,  7.6834e-02, -1.2016e-01],
          [ 3.5126e-01, -1.5351e-02,  3.4803e-01,  1.5313e-01, -1.5862e-01]],

         [[-3.7064e-01, -3.1444e-01,  3.8796e-01, -3.2246e-01,  3.5115e-02],
          [-2.9993e-01,  1.9692e-01,  8.8852e-02, -6.8451e-02, -1.9695e-03],
          [ 7.6120e-01, -6.7346e-01, -1.3887e-01,  9.9641e-02,  1.6692e-

In [40]:
#create a mask to avoid overfitting
mask= torch.triu(torch.ones(context_length, context_length), diagonal= 1)
print(mask.shape)
mask

torch.Size([5, 5])


tensor([[0., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1.],
        [0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0.]])

In [41]:
#converted to boolean (true or false)
mask_bool= mask.bool()[:context_length, :context_length]
print(mask_bool.shape)
mask_bool

torch.Size([5, 5])


tensor([[False,  True,  True,  True,  True],
        [False, False,  True,  True,  True],
        [False, False, False,  True,  True],
        [False, False, False, False,  True],
        [False, False, False, False, False]])

In [42]:
#fill attention scores
attn_scores.masked_fill_(mask_bool, -torch.inf)

tensor([[[[-1.3464e-01,        -inf,        -inf,        -inf,        -inf],
          [ 1.4673e-01, -4.2929e-01,        -inf,        -inf,        -inf],
          [-2.6200e-01,  1.5220e+00,  5.6800e-01,        -inf,        -inf],
          [ 1.2262e-01, -4.1868e-01, -1.0446e-02,  1.2211e-02,        -inf],
          [ 3.3008e-02,  4.4287e-02,  1.3374e-01,  4.7900e-02,  5.8627e-02]],

         [[-7.9567e-01,        -inf,        -inf,        -inf,        -inf],
          [ 3.5357e-01,  1.2900e-01,        -inf,        -inf,        -inf],
          [-4.9187e-01,  1.3913e-01, -6.8389e-01,        -inf,        -inf],
          [ 9.4263e-02,  2.9577e-02,  3.7099e-02,  7.6834e-02,        -inf],
          [ 3.5126e-01, -1.5351e-02,  3.4803e-01,  1.5313e-01, -1.5862e-01]],

         [[-3.7064e-01,        -inf,        -inf,        -inf,        -inf],
          [-2.9993e-01,  1.9692e-01,        -inf,        -inf,        -inf],
          [ 7.6120e-01, -6.7346e-01, -1.3887e-01,        -inf,        -i

In [43]:
#check shapes one last time
print(mask_bool.shape)
print(attn_scores.shape)

torch.Size([5, 5])
torch.Size([2, 3, 5, 5])
