A comparatively efficienct implementation of a multihead attention class with dropout.

In the CausalAttention notebook, the last class implemented is a wrapper that takes multiple Causal Attention modules and then serially projects the input over the trainable weights to get context vectors, which are then stacked together in the output. This is way less efficient due to the number of matrix multiplications is equal to the number of heads. 

Instead, we are going to create matrices that have the Q, K, V weights for all heads in them, multiply once (i.e. do the projections once), and then tease the output apart to get to context vectors from each head. Ofcourse, we will take this all the way this time around and combine the context vectors to represent the context vectors obtained via multihead attention.

In [58]:
import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout,
                 num_heads, qkv_bias=False):
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.num_heads = num_heads
        #The number of heads per dimension
        self.head_dim = self.d_out // self.num_heads
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        print(f"W_key weight shape: {self.W_key.weight.shape}")
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        print(f"keys.shape: {keys.shape}")
        
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)

        
        print(f"keys.shape after splitting into heads: {keys.shape}")

        #Now here, we are going to "rearrange" the matrices such that we go from
        #[b, num_tokens, num_heads, head_dim] -> [b, num_heads, num_tokens, head_dim]
        #and this should make a lot of sense because after all, we are trying to process
        #multiple heads here, and so arranging this in that hierarchy is what we need to do

        keys = keys.transpose(1,2)
        queries = queries.transpose(1,2)
        values = values.transpose(1,2)

        #and now really business as usual. We want to compute the attn weights and context
        #vectors, just that we need to remain hyper aware that our matrices are hierarchically
        #arranged as batches -> heads -> tokens -> embedded vectors
        attn_scores = queries @ keys.transpose(2,3)

        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        attn_scores.masked_fill_(mask_bool, -torch.inf)

        #scaled dot product
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

        #apply the dropout if there is one specified
        attn_weights = self.dropout(attn_weights)

        #calculate the context vectors
        context_vectors = attn_weights @ values

        #Now, we want to go back to how we desire the context vectors, i.e.,
        #(b, num_tokens, vectors) where all of the vectors from the multiple heads are
        #"combined" to see the uniform result of multihead processing.
        #To get there, we need to start rearranging the hierarchy 

        context_vectors = context_vectors.transpose(1,2)
        #so now this will become (b, tokens, heads, vectors)

        context_vectors = context_vectors.contiguous().view(b, num_tokens, self.d_out)
        #so now this will become (b, tokens, vectors)

        #apply a linear projection that will be useful when this class is used in training
        context_vectors = self.out_proj(context_vectors)

        return context_vectors
        

        


In [59]:
#batch size, context length, number of embedded dimensions
batch = torch.randn(8, 1024, 800)
d_in = batch.shape[2] ## the input embedding size
d_out = 400 ## the output embedding size

batch.shape

torch.Size([8, 1024, 800])

In [60]:
mha = MultiHeadAttention(d_in, d_out, batch.shape[1], 0, 2)

context_vecs = mha(batch)

context_vecs.d


W_key weight shape: torch.Size([400, 800])
keys.shape: torch.Size([8, 1024, 400])
keys.shape after splitting into heads: torch.Size([8, 1024, 2, 200])


tensor([[[ 5.3731e-01,  2.0384e-01,  3.5028e-01,  ...,  2.7758e-01,
          -3.8424e-02, -5.1152e-02],
         [ 3.1804e-02,  3.3741e-01,  5.9904e-02,  ...,  9.1153e-02,
           1.8804e-01,  3.5074e-02],
         [-2.0838e-01,  2.0049e-01,  4.2469e-02,  ...,  1.8851e-01,
           2.5226e-01, -5.1680e-02],
         ...,
         [-3.9543e-02,  3.3366e-02, -1.7590e-02,  ...,  4.8678e-03,
          -4.5569e-02,  3.3840e-02],
         [-4.5406e-02,  3.3032e-02, -1.7882e-02,  ...,  6.3432e-04,
          -4.0340e-02,  2.9276e-02],
         [-4.2710e-02,  4.0721e-02, -8.8204e-03,  ...,  8.4825e-03,
          -4.1777e-02,  2.4589e-02]],

        [[-1.9143e-02, -3.2137e-01,  7.5531e-01,  ..., -1.8179e-01,
           6.0165e-03,  2.6764e-01],
         [ 1.4512e-01, -7.8187e-02,  5.1812e-01,  ..., -1.0876e-01,
          -1.8792e-01,  1.1768e-01],
         [ 1.7365e-01,  8.4582e-02,  4.2504e-01,  ..., -7.4203e-02,
          -1.9752e-01, -3.3649e-02],
         ...,
         [-2.6509e-02,  1