In [39]:
import torch
import torch.nn as nn
import torch.nn.functional as F
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


#### Weight Aggregation

In [2]:
#We have a Tensor (B,T,C) where B is Batch, T is time (no of words in context), and C is (channel) embedding for each token
#We want to have one Token talk to other token but only previous to it. (Like token 2 can talk to 0,1 and not to 3 and so forth)
B = 2
T = 5
C = 3
xb = torch.rand((B,T,C))

In [3]:
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        xprev = xb[b,:t+1]  #(t,C)
        xbow[b,t] = torch.mean(xprev,0)  #around dim = 0 meaning xbow[b,t] = (C)

In [4]:
xb

tensor([[[0.9192, 0.3751, 0.0884],
         [0.5822, 0.2083, 0.2330],
         [0.6818, 0.9236, 0.1016],
         [0.6826, 0.6645, 0.8907],
         [0.0498, 0.1298, 0.0327]],

        [[0.4716, 0.9773, 0.8686],
         [0.0369, 0.8572, 0.6969],
         [0.9450, 0.1676, 0.3098],
         [0.2177, 0.3863, 0.9981],
         [0.6362, 0.6365, 0.7730]]])

In [5]:
xbow

tensor([[[0.9192, 0.3751, 0.0884],
         [0.7507, 0.2917, 0.1607],
         [0.7278, 0.5023, 0.1410],
         [0.7165, 0.5429, 0.3284],
         [0.5831, 0.4603, 0.2693]],

        [[0.4716, 0.9773, 0.8686],
         [0.2543, 0.9172, 0.7828],
         [0.4845, 0.6674, 0.6251],
         [0.4178, 0.5971, 0.7184],
         [0.4615, 0.6050, 0.7293]]])

#### Demo Matrix Multiplication for Comms

In [6]:
#Full Sum
a = torch.ones(3,3)
b = torch.randint(0,10,(3,2)).float()
c = a @ b 

In [7]:
a

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])

In [8]:
b

tensor([[2., 7.],
        [4., 8.],
        [0., 9.]])

In [9]:
c

tensor([[ 6., 24.],
        [ 6., 24.],
        [ 6., 24.]])

In [10]:
# Masking for preventing future token access
a = torch.tril(torch.ones(3,3))
c = a @ b

In [11]:
a

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [12]:
b

tensor([[2., 7.],
        [4., 8.],
        [0., 9.]])

In [13]:
c

tensor([[ 2.,  7.],
        [ 6., 15.],
        [ 6., 24.]])

In [14]:
#Average
a = torch.tril(torch.ones(3,3))
a = a / torch.sum(a,1,keepdim=True)
c = a @ b

In [15]:
a

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])

In [16]:
b

tensor([[2., 7.],
        [4., 8.],
        [0., 9.]])

In [17]:
c

tensor([[2.0000, 7.0000],
        [3.0000, 7.5000],
        [2.0000, 8.0000]])

#### Change1

In [24]:
wei = torch.tril(torch.ones(T,T))
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei @ xb    #(T,T) -> (B,T,T) Batch Multiplication    # (B,T,T) @ (B,T,C) -> (B,T,C)
torch.allclose(xbow,xbow2)

True

In [22]:
xb

tensor([[[0.9192, 0.3751, 0.0884],
         [0.5822, 0.2083, 0.2330],
         [0.6818, 0.9236, 0.1016],
         [0.6826, 0.6645, 0.8907],
         [0.0498, 0.1298, 0.0327]],

        [[0.4716, 0.9773, 0.8686],
         [0.0369, 0.8572, 0.6969],
         [0.9450, 0.1676, 0.3098],
         [0.2177, 0.3863, 0.9981],
         [0.6362, 0.6365, 0.7730]]])

In [23]:
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]])

In [21]:
xbow2

tensor([[[0.9192, 0.3751, 0.0884],
         [0.7507, 0.2917, 0.1607],
         [0.7278, 0.5023, 0.1410],
         [0.7165, 0.5429, 0.3284],
         [0.5831, 0.4603, 0.2693]],

        [[0.4716, 0.9773, 0.8686],
         [0.2543, 0.9172, 0.7828],
         [0.4845, 0.6674, 0.6251],
         [0.4178, 0.5971, 0.7184],
         [0.4615, 0.6050, 0.7293]]])

#### Change2

In [35]:
tril = torch.tril(torch.ones((T,T)))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei,dim=-1)
xbow3 = wei @ xb
torch.allclose(xbow,xbow3)

True

In [28]:
tril

tensor([[1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1.]])

In [36]:
torch.zeros((T,T)).masked_fill(tril == 0, float('-inf'))

tensor([[0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0.]])

In [37]:
F.softmax(torch.zeros((T,T)).masked_fill(tril == 0, float('-inf')),dim=-1)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]])

#### Single Head of Attention

In [None]:
#Input
B = 2
T = 5 
C = 3
xb = torch.rand((B,T,C))
xb.shape

torch.Size([2, 5, 3])

In [None]:
head_size = 16
#Every token at each position will give one key vector and one query vector
#Query (what are u looking for)
#Key (what do i contain)

#my query dot product with all other key to get wei matrix
key = nn.Linear(C,head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)

In [61]:
k = key(xb)     #(B,T,16)
q = query(xb)   #(B,T,16)

In [62]:
xb.shape, k.shape, q.shape

(torch.Size([2, 5, 3]), torch.Size([2, 5, 16]), torch.Size([2, 5, 16]))

In [63]:
wei = q @ k.transpose(-2,-1)        #Transpose dim = -1 with -2 (T,16) -> (16,T)
wei.shape       #(B,T,T)

torch.Size([2, 5, 5])

In [64]:
k.transpose(-2,-1).shape

torch.Size([2, 16, 5])

In [65]:
tril = torch.tril(torch.ones(T,T))
wei = wei.masked_fill(tril== 0, float('-inf'))
wei = F.softmax(wei,dim=-1)

In [66]:
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5280, 0.4720, 0.0000, 0.0000, 0.0000],
        [0.3104, 0.3080, 0.3816, 0.0000, 0.0000],
        [0.2757, 0.2216, 0.2749, 0.2279, 0.0000],
        [0.1672, 0.2185, 0.2319, 0.2172, 0.1652]], grad_fn=<SelectBackward0>)

In [53]:
#Output
out = wei @ xb
out.shape

torch.Size([2, 5, 3])

In [55]:
out

tensor([[[0.0516, 0.7512, 0.2985],
         [0.3308, 0.4479, 0.2216],
         [0.4993, 0.4662, 0.1734],
         [0.5183, 0.5189, 0.2925],
         [0.5848, 0.5365, 0.4767]],

        [[0.3966, 0.5235, 0.4951],
         [0.2870, 0.4561, 0.6321],
         [0.2080, 0.3668, 0.6084],
         [0.2775, 0.3504, 0.5227],
         [0.3824, 0.3633, 0.4628]]], grad_fn=<UnsafeViewBackward0>)

In [68]:
value = nn.Linear(C, head_size, bias = False)
v = value(xb)
out = wei @ v
out.shape

torch.Size([2, 5, 16])