#### Comparing Efficient Multi-Head Attention Implementations
###### This code notebook compares different ways to implement causal multi-head attention used in decoder-style LLMs like GPT, Llama, etc.

In [None]:
import torch

torch.manual_seed(123)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"PyTorch version: {torch.__version__}")

batch_size = 8
context_len = 1024
embed_dim = 768
embeddings = torch.randn((batch_size, context_len, embed_dim), device=device)

###### To run all the code in this notebook, please ensure you update to at least PyTorch 2.5 (FlexAttention is not included in earlier PyTorch releases)
###### If the code cell above shows a PyTorch version lower than 2.5, you can upgrade your PyTorch installation by uncommenting and running the following code cell (Please note that PyTorch 2.5 requires Python 3.9 or later)
###### For more specific instructions and CUDA versions, please refer to the official installation guide at https://pytorch.org

In [None]:
# pip install --upgrade torch torchvision torchaudio

##### 1) CausalAttention MHA wrapper class from chapter 3

In [None]:
import torch.nn as nn

class CausalAttention(nn.Module):
     
     def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
         super().__init__()
         self.d_out = d_out
         self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
         self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
         self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
         self.dropout = nn.Dropout(dropout) # New 
         self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))  # New

     def forward(self, x):
         b, num_tokens, d_in = x.shape # New batch dimension b
         keys = self.W_key(x)
         queries = self.W_query(x)
         values = self.W_value(x)

         attn_scores = queries @ keys.transpose(1, 2)  # Changed transpose
         attn_scores.masked_fill_(  # New, _ ops are in-place
             self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
         )
         attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
         attn_weights = self.dropout(attn_weights)  # New
         context_vec = attn_weights @ values
         return context_vec

class Ch03_MHA_Wrapper(nn.Module):
       
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)
            for _ in range(num_heads)]
        )
        self.out_proj = nn.Linear(d_out*num_heads, d_out*num_heads)

    def forward(self, x):
        context_vec = torch.cat([head(x) for head in self.heads], dim=-1)
        return self.out_proj(context_vec)

mha_ch03_wrapper = Ch03_MHA_Wrapper(
    d_in = embed_dim,
    d_out = embed_dim // 12,
    context_length=context_len,
    dropout=0.0,
    num_heads=12,
    qkv_bias=False
).to(device)

out = mha_ch03_wrapper(embeddings)
print(out.shape)



###### 2) The multi-head attention class from chapter 3

In [None]:
class Ch03_MHA(nn.Module):
     def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
         super().__init__()
         assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

         self.d_out = d_out
         self.num_heads = num_heads
         self.head_dim = d_out // num_heads  # Reduce the projection dim to match desired output dim

         self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
         self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
         self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
         self.out_proj = nn.Linear(d_in, d_out)  # Linear layer to combine head outputs
         self.dropout = nn.Dropout(dropout)
         self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

     def forward(self, x):
         b, num_tokens, d_in = x.shape

         keys = self.W_key(x)  # Shape: (b, num_tokens, d_out)
         queries = self.W_query(x)
         values = self.W_value(x)

         # We implicitly split the matrix by adding a `num_heads` dimension
         # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
         keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
         values = values.view(b, num_tokens, self.num_heads, self.head_dim)
         queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

         # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
         keys = keys.transpose(1, 2)
         queries = queries.transpose(1, 2)
         values = values.transpose(1, 2)


         # Compute scaled dot-product attention (aka self-attention) with a causal mask
         attn_scores = queries @ keys.transpose(2, 3)   # Dot product for each head

         # Original mask truncated to the number of tokens and converted to boolean
         mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

         # Use the mask to fill attention scores
         attn_scores.masked_fill_(mask_bool, -torch.inf)

         attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
         attn_weights = self.dropout(attn_weights)

         # Shape: (b, num_tokens, num_heads, head_dim)
         context_vec = (attn_weights @ values).transpose(1, 2)

         # Combine heads, where self.d_out = self.num_heads * self.head_dim
         context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
         context_vec = self.out_proj(context_vec)  # optional projection

         return context_vec

mha_ch03 = Ch03_MHA(
    d_in=embed_dim,
    d_out=embed_dim,
    context_length=context_len,
    dropout=0.0,
    num_heads=12,
    qkv_bias=False
).to(device)

out = mha_ch03(embeddings)
print(out.shape)

         

#### 3) An alternative multi-head attention with combined weights
###### The code for the MultiHeadAttentionCombinedQKV class below is based on code that was kindly shared by Rayed Bin Wahed

###### The main difference between the MultiHeadAttentionCombinedQKV class and the MultiHeadAttention class used in chapter 3 is that MultiHeadAttentionCombinedQKV uses a single weight matrix, self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias) instead of separate weight matrices:

###### self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
###### self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
###### self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
###### Here, self.qkv combines all three weight matrices self.W_query, self.W_key, and self.W_value to carry out the query, key, and value computation in a single step

###### Using q, k, v = qkv.unbind(0), we obtain the individual query, key, and value tensors, which are then used similarly to the query, key, and value tensors in the MultiHeadAttention class in chapter 3

In [None]:
import torch.nn as nn

class MultiHeadAttentionCombinedQKV(nn.Module):
   def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):
       super().__init__()

       assert d_out % num_heads == 0, "embed_dim is indivisible by num_heads"

       self.num_heads = num_heads
       self.context_length = context_length
       self.head_dim = d_out // num_heads

       self.qkv = nn.Linear(d_in, 3*d_out, bias=qkv_bias)
       self.proj = nn.Linear(d_out, d_out)
       self.dropout = nn.Dropout(dropout)

       self.register_buffer(
        "mask", torch.triu(torch.ones(context_length, context_length), diagonal=1)
       )

   def forward(self, x):
       batch_size, num_tokens, embed_dim = x.shape

       # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)
       qkv = self.qkv(x)

        # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)
        qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)

        # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)

        # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_head, num_tokens, head_dim)
        queries, keys, values = qkv.unbind(0)

        # (b, num_heads, num_tokens, head_dim) --> (b, num_heads, num_tokens, num_tokens)
        attn_scores = queries @ keys.transpose(-2, -1)
        attn_scores = attn_scores.masked_fill(
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
        )

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**-0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # (b, num_heads, num_tokens, num_tokens) --> (b, num_heads, num_tokens, head_dim)
        context_vec = attn_weights @ values

        # (b, num_heads, num_tokens, head_dim) --> (b, num_tokens, num_heads, head_dim)
        context_vec = context_vec.transpose(1, 2)

        # (b, num_tokens, num_heads, head_dim) --> (b, num_tokens, embed_dim)
        context_vec = context_vec.contiguous().view(batch_size, num_tokens, embed_dim)

        context_vec = self.proj(context_vec)

        return context_vec

mha_combined_qkv = MultiHeadAttentionCombinedQKV(
    d_in = embed_dim,
    d_out = embed_dim,
    context_length = context_len,
    dropout = 0.0,
    num_heads = 12,
    qkv_bias = False
).to(device)

out = mha_combined_qkv(embeddings)
print(out.shape)

#### 4) Multi-head attention with Einsum

###### Implementing multi-head attention using Einstein summation via torch.einsum

In [None]:
import math

class MHAEinsum(nn.Module):

    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
       super().__init__()
       assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

       self.d_out = d_out
       self.num_heads = num_heads
       self.head_dim = d_out // num_heads

       # Initialize parameters for Q, K, V
       self.W_query = nn.Parameter(torch.randn(d_out, d_in))
       self.W_key = nn.Parameter(torch.randn(d_out, d_in))
       self.W_value = nn.Parameter(torch.randn(d_out, d_in))

       if qkv_bias:
          self.bias_q = nn.Parameter(torch.zeros(d_out))
          self.bias_k = nn.Parameter(torch.zeros(d_out))
          self.bias_v = nn.Parameter(torch.zeros(d_out))
       else:
          self.register_parameter("bias_q", None)
          self.register_parameter("bias_k", None)
          self.register_parameter("bias_v", None)

       self.out_proj = nn.Linear(d_out, d_out)
       self.dropout = nn.Dropout(dropout)
       self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

       # Initialize parameters
       self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.W_quey, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.W_key, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.W_value, a=math.sqrt(5))
        if self.bias_q is not None:
           fan_in, _  =  nn.init._calculate_fan_in_and_fan_out(self.W_query)
           bound = 1 / math.sqrt(fan_in)
           nn.init.uniform_(self.bias_q, -bound, bound)
           nn.init.uniform_(self.bias_k, -bound, bound)
           nn.init.uniform_(self.bias_v, -bound, bound)

    def forward(self, x):
        b, n, _ = x.shape

        # Calculate Q, K, V using einsum, first perform linear transformations 
        Q = torch.einsum("bnd,di->bni", x, self.W_query)
        K = torch.einsum("bnd,di->bni", x, self.W_key)
        V = torch.einsum("bnd,di->bni", x, self.W_value)

        # Add biases if they are used
        if self.bias_q is not None:
           Q += self.bias_q
           K += self.bias_k
           V += self.bias_v

        # Reshape for multi-head attention
        Q = Q.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)

        # Scaled dot-product attention
        scores = torch.einsum("bhnd, bhmd->bhnm", Q, K) / (self.head_dim ** 0.5)

        # Apply mask
        mask = self.mask[:n, :n].unsqueeze(0).unsqueeze(1).expand(b, self.num_heads, n, n)
        scores = scores.masked_fill(mask.bool(), -torch.inf)

        # Softmax and dropout
        attn_weights = torch.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Aggregate the attended context vectors
        context_vec = torch.einsum("bhnm,bhmd->bhnd", attn_weights, V)

        # Combine heads and project the output
        context_vec = context_vec.transpose(1, 2).reshape(b, n, self.d_out)
        context_vec = self.out_proj(context_vec)

        return context_vec

mha_einsum = MHAEinsum(
    d_in=embed_dim,
    d_out=embed_dim,
    context_length=context_len,
    dropout=0.0,
    num_heads=12,
    qkv_bias=False
).to(device)

out = mha_einsum(embeddings)
print(out.shape)



#### 5) Multi-head attention with PyTorch's scaled dot product attention and FlashAttention
###### The implementation below uses PyTorch's scaled_dot_product_attention function, which implements a memory-optimized version of self-attention called FlashAttention

In [None]:
class MHAPyTorchScaledDotProduct(nn.Module):
    def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):
       super().__init__()

       assert d_out % num_heads == 0, "embed_dim is indivisible by num_heads"

       self.num_heads = num_heads
       self.context_length = context_length
       self.head_dim = d_out // num_heads
       self.d_out = d_out

       self.qkv = nn.Linear(d_in, 3*d_out, bias=qkv_bias)
       self.proj = nn.Linear(d_out, d_out)
       self.dropout = dropout

    def forward(self, x):
       batch_size, num_tokens, embed_dim = x.shape

       # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)
       qkv = self.qkv(x)

       # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)
       qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)

       # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)

        # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)
        queries, keys, values = qkv

        use_dropout = 0. if not self.training else self.dropout

        context_vec = nn.functional.scaled_dot_product_attention(
            queries, keys, values, attn_mask=None, dropout_p=use_dropout, is_causal=True)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)

        context_vec = self.proj(context_vec)

        return context_vec


In [None]:
mha_pytorch_scaled = MHAPyTorchScaledDotProduct(
    d_in=embed_dim,
    d_out=embed_dim,
    context_length=context_len,
    dropout=0.0,
    num_heads=12,
    qkv_bias=False
).to(device)

out = mha_pytorch_scaled(embeddings)
print(out.shape)