In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cpu


#### Weight Aggregation

In [31]:
#We have a Tensor (B,T,C) where B is Batch, T is time (no of words in context), and C is (channel) embedding for each token
#We want to have one Token talk to other token but only previous to it. (Like token 2 can talk to 0,1 and not to 3 and so forth)
B = 2
T = 5
C = 3
xb = torch.rand((B,T,C))

In [33]:
xb
#Shape Type
# [[                        #Channel1
#   [Word1 Embedding]
#   [Word2 Embedding]
#   [Word3 Embedding]
#  ]
#  [                        #Channel2
#   [Word1 Embedding]
#   [Word2 Embedding]
#   [Word3 Embedding]
#  ]

tensor([[[0.2357, 0.6186, 0.3991],
         [0.1029, 0.5228, 0.8372],
         [0.9153, 0.9635, 0.1719],
         [0.0969, 0.2464, 0.5394],
         [0.3521, 0.2202, 0.7020]],

        [[0.7796, 0.2707, 0.5217],
         [0.1820, 0.5898, 0.4388],
         [0.4741, 0.3602, 0.0515],
         [0.1895, 0.8833, 0.1996],
         [0.2332, 0.0990, 0.9172]]])

In [None]:
xbow = torch.zeros((B,T,C))
for b in range(B):              #for each batch 
    for t in range(T):          #for each word
        xprev = xb[b,:t+1]  #(t,C)
        xbow[b,t] = torch.mean(xprev,0)  #around dim = 0 meaning xbow[b,t] = (C)

In [None]:
xbow
#Shape Type
# [ [ Word1 New Embedding]
#   [ Word2 New Embedding = Word1 + Word2 / 2]
#   [ Word3 New Embedding = Word1 + Word2 + Word3 / 3] 
#...
#  ]

# By using mean as calculation we are creating new embedidng where each word embedding is dependent on its values and previous words.

tensor([[[0.2357, 0.6186, 0.3991],
         [0.1693, 0.5707, 0.6182],
         [0.4180, 0.7016, 0.4694],
         [0.3377, 0.5878, 0.4869],
         [0.3406, 0.5143, 0.5299]],

        [[0.7796, 0.2707, 0.5217],
         [0.4808, 0.4303, 0.4803],
         [0.4786, 0.4069, 0.3373],
         [0.4063, 0.5260, 0.3029],
         [0.3717, 0.4406, 0.4258]]])

#### Demo Matrix Multiplication for Communication between words

In [None]:
#We are now exploring the same thing with multiplication as calculative step.

# a is just a helping matrix
# b is xb                       (input)
# c is xbow                     (output)    
a = torch.ones(3,3)
b = torch.randint(0,10,(3,2)).float()
c = a @ b 

In [8]:
a

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])

In [9]:
b

tensor([[9., 9.],
        [9., 5.],
        [2., 5.]])

In [10]:
c

tensor([[20., 19.],
        [20., 19.],
        [20., 19.]])

In [None]:
#Problem:
#   1. Our current is effected by future words.
#Solutin:
#   1. Masking for preventing future token access (chaning the a matrix to lowe triangular matrix)
a = torch.tril(torch.ones(3,3))
c = a @ b

In [12]:
a

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [13]:
b

tensor([[9., 9.],
        [9., 5.],
        [2., 5.]])

In [14]:
c

tensor([[ 9.,  9.],
        [18., 14.],
        [20., 19.]])

In [None]:
#Still we are summing the values of all the tokens. We need to change it to mean.
#Change a matrix again;
a = torch.tril(torch.ones(3,3))
a = a / torch.sum(a,1,keepdim=True)
c = a @ b

In [16]:
a

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])

In [17]:
b

tensor([[9., 9.],
        [9., 5.],
        [2., 5.]])

In [18]:
c

tensor([[9.0000, 9.0000],
        [9.0000, 7.0000],
        [6.6667, 6.3333]])

#### Method 1

In [None]:
# We can achieve the same thing with a single matrix multiplication (without using for loop) by using broadcasting.

wei = torch.tril(torch.ones(T,T))
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei @ xb    #(T,T) -> (B,T,T) Batch Multiplication    # (B,T,T) @ (B,T,C) -> (B,T,C)

#torch.allclose check is two tensor are matching or not (element wise) (works with floating type as it check with some tolerance)
torch.allclose(xbow,xbow2)

True

In [20]:
xb

tensor([[[0.5008, 0.8865, 0.6845],
         [0.1078, 0.5549, 0.4455],
         [0.5259, 0.5692, 0.6997],
         [0.5621, 0.8024, 0.0466],
         [0.7508, 0.4546, 0.5460]],

        [[0.7871, 0.9321, 0.0516],
         [0.1905, 0.2732, 0.1106],
         [0.4148, 0.5582, 0.4033],
         [0.4960, 0.7807, 0.5988],
         [0.8282, 0.4706, 0.7160]]])

In [21]:
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]])

In [22]:
xbow2

tensor([[[0.5008, 0.8865, 0.6845],
         [0.3043, 0.7207, 0.5650],
         [0.3782, 0.6702, 0.6099],
         [0.4242, 0.7033, 0.4691],
         [0.4895, 0.6535, 0.4845]],

        [[0.7871, 0.9321, 0.0516],
         [0.4888, 0.6026, 0.0811],
         [0.4642, 0.5878, 0.1885],
         [0.4721, 0.6361, 0.2911],
         [0.5433, 0.6030, 0.3761]]])

#### Method 2

In [23]:
tril = torch.tril(torch.ones((T,T)))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei,dim=-1)
xbow3 = wei @ xb
torch.allclose(xbow,xbow3)

True

In [24]:
tril

tensor([[1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1.]])

In [25]:
torch.zeros((T,T)).masked_fill(tril == 0, float('-inf'))

tensor([[0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0.]])

In [26]:
F.softmax(torch.zeros((T,T)).masked_fill(tril == 0, float('-inf')),dim=-1)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]])

#### Single Head of Attention

In [27]:
#Input
B = 2
T = 5 
C = 3
xb = torch.rand((B,T,C))
xb.shape

torch.Size([2, 5, 3])

In [None]:
head_size = 16
#Every token at each position will give one key vector and one query vector and one value vector
#Query (what are u looking for)
#Key (what do i contain)

#my query dot product with all other key to get wei matrix
key = nn.Linear(C,head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)

In [61]:
k = key(xb)     #(B,T,16)
q = query(xb)   #(B,T,16)

In [62]:
xb.shape, k.shape, q.shape

(torch.Size([2, 5, 3]), torch.Size([2, 5, 16]), torch.Size([2, 5, 16]))

In [63]:
wei = q @ k.transpose(-2,-1)        #Transpose dim = -1 with -2 (T,16) -> (16,T)
wei.shape       #(B,T,T)

torch.Size([2, 5, 5])

In [64]:
k.transpose(-2,-1).shape

torch.Size([2, 16, 5])

In [65]:
tril = torch.tril(torch.ones(T,T))
wei = wei.masked_fill(tril== 0, float('-inf'))
wei = F.softmax(wei,dim=-1)

In [66]:
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5280, 0.4720, 0.0000, 0.0000, 0.0000],
        [0.3104, 0.3080, 0.3816, 0.0000, 0.0000],
        [0.2757, 0.2216, 0.2749, 0.2279, 0.0000],
        [0.1672, 0.2185, 0.2319, 0.2172, 0.1652]], grad_fn=<SelectBackward0>)

In [53]:
#Output
out = wei @ xb
out.shape

torch.Size([2, 5, 3])

In [55]:
out

tensor([[[0.0516, 0.7512, 0.2985],
         [0.3308, 0.4479, 0.2216],
         [0.4993, 0.4662, 0.1734],
         [0.5183, 0.5189, 0.2925],
         [0.5848, 0.5365, 0.4767]],

        [[0.3966, 0.5235, 0.4951],
         [0.2870, 0.4561, 0.6321],
         [0.2080, 0.3668, 0.6084],
         [0.2775, 0.3504, 0.5227],
         [0.3824, 0.3633, 0.4628]]], grad_fn=<UnsafeViewBackward0>)

In [68]:
value = nn.Linear(C, head_size, bias = False)
v = value(xb)
out = wei @ v
out.shape

torch.Size([2, 5, 16])