In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple, Union

class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization"""
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        x = x / rms
        return self.weight * x


class RotaryPositionalEmbedding(nn.Module):
    """Rotary Positional Embeddings (RoPE)"""
    def __init__(self, dim: int, max_seq_len: int = 4096, base: int = 10000):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.base = base

        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        self._build_cache()

    def _build_cache(self):
        seq_idx = torch.arange(self.max_seq_len, dtype=self.inv_freq.dtype)
        freqs = torch.outer(seq_idx, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, :, None, :])
        self.register_buffer("sin_cached", emb.sin()[None, :, None, :])

    def forward(self, x: torch.Tensor, seq_len: int) -> torch.Tensor:
        return self.apply_rotary_emb(x, self.cos_cached[:, :seq_len], 
                                     self.sin_cached[:, :seq_len])

    @staticmethod
    def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2 :]

        rotated = torch.cat(
            (-x2 * sin[..., :x2.shape[-1]] + x1 * cos[..., :x1.shape[-1]],
             x2 * cos[..., :x2.shape[-1]] + x1 * sin[..., :x1.shape[-1]]),
            dim=-1
        )
        return rotated


class MultiheadAttentionWithRoPE(nn.Module):
    """Pure Multi-head Attention mechanism with RoPE

    This module only handles the attention computation.
    Normalization, dropout, and residual connections are handled externally.
    """
    def __init__(
        self, 
        d_model: int,
        n_heads: int,
        max_seq_len: int = 4096,
        attn_dropout: float = 0.1,
        rope_base: int = 10000
    ):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"

        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.scale = self.head_dim ** -0.5

        # Linear projections for Q, K, V
        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, d_model, bias=False)
        self.v_proj = nn.Linear(d_model, d_model, bias=False)

        # Output projection
        self.out_proj = nn.Linear(d_model, d_model, bias=False)

        # Rotary embeddings
        self.rope = RotaryPositionalEmbedding(self.head_dim, max_seq_len, rope_base)

        # Attention dropout (applied to attention weights)
        self.attn_dropout = nn.Dropout(attn_dropout)

    def forward(
        self, 
        x: torch.Tensor,
        causal_mask: bool = False,
        attention_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Forward pass - returns raw attention output without residual

        Args:
            x: Input tensor of shape [batch, seq_len, d_model]
            causal_mask: Whether to apply causal masking
            attention_mask: Optional custom attention mask

        Returns:
            Attention output of shape [batch, seq_len, d_model]
        """
        batch_size, seq_len, _ = x.shape

        # Compute Q, K, V projections
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        # Reshape for multi-head attention
        q = q.view(batch_size, seq_len, self.n_heads, self.head_dim)
        k = k.view(batch_size, seq_len, self.n_heads, self.head_dim)
        v = v.view(batch_size, seq_len, self.n_heads, self.head_dim)

        # Apply rotary positional embeddings to Q and K
        q = self.rope(q, seq_len)
        k = self.rope(k, seq_len)

        # Transpose for attention computation
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # Compute attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale

        # Apply causal mask if requested
        if causal_mask:
            mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
            scores = scores.masked_fill(mask, float('-inf'))

        # Apply custom attention mask if provided
        if attention_mask is not None:
            scores = scores.masked_fill(attention_mask == 0, float('-inf'))

        # Compute attention weights
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.attn_dropout(attn_weights)

        # Apply attention to values
        attn_output = torch.matmul(attn_weights, v)

        # Reshape back to [batch, seq_len, d_model]
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len, self.d_model)

        # Apply output projection
        output = self.out_proj(attn_output)

        return output


class FeedForwardNetwork(nn.Module):
    """Feed-Forward Network with configurable activation"""
    def __init__(
        self,
        d_model: int,
        d_ff: Optional[int] = None,
        activation: str = "swiglu"
    ):
        super().__init__()

        if d_ff is None:
            d_ff = 4 * d_model

        self.activation_type = activation

        if activation == "swiglu":
            self.w1 = nn.Linear(d_model, d_ff, bias=False)
            self.w2 = nn.Linear(d_ff, d_model, bias=False)
            self.w3 = nn.Linear(d_model, d_ff, bias=False)
        else:
            self.w1 = nn.Linear(d_model, d_ff, bias=False)
            self.w2 = nn.Linear(d_ff, d_model, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.activation_type == "swiglu":
            return self.w2(F.silu(self.w1(x)) * self.w3(x))
        elif self.activation_type == "gelu":
            return self.w2(F.gelu(self.w1(x)))
        else:
            return self.w2(F.relu(self.w1(x)))
        
class CausalInductionSelfAttentionBlock(nn.Module):
    """Causal ISAB that properly maintains causality

    This version ensures inducing points only attend to past positions
    while maintaining computational efficiency.
    """

    def __init__(self, d_model: int, n_heads: int, n_inducing_points: int, dropout: float = 0.1):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_inducing_points = n_inducing_points

        # Learnable inducing points
        self.inducing_points = nn.Parameter(torch.randn(n_inducing_points, d_model))

        # Attention layers
        self.compress_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.decompress_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)

        # Layer norms
        self.compress_norm = RMSNorm(d_model)
        self.decompress_norm = RMSNorm(d_model) 

        # Feedforward
        self.ffn = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model),
            nn.Dropout(dropout)
        )
        self.ffn_norm = RMSNorm(d_model)

        self.dropout = nn.Dropout(dropout)

    def _create_causal_inducing_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
        """Create mask ensuring inducing points only see past positions

        For each sequence position i, inducing points can only attend to positions 0...i
        This creates a block-diagonal pattern in the attention.

        Returns:
            mask: [seq_len * n_inducing_points, seq_len] attention mask
        """
        n_ind = self.n_inducing_points

        # Create a mask where inducing point j at position i can see positions 0...i
        mask = torch.zeros(seq_len * n_ind, seq_len, device=device)

        for pos in range(seq_len):
            # For position pos, all inducing points can see positions 0...pos
            start_idx = pos * n_ind
            end_idx = (pos + 1) * n_ind
            mask[start_idx:end_idx, :pos+1] = 1

        # Convert to attention mask format (0 = attend, -inf = don't attend)
        mask = (1 - mask) * float('-inf')
        mask[mask != mask] = 0  # Replace NaN with 0

        return mask

    def forward(self, x: torch.Tensor, causal_mask: bool = True) -> torch.Tensor:
        """Forward pass with causal compression

        Args:
            x: Input tensor [batch, seq_len, d_model]
            causal_mask: Whether to apply causal masking

        Returns:
            Output tensor [batch, seq_len, d_model]
        """
        batch_size, seq_len, _ = x.shape
        device = x.device

        # Step 1: Causal compression
        x_norm = self.compress_norm(x)

        if causal_mask:
            # Create position-specific inducing points for causality
            # Each position gets its own set of inducing points that can only see past
            inducing = self.inducing_points.unsqueeze(0).expand(seq_len, -1, -1)  # [seq_len, n_ind, d_model]
            inducing = inducing.reshape(seq_len * self.n_inducing_points, self.d_model)  # [seq_len * n_ind, d_model]
            inducing = inducing.unsqueeze(0).expand(batch_size, -1, -1)  # [batch, seq_len * n_ind, d_model]

            # Create causal mask for compression
            compress_mask = self._create_causal_inducing_mask(seq_len, device)

            # Compress with causal attention
            compressed, _ = self.compress_attn(
                query=inducing,
                key=x_norm,
                value=x_norm,
                attn_mask=compress_mask,
                need_weights=False
            )

            # Reshape back to [batch, seq_len, n_inducing_points, d_model]
            compressed = compressed.view(batch_size, seq_len, self.n_inducing_points, self.d_model)

            # Step 2: Causal decompression
            # For each position, attend only to inducing points from that position
            output = []
            for i in range(seq_len):
                # Get inducing points for position i
                ind_i = compressed[:, i, :, :]  # [batch, n_inducing_points, d_model]

                # Decompress to position i
                x_norm_i = self.decompress_norm(x[:, i:i+1, :])  # [batch, 1, d_model]
                out_i, _ = self.decompress_attn(
                    query=x_norm_i,
                    key=ind_i,
                    value=ind_i,
                    need_weights=False
                )
                output.append(out_i)

            output = torch.cat(output, dim=1)  # [batch, seq_len, d_model]

        else:
            # Non-causal path: standard ISAB
            inducing = self.inducing_points.unsqueeze(0).expand(batch_size, -1, -1)

            # Compress
            compressed, _ = self.compress_attn(
                query=inducing,
                key=x_norm,
                value=x_norm,
                need_weights=False
            )

            # Decompress
            x_norm = self.decompress_norm(x)
            output, _ = self.decompress_attn(
                query=x_norm,
                key=compressed,
                value=compressed,
                need_weights=False
            )

        # Add residual
        x = x + self.dropout(output)

        # FFN with residual
        x = x + self.ffn(self.ffn_norm(x))

        return x
    

class MemoryEfficientCausalISAB(nn.Module):
    """Memory-efficient Causal ISAB using chunked processing"""

    def __init__(self, d_model: int, n_heads: int, n_inducing_points: int, 
                 dropout: float = 0.1, chunk_size: int = 64):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_inducing_points = n_inducing_points
        self.chunk_size = chunk_size

        # Single set of inducing points (not position-specific)
        self.inducing_points = nn.Parameter(torch.randn(n_inducing_points, d_model))

        # Attention layers with memory-efficient settings
        self.compress_attn = nn.MultiheadAttention(
            d_model, n_heads, dropout=dropout, batch_first=True
        )
        self.decompress_attn = nn.MultiheadAttention(
            d_model, n_heads, dropout=dropout, batch_first=True
        )

        # Norms and FFN
        self.compress_norm = RMSNorm(d_model)
        self.decompress_norm = RMSNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model),
            nn.Dropout(dropout)
        )
        self.ffn_norm = RMSNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, causal_mask: bool = True) -> torch.Tensor:
        batch_size, seq_len, _ = x.shape
        device = x.device

        if causal_mask:
            # Process in chunks to save memory
            output = torch.zeros_like(x)

            for start_idx in range(0, seq_len, self.chunk_size):
                end_idx = min(start_idx + self.chunk_size, seq_len)
                chunk_len = end_idx - start_idx

                # Process chunk with causal attention to inducing points
                x_chunk = x[:, :end_idx]  # Can see all previous positions
                x_norm_chunk = self.compress_norm(x_chunk)

                # Inducing points for this chunk
                inducing = self.inducing_points.unsqueeze(0).expand(batch_size, -1, -1)

                # Compress: inducing points attend to all positions up to end_idx
                compressed, _ = self.compress_attn(
                    query=inducing,
                    key=x_norm_chunk,
                    value=x_norm_chunk,
                    need_weights=False
                )

                # Decompress only for current chunk positions
                x_norm_current = self.decompress_norm(x[:, start_idx:end_idx])
                decompressed, _ = self.decompress_attn(
                    query=x_norm_current,
                    key=compressed,
                    value=compressed,
                    need_weights=False
                )

                output[:, start_idx:end_idx] = decompressed

            # Residual connection
            x = x + self.dropout(output)
        else:
            # Non-causal: standard ISAB (already efficient)
            inducing = self.inducing_points.unsqueeze(0).expand(batch_size, -1, -1)
            x_norm = self.compress_norm(x)
            compressed, _ = self.compress_attn(
                query=inducing, key=x_norm, value=x_norm, need_weights=False
            )
            x_norm = self.decompress_norm(x)
            output, _ = self.decompress_attn(
                query=x_norm, key=compressed, value=compressed, need_weights=False
            )
            x = x + self.dropout(output)

        # FFN with residual
        x = x + self.ffn(self.ffn_norm(x))
        return x



class TransformerEncoderBlock(nn.Module):
    """Transformer Encoder Block with unified pre-norm pattern

    Handles all normalization, dropout, and residual connections
    for both attention and FFN sub-layers uniformly.
    """
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        d_ff: Optional[int] = None,
        max_seq_len: int = 4096,
        dropout: float = 0.1,
        rope_base: int = 10000,
        activation: str = "swiglu",
        eps: float = 1e-6
    ):
        super().__init__()

        # Sub-layer components
        self.attention = MultiheadAttentionWithRoPE(
            d_model=d_model,
            n_heads=n_heads,
            max_seq_len=max_seq_len,
            attn_dropout=dropout,  # Internal attention dropout
            rope_base=rope_base
        )

        self.ffn = FeedForwardNetwork(
            d_model=d_model,
            d_ff=d_ff,
            activation=activation
        )

        # Layer normalizations (pre-norm configuration)
        self.attn_norm = RMSNorm(d_model, eps=eps)
        self.ffn_norm = RMSNorm(d_model, eps=eps)

        # Dropout for residual connections
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x: torch.Tensor,
        causal_mask: bool = False,
        attention_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Forward pass with consistent pre-norm pattern

        Pattern for each sub-layer:
        1. Store residual
        2. Apply normalization
        3. Apply sub-layer (attention or FFN)
        4. Apply dropout
        5. Add residual
        """
        # Attention sub-layer with pre-norm
        residual = x
        x_norm = self.attn_norm(x)
        attn_output = self.attention(x_norm, causal_mask=causal_mask, 
                                     attention_mask=attention_mask)
        x = residual + self.dropout(attn_output)

        # FFN sub-layer with pre-norm
        residual = x
        x_norm = self.ffn_norm(x)
        ffn_output = self.ffn(x_norm)
        x = residual + self.dropout(ffn_output)

        return x


class TransformerEncoder(nn.Module):
    """Complete Transformer Encoder with multiple layers"""
    def __init__(
        self,
        n_layers: int,
        d_model: int,
        n_heads: int,
        d_ff: Optional[int] = None,
        max_seq_len: int = 4096,
        dropout: float = 0.1,
        rope_base: int = 10000,
        activation: str = "swiglu",
        eps: float = 1e-6
    ):
        super().__init__()

        self.layers = nn.ModuleList([
            TransformerEncoderBlock(
                d_model=d_model,
                n_heads=n_heads,
                d_ff=d_ff,
                max_seq_len=max_seq_len,
                dropout=dropout,
                rope_base=rope_base,
                activation=activation,
                eps=eps
            )
            for _ in range(n_layers)
        ])

        # Final layer normalization
        self.final_norm = RMSNorm(d_model, eps=eps)

    def forward(
        self,
        x: torch.Tensor,
        causal_mask: bool = False,
        attention_mask: Optional[torch.Tensor] = None,
        return_hidden_states: bool = False
    ) -> torch.Tensor:
        hidden_states = [] if return_hidden_states else None

        for layer in self.layers:
            x = layer(x, causal_mask=causal_mask, attention_mask=attention_mask)
            if return_hidden_states:
                hidden_states.append(x)

        x = self.final_norm(x)

        if return_hidden_states:
            return x, hidden_states
        return x
    
class CausalInductionTransformerEncoderBlock(nn.Module):
    """Transformer Encoder Block using Causal Induction Self-Attention

    This block replaces standard self-attention with the Causal Induction
    Self-Attention mechanism for improved efficiency on long sequences.
    """
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        n_inducing_points: int,
        d_ff: Optional[int] = None,
        dropout: float = 0.1,
        activation: str = "swiglu",
        eps: float = 1e-6,
        use_external_ffn: bool = True
    ):
        super().__init__()

        # Causal Induction Self-Attention Block (already includes FFN)
        self.cisab = CausalInductionSelfAttentionBlock(
            d_model=d_model,
            n_heads=n_heads,
            n_inducing_points=n_inducing_points,
            dropout=dropout
        )

        # Optional additional FFN layer for deeper processing
        self.use_external_ffn = use_external_ffn
        if use_external_ffn:
            self.ffn = FeedForwardNetwork(
                d_model=d_model,
                d_ff=d_ff,
                activation=activation
            )
            self.ffn_norm = RMSNorm(d_model, eps=eps)
            self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x: torch.Tensor,
        causal_mask: bool = True,
        attention_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Forward pass through the CISAB block

        Args:
            x: Input tensor [batch, seq_len, d_model]
            causal_mask: Whether to apply causal masking
            attention_mask: Optional attention mask (not used in CISAB)

        Returns:
            Output tensor [batch, seq_len, d_model]
        """
        # Apply Causal Induction Self-Attention
        x = self.cisab(x, causal_mask=causal_mask)

        # Optional additional FFN
        if self.use_external_ffn:
            residual = x
            x_norm = self.ffn_norm(x)
            ffn_output = self.ffn(x_norm)
            x = residual + self.dropout(ffn_output)

        return x


class CausalInductionTransformerEncoder(nn.Module):
    """Transformer Encoder using Causal Induction Self-Attention Blocks

    This encoder replaces standard self-attention with causal induction
    attention for improved efficiency, especially on long sequences.
    """
    def __init__(
        self,
        n_layers: int,
        d_model: int,
        n_heads: int,
        n_inducing_points: int,
        d_ff: Optional[int] = None,
        dropout: float = 0.1,
        activation: str = "swiglu",
        eps: float = 1e-6,
        use_external_ffn: bool = True,
        varying_inducing_points: Optional[list] = None
    ):
        """Initialize the Causal Induction Transformer Encoder

        Args:
            n_layers: Number of encoder layers
            d_model: Dimension of the model
            n_heads: Number of attention heads
            n_inducing_points: Number of inducing points for compression
            d_ff: Dimension of feedforward network (default: 4 * d_model)
            dropout: Dropout rate
            activation: Activation function for FFN
            eps: Epsilon for layer normalization
            use_external_ffn: Whether to use additional FFN after CISAB
            varying_inducing_points: Optional list of inducing points per layer
        """
        super().__init__()

        self.n_layers = n_layers
        self.d_model = d_model

        # Allow varying number of inducing points per layer
        if varying_inducing_points is not None:
            assert len(varying_inducing_points) == n_layers, \
                "Length of varying_inducing_points must match n_layers"
            inducing_points_per_layer = varying_inducing_points
        else:
            inducing_points_per_layer = [n_inducing_points] * n_layers

        # Create encoder layers
        self.layers = nn.ModuleList([
            CausalInductionTransformerEncoderBlock(
                d_model=d_model,
                n_heads=n_heads,
                n_inducing_points=inducing_points_per_layer[i],
                d_ff=d_ff,
                dropout=dropout,
                activation=activation,
                eps=eps,
                use_external_ffn=use_external_ffn
            )
            for i in range(n_layers)
        ])

        # Final layer normalization
        self.final_norm = RMSNorm(d_model, eps=eps)

    def forward(
        self,
        x: torch.Tensor,
        causal_mask: bool = True,
        attention_mask: Optional[torch.Tensor] = None,
        return_hidden_states: bool = False
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, list]]:
        """Forward pass through the encoder

        Args:
            x: Input tensor [batch, seq_len, d_model]
            causal_mask: Whether to apply causal masking
            attention_mask: Optional attention mask (passed but not used by CISAB)
            return_hidden_states: Whether to return intermediate hidden states

        Returns:
            Output tensor [batch, seq_len, d_model]
            Optionally: (output, hidden_states) if return_hidden_states=True
        """
        hidden_states = [] if return_hidden_states else None

        for layer in self.layers:
            x = layer(x, causal_mask=causal_mask, attention_mask=attention_mask)
            if return_hidden_states:
                hidden_states.append(x)

        # Apply final normalization
        x = self.final_norm(x)

        if return_hidden_states:
            return x, hidden_states
        return x


class HybridTransformerEncoder(nn.Module):
    """Hybrid Transformer that can mix standard and CISAB blocks

    This allows combining standard attention for early layers (better for
    local patterns) with CISAB for later layers (better for long-range).
    """
    def __init__(
        self,
        n_layers: int,
        d_model: int,
        n_heads: int,
        n_inducing_points: int,
        d_ff: Optional[int] = None,
        max_seq_len: int = 4096,
        dropout: float = 0.1,
        rope_base: int = 10000,
        activation: str = "swiglu",
        eps: float = 1e-6,
        cisab_layers: Optional[list] = None,
        cisab_start_layer: Optional[int] = None
    ):
        """Initialize Hybrid Transformer Encoder

        Args:
            cisab_layers: List of layer indices to use CISAB (e.g., [2, 3, 5])
            cisab_start_layer: Alternative - use CISAB from this layer onwards
        """
        super().__init__()

        self.n_layers = n_layers
        self.d_model = d_model

        # Determine which layers use CISAB
        if cisab_layers is not None:
            use_cisab = [i in cisab_layers for i in range(n_layers)]
        elif cisab_start_layer is not None:
            use_cisab = [i >= cisab_start_layer for i in range(n_layers)]
        else:
            # Default: use CISAB for second half of layers
            use_cisab = [i >= n_layers // 2 for i in range(n_layers)]

        # Create mixed layers
        self.layers = nn.ModuleList()
        for i in range(n_layers):
            if use_cisab[i]:
                layer = CausalInductionTransformerEncoderBlock(
                    d_model=d_model,
                    n_heads=n_heads,
                    n_inducing_points=n_inducing_points,
                    d_ff=d_ff,
                    dropout=dropout,
                    activation=activation,
                    eps=eps,
                    use_external_ffn=True
                )
            else:
                layer = TransformerEncoderBlock(
                    d_model=d_model,
                    n_heads=n_heads,
                    d_ff=d_ff,
                    max_seq_len=max_seq_len,
                    dropout=dropout,
                    rope_base=rope_base,
                    activation=activation,
                    eps=eps
                )
            self.layers.append(layer)

        # Final layer normalization
        self.final_norm = RMSNorm(d_model, eps=eps)

        # Store configuration for inspection
        self.use_cisab = use_cisab

    def forward(
        self,
        x: torch.Tensor,
        causal_mask: bool = False,
        attention_mask: Optional[torch.Tensor] = None,
        return_hidden_states: bool = False
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, list]]:
        """Forward pass through hybrid encoder"""
        hidden_states = [] if return_hidden_states else None

        for layer in self.layers:
            x = layer(x, causal_mask=causal_mask, attention_mask=attention_mask)
            if return_hidden_states:
                hidden_states.append(x)

        x = self.final_norm(x)

        if return_hidden_states:
            return x, hidden_states
        return x


In [2]:
import pytest
import torch
import torch.nn as nn
import numpy as np
from typing import Optional
import warnings

# Assuming the transformer code is in a module called 'transformers'
# from transformers import (
#     RMSNorm, RotaryPositionalEmbedding, MultiheadAttentionWithRoPE,
#     FeedForwardNetwork, CausalInductionSelfAttentionBlock,
#     TransformerEncoderBlock, TransformerEncoder,
#     CausalInductionTransformerEncoderBlock, CausalInductionTransformerEncoder,
#     HybridTransformerEncoder
# )

class TestRMSNorm:
    """Test Root Mean Square Layer Normalization"""

    def test_output_shape(self):
        """Test that RMSNorm preserves input shape"""
        norm = RMSNorm(dim=128)
        x = torch.randn(2, 10, 128)
        output = norm(x)
        assert output.shape == x.shape

    def test_normalization_effect(self):
        """Test that RMSNorm actually normalizes"""
        norm = RMSNorm(dim=64)
        x = torch.randn(4, 8, 64) * 10  # Large variance input
        output = norm(x)

        # Check that RMS is approximately 1 after normalization
        rms = torch.sqrt(torch.mean(output ** 2, dim=-1))
        assert torch.allclose(rms, torch.ones_like(rms), atol=0.1)

    def test_gradient_flow(self):
        """Test gradient flow through RMSNorm"""
        norm = RMSNorm(dim=32)
        x = torch.randn(2, 5, 32, requires_grad=True)
        output = norm(x)
        loss = output.sum()
        loss.backward()
        assert x.grad is not None
        assert not torch.isnan(x.grad).any()


class TestRotaryPositionalEmbedding:
    """Test Rotary Positional Embeddings"""

    def test_output_shape(self):
        """Test RoPE preserves input shape"""
        rope = RotaryPositionalEmbedding(dim=64, max_seq_len=100)
        x = torch.randn(2, 10, 8, 64)  # [batch, seq, heads, dim]
        output = rope(x, seq_len=10)
        assert output.shape == x.shape

    def test_position_dependency(self):
        """Test that different positions get different embeddings"""
        rope = RotaryPositionalEmbedding(dim=64, max_seq_len=100)
        x = torch.ones(1, 3, 1, 64)  # Same input at different positions
        output = rope(x, seq_len=3)

        # Different positions should have different outputs
        assert not torch.allclose(output[0, 0], output[0, 1])
        assert not torch.allclose(output[0, 1], output[0, 2])

    def test_deterministic(self):
        """Test that RoPE is deterministic for same position"""
        rope = RotaryPositionalEmbedding(dim=64, max_seq_len=100)
        x = torch.randn(2, 10, 4, 64)
        output1 = rope(x, seq_len=10)
        output2 = rope(x, seq_len=10)
        assert torch.allclose(output1, output2)


class TestCausality:
    """Extensive tests for causal masking in attention mechanisms"""

    def test_standard_attention_causality(self):
        """Test that causal mask prevents attending to future positions"""
        attn = MultiheadAttentionWithRoPE(
            d_model=64,
            n_heads=4,
            max_seq_len=100
        )
        attn.eval()  # Disable dropout for deterministic testing

        batch_size = 2
        seq_len = 10
        d_model = 64

        # Create input where each position has a unique pattern
        x = torch.randn(batch_size, seq_len, d_model)

        # Test with causal mask
        with torch.no_grad():
            output_causal = attn(x, causal_mask=True)

        # Perturb future positions and check if past positions remain unchanged
        for pos in range(seq_len - 1):
            x_perturbed = x.clone()
            # Significantly perturb all positions after 'pos'
            x_perturbed[:, pos+1:, :] = x_perturbed[:, pos+1:, :] + 100.0

            with torch.no_grad():
                output_perturbed = attn(x_perturbed, causal_mask=True)

            # Output at positions up to 'pos' should remain unchanged
            assert torch.allclose(
                output_causal[:, :pos+1, :],
                output_perturbed[:, :pos+1, :],
                atol=1e-5
            ), f"Causality violated at position {pos}"

    def test_attention_mask_shape(self):
        """Test that attention scores have correct causal structure"""
        d_model = 32
        n_heads = 2
        seq_len = 8

        # Create a custom attention module that exposes attention weights
        class AttentionWithWeights(MultiheadAttentionWithRoPE):
            def forward(self, x, causal_mask=False, attention_mask=None):
                batch_size, seq_len, _ = x.shape

                # Compute Q, K, V
                q = self.q_proj(x)
                k = self.k_proj(x)
                v = self.v_proj(x)

                # Reshape for multi-head
                q = q.view(batch_size, seq_len, self.n_heads, self.head_dim)
                k = k.view(batch_size, seq_len, self.n_heads, self.head_dim)
                v = v.view(batch_size, seq_len, self.n_heads, self.head_dim)

                # Apply RoPE
                q = self.rope(q, seq_len)
                k = self.rope(k, seq_len)

                # Transpose
                q = q.transpose(1, 2)
                k = k.transpose(1, 2)
                v = v.transpose(1, 2)

                # Compute scores
                scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale

                if causal_mask:
                    mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
                    scores = scores.masked_fill(mask, float('-inf'))

                # Return both output and attention scores for testing
                attn_weights = torch.softmax(scores, dim=-1)
                return attn_weights

        attn = AttentionWithWeights(d_model=d_model, n_heads=n_heads)
        x = torch.randn(1, seq_len, d_model)

        with torch.no_grad():
            attn_weights = attn(x, causal_mask=True)

        # Check that attention weights are zero for future positions
        for head in range(n_heads):
            for i in range(seq_len):
                for j in range(i + 1, seq_len):
                    assert attn_weights[0, head, i, j] < 1e-6, \
                        f"Non-zero attention from position {i} to future position {j}"

    def test_transformer_encoder_causality(self):
        """Test causality in full transformer encoder"""
        encoder = TransformerEncoder(
            n_layers=2,
            d_model=64,
            n_heads=4,
            dropout=0.0  # No dropout for deterministic testing
        )
        encoder.eval()

        batch_size = 2
        seq_len = 12
        x = torch.randn(batch_size, seq_len, 64)

        with torch.no_grad():
            output_causal = encoder(x, causal_mask=True)

        # Test that modifying future doesn't affect past
        for pos in range(seq_len - 1):
            x_modified = x.clone()
            x_modified[:, pos+1:, :] = torch.randn_like(x_modified[:, pos+1:, :]) * 10

            with torch.no_grad():
                output_modified = encoder(x_modified, causal_mask=True)

            assert torch.allclose(
                output_causal[:, :pos+1, :],
                output_modified[:, :pos+1, :],
                atol=1e-5
            ), f"Encoder causality violated at position {pos}"


class TestCISABCausality:
    """Extensive tests for Causal Induction Self-Attention Block causality"""

    def test_cisab_basic_causality(self):
        """Test that CISAB maintains causality"""
        cisab = CausalInductionSelfAttentionBlock(
            d_model=64,
            n_heads=4,
            n_inducing_points=8,
            dropout=0.0
        )
        cisab.eval()

        batch_size = 2
        seq_len = 16
        x = torch.randn(batch_size, seq_len, 64)

        with torch.no_grad():
            output_causal = cisab(x, causal_mask=True)

        # Modify future positions
        for pos in range(seq_len - 1):
            x_modified = x.clone()
            # Add large perturbation to future positions
            x_modified[:, pos+1:, :] = x_modified[:, pos+1:, :] + 50.0

            with torch.no_grad():
                output_modified = cisab(x_modified, causal_mask=True)

            # Check past positions remain unchanged
            assert torch.allclose(
                output_causal[:, :pos+1, :],
                output_modified[:, :pos+1, :],
                atol=1e-4
            ), f"CISAB causality violated at position {pos}"

    def test_cisab_inducing_mask_structure(self):
        """Test that inducing point mask has correct structure"""
        cisab = CausalInductionSelfAttentionBlock(
            d_model=32,
            n_heads=2,
            n_inducing_points=4,
            dropout=0.0
        )

        seq_len = 8
        device = torch.device('cpu')

        # Get the mask
        mask = cisab._create_causal_inducing_mask(seq_len, device)

        # Check mask structure
        n_ind = 4
        for pos in range(seq_len):
            start_idx = pos * n_ind
            end_idx = (pos + 1) * n_ind

            # Inducing points at position 'pos' should only see positions 0...pos
            for ind_idx in range(start_idx, end_idx):
                # Should be able to attend to positions 0...pos (mask = 0)
                for j in range(pos + 1):
                    assert mask[ind_idx, j] == 0, \
                        f"Inducing point {ind_idx} at position {pos} cannot see position {j}"

                # Should not attend to positions after pos (mask = -inf)
                for j in range(pos + 1, seq_len):
                    assert mask[ind_idx, j] == float('-inf'), \
                        f"Inducing point {ind_idx} at position {pos} can see future position {j}"

    def test_cisab_vs_non_causal(self):
        """Test that causal and non-causal CISAB produce different results"""
        cisab = CausalInductionSelfAttentionBlock(
            d_model=64,
            n_heads=4,
            n_inducing_points=8,
            dropout=0.0
        )
        cisab.eval()

        x = torch.randn(2, 10, 64)

        with torch.no_grad():
            output_causal = cisab(x, causal_mask=True)
            output_non_causal = cisab(x, causal_mask=False)

        # Outputs should be different
        assert not torch.allclose(output_causal, output_non_causal, atol=1e-3)


class TestCausalInductionTransformer:
    """Test full Causal Induction Transformer"""

    def test_encoder_output_shape(self):
        """Test output shapes of CISAB encoder"""
        encoder = CausalInductionTransformerEncoder(
            n_layers=3,
            d_model=64,
            n_heads=4,
            n_inducing_points=8
        )

        x = torch.randn(2, 20, 64)
        output = encoder(x)
        assert output.shape == x.shape

    def test_encoder_causality_preservation(self):
        """Test that multi-layer CISAB encoder preserves causality"""
        encoder = CausalInductionTransformerEncoder(
            n_layers=4,
            d_model=128,
            n_heads=8,
            n_inducing_points=16,
            dropout=0.0
        )
        encoder.eval()

        batch_size = 2
        seq_len = 24
        x = torch.randn(batch_size, seq_len, 128)

        with torch.no_grad():
            output_original = encoder(x, causal_mask=True)

        # Test independence from future
        for pos in range(0, seq_len - 1, 3):  # Test every 3rd position for speed
            x_future_modified = x.clone()
            x_future_modified[:, pos+1:, :] = torch.randn_like(x_future_modified[:, pos+1:, :]) * 5

            with torch.no_grad():
                output_modified = encoder(x_future_modified, causal_mask=True)

            assert torch.allclose(
                output_original[:, :pos+1, :],
                output_modified[:, :pos+1, :],
                atol=1e-4
            ), f"Multi-layer CISAB causality violated at position {pos}"

    def test_varying_inducing_points(self):
        """Test encoder with different inducing points per layer"""
        encoder = CausalInductionTransformerEncoder(
            n_layers=4,
            d_model=64,
            n_heads=4,
            n_inducing_points=8,  # Default
            varying_inducing_points=[4, 8, 12, 16]  # Increasing
        )

        x = torch.randn(2, 15, 64)
        output = encoder(x)
        assert output.shape == x.shape


class TestHybridTransformer:
    """Test Hybrid Transformer mixing standard and CISAB blocks"""

    def test_hybrid_construction(self):
        """Test that hybrid transformer constructs correctly"""
        encoder = HybridTransformerEncoder(
            n_layers=6,
            d_model=64,
            n_heads=4,
            n_inducing_points=8,
            cisab_layers=[2, 3, 5]  # Specific layers use CISAB
        )

        # Check that correct layers are CISAB
        assert encoder.use_cisab == [False, False, True, True, False, True]

    def test_hybrid_causality(self):
        """Test causality in hybrid transformer"""
        encoder = HybridTransformerEncoder(
            n_layers=4,
            d_model=64,
            n_heads=4,
            n_inducing_points=8,
            cisab_start_layer=2,  # Last 2 layers use CISAB
            dropout=0.0
        )
        encoder.eval()

        x = torch.randn(2, 16, 64)

        with torch.no_grad():
            output_original = encoder(x, causal_mask=True)

        # Test causality preservation
        for pos in range(0, 15, 2):
            x_modified = x.clone()
            x_modified[:, pos+1:, :] = torch.randn_like(x_modified[:, pos+1:, :]) * 10

            with torch.no_grad():
                output_modified = encoder(x_modified, causal_mask=True)

            assert torch.allclose(
                output_original[:, :pos+1, :],
                output_modified[:, :pos+1, :],
                atol=1e-4
            ), f"Hybrid transformer causality violated at position {pos}"


class TestEdgeCases:
    """Test edge cases and special scenarios"""

    def test_single_token_sequence(self):
        """Test with sequence length of 1"""
        encoder = TransformerEncoder(
            n_layers=2,
            d_model=32,
            n_heads=4
        )

        x = torch.randn(2, 1, 32)
        output = encoder(x, causal_mask=True)
        assert output.shape == x.shape

    def test_very_long_sequence(self):
        """Test with long sequences"""
        encoder = CausalInductionTransformerEncoder(
            n_layers=2,
            d_model=64,
            n_heads=4,
            n_inducing_points=32  # More inducing points for long sequence
        )

        x = torch.randn(1, 512, 64)
        output = encoder(x, causal_mask=True)
        assert output.shape == x.shape

    def test_gradient_flow_through_cisab(self):
        """Test gradient flow through CISAB"""
        cisab = CausalInductionSelfAttentionBlock(
            d_model=32,
            n_heads=2,
            n_inducing_points=4
        )

        x = torch.randn(2, 8, 32, requires_grad=True)
        output = cisab(x, causal_mask=True)
        loss = output.sum()
        loss.backward()

        assert x.grad is not None
        assert not torch.isnan(x.grad).any()
        assert not torch.isinf(x.grad).any()

    def test_attention_mask_compatibility(self):
        """Test that custom attention masks work"""
        encoder = TransformerEncoder(
            n_layers=2,
            d_model=64,
            n_heads=4
        )

        batch_size = 2
        seq_len = 10
        x = torch.randn(batch_size, seq_len, 64)

        # Create custom attention mask (e.g., for padding)
        attention_mask = torch.ones(seq_len, seq_len)
        attention_mask[:, 7:] = 0  # Mask out positions 7-9

        output = encoder(x, attention_mask=attention_mask)
        assert output.shape == x.shape


class TestNumericalStability:
    """Test numerical stability of components"""

    def test_rmsnorm_stability(self):
        """Test RMSNorm with extreme values"""
        norm = RMSNorm(dim=64)

        # Very small values
        x_small = torch.randn(2, 5, 64) * 1e-8
        output_small = norm(x_small)
        assert not torch.isnan(output_small).any()
        assert not torch.isinf(output_small).any()

        # Very large values
        x_large = torch.randn(2, 5, 64) * 1e8
        output_large = norm(x_large)
        assert not torch.isnan(output_large).any()
        assert not torch.isinf(output_large).any()

    def test_attention_numerical_stability(self):
        """Test attention with extreme values"""
        attn = MultiheadAttentionWithRoPE(
            d_model=32,
            n_heads=2
        )

        # Input with large variance
        x = torch.randn(2, 8, 32) * 100
        output = attn(x, causal_mask=True)
        assert not torch.isnan(output).any()
        assert not torch.isinf(output).any()


def test_comprehensive_causality_suite():
    """Run comprehensive causality tests across all architectures"""

    print("Testing Standard Transformer Causality...")
    test_causal = TestCausality()
    test_causal.test_standard_attention_causality()
    test_causal.test_transformer_encoder_causality()
    print("✓ Standard Transformer causality tests passed")

    print("\nTesting CISAB Causality...")
    test_cisab = TestCISABCausality()
    test_cisab.test_cisab_basic_causality()
    test_cisab.test_cisab_inducing_mask_structure()
    print("✓ CISAB causality tests passed")

    print("\nTesting Causal Induction Transformer...")
    test_ci_transformer = TestCausalInductionTransformer()
    test_ci_transformer.test_encoder_causality_preservation()
    print("✓ Causal Induction Transformer causality tests passed")

    print("\nTesting Hybrid Transformer...")
    test_hybrid = TestHybridTransformer()
    test_hybrid.test_hybrid_causality()
    print("✓ Hybrid Transformer causality tests passed")

    print("\n✅ All causality tests passed successfully!")


if __name__ == "__main__":
    # Run all tests
    test_comprehensive_causality_suite()

    # Additional basic functionality tests
    print("\nRunning additional functionality tests...")

    # Test RMSNorm
    test_norm = TestRMSNorm()
    test_norm.test_output_shape()
    test_norm.test_normalization_effect()

    # Test RoPE
    test_rope = TestRotaryPositionalEmbedding()
    test_rope.test_output_shape()
    test_rope.test_position_dependency()

    # Test edge cases
    test_edge = TestEdgeCases()
    test_edge.test_single_token_sequence()
    test_edge.test_gradient_flow_through_cisab()

    print("✅ All tests completed successfully!")


Testing Standard Transformer Causality...
✓ Standard Transformer causality tests passed

Testing CISAB Causality...
✓ CISAB causality tests passed

Testing Causal Induction Transformer...
✓ Causal Induction Transformer causality tests passed

Testing Hybrid Transformer...
✓ Hybrid Transformer causality tests passed

✅ All causality tests passed successfully!

Running additional functionality tests...
✅ All tests completed successfully!


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
cisab_model = CausalInductionTransformerEncoder(
    n_layers=8,
    d_model=512,
    n_heads=8,
    n_inducing_points=64,
    dropout=0.1
).to(device)
X = torch.randn(32, 512, 512).to(device)  # [batch, seq_len, d_model]
Z = cisab_model(X, causal_mask=True)  # [batch, seq_len, d_model]