In [2]:
import torch

### Self-Attention

In [3]:
torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time, channels
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 32])

Language Modelling Principal (Token Communication):

Token in 5th Location for example should not be able to communicate with token in 6th, 7th or 8th token as they are future tokens. 

It should only be able to communicate with tokens before it.

We can do this using Matrix Multiplication

In [4]:
torch.manual_seed(1337)
a = torch.tril(torch.ones(3, 3)) # torch.ones(3, 3)
b = torch.randint(0, 10, (3,2), dtype = torch.float)
c = a @ b

print('a=')
print(a)
print('b=')
print(b)
print('c=')
print(c)

a=
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
b=
tensor([[5., 7.],
        [2., 0.],
        [5., 3.]])
c=
tensor([[ 5.,  7.],
        [ 7.,  7.],
        [12., 10.]])


In [5]:
# Version 1
wei = torch.tril(torch.ones(T,T))
wei = wei / wei.sum(1, keepdim=True) # normalize the rows
out_1 = wei @ x # (T,T) @ (B,T,C) ---> (B,T,T) @ (B,T,C) ---> (B,T,C)
out_1.shape

torch.Size([4, 8, 32])

In [6]:
# Version 2
tril = torch.tril(torch.ones(T,T))
wei = torch.zeros(T,T)
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = torch.softmax(wei, dim=-1)
out_2 = wei @ x
out_2.shape

torch.Size([4, 8, 32])

- Every single token at every position will now emit three vectors, a Query and a Key and a Value
- Query means What am I looking for?
- Key means What do I contain?
- Value means What will I communicate?
- Their dot product of Q and K will then basically give us attention scores meaning which token has a higher affinity to which other tokens.
- Finally we will take the dot product of the attention scores with the values to get the final output.

In [None]:
# Lets now Implement a Single Head of Self Attention

head_size = 16
key = torch.nn.Linear(C, head_size, bias = False)
query = torch.nn.Linear(C, head_size, bias = False)
value = torch.nn.Linear(C, head_size, bias = False)

k = key(x)
q = query(x)
v = value(x)

wei = q @ k.transpose(-2,-1) / head_size**0.5  # (B, T, head_size) @ (B, head_size, T) -> (B, T, T)

tril = torch.tril(torch.ones(T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = torch.softmax(wei, dim=-1)
out = wei @ v
out.shape

torch.Size([4, 8, 16])

In [11]:
wei[0]

tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [7.1821e-01, 2.8179e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [6.2292e-01, 2.6785e-01, 1.0923e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [2.7394e-02, 3.5937e-02, 8.8566e-02, 8.4810e-01, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [9.2227e-01, 1.0381e-02, 4.4360e-02, 7.6919e-04, 2.2219e-02, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [7.4951e-02, 1.1837e-01, 2.4863e-01, 3.7787e-01, 6.3060e-03, 1.7387e-01,
         0.0000e+00, 0.0000e+00],
        [1.8730e-01, 5.8106e-02, 6.1382e-02, 3.6453e-03, 6.4791e-01, 1.8433e-02,
         2.3230e-02, 0.0000e+00],
        [4.1080e-01, 6.0570e-02, 2.1063e-02, 1.6063e-03, 1.6883e-01, 1.5380e-02,
         4.0297e-03, 3.1772e-01]], grad_fn=<SelectBackward0>)