In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

import time


simulate data and  attention matrices

In [2]:
#parameters
n_batch = 4
n_embed = 10
context_length = 8
vocab_size = 40

#input data
data  = torch.randint(vocab_size,(n_batch,context_length))
# [batch,tokens]

In [3]:
data#.shape
#this is a simulation of 8 tokens in a batch, total of 4 batches

tensor([[ 7, 20, 22,  7, 36, 14, 27, 33],
        [16, 38, 38,  7,  9, 15, 27, 20],
        [11,  8, 24, 27,  8, 10, 37, 32],
        [ 9, 32, 29, 28, 39, 21, 32,  7]])

In [4]:
#embedding matrix
embeddings  = nn.Embedding(vocab_size,n_embed)

# create q,k,v matrices
key = nn.Linear(n_embed, n_embed, bias=False) #this is Wk in notes
query = nn.Linear(n_embed, n_embed, bias=False) #this is Wq in notes
value = nn.Linear(n_embed, n_embed, bias=False)
# these are trainable and static once training completed

In [12]:
embeddings.weight.shape

torch.Size([40, 10])

process the data

In [13]:
x = embeddings(data)

#weight the data pre-attention
k = key(x) #this is K in notes
q = query(x)
v = value(x)

In [14]:
k.shape

torch.Size([4, 8, 10])

In [15]:
x.shape

torch.Size([4, 8, 10])

In [16]:
embeddings.weight.shape

torch.Size([40, 10])

In [17]:
key.weight.shape

torch.Size([10, 10])

In [18]:
#pirnt data sizes
print(f'   Data matrix: {data.shape}') #[batch x seq_len]
print(f'Embeddingfs matrix: {embeddings.weight.shape}') #[n_vocab,n_embed]
print(f'Token embeddings: {x.shape}') #batch x seq_len x n_embed

#size of matrices
print('')
print(f'   Size of Q: {query.weight.shape}') 
print(f'   Size of K: {key.weight.shape}')
print(f'   Size of V: {value.weight.shape}')

#print attention matrices size
print('')
print(f'    Size of Q(x): {k.shape}') #[batch x seq_len x n_embed]
print(f'    Size of K(x): {q.shape}')
print(f'    Size of V(x): {v.shape}')

   Data matrix: torch.Size([4, 8])
Embeddingfs matrix: torch.Size([40, 10])
Token embeddings: torch.Size([4, 8, 10])

   Size of Q: torch.Size([10, 10])
   Size of K: torch.Size([10, 10])
   Size of V: torch.Size([10, 10])

    Size of Q(x): torch.Size([4, 8, 10])
    Size of K(x): torch.Size([4, 8, 10])
    Size of V(x): torch.Size([4, 8, 10])


In [57]:
q

tensor([[[-0.6868,  0.1647,  0.3240, -1.4790, -0.3916,  0.4981,  0.4791,
          -0.5527, -0.5694, -0.6859],
         [-0.6059,  0.7051, -0.2647, -0.1895, -0.1648, -0.1711,  0.9039,
           0.1445, -0.5653, -0.4835],
         [ 0.3330, -0.3519, -0.7846,  0.2211, -0.8384, -1.0135,  0.7510,
          -0.1089,  1.2854,  0.4599],
         [ 0.3330, -0.3519, -0.7846,  0.2211, -0.8384, -1.0135,  0.7510,
          -0.1089,  1.2854,  0.4599],
         [ 0.3330, -0.3519, -0.7846,  0.2211, -0.8384, -1.0135,  0.7510,
          -0.1089,  1.2854,  0.4599],
         [-1.2047,  1.0327,  0.2075, -0.1562, -0.4745, -0.5841,  0.7178,
           0.2405, -0.3858, -1.2063],
         [ 0.4388,  0.2682, -0.6617, -0.8138, -0.0886, -0.2311, -0.2104,
           0.1695, -0.7296, -0.6092],
         [ 0.4201, -0.1744, -0.0847,  0.5388, -1.0043, -0.6665,  0.2823,
          -0.1734,  1.6383,  0.4710]],

        [[ 0.0119,  0.0916,  0.4064, -0.8448,  0.0294,  0.2446, -0.1907,
           0.3070, -0.2472, -0.6603],

implement self-attention

In [19]:
## manual implementation
#"cosine sim" b/w query and keys (note: woudl actuall be cosine sim
qk = q@k.transpose(-2,-1) #transpose non-btach dimensions
#do not transpose the first dimension, wiz batch

# variance scale the QK
qk_scaled = qk * n_embed**-.5

#apply mask for future tokens
pastmask = torch.tril(torch.ones(n_batch,context_length, context_length))
qk_scaled[pastmask==0] = -torch.inf #equivalent of adding a matrix of zeros/-infs

#softmaxify
qk_softmax = F.softmax(qk_scaled,dim=-1)


# and final attnetion mechanism
actsManual = qk_softmax @ v

print(f'Shape of activations (manual): {actsManual.shape}')
# [batch, context, n_embed]

Shape of activations (manual): torch.Size([4, 8, 10])


In [23]:
qk_softmax

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5499, 0.4501, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3357, 0.4158, 0.2486, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2666, 0.2825, 0.1844, 0.2666, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2691, 0.2073, 0.1128, 0.2691, 0.1417, 0.0000, 0.0000, 0.0000],
         [0.1475, 0.1535, 0.1946, 0.1475, 0.1740, 0.1829, 0.0000, 0.0000],
         [0.1491, 0.1075, 0.1857, 0.1491, 0.1567, 0.1373, 0.1144, 0.0000],
         [0.1643, 0.0751, 0.1812, 0.1643, 0.1567, 0.1243, 0.0813, 0.0529]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5429, 0.4571, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3726, 0.3137, 0.3137, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3594, 0.1988, 0.1988, 0.2430, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1794, 0.2273, 0.2273, 0.1740, 0.1919, 0.0000, 0.0000, 0.0000],
         [0.2009, 0.148

In [31]:
(0.5499*0.6730 )+(0.4501*-0.1883)
#this is in actsManual[0][1][0]

0.28532887000000007

In [30]:
v

tensor([[[ 0.6730,  0.1252,  0.6332,  0.0487, -0.3828,  0.0085, -0.3135,
           0.4310,  0.9521,  0.0350],
         [-0.1883, -0.3793,  0.6623,  0.5946,  0.2775, -1.3248, -0.2373,
          -0.6922,  0.2907,  0.8717],
         [ 0.3188, -0.1964,  0.4477,  1.0626,  0.9831, -0.1714, -0.7947,
           0.1711,  0.5835, -0.5877],
         [ 0.6730,  0.1252,  0.6332,  0.0487, -0.3828,  0.0085, -0.3135,
           0.4310,  0.9521,  0.0350],
         [ 0.2426, -0.5918, -0.6923,  0.4786,  0.0748,  0.1482,  0.0862,
           1.7254,  0.3556, -1.0381],
         [ 0.0975,  0.0820, -0.6348, -0.2166, -0.6738,  0.8520,  0.5396,
           0.6224, -0.0365, -0.5814],
         [ 0.2975, -0.0358, -0.3771, -0.3353, -0.3761,  0.5109,  0.3797,
           0.5869, -0.4366, -0.4776],
         [-0.7766,  0.2731,  0.2990,  0.0730,  0.3169, -0.3460,  0.3567,
          -0.3694, -0.5198, -0.0664]],

        [[-0.8386,  0.2882, -0.3978, -0.4237, -0.1172, -0.5148,  0.4298,
           0.5018, -0.1479, -0.3020],

In [27]:
actsManual

tensor([[[ 6.7305e-01,  1.2515e-01,  6.3322e-01,  4.8683e-02, -3.8284e-01,
           8.5107e-03, -3.1349e-01,  4.3095e-01,  9.5213e-01,  3.4964e-02],
         [ 2.8538e-01, -1.0186e-01,  6.4632e-01,  2.9439e-01, -8.5667e-02,
          -5.9159e-01, -2.7919e-01, -7.4521e-02,  6.5443e-01,  4.1156e-01],
         [ 2.2687e-01, -1.6449e-01,  5.9921e-01,  5.2769e-01,  2.3120e-01,
          -5.9056e-01, -4.0142e-01, -1.0059e-01,  5.8549e-01,  2.2809e-01],
         [ 3.6439e-01, -7.6630e-02,  6.0724e-01,  3.8985e-01,  5.5531e-02,
          -4.0133e-01, -3.8069e-01,  6.5758e-02,  6.9730e-01,  1.5655e-01],
         [ 3.9353e-01, -1.1729e-01,  4.3049e-01,  3.3716e-01, -2.7025e-02,
          -2.6838e-01, -2.9534e-01,  3.5227e-01,  6.8890e-01, -1.3894e-02],
         [ 2.9165e-01, -1.4752e-01,  1.3899e-01,  3.5609e-01,  1.0737e-02,
          -5.2617e-02, -1.6982e-01,  4.6818e-01,  4.9419e-01, -2.5719e-01],
         [ 3.2516e-01, -1.2552e-01,  1.0441e-01,  2.8270e-01, -2.5650e-02,
           2.6965e-

In [24]:
qk_softmax.shape

torch.Size([4, 8, 8])

In [26]:
v.shape
# so final result is of shape [batch, seq_len, n_dim]

torch.Size([4, 8, 10])

In [59]:
qk.shape

torch.Size([4, 8, 8])

In [47]:
#pytorch implementation
actsTorch = F.scaled_dot_product_attention(q,k,v,is_causal=True)
print(f'Shape of activations (PyTorch): {actsTorch.shape}')

Shape of activations (PyTorch): torch.Size([4, 8, 10])


In [49]:
print(actsManual[0,:,:])
print(' ')
print(actsTorch[0,:,:])

tensor([[-0.2197, -0.9711,  0.4389, -1.5539,  0.3279,  1.5416, -0.4719,  0.0030,
         -0.1869, -0.0189],
        [-0.2865, -0.8780,  0.2941, -1.2914,  0.0968,  1.3652, -0.1323, -0.3451,
         -0.0482, -0.2714],
        [-0.5426, -0.7807,  0.0061, -0.8907, -0.0743,  1.0662, -0.2795, -0.3974,
         -0.2424, -0.3955],
        [-0.7070, -0.7072, -0.1901, -0.6075, -0.2148,  0.8582, -0.3117, -0.4860,
         -0.3359, -0.5113],
        [-0.8164, -0.6582, -0.3208, -0.4190, -0.3083,  0.7197, -0.3332, -0.5450,
         -0.3981, -0.5885],
        [-1.1114, -0.5058, -0.7399,  0.1425, -0.5604,  0.3268, -0.3326, -0.7598,
         -0.5601, -0.8261],
        [-0.5738, -0.7201, -0.1763, -0.6897, -0.0685,  0.8902, -0.1385, -0.5112,
         -0.3192, -0.4192],
        [-0.5730, -0.3869, -0.4905, -0.0513, -0.4686,  0.1792, -0.1489, -0.5100,
         -0.1328, -0.1744]], grad_fn=<SliceBackward0>)
 
tensor([[-0.2197, -0.9711,  0.4389, -1.5539,  0.3279,  1.5416, -0.4719,  0.0030,
         -0.1869, 

In [None]:
# the perfo could be made better using optimisations
    # JIT (Just in Time) compiler for F.scaled_dot_prodcut_attention()
    # set floating point precision