The masked attention formula is given by: $\text{MaskedAttention}(Q, K, V, M) = \text{SoftMax}\left(\frac{QK^T}{\sqrt{d_k}} + M\right)V$.

In [1]:
 import torch
 from torch.nn import Embedding

 def token_vectorization(prompt):
  word_dic = {s.lower():i for i,s in enumerate(sorted(prompt.split(" ")))}
  tokens_int = [word_dic[i.lower()] for i in prompt.split(" ")]
  tokens_tf = torch.tensor(tokens_int)
  vocab_size = 50000
  torch.manual_seed(123)
  embedder = Embedding(vocab_size, 3)
  token_embedding = embedder(tokens_tf).detach()
  return token_embedding

In [2]:
def initialize_weights(token_embedding):
  torch.manual_seed(123)
  d = token_embedding.shape[1]
  d_q, d_k, d_v = 24, 24, 28
  w_query = torch.nn.Parameter(torch.rand(d,d_q))
  w_key = torch.nn.Parameter(torch.rand(d,d_k))
  w_value = torch.nn.Parameter(torch.rand(d,d_v))
  q = token_embedding @ w_query
  k = token_embedding @ w_key
  v = token_embedding @ w_value
  return q, k, v, d_k

In [3]:
def mask_matrix(dim):
  # Step 1: Create a lower triangular matrix
  lower_triangular = torch.tril(torch.ones(dim, dim))
  # Step 2: Invert to create an upper triangular matrix
  upper_triangular = 1 - lower_triangular
  # Step 3: Replace 1s with -inf
  m = upper_triangular.masked_fill(upper_triangular == 1, float("-inf"))
  return m

In [4]:
import math
from torch.nn.functional import softmax

def calc_attention(q,k,v,d_k,token_embedding):
  q_kt = q @ k.T
  f = q_kt/math.sqrt(d_k)
  m = mask_matrix(f.shape[1])
  f1 = f + m
  s = softmax(f1, dim=-1)
  context_vector = s @ v
  return context_vector

In [6]:
prompt = input()
token_embedding = token_vectorization(prompt)
q, k, v, d_k  = initialize_weights(token_embedding)
calc_attention(q,k,v,d_k, token_embedding)

Hello World I am an attention seeker I love attention


tensor([[-0.0031, -0.1732, -0.1618,  0.0456,  0.0057, -0.1686, -0.2620, -0.1475,
         -0.0319, -0.0568, -0.0197,  0.1504, -0.1717,  0.0673, -0.2114,  0.0125,
         -0.1724,  0.2598, -0.1221,  0.1350,  0.1590,  0.1032, -0.0372,  0.1667,
         -0.0777,  0.2618, -0.0834, -0.0971],
        [-0.8196, -0.6501, -1.0132, -0.9124, -0.2872, -0.6141, -0.2721, -0.6837,
         -0.8824, -1.7377, -0.7412, -1.5834, -1.1017, -0.6477, -0.7898, -0.9227,
         -1.1314, -1.4216, -0.8238, -1.7673, -0.7232, -0.6134, -1.5221, -1.3918,
         -1.6002, -0.8852, -0.2614,  0.1580],
        [-0.2711, -0.2106, -0.2155, -0.2527, -0.0690, -0.1550, -0.1670, -0.2228,
         -0.2738, -0.4673, -0.3073, -0.3992, -0.2247, -0.2208, -0.1635, -0.3456,
         -0.3649, -0.3816, -0.1677, -0.4282, -0.3239, -0.3247, -0.3502, -0.3985,
         -0.4569, -0.2914, -0.2056, -0.0531],
        [ 1.2176,  1.4388,  0.6619,  0.6606,  0.1180,  0.8869,  2.0562,  1.4213,
          1.1974,  1.5436,  1.8411,  0.4929,  0.5993