# Self-Attention Mechanism

The self-attention mechanism allows the model to weigh the importance of different words in a sequence. It helps the model focus on relevant words while encoding a particular word. This is a critical part of the Transformer architecture.

In the prior module, We left off with the position encoding embedding. To continue this process, we are pushing the positional embedding into the next compontent of the transformer, the attention mechanism. 

In [2]:
import torch
import torch.nn as nn
import math

# Positional Encoded Embeddings from the previous notebook
pos_encoded_embeddings = torch.tensor([[[ 0.0171,  1.0654,  0.4616,  0.9196,  0.7193, -0.7430,  0.7120,
           0.4198,  2.7427,  0.2844],
         [ 0.9161, -0.4999,  1.8302,  1.2182,  0.0651,  2.0396,  0.3780,
           0.4085,  1.3560, -1.3176],
         [-0.3461,  0.2526,  0.4120,  0.0398,  1.6146,  2.6475, -0.1887,
          -0.6907, -0.4066,  1.2899],
         [-0.8345, -0.4657, -0.5635,  1.1514,  0.1762,  1.8148, -1.4084,
           0.9153, -1.1734,  1.1989],
         [-0.3981, -0.7277, -1.2657,  1.9887, -0.2399,  0.0412,  1.9375,
           0.6083,  0.2095,  1.5739],
         [-1.5240,  1.2359,  0.5596, -0.1529, -0.4064, -0.0906, -1.6746,
           0.2480, -0.2364,  0.8417],
         [ 0.6992, -1.1193,  0.5642,  1.3861, -0.5185,  0.6701, -0.5353,
           1.8013,  0.7900, -0.9430],
         [-0.3550,  2.1032,  1.8690,  0.7174, -1.7692,  0.8200, -0.9620,
           1.8325,  0.3009,  1.0083],
         [ 1.5902, -0.8516,  3.2954, -0.5147,  0.1798,  1.8522, -1.0186,
           0.6484,  0.9932,  0.4030],
         [-0.5990, -0.4916,  0.5419, -0.0293, -0.3465,  1.1548, -1.1130,
          -0.8118,  0.3455,  0.9474]]], dtype=torch.float)

print("Positional Encoded Embeddings Shape:", pos_encoded_embeddings.shape)
print(pos_encoded_embeddings)

Positional Encoded Embeddings Shape: torch.Size([1, 10, 10])
tensor([[[ 0.0171,  1.0654,  0.4616,  0.9196,  0.7193, -0.7430,  0.7120,
           0.4198,  2.7427,  0.2844],
         [ 0.9161, -0.4999,  1.8302,  1.2182,  0.0651,  2.0396,  0.3780,
           0.4085,  1.3560, -1.3176],
         [-0.3461,  0.2526,  0.4120,  0.0398,  1.6146,  2.6475, -0.1887,
          -0.6907, -0.4066,  1.2899],
         [-0.8345, -0.4657, -0.5635,  1.1514,  0.1762,  1.8148, -1.4084,
           0.9153, -1.1734,  1.1989],
         [-0.3981, -0.7277, -1.2657,  1.9887, -0.2399,  0.0412,  1.9375,
           0.6083,  0.2095,  1.5739],
         [-1.5240,  1.2359,  0.5596, -0.1529, -0.4064, -0.0906, -1.6746,
           0.2480, -0.2364,  0.8417],
         [ 0.6992, -1.1193,  0.5642,  1.3861, -0.5185,  0.6701, -0.5353,
           1.8013,  0.7900, -0.9430],
         [-0.3550,  2.1032,  1.8690,  0.7174, -1.7692,  0.8200, -0.9620,
           1.8325,  0.3009,  1.0083],
         [ 1.5902, -0.8516,  3.2954, -0.5147,  0.17

In [9]:
# Define dimensions
embedding_size = pos_encoded_embeddings.size(2)
# embedding_size is the size of the embeddings vector for each word
heads = 2
# heads is the number of attention heads, allowing the model to focus on different parts of the sentence simultaneously
head_dim = embedding_size // heads
# head_dim is the size of each attention head, ensuring that the embedding size is evenly divisible by the number of heads

print("Embedding Size:", embedding_size)
print("Number of Heads:", heads)
print("Dimension per Head:", head_dim)

Embedding Size: 10
Number of Heads: 2
Dimension per Head: 5


In [10]:
# Define linear transformations for query, key, and value
values_linear = nn.Linear(head_dim, head_dim, bias=False)
keys_linear = nn.Linear(head_dim, head_dim, bias=False)
queries_linear = nn.Linear(head_dim, head_dim, bias=False)

# These linear layers project the input embeddings into different vector spaces (query, key, and value)
# This helps in calculating attention scores and obtaining the relevant context for each word

print("Linear layers for values, keys, and queries are defined.")

Linear layers for values, keys, and queries are defined.


In [12]:
# Reshape the input embeddings to split into heads
N = pos_encoded_embeddings.shape[0]
# N is the batch size, which is the number of sentences being processed at once
value_len, key_len, query_len = pos_encoded_embeddings.shape[1], pos_encoded_embeddings.shape[1], pos_encoded_embeddings.shape[1]
print(f"value len: {value_len}")
print(f"key len: {value_len}")
print(f"query len: {value_len}")
# value_len, key_len, and query_len are the lengths of the input sequences

values = pos_encoded_embeddings.reshape(N, value_len, heads, head_dim)
keys = pos_encoded_embeddings.reshape(N, key_len, heads, head_dim)
queries = pos_encoded_embeddings.reshape(N, query_len, heads, head_dim)
# The input embeddings are reshaped to separate the heads for multi-head attention

print("Values Shape after Reshape:", values.shape)
print("Keys Shape after Reshape:", keys.shape)
print("Queries Shape after Reshape:", queries.shape)

value len: 10
key len: 10
query len: 10
Values Shape after Reshape: torch.Size([1, 10, 2, 5])
Keys Shape after Reshape: torch.Size([1, 10, 2, 5])
Queries Shape after Reshape: torch.Size([1, 10, 2, 5])


In [13]:
# Apply linear transformations to the reshaped embeddings
values = values_linear(values)
keys = keys_linear(keys)
queries = queries_linear(queries)

# The reshaped embeddings are projected into different vector spaces (query, key, and value)
# This helps in calculating attention scores and obtaining the relevant context for each word

print("Values Shape after Linear Transformation:", values.shape)
print("Keys Shape after Linear Transformation:", keys.shape)
print("Queries Shape after Linear Transformation:", queries.shape)

Values Shape after Linear Transformation: torch.Size([1, 10, 2, 5])
Keys Shape after Linear Transformation: torch.Size([1, 10, 2, 5])
Queries Shape after Linear Transformation: torch.Size([1, 10, 2, 5])


In [15]:
# Scaled dot-product attention
d_k = queries.size(-1)
# d_k is the dimension of the keys, used for scaling the dot product to prevent large values

scores = torch.matmul(queries, keys.transpose(-2, -1)) / math.sqrt(d_k)
# scores are calculated as the dot product of queries and keys, scaled by the square root of the key dimension

print("Scores Shape:", scores.shape)
print("Scores:", scores)

Scores Shape: torch.Size([1, 10, 2, 2])
Scores: tensor([[[[-0.1030, -0.2223],
          [-0.3210, -0.6317]],

         [[-0.2229, -0.1614],
          [ 0.2494,  0.2484]],

         [[ 0.1188,  0.0469],
          [ 0.5596, -0.2194]],

         [[-0.2212,  0.0894],
          [ 0.1758, -0.3274]],

         [[ 0.1100, -0.2067],
          [-0.1585,  0.0536]],

         [[-0.1201, -0.1736],
          [ 0.4840, -0.1881]],

         [[ 0.0022, -0.0563],
          [ 0.0427,  0.0560]],

         [[-0.0377, -0.1694],
          [ 0.3752, -0.2123]],

         [[-0.1618, -0.6680],
          [-0.3269, -0.1881]],

         [[ 0.0133, -0.1859],
          [-0.0267,  0.0188]]]], grad_fn=<DivBackward0>)


In [16]:
# Apply softmax to get attention weights
attention = torch.nn.functional.softmax(scores, dim=-1)
# The attention weights are calculated using the softmax function to normalize the scores

print("Attention Shape:", attention.shape)
print("Attention Weights:", attention)

Attention Shape: torch.Size([1, 10, 2, 2])
Attention Weights: tensor([[[[0.5298, 0.4702],
          [0.5771, 0.4229]],

         [[0.4846, 0.5154],
          [0.5003, 0.4997]],

         [[0.5180, 0.4820],
          [0.6855, 0.3145]],

         [[0.4230, 0.5770],
          [0.6232, 0.3768]],

         [[0.5785, 0.4215],
          [0.4472, 0.5528]],

         [[0.5134, 0.4866],
          [0.6620, 0.3380]],

         [[0.5146, 0.4854],
          [0.4967, 0.5033]],

         [[0.5329, 0.4671],
          [0.6428, 0.3572]],

         [[0.6239, 0.3761],
          [0.4653, 0.5347]],

         [[0.5496, 0.4504],
          [0.4886, 0.5114]]]], grad_fn=<SoftmaxBackward0>)


In [17]:
# Compute the output as a weighted sum of values
output = torch.matmul(attention, values)
# The output is calculated as the weighted sum of the values, where the weights are the attention weights

print("Output Shape:", output.shape)
print("Output:", output)

Output Shape: torch.Size([1, 10, 2, 5])
Output: tensor([[[[ 5.1275e-01, -2.3062e-02,  7.6786e-01,  3.5396e-01, -4.8948e-01],
          [ 4.7629e-01,  5.5857e-03,  7.1325e-01,  3.1193e-01, -4.6188e-01]],

         [[ 6.8100e-01,  2.7311e-01,  1.4495e-01, -1.7394e-01, -8.7010e-01],
          [ 6.8551e-01,  2.7291e-01,  1.5859e-01, -1.7169e-01, -8.5177e-01]],

         [[-4.1000e-01,  6.4581e-01, -5.4488e-01, -5.8096e-01,  5.3965e-02],
          [-3.2036e-01,  5.2982e-01, -2.9135e-01, -4.3456e-01,  2.2621e-01]],

         [[-2.7894e-02,  2.6303e-01, -7.1920e-02, -2.4139e-01,  2.3254e-01],
          [ 1.0286e-01, -4.0710e-04,  2.4272e-01,  9.5863e-02,  6.4063e-02]],

         [[ 1.7366e-01, -1.0080e-01,  4.3616e-01,  3.5210e-01, -4.7065e-01],
          [ 6.3986e-02,  7.8689e-02,  2.7882e-01,  1.4916e-01, -3.1616e-01]],

         [[ 9.2607e-02, -1.6513e-01,  3.7339e-01,  1.7588e-01,  4.9838e-01],
          [ 9.3334e-02, -1.9488e-01,  3.7261e-01,  1.9396e-01,  4.7630e-01]],

         [[ 8.20

In [18]:
# Combine heads and apply final linear transformation
output = output.transpose(1, 2).contiguous().view(N, -1, heads * head_dim)
# The heads are combined by transposing and reshaping the output to merge the heads

fc_out = nn.Linear(heads * head_dim, embedding_size)
# A final linear layer is applied to project the combined heads back to the original embedding size

output = fc_out(output)

print("Final Output Shape:", output.shape)
print("Final Output:", output)

Final Output Shape: torch.Size([1, 10, 10])
Final Output: tensor([[[ 4.3352e-01,  6.4582e-02, -2.0290e-01, -1.3944e-01, -4.8964e-01,
           6.4467e-03, -2.5833e-01, -4.8392e-02,  1.8122e-01, -4.5369e-01],
         [ 1.9231e-01, -2.8036e-01,  2.4081e-01, -1.3562e-03,  5.1849e-01,
           8.9808e-03,  3.7103e-01, -1.3844e-01, -2.2998e-01, -9.1597e-02],
         [ 2.3325e-01,  5.7716e-02,  3.2724e-01,  3.4075e-02, -1.0252e-01,
          -6.5169e-02,  1.2867e-01, -1.2793e-01, -2.2600e-01,  4.2445e-02],
         [ 5.0482e-01,  1.6524e-02,  5.1617e-02, -7.5409e-02, -3.3103e-01,
          -1.4219e-01, -1.0279e-01, -1.8948e-01,  2.8319e-02, -1.3019e-01],
         [ 4.6103e-01, -1.0390e-01,  3.3408e-01, -2.1866e-01,  7.6646e-02,
          -2.5859e-01,  1.1607e-01, -2.0052e-01, -1.4547e-01,  1.2353e-01],
         [ 4.3236e-01,  5.5618e-02, -1.9086e-01, -1.4108e-01, -4.5655e-01,
           3.0693e-03, -2.3673e-01, -4.7470e-02,  1.7033e-01, -4.4062e-01],
         [ 2.7380e-01, -1.1220e-01, 