# Implementation of Attention

In this notebook, we are going to implement several attention mechanisms from scratch using PyTorch.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

Here we are implementi a simple version of attention using only dot product. The input is a sequence of vectors (for example word embeddings) and the output is a sequence of the same length where each vector is a weighted sum of all the input vectors.


The attention mechanism allows each position to attend to all positions in the sequence, creating a weighted combination based on similarity (dot product).

This simplified version does not include trainable parameters but allow us to understand the core concept of attention.

In [2]:
from mermaid import Mermaid

diagram = """
graph TD
    A["Input: [Your, journey, starts, with, one, step]"] --> B["Select Query: journey (index 1)"]
    A --> C["All tokens as Keys"]
    B --> D["Compute Dot Products"]
    C --> D
    D --> E["Attention Scores: [1.96, 2.79, 2.85, 1.63, 1.78, 1.72]"]
    E --> F["Apply Softmax"]
    F --> G["Attention Weights: [0.13, 0.22, 0.24, 0.09, 0.14, 0.13]"]
    G --> H["Weighted Sum"]
    A --> H
    H --> I["Context Vector"]
    
    style B fill:#ff9999
    style I fill:#99ff99
"""

Mermaid(diagram)

In [3]:
inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your
     [0.55, 0.87, 0.66], # journey
     [0.57, 0.85, 0.64], # starts
     [0.22, 0.58, 0.33], # with
     [0.77, 0.25, 0.10], # one
     [0.05, 0.80, 0.55]] # step
)


In [4]:
query = inputs[1]

attention_score2 = torch.empty(inputs.shape[0])
for i, key in enumerate(inputs):
    attention_score2[i] = torch.dot(key, query)

attention_score2 

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

In [5]:
attention_weights2 = F.softmax(attention_score2, dim=0)
attention_weights2

tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])

In [6]:
query = inputs[1]

context_vec_2 = torch.zeros(query.shape)
for i, value in enumerate(inputs):
    context_vec_2 += attention_weights2[i] * value
context_vec_2

tensor([0.4419, 0.6515, 0.5683])

In [7]:
class DotProductAttention(nn.Module):
    """
    Dot Product Attention Mechanism
    """
    def __init__(self):
        super(DotProductAttention, self).__init__()

    def forward(self, query):
        """
        Forward pass of the attention mechanism.
        Args:
            query (torch.Tensor): The query tensor of shape (d_model,).
        Returns:
            dict: A dictionary containing the context vector and attention weights.
        """
        atten_score = torch.empty((inputs.shape[0], query.shape[0]))
        for i, key in enumerate(inputs):
            for j, q in enumerate(query):
                atten_score[i, j] = torch.dot(key, q)
        atten_weights = F.softmax(atten_score, dim=-1)
        context_vec = atten_weights @ query
        return {"context_vector": context_vec, "attention_weights": atten_weights}


attention = DotProductAttention()
context_vec, atten_weights = attention(inputs).values()
context_vec, atten_weights

(tensor([[0.4421, 0.5931, 0.5790],
         [0.4419, 0.6515, 0.5683],
         [0.4431, 0.6496, 0.5671],
         [0.4304, 0.6298, 0.5510],
         [0.4671, 0.5910, 0.5266],
         [0.4177, 0.6503, 0.5645]]),
 tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
         [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
         [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
         [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
         [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
         [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]]))

Self-attention implementation 

This time we will implement self-attention, where the input sequence attends to itself. This is a key component of transformer models.

In [8]:
x_2 = inputs[1]
d_in = x_2.shape[0]
d_out = 2

In [9]:
torch.manual_seed(123)

W_q = nn.Parameter(torch.randn(d_in, d_out), requires_grad=False) # requires_grad=True for training
W_k = nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
W_v = nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)

In [10]:
query_2= x_2 @ W_q
keys_2 = x_2 @ W_k
values_2 = x_2 @ W_v

query_2

tensor([-1.1729, -0.0048])

In [11]:
keys = inputs @ W_k
values = inputs @ W_v
keys, values

(tensor([[-0.1823, -0.6888],
         [-0.1142, -0.7676],
         [-0.1443, -0.7728],
         [ 0.0434, -0.3580],
         [-0.6467, -0.6476],
         [ 0.3262, -0.3395]]),
 tensor([[ 0.1196, -0.3566],
         [ 0.4107,  0.6274],
         [ 0.4091,  0.6390],
         [ 0.2436,  0.4182],
         [ 0.2653,  0.6668],
         [ 0.2728,  0.3242]]))

In [14]:
keys_2 = keys[1]
attn_score22 = query_2.dot(keys_2)
attn_score22

tensor(0.1376)

In [15]:
atten_score2 = query_2 @ keys.T
atten_score2

tensor([ 0.2172,  0.1376,  0.1730, -0.0491,  0.7616, -0.3809])

In [16]:
d_k = keys.shape[-1]

atten_weights2 = torch.softmax(atten_score2 / d_k**0.5, dim=-1)
atten_weights2

tensor([0.1704, 0.1611, 0.1652, 0.1412, 0.2505, 0.1117])

In [18]:
context_vec2 = atten_weights2 @ values
context_vec2

tensor([0.2854, 0.4081])

In [25]:
class ScaledDotProductAttention(nn.Module):
    """
    Scaled Dot Product Attention Mechanism
    """
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        values = self.W_value(x)
        queries = self.W_query(x)
        atten_scores = queries @ keys.T # omega
        atten_weights = F.softmax(atten_scores / keys.shape[-1]**0.5, dim=-1)
        context_vectors = atten_weights @ values
        return context_vectors



In [26]:
torch.manual_seed(123)
attention = ScaledDotProductAttention(d_in=3, d_out=2)
context_vec3 = attention(inputs)
context_vec3

tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)