## ---------------------------------------Multi-Head Attention---------------------------------

- In theory multi-head attention involves creating mutiple instances of the casual attention attention mechanism and concatenating their outputs
- In code I did this by implementing a simple multi-head attention wrapper class that stacks mutiple instances of my previous casual attention class

In [3]:
# Creating the token embeddings - Randomized
import torch
output_dim = 3

inputs = torch.tensor([
    [0.43, 0.15, 0.89], # Your    # X1
    [0.55, 0.87, 0.66], # journey # x2
    [0.57, 0.85, 0.64], # begins  # X3
    [0.22, 0.58, 0.33], # with    # X4
    [0.77, 0.25, 0.10], # one     # X5
    [0.05, 0.80, 0.55] # step     # X6
])  
inputs

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])

In [4]:
# Creating the input structure

batch = torch.stack((inputs, inputs), dim=0)
context_length = batch.shape[1]
d_in, d_out = 3, 2
batch.shape

torch.Size([2, 6, 3])

In [5]:
# Casual Attention Class
from torch import nn
class CasualAttentionV1(nn.Module):
    def __init__(self, d_in, d_out, dropout_rate, context_length, bias_units=False):
        super().__init__()
        # Defining the key query value weights
        self.w_key = torch.nn.Linear(d_in, d_out, bias=bias_units)
        self.w_query = torch.nn.Linear(d_in, d_out, bias=bias_units)
        self.w_value = torch.nn.Linear(d_in, d_out, bias=bias_units)
        self.dropout = torch.nn.Dropout(dropout_rate) # new
        # Creating the masking foundation
        self.register_buffer("mask",torch.triu(torch.ones(context_length, context_length), diagonal=1))
    def forward(self, x):
        # Remember we are dealing with batches
        b, num_tokens, d_in = x.shape
        # Getting the key query value matrices
        keys = self.w_key(x)
        queries = self.w_query(x)
        values = self.w_value(x)
        # Getting the attention scores - we reshape the inner dimensions in the transpose
        attention_scores = queries @ keys.transpose(1, 2)
        # Upper triangular infinity mask - modify the tensor in place
        attention_scores.masked_fill_(
            # Slicing the mask to match the current input  
            self.mask.bool()[:num_tokens, :num_tokens],
            -torch.inf
        )
        # Scaling attention scores
        scaled_attention_scores = attention_scores / keys.shape[-1]**0.5
        # Calculating the attention weight
        attention_weights = torch.softmax(scaled_attention_scores, dim=-1)
        # Dropout layer
        attention_weights = self.dropout(attention_weights)
        # Calculating the context vectors
        context_vectors = attention_weights @ values
        return context_vectors        

In [6]:
# Instatiation using our batch sample
batches, context_length, dimensions = batch.shape
ca = CasualAttentionV1(d_in=3, d_out=2, dropout_rate=0.1 ,context_length=context_length)
context_vectors = ca(batch)
context_vectors, context_vectors.shape

(tensor([[[-0.5242,  0.5078],
          [-0.6026,  0.5712],
          [-0.6281,  0.5896],
          [-0.5521,  0.5221],
          [-0.4230,  0.4001],
          [-0.3174,  0.2891]],
 
         [[-0.5242,  0.5078],
          [-0.6026,  0.5712],
          [-0.3963,  0.3732],
          [-0.5521,  0.5221],
          [-0.4848,  0.4248],
          [-0.4326,  0.3966]]], grad_fn=<UnsafeViewBackward0>),
 torch.Size([2, 6, 2]))

In [7]:
# Example of a wrapper class 

# Base class
class Worker:
    def __init__(self, name):
        self.name = name
    def get_result(self):
        return f"Hello from {self.name}"    
# Wrapper class
class Manager:
    def __init__(self, names):
        # This creates a list of instances
        self.workers = [Worker(name) for name in names]
    def get_combined_result(self):
        # For each instance combine the output
        return "|".join(worker.get_result() for worker in self.workers)

# Initialization 
names = ["mel", "bob", "chel"]
manager = Manager(names)
print(manager.get_combined_result())

Hello from mel|Hello from bob|Hello from chel


In [12]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout_rate, num_heads, bias_units=False):
        super().__init__()
        # Creating a list of instances
        self.heads = nn.ModuleList(
            [CasualAttentionV1(d_in, d_out, dropout_rate, context_length, bias_units) for _ in range(num_heads)]
        )
    # Forward pass
    def forward(self, x):
        # Concatenating the result of each instance
        return torch.cat([head(x) for head in self.heads], dim=-1)


## Notes - On the shape of the final output 

Let the context_vector.shape for each head = (6, 2)

The shape of the final context vector = (6, 2 x number of heads)

Let number of head  =  2 final context vector = (6, 2x2) = (6, 4)

In [13]:
# Creating the final context vector
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.2, 2)
context_vector = mha(batch)
context_vector, context_vector.shape

(tensor([[[-0.6709,  0.4473,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0252,  0.0927],
          [-0.7199,  0.2780,  0.1457,  0.1900],
          [-0.4445,  0.1822,  0.1457,  0.1629],
          [-0.6189,  0.1582,  0.1778,  0.1674],
          [-0.4527,  0.1411,  0.1382,  0.1361]],
 
         [[-0.6709,  0.4473, -0.0491,  0.1806],
          [-0.7067,  0.3251, -0.0252,  0.0927],
          [-0.4755,  0.2188,  0.1457,  0.1900],
          [-0.4438,  0.1858,  0.1457,  0.1629],
          [-0.4091,  0.0465,  0.1778,  0.1674],
          [-0.3942,  0.1231,  0.0991,  0.1052]]], grad_fn=<CatBackward0>),
 torch.Size([2, 6, 4]))