In [1]:
import torch

In [2]:
input_embeddings = torch.tensor([
    [0.43, 0.15, 0.89], #x1
    [0.55, 0.87, 0.66], #x2
    [0.57, 0.85, 0.64], #x3
    [0.22, 0.58, 0.33], #x4
    [0.77, 0.25, 0.10], #x5
    [0.05, 0.80, 0.55]  #x6
])

## **Self-Attention without trainable weights**

### Lets first apply self-attention to single input embedding (x2)

In [3]:
query = input_embeddings[1]
query

tensor([0.5500, 0.8700, 0.6600])

In [4]:
attention_scores = torch.empty(input_embeddings.shape[0])

for i, input in enumerate(input_embeddings):
  dot_product = torch.dot(input,query)
  attention_scores[i] = dot_product

print(attention_scores)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


In [5]:
#Alternative way
input_embeddings @ query.T

  input_embeddings @ query.T


tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

In [6]:
#noramlize attention scores for interpretability  and stable training which will be used attention weights
attention_weights = torch.softmax(attention_scores, dim=-1)
attention_weights

tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])

In [7]:
#next we will use these attention weights for generating context vector z2 for the x2 in the input sequence
context_vector_z2 = attention_weights @ input_embeddings
context_vector_z2

tensor([0.4419, 0.6515, 0.5683])

In [8]:
################################
############## OR ##############
################################
context_vector_z2 = torch.zeros(query.shape)
for i, input in enumerate(input_embeddings):
  context_vector_z2 += input * attention_weights[i]

context_vector_z2

tensor([0.4419, 0.6515, 0.5683])

### Self-attention to all elements in the input sequence

In [9]:
attention_scores_all = input_embeddings @ input_embeddings.T

attention_scores_all

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

In [10]:
attention_weights_all = torch.softmax(attention_scores_all, dim=-1)

attention_weights_all

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

In [11]:
#comparing individual and batch attention weights for x2 element
print('Attention weights for x2 (individual calculation): ', attention_weights)
print('Attention weights for x2 (batch calculation): ', attention_weights_all[1])

Attention weights for x2 (individual calculation):  tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Attention weights for x2 (batch calculation):  tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])


As can be seen that both the attention weight tensors for x2 are same

In [12]:
#computing the context vectors for all the elements in the input sequence
context_vectors_all = attention_weights_all @ input_embeddings
context_vectors_all

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

In [13]:
print("Context vector z2 (individual calculation): ", context_vector_z2)
print("Context vector z2 (batch calculation): ", context_vectors_all[1])

Context vector z2 (individual calculation):  tensor([0.4419, 0.6515, 0.5683])
Context vector z2 (batch calculation):  tensor([0.4419, 0.6515, 0.5683])


## **Self-Attention with trainable weights**

### For single input element x2

In [14]:
#lets create weights matrix for query, key and value
x2 = input_embeddings[1]
dim_in = input_embeddings.shape[-1]
dim_out = 2

torch.manual_seed(123)
w_query = torch.nn.Parameter(torch.rand(dim_in,dim_out), requires_grad=False)
w_keys = torch.nn.Parameter(torch.rand(dim_in,dim_out), requires_grad=False)
w_values = torch.nn.Parameter(torch.rand(dim_in,dim_out), requires_grad=False)

In [15]:
query = x2 @ w_query
keys = input_embeddings @ w_keys
values = input_embeddings @ w_values

query

tensor([0.4306, 1.4551])

In [16]:
attention_scores = query @ keys.T
attention_scores

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])

In [17]:
#lets scale the attention scores using softmax but before that we will scale the dot-products (or attention scores) by dividing it by square root of key vector's dimension size
attention_weights = torch.softmax((attention_scores/keys.shape[-1] ** 0.5), dim=-1)
attention_weights

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])

In [18]:
#lets calculate the context vector for x2
context_vector_z2 = attention_weights @ values
context_vector_z2

tensor([0.3061, 0.8210])

### Lets generalize this approach through class implementation

In [19]:
class SelfAttention_v1(torch.nn.Module):
  def __init__(self, dim_in, dim_out):
    super().__init__()
    self.w_query = torch.nn.Parameter(torch.rand(dim_in, dim_out))
    self.w_key = torch.nn.Parameter(torch.rand(dim_in, dim_out))
    self.w_value = torch.nn.Parameter(torch.rand(dim_in, dim_out))

  def forward(self, inputs):
    queries = inputs @ self.w_query
    keys = inputs @ self.w_key
    values = inputs @ self.w_value
    attention_scores_all = queries @ keys.T
    attention_weights_all = torch.softmax(attention_scores_all/keys.shape[-1] ** 0.5, dim=-1)
    context_vectors_all = attention_weights_all @ values

    return context_vectors_all

In [20]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(dim_in, dim_out)
print(sa_v1(input_embeddings))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


Lets use Linear layer for weight matrix initialization and matrix multiplication

In [21]:
class SelfAttention_v2(torch.nn.Module):
  def __init__(self, dim_in, dim_out, qkv_bias=False):
    super().__init__()
    self.w_query = torch.nn.Linear(dim_in, dim_out, bias=qkv_bias)
    self.w_key = torch.nn.Linear(dim_in, dim_out, bias=qkv_bias)
    self.w_value = torch.nn.Linear(dim_in, dim_out, bias=qkv_bias)

  def forward(self, inputs):
    queries = self.w_query(inputs)
    keys = self.w_key(inputs)
    values = self.w_value(inputs)
    attention_scores_all = queries @ keys.T
    attention_weights_all = torch.softmax(attention_scores_all/keys.shape[-1] ** 0.5, dim=-1)
    context_vectors_all = attention_weights_all @ values

    return context_vectors_all

In [22]:
sa_v2 = SelfAttention_v2(dim_in, dim_out)
print(sa_v2(input_embeddings))

tensor([[0.5085, 0.3508],
        [0.5084, 0.3508],
        [0.5084, 0.3506],
        [0.5074, 0.3471],
        [0.5076, 0.3446],
        [0.5077, 0.3493]], grad_fn=<MmBackward0>)


Both the classes results different context vectors because of different weight parameter initialization

Lets use the weights of Linear layers in SelfAttention_v1

In [23]:
sa_v1.w_query = torch.nn.Parameter(sa_v2.w_query.weight.T) #transposing the weight matrix to match the original shape since Linear layer uses weight matrix which is transposed
sa_v1.w_key = torch.nn.Parameter(sa_v2.w_key.weight.T)
sa_v1.w_value = torch.nn.Parameter(sa_v2.w_value.weight.T)

print(sa_v1(input_embeddings))

tensor([[0.5085, 0.3508],
        [0.5084, 0.3508],
        [0.5084, 0.3506],
        [0.5074, 0.3471],
        [0.5076, 0.3446],
        [0.5077, 0.3493]], grad_fn=<MmBackward0>)


## **Causal Attention**

In [24]:
context_length = input_embeddings.shape[0]
context_length

6

In [25]:
#lets create a mask
mask = torch.triu(torch.ones(context_length,context_length), diagonal=1)
mask

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])

In [26]:
masked_attention_scores_all = attention_scores_all.masked_fill(mask.bool(), -torch.inf)
masked_attention_scores_all

tensor([[0.9995,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.9544, 1.4950,   -inf,   -inf,   -inf,   -inf],
        [0.9422, 1.4754, 1.4570,   -inf,   -inf,   -inf],
        [0.4753, 0.8434, 0.8296, 0.4937,   -inf,   -inf],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654,   -inf],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

In [27]:
masked_attention_weights_all = torch.softmax(masked_attention_scores_all/input_embeddings.shape[-1]**0.5, dim=-1)
masked_attention_weights_all

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4226, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2698, 0.3670, 0.3632, 0.0000, 0.0000, 0.0000],
        [0.2235, 0.2764, 0.2742, 0.2259, 0.0000, 0.0000],
        [0.1858, 0.2146, 0.2157, 0.1744, 0.2095, 0.0000],
        [0.1511, 0.1965, 0.1936, 0.1533, 0.1243, 0.1811]])

In [28]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
dropout(masked_attention_weights_all)

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.1548, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.7263, 0.0000, 0.0000, 0.0000],
        [0.4470, 0.5528, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3717, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3931, 0.0000, 0.0000, 0.0000, 0.0000]])

### Compact class implementation for batch processing

In [29]:
batch_inp = torch.stack((input_embeddings, input_embeddings), dim=0)
batch_inp.shape

torch.Size([2, 6, 3])

In [30]:
class CausalAttention(torch.nn.Module):
  def __init__(self, dim_in, dim_out, context_length, dropout, qkv_bias=False):
    super().__init__()
    self.w_query = torch.nn.Linear(dim_in, dim_out, bias=qkv_bias)
    self.w_key = torch.nn.Linear(dim_in, dim_out, bias=qkv_bias)
    self.w_value = torch.nn.Linear(dim_in, dim_out, bias=qkv_bias)
    self.dropout = torch.nn.Dropout(dropout)
    self.register_buffer(
        'mask',
        torch.triu(torch.ones(context_length, context_length), diagonal=1)
    )

  def forward(self, x):
    batch_size, num_tokens, dim_in = x.shape
    queries = self.w_query(x)
    keys = self.w_key(x)
    values = self.w_value(x)

    attention_scores = queries @ keys.transpose(1,2)
    attention_weights = torch.softmax(
        attention_scores.masked_fill_(self.mask.bool()[:num_tokens,:num_tokens], -torch.inf)/keys.shape[-1]**0.5,
        dim=-1
    )
    attention_weights = self.dropout(attention_weights)

    context_vectors = attention_weights @ values

    return context_vectors

In [31]:
torch.manual_seed(123)
context_length = batch_inp.shape[1]
ca = CausalAttention(3,2,context_length,0.5)
print(ca(batch_inp))

tensor([[[ 0.0000,  0.0000],
         [-0.4368,  0.2142],
         [-0.7751,  0.0077],
         [-0.9140, -0.2769],
         [ 0.0000,  0.0000],
         [-0.6906, -0.0974]],

        [[-0.9038,  0.4432],
         [ 0.0000,  0.0000],
         [-0.2883,  0.1414],
         [-0.9140, -0.2769],
         [-0.4416, -0.1410],
         [-0.5272, -0.1706]]], grad_fn=<UnsafeViewBackward0>)


## **Multi-head Attention**

### Multi-head wrapper around CausalAttention class

In [32]:
class MultiHeadWrapper(torch.nn.Module):
  def __init__(self, dim_in, dim_out, context_length, num_heads, dropout, qkv_bias=False):
    super().__init__()
    self.heads = torch.nn.ModuleList([
        CausalAttention(dim_in, dim_out, context_length, dropout, qkv_bias)
        for _ in range(num_heads)
    ])

  def forward(self, x):
    return torch.cat([head(x) for head in self.heads], dim=-1)

In [33]:
torch.manual_seed(123)
ma_wrapper = MultiHeadWrapper(3, 1, 6, 2, 0)
context_vectors = ma_wrapper(batch_inp)

print(context_vectors.shape)
print(context_vectors)

torch.Size([2, 6, 2])
tensor([[[-0.5740,  0.2216],
         [-0.7320,  0.0155],
         [-0.7774, -0.0546],
         [-0.6979, -0.0817],
         [-0.6538, -0.0957],
         [-0.6424, -0.1065]],

        [[-0.5740,  0.2216],
         [-0.7320,  0.0155],
         [-0.7774, -0.0546],
         [-0.6979, -0.0817],
         [-0.6538, -0.0957],
         [-0.6424, -0.1065]]], grad_fn=<CatBackward0>)


Here, the problem is that we are implementing multihead attention sequentially (not efficient)

### Parallel Multi-head Attention

In [34]:
class MultiHeadAttention(torch.nn.Module):
  def __init__(self, dim_in, dim_out, context_length, dropout, num_heads, qkv_bias=False):
    super().__init__()
    assert (dim_out % num_heads == 0), "dim_out must be divisible by num_heads"

    self.dim_out = dim_out # final merged context vector embedding size
    self.num_heads = num_heads
    self.head_dim = dim_out//num_heads # embedding size of context vector in single head
    self.w_query = torch.nn.Linear(dim_in, dim_out, bias=qkv_bias)
    self.w_key = torch.nn.Linear(dim_in, dim_out, bias=qkv_bias)
    self.w_value = torch.nn.Linear(dim_in, dim_out, bias=qkv_bias)
    self.out_proj = torch.nn.Linear(dim_out, dim_out) # transform merged context_vectors into similar dimension size vectors
    self.dropout = torch.nn.Dropout(dropout)
    self.register_buffer(
        'mask',
        torch.triu(torch.ones(context_length, context_length), diagonal=1)
    )

  def forward(self, x):
    batch_size, num_tokens, dim_in = x.shape
    queries = self.w_query(x)
    keys = self.w_key(x)
    values = self.w_value(x)  #shape (batch_size, num_tokens, dim_out)

    queries = queries.view(batch_size, num_tokens, self.num_heads, self.head_dim)
    keys = keys.view(batch_size, num_tokens, self.num_heads, self.head_dim)
    values = values.view(batch_size, num_tokens, self.num_heads, self.head_dim) #shape (batch_size, num_tokens, num_heads, head_dim)

    queries = queries.transpose(1,2)
    keys = keys.transpose(1,2)
    values = values.transpose(1,2) #shape (batch_size, num_heads, num_tokens, head_dim)

    attention_scores = queries @ keys.transpose(2,3)
    attention_scores.masked_fill_(self.mask.bool()[:num_tokens,:num_tokens], -torch.inf)

    attention_weights = torch.softmax(attention_scores/keys.shape[-1]**0.5, dim=-1)
    attention_weights = self.dropout(attention_weights)

    context_vectors = (attention_weights @ values).transpose(1,2) #transposing axis 1,2  since we have to merge the context vectors by num_heads and head_dim, so required shape will now be (batch_size, num_tokens, num_heads, head_dim)
    context_vectors = context_vectors.contiguous().view(batch_size, num_tokens, self.dim_out)

    context_vectors = self.out_proj(context_vectors)

    return context_vectors

In [35]:
torch.manual_seed(123)
batch_size, context_length, dim_in = batch_inp.shape
dim_out = 2
mha = MultiHeadAttention(dim_in, dim_out, context_length, 0, num_heads=2)
context_vectors = mha(batch_inp)
print(context_vectors.shape)
print(context_vectors)

torch.Size([2, 6, 2])
tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)


In [37]:
# initializing GPT-2 size attention module
dim_in = 768
dim_out = 768
context_length = 1024

gpt_batch_inp = torch.rand(2,context_length, dim_in)

mha2 = MultiHeadAttention(dim_in, dim_out, context_length, 0.1, 12)
context_vectors = mha2(gpt_batch_inp)
print(context_vectors.shape)
print(context_vectors)

torch.Size([2, 1024, 768])
tensor([[[ 0.0170,  0.1816, -0.2170,  ...,  0.0639,  0.0964,  0.2717],
         [-0.0047,  0.0741, -0.1810,  ..., -0.0047,  0.2833,  0.2123],
         [ 0.0483,  0.0444, -0.1795,  ...,  0.0255,  0.2041,  0.1987],
         ...,
         [ 0.0416,  0.1521, -0.1774,  ...,  0.0179,  0.1017,  0.1228],
         [ 0.0436,  0.1570, -0.1777,  ...,  0.0166,  0.1021,  0.1215],
         [ 0.0483,  0.1519, -0.1812,  ...,  0.0172,  0.1026,  0.1172]],

        [[ 0.1762,  0.0427, -0.3035,  ..., -0.1063,  0.1445,  0.1308],
         [-0.0762,  0.0785, -0.2634,  ...,  0.0014,  0.0424,  0.1250],
         [ 0.0960,  0.1276, -0.2463,  ..., -0.0068,  0.1109,  0.1335],
         ...,
         [ 0.0451,  0.1483, -0.1869,  ...,  0.0137,  0.0928,  0.1143],
         [ 0.0488,  0.1542, -0.1884,  ...,  0.0187,  0.0982,  0.1175],
         [ 0.0434,  0.1553, -0.1887,  ...,  0.0109,  0.0900,  0.1159]]],
       grad_fn=<ViewBackward0>)
