<a href="https://colab.research.google.com/github/sobit-nep/Transformer-Neural-Network-from-scratch/blob/main/Multihead_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Multihead attention**

In [14]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [3]:
sequence_length = 4
batch_size = 1
input_dim = 512
d_model = 512
x = torch.randn((batch_size, sequence_length, input_dim))

In [4]:
x.size()

torch.Size([1, 4, 512])

In [5]:
qkv_layer = nn.Linear(input_dim, 3*d_model)

In [6]:
qkv = qkv_layer(x)

In [8]:
qkv.size()

torch.Size([1, 4, 1536])

In [9]:
num_heads = 8
head_dim = d_model//num_heads
qkv = qkv.reshape(batch_size, sequence_length, num_heads, 3*head_dim)

In [10]:
qkv.shape

torch.Size([1, 4, 8, 192])

In [11]:
qkv = qkv.permute(0,2,1,3)
qkv.shape

torch.Size([1, 8, 4, 192])

In [13]:
q,k,v = qkv.chunk(3, dim=-1)
q.shape ,k.shape, v.shape

(torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]))

# **SELF ATTENTION FOR MULTIPLE HEADS**

In [16]:
d_k = q.size()[-1]
scaled = torch.matmul(q, k.transpose(-2,-1))/math.sqrt(d_k)
scaled.shape

torch.Size([1, 8, 4, 4])

In [17]:
mask = torch.full(scaled.size(),float('-inf'))
mask = torch.triu(mask, diagonal=1)
mask[0][1]

tensor([[0., -inf, -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., -inf],
        [0., 0., 0., 0.]])

In [19]:
(scaled + mask)[0][0]

tensor([[-0.0077,    -inf,    -inf,    -inf],
        [ 0.3348, -0.0940,    -inf,    -inf],
        [ 0.1097, -0.0081, -0.3745,    -inf],
        [-0.2475,  0.0282, -0.0200,  0.0883]], grad_fn=<SelectBackward0>)

In [20]:
scaled += mask

In [21]:
attention = F.softmax(scaled, dim=-1)

In [22]:
attention.shape

torch.Size([1, 8, 4, 4])

In [23]:
attention[0][0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.6056, 0.3944, 0.0000, 0.0000],
        [0.3992, 0.3548, 0.2460, 0.0000],
        [0.2011, 0.2650, 0.2525, 0.2814]], grad_fn=<SelectBackward0>)

In [25]:
values = torch.matmul(attention, v)
values.shape

torch.Size([1, 8, 4, 64])

# FUNCTION
All of the above cells are combined as a function

In [26]:
import math
def scaled_dot_product(q, k, v, mask=mask): #masking is needed for decoder only
  d_k = q.size() [-1]
  scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)
  if mask is not None:
    scaled += mask
  attention = F. softmax(scaled, dim=-1)
  values = torch.matmul(attention, v)
  return values, attention

In [30]:
values, attention = scaled_dot_product(q, k, v, mask=mask)

In [31]:
attention.shape

torch.Size([1, 8, 4, 4])

In [32]:
attention[0][0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.6056, 0.3944, 0.0000, 0.0000],
        [0.3992, 0.3548, 0.2460, 0.0000],
        [0.2011, 0.2650, 0.2525, 0.2814]], grad_fn=<SelectBackward0>)