In [1]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
device="mps"

In [2]:
def calculate_attention(
        query:torch.Tensor,
        key:torch.Tensor,
        value:torch.Tensor,
):
    attention_scores=torch.matmul(query,key.transpose(-2,-1))
    attention_scores=attention_scores/math.sqrt(key.shape[-1])
    attention_scores=F.softmax(attention_scores,dim=-1)
    attention=torch.matmul(attention_scores,value)
    return attention,attention_scores

In [3]:
batch_size=2
num_queries=4
num_keys=16
embed_size=8
query=torch.randn(batch_size,num_queries,embed_size)
keys=torch.randn(batch_size,num_keys,embed_size)
value=torch.randn(batch_size,num_keys,embed_size)


In [4]:
query.shape


torch.Size([2, 4, 8])

In [5]:
query


tensor([[[ 8.2355e-02, -9.0994e-01, -1.0084e+00, -1.6941e+00, -6.4022e-01,
           3.4664e-01, -1.9269e+00, -2.7803e-01],
         [ 2.2504e-01,  1.2899e+00, -1.8094e+00,  2.8041e-01, -2.3472e-01,
          -6.6442e-01, -2.2193e-01,  3.9833e-01],
         [-2.9465e-01,  6.4270e-01, -1.8070e+00, -1.3021e-01, -3.4097e-01,
           1.4162e+00,  2.8721e-01, -2.8228e-02],
         [-1.1559e+00, -1.6789e+00,  1.9322e+00,  5.4015e-02, -1.1656e-01,
          -1.2037e-01, -1.2872e-01, -8.1735e-03]],

        [[-3.5681e-01, -2.3552e-01,  1.7536e+00, -1.9116e+00,  1.3666e+00,
          -1.6470e+00,  1.9247e-01,  1.2102e+00],
         [-1.1653e+00,  1.9983e+00, -4.5601e-01,  7.4543e-01, -8.0717e-01,
           6.5167e-01, -9.2001e-01,  1.0739e+00],
         [-3.1597e-02, -8.4615e-01, -2.8452e-01,  6.0574e-01,  6.9423e-01,
           8.2001e-01,  1.3151e+00,  8.3859e-01],
         [ 3.6834e-01,  1.2041e+00,  1.4952e-03,  3.7204e-01, -1.0653e+00,
          -5.1867e-01, -1.0929e+00,  1.5713e+00]

In [8]:
keys.shape


torch.Size([2, 16, 8])

In [13]:
attention,attention_scores=calculate_attention(query,keys,value)
attention.shape,attention_scores.shape

(torch.Size([2, 4, 8]), torch.Size([2, 4, 16]))

In [14]:
text='attention ! we will train attention .'
text_tokens=text.split()
vocab=set(text_tokens)
vocab_to_idx={token: idx for idx,token in enumerate(vocab)}
print(vocab_to_idx)


{'!': 0, 'we': 1, 'train': 2, 'attention': 3, '.': 4, 'will': 5}


In [15]:
int_tokens=torch.tensor([vocab_to_idx[token] for token in text_tokens])
int_tokens=int_tokens.unsqueeze(0)
print(int_tokens,"\nshape:",int_tokens.shape)

tensor([[3, 0, 1, 5, 2, 3, 4]]) 
shape: torch.Size([1, 7])


In [16]:
embedding_layer=nn.Embedding(num_embeddings=len(vocab),embedding_dim=8)


In [17]:
embedding_layer.weight


Parameter containing:
tensor([[ 1.5885e+00, -1.2211e+00,  5.7886e-01,  3.7819e-01, -9.1898e-01,
         -1.7496e-01,  5.4185e-01, -4.0746e-01],
        [ 4.4848e-01,  2.8228e-04,  1.1137e+00, -3.7905e-01,  4.7805e-01,
          1.2945e+00, -1.9192e+00, -1.6568e+00],
        [ 1.5679e-01, -1.0564e+00,  7.3530e-01,  8.3025e-01, -9.4689e-01,
         -1.7919e+00,  2.1915e-01,  6.8174e-01],
        [ 8.1251e-01,  3.2273e-01,  4.1815e-01,  7.7991e-01, -2.0580e-01,
         -3.6966e-01,  5.4677e-01,  5.3566e-01],
        [-2.3744e-01,  1.0902e-02,  4.6608e-01, -6.6422e-01,  1.6636e+00,
          1.3373e+00, -1.0628e+00,  1.2497e+00],
        [ 1.5756e+00,  1.4329e+00,  1.5827e+00,  8.6508e-01,  2.2641e-03,
         -3.7737e-02,  1.4275e-01, -9.2139e-01]], requires_grad=True)

In [21]:
embedding_layer.weight.shape

torch.Size([6, 8])

In [18]:
embeddings=embedding_layer(int_tokens)
embeddings

tensor([[[ 8.1251e-01,  3.2273e-01,  4.1815e-01,  7.7991e-01, -2.0580e-01,
          -3.6966e-01,  5.4677e-01,  5.3566e-01],
         [ 1.5885e+00, -1.2211e+00,  5.7886e-01,  3.7819e-01, -9.1898e-01,
          -1.7496e-01,  5.4185e-01, -4.0746e-01],
         [ 4.4848e-01,  2.8228e-04,  1.1137e+00, -3.7905e-01,  4.7805e-01,
           1.2945e+00, -1.9192e+00, -1.6568e+00],
         [ 1.5756e+00,  1.4329e+00,  1.5827e+00,  8.6508e-01,  2.2641e-03,
          -3.7737e-02,  1.4275e-01, -9.2139e-01],
         [ 1.5679e-01, -1.0564e+00,  7.3530e-01,  8.3025e-01, -9.4689e-01,
          -1.7919e+00,  2.1915e-01,  6.8174e-01],
         [ 8.1251e-01,  3.2273e-01,  4.1815e-01,  7.7991e-01, -2.0580e-01,
          -3.6966e-01,  5.4677e-01,  5.3566e-01],
         [-2.3744e-01,  1.0902e-02,  4.6608e-01, -6.6422e-01,  1.6636e+00,
           1.3373e+00, -1.0628e+00,  1.2497e+00]]],
       grad_fn=<EmbeddingBackward0>)

In [19]:
embeddings.shape


torch.Size([1, 7, 8])

In [22]:
embedding_dim=8
embedding_layer=nn.Embedding(num_embeddings=len(vocab),embedding_dim=embedding_dim)
query_dense_layer=nn.Linear(in_features=embedding_dim,out_features=8)
key_dense_layer=nn.Linear(in_features=embedding_dim,out_features=8)
value_dense_layer=nn.Linear(in_features=embedding_dim,out_features=8)

In [23]:
embeddings=embedding_layer(int_tokens)
embeddings

tensor([[[ 0.3041,  2.2403,  1.8302, -0.6226, -1.1173, -1.7150, -0.4159,
          -0.4360],
         [-0.0413,  0.1477, -1.5658, -0.0085,  0.7859, -1.3566,  0.1088,
           1.2149],
         [-1.8079,  0.8667,  0.2132,  1.0926,  0.2734, -1.0515, -0.1971,
          -0.4932],
         [ 1.6140,  0.5343, -0.1385,  1.3212, -0.2533,  0.1117,  0.8089,
           1.5170],
         [-0.1078,  0.0337, -0.8248, -0.0400,  0.9999,  0.4553,  0.3001,
           0.6170],
         [ 0.3041,  2.2403,  1.8302, -0.6226, -1.1173, -1.7150, -0.4159,
          -0.4360],
         [ 0.4470,  0.8136,  0.1836, -0.8535,  1.7319, -0.4128,  0.8730,
          -0.7801]]], grad_fn=<EmbeddingBackward0>)

In [24]:
embeddings.shape


torch.Size([1, 7, 8])

In [25]:
query=query_dense_layer(embeddings)
key=key_dense_layer(embeddings)
value=value_dense_layer(embeddings)

In [26]:
query.shape,key.shape,value.shape


(torch.Size([1, 7, 8]), torch.Size([1, 7, 8]), torch.Size([1, 7, 8]))

In [27]:
attention,attention_scores=calculate_attention(query,key,value)
attention.shape,attention_scores.shape

(torch.Size([1, 7, 8]), torch.Size([1, 7, 7]))

In [28]:
attention_scores


tensor([[[0.1042, 0.1930, 0.1378, 0.1430, 0.1718, 0.1042, 0.1461],
         [0.1486, 0.1129, 0.0992, 0.2332, 0.1338, 0.1486, 0.1236],
         [0.1809, 0.0850, 0.1287, 0.1487, 0.1360, 0.1809, 0.1397],
         [0.1588, 0.1405, 0.1473, 0.2149, 0.1051, 0.1588, 0.0745],
         [0.1621, 0.1205, 0.1540, 0.1472, 0.1262, 0.1621, 0.1279],
         [0.1042, 0.1930, 0.1378, 0.1430, 0.1718, 0.1042, 0.1461],
         [0.1293, 0.1555, 0.2022, 0.1089, 0.1414, 0.1293, 0.1333]]],
       grad_fn=<SoftmaxBackward0>)

In [29]:
right_triangular_mask=torch.tril(torch.ones_like(attention_scores))
right_triangular_mask

tensor([[[1., 0., 0., 0., 0., 0., 0.],
         [1., 1., 0., 0., 0., 0., 0.],
         [1., 1., 1., 0., 0., 0., 0.],
         [1., 1., 1., 1., 0., 0., 0.],
         [1., 1., 1., 1., 1., 0., 0.],
         [1., 1., 1., 1., 1., 1., 0.],
         [1., 1., 1., 1., 1., 1., 1.]]])

In [30]:
def calculate_masked_attention(
    values: torch.Tensor,
    keys: torch.Tensor,
    query: torch.Tensor,
    mask: torch.Tensor = None,
):
    attention_scores = torch.matmul(query, keys.transpose(-2, -1))
    attention_scores = attention_scores / math.sqrt(keys.shape[-1])
    if mask is not None:
        attention_scores = torch.where(
            mask == 0,
            torch.full_like(attention_scores, -1e9),
            attention_scores
        )
    attention_scores = F.softmax(attention_scores, dim=-1)
    attention = torch.matmul(attention_scores, values)
    return attention, attention_scores

In [31]:
attention_context,attention_scores=calculate_masked_attention(query,key,value,right_triangular_mask)


In [32]:
attention_context.shape


torch.Size([1, 7, 8])

In [33]:
attention_scores.shape


torch.Size([1, 7, 7])

In [34]:
attention_scores


tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.8257, 0.1743, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3587, 0.3439, 0.2974, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4028, 0.1402, 0.2340, 0.2230, 0.0000, 0.0000, 0.0000],
         [0.2439, 0.1466, 0.2639, 0.1745, 0.1712, 0.0000, 0.0000],
         [0.1223, 0.2282, 0.0931, 0.2697, 0.1644, 0.1223, 0.0000],
         [0.1476, 0.1448, 0.2325, 0.1169, 0.1178, 0.1476, 0.0928]]],
       grad_fn=<SoftmaxBackward0>)