<a href="https://colab.research.google.com/github/rit-clone/Indecision-app/blob/master/Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import torch

In [2]:
sentence = " Life is short, eat dessert first"
tokens = sentence.split()
sorted_tokens = sorted(tokens)
positions = {token:idx for idx, token in enumerate(sorted_tokens)}
positions

{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short,': 5}

In [3]:
result = [positions[token] for token in tokens]
result = torch.tensor(result)
result

tensor([0, 4, 5, 2, 1, 3])

In [6]:
vocab_size = 50_000
torch.manual_seed(123)
embed= torch.nn.Embedding(vocab_size, 3)
embedded_sentence = embed(result).detach()
embedded_sentence

array([[ 0.3373702 , -0.17777722, -0.3035276 ],
       [ 0.17937961,  1.895148  ,  0.49544638],
       [ 0.26919857, -0.07702024, -1.0204719 ],
       [-0.21963762, -0.37916982,  0.76710707],
       [-0.58801186,  0.3486052 ,  0.66034096],
       [-1.192502  ,  0.6983519 , -1.4097229 ]], dtype=float32)

In [19]:
torch_manual_seed = 123
d = embedded_sentence.shape[1]
d_q, d_k, d_v = 2, 2, 4

W_query = torch.nn.Parameter(torch.rand(d, d_q))
W_key = torch.nn.Parameter(torch.rand(d, d_k))
W_value = torch.nn.Parameter(torch.rand(d, d_v))

W_query.shape, W_key.shape, W_value.shape


(torch.Size([3, 2]), torch.Size([3, 2]), torch.Size([3, 4]))

In [11]:
x_2 = embedded_sentence[1]
x_2 = torch.from_numpy(x_2)
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2.shape)
print(key_2.shape)
print(value_2.shape)

torch.Size([2])
torch.Size([2])
torch.Size([4])


In [20]:
embedded_sentence = torch.tensor(embedded_sentence)
keys = embedded_sentence @ W_key
values = embedded_sentence @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)


torch.Size([6, 3])
keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 4])


  embedded_sentence = torch.tensor(embedded_sentence)


In [15]:
omega_2 = query_2 @ keys.T
print(omega_2)

tensor([-0.0185,  3.2059, -0.1406, -0.5765,  0.1290, -0.5444],
       grad_fn=<SqueezeBackward4>)


In [17]:
import torch.nn.functional as F
attention_weights_2 = F.softmax(omega_2 / d_k**0.5, dim=0)
print(attention_weights_2)

tensor([0.0706, 0.6901, 0.0647, 0.0476, 0.0783, 0.0487],
       grad_fn=<SoftmaxBackward0>)


In [18]:
context_vector_2 = attention_weights_2 @ values
print(context_vector_2.shape)
print(context_vector_2)

torch.Size([4])
tensor([1.0684, 0.7869, 0.9349, 1.4151], grad_fn=<SqueezeBackward4>)


In [27]:
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out_kq, d_out_v):
        super().__init__()
        self.d_in = d_in
        self.d_out_kq = d_out_kq
        self.d_out_v = d_out_v
        self.W_query = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out_v))

    def forward(self, x):
        # x = torch.from_numpy(x)
        query = x @ self.W_query
        key = x @ self.W_key
        value = x @ self.W_value
        attention_weights = F.softmax(query @ key.T / self.d_out_kq**0.5, dim=-1)
        context_vector = attention_weights @ value
        return context_vector

    def __repr__(self):
        return f"SelfAttention(d_in={self.d_in}, d_out_kq={self.d_out_kq}, d_out_v={self.d_out_v})"


torch.manual_seed(123)
d = embedded_sentence.shape[1]
d_q, d_k, d_v = 2, 2, 4
sa = SelfAttention(d, d_q, d_v)
print(sa(embedded_sentence))


tensor([[-0.1564,  0.1028, -0.0763, -0.0764],
        [ 0.5313,  1.3607,  0.7891,  1.3110],
        [-0.3542, -0.1234, -0.2626, -0.3706],
        [ 0.0071,  0.3345,  0.0969,  0.1998],
        [ 0.1008,  0.4780,  0.2021,  0.3674],
        [-0.5296, -0.2799, -0.4107, -0.6006]], grad_fn=<MmBackward0>)


# Mutihead Attention

In [28]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out_kq, d_out_v, num_heads):
        super().__init__()
        self.heads = nn.ModuleList([SelfAttention(d_in, d_out_kq, d_out_v) for _ in range(num_heads)])

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [29]:
torch.manual_seed(123)
d, d_out_kq, d_out_v = 3, 2, 1
sa = SelfAttention(d, d_out_kq, d_out_v)
print(sa(embedded_sentence))

tensor([[-0.0185],
        [ 0.4003],
        [-0.1103],
        [ 0.0668],
        [ 0.1180],
        [-0.1827]], grad_fn=<MmBackward0>)


In [33]:
torch.manual_seed(123)
d, d_out_kq, d_out_v, num_heads = 3, 2, 1, 4
mha = MultiHeadAttention(d, d_out_kq, d_out_v, num_heads)
context_vecs = mha(embedded_sentence)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[-0.0185,  0.0170,  0.1999, -0.0860],
        [ 0.4003,  1.7137,  1.3981,  1.0497],
        [-0.1103, -0.1609,  0.0079, -0.2416],
        [ 0.0668,  0.3534,  0.2322,  0.1008],
        [ 0.1180,  0.6949,  0.3157,  0.2807],
        [-0.1827, -0.2060, -0.2393, -0.3167]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([6, 4])


# Cross Attention

In [37]:
class crossAttention(nn.Module):
  def __init__(self, d_in, d_out_kq, d_out_v):
    super().__init__()
    self.d_in = d_in
    self.d_out_kq = d_out_kq
    self.d_out_v = d_out_v
    self.W_query = nn.Parameter(torch.rand(d_in, d_out_kq))
    self.W_key = nn.Parameter(torch.rand(d_in, d_out_kq))
    self.W_value = nn.Parameter(torch.rand(d_in, d_out_v))

  def forward(self, x, y):
      query = x @ self.W_query
      key = y @ self.W_key
      value = y @ self.W_value
      attn_scores = query @ key.T /(self.d_out_kq**0.5)
      attn_weights = F.softmax(attn_scores, dim=-1)
      context_vector = attn_weights @ value
      return context_vector

In [38]:
torch.manual_seed(123)
d_in, d_out_kq, d_out_v = 3, 2, 4
ca = crossAttention(d_in, d_out_kq, d_out_v)
first_input = embedded_sentence
second_input = torch.rand(8, d_in)
print("First input shape:", first_input.shape)
print("Second input shape:", second_input.shape)

First input shape: torch.Size([6, 3])
Second input shape: torch.Size([8, 3])


In [39]:
context_vectors = ca(first_input, second_input)
print(context_vectors)
print("Output shape:", context_vectors.shape)

tensor([[0.4231, 0.8665, 0.6503, 1.0042],
        [0.4874, 0.9718, 0.7359, 1.1353],
        [0.4054, 0.8359, 0.6258, 0.9667],
        [0.4357, 0.8886, 0.6678, 1.0311],
        [0.4429, 0.9006, 0.6775, 1.0460],
        [0.3860, 0.8021, 0.5985, 0.9250]], grad_fn=<MmBackward0>)
Output shape: torch.Size([6, 4])
