In [1]:
# Self Attention block

In [2]:
# Mathematical trick in self attention
# at the heart of an efficient implementation of self attention

In [3]:
import torch
import torch.nn.functional as F

In [4]:
torch.manual_seed(1337)
B, T, C = 4, 8, 2 # Batch, Time, Channels : Time component being the given x tokens predict next token which moves over time
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [5]:
# we'd like these 8 tokens in the time dimension to talk to each other
# in particular, we have a specific config in mind
# token at Nth place should only talk to tokens < Nth place, cannot talk to future tokens since they're supposed to be predicted
# 5th token should only talk to 4th, 3rd, 2nd and 1st
# information only flows from previous context to the current one, cannot get information from future context

In [6]:
# the simplest way to communicate is to average all tokens of the preceeding elements
# the average could be the feature vector of me in the context of my history
# just sum or average is extremely weak form of interaction - extremely lossy - but okay for now

In [7]:
# for every single batch element independently
# for every Tth token, we'd like to calculate the average of all the vectors in all the previous tokens

In [8]:
# version 1 : nested for loop select previous and mean

In [9]:
# we want x[b, t] = mean_{i<=t} x[b, i]
xbow = torch.zeros(B, T, C) # bow = bag of words : term used when averaging up words
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] # (t, C)
        xbow[b, t] = torch.mean(xprev, 0)

In [10]:
xbow.shape

torch.Size([4, 8, 2])

In [11]:
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [12]:
xbow[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

In [13]:
# Notice how the averages are above, at any given row they're the averages of the all the rows above

In [14]:
torch.mean(x[0, :5], dim=0) == xbow[0, 4]

tensor([True, True])

In [15]:
# version 2 : lower triangular normalised over rows matrix multiply

In [16]:
# we can very very efficient about this using matrix multiplication
# lets look at a toy example

In [17]:
torch.manual_seed(42)
a = torch.ones(3, 3)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print(f'a=\n{a}\n---')
print(f'b=\n{b}\n---')
print(f'c=\n{c}\n---')

a=
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
---
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
---
c=
tensor([[14., 16.],
        [14., 16.],
        [14., 16.]])
---


In [18]:
# what i was thinking (about masking upper triangle of a matrix
# torch has a function called tril that returns the lower triangular matrix

In [19]:
torch.tril(torch.ones(3, 3))

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [20]:
# lets check what happens when we do that

In [21]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print(f'a=\n{a}\n---')
print(f'b=\n{b}\n---')
print(f'c=\n{c}\n---')

a=
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
---
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
---
c=
tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])
---


In [22]:
# notice how we are doing sums of the previous rows*
# current we are doing sums, we could do an average as well

In [23]:
# if we normalize the elements of the lower triangular matrix to sum to 1.0, then the matrix product will be an average of previous rows

In [24]:
torch.tril(torch.ones(3, 3))

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [25]:
torch.sum(torch.tril(torch.ones(3, 3)), 1, keepdim=True)

tensor([[1.],
        [2.],
        [3.]])

In [26]:
a = torch.tril(torch.ones(3, 3))
a = a / a.sum(dim=1, keepdim=True)
a

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])

In [27]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
a = a / a.sum(dim=1, keepdim=True)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print(f'a=\n{a}\n---')
print(f'b=\n{b}\n---')
print(f'c=\n{c}\n---')

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
---
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
---
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])
---


In [28]:
# by manpulating the elements of multiplying matrix we can compute the averages in an incremental fashion

In [29]:
# lets take the same nested for loop example above and vectorize it to make it more efficient

In [30]:
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(dim=1, keepdim=True)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [31]:
xbow2 = wei @ x # (T, T) @ (B, T, C) => (B, T, T) @ (B, T, C) => batched matrix mult => (B, T, C)

In [32]:
torch.allclose(xbow, xbow2)

True

In [33]:
xbow[0], xbow2[0]

(tensor([[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]]),
 tensor([[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]]))

In [34]:
# version 3 : using softmax

In [35]:
tril = torch.tril(torch.ones(T, T))
wei  = torch.zeros(T, T) # torch.ones(T, T) : doesn't matter as long as all elments equal since they'll get softmaxed

In [36]:
wei = wei.masked_fill(tril == 0, float('-inf'))
wei

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [37]:
# interesting thing with softmax is that, it effectively drops out the weights with -inf
# and soft-maxes out the rest of the elements,
# since the lower triangular is all equal, 0 in this case, could be any other numbers I think
# they get soft-maxed out to equal normalised weights that sum up to 1

In [38]:
wei = F.softmax(wei, dim=1)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [39]:
xbow3 = wei @ x
torch.allclose(xbow3, xbow)

True

In [40]:
# in softmax we exponentiate every element, and the divide by their sum
# exp(0) = 1, exp(-inf) = 0, then we normalise the sums

In [41]:
# The reason softmax is more interesting and it will be used later
# the weights begin with zero, think of it as an affinity
# how much of token from the past do we want to aggregate for the future
# tokens from future cannot participate for the aggregation
# the weights are going to be data dependent and learnable even if we start out at zero
#
# token will start interacting, find other tokens more or less interesting
# depending on the values of the weights, the interest will be proportional, let's call that affinities
# with a normalise and matrix mult of the affinities we will find the resulting interest values
# that's the preview for self attention

In [42]:
# Long story short
# You can do weighted aggeration of past elements
# by using matrix multiplication
# of a lower triangular fashion
# the elements of which are telling you how much of each elements fused/contributes to an individual position