# Simplified Attention

In [1]:
import torch

inputs = torch.tensor(
    [
        [0.43, 0.15, 0.89],  # Your     (x^1)
        [0.55, 0.87, 0.66],  # journey  (x^2)
        [0.57, 0.85, 0.64],  # starts   (x^3)
        [0.22, 0.58, 0.33],  # with     (x^4)
        [0.77, 0.25, 0.10],  # one      (x^5)
        [0.05, 0.80, 0.55],
    ]  # step     (x^6)
)

  cpu = _conversion_method_template(device=torch.device("cpu"))


In [2]:
query = inputs[1]  # 1
attn_scores_2 = torch.empty(inputs.shape[0])
print(f"query: {query}")
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = x_i.dot(query)  # torch.dot(x_i, query) - both are same.
    print(f"x_{i} : {x_i}, dot: {attn_scores_2[i]}")
print(attn_scores_2)

query: tensor([0.5500, 0.8700, 0.6600])
x_0 : tensor([0.4300, 0.1500, 0.8900]), dot: 0.9544000625610352
x_1 : tensor([0.5500, 0.8700, 0.6600]), dot: 1.4950001239776611
x_2 : tensor([0.5700, 0.8500, 0.6400]), dot: 1.4754000902175903
x_3 : tensor([0.2200, 0.5800, 0.3300]), dot: 0.8434000015258789
x_4 : tensor([0.7700, 0.2500, 0.1000]), dot: 0.7070000171661377
x_5 : tensor([0.0500, 0.8000, 0.5500]), dot: 1.0865000486373901
tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


In [3]:
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
print("Attention weights:", attn_weights_2_tmp)
print("Sum:", attn_weights_2_tmp.sum())

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: tensor(1.0000)


In [4]:
from llm_from_scratch.c3 import softmax_naive

attn_weights_2 = softmax_naive(attn_scores_2)
print(attn_weights_2)

tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])


In [5]:
query = inputs[1]
context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i] * x_i
print(context_vec_2)

tensor([0.4419, 0.6515, 0.5683])


In [6]:
attn_scores = torch.empty(6, 6)
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = x_i.dot(x_j)
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [7]:
# inputs: 6 x 3
# (6x3) @ (3x6) = (6x6)
attn_scores = inputs @ inputs.T
print(attn_scores)

# What if I use * instead of @
# This just does element by element *
# and because dimentions do not match, does not work.
# experiment = inputs * inputs.T
# print(experiment)

# What happens if we do this the other way around?
# (3x6) @ (6x3) = (3x3)
experiment = inputs.T @ inputs
print(experiment)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])
tensor([[1.4561, 1.3876, 1.2876],
        [1.3876, 2.5408, 1.9081],
        [1.2876, 1.9081, 2.0587]])


In [8]:
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 7])
print("> a", a)
print("> b", a)

# Stack will add another dimension
c = torch.stack((a, b))
print("> torch.stack(a, b)", c)

# .T only well defined for 2-D tensor
print("> c.T", c.T)

# .T on other dimentions tensors is not defined and will be removed
# in later releases.
# You can, however, use torch.permuate(...)
# print("> a.T", a.T) -- in older torch just returns a

# To actually turn a 1 dim row matrix to a colum matrix, you can
# do this
print("> a.unsqueeze(0)", a.unsqueeze(0))
print("> a.unsqueeze(1)", a.unsqueeze(1))
print("> a.unsqueeze(0).T", a.unsqueeze(0).T)
print("> a.unsqueeze(0).permute(0, 1)", a.unsqueeze(0).permute(0, 1))
# print("> a.unsqueeze(2)", a.unsqueeze(2)) -- this fails.
print("> a[None, :]", a[None, :])
print("> a[:, None]", a[:, None])
print("> a[:, None, None]", a[None, :, None])


print("> a*b", a * b)

# What does a.T or b.T do?

# This gives s

# What does @ do exactly?
print("> a@b", a @ b)

> a tensor([1, 2, 3])
> b tensor([1, 2, 3])
> torch.stack(a, b) tensor([[1, 2, 3],
        [4, 5, 7]])
> c.T tensor([[1, 4],
        [2, 5],
        [3, 7]])
> a.unsqueeze(0) tensor([[1, 2, 3]])
> a.unsqueeze(1) tensor([[1],
        [2],
        [3]])
> a.unsqueeze(0).T tensor([[1],
        [2],
        [3]])
> a.unsqueeze(0).permute(0, 1) tensor([[1, 2, 3]])
> a[None, :] tensor([[1, 2, 3]])
> a[:, None] tensor([[1],
        [2],
        [3]])
> a[:, None, None] tensor([[[1],
         [2],
         [3]]])
> a*b tensor([ 4, 10, 21])
> a@b tensor(35)


In [9]:
attn_weights = torch.softmax(attn_scores, dim=1)
print(attn_weights)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In [10]:
all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

# * In this simple case we did (X.X^T)X
# * Where normally you do ( (QX) (KX)^T ) V X
#
# Why don't we just optimize for QK^T directory, rather than introducing
# Q and K separately?

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


# Self Attention

In [11]:
x_2 = inputs[1]  # 1
d_in = inputs.shape[1]  # 2
d_out = 2  # 3

In [12]:
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [13]:
print("- Discovering the @ operator")
print(x_2.shape)
print(W_query.shape)
print(W_query.data)

print(torch.tensor([1.0, 0, 0]) @ W_query)
print(torch.tensor([0.0, 1, 0]) @ W_query)
print(torch.tensor([1.0, 0, 1]) @ W_query)

print("\n- Compute Query_2")
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2)

- Discovering the @ operator
torch.Size([3])
torch.Size([3, 2])
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])
tensor([0.2961, 0.5166])
tensor([0.2517, 0.6886])
tensor([0.3701, 1.3831])

- Compute Query_2
tensor([0.4306, 1.4551])


In [14]:
keys = inputs @ W_key
values = inputs @ W_value
queries = inputs @ W_query
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)
print("queries.shape:", queries.shape)

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])
queries.shape: torch.Size([6, 2])


In [15]:
keys_2 = keys[1]  # 1
attn_score_22 = query_2.dot(keys_2)
print(attn_score_22)

tensor(1.8524)


In [16]:
attn_scores_2 = query_2 @ keys.T  # (2) @ (2 * n)       #1
print(attn_scores_2)

print(keys @ query_2)  # (n * 2) * 2 This gives the same result too

# what if I try for everyone - this gives weight_ij = q_i * v_j
# which is what we want, row i column j is (when sample i queries value of j) i.e. q_i * v_j
print("All weights", queries @ keys.T)

# This does not work, unlike the 1 sample case -- 6x2 * 2x6
# print("All weights", keys @ queries)
# This works, it is just transposed.
# i.e. weights_ij = v_i * q_j
print("All weights . T", keys @ queries.T)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])
tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])
All weights tensor([[0.9231, 1.3545, 1.3241, 0.7910, 0.4032, 1.1330],
        [1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440],
        [1.2544, 1.8284, 1.7877, 1.0654, 0.5508, 1.5238],
        [0.6973, 1.0167, 0.9941, 0.5925, 0.3061, 0.8475],
        [0.6114, 0.8819, 0.8626, 0.5121, 0.2707, 0.7307],
        [0.8995, 1.3165, 1.2871, 0.7682, 0.3937, 1.0996]])
All weights . T tensor([[0.9231, 1.2705, 1.2544, 0.6973, 0.6114, 0.8995],
        [1.3545, 1.8524, 1.8284, 1.0167, 0.8819, 1.3165],
        [1.3241, 1.8111, 1.7877, 0.9941, 0.8626, 1.2871],
        [0.7910, 1.0795, 1.0654, 0.5925, 0.5121, 0.7682],
        [0.4032, 0.5577, 0.5508, 0.3061, 0.2707, 0.3937],
        [1.1330, 1.5440, 1.5238, 0.8475, 0.7307, 1.0996]])


In [17]:
d_k = keys.shape[-1]
print(f"d_k={d_k}")
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print("attn_weights_2=", attn_weights_2)

context_vec_2 = attn_weights_2 @ values
print("context_vec_2=", context_vec_2)

d_k=2
attn_weights_2= tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])
context_vec_2= tensor([0.3061, 0.8210])


In [18]:
import torch.nn as nn
import torch


class SelfAttention_v1(nn.Module):
    def __init__(self, d_in: int, d_out: int) -> None:
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        attn_scores = queries @ keys.T  # omega
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec

In [19]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


In [20]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False) -> None:
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec

In [21]:
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


# Masked (causal) attention

In [22]:
queries = sa_v2.W_query(inputs)  # 1
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
print(attn_weights)

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


In [24]:
context_length = attn_scores.shape[0]
ones = torch.ones(context_length, context_length)
mask_simple = torch.tril(ones)
print(ones)
print(mask_simple)

tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])
tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [30]:
masked_simple = attn_weights * mask_simple  # elementwise mult.
print(masked_simple)

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)


In [42]:
# keepdim = True is going to keep row_sums n * 1 tensor instead of n
# if you don't do that masked_simple / row_sums will divide each col
# by the sum instead of row. That is weird (?)

row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)


In [47]:
print(torch.triu(torch.ones(context_length, context_length), diagonal=0))
print(torch.triu(torch.ones(context_length, context_length), diagonal=1))
print(torch.triu(torch.ones(context_length, context_length), diagonal=2))

tensor([[1., 1., 1., 1., 1., 1.],
        [0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.]])
tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])
tensor([[0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]])


In [48]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)


In [49]:
attn_weights = torch.softmax(masked / keys.shape[-1] ** 0.5, dim=1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


**A. Dropout**

In [56]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)  # 1
example = torch.ones(6, 6)  # 2
print(dropout(example))

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])


**B. Final formula**

In [57]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

torch.Size([2, 6, 3])


In [85]:
import torch

# Create a 3D tensor of shape [2, 3, 4]
t = torch.ones(2, 3, 4)
print("Original tensor shape:", t.shape)

# Create a 2D mask of shape [3, 4]
# This is smaller than the tensor but can be broadcast
mask = torch.zeros(3, 4, dtype=torch.bool)
mask[1, :] = True  # Make the middle row True
print("Mask shape:", mask.shape)
print("Mask:", mask)

# Apply the mask - it will be broadcast across the first dimension
result = t.masked_fill(mask, -2)
print("\nResult tensor:")
print(result)

Original tensor shape: torch.Size([2, 3, 4])
Mask shape: torch.Size([3, 4])
Mask: tensor([[False, False, False, False],
        [ True,  True,  True,  True],
        [False, False, False, False]])

Result tensor:
tensor([[[ 1.,  1.,  1.,  1.],
         [-2., -2., -2., -2.],
         [ 1.,  1.,  1.,  1.]],

        [[ 1.,  1.,  1.,  1.],
         [-2., -2., -2., -2.],
         [ 1.,  1.,  1.,  1.]]])


In [72]:
class CausalAttention(nn.Module):
    def __init__(
        self,
        d_in: int,
        d_out: int,
        context_length: int,
        dropout: float,
        qkv_bias: bool = False,
    ) -> None:
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)  # 1
        self.register_buffer(
            "mask", torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )  # 2
        self.mask: torch.Tensor  # typehint for register_buffer(...)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b, num_tokens, d_in = x.shape  # 3
        _, context_length = self.mask.shape
        assert num_tokens <= context_length
        keys: torch.Tensor = self.W_key(x)
        queries: torch.Tensor = self.W_query(x)
        values: torch.Tensor = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2)
        # num_tokens is used because the input might be smaller
        # than total context_len
        # Note: masked fill does broadcast, so dims don't have to match.
        # print("Shape of attn_scores:", attn_scores.shape)
        attn_scores.masked_fill_(  # 4
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
        )
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ values
        return context_vec

In [73]:
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)

Shape of attn_scores: torch.Size([2, 6, 6])
context_vecs.shape: torch.Size([2, 6, 2])


# Multi-head attention

**A. Basic implementation with for**

In [66]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [
                CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)
                for _ in range(num_heads)
            ]
        )

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [67]:
torch.manual_seed(123)
context_length = batch.shape[1]  # This is the number of tokens
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)

print("context_vecs.shape:", context_vecs.shape)
print(context_vecs)

context_vecs.shape: torch.Size([2, 6, 4])
tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)


**B. torch.view**

In [91]:
import torch

# Create a tensor
x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
print("Original tensor:", x)
print("Original shape:", x.shape)

Original tensor: tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])
Original shape: torch.Size([12])


In [92]:
# Reshape to 3x4 matrix
x_reshaped = x.view(3, 4)
print("\nReshaped to 3x4:")
print(x_reshaped)
print("New shape:", x_reshaped.shape)


Reshaped to 3x4:
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12]])
New shape: torch.Size([3, 4])


In [88]:
# Reshape to 2x2x3
x_3d = x.view(2, 2, 3)
print("\nReshaped to 2x2x3:")
print(x_3d)
print("New shape:", x_3d.shape)


Reshaped to 2x2x3:
tensor([[[ 1,  2,  3],
         [ 4,  5,  6]],

        [[ 7,  8,  9],
         [10, 11, 12]]])
New shape: torch.Size([2, 2, 3])


In [101]:
# Use -1 to automatically infer one dimension
x_auto = x.view(2, -1, 3)  # PyTorch will calculate the first dimension
print("\nReshaped with automatic dimension:")
print(x_auto)
print("New shape:", x_auto.shape)


Reshaped with automatic dimension:
tensor([[[ 1,  2,  3],
         [ 4,  5,  6]],

        [[ 7,  8,  9],
         [10, 11, 12]]])
New shape: torch.Size([2, 2, 3])


In [90]:
# Flatten a tensor
x_flat = x_3d.view(-1)  # Flatten to 1D
print("\nFlattened tensor:")
print(x_flat)
print("New shape:", x_flat.shape)


Flattened tensor:
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])
New shape: torch.Size([12])


In [106]:
# contiguous a tensor
x_contig = x_3d.contiguous()
print("contiguous tensor:")
print(x_contig)
print("New shape:", x_flat.shape)

contiguous tensor:
tensor([[[ 1,  2,  3],
         [ 4,  5,  6]],

        [[ 7,  8,  9],
         [10, 11, 12]]])
New shape: torch.Size([12])


**C. How nn.Linear changes the dimension**

In [102]:
import torch
import torch.nn as nn

# Create a simple linear layer
in_features = 3
out_features = 5
linear = nn.Linear(in_features, out_features)

# Print the shapes of the weights and bias
print(
    f"Weight shape: {linear.weight.shape}"
)  # Should be (out_features, in_features) = (5, 3)
print(f"Bias shape: {linear.bias.shape}")  # Should be (out_features) = (5)

# Example 1: Single vector input
x1 = torch.randn(in_features)  # Shape: (3)
print(f"\nInput x1 shape: {x1.shape}")

y1 = linear(x1)  # Apply linear transformation
print(f"Output y1 shape: {y1.shape}")  # Should be (5)

# Example 2: Batch of vectors
x2 = torch.randn(10, in_features)  # Shape: (10, 3) - batch of 10 vectors
print(f"\nInput x2 shape: {x2.shape}")

y2 = linear(x2)  # Apply linear transformation
print(f"Output y2 shape: {y2.shape}")  # Should be (10, 5)

# Example 3: Multi-dimensional input
x3 = torch.randn(4, 6, in_features)  # Shape: (4, 6, 3) - 4 batches of 6 vectors each
print(f"\nInput x3 shape: {x3.shape}")

y3 = linear(x3)  # Apply linear transformation
print(f"Output y3 shape: {y3.shape}")  # Should be (4, 6, 5)

# Let's manually verify the transformation for a single vector
manual_result = x1 @ linear.weight.T + linear.bias
print(f"\nManual calculation shape: {manual_result.shape}")  # Should be (5)
print(f"Are results equal? {torch.allclose(y1, manual_result)}")

Weight shape: torch.Size([5, 3])
Bias shape: torch.Size([5])

Input x1 shape: torch.Size([3])
Output y1 shape: torch.Size([5])

Input x2 shape: torch.Size([10, 3])
Output y2 shape: torch.Size([10, 5])

Input x3 shape: torch.Size([4, 6, 3])
Output y3 shape: torch.Size([4, 6, 5])

Manual calculation shape: torch.Size([5])
Are results equal? True


**D. Vectorized implementation**

In [108]:
# Same as Andrei-Karpathi's video
class MultiHeadAttention(nn.Module):
    """
    #1 Reduces the projection dim to match the desired output dim
    #2 Uses a Linear layer to combine head outputs
    #3 Tensor shape: (b, num_tokens, d_out)
    #4 We implicitly split the matrix by adding a num_heads dimension. Then we unroll the last dim: (b, num_tokens, d_out) -&gt; (b, num_tokens, num_heads, head_dim).
    #5 Transposes from shape (b, num_tokens, num_heads, head_dim) to (b, num_heads, num_tokens, head_dim)
    #6 Computes dot product for each head
    #7 Masks truncated to the number of tokens
    #8 Uses the mask to fill attention scores
    #9 Tensor shape: (b, num_tokens, n_heads, head_dim)
    #10 Combines heads, where self.d_out = self.num_heads * self.head_dim
    #11 Adds an optional linear projection”
    """

    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads  # 1
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # 2
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask", torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)  # 3
        queries = self.W_query(x)  # 3
        values = self.W_value(x)  # 3

        # b * num_tokens * d_out -> b * num_tokens * n_heads * head_dim
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)  # 4
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # b * num_tokens * n_heads * head_dim -> b * n_heads * num_tokens * head_dim
        keys = keys.transpose(1, 2)  # 5
        queries = queries.transpose(1, 2)  # 5
        values = values.transpose(1, 2)  # 5

        # queries: b * n_heads * num_tokens * head_dim
        # keys.T :   b * n_heads * head_dim * num_tokens
        # =>
        # attn_scores: b * n_heads * num_tokens * num_tokens
        attn_scores = queries @ keys.transpose(2, 3)  # 6

        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]  # 7
        attn_scores.masked_fill_(mask_bool, -torch.inf)  # 8

        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # attn_weights: b * n_heads * num_tokens * num_tokens
        # values      : b * n_heads * num_tokens * head_dim
        # =>
        # context_vec  : b * n_heads * num_tokens * head_dim
        # =>
        # context_vec  : b * num_tokens * n_heads * head_dim
        context_vec = (attn_weights @ values).transpose(1, 2)  # 9

        # contiguous arranges the memory in a good shape again
        # required by view.
        # before: b * num_tokens * n_heads * head_dim
        # after:  b * num_tokens * d_out
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)  # 10
        # in karpathy's video, he also added a dropout here.
        context_vec = self.out_proj(context_vec)  # 11
        return context_vec

In [123]:
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
num_heads = 2
d_out = 2  # must be divisible by num_heads
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=num_heads)
context_vecs = mha(batch)
print("context_vecs.shape:", context_vecs.shape)
print(context_vecs)

context_vecs.shape: torch.Size([2, 6, 2])
tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
