<a href="https://colab.research.google.com/github/pksX01/AI-practice/blob/main/Deep%20Learning/Attention%20Mechanism/multi_head_mask_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Masked Attention

In [10]:
class Attention(nn.Module):
    def __init__(self, row_dim=0, col_dim=1, d_model=2):
        super().__init__()
        self.row_dim = row_dim
        self.col_dim = col_dim

        self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)

    def forward(self, q_encoding, k_encoding, v_encoding, mask=None):
        q = self.W_q(q_encoding)
        k = self.W_k(k_encoding)
        v = self.W_v(v_encoding)
        similarity = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))
        scaled_similarity = similarity/ torch.tensor(k.size(self.col_dim)**0.5)

        if mask is not None: #In case of Mask Attention
            scaled_similarity = scaled_similarity.masked_fill(mask=mask, value=-1e9)

        attention_percents = F.softmax(scaled_similarity, dim=self.col_dim)
        attention_scores = torch.matmul(attention_percents, v)
        return attention_scores

In [11]:
q_encoding = torch.tensor([[1.16, 0.23],
            [0.57, 1.36],
            [4.41, -2.16]])

k_encoding = torch.tensor([[1.16, 0.23],
            [0.57, 1.36],
            [4.41, -2.16]])

v_encoding = torch.tensor([[1.16, 0.23],
            [0.57, 1.36],
            [4.41, -2.16]])

## Calculate Self Attention

In [12]:
torch.manual_seed(42)
attention = Attention(row_dim=0, col_dim=1, d_model=2)
attention(q_encoding, k_encoding, v_encoding)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

## Calculate Masked Attention

In [13]:
mask = torch.tril(torch.ones(3,3))
mask = (mask == 0)
mask

tensor([[False,  True,  True],
        [False, False,  True],
        [False, False, False]])

In [14]:
mask_attention = Attention(row_dim=0, col_dim=1, d_model=2)
mask_attention(q_encoding, k_encoding, v_encoding, mask=mask)

tensor([[-0.3970, -0.2253],
        [-0.3488,  0.1166],
        [-0.7190, -0.8447]], grad_fn=<MmBackward0>)

# Multi-head attention

In [15]:
class MultiHeadAttention(nn.Module):
    def __init__(self, row_dim=0, col_dim=1, d_model=2, num_heads=1):
        super().__init__()
        self.row_dim = row_dim
        self.col_dim = col_dim

        self.heads = nn.ModuleList([Attention(row_dim=0, col_dim=1, d_model=2) for _ in range(num_heads)])

    def forward(self, q_encoding, k_encoding, v_encoding, mask=None):
        return torch.cat([head(q_encoding, k_encoding, v_encoding, mask) for head in self.heads], dim=self.col_dim)

## Calculate attention for single head

In [17]:
torch.manual_seed(42)
singleHeadAttention = MultiHeadAttention(row_dim=0, col_dim=1, d_model=2, num_heads=1)
singleHeadAttention(q_encoding, k_encoding, v_encoding, mask=None)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<CatBackward0>)

## Calculate attention for multiple heads

In [18]:
torch.manual_seed(42)
multiHeadAttention = MultiHeadAttention(row_dim=0, col_dim=1, d_model=2, num_heads=2)
multiHeadAttention(q_encoding, k_encoding, v_encoding, mask=None)

tensor([[ 1.0100,  1.0641, -0.7081, -0.8268],
        [ 0.2040,  0.7057, -0.7417, -0.9193],
        [ 3.4989,  2.2427, -0.7190, -0.8447]], grad_fn=<CatBackward0>)

## Calculate masked attention for single head

In [19]:
torch.manual_seed(42)
mask = torch.tril(torch.ones(3,3))
mask = (mask == 0)
singleMaskedHeadAttention = MultiHeadAttention(row_dim=0, col_dim=1, d_model=2, num_heads=1)
singleMaskedHeadAttention(q_encoding, k_encoding, v_encoding, mask=mask)

tensor([[ 0.6038,  0.7434],
        [-0.0062,  0.6072],
        [ 3.4989,  2.2427]], grad_fn=<CatBackward0>)

## Calculate masked attention for multiple heads

In [21]:
torch.manual_seed(42)
mask = torch.tril(torch.ones(3,3))
mask = (mask == 0)
multiMaskedHeadAttention = MultiHeadAttention(row_dim=0, col_dim=1, d_model=2, num_heads=2)
multiMaskedHeadAttention(q_encoding, k_encoding, v_encoding, mask=mask)

tensor([[ 0.6038,  0.7434, -0.3970, -0.2253],
        [-0.0062,  0.6072, -0.3488,  0.1166],
        [ 3.4989,  2.2427, -0.7190, -0.8447]], grad_fn=<CatBackward0>)