## --------------------------Multi-Head Attention with weight splits--------------------------

In [27]:
# Creating the multi-head attention compact class
import torch
from torch import nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, num_heads, context_length, dropout_rate, bias_units=False):
        super().__init__()
        assert d_out % num_heads == 0, "dimensions out must be divisible by number of heads"
        # Getting the head dimensions
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        # Initializing the key query value weights - (d_out, d_out)
        self.w_key = nn.Linear(d_in, d_out, bias=bias_units)
        self.w_query = nn.Linear(d_in, d_out, bias=bias_units)
        self.w_value = nn.Linear(d_in, d_out, bias=bias_units)
        # Initializing the final projection layer - optional - (d_out, d_out)
        self.out_proj = nn.Linear(d_out, d_out)
        # Creating the masking layer
        self.register_buffer("mask", torch.triu(
            torch.ones(context_length, context_length),
            diagonal = 1
        ))
        # Creating the dropout layer
        self.dropout = nn.Dropout(dropout_rate)
    # Forward pass    
    def forward(self, x):
        # Exploding the input shape
        b, num_tokens, d_out = x.shape
        # Getting the key query value matrices (b, num_tokens, d_out)
        keys = self.w_key(x)
        queries = self.w_query(x)
        values =  self.w_value(x)
        # Reshaping the key query value matrices - (b, num_tokens, num_head, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        # Grouping by number of heads - (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)
        # Getting the attention scores - (b, num_heads, num_tokens, num_tokens)
        attention_scores = queries @ keys.transpose(2, 3)
        # Masking the attention scores
        attention_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens],
            -torch.inf
        )
        # Scaling the attention scores
        attention_scores = attention_scores / keys.shape[-1]**0.5
        # Getting the attention weights
        attention_weights = torch.softmax(attention_scores, dim=-1)
        # Implementing the dropout layer
        attention_weights = self.dropout(attention_weights)
        # Getting the context vector - (b, num_heads, num_tokens, head_dim)
        context_vector = attention_weights @ values
        # Reshaping the context vectors - (b, num_tokens, num_heads, head_dim)
        context_vector = context_vector.transpose(1, 2)
        # Combining the result of mutiple heads - d_out = num_heads * head_dim
        context_vector = context_vector.contiguous().view(b, num_tokens, d_out)
        # Passing the final context vector into the projection layer - optional
        context_vector = self.out_proj(context_vector)
        return context_vector

In [28]:
# Example of input

output_dim = 3
inputs = torch.tensor(
    [[0.43, 0.15, 0.89, 0.55, 0.87, 0.66],  # Row 1,
     [0.57, 0.85, 0.64, 0.22, 0.58, 0.33],  # Row 2,
     [0.77, 0.25, 0.10, 0.05, 0.80, 0.55],] # Row 3
)
print(inputs.shape)
batch = torch.stack((inputs, inputs), dim=0)
b, context_length, d_in = batch.shape
print(batch.shape)

torch.Size([3, 6])
torch.Size([2, 3, 6])


In [29]:
# Example of instance

d_out = d_in
num_heads = 2
mha = MultiHeadAttention(d_in, d_out, num_heads, context_length, 0.15)
context_vectors = mha(batch)
context_vectors, context_vectors.shape

(tensor([[[ 0.2708,  0.0741,  0.4923, -0.0958, -0.2626, -0.0549],
          [ 0.2410, -0.1707,  0.2888, -0.0914, -0.3085, -0.0217],
          [ 0.2825, -0.0746,  0.3702, -0.0928, -0.2867, -0.0395]],
 
         [[ 0.2171, -0.0048,  0.4475, -0.1367, -0.2203,  0.0122],
          [ 0.2895, -0.1292,  0.2980, -0.0443, -0.2448,  0.0181],
          [ 0.2610, -0.1991,  0.3199, -0.1254, -0.2818, -0.0571]]],
        grad_fn=<ViewBackward0>),
 torch.Size([2, 3, 6]))