# Multi-head Self-Attention (MHSA)

## Code

In [None]:
import torch
import math
from torch import Tensor
from jaxtyping import Float
import torch.nn as nn
from cs336_basics.softmax import Softmax
from cs336_basics.scaled_dot_product_attention import ScaledDotProductAttention
 
 
class MultiHeadSelfAttention(nn.Module):
    
    def __init__(
            self,
            d_model : int,
            num_heads : int,  
        ) -> None:
        
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.attention = ScaledDotProductAttention()
 
    
    def forward(
            self,
            q_proj_weight : torch.Tensor,
            k_proj_weight : torch.Tensor,
            v_proj_weight : torch.Tensor,
            o_proj_weight : torch.Tensor,
            in_features : torch.Tensor
        ) -> Tensor:

        # Linear Projections
        q = in_features @ q_proj_weight.T
        k = in_features @ k_proj_weight.T
        v = in_features @ v_proj_weight.T

        # Reshape and Transpose for the projections
        def split_heads(x: Tensor):
            *batch_dims, seq_len, d_total = x.shape
            d_head = d_total // self.num_heads
            return x.view(*batch_dims, seq_len, self.num_heads, d_head).transpose(-3, -2)
        
        q = split_heads(q)
        k = split_heads(k)
        v = split_heads(v)

        # Create causal mask
        seq_len = q.shape[-2]
        causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=in_features.device, dtype=torch.bool))

        attn_output = self.attention(q, k, v, mask=causal_mask)

        # Concatenate Heads
        *batch_dims, num_heads, seq_len, d_v_head = attn_output.shape
        combined = attn_output.transpose(-3, -2).reshape(*batch_dims, seq_len, num_heads * d_v_head)

        return combined @ o_proj_weight.T

In [None]:
import torch
import math
from torch import Tensor
from jaxtyping import Float
import torch.nn as nn
from cs336_basics.softmax import Softmax
from cs336_basics.scaled_dot_product_attention import ScaledDotProductAttention
from cs336_basics.rope import RoPE

class MultiHeadSelfAttentionRoPE(nn.Module):
    
    def __init__(
            self,
            d_model : int,
            num_heads : int,
            max_seq_len : int,
            theta : float
        ) -> None:
        
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.attention = ScaledDotProductAttention()
        self.rope = RoPE(
            d_k=d_model,
            max_seq_len=max_seq_len,
            theta=theta,
        )
 
    
    def forward(
            self,
            q_proj_weight : torch.Tensor,
            k_proj_weight : torch.Tensor,
            v_proj_weight : torch.Tensor,
            o_proj_weight : torch.Tensor,
            in_features : torch.Tensor,
            token_positions : torch.Tensor
        ) -> Tensor:

        # Linear Projections
        q = in_features @ q_proj_weight.T
        k = in_features @ k_proj_weight.T
        v = in_features @ v_proj_weight.T

        # Reshape and Transpose for the projections
        def split_heads(x: Tensor):
            *batch_dims, seq_len, d_total = x.shape
            d_head = d_total // self.num_heads
            return x.view(*batch_dims, seq_len, self.num_heads, d_head).transpose(-3, -2)
        
        q = split_heads(q)
        k = split_heads(k)
        v = split_heads(v)

        


        # Create causal mask
        seq_len = q.shape[-2]
        causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=in_features.device, dtype=torch.bool))

        attn_output = self.attention(q, k, v, mask=causal_mask)

        # Concatenate Heads
        *batch_dims, num_heads, seq_len, d_v_head = attn_output.shape
        combined = attn_output.transpose(-3, -2).reshape(*batch_dims, seq_len, num_heads * d_v_head)

        return combined @ o_proj_weight.T