In [8]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# This is the length of input sequence
sequence_length = 4

batch_size = 1

# This is the dimension of the input vector for every single word or token 
input_dimension = 512

# This is the output of attention block for every single word
d_model = 512

# This is the randomly sampled input data since no positional encoding is used here
x = torch.randn((batch_size, sequence_length, input_dimension))

In [3]:
x.size()

torch.Size([1, 4, 512])

In [4]:
# Creating the KEY, QUERY and VALUE vectors
qkv_layer = nn.Linear(input_dimension , 3 * d_model)
qkv = qkv_layer(x)
print(qkv.shape)

torch.Size([1, 4, 1536])


Now if number of heads is 8, it means simply 8 different sets of learnable parameters operate in parallel all of which get same input but produce different context aware vectors. Each head may learn to attend to different parts of the input, capturing various patterns or information.

In [5]:
# The multi head attention will have 8 heads. That means 8 different outputs will be generated from the attention block
# 8 different linear blocks will run simultaneously that do not share wieghts

num_heads = 8
head_dim = d_model // num_heads
qkv = qkv.reshape(batch_size, sequence_length, num_heads, 3 * head_dim)
print(qkv.shape)

torch.Size([1, 4, 8, 192])


In [6]:
qkv = qkv.permute(0, 2, 1, 3)
print(qkv.shape)

torch.Size([1, 8, 4, 192])


In [7]:
q, k, v = qkv.chunk(3, dim=-1)
print(f"QUERY Vector : {q.shape}")
print(f"KEY Vector : {k.shape}")
print(f"VALUE Vector : {v.shape}")

QUERY Vector : torch.Size([1, 8, 4, 64])
KEY Vector : torch.Size([1, 8, 4, 64])
VALUE Vector : torch.Size([1, 8, 4, 64])


In [9]:
# Performing dot product between query and key vectors and then scaling it with sqaure root of dimension of KEY vector

d_k = q.size()[-1]
scaled = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
print(scaled.shape)

torch.Size([1, 8, 4, 4])


In [11]:
mask = torch.full(scaled.size() , float('-inf'))
mask = torch.triu(mask, diagonal=1)

# mask for input to a single head
print(mask[0][1])

tensor([[0., -inf, -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., -inf],
        [0., 0., 0., 0.]])


In [12]:
scaled += mask

In [13]:
attention = F.softmax(scaled, dim=-1)

In [14]:
attention[0][0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.5102, 0.4898, 0.0000, 0.0000],
        [0.4090, 0.3680, 0.2230, 0.0000],
        [0.2990, 0.2103, 0.3047, 0.1860]], grad_fn=<SelectBackward0>)

As we can see after applying sofmax, all negative infinity gets converted to zeroes. So, in decoder, prediction only happens by looking at previous words.

In [15]:
attention.shape

torch.Size([1, 8, 4, 4])

In [16]:
# This is the dot product between the output and VALUE vectors to get the final values which are now more context aware

values = torch.matmul(attention, v)
values.shape

torch.Size([1, 8, 4, 64])

From the shape of tensors, we can see for a single batch, we have 8 heads for input sequence length 4 and 64 size vectors for every single word

Now we will combine all the attention heads since at last we are gonna be needing the same shape as input vector. This concatenation can be said that all attetnion heads will communicate with each other to produce final context aware vector.

In [17]:
values = values.reshape(batch_size, sequence_length, num_heads * head_dim)
print(values.size())

torch.Size([1, 4, 512])


In [18]:
linear_layer = nn.Linear(d_model, d_model)

In [19]:
output = linear_layer(values)

In [20]:
output.shape

torch.Size([1, 4, 512])

In [21]:
output

tensor([[[-0.0051,  0.0204, -0.1722,  ...,  0.0487,  0.1819,  0.1497],
         [ 0.0182, -0.4404,  0.0071,  ...,  0.0053,  0.2802, -0.6434],
         [-0.1348,  0.1406, -0.1483,  ...,  0.1058, -0.4081,  0.1045],
         [ 0.1432, -0.0169,  0.0990,  ...,  0.1788,  0.0016,  0.2969]]],
       grad_fn=<ViewBackward0>)

This is the final vector which is now more context aware than what it originally was.