In [35]:
import torch
import torch.nn as nn
from torch.nn import functional as F

import time


simulate data and  attention matrices

In [36]:
#parameters
n_batch = 4
n_embed = 10
context_length = 8
vocab_size = 40

#input data
data  = torch.randint(vocab_size,(n_batch,context_length))
# [batch,tokens]

In [37]:
data#.shape
#this is a simulation of 8 tokens in a batch, total of 4 batches

tensor([[15, 33, 35, 35, 35, 11, 38,  2],
        [ 8,  0, 16, 33, 25, 17,  9, 36],
        [23, 12,  6, 22, 24, 17,  2,  0],
        [20, 32, 13, 11, 17,  0,  6, 15]])

In [38]:
#embedding matrix
embeddings  = nn.Embedding(vocab_size,n_embed)

# create q,k,v matrices
key = nn.Linear(n_embed, n_embed, bias=False) #this is Wk in notes
query = nn.Linear(n_embed, n_embed, bias=False) #this is Wq in notes
value = nn.Linear(n_embed, n_embed, bias=False)
# these are trainable and static once training completed

process the data

In [39]:
x = embeddings(data)

#weight the data pre-attention
k = key(x) #this is K in notes
q = query(x)
v = value(x)

In [56]:
k.shape

torch.Size([4, 8, 10])

In [40]:
x.shape

torch.Size([4, 8, 10])

In [41]:
embeddings.weight.shape

torch.Size([40, 10])

In [42]:
key.weight.shape

torch.Size([10, 10])

In [44]:
#pirnt data sizes
print(f'   Data matrix: {data.shape}') #[batch x seq_len]
print(f'Embeddingfs matrix: {embeddings.weight.shape}') #[n_vocab,n_embed]
print(f'Token embeddings: {x.shape}') #batch x seq_len x n_embed

#size of matrices
print('')
print(f'   Size of Q: {query.weight.shape}') 
print(f'   Size of K: {key.weight.shape}')
print(f'   Size of V: {value.weight.shape}')

#print attention matrices size
print('')
print(f'    Size of Q(x): {k.shape}') #[batch x seq_len x n_embed]
print(f'    Size of K(x): {q.shape}')
print(f'    Size of V(x): {v.shape}')

   Data matrix: torch.Size([4, 8])
Embeddingfs matrix: torch.Size([40, 10])
Token embeddings: torch.Size([4, 8, 10])

   Size of Q: torch.Size([10, 10])
   Size of K: torch.Size([10, 10])
   Size of V: torch.Size([10, 10])

    Size of Q(x): torch.Size([4, 8, 10])
    Size of K(x): torch.Size([4, 8, 10])
    Size of V(x): torch.Size([4, 8, 10])


In [57]:
q

tensor([[[-0.6868,  0.1647,  0.3240, -1.4790, -0.3916,  0.4981,  0.4791,
          -0.5527, -0.5694, -0.6859],
         [-0.6059,  0.7051, -0.2647, -0.1895, -0.1648, -0.1711,  0.9039,
           0.1445, -0.5653, -0.4835],
         [ 0.3330, -0.3519, -0.7846,  0.2211, -0.8384, -1.0135,  0.7510,
          -0.1089,  1.2854,  0.4599],
         [ 0.3330, -0.3519, -0.7846,  0.2211, -0.8384, -1.0135,  0.7510,
          -0.1089,  1.2854,  0.4599],
         [ 0.3330, -0.3519, -0.7846,  0.2211, -0.8384, -1.0135,  0.7510,
          -0.1089,  1.2854,  0.4599],
         [-1.2047,  1.0327,  0.2075, -0.1562, -0.4745, -0.5841,  0.7178,
           0.2405, -0.3858, -1.2063],
         [ 0.4388,  0.2682, -0.6617, -0.8138, -0.0886, -0.2311, -0.2104,
           0.1695, -0.7296, -0.6092],
         [ 0.4201, -0.1744, -0.0847,  0.5388, -1.0043, -0.6665,  0.2823,
          -0.1734,  1.6383,  0.4710]],

        [[ 0.0119,  0.0916,  0.4064, -0.8448,  0.0294,  0.2446, -0.1907,
           0.3070, -0.2472, -0.6603],

implement self-attention

In [46]:
## manual implementation
#"cosine sim" b/w query and keys (note: woudl actuall be cosine sim
qk = q@k.transpose(-2,-1) #transpose non-btach dimensions
#do not transpose the first dimension, wiz batch

# variance scale the QK
qk_scaled = qk * n_embed**-.5

#apply mask for future tokens
pastmask = torch.tril(torch.ones(n_batch,context_length, context_length))
qk_scaled[pastmask==0] = -torch.inf #equivalent of adding a matrix of zeros/-infs

#softmaxify
qk_softmax = F.softmax(qk_scaled,dim=-1)


# and final attnetion mechanism
actsManual = qk_softmax @ v

print(f'Shape of activations (manual): {actsManual.shape}')
# [batch, context, n_embed]

Shape of activations (manual): torch.Size([4, 8, 10])


In [47]:
#pytorch implementation
actsTorch = F.scaled_dot_product_attention(q,k,v,is_causal=True)
print(f'Shape of activations (PyTorch): {actsTorch.shape}')

Shape of activations (PyTorch): torch.Size([4, 8, 10])


In [49]:
print(actsManual[0,:,:])
print(' ')
print(actsTorch[0,:,:])

tensor([[-0.2197, -0.9711,  0.4389, -1.5539,  0.3279,  1.5416, -0.4719,  0.0030,
         -0.1869, -0.0189],
        [-0.2865, -0.8780,  0.2941, -1.2914,  0.0968,  1.3652, -0.1323, -0.3451,
         -0.0482, -0.2714],
        [-0.5426, -0.7807,  0.0061, -0.8907, -0.0743,  1.0662, -0.2795, -0.3974,
         -0.2424, -0.3955],
        [-0.7070, -0.7072, -0.1901, -0.6075, -0.2148,  0.8582, -0.3117, -0.4860,
         -0.3359, -0.5113],
        [-0.8164, -0.6582, -0.3208, -0.4190, -0.3083,  0.7197, -0.3332, -0.5450,
         -0.3981, -0.5885],
        [-1.1114, -0.5058, -0.7399,  0.1425, -0.5604,  0.3268, -0.3326, -0.7598,
         -0.5601, -0.8261],
        [-0.5738, -0.7201, -0.1763, -0.6897, -0.0685,  0.8902, -0.1385, -0.5112,
         -0.3192, -0.4192],
        [-0.5730, -0.3869, -0.4905, -0.0513, -0.4686,  0.1792, -0.1489, -0.5100,
         -0.1328, -0.1744]], grad_fn=<SliceBackward0>)
 
tensor([[-0.2197, -0.9711,  0.4389, -1.5539,  0.3279,  1.5416, -0.4719,  0.0030,
         -0.1869, 

In [None]:
# the perfo could be made better using optimisations
    # JIT (Just in Time) compiler for F.scaled_dot_prodcut_attention()
    # set floating point precision