<a href="https://colab.research.google.com/github/teelch0/Data-Mining/blob/main/Multi_Head_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [85]:
import torch
import torch.nn as nn

In [86]:
#Multi-head attention
#multiple causal attentions stuck together
class MultiHeadAttention(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
    super().__init__()
    self.heads = nn.ModuleList([CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)])

  def forward(self, x):
    return torch.cat( [ head(x) for head in self.heads], dim= -1)

In [87]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        # As in `CausalAttention`, for inputs where `num_tokens` exceeds `context_length`,
        # this will result in errors in the mask creation further below.
        # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs
        # do not exceed `context_length` before reaching this forward method.

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec

In [88]:
context_length= 5
d_in= 10

In [89]:
#Start with creating batches
inputs= torch.nn.Embedding(5, 10)
inputs= inputs.weight
inputs= inputs.data
inputs.shape
inputs

tensor([[ 0.5884, -0.4747, -0.1274, -1.2693,  1.0388, -0.0463,  0.4765,  0.3781,
          0.3754,  1.2586],
        [-0.5737,  0.5784,  0.1409,  0.1539, -2.0708, -0.8233,  0.1262,  0.2651,
         -0.2559,  0.1415],
        [-0.7087,  0.2665, -0.1157,  0.7769, -0.3563, -0.4336, -0.9801,  0.0788,
          0.0539,  0.1857],
        [-0.7771,  0.0708, -0.4521, -0.8356,  0.1321, -0.9660, -0.2906,  0.5261,
         -1.4621,  0.1175],
        [ 1.5643,  0.9243, -0.7974, -1.3218, -0.0785, -0.7778,  0.1165,  0.3465,
          0.0206, -0.6479]])

In [90]:
batches = torch.stack( (inputs, inputs), dim=0)
print(batches.shape)
batches

torch.Size([2, 5, 10])


tensor([[[ 0.5884, -0.4747, -0.1274, -1.2693,  1.0388, -0.0463,  0.4765,
           0.3781,  0.3754,  1.2586],
         [-0.5737,  0.5784,  0.1409,  0.1539, -2.0708, -0.8233,  0.1262,
           0.2651, -0.2559,  0.1415],
         [-0.7087,  0.2665, -0.1157,  0.7769, -0.3563, -0.4336, -0.9801,
           0.0788,  0.0539,  0.1857],
         [-0.7771,  0.0708, -0.4521, -0.8356,  0.1321, -0.9660, -0.2906,
           0.5261, -1.4621,  0.1175],
         [ 1.5643,  0.9243, -0.7974, -1.3218, -0.0785, -0.7778,  0.1165,
           0.3465,  0.0206, -0.6479]],

        [[ 0.5884, -0.4747, -0.1274, -1.2693,  1.0388, -0.0463,  0.4765,
           0.3781,  0.3754,  1.2586],
         [-0.5737,  0.5784,  0.1409,  0.1539, -2.0708, -0.8233,  0.1262,
           0.2651, -0.2559,  0.1415],
         [-0.7087,  0.2665, -0.1157,  0.7769, -0.3563, -0.4336, -0.9801,
           0.0788,  0.0539,  0.1857],
         [-0.7771,  0.0708, -0.4521, -0.8356,  0.1321, -0.9660, -0.2906,
           0.5261, -1.4621,  0.1175],

In [91]:
d_out= 6
num_heads= 3
head_dim= 2
b= 2

In [92]:
W_query= nn.Linear(d_in, d_out, bias= False)
W_key= nn.Linear(d_in, d_out, bias= False)

In [93]:
keys= W_key(batches)
print(keys.shape)
keys

torch.Size([2, 5, 6])


tensor([[[ 0.1925,  0.5577,  0.7312,  0.4568, -0.1857, -0.0493],
         [-0.2258, -0.0152, -0.5871, -0.2813, -0.2708,  0.9944],
         [ 0.1086, -0.1215, -0.2574, -0.0257, -0.0733,  0.1241],
         [-0.3641,  0.7743, -0.7566,  0.1291, -0.5361,  0.9804],
         [-0.0889, -0.0881, -0.2959,  0.3891, -0.2442,  0.3458]],

        [[ 0.1925,  0.5577,  0.7312,  0.4568, -0.1857, -0.0493],
         [-0.2258, -0.0152, -0.5871, -0.2813, -0.2708,  0.9944],
         [ 0.1086, -0.1215, -0.2574, -0.0257, -0.0733,  0.1241],
         [-0.3641,  0.7743, -0.7566,  0.1291, -0.5361,  0.9804],
         [-0.0889, -0.0881, -0.2959,  0.3891, -0.2442,  0.3458]]],
       grad_fn=<UnsafeViewBackward0>)

In [94]:
queries = W_query(batches)
print(queries.shape)
queries

torch.Size([2, 5, 6])


tensor([[[-0.2630,  0.3940,  0.2773,  0.1226, -0.4435,  0.1881],
         [ 0.4408, -0.0268,  0.4715,  0.5891, -0.3121, -0.7911],
         [ 0.2078,  0.0412,  0.2971, -0.0481,  0.2002, -0.0091],
         [ 0.5877,  0.0107,  0.0603,  0.6680,  0.2824, -0.7447],
         [-0.2170,  0.1072, -0.1667,  0.0892, -0.2095, -0.8082]],

        [[-0.2630,  0.3940,  0.2773,  0.1226, -0.4435,  0.1881],
         [ 0.4408, -0.0268,  0.4715,  0.5891, -0.3121, -0.7911],
         [ 0.2078,  0.0412,  0.2971, -0.0481,  0.2002, -0.0091],
         [ 0.5877,  0.0107,  0.0603,  0.6680,  0.2824, -0.7447],
         [-0.2170,  0.1072, -0.1667,  0.0892, -0.2095, -0.8082]]],
       grad_fn=<UnsafeViewBackward0>)

In [95]:
keys= keys.view(b, context_length , num_heads, head_dim)
print(keys.shape)
keys

torch.Size([2, 5, 3, 2])


tensor([[[[ 0.1925,  0.5577],
          [ 0.7312,  0.4568],
          [-0.1857, -0.0493]],

         [[-0.2258, -0.0152],
          [-0.5871, -0.2813],
          [-0.2708,  0.9944]],

         [[ 0.1086, -0.1215],
          [-0.2574, -0.0257],
          [-0.0733,  0.1241]],

         [[-0.3641,  0.7743],
          [-0.7566,  0.1291],
          [-0.5361,  0.9804]],

         [[-0.0889, -0.0881],
          [-0.2959,  0.3891],
          [-0.2442,  0.3458]]],


        [[[ 0.1925,  0.5577],
          [ 0.7312,  0.4568],
          [-0.1857, -0.0493]],

         [[-0.2258, -0.0152],
          [-0.5871, -0.2813],
          [-0.2708,  0.9944]],

         [[ 0.1086, -0.1215],
          [-0.2574, -0.0257],
          [-0.0733,  0.1241]],

         [[-0.3641,  0.7743],
          [-0.7566,  0.1291],
          [-0.5361,  0.9804]],

         [[-0.0889, -0.0881],
          [-0.2959,  0.3891],
          [-0.2442,  0.3458]]]], grad_fn=<ViewBackward0>)

In [96]:
queries= queries.view(b, context_length , num_heads, head_dim)
print(queries.shape)
queries

torch.Size([2, 5, 3, 2])


tensor([[[[-0.2630,  0.3940],
          [ 0.2773,  0.1226],
          [-0.4435,  0.1881]],

         [[ 0.4408, -0.0268],
          [ 0.4715,  0.5891],
          [-0.3121, -0.7911]],

         [[ 0.2078,  0.0412],
          [ 0.2971, -0.0481],
          [ 0.2002, -0.0091]],

         [[ 0.5877,  0.0107],
          [ 0.0603,  0.6680],
          [ 0.2824, -0.7447]],

         [[-0.2170,  0.1072],
          [-0.1667,  0.0892],
          [-0.2095, -0.8082]]],


        [[[-0.2630,  0.3940],
          [ 0.2773,  0.1226],
          [-0.4435,  0.1881]],

         [[ 0.4408, -0.0268],
          [ 0.4715,  0.5891],
          [-0.3121, -0.7911]],

         [[ 0.2078,  0.0412],
          [ 0.2971, -0.0481],
          [ 0.2002, -0.0091]],

         [[ 0.5877,  0.0107],
          [ 0.0603,  0.6680],
          [ 0.2824, -0.7447]],

         [[-0.2170,  0.1072],
          [-0.1667,  0.0892],
          [-0.2095, -0.8082]]]], grad_fn=<ViewBackward0>)

In [97]:
keys_T= keys.transpose(1,2)
print(keys_T.shape)
keys_T

torch.Size([2, 3, 5, 2])


tensor([[[[ 0.1925,  0.5577],
          [-0.2258, -0.0152],
          [ 0.1086, -0.1215],
          [-0.3641,  0.7743],
          [-0.0889, -0.0881]],

         [[ 0.7312,  0.4568],
          [-0.5871, -0.2813],
          [-0.2574, -0.0257],
          [-0.7566,  0.1291],
          [-0.2959,  0.3891]],

         [[-0.1857, -0.0493],
          [-0.2708,  0.9944],
          [-0.0733,  0.1241],
          [-0.5361,  0.9804],
          [-0.2442,  0.3458]]],


        [[[ 0.1925,  0.5577],
          [-0.2258, -0.0152],
          [ 0.1086, -0.1215],
          [-0.3641,  0.7743],
          [-0.0889, -0.0881]],

         [[ 0.7312,  0.4568],
          [-0.5871, -0.2813],
          [-0.2574, -0.0257],
          [-0.7566,  0.1291],
          [-0.2959,  0.3891]],

         [[-0.1857, -0.0493],
          [-0.2708,  0.9944],
          [-0.0733,  0.1241],
          [-0.5361,  0.9804],
          [-0.2442,  0.3458]]]], grad_fn=<TransposeBackward0>)

In [98]:
queries_T= queries.transpose(1,2)
print(queries_T.shape)
queries_T

torch.Size([2, 3, 5, 2])


tensor([[[[-0.2630,  0.3940],
          [ 0.4408, -0.0268],
          [ 0.2078,  0.0412],
          [ 0.5877,  0.0107],
          [-0.2170,  0.1072]],

         [[ 0.2773,  0.1226],
          [ 0.4715,  0.5891],
          [ 0.2971, -0.0481],
          [ 0.0603,  0.6680],
          [-0.1667,  0.0892]],

         [[-0.4435,  0.1881],
          [-0.3121, -0.7911],
          [ 0.2002, -0.0091],
          [ 0.2824, -0.7447],
          [-0.2095, -0.8082]]],


        [[[-0.2630,  0.3940],
          [ 0.4408, -0.0268],
          [ 0.2078,  0.0412],
          [ 0.5877,  0.0107],
          [-0.2170,  0.1072]],

         [[ 0.2773,  0.1226],
          [ 0.4715,  0.5891],
          [ 0.2971, -0.0481],
          [ 0.0603,  0.6680],
          [-0.1667,  0.0892]],

         [[-0.4435,  0.1881],
          [-0.3121, -0.7911],
          [ 0.2002, -0.0091],
          [ 0.2824, -0.7447],
          [-0.2095, -0.8082]]]], grad_fn=<TransposeBackward0>)

In [99]:
attn_scores= queries_T @ keys_T.transpose(2, 3)
print(attn_scores.shape)
attn_scores

torch.Size([2, 3, 5, 5])


tensor([[[[ 0.1691,  0.0534, -0.0764,  0.4009, -0.0113],
          [ 0.0699, -0.0991,  0.0511, -0.1812, -0.0368],
          [ 0.0630, -0.0476,  0.0176, -0.0438, -0.0221],
          [ 0.1191, -0.1329,  0.0625, -0.2058, -0.0532],
          [ 0.0180,  0.0474, -0.0366,  0.1620,  0.0098]],

         [[ 0.2588, -0.1973, -0.0745, -0.1940, -0.0344],
          [ 0.6138, -0.4425, -0.1365, -0.2806,  0.0897],
          [ 0.1953, -0.1609, -0.0752, -0.2310, -0.1066],
          [ 0.3492, -0.2233, -0.0327,  0.0407,  0.2421],
          [-0.0812,  0.0728,  0.0406,  0.1377,  0.0841]],

         [[ 0.0731,  0.3071,  0.0559,  0.4222,  0.1734],
          [ 0.0969, -0.7022, -0.0753, -0.6083, -0.1973],
          [-0.0367, -0.0633, -0.0158, -0.1163, -0.0521],
          [-0.0158, -0.8170, -0.1131, -0.8815, -0.3265],
          [ 0.0787, -0.7470, -0.0849, -0.6801, -0.2283]]],


        [[[ 0.1691,  0.0534, -0.0764,  0.4009, -0.0113],
          [ 0.0699, -0.0991,  0.0511, -0.1812, -0.0368],
          [ 0.0630, -0.

In [100]:
mask= torch.triu(torch.ones(context_length, context_length), diagonal= 1)
print(mask.shape)
mask

torch.Size([5, 5])


tensor([[0., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1.],
        [0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0.]])

In [101]:
mask_bool= mask.bool()[:context_length, :context_length]
print(mask_bool.shape)
mask_bool

torch.Size([5, 5])


tensor([[False,  True,  True,  True,  True],
        [False, False,  True,  True,  True],
        [False, False, False,  True,  True],
        [False, False, False, False,  True],
        [False, False, False, False, False]])

In [102]:
attn_scores.masked_fill_(mask_bool, -torch.inf)

tensor([[[[ 0.1691,    -inf,    -inf,    -inf,    -inf],
          [ 0.0699, -0.0991,    -inf,    -inf,    -inf],
          [ 0.0630, -0.0476,  0.0176,    -inf,    -inf],
          [ 0.1191, -0.1329,  0.0625, -0.2058,    -inf],
          [ 0.0180,  0.0474, -0.0366,  0.1620,  0.0098]],

         [[ 0.2588,    -inf,    -inf,    -inf,    -inf],
          [ 0.6138, -0.4425,    -inf,    -inf,    -inf],
          [ 0.1953, -0.1609, -0.0752,    -inf,    -inf],
          [ 0.3492, -0.2233, -0.0327,  0.0407,    -inf],
          [-0.0812,  0.0728,  0.0406,  0.1377,  0.0841]],

         [[ 0.0731,    -inf,    -inf,    -inf,    -inf],
          [ 0.0969, -0.7022,    -inf,    -inf,    -inf],
          [-0.0367, -0.0633, -0.0158,    -inf,    -inf],
          [-0.0158, -0.8170, -0.1131, -0.8815,    -inf],
          [ 0.0787, -0.7470, -0.0849, -0.6801, -0.2283]]],


        [[[ 0.1691,    -inf,    -inf,    -inf,    -inf],
          [ 0.0699, -0.0991,    -inf,    -inf,    -inf],
          [ 0.0630, -0.

In [103]:
print(mask_bool.shape)
print(attn_scores.shape)

torch.Size([5, 5])
torch.Size([2, 3, 5, 5])
