## Self attention basic implementation

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [None]:
# version 4: self-attention!
torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time, channels
x = torch.randn(B,T,C)

# let's see a single Head perform self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x)   # (B, T, 16)
q = query(x) # (B, T, 16)
wei =  q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)

tril = torch.tril(torch.ones(T, T))
#wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
v = value(x)

### Matrix A (Lower Triangular, Shape: 3x3)
Matrix A is a lower triangular matrix, and each row sums to 1:

$$
A = \begin{bmatrix}
1 & 0 & 0 \\
\frac{1}{2} & \frac{1}{2} & 0 \\
\frac{1}{3} & \frac{1}{3} & \frac{1}{3} \\
\end{bmatrix}
$$

### Matrix B (Shape: 3x2)
$$
B = \begin{bmatrix}
b_{11} & b_{12} \\
b_{21} & b_{22} \\
b_{31} & b_{32} \\
\end{bmatrix}
$$

### Matrix Multiplication (Resulting Shape: 3x2)
The resulting matrix C is calculated as follows:

$$
C = A \times B = \begin{bmatrix}
c_{11} & c_{12} \\
c_{21} & c_{22} \\
c_{31} & c_{32} \\
\end{bmatrix}
$$

Where:

$$
c_{11} = 1 \cdot b_{11} + 0 \cdot b_{21} + 0 \cdot b_{31} = b_{11}
$$

$$
c_{12} = 1 \cdot b_{12} + 0 \cdot b_{22} + 0 \cdot b_{32} = b_{12}
$$

$$
c_{21} = \frac{1}{2} \cdot b_{11} + \frac{1}{2} \cdot b_{21} + 0 \cdot b_{31} = \frac{1}{2}(b_{11} + b_{21})
$$

$$
c_{22} = \frac{1}{2} \cdot b_{12} + \frac{1}{2} \cdot b_{22} + 0 \cdot b_{32} = \frac{1}{2}(b_{12} + b_{22})
$$

$$
c_{31} = \frac{1}{3} \cdot b_{11} + \frac{1}{3} \cdot b_{21} + \frac{1}{3} \cdot b_{31} = \frac{1}{3}(b_{11} + b_{21} + b_{31})
$$

$$
c_{32} = \frac{1}{3} \cdot b_{12} + \frac{1}{3} \cdot b_{22} + \frac{1}{3} \cdot b_{32} = \frac{1}{3}(b_{12} + b_{22} + b_{32})
$$


Each row of A say how much it should get of each token in the matrix B.  
The first row say to get 100% of first token and 0 of the others.  
This is done by multiplying each channel (of the embedding) for each row of A and summing up.  
So it works channel by channel (dimension by dimension) getting the quantity necessary (from each token) and summing up in that dimension.  