# (Detour) Simplified attention block

Paper: [Simplifying Transformer Blocks -- He et al.](https://arxiv.org/abs/2311.01906), removes the need for skip connections (which increases training efficiency)

## Imports

In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from typing import Optional

## Attention definition

In [None]:
class SimplifiedMQA(nn.Module):
    """Simplified Multi-Query Attention implementation based on "Simplifying Transformer Blocks" (He & Hofmann, 2024)"""
    
    def __init__(self, d_in: int, d_out: int, context_length: int, dropout: float, num_heads: int, qkv_bias: bool = False, split_into_chunks: bool = True, zero_init_query: bool = False, is_casual: bool = True, do_dropout: bool = False):
        """Initializes the Simplified Multi-Query Attention module.
    
        Args:
            d_in (int): Input dimension
            d_out (int): Output dimension
            context_length (int): Maximum sequence length
            dropout (float): Dropout probability
            num_heads (int): Number of attention heads
            qkv_bias (bool): Whether to include bias in query/key/value projections
            split_into_chunks (bool): If True, splits input directly into heads for values (used in all layers except first).
                                    If False, projects input to head_dim using W_v (used in first layer only).
            zero_init_query (bool): Whether to initialize query weights to zero (as in original paper)
            is_casual (bool): Whether to use causal masking (for autoregressive models)
            do_dropout (bool): Whether to apply dropout to attention weights (recommended for fine-tuning)
        """
        super().__init__()
        assert (d_out % num_heads == 0), \
        "d_out must be divisible by num_heads"
        
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        
        # For First layer, project input from d_in to head_dim, otherwise, split into h chunks
        self.split_into_chunks = split_into_chunks
        self.is_casual = is_casual
        self.do_dropout = do_dropout
        
        # Query projects to full d_out dimension (will be split into heads)
        self.W_q = nn.Linear(d_in, d_out, bias=qkv_bias)
        # Key and Value project only to head_dim (shared across heads)
        self.W_k = nn.Linear(d_in, self.head_dim, bias=qkv_bias)
        self.W_v = nn.Linear(d_in, self.head_dim, bias=qkv_bias) if not self.split_into_chunks else None
        
        # Initialize W_q to zero if nessesary (as mentioned in original paper)
        if zero_init_query:
            nn.init.zeros_(self.W_q.weight)
            if self.W_q.bias is not None:
                nn.init.zeros_(self.W_q.bias)

        # Initialize learnable parameters per head
        self.alpha = nn.Parameter(torch.ones(num_heads))  # shape: (num_heads,)
        self.beta = nn.Parameter(torch.ones(num_heads))   # shape: (num_heads,)
        self.gamma = nn.Parameter(torch.ones(num_heads))  # shape: (num_heads,)
        
        # No outwards projection as the outwards projection layer often approximates the identity matrix (according to the paper)
        
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length),
            diagonal=1)
        )
        
    def forward(self, x: torch.Tensor):
        batch, num_tokens, d_in = x.shape
        
        # Mask for future tokens (since tokens are casual, and future tokens should not influence past tokens)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        
        # Project queries to full dimension and reshape
        queries = self.W_q(x).view(batch, num_tokens, self.num_heads, self.head_dim)
        # Project keys and values to single head dimension
        keys = self.W_k(x)    # shape: (batch, num_tokens, head_dim)
        if self.split_into_chunks:
            values = x.view(batch, num_tokens, self.num_heads, self.head_dim)
            values = values.transpose(1, 2)  # (batch, num_heads, num_tokens, head_dim)
        else:
            values = self.W_v(x)  # shape: (batch, num_tokens, head_dim)
        
        # Transpose queries for attention computation
        queries = queries.transpose(1, 2)  # (batch, num_heads, num_tokens, head_dim)
        
        # Reshape alpha, beta, gamma for broadcasting
        # Add dimensions for batch and sequence length
        alpha = self.alpha.view(1, self.num_heads, 1, 1)  # shape: (1, num_heads, 1, 1)
        beta = self.beta.view(1, self.num_heads, 1, 1)    # shape: (1, num_heads, 1, 1)
        gamma = self.gamma.view(1, self.num_heads, 1, 1)  # shape: (1, num_heads, 1, 1)
        
        # Creating the Identity component
        identity = torch.eye(num_tokens, device=x.device)[None, None, :, :]
        
        # Centering matrix 
        C = torch.ones(num_tokens, num_tokens, device=x.device) / num_tokens
        C = C[None, None, :, :]
        if self.is_casual:
            C.masked_fill(mask_bool, 0)
            C = C / (C.sum(dim=-1, keepdim=True) + 1e-8) # Renormalized to 1 after ("casual") masking
        
        # Compute attention scores
        # keys/values are broadcast across num_heads dimension
        attn_scores = queries @ keys.unsqueeze(1).transpose(-2, -1)
        if self.is_casual:   
            attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        # Attention weights
        attn_weights = torch.softmax(
            attn_scores / self.head_dim**0.5, dim=-1
        )
        
        # Dropout attention weights if user decides to do so (like in finetuning during original paper)
        if self.do_dropout:
            attn_weights = self.dropout(attn_weights)
        
        # Shaped attention
        shaped_attn = (alpha * identity + beta * attn_weights - gamma * C)
        
        if self.split_into_chunks:
            context_vec = (shaped_attn @ values).transpose(1, 2) # Values already has heads dimension
        else:
            context_vec = (shaped_attn @ values.unsqueeze(1)).transpose(1, 2) # Values needs broadcasting across heads
        
        # Reshape back to original dimensions
        context_vec = context_vec.contiguous().view(batch, num_tokens, self.d_out)
        
        return context_vec

# Transformer Block

In [None]:
class SimplifiedTransformerBlock(nn.Module):
    """Simplified Transformer Block based on "Simplifying Transformer Blocks" (He & Hofmann, 2024)"""
    def __init__(self, 
                 d_model: int,
                 num_heads: int,
                 mlp_dim: int,  # Usually 4*d_model
                 context_length: int,
                 dropout: float = 0.1,
                 layer_idx: int = 0,
                 split_into_chunks: Optional[bool] = None,
                 beta_ff: float = 0.1,
                 activation=F.relu):
        """Initializes the Simplified Transformer Block.
    
        Args:
            d_model (int): Model dimension
            num_heads (int): Number of attention heads
            mlp_dim (int): Dimension of MLP hidden layer (usually 4*d_model)
            context_length (int): Maximum sequence length
            dropout (float): Dropout probability
            layer_idx (int): Index of this layer in the stack
            split_into_chunks (Optional[bool]): Whether to split values into chunks or use projection.
                                            If None, defaults to True for all layers except first.
            beta_ff (float): Initial scale for MLP output (paper suggests 0.1 for depth 18)
            activation: Activation function to use in MLP
        """
        super().__init__()
        
        split_into_chunks = split_into_chunks if split_into_chunks else layer_idx != 0
        
        # Attention path
        self.attention = SimplifiedMQA(
            d_in=d_model,
            d_out=d_model,
            context_length=context_length,
            dropout=dropout,
            num_heads=num_heads,
            split_into_chunks=split_into_chunks  # Only first layer uses value projection
        )
        
        # MLP path
        self.mlp_in = nn.Linear(d_model, mlp_dim)
        self.mlp_out = nn.Linear(mlp_dim, d_model)
        self.activation = activation
        
        # Initialize MLP residual scale as mentioned in paper
        self.beta_ff = nn.Parameter(torch.tensor(beta_ff))
        
    def forward(self, x):
        # Parallel processing of attention and MLP paths
        attn_out = self.attention(x)
        mlp_out = self.mlp_out(self.activation(self.mlp_in(x)))
        
        # Combine paths (with optional MLP scaling)
        return attn_out + self.beta_ff * mlp_out

class SimplifiedTransformer(nn.Module):
    def __init__(self, num_layers, **block_args):
        super().__init__()
        self.layers = nn.ModuleList([
            SimplifiedTransformerBlock(layer_idx=i, **block_args)
            for i in range(num_layers)
        ])
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x