<a href="https://colab.research.google.com/github/shivankgoel/TranformersFromScratch/blob/main/2_multihead_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import torch
import numpy as np
import math
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F

def display_attention_matrix(attention_weights):
  print("Plotting {}".format(attention_weights))
  # Plot the heatmap
  plt.figure(figsize=(8, 6))
  plt.imshow(attention_weights, cmap='viridis', interpolation='nearest')
  plt.colorbar(label='Attention Weight')
  plt.title('Attention Heatmap')
  plt.xlabel('Key Tokens')
  plt.ylabel('Query Tokens')

  # Annotate the heatmap for clarity
  for i in range(attention_weights.shape[0]):
      for j in range(attention_weights.shape[1]):
          plt.text(j, i, f'{attention_weights[i, j]:.2f}', ha='center', va='center', color='white')

  # Display the heatmap
  plt.tight_layout()
  plt.show()

def softmax(x):
  return (np.exp(x) / np.sum(np.exp(x), axis=1, keepdims=True))

In [12]:
batch_sz, voc_sz, input_sz, qkv_sz = 1, 4, 512, 512
input = torch.randn((batch_sz, voc_sz, qkv_sz))
print("Input size", input.size())

qkv_layer = nn.Linear(input_sz, 3 * qkv_sz)
qkv = qkv_layer(input)
print("QKV size", qkv.size())

Input size torch.Size([1, 4, 512])
QKV size torch.Size([1, 4, 1536])


In [11]:
num_heads = 8
qkv_sz_per_head = qkv_sz // num_heads
qkv = qkv.reshape(batch_sz, voc_sz, num_heads, 3 * qkv_sz_per_head)
print(qkv.size())
qkv = qkv.permute(0, 2, 1, 3)
print(qkv.size())
q, k, v = qkv.chunk(3, dim=-1)
print(q.size(), k.size(), v.size())

torch.Size([1, 4, 8, 192])
torch.Size([1, 8, 4, 192])
torch.Size([1, 8, 4, 64]) torch.Size([1, 8, 4, 64]) torch.Size([1, 8, 4, 64])


In [25]:
attentions = torch.matmul(q, k.transpose(-1,-2)) / math.sqrt(qkv_sz_per_head)
print(attentions.size())

mask  = torch.full(attentions.size(), float('-inf'))
mask = torch.triu(mask, diagonal=1)
print(mask[0][0])

masked_attentions = attentions + mask
print(masked_attentions.size())
attention_weighted_values =  torch.matmul(masked_attentions, v)
print(attention_weighted_values.size())
attention_weighted_values = attention_weighted_values.permute(0, 2, 1, 3)
print(attention_weighted_values.size())
attention_weighted_values = attention_weighted_values.reshape(batch_sz, voc_sz, qkv_sz)
print(attention_weighted_values.size())

torch.Size([1, 8, 4, 4])
tensor([[0., -inf, -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., -inf],
        [0., 0., 0., 0.]])
torch.Size([1, 8, 4, 4])
torch.Size([1, 8, 4, 64])
torch.Size([1, 4, 8, 64])
torch.Size([1, 4, 512])


In [30]:
class MultiHeadAttention(nn.Module):
  def __init__(self, input_sz, qkv_sz, num_heads):
    super(MultiHeadAttention, self).__init__()
    self.input_sz = input_sz
    self.qkv_sz = qkv_sz
    self.num_heads = num_heads
    self.qkv_sz_per_head = qkv_sz // num_heads

    self.qkv_layer = nn.Linear(input_sz, 3 * qkv_sz)
    self.out_layer = nn.Linear(qkv_sz, qkv_sz)


  def forward(self, input, mask = False):
    batch_sz, voc_sz, input_sz = input.size()
    print("Input size {}".format(input.size()))
    qkv = self.qkv_layer(input)
    print("QKV size {}".format(qkv.size()))
    qkv = qkv.reshape(batch_sz, voc_sz, self.num_heads, 3 * self.qkv_sz_per_head)
    print("QKV size after heads division {}".format(qkv.size()))
    qkv = qkv.permute(0, 2, 1, 3)
    print("QKV size after permuting {}".format(qkv.size()))
    q, k, v = qkv.chunk(3, dim=-1)
    print("Q size {}, K size {}, V size {}".format(q.size(), k.size(), v.size()))
    attentions = torch.matmul(q, k.transpose(-1,-2)) / math.sqrt(self.qkv_sz_per_head)
    print("Attentions size {}".format(attentions.size()))
    if mask:
      mask  = torch.full(attentions.size(), float('-inf'))
      mask = torch.triu(mask, diagonal=1)
      attentions = attentions + mask
    attention_weighted_values =  torch.matmul(masked_attentions, v)
    print("Attention weighted values size {}".format(attention_weighted_values.size()))
    attention_weighted_values = attention_weighted_values.permute(0, 2, 1, 3)
    print("Attention weighted values size after permuting {}".format(attention_weighted_values.size()))
    attention_weighted_values = attention_weighted_values.reshape(batch_sz, voc_sz, self.qkv_sz)
    print("Attention weighted values size after reshaping {}".format(attention_weighted_values.size()))
    out = self.out_layer(attention_weighted_values)
    print("Output size after linear layer {}".format(out.size()))
    return out


multihead_attention = MultiHeadAttention(input_sz, qkv_sz, num_heads)
out = multihead_attention(input, mask=True)

Input size torch.Size([1, 4, 512])
QKV size torch.Size([1, 4, 1536])
QKV size after heads division torch.Size([1, 4, 8, 192])
QKV size after permuting torch.Size([1, 8, 4, 192])
Q size torch.Size([1, 8, 4, 64]), K size torch.Size([1, 8, 4, 64]), V size torch.Size([1, 8, 4, 64])
Attentions size torch.Size([1, 8, 4, 4])
Attention weighted values size torch.Size([1, 8, 4, 64])
Attention weighted values size after permuting torch.Size([1, 4, 8, 64])
Attention weighted values size after reshaping torch.Size([1, 4, 512])
Output size after linear layer torch.Size([1, 4, 512])
