# Simplified Attention

In [1]:
import torch

inputs = torch.tensor(
    [
        [0.43, 0.15, 0.89],  # Your     (x^1)
        [0.55, 0.87, 0.66],  # journey  (x^2)
        [0.57, 0.85, 0.64],  # starts   (x^3)
        [0.22, 0.58, 0.33],  # with     (x^4)
        [0.77, 0.25, 0.10],  # one      (x^5)
        [0.05, 0.80, 0.55],
    ]  # step     (x^6)
)

  cpu = _conversion_method_template(device=torch.device("cpu"))


In [2]:
query = inputs[1]  # 1
attn_scores_2 = torch.empty(inputs.shape[0])
print(f"query: {query}")
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = x_i.dot(query)  # torch.dot(x_i, query) - both are same.
    print(f"x_{i} : {x_i}, dot: {attn_scores_2[i]}")
print(attn_scores_2)

query: tensor([0.5500, 0.8700, 0.6600])
x_0 : tensor([0.4300, 0.1500, 0.8900]), dot: 0.9544000625610352
x_1 : tensor([0.5500, 0.8700, 0.6600]), dot: 1.4950001239776611
x_2 : tensor([0.5700, 0.8500, 0.6400]), dot: 1.4754000902175903
x_3 : tensor([0.2200, 0.5800, 0.3300]), dot: 0.8434000015258789
x_4 : tensor([0.7700, 0.2500, 0.1000]), dot: 0.7070000171661377
x_5 : tensor([0.0500, 0.8000, 0.5500]), dot: 1.0865000486373901
tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


In [3]:
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
print("Attention weights:", attn_weights_2_tmp)
print("Sum:", attn_weights_2_tmp.sum())

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: tensor(1.0000)


In [4]:
from llm_from_scratch.c3 import softmax_naive

attn_weights_2 = softmax_naive(attn_scores_2)
print(attn_weights_2)

tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])


In [5]:
query = inputs[1]
context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i] * x_i
print(context_vec_2)

tensor([0.4419, 0.6515, 0.5683])


In [6]:
attn_scores = torch.empty(6, 6)
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = x_i.dot(x_j)
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [7]:
# inputs: 6 x 3
# (6x3) @ (3x6) = (6x6)
attn_scores = inputs @ inputs.T
print(attn_scores)

# What if I use * instead of @
# This just does element by element *
# and because dimentions do not match, does not work.
# experiment = inputs * inputs.T
# print(experiment)

# What happens if we do this the other way around?
# (3x6) @ (6x3) = (3x3)
experiment = inputs.T @ inputs
print(experiment)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])
tensor([[1.4561, 1.3876, 1.2876],
        [1.3876, 2.5408, 1.9081],
        [1.2876, 1.9081, 2.0587]])


In [8]:
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 7])
print("> a", a)
print("> b", a)

# Stack will add another dimension
c = torch.stack((a, b))
print("> torch.stack(a, b)", c)

# .T only well defined for 2-D tensor
print("> c.T", c.T)

# .T on other dimentions tensors is not defined and will be removed
# in later releases.
# You can, however, use torch.permuate(...)
# print("> a.T", a.T) -- in older torch just returns a

# To actually turn a 1 dim row matrix to a colum matrix, you can
# do this
print("> a.unsqueeze(0)", a.unsqueeze(0))
print("> a.unsqueeze(1)", a.unsqueeze(1))
print("> a.unsqueeze(0).T", a.unsqueeze(0).T)
print("> a.unsqueeze(0).permute(0, 1)", a.unsqueeze(0).permute(0, 1))
# print("> a.unsqueeze(2)", a.unsqueeze(2)) -- this fails.
print("> a[None, :]", a[None, :])
print("> a[:, None]", a[:, None])
print("> a[:, None, None]", a[None, :, None])


print("> a*b", a * b)

# What does a.T or b.T do?

# This gives s

# What does @ do exactly?
print("> a@b", a @ b)

> a tensor([1, 2, 3])
> b tensor([1, 2, 3])
> torch.stack(a, b) tensor([[1, 2, 3],
        [4, 5, 7]])
> c.T tensor([[1, 4],
        [2, 5],
        [3, 7]])
> a.unsqueeze(0) tensor([[1, 2, 3]])
> a.unsqueeze(1) tensor([[1],
        [2],
        [3]])
> a.unsqueeze(0).T tensor([[1],
        [2],
        [3]])
> a.unsqueeze(0).permute(0, 1) tensor([[1, 2, 3]])
> a[None, :] tensor([[1, 2, 3]])
> a[:, None] tensor([[1],
        [2],
        [3]])
> a[:, None, None] tensor([[[1],
         [2],
         [3]]])
> a*b tensor([ 4, 10, 21])
> a@b tensor(35)


In [9]:
attn_weights = torch.softmax(attn_scores, dim=1)
print(attn_weights)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In [10]:
all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

# * In this simple case we did (X.X^T)X
# * Where normally you do ( (QX) (KX)^T ) V X
#
# Why don't we just optimize for QK^T directory, rather than introducing
# Q and K separately?

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


# Self Attention

In [12]:
x_2 = inputs[1]  # 1
d_in = inputs.shape[1]  # 2
d_out = 2  # 3

In [13]:
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [25]:
print("- Discovering the @ operator")
print(x_2.shape)
print(W_query.shape)
print(W_query.data)

print(torch.tensor([1.0, 0, 0]) @ W_query)
print(torch.tensor([0.0, 1, 0]) @ W_query)
print(torch.tensor([1.0, 0, 1]) @ W_query)

print("\n- Compute Query_2")
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2)

- Discovering the @ operator
torch.Size([3])
torch.Size([3, 2])
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])
tensor([0.2961, 0.5166])
tensor([0.2517, 0.6886])
tensor([0.3701, 1.3831])

- Compute Query_2
tensor([0.4306, 1.4551])


In [47]:
keys = inputs @ W_key
values = inputs @ W_value
queries = inputs @ W_query
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)
print("queries.shape:", queries.shape)

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])
queries.shape: torch.Size([6, 2])


In [48]:
keys_2 = keys[1]  # 1
attn_score_22 = query_2.dot(keys_2)
print(attn_score_22)

tensor(1.8524)


In [52]:
attn_scores_2 = query_2 @ keys.T # (2) @ (2 * n)       #1
print(attn_scores_2)

print(keys @ query_2) # (n * 2) * 2 This gives the same result too

# what if I try for everyone - this gives weight_ij = q_i * v_j
# which is what we want, row i column j is (when sample i queries value of j) i.e. q_i * v_j
print("All weights", queries @ keys.T)

# This does not work, unlike the 1 sample case -- 6x2 * 2x6
# print("All weights", keys @ queries)
# This works, it is just transposed.
# i.e. weights_ij = v_i * q_j
print("All weights . T", keys @ queries.T)


tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])
tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])
All weights tensor([[0.9231, 1.3545, 1.3241, 0.7910, 0.4032, 1.1330],
        [1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440],
        [1.2544, 1.8284, 1.7877, 1.0654, 0.5508, 1.5238],
        [0.6973, 1.0167, 0.9941, 0.5925, 0.3061, 0.8475],
        [0.6114, 0.8819, 0.8626, 0.5121, 0.2707, 0.7307],
        [0.8995, 1.3165, 1.2871, 0.7682, 0.3937, 1.0996]])
All weights tensor([[0.9231, 1.2705, 1.2544, 0.6973, 0.6114, 0.8995],
        [1.3545, 1.8524, 1.8284, 1.0167, 0.8819, 1.3165],
        [1.3241, 1.8111, 1.7877, 0.9941, 0.8626, 1.2871],
        [0.7910, 1.0795, 1.0654, 0.5925, 0.5121, 0.7682],
        [0.4032, 0.5577, 0.5508, 0.3061, 0.2707, 0.3937],
        [1.1330, 1.5440, 1.5238, 0.8475, 0.7307, 1.0996]])


In [44]:
d_k = keys.shape[-1]
print(f"d_k={d_k}")
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print("attn_weights_2=", attn_weights_2)

context_vec_2 = attn_weights_2 @ values
print("context_vec_2=", context_vec_2)

d_k=2
attn_weights_2= tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])
context_vec_2= tensor([0.3061, 0.8210])


In [None]:
import torch.nn as nn
import torch


class SelfAttention_v1(nn.Module):
    def __init__(self, d_in: int, d_out: int) -> None:
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        attn_scores = queries @ keys.T  # omega
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec