In [1]:
import torch
import torch.nn as nn

In [2]:
class MultiHeadCausalAttention(nn.Module):

    def __init__(self, d_in, d_out, context_length, num_heads, dropout, qkv_bias=False):
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.context_length = context_length
        self.num_heads = num_heads
        self.droput = dropout
        self.qkv_bias = qkv_bias
        self.droput = nn.Dropout(p=dropout)

        assert d_out%num_heads == 0
        self.d_k = d_out//num_heads #this is the head dimension, d_k as per paper

        self.Wq = nn.Linear(d_in, d_out, qkv_bias)
        self.Wk = nn.Linear(d_in, d_out, qkv_bias)
        self.Wv = nn.Linear(d_in, d_out, qkv_bias)

        self.out_proj =nn.Linear(d_out, d_out)
        
        self.register_buffer("mask", torch.triu(torch.ones(context_length,context_length),diagonal=1))


    def forward(self, X):
        '''extract the batch_size, num of tokens and the input shape'''
        batch_size, num_tokens, d_in = X.shape # note d_in = d_model in the paper

        '''compute the query, key and value matrices'''
        # they are of shape (batch_size, num_tokens, d_out)
        query = self.Wq(X)
        key = self.Wq(X)
        value = self.Wq(X)

        '''the last dimension of query,key,value matrices is d_out = num_heads*d_k'''
        """ Reshape the matrices from (batch_size, num_tokens, d_out) -> (batch_size, num_tokens, num_heads, d_k)"""
        query = query.view(batch_size, num_tokens, self.num_heads, self.d_k)
        key = key.view(batch_size, num_tokens, self.num_heads, self.d_k)
        value = value.view(batch_size, num_tokens, self.num_heads, self.d_k)

        '''the resulting matrices are grouped by num_tokens but'''
        '''in order to compute attention they must be grouped by no of heads'''
        """so the 2nd and 3rd dimesnions need to be interchanged"""
        query = query.transpose(1,2)
        key = key.transpose(1,2)
        value = value.transpose(1,2)

        '''Resulting dimension is (batch_size, num_heads, num_tokens, d_k)'''
        '''attn = query @ value.T [(batch_size, num_heads, num_tokens, d_k) * (batch_size, num_heads, d_k, num_tokens)]'''
        """Compute attention scores"""
        attn_scores = query @ key.transpose(2,3)

        '''Resulting dimesnion is (batch_size, num_heads, num_tokens, num_tokens)'''
        """Compute attention weights"""
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)  # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size
        attn_weights = torch.softmax(attn_scores / key.shape[-1]**0.5, dim=-1)
        attn_weights = self.droput(attn_weights)

        '''Resulting dimension is (batch_size, num_heads, num_tokens, num_tokens)'''
        '''context vector = attn_weights * value'''
        '''(batch_size, num_heads, num_tokens, num_tokens) * (batch_size, num_heads, num_tokens, d_k) '''
        context_vec = attn_weights @ value

        '''Resulting dimesnion is (batch_size, num_heads, num_tokens, d_k)'''
        context_vec = context_vec.transpose(1,2)
        '''combine them such as d_out = num_heads * d_k'''
        context_vec = context_vec.contigious().view(batch_size, num_tokens, self.d_out)

        '''Resulting dimesnion is (batch_size, num_tokens, d_out)'''

        '''projection for output layer'''
        '''output is (batch_size, num_tokens, d_out)'''
        output = self.out_proj(context_vec)

        return output 

        

