In [1]:

import torch

from torch.nn import functional as F



## Attention

* attention converts "x" into K, Q, V and performs the attention mechanism by the matrix multiplication between K and Q


In [2]:

N = 32 

x = torch.randn(N, 40, 512)
x.shape


torch.Size([32, 40, 512])


## Q


In [3]:

wq = torch.randn(N, 512, 64)
wq.shape


torch.Size([32, 512, 64])

In [4]:

bq = torch.randn(  N, 40, 64  )
bq.shape


torch.Size([32, 40, 64])

In [5]:

Q = torch.matmul(  x, wq  ) 
Q.shape


torch.Size([32, 40, 64])

In [6]:

Q = Q + bq
Q.shape


torch.Size([32, 40, 64])


## K 


In [7]:

wk = torch.randn(N, 512, 64)
wk.shape


torch.Size([32, 512, 64])

In [8]:

bk = torch.randn(  N, 40, 64  )
bk.shape


torch.Size([32, 40, 64])

In [9]:

K = torch.matmul(  x, wk  ) 
K.shape


torch.Size([32, 40, 64])

In [10]:

K = K + bk
K.shape


torch.Size([32, 40, 64])


## Attention Q*K = [N, 40, 40]


In [11]:

attention_scores = torch.matmul(   Q, K.transpose( -2, -1 )   )
attention_scores.shape


torch.Size([32, 40, 40])


## V


In [12]:

wv = torch.randn(N, 512, 64)
wv.shape


torch.Size([32, 512, 64])

In [13]:

bv = torch.randn(  N, 40, 64  )
bv.shape


torch.Size([32, 40, 64])

In [14]:

V = torch.matmul(  x, wv  ) 
V.shape


torch.Size([32, 40, 64])

In [15]:

V = V + bv
V.shape


torch.Size([32, 40, 64])

In [16]:

out = torch.matmul( attention_scores , V )
out.shape


torch.Size([32, 40, 64])


## Concatenate All 8 heads


In [17]:

list_head = [ out for i in range(8) ]


In [18]:

for j in range(len(list_head)):
    print(   list_head[j].shape   )


torch.Size([32, 40, 64])
torch.Size([32, 40, 64])
torch.Size([32, 40, 64])
torch.Size([32, 40, 64])
torch.Size([32, 40, 64])
torch.Size([32, 40, 64])
torch.Size([32, 40, 64])
torch.Size([32, 40, 64])


In [19]:

out_cat = torch.cat(  list_head, dim = -1  )
out_cat.shape


torch.Size([32, 40, 512])


## Another projection for the concatenated 8 heads


In [20]:

8*64


512

In [21]:

w0 = torch.randn(   N, 8*64, 512   )
w0.shape


torch.Size([32, 512, 512])

In [22]:

b0 = torch.randn(  N,  40,  512  )
b0.shape


torch.Size([32, 40, 512])

In [23]:

z = torch.matmul( out_cat, w0  )
z.shape


torch.Size([32, 40, 512])

In [24]:

z = z + b0
z.shape


torch.Size([32, 40, 512])


## The Mask


In [25]:

tril_def = torch.tril(
              torch.ones(10, 10)      ## should be 40 but using 10 for viz
)
tril_def.shape


torch.Size([10, 10])

In [26]:

tril_def 


tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])

In [27]:

## this is just to record tril_deg as a buffer that is not updated during training

'''


import torch.nn as nn

my_tril_reg = nn.Module.register_buffer('tril', tril_def)
my_tril_reg

'''


"\n\n\nimport torch.nn as nn\n\nmy_tril_reg = nn.Module.register_buffer('tril', tril_def)\nmy_tril_reg\n\n"


## Batch of 32 sentences in the attention matrix 40x40


In [28]:

attention_scores.shape


torch.Size([32, 40, 40])

In [29]:

size10_attention = torch.randn(   N, 10, 10  )
size10_attention.shape


torch.Size([32, 10, 10])


## Use the tril for masking


In [30]:

tril_def[:10, :10].shape


torch.Size([10, 10])

In [31]:

tril_def[:10, :10] == 0


tensor([[False,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False, False, False, False,  True,  True,  True],
        [False, False, False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False, False, False]])

In [32]:

size10_attention = size10_attention.masked_fill(
                            tril_def[:10, :10] == 0,
                            float('-inf')
    
)
size10_attention.shape


torch.Size([32, 10, 10])

In [34]:

size10_attention[0]  ## just 1 of the 32 in batch


tensor([[-0.8178,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf],
        [-0.3788,  0.1090,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf],
        [ 0.7846,  0.9569, -1.8265,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf],
        [-0.9664,  1.0718,  0.8340,  0.6549,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf],
        [ 1.4578,  2.2975,  0.9330, -1.4564, -0.2981,    -inf,    -inf,    -inf,
            -inf,    -inf],
        [-0.1892,  0.4582,  1.8721,  0.0903, -0.1273,  0.3352,    -inf,    -inf,
            -inf,    -inf],
        [-0.5006, -2.3496, -0.3927, -1.1587,  0.6523, -1.2205,  0.8650,    -inf,
            -inf,    -inf],
        [-1.4079, -0.8742,  1.4084, -0.2154, -0.2828,  0.1901, -0.6626, -0.6787,
            -inf,    -inf],
        [ 0.0438, -0.6462, -0.2794,  0.6257,  0.6305,  0.2509,  0.2491, -0.1652,
          0.5916,    -inf],
        [-0.9955,  


## Negative infinities

* softmax makes negative infinities close to zero


In [35]:

size10_attention[0]


tensor([[-0.8178,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf],
        [-0.3788,  0.1090,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf],
        [ 0.7846,  0.9569, -1.8265,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf],
        [-0.9664,  1.0718,  0.8340,  0.6549,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf],
        [ 1.4578,  2.2975,  0.9330, -1.4564, -0.2981,    -inf,    -inf,    -inf,
            -inf,    -inf],
        [-0.1892,  0.4582,  1.8721,  0.0903, -0.1273,  0.3352,    -inf,    -inf,
            -inf,    -inf],
        [-0.5006, -2.3496, -0.3927, -1.1587,  0.6523, -1.2205,  0.8650,    -inf,
            -inf,    -inf],
        [-1.4079, -0.8742,  1.4084, -0.2154, -0.2828,  0.1901, -0.6626, -0.6787,
            -inf,    -inf],
        [ 0.0438, -0.6462, -0.2794,  0.6257,  0.6305,  0.2509,  0.2491, -0.1652,
          0.5916,    -inf],
        [-0.9955,  

In [36]:

size10_attention_softmax = F.softmax( size10_attention, dim=-1)
size10_attention_softmax.shape


torch.Size([32, 10, 10])

In [38]:

size10_attention_softmax[0]


tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.3804, 0.6196, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.4422, 0.5253, 0.0325, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0505, 0.3879, 0.3058, 0.2557, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.2419, 0.5601, 0.1431, 0.0131, 0.0418, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0674, 0.1287, 0.5293, 0.0891, 0.0717, 0.1138, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0965, 0.0152, 0.1075, 0.0500, 0.3057, 0.0470, 0.3781, 0.0000, 0.0000,
         0.0000],
        [0.0286, 0.0488, 0.4787, 0.0944, 0.0882, 0.1416, 0.0603, 0.0594, 0.0000,
         0.0000],
        [0.0925, 0.0464, 0.0670, 0.1655, 0.1663, 0.1138, 0.1136, 0.0750, 0.1600,
         0.0000],
        [0.0304, 0.1168, 0.1546, 0.0466, 0.3802, 0.0775, 0.0587, 0.0399, 0.0416,
         0.0538]])


## Assume batch of only one sentence


In [50]:

size_1_attention = torch.randn(  1, 6, 6  )
size_1_attention.shape


torch.Size([1, 6, 6])

In [51]:

size_1_attention


tensor([[[-0.0570,  0.0402, -1.5891,  0.1496, -0.3673,  2.0504],
         [-0.5728, -1.3474, -0.1137,  0.3970,  0.2269, -0.2877],
         [ 1.6589,  2.7081,  0.2418,  0.9370, -0.5362, -1.2214],
         [ 0.1721, -0.3412, -0.5787, -0.5658,  1.6224, -1.5423],
         [ 1.0458, -0.9708, -0.6614,  0.5996, -0.5732, -0.7677],
         [-0.4138, -1.1113, -0.9601, -0.5756, -0.4101, -1.1919]]])

In [52]:

size_1_attention = size_1_attention.masked_fill(
                            tril_def[:6, :6] == 0,
                            float('-inf')
    
)
size_1_attention.shape


torch.Size([1, 6, 6])

In [53]:

size_1_attention.shape


torch.Size([1, 6, 6])

In [54]:

size_1_attention


tensor([[[-0.0570,    -inf,    -inf,    -inf,    -inf,    -inf],
         [-0.5728, -1.3474,    -inf,    -inf,    -inf,    -inf],
         [ 1.6589,  2.7081,  0.2418,    -inf,    -inf,    -inf],
         [ 0.1721, -0.3412, -0.5787, -0.5658,    -inf,    -inf],
         [ 1.0458, -0.9708, -0.6614,  0.5996, -0.5732,    -inf],
         [-0.4138, -1.1113, -0.9601, -0.5756, -0.4101, -1.1919]]])

In [55]:

size_1_attention_softmax = F.softmax( size_1_attention, dim=-1)
size_1_attention_softmax.shape


torch.Size([1, 6, 6])

In [56]:

size_1_attention_softmax


tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.6845, 0.3155, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2440, 0.6968, 0.0592, 0.0000, 0.0000, 0.0000],
         [0.3924, 0.2348, 0.1852, 0.1876, 0.0000, 0.0000],
         [0.4646, 0.0618, 0.0843, 0.2973, 0.0920, 0.0000],
         [0.2278, 0.1134, 0.1319, 0.1937, 0.2286, 0.1046]]])


## Token Embeddings


In [62]:

import torch
import torch.nn as nn



In [63]:

# Suppose we have a vocab of 1000 tokens
vocab_size = 1000
embedding_dim = 512

seq_len = 10
batch_size = 2


In [64]:

# Fake input: token IDs
tokens = torch.randint(0, vocab_size, (batch_size, seq_len))  # [2, 10]


In [65]:

tokens.shape


torch.Size([2, 10])

In [66]:

tokens


tensor([[109, 694, 180, 383, 270,  43, 383, 961, 183, 845],
        [ 77, 588, 779, 664, 402, 764,  97, 369, 781, 295]])

In [67]:

# Define embedding layer
embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)


In [68]:

# Get embedded vectors
x = embedding(tokens)  # shape: [2, 10, 512]
print(x.shape)


torch.Size([2, 10, 512])



## Now Positional and token embedding together


In [69]:

import torch
import torch.nn as nn


In [70]:

# Parameters
vocab_size     = 1000     # Size of the vocabulary
embedding_dim  = 64       # Dimension of embedding
seq_len        = 10       # Max sequence length
batch_size     = 2        # Number of sequences in a batch


In [72]:

# Token indices input (simulated)
tokens = torch.randint(0, vocab_size, (batch_size, seq_len))  # shape: [2, 10]
tokens

tensor([[666, 462, 946, 723, 790, 482, 457, 149, 684, 926],
        [566, 525,  65,  62, 991, 726, 912, 234, 903, 990]])

In [73]:

# Token embedding
token_embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)


In [74]:

# Positional embedding (learned)
position_embedding = nn.Embedding(num_embeddings=seq_len, embedding_dim=embedding_dim)


In [75]:

tokens.shape


torch.Size([2, 10])

In [76]:

# Get token embeddings
x_token = token_embedding(tokens)  # shape: [2, 10, 64]
x_token.shape


torch.Size([2, 10, 64])

In [77]:

# Create position indices for each token in the sequence
positions = torch.arange(seq_len).unsqueeze(0).expand(batch_size, seq_len)  # shape: [2, 10]
positions


tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])

In [78]:

positions.shape


torch.Size([2, 10])

In [79]:

# Get position embeddings
x_pos = position_embedding(positions)  # shape: [2, 10, 64]
x_pos.shape


torch.Size([2, 10, 64])

In [80]:

# Add both embeddings
x = x_token + x_pos  # shape: [2, 10, 64]


In [81]:


print(x.shape)  # Final embedded representation with position info


torch.Size([2, 10, 64])
