### Mathematical trick that is used in self attention inside a transformer.

In [28]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [81]:
torch.manual_seed(1337)
B,T,C = 4, 8, 2
x = torch.randn(B, T, C) # c is the information at each point in the sequence
x.shape

torch.Size([4, 8, 2])

##### Coupling the tokens so that they talk to each other

Token at location 5 should not talk to the one in 6th, 7th and 8th as they are the future tokens in the sequence but talk to the one in 4th, 3rd, 2nd and 1st in the sequence.

Information just flows from the previous context to the current time step and we can't get any info from the future because we are trying to predict them.

Easiest way to do this (to communicate with the past), one could do a average of all of the preceding elements. So for a 5th token, we could do an average of all the channels of the preceding elements. This would be a feature vector that summarizes me in the context of my history.

This kind of interaction is extremely lossy since we have lost a ton of information about the spatial arrangements of all the tokens.

In [87]:
# we want x[b, t] = mean_{i<=t} x[b, i]
xbow = torch.zeros((B, T, C)) # there is a word that is stored on every one of those eight locations, so we are just averaging them kind of like a bag of words.
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] # (t, C), previous chunk of tokens from my current sequence.
        xbow[b, t] = torch.mean(xprev, 0) # averaging over the 0th dimension, time.

In [88]:
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [89]:
xbow[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

Vertical average = Averaging up all the tokens now give this outcome

[-0.0894, -0.4926] is the average of [ 0.1808, -0.0700] <br>
[ 0.1490, -0.3199] is the average of [ 0.1808, -0.0700], [-0.0894, -0.4926],<br>
 and goes on

This is very very inefficient.

#### Making it efficient via matrix multiplication

##### Understanding of matrix multiplication

In [33]:
torch.manual_seed(42)
a = torch.ones(3,3)
b = torch.randint(0, 10, (3,2)).float()
c = a @ b 
print('a=')
print(a)
print('---')
print('b=')
print(b)
print('---')
print('c=')
print(c)
print('---')

a=
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
---
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
---
c=
tensor([[14., 16.],
        [14., 16.],
        [14., 16.]])
---


a is a 3 x 3 of ones <br>
b is a 3 x 2 of random numbers within 0 to 10 range <br>
c is 3x3 multiplied by 3x2 outputting a 3 x 2 <br>


In [34]:
torch.tril(torch.ones(3,3)) # creates a lower triangular array of ones.

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [35]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3,3))
b = torch.randint(0, 10, (3,2)).float()
c = a @ b 
print('a=')
print(a)
print('---')
print('b=')
print(b)
print('---')
print('c=')
print(c)
print('---')

a=
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
---
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
---
c=
tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])
---


What happens here is , 2 in c is [1, 0 , 0] times [2, 6, 6]. since the elements in the a [1., 0., 0.] has 0s, it is ignored and 2 in b is taken. so basically the c took out the b since others where 0.

But when we look at the 2nd row of c we see that we have the sum of elements in the column of b, 2 + 6 results in 8 which is present at the 2nd row of c.

So depending upon the ones and zeroes in the a, we have the elements summed up in the c.

In similar way we could do the average by normalising the rows in a such that it sums up to 1.

In [41]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3,3))
a = a/ torch.sum(a, 1, keepdim= True)
b = torch.randint(0, 10, (3,2)).float()
c = a @ b 
print('a=')
print(a)
print('---')
print('b=')
print(b)
print('---')
print('c=')
print(c)
print('---')

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
---
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
---
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])
---


[0.5000, 0.5000, 0.0000] sums to 1.<br>
Now on doing a a multiplied by b, we get the first row same as the first row of a <br>

But on moving to the 2nd row, we get the average of the first 2 rows of b <br>
[2., 7.], [6., 4.] averages to [4.0000, 5.5000]<br>

Finally average of all of the elements of b are deposited at the bottom of c.

##### Using matrix multiplication

In [84]:
# version 2: using matrix multiplication
wei = torch.tril(torch.ones(T,T))
wei = wei / wei.sum(1, keepdim=True)

In [85]:
wei
# how much every row, we want to average up and each row sums up to 1

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [86]:
xbow2 = wei @ x # (T,T) @ (B, T, C)
xbow2

tensor([[[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]],

        [[ 1.3488, -0.1396],
         [ 0.8173,  0.4127],
         [-0.1342,  0.4395],
         [ 0.2711,  0.4774],
         [ 0.2421,  0.0694],
         [ 0.0084,  0.0020],
         [ 0.0712, -0.1128],
         [ 0.2527,  0.2149]],

        [[-0.6631, -0.2513],
         [ 0.1735, -0.0649],
         [ 0.1685,  0.3348],
         [-0.1621,  0.1765],
         [-0.2312, -0.0436],
         [-0.1015, -0.2855],
         [-0.2593, -0.1630],
         [-0.3015, -0.2293]],

        [[ 1.6455, -0.8030],
         [ 1.4985, -0.5395],
         [ 0.4954,  0.3420],
         [ 1.0623, -0.1802],
         [ 1.1401, -0.4462],
         [ 1.0870, -0.4071],
         [ 1.0430, -0.1299],
         [ 1.1138, -0.1641]]])

since both doesn't have the same dimensions, torch will do a batched matrix multiplication by inserting a B, batch dimension to wei <br>
(T,T) @ (B, T, C) will be (B, T, T) @ (B, T, C)

and exactly as we done above, for every batch dimension, we will have a T @ C

Finally, we will get output as (B, T, C)

xbow2 will be identical to xbow

In [90]:
xbow

tensor([[[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]],

        [[ 1.3488, -0.1396],
         [ 0.8173,  0.4127],
         [-0.1342,  0.4395],
         [ 0.2711,  0.4774],
         [ 0.2421,  0.0694],
         [ 0.0084,  0.0020],
         [ 0.0712, -0.1128],
         [ 0.2527,  0.2149]],

        [[-0.6631, -0.2513],
         [ 0.1735, -0.0649],
         [ 0.1685,  0.3348],
         [-0.1621,  0.1765],
         [-0.2312, -0.0436],
         [-0.1015, -0.2855],
         [-0.2593, -0.1630],
         [-0.3015, -0.2293]],

        [[ 1.6455, -0.8030],
         [ 1.4985, -0.5395],
         [ 0.4954,  0.3420],
         [ 1.0623, -0.1802],
         [ 1.1401, -0.4462],
         [ 1.0870, -0.4071],
         [ 1.0430, -0.1299],
         [ 1.1138, -0.1641]]])

In [91]:
xbow2

tensor([[[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]],

        [[ 1.3488, -0.1396],
         [ 0.8173,  0.4127],
         [-0.1342,  0.4395],
         [ 0.2711,  0.4774],
         [ 0.2421,  0.0694],
         [ 0.0084,  0.0020],
         [ 0.0712, -0.1128],
         [ 0.2527,  0.2149]],

        [[-0.6631, -0.2513],
         [ 0.1735, -0.0649],
         [ 0.1685,  0.3348],
         [-0.1621,  0.1765],
         [-0.2312, -0.0436],
         [-0.1015, -0.2855],
         [-0.2593, -0.1630],
         [-0.3015, -0.2293]],

        [[ 1.6455, -0.8030],
         [ 1.4985, -0.5395],
         [ 0.4954,  0.3420],
         [ 1.0623, -0.1802],
         [ 1.1401, -0.4462],
         [ 1.0870, -0.4071],
         [ 1.0430, -0.1299],
         [ 1.1138, -0.1641]]])

In [92]:
torch.allclose(xbow, xbow2, rtol=1e-4, atol=1e-5)
# rtol=1e-4 sets the relative tolerance, which allows for small proportional differences.
# atol=1e-5 sets the absolute tolerance, allowing small absolute differences.


True

#### Using softmax

In [69]:
tril = torch.tril(torch.ones(T, T))
tril

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [70]:
wei = torch.zeros((T, T))
wei

tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [71]:
wei = wei.masked_fill(tril == 0, float('-inf')) # all elements where tril is 0, we get -inf
wei

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

Here, we would exponentiate every single one of them and divide by the sum.

exponentiating [0., -inf, -inf, -inf, -inf, -inf, -inf, -inf] will result in 1 for 0, and all 0s for -inf.

and on normalising by dividing by sum we get 1 itself.

In [73]:
# version 3: use softmax 
wei = F.softmax(wei, dim = -1) # softmax over every single row, 
# softmax is also a normalisation operation and we get the exact same matrix
wei

tensor([[0.2797, 0.1029, 0.1029, 0.1029, 0.1029, 0.1029, 0.1029, 0.1029],
        [0.1773, 0.1773, 0.1076, 0.1076, 0.1076, 0.1076, 0.1076, 0.1076],
        [0.1519, 0.1519, 0.1519, 0.1089, 0.1089, 0.1089, 0.1089, 0.1089],
        [0.1405, 0.1405, 0.1405, 0.1405, 0.1095, 0.1095, 0.1095, 0.1095],
        [0.1341, 0.1341, 0.1341, 0.1341, 0.1341, 0.1098, 0.1098, 0.1098],
        [0.1300, 0.1300, 0.1300, 0.1300, 0.1300, 0.1300, 0.1100, 0.1100],
        [0.1271, 0.1271, 0.1271, 0.1271, 0.1271, 0.1271, 0.1271, 0.1102],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

wei = torch.zeros((T, T)) this wei starts off wth 0 and finally we get a value. We can look at it as the output of the interaction strength or affinity. This actually tells how many elements from paste we need to aggregate.

wei = wei.masked_fill(tril == 0, float('-inf')), tokens from the past cannot communicate. By setting to infinity, we say we dont aggregate anything from these tokens of the past.


wei = torch.zeros((T, T)) are currently just set by us to be zero. Affinity between the tokens wont be just constants at 0 but data dependent. Token will look into other token and some tokens would feel the other tokens more or less interesting. Depending upon their values, they are going to find each other interesting to different ammounts could be called as affinities.

wei = wei.masked_fill(tril == 0, float('-inf')), future cannot communicate with the past. We are going to clamp them and then when we normalise and sum, we are going to aggregate sort of their values depending upon how itneresting they find each other,

In [78]:
xbow3 = wei @ x
xbow3[0]

tensor([[ 0.0039,  0.0973],
        [-0.0418,  0.0459],
        [-0.0104,  0.0747],
        [ 0.0137,  0.0888],
        [ 0.0129,  0.1237],
        [-0.0217,  0.1125],
        [-0.0190,  0.1094],
        [-0.0341,  0.1332]])

In [97]:
import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(1337)
B,T,C = 4, 8, 2
x = torch.randn(B, T, C) # c is the information at each point in the sequence
x.shape
# we want x[b, t] = mean_{i<=t} x[b, i]
xbow = torch.zeros((B, T, C)) # there is a word that is stored on every one of those eight locations, so we are just averaging them kind of like a bag of words.
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] # (t, C), previous chunk of tokens from my current sequence.
        xbow[b, t] = torch.mean(xprev, 0) # averaging over the 0th dimension, time.

xbow

tensor([[[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]],

        [[ 1.3488, -0.1396],
         [ 0.8173,  0.4127],
         [-0.1342,  0.4395],
         [ 0.2711,  0.4774],
         [ 0.2421,  0.0694],
         [ 0.0084,  0.0020],
         [ 0.0712, -0.1128],
         [ 0.2527,  0.2149]],

        [[-0.6631, -0.2513],
         [ 0.1735, -0.0649],
         [ 0.1685,  0.3348],
         [-0.1621,  0.1765],
         [-0.2312, -0.0436],
         [-0.1015, -0.2855],
         [-0.2593, -0.1630],
         [-0.3015, -0.2293]],

        [[ 1.6455, -0.8030],
         [ 1.4985, -0.5395],
         [ 0.4954,  0.3420],
         [ 1.0623, -0.1802],
         [ 1.1401, -0.4462],
         [ 1.0870, -0.4071],
         [ 1.0430, -0.1299],
         [ 1.1138, -0.1641]]])

In [98]:
import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(1337)
B,T,C = 4, 8, 2
x = torch.randn(B, T, C) # c is the information at each point in the sequence
x.shape

# version 2: using matrix multiplication
wei = torch.tril(torch.ones(T,T))
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei @ x # (T,T) @ (B, T, C)
xbow2


tensor([[[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]],

        [[ 1.3488, -0.1396],
         [ 0.8173,  0.4127],
         [-0.1342,  0.4395],
         [ 0.2711,  0.4774],
         [ 0.2421,  0.0694],
         [ 0.0084,  0.0020],
         [ 0.0712, -0.1128],
         [ 0.2527,  0.2149]],

        [[-0.6631, -0.2513],
         [ 0.1735, -0.0649],
         [ 0.1685,  0.3348],
         [-0.1621,  0.1765],
         [-0.2312, -0.0436],
         [-0.1015, -0.2855],
         [-0.2593, -0.1630],
         [-0.3015, -0.2293]],

        [[ 1.6455, -0.8030],
         [ 1.4985, -0.5395],
         [ 0.4954,  0.3420],
         [ 1.0623, -0.1802],
         [ 1.1401, -0.4462],
         [ 1.0870, -0.4071],
         [ 1.0430, -0.1299],
         [ 1.1138, -0.1641]]])

In [100]:
import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(1337)
B,T,C = 4, 8, 2
x = torch.randn(B, T, C) # c is the information at each point in the sequence
x.shape
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf')) # all elements where tril is 0, we get -inf
# version 3: use softmax 
wei = F.softmax(wei, dim = -1) # softmax over every single row, 
# softmax is also a normalisation operation and we get the exact same matrix
xbow3 = wei @ x
xbow3


tensor([[[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]],

        [[ 1.3488, -0.1396],
         [ 0.8173,  0.4127],
         [-0.1342,  0.4395],
         [ 0.2711,  0.4774],
         [ 0.2421,  0.0694],
         [ 0.0084,  0.0020],
         [ 0.0712, -0.1128],
         [ 0.2527,  0.2149]],

        [[-0.6631, -0.2513],
         [ 0.1735, -0.0649],
         [ 0.1685,  0.3348],
         [-0.1621,  0.1765],
         [-0.2312, -0.0436],
         [-0.1015, -0.2855],
         [-0.2593, -0.1630],
         [-0.3015, -0.2293]],

        [[ 1.6455, -0.8030],
         [ 1.4985, -0.5395],
         [ 0.4954,  0.3420],
         [ 1.0623, -0.1802],
         [ 1.1401, -0.4462],
         [ 1.0870, -0.4071],
         [ 1.0430, -0.1299],
         [ 1.1138, -0.1641]]])

In [109]:
torch.allclose(xbow2, xbow3, rtol=1e-4, atol=1e-5)
# rtol=1e-4 sets the relative tolerance, which allows for small proportional differences.
# atol=1e-5 sets the absolute tolerance, allowing small absolute differences.


True