## The mathematical trick in self-attention

In [62]:
import torch
import torch.nn.functional as F 

In [3]:
# toy example 

B,T,C = 4,8,2 # batch, time, channels or batch, time/tokens, dimentions 

# 4 independent input data, each with length of 8 tokens, and each token has 2 dims

x = torch.randn(B,T,C)
x.shape


torch.Size([4, 8, 2])

In [13]:
x[0]

# first input vector, with 8 words / tokens , with each token having 2 dims 
# each row is sequencial, 

tensor([[-0.8017,  1.1943],
        [ 1.0161,  1.0499],
        [ 0.7962,  0.5086],
        [-0.6768, -1.9990],
        [-0.2106, -0.1355],
        [ 0.7507,  0.9637],
        [-0.3740,  0.0230],
        [ 1.7420, -0.2458]])

- so we have 8 tokens, and what we want them to talk to each other,
- here the way we want them to learn is from the past only, so in an input we have 8 tokens/time, they are sequencial 
- meaning, it's like a sentence with 8 words. 
- what we want is what word comes next is the function of it's past. 
- so if we are in the 5th word/token, we want it to be able see / talk to / learn from all the 4 tokens before it
- but we dont want it to see the future, 6th, 7th, 8th token

In [22]:
# easiest way to communicate to past tokens ==> take avg of all the tokens all past tokens
# extremely lossy way to aggregate past info, but it's a start 

xbow = torch.zeros((B,T,C))

for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] # (t,C)
        xbow[b,t] = torch.mean(xprev, 0 )

In [23]:
x[0]

tensor([[-0.8017,  1.1943],
        [ 1.0161,  1.0499],
        [ 0.7962,  0.5086],
        [-0.6768, -1.9990],
        [-0.2106, -0.1355],
        [ 0.7507,  0.9637],
        [-0.3740,  0.0230],
        [ 1.7420, -0.2458]])

In [25]:
xbow[0]

# each row in the xbow metric is the avg of all the rows above it, including itslelf
# in simple lang, its like moving avg
# so the third row, is the third token/word, and now it has somehow information of all the words before it. 

tensor([[-0.8017,  1.1943],
        [ 0.1072,  1.1221],
        [ 0.3369,  0.9176],
        [ 0.0834,  0.1884],
        [ 0.0246,  0.1237],
        [ 0.1456,  0.2637],
        [ 0.0714,  0.2293],
        [ 0.2802,  0.1699]])

## Math trick to do this efficiently

- basically,create such a matric (some wieght matric) that when we do a mat multiplication it with x mat the result is  the xbow mnatrix
- for loops are super expensive and not at all efficient

In [26]:
torch.tril(torch.ones(3,3))

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [35]:
a = torch.tril(torch.ones(3,3))
a_sum = torch.sum(a, 1, keepdim=True)
print(a.shape)
print(a_sum.shape)

torch.Size([3, 3])
torch.Size([3, 1])


In [40]:
# why we need to keep the dim in a_sum 
# b/c when pytorch sees (3,3) is dividec by (3,1), it converts the (3,1) to (3,3) by duplicating the column and then it doesn and element wise division 

print(f"a = {a}")
print(f"a_sum = {a_sum}")
print(f"a/a_sum = {a/a_sum}")


a = tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
a_sum = tensor([[1.],
        [2.],
        [3.]])
a/a_sum = tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])


In [45]:
# exmaple

torch.manual_seed(42)
# a = some kinda weight metric
a = torch.tril(torch.ones(3,3))
a = a / torch.sum(a, 1, keepdim=True)
# b = something like our x metric
b = torch.randint(0,10,(3,2)).float()
c = a @ b # dot multiplicaiton, essentially matrix multiplication
print(f"b = {b}")
print("below c metrix is a cumulative avg of all the rows above of mat b")
print(f"c = {c}")

b = tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
below c metrix is a cumulative avg of all the rows above of mat b
c = tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [49]:
## Putting all together 

# our input info is in x mat, which (4,8,2) matrix; 

# weight matix 
wei = torch.tril(torch.ones(T,T))
wei = wei / torch.sum(wei, 1, keepdim=True)

xbow2 = wei @ x # (B,T,T) @ (B,T,C) ----> (B,T,C)  
# for the wei metric, which is (T,T), when @ with x, the pytorch will add the dim B in front making it (B,T,T )
# and pytorch will apply with multiplcation for each batch parellaly, essentilally doing (T,T) @ (T,C) for each of the batch element
# pytorch is awesome

torch.allclose(xbow,xbow2) # both are same

True

In [54]:
# just reminding us what is xbow2 

print(f"x's first input : \n {x[0]}")
print("xbow is essentially a cumulative avg of x for each row")
print(f"xbow's first input : \n {xbow[0]}")


x's first input : 
 tensor([[-0.8017,  1.1943],
        [ 1.0161,  1.0499],
        [ 0.7962,  0.5086],
        [-0.6768, -1.9990],
        [-0.2106, -0.1355],
        [ 0.7507,  0.9637],
        [-0.3740,  0.0230],
        [ 1.7420, -0.2458]])
xbow is essentially a cumulative avg of x for each row
xbow's first input : 
 tensor([[-0.8017,  1.1943],
        [ 0.1072,  1.1221],
        [ 0.3369,  0.9176],
        [ 0.0834,  0.1884],
        [ 0.0246,  0.1237],
        [ 0.1456,  0.2637],
        [ 0.0714,  0.2293],
        [ 0.2802,  0.1699]])


In [78]:
## Version 3: Use Softmax 

# building a wei matrix using softmax 

tril = torch.tril(torch.ones(T,T))
# a simple weight matrix 
wei = torch.zeros(T,T) # how much weight we want to give to all the other tokens in T dim
# wei = affinitiy between tokens and it will be data dependent, it will learn on which tockens from past to be given more/less weightage 
wei = wei.masked_fill(tril==0, float('-inf')) # dont look in future, only have a look in past, 
wei = F.softmax(wei, dim=1)
xbow3 = wei @ x
torch.allclose(xbow, xbow3)


True

In [76]:
# a quick detour on how to use dim 
''' 
- dim 0 means across the rows, and 1 means across the columns 
- so here, we think we want the softmax "across the rows", meaning take the first row, and then perform softmax on each of it's element
- now, that is actually across the columns in pytorch language 
- what is happening is that you want a softmax for the element of each of the columns in the first row
- hence, this is actually a operation column wise / across the columns, and so the dim would be 1 not 0 
'''

' \n- dim 0 means across the rows, and 1 means across the columns \n- so here, we think we want the softmax "across the rows", meaning take the first row, and then perform softmax on each of it\'s element\n- now, that is actually across the columns in pytorch language \n- what is happening is that you want a softmax for the element of each of the columns in the first row\n- hence, this is actually a operation column wise / across the columns, and so the dim would be 1 not 0 \n'