<a href="https://colab.research.google.com/github/syedmahmoodiagents/transformers/blob/main/Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import math

In [None]:
class SingleAttention(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.embed_dim = embed_dim

        # Linear layers to transform the input embeddings to queries, keys, and values
        self.query_linear = nn.Linear(embed_dim, embed_dim)
        self.key_linear = nn.Linear(embed_dim, embed_dim)
        self.value_linear = nn.Linear(embed_dim, embed_dim)

        self.out_linear = nn.Linear(embed_dim, embed_dim)

    def forward(self, query, key, value):
        batch_size, seq_length, embed_dim = query.size()

        # Linear transformations to get queries, keys, and values
        query = self.query_linear(query)  # (batch_size, query_len, embed_dim)
        key = self.key_linear(key)        # (batch_size, key_len, embed_dim)
        value = self.value_linear(value)  # (batch_size, key_len, embed_dim)

        # Scaled Dot-Product Attention
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(embed_dim)
        attention_weights = torch.softmax(scores, dim=-1)  # (batch_size, query_len, key_len)
        context = torch.matmul(attention_weights, value)   # (batch_size, query_len, embed_dim)

        # Final linear transformation
        out = self.out_linear(context)

        return out, attention_weights


In [None]:

# Example usage
embed_dim = 128
batch_size = 32
query_len = 10
key_len = 15

query = torch.randn(batch_size, query_len, embed_dim)  # (batch_size, query_len, embed_dim)
key = torch.randn(batch_size, key_len, embed_dim)      # (batch_size, key_len, embed_dim)
value = torch.randn(batch_size, key_len, embed_dim)    # (batch_size, key_len, embed_dim)


In [None]:
model = SingleAttention(embed_dim)
out, attention_weights = model(query, key, value)
print(out.shape)  # Expected output: (32, 10, 128)
print(attention_weights.shape)  # Expected output: (32, 10, 15)

torch.Size([32, 10, 128])
torch.Size([32, 10, 15])


**Multi Head Attention**

In [None]:
class MHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.query_linear = nn.Linear(embed_dim, embed_dim)
        self.key_linear = nn.Linear(embed_dim, embed_dim)
        self.value_linear = nn.Linear(embed_dim, embed_dim)
        self.out_linear = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size, seq_length, embed_dim = x.size()

        # Linear transformations
        query = self.query_linear(x)
        key = self.key_linear(x)
        value = self.value_linear(x)

        # Split into multiple heads
        query = query.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        key = key.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        value = value.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)

        # Scaled Dot-Product Attention
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attention_weights = torch.softmax(scores, dim=-1)
        context = torch.matmul(attention_weights, value)

        # Concatenate heads
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_length, embed_dim)

        # Final linear transformation
        out = self.out_linear(context)

        return out



In [None]:
embed_dim = 128
num_heads = 8
x = torch.randn(32, 10, embed_dim)  # (batch_size, seq_length, embed_dim)

self_attention = MHeadAttention(embed_dim, num_heads)
out = self_attention(x)
print(out.shape)  # Expected output: (32, 10, 128)

torch.Size([32, 10, 128])
