<a href="https://colab.research.google.com/github/rileyburns707/Shakespeare_GPT/blob/main/math_trick__for__self_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F

The mathematical trick in self-attention

---
This will be a quick detour from the 'building_GPT' code but will help us understand the greater idea moving foward

In [None]:
# consider the following toy example:

torch.manual_seed(1337)
B,T,C = 4,8,2 # batch, time, channels
x = torch.randn(B,T,C)
print(x.shape)

# we would like the 8 tokens (up to 8 tokens in a batch) to talk to each other (couple them)
# we want to couple them in a specifc way. The 5th token should not talk to future tokens like 6,7,8
# the easiest way for tokens to communicate is to do an average of the preceding elements
# that method makes you lose a lot of info about special arrangments of the tokens but we will worry about that later

torch.Size([4, 8, 2])


In [None]:
x[0] # 0'th batch element

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [None]:
# version 1

# we want x[b,t] = mean_{i<=t} x [b,i]
xbow = torch.zeros((B,T,C)) # bow = bag of words
for b in range(B): # iterating over the batch dimensions independently
  for t in range(T): # iterating over time
    xprev = x[b, :t+1] # (t,C). xpev is the previous chunk of tokens
    xbow[b,t] = torch.mean(xprev, 0) # averages out the time and you get a 1D C which you store in xbow

In [None]:
xbow[0]

# the first row is the same since you only avergaed the first row
# second row is the average of both rows in x
# vertical average

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

In [None]:
# version 2: using matrix multiply for a weighted aggregation

# for explanation look below for commented version
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei @ x
print(torch.allclose(xbow, xbow2)) # should outputting true since they are the same, but commented out since it was outputting false

False


In [None]:
xbow, xbow2 # if you look they are the exact same

(tensor([[[ 0.1808, -0.0700],
          [-0.0894, -0.4926],
          [ 0.1490, -0.3199],
          [ 0.3504, -0.2238],
          [ 0.3525,  0.0545],
          [ 0.0688, -0.0396],
          [ 0.0927, -0.0682],
          [-0.0341,  0.1332]],
 
         [[ 1.3488, -0.1396],
          [ 0.8173,  0.4127],
          [-0.1342,  0.4395],
          [ 0.2711,  0.4774],
          [ 0.2421,  0.0694],
          [ 0.0084,  0.0020],
          [ 0.0712, -0.1128],
          [ 0.2527,  0.2149]],
 
         [[-0.6631, -0.2513],
          [ 0.1735, -0.0649],
          [ 0.1685,  0.3348],
          [-0.1621,  0.1765],
          [-0.2312, -0.0436],
          [-0.1015, -0.2855],
          [-0.2593, -0.1630],
          [-0.3015, -0.2293]],
 
         [[ 1.6455, -0.8030],
          [ 1.4985, -0.5395],
          [ 0.4954,  0.3420],
          [ 1.0623, -0.1802],
          [ 1.1401, -0.4462],
          [ 1.0870, -0.4071],
          [ 1.0430, -0.1299],
          [ 1.1138, -0.1641]]]),
 tensor([[[ 0.1808, -0.0700]

This is all good but it is very inefficient. The trick is using matrix multiplication

In [None]:
# using matrix multiplication

torch.manual_seed(42)
a = torch.ones(3,3) # 3x3 matrix of all 1's
b = torch.randint(0,10, (3,2)).float() # creates 3x2 matrix w/ random numbers
c = a @ b # matrix multiplication of 'a dot b equals c'
print('a=')
print(a)
print('------')
print('b=')
print(b)
print('------')
print('c=')
print(c)

a=
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
------
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
------
c=
tensor([[14., 16.],
        [14., 16.],
        [14., 16.]])


In [None]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3,3)) # creates a lower triangular 3x3 matrix of all 1's
b = torch.randint(0,10, (3,2)).float() # creates 3x2 matrix w/ random numbers
c = a @ b # matrix multiplication of 'a dot b equals c'
print('a=')
print(a)
print('------')
print('b=')
print(b)
print('------')
print('c=')
print(c)

a=
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
------
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
------
c=
tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])


When seeing how the dot product works we see for the final matrix c we get sums of a and b. We are doing a sum of a variable number (x number) of the rows in matrix b.

We are doing sums but you can get the average for the rows in matrix b. If you normalize the rows in matrix a so they sum to 1, then you will get an average

In [None]:
torch.manual_seed(42)

a = torch.tril(torch.ones(3,3))

a = a / torch.sum(a, 1, keepdim=True)
# computes the sum of elements in each row of a. The argument 1 specifies that the sum is computed along the rows (dim=1).
# keepdim=True keeps the dimensions of the result the same as the original tensor (i.e., the result is a column vector).

b = torch.randint(0,10, (3,2)).float()
c = a @ b
print('a=')
print(a)
print('------')
print('b=')
print(b)
print('------')
print('c=')
print(c)

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
------
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
------
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


So the rows in matrix a sum to 1. The c matrix now has the average of the columns in matrix b. The 2,1 postion in matrix c is the average of 2 and 6. This applies to all the positions in matrix c

In [None]:
# This is the explanation of version 2

wei = torch.tril(torch.ones(T, T)) # a in this case
wei = wei / wei.sum(1, keepdim=True) # gets lower triangular matrix for wei
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [None]:
# explanation for version 2

xbow2 = wei @ x # (T, T) @ (B,T,C) ---> (B, T, T) @ (B,T,C) pytorch sees there is no batch in wei so it will add one
# ----> (B, T, C). So xbow2 will become identical to xbow
torch.allclose(xbow, xbow2) # convinces us they are the same. Should be true but prints false even though they are the exact same

False

The trick is, we were able to use batch matrix multiply to do a weighted aggregation. The weights are specifed by the wei T x T array. Doing weighted sums makes us ensure ***we will only get information from tokens preceding it because we are using a lower triangular method.***

In [None]:
# Version 3
# Use Softmax

tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T,T)) # all zeros in begin
wei = wei.masked_fill(tril == 0, float('-inf')) # for all the elements where tril =0, make them be negative infinity
wei = F.softmax(wei, dim=-1) # softmax is a normalization operation
xbow3 = wei @ x
torch.allclose(xbow, xbow3) # should be true they are the exact same as seen below

False

In [None]:
xbow, xbow3 # they are the exact.

(tensor([[[ 0.1808, -0.0700],
          [-0.0894, -0.4926],
          [ 0.1490, -0.3199],
          [ 0.3504, -0.2238],
          [ 0.3525,  0.0545],
          [ 0.0688, -0.0396],
          [ 0.0927, -0.0682],
          [-0.0341,  0.1332]],
 
         [[ 1.3488, -0.1396],
          [ 0.8173,  0.4127],
          [-0.1342,  0.4395],
          [ 0.2711,  0.4774],
          [ 0.2421,  0.0694],
          [ 0.0084,  0.0020],
          [ 0.0712, -0.1128],
          [ 0.2527,  0.2149]],
 
         [[-0.6631, -0.2513],
          [ 0.1735, -0.0649],
          [ 0.1685,  0.3348],
          [-0.1621,  0.1765],
          [-0.2312, -0.0436],
          [-0.1015, -0.2855],
          [-0.2593, -0.1630],
          [-0.3015, -0.2293]],
 
         [[ 1.6455, -0.8030],
          [ 1.4985, -0.5395],
          [ 0.4954,  0.3420],
          [ 1.0623, -0.1802],
          [ 1.1401, -0.4462],
          [ 1.0870, -0.4071],
          [ 1.0430, -0.1299],
          [ 1.1138, -0.1641]]]),
 tensor([[[ 0.1808, -0.0700]

In [None]:
# next few lines explain the code in version 3 in a more spelled out way

tril

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [None]:
wei = torch.zeros((T,T))
wei

tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [None]:
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [None]:
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

The weights begin with 0, which can be thought of as an interaction strengthener (affinity), in the sense that the weights tell you how much of each past token do you want to aggregate and average up.

The lower triangular method ensures tokens from the past cannot communicate by setting them to negative infinity.

Then we normalize and sum

The weights will not always be zero, it will be data dependent. Each dataset will learn how much of each past token it should use, a token will have a stronger affinity towards one past token compared to another past token. When we normalize and sum we will aggregate their values depending on how interesting the tokens find each other.



General summary of this detour:
You can do weighted aggregations of your past elements by using matrix multiplication of a lower triangular fashion. The elements in the lower triangular part tells you how much each element uses