# Self Attention Step-by-step walkthrough

## Basic Self-attention without trainable weights

In [3]:
import torch

In [4]:
inputs = torch.tensor(
   [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)
print('-' * 100)
print('T x Embed_Size')
print('-' * 100)
print(inputs.shape)
print('-' * 100)

----------------------------------------------------------------------------------------------------
T x Embed_Size
----------------------------------------------------------------------------------------------------
torch.Size([6, 3])
----------------------------------------------------------------------------------------------------


### Compute attention scores

In [5]:
attention_scores = inputs @ inputs.T
attention_scores

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

In [6]:
print('-' * 100)
print('T x T')
print('-' * 100)
print(attention_scores.shape)
print('-' * 100)

----------------------------------------------------------------------------------------------------
T x T
----------------------------------------------------------------------------------------------------
torch.Size([6, 6])
----------------------------------------------------------------------------------------------------


### Normalize

attention scores w21 -> w2T for input query x2, normalize to get attention weights

In [7]:
def softmax(x):
    return torch.exp(x) / torch.sum(torch.exp(x))
attention_weights = torch.empty(attention_scores.shape)
for i in range(attention_scores.shape[0]):
    attention_weights[i] = softmax(attention_scores[i])
attention_weights

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

In [8]:
attention_weights.sum(dim=1)

tensor([1., 1., 1., 1., 1., 1.])

In [9]:
attention_weights = torch.softmax(attention_scores, dim=1)
attention_weights

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

In [10]:
attention_weights.shape

torch.Size([6, 6])

In [11]:
attention_weights.sum(dim=1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])

### Compute the context vector

z_2 = [a_21 * x_1] + [a22 * x_2] + ... [a2T * x_T]


In [12]:
query_input_idx = 1
query = inputs[query_input_idx]
z_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    z_2 += x_i * attention_weights[query_input_idx][i]
z_2

tensor([0.4419, 0.6515, 0.5683])

In [13]:
Z = attention_weights @ inputs
Z

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

## Self-attention with trainable weights

self-attention mechanism aka scaled-dot-product attention

#### 1. Compute query(q), key(k), value(v) vectors for input elements x

For that we need to have trainable weight matrices Wq, Wk, Wv. These matrices project the embedded input tokens
into query, key, value vectors.

In [14]:
# For one single query x_2
x_2 = inputs[2]
d_in = inputs.shape[1]
d_out = inputs.shape[1]

In [15]:
x_2.shape

torch.Size([3])

In [22]:
# initialize three weights matrices Wq, Wk, Wv
torch.manual_seed(123)
Wq = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
Wk = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
Wv = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [23]:
x_2.shape

torch.Size([3])

In [24]:
# compute query, key, value vectors
q = Wq @ x_2
k = Wk @ x_2
v = Wv @ x_2

In [25]:
queries = inputs @ Wq
keys = inputs @ Wk
values = inputs @ Wv

queries.shape, keys.shape, values.shape

(torch.Size([6, 3]), torch.Size([6, 3]), torch.Size([6, 3]))

#### 2. Compute attention scores

We compute dot product of query and key.

Say, for for query_2 we want to find attention_scores \
attn_score_2_1 = dot(query_2, key_1) \
attn_score_2_2 = dot(query_2, key_2) \
... \
attn_score_2_T = dot(query_2, key_T)


In [26]:
attention_scores = queries @ keys.T
attention_scores.shape

torch.Size([6, 6])

#### 3. Compute attention weights

Scale and Normalize

In [27]:
d_k = keys.shape[-1]
attention_weights = torch.softmax(attention_scores / d_k ** 0.5, dim=-1)
attention_weights

tensor([[0.1747, 0.1866, 0.1864, 0.1446, 0.1586, 0.1491],
        [0.1862, 0.2123, 0.2117, 0.1179, 0.1450, 0.1269],
        [0.1859, 0.2118, 0.2112, 0.1184, 0.1454, 0.1273],
        [0.1798, 0.1936, 0.1932, 0.1365, 0.1542, 0.1427],
        [0.1751, 0.1895, 0.1893, 0.1418, 0.1579, 0.1465],
        [0.1837, 0.2003, 0.1998, 0.1293, 0.1501, 0.1369]])

In [28]:
attention_weights.shape

torch.Size([6, 6])

In [29]:
#### 4. Compute the context vector (context maxtrix Z)

In [32]:
values

tensor([[0.4976, 0.9655, 0.7614],
        [0.9074, 1.3518, 1.5075],
        [0.8976, 1.3391, 1.4994],
        [0.5187, 0.7319, 0.8493],
        [0.4699, 0.7336, 0.9307],
        [0.6446, 0.9045, 0.9814]])

In [48]:
alpha_2 = attention_weights[2]
z_2 = torch.zeros(3)

for i, v_i in enumerate(values):
    z_2 += alpha_2[i] * v_i
    print(z_2)

tensor([0.0925, 0.1795, 0.1416])
tensor([0.2847, 0.4658, 0.4608])
tensor([0.4742, 0.7485, 0.7775])
tensor([0.5356, 0.8352, 0.8780])
tensor([0.6040, 0.9419, 1.0133])
tensor([0.6860, 1.0570, 1.1383])


In [33]:
Z = attention_weights @ values
Z

tensor([[0.6692, 1.0276, 1.1106],
        [0.6864, 1.0577, 1.1389],
        [0.6860, 1.0570, 1.1383],
        [0.6738, 1.0361, 1.1180],
        [0.6711, 1.0307, 1.1139],
        [0.6783, 1.0441, 1.1252]])

In [49]:
print('-' * 100)
print('T x d_model')
print('-' * 100)
print(Z.shape)
print('-' * 100)

----------------------------------------------------------------------------------------------------
T x d_model
----------------------------------------------------------------------------------------------------
torch.Size([6, 3])
----------------------------------------------------------------------------------------------------


We add a normalization step to the dot product calculation in self-attention to aid training with large embedding sizes. This normalization, by the square root of the embedding dimension, is why it's called scaled-dot product attention.

Here's what this means:

* In natural language processing, we often use large embedding sizes (dimensions) to represent words.
* When these dimensions are high (like in large language models), the dot product between query and key vectors can become very large.
* Large dot products create problems during training because they make the softmax function act almost like an on/off switch. This, in turn, leads to very small gradients during backpropagation.
* Small gradients make it difficult for the model to learn effectively.
* By normalizing the dot product by the square root of the embedding dimension, we prevent these issues and ensure the gradients stay usable during training.


"Why use query, key, and value?

These terms, borrowed from information retrieval and databases, help organize and process information in attention mechanisms.

A "query" is like a search term, representing the current focus of the model.
A "key" indexes input items, aiding in matching with the query.
The "value" holds the actual content or representation of input items, retrieved based on the matched keys."

## Compact self-attention class

In [69]:
import torch.nn as nn
torch.manual_seed(143)

class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out):
        super(SelfAttention, self).__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.Wq = nn.Parameter(torch.rand(d_in, d_out))
        self.Wv = nn.Parameter(torch.rand(d_in, d_out))
        self.Wk = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        # x -> (T, d_in) or (T, embed_size)
        queries = x @ self.Wq  # (T, d_out)
        keys = x @ self.Wk     # (T, d_out)
        values = x @ self.Wv   # (T, d_out)

        # compute attention scores
        attention_scores = queries @ keys.T  # (T, T)

        # compute attention weights
        attention_weights = torch.softmax( attention_scores / self.d_out ** 0.5 , 1) # (T, T)

        # compute context vector
        # Z_2 = (a2_1 * v_1) + (a2_2 * v_2) + .. (a2_T * v_T)
        Z = attention_weights @ values
        return Z

In [72]:
selfattn = SelfAttention(d_in=4608, d_out=4608)
X = torch.rand(4096, 4608)
Z = selfattn(X)
print(Z.shape)

torch.Size([4096, 4608])


In [73]:
import torch.nn as nn
torch.manual_seed(143)

class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super(SelfAttention, self).__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.Wq = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.Wv = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.Wk = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        # x -> (T, d_in) or (T, embed_size)
        queries = self.Wq(x)  # (T, d_out)
        keys = self.Wk(x)     # (T, d_out)
        values = self.Wv(x)   # (T, d_out)

        # compute attention scores
        attention_scores = queries @ keys.T  # (T, T)

        # compute attention weights
        attention_weights = torch.softmax( attention_scores / self.d_out ** 0.5 , 1) # (T, T)

        # compute context vector
        # Z_2 = (a2_1 * v_1) + (a2_2 * v_2) + .. (a2_T * v_T)
        Z = attention_weights @ values
        return Z

selfattn = SelfAttention(d_in=4608, d_out=4608, qkv_bias=False)
X = torch.rand(4096, 4608)
Z = selfattn(X)
print(Z.shape)

torch.Size([4096, 4608])
