In [45]:
import torch
import torch.nn.functional as F

In [63]:
torch.manual_seed(1337)


<torch._C.Generator at 0x10dca6a70>

Notes:
- Attention is a **communication mechanism**. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
- Each example across batch dimension is of course processed completely independently and never "talk" to each other
- In an "encoder" attention block just delete the single line that does masking with `tril`, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.
- meaning that in encoder, all tokens are allowed to communicate with each other, but in a decoder only tokens in past can communicate with the current token
- encoder may be used for tasks like sentiment analysis where its okay for all tokens to interact, but in tasks like language generation since we are predicting the future token we should use decoder
- "self-attention" just means that the keys and values are produced from the same source as queries. In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)
- "Scaled" attention additional divides `wei` by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much. Illustration below

# Mathematical Trick used in Self-Attention

In [3]:
# consider this toy example

B,T,C = 4,8,2 # batch, time, channels
x= torch.randn(B,T,C)
x.shape

# right now these 8 tokens in a batch are independent but we would want them to talk to each other
# any given token can only communicate with its previous tokens coz the ones after that are in the future and the current token should not have access to that information
# so information only flows from previous context to current time step 


torch.Size([4, 8, 2])

In [4]:
# the simiplest way to communicate with the previous token is to just take the average till the current token
# so that kind of becomes like a feature vector that summarizes the current token in context of all the previous tokens
# but the problem is that sum/avergage is a very weak form of interation(very lossy), we lose a lot of information
# specifiacally we lose a ton of information about the spatial arrangement/order of the tokens but thats okay for now, we will come back to this later

In [5]:
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in  range(T):
        xprev = x[b,:t+1] # upto the current token # t,c 
        xbow[b,t] = torch.mean(xprev, dim=0)


In [6]:
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [7]:
xbow[0]
# here 1st row is same as x[0]
# 2nd row is the average of 1st and 2nd row of x[0]
# last row is the average of all the rows of x[0]


tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

## The Trick - we can be very efficient about this calculations by doing matrix multiplications instead of loops


In [8]:
torch.manual_seed(42)
a =torch.ones(3,3)
b=torch.randint(0,10,(3,2)).float()
c =a @ b
print('a=')
print(a)
print(a.shape)
print('b=')
print(b)
print(b.shape)
print('c=')
print(c)
print(c.shape)

a=
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
torch.Size([3, 3])
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
torch.Size([3, 2])
c=
tensor([[14., 16.],
        [14., 16.],
        [14., 16.]])
torch.Size([3, 2])


In [9]:
torch.tril(torch.ones(3,3)) # lower triangular matrix

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [10]:
torch.manual_seed(42)
a =torch.tril(torch.ones(3,3))
b=torch.randint(0,10,(3,2)).float()
c =a @ b
print('a=')
print(a)
print(a.shape)
print('b=')
print(b)
print(b.shape)
print('c=')
print(c)
print(c.shape)

print('Note that when we use lower triangular matrix as a, at each element we are getting the sum of all the previous elements')
print('Now we just need to make sure we take the mean instead of sum')

a=
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
torch.Size([3, 3])
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
torch.Size([3, 2])
c=
tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])
torch.Size([3, 2])
Note that when we use lower triangular matrix as a, at each element we are getting the sum of all the previous elements
Now we just need to make sure we take the mean instead of sum


In [12]:
torch.manual_seed(42)
a =torch.tril(torch.ones(3,3))
a = a/torch.sum(a,dim=1,keepdim=True)

b=torch.randint(0,10,(3,2)).float()
c =a @ b
print('a=')
print(a)
print(a.shape)
print('b=')
print(b)
print(b.shape)
print('c=')
print(c)
print(c.shape)

print('Now we have are getting the desired output')
print('Lets make it more efficient')


a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
torch.Size([3, 3])
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
torch.Size([3, 2])
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])
torch.Size([3, 2])
Now we have are getting the desired output
Lets make it more efficient


In [34]:
xbow2 = torch.zeros((B,T,C))
wei =  torch.tril(torch.ones(T,T)) # weights
wei= wei/torch.sum(wei,dim=1,keepdim=True)

#for b in range(B): # this can be just writren as following without the loop
#    xbow[b] = normalizer @ x[b]
    
xbow2  = wei @ x

torch.allclose(xbow,xbow2) # torch.all_close is used to compare two tensors element wise  and return True if they are equal within some tolerance, False otherwise,
# tolerance meaning that the difference between the two tensors should be less than the tolerance value, default is 1e-5

True

In [31]:
wei


tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [27]:
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [35]:
xbow2[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

In [33]:
# The trick was that we were able to do the weighted sum of all the previous tokens in one go using matrix multiplication
# the weights were stored in the lower triangular matrix

### Anoter useful way to write the same
# version3  - using softmax

In [40]:
# version3 : use softmax
tril  = torch.tril (torch.ones(T,T))
wei = torch.zeros((T,T))



In [41]:
tril

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [42]:
wei

tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [46]:
wei = wei.masked_fill(tril==0, float('-inf'))
wei

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [47]:
wei = F.softmax(wei, dim=1) # because e^-inf is 0, so the softmax will be 0 for all the elements that are masked, for rest the value will become average of non inf values
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [50]:
# so we can do 
tril = torch.tril(torch.ones((T,T)))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril==0, float('-inf'))
wei = F.softmax(wei, dim=1) # dim=1 because we want to normalize the rows
xbow3 = wei @ x

torch.allclose(xbow2, xbow3)

True

In [51]:
# you can thin of wei as the affinity matrix, it tells us how much each token should be weighted with respect to the current token,
# these were set to zero currently but we can learn these weights while training the model


### implementing a single head of attention

In [52]:
import torch.nn as nn

In [60]:
# v4  - using self attention
torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time, channels
x= torch.randn(B,T,C)

# lets see a single head perform self attention
head_size = 16 #hyperparameter
key = nn.Linear(C, head_size,bias=False)

query = nn.Linear(C, head_size,bias=False)
value = nn.Linear(C, head_size,bias=False)
k = key(x) # B,T,head_size
q = query(x) # B,T,head_size

wei = q @ k.transpose(-2,-1) # (B,T,head_size) @ (B,head_size,T)--> (B,T,T) # we only want to transpose the last two dimensions
tril = torch.tril(torch.ones((T,T)))
#wei = torch.zeros((T,T))
wei = wei.masked_fill(tril==0, float('-inf'))
wei = F.softmax(wei, dim=1) # dim=1 because we want to normalize the rows
v = value(x) 
out = wei @ v 
#out = wei @ x

out.shape

torch.Size([4, 8, 32])

In [None]:
# now wei is not constant but data dependent 

In [61]:
wei[0]

tensor([[0.0248, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0052, 0.0091, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0521, 0.0135, 0.2482, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3171, 0.0214, 0.1642, 0.1188, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0412, 0.0487, 0.1046, 0.0742, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1060, 0.5347, 0.2059, 0.1030, 0.7402, 0.0192, 0.0000, 0.0000],
        [0.4298, 0.3409, 0.1769, 0.2027, 0.0480, 0.8472, 0.2329, 0.0000],
        [0.0238, 0.0316, 0.1002, 0.5013, 0.0117, 0.1336, 0.7671, 1.0000]],
       grad_fn=<SelectBackward0>)

In [64]:
out[0][0]

tensor([ 0.0045, -0.0017, -0.0089, -0.0227,  0.0155,  0.0006,  0.0237,  0.0016,
         0.0090,  0.0290, -0.0335, -0.0127,  0.0059, -0.0060, -0.0229,  0.0383,
         0.0335, -0.0035,  0.0071,  0.0240, -0.0506,  0.0123,  0.0369,  0.0147,
         0.0031, -0.0388, -0.0288, -0.0083,  0.0111, -0.0199,  0.0379,  0.0623],
       grad_fn=<SelectBackward0>)

# scaled attention


In [68]:
k  = torch.randn(B,T,head_size)
q = torch.randn(B,T,head_size)
wei = q @ k.transpose(-2,-1) 
k.var(), q.var(), wei.var() # notice that variance of wei is order of head_size, this is because we are doing a dot product of two random vectors

(tensor(0.9182), tensor(0.9487), tensor(13.5532))

In [88]:
k  = torch.randn(B,T,head_size)
q = torch.randn(B,T,head_size)
wei = q @ k.transpose(-2,-1) * head_size**-0.5
k.var(), q.var(), wei.var() # notice that variance of wei is now near 1, this is because we are scaling the dot product by head_size**-0.5

(tensor(1.0104), tensor(1.0204), tensor(1.1053))

### Why is the scaling of wei important?
- Notice the wei will feed into the softmax
- So its important, especially at intilisation that wei be fairly diffused
- If wei takes on very positive or negative values, then softmax will converge towards one hot vectors
- If that happens, it would mean that every node is aggregating information from only one node, which is not what we want(atleast at initialisation)

In [87]:
a = torch.softmax(torch.tensor([-1.0,0.2,0.3,0.75]), dim=-1) # fairly diffused softmax
b = torch.softmax(torch.tensor([-1.0,0.2,0.3,0.75])*8, dim=-1,) # sharpening towards the max value
print(a)
print(b)

tensor([0.0728, 0.2416, 0.2670, 0.4187])
tensor([7.9985e-07, 1.1810e-02, 2.6283e-02, 9.6191e-01])


In [5]:
a = [1,2,3,4,5]
a[-2:]

[4, 5]

In [6]:
# multi headed attention is just multiple attentions done in parallel and then concatenating their outputs

## Residual/SKip connections
- As we use multiheaded self attention, our network starts becoming deep
- And these deep networks often run into optimisation problems
- One solution for this is using skip/residual connections