In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

"""
MMFormer_LiFS_Improved_v2

This file implements an improved MMFormer for Liver Fibrosis Staging (LiFS),
with missing-modality compensation via a Delta Function, following:
- Modality-specific encoder (Sec. 2.3)
- Intra-modality Transformer (Sec. 2.3, Eq. (4)-(8))
- Delta Function for missing-modality compensation (Sec. 2.4, Eq. (9)-(13))
- Modality-correlated encoder / cross-modality Transformer fusion (Sec. 2.5, Eq. (14)-(17))
- Classification head (Sec. 2.6)
"""

# =========================
# Global configuration
# =========================
# Paper-aligned typical settings:
# C_base = 8, C_t = 256, h = 4 heads, L = 4 Transformer blocks, M = 3 modalities
basic_dims = 8
transformer_basic_dims = 256
mlp_dim = 1024
num_heads = 4
depth = 4
num_modals = 3
num_classes = 2
fusion_dim = transformer_basic_dims

# =========================
# Normalization helper
# =========================
def normalization(planes, norm='bn'):
    """
    Normalization options.
    The paper typically uses InstanceNorm3d (IN) for 3D conv units in the encoder (Sec. 2.3).
    """
    if norm == 'bn':
        return nn.BatchNorm3d(planes)
    elif norm == 'gn':
        return nn.GroupNorm(4, planes)
    elif norm == 'in':
        return nn.InstanceNorm3d(planes, affine=True)
    else:
        raise ValueError(f'Unsupported norm: {norm}')

# =========================
# 3D Conv Unit used in the encoder
# =========================
class general_conv3d_prenorm(nn.Module):
    """
    3D Conv -> Norm -> Activation -> (optional) Dropout3d

    Paper correspondence:
    - Residual 3D conv unit: IN + LeakyReLU (+ dropout) (Sec. 2.3)
    """
    def __init__(self, in_ch, out_ch, k_size=3, stride=1, padding=1, pad_type='zeros',
                 norm='in', act_type='lrelu', relufactor=0.2, dropout_rate=0.1):
        super().__init__()
        self.conv = nn.Conv3d(in_ch, out_ch, k_size, stride, padding, padding_mode=pad_type, bias=True)
        self.norm = normalization(out_ch, norm)
        self.act = nn.LeakyReLU(negative_slope=relufactor, inplace=True) if act_type == 'lrelu' else nn.ReLU(inplace=True)
        self.dropout = nn.Dropout3d(dropout_rate) if dropout_rate > 0 else nn.Identity()

    def forward(self, x):
        return self.dropout(self.act(self.norm(self.conv(x))))

# =========================
# Modality-specific Encoder (5-stage 3D Res-CNN)
# =========================
class Encoder(nn.Module):
    """
    5-stage residual 3D convolutional encoder.

    Paper correspondence (Sec. 2.3):
    - Stages s=1..5
      * stage1: stride=1
      * stage2-5: stride=2 down-sampling
    - Channel progression roughly: C_s = 2^(s-1) * C_base  (Eq. (1) style description)
      * s1: 8, s2: 16, s3: 32, s4: 64, s5: 128  (given C_base=8)
    - Residual within each stage: x_s = x_s + Conv(Conv(x_s))  (Eq. (2) style)
    - Total down-sampling: 2^4 = 16 => output resolution ~ (D,H,W)/16
    """
    def __init__(self, dropout_rate=0.1):
        super().__init__()
        # Stage 1 (no down-sampling)
        self.e1_c1 = nn.Conv3d(1, basic_dims, 3, 1, 1, padding_mode='reflect', bias=True)
        self.e1_c2 = general_conv3d_prenorm(basic_dims, basic_dims, pad_type='reflect', dropout_rate=dropout_rate)
        self.e1_c3 = general_conv3d_prenorm(basic_dims, basic_dims, pad_type='reflect', dropout_rate=dropout_rate)

        # Stage 2 (down-sampling by stride=2)
        self.e2_c1 = general_conv3d_prenorm(basic_dims, basic_dims * 2, stride=2, pad_type='reflect', dropout_rate=dropout_rate)
        self.e2_c2 = general_conv3d_prenorm(basic_dims * 2, basic_dims * 2, pad_type='reflect', dropout_rate=dropout_rate)
        self.e2_c3 = general_conv3d_prenorm(basic_dims * 2, basic_dims * 2, pad_type='reflect', dropout_rate=dropout_rate)

        # Stage 3
        self.e3_c1 = general_conv3d_prenorm(basic_dims * 2, basic_dims * 4, stride=2, pad_type='reflect', dropout_rate=dropout_rate)
        self.e3_c2 = general_conv3d_prenorm(basic_dims * 4, basic_dims * 4, pad_type='reflect', dropout_rate=dropout_rate)
        self.e3_c3 = general_conv3d_prenorm(basic_dims * 4, basic_dims * 4, pad_type='reflect', dropout_rate=dropout_rate)

        # Stage 4
        self.e4_c1 = general_conv3d_prenorm(basic_dims * 4, basic_dims * 8, stride=2, pad_type='reflect', dropout_rate=dropout_rate)
        self.e4_c2 = general_conv3d_prenorm(basic_dims * 8, basic_dims * 8, pad_type='reflect', dropout_rate=dropout_rate)
        self.e4_c3 = general_conv3d_prenorm(basic_dims * 8, basic_dims * 8, pad_type='reflect', dropout_rate=dropout_rate)

        # Stage 5
        self.e5_c1 = general_conv3d_prenorm(basic_dims * 8, basic_dims * 16, stride=2, pad_type='reflect', dropout_rate=dropout_rate)
        self.e5_c2 = general_conv3d_prenorm(basic_dims * 16, basic_dims * 16, pad_type='reflect', dropout_rate=dropout_rate)
        self.e5_c3 = general_conv3d_prenorm(basic_dims * 16, basic_dims * 16, pad_type='reflect', dropout_rate=dropout_rate)

        # Additional regularization
        self.final_dropout = nn.Dropout3d(dropout_rate * 2)

    def forward(self, x):
        # Stage 1 residual
        x1 = self.e1_c1(x)
        x1 = x1 + self.e1_c3(self.e1_c2(x1))

        # Stage 2 residual
        x2 = self.e2_c1(x1)
        x2 = x2 + self.e2_c3(self.e2_c2(x2))

        # Stage 3 residual
        x3 = self.e3_c1(x2)
        x3 = x3 + self.e3_c3(self.e3_c2(x3))

        # Stage 4 residual
        x4 = self.e4_c1(x3)
        x4 = x4 + self.e4_c3(self.e4_c2(x4))

        # Stage 5 residual
        x5 = self.e5_c1(x4)
        x5 = x5 + self.e5_c3(self.e5_c2(x5))

        return self.final_dropout(x5)

# =========================
# Self-Attention + Transformer Blocks
# =========================
class SelfAttention(nn.Module):
    """
    Multi-head self-attention (MSA).

    Paper correspondence (Sec. 2.3 / 2.5):
    - QKV projection, scaled dot-product attention, dropout on attention/projection
    - Used in both intra-modality transformer and cross-modality transformer
    """
    def __init__(self, dim, heads=8, qkv_bias=False, qk_scale=None, dropout_rate=0.2):
        super().__init__()
        self.num_heads = heads
        head_dim = dim // heads
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(dropout_rate)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(dropout_rate)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj_drop(self.proj(x))
        return x

class Residual(nn.Module):
    """Standard residual wrapper: y = f(x) + x (matches residual forms in Eq. (7)(8)(16)(17))."""
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x) + x

class PreNorm(nn.Module):
    """Pre-LN: y = f(LN(x)) (paper uses LN around attention/FFN blocks)."""
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x):
        return self.fn(self.norm(x))

class PreNormDrop(nn.Module):
    """Pre-LN + dropout wrapper (implementation detail for regularization)."""
    def __init__(self, dim, dropout_rate, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.drop = nn.Dropout(dropout_rate)
        self.fn = fn

    def forward(self, x):
        return self.drop(self.fn(self.norm(x)))

class GELU(nn.Module):
    def forward(self, x):
        return F.gelu(x)

class FeedForward(nn.Module):
    """
    FFN / MLP block: Linear -> GELU -> Dropout -> Linear -> Dropout
    Paper correspondence: FFN described in transformer equations (Sec. 2.3/2.5).
    """
    def __init__(self, dim, hidden_dim, dropout_rate):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout_rate)
        )

    def forward(self, x):
        return self.net(x)

class Transformer(nn.Module):
    """
    Transformer encoder stack (used for both intra-modality and cross-modality encoding).

    Paper correspondence:
    - Add sinusoidal positional encoding at each layer: T^l = T^{l-1} + P  (Eq. (4) / Eq. (15))
    - Then MSA + residual + LN, and FFN + residual + LN (Eq. (5)-(8), Eq. (16)-(17))
    """
    def __init__(self, embedding_dim, depth, heads, mlp_dim, dropout_rate=0.2):
        super().__init__()
        self.cross_attn = nn.ModuleList([
            Residual(PreNormDrop(embedding_dim, dropout_rate, SelfAttention(embedding_dim, heads, dropout_rate=dropout_rate)))
            for _ in range(depth)
        ])
        self.cross_ffn = nn.ModuleList([
            Residual(PreNorm(embedding_dim, FeedForward(embedding_dim, mlp_dim, dropout_rate)))
            for _ in range(depth)
        ])
        self.depth = depth

    def forward(self, x, pos):
        for i in range(self.depth):
            # Eq. (4) / Eq. (15): add sinusoidal PE at every transformer layer
            x = x + pos
            # Eq. (5)-(8) or Eq. (16)-(17): MSA + FFN with residual structure
            x = self.cross_attn[i](x)
            x = self.cross_ffn[i](x)
        return x


# =========================
# Delta Function: Missing Modality Compensation
# =========================
class DeltaFunction(nn.Module):
    """
    Delta Function for missing-modality compensation.

    Paper correspondence (Sec. 2.4):
    - Learnable modality calibration parameters (Eq. (9)):
        T̂_m = ((T_m - μ_m) / (σ_m + ε)) ⊙ w_m
      where μ_m, σ_m, w_m are learnable per modality.
    - Reference token computed by averaging available modalities (Eq. (10))
    - Proxy feature generation via MSA on reference tokens (Eq. (11))
    - Proxy refinement and attenuation α to down-weight inferred features (Eq. (12)-(13))
    """
    def __init__(self, feature_dim, num_modals=3, dropout_rate=0.2):
        super().__init__()
        self.num_modals = num_modals
        self.feature_dim = feature_dim

        # Learnable calibration parameters: μ_m, σ_m, w_m (Eq. (9))
        self.delta_mean = nn.Parameter(torch.zeros(num_modals, feature_dim))
        self.delta_std = nn.Parameter(torch.ones(num_modals, feature_dim))
        self.delta_weight = nn.Parameter(torch.ones(num_modals, feature_dim))

        # Cross-modal attention used to synthesize proxy features (Eq. (11))
        self.cross_modal_attention = nn.MultiheadAttention(
            feature_dim, num_heads=4, dropout=dropout_rate, batch_first=True
        )

        # Compensation network (FFN refinement): Eq. (12)-(13) style refinement
        self.compensation_net = nn.Sequential(
            nn.Linear(feature_dim, feature_dim),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(feature_dim, feature_dim),
            nn.Dropout(dropout_rate)
        )

    def forward(self, modal_tokens, missing_mask):
        """
        Args:
            modal_tokens: list length=num_modals, each is (B, N, C) or None
            missing_mask: (B, num_modals), 1=present, 0=missing

        Returns:
            compensated: list length=num_modals, each (B, N, C) with missing ones replaced by proxy
        """
        batch_size = missing_mask.size(0)
        device = missing_mask.device
        compensated = []
        available_features = []

        # ---- Step A: calibrate available modalities (Eq. (9)) ----
        for i, tokens in enumerate(modal_tokens):
            if tokens is not None and missing_mask[:, i].sum() > 0:
                dm = self.delta_mean[i].view(1, 1, -1)
                ds = self.delta_std[i].view(1, 1, -1)
                dw = self.delta_weight[i].view(1, 1, -1)

                # Eq. (9): Delta calibration
                calib = (tokens - dm) / (ds + 1e-8) * dw

                # apply availability mask per sample
                mask_i = missing_mask[:, i].view(-1, 1, 1).expand_as(calib)
                calib = calib * mask_i

                compensated.append(calib)
                available_features.append(calib)
            else:
                compensated.append(None)

        # ---- Step B: synthesize proxy for missing modalities ----
        if available_features:
            # Eq. (10): reference token (average of available modalities)
            ref_feature = torch.stack(available_features, dim=0).mean(dim=0)

            for i, tokens in enumerate(compensated):
                if tokens is None:
                    # Eq. (11): proxy generation via attention on reference tokens
                    proxy, _ = self.cross_modal_attention(ref_feature, ref_feature, ref_feature)

                    # Eq. (12): modality-specific delta calibration on proxy (same form as Eq. (9))
                    dm = self.delta_mean[i].view(1, 1, -1)
                    ds = self.delta_std[i].view(1, 1, -1)
                    dw = self.delta_weight[i].view(1, 1, -1)
                    proxy = (proxy - dm) / (ds + 1e-8) * dw

                    # Eq. (13): refinement + attenuation alpha (down-weight inferred features)
                    proxy = self.compensation_net(proxy) * 0.3  # alpha = 0.3
                    compensated[i] = proxy

        return compensated

# =========================
# Main Model: MMFormer_LiFS_Improved_v2
# =========================
class MMFormer_LiFS_Improved_v2(nn.Module):
    """
    Full model:
    - Modality-specific encoders + projection (Sec. 2.3)
    - Intra-modality transformer per modality (Sec. 2.3)
    - Delta compensation for missing modalities (Sec. 2.4)
    - Cross-modality transformer fusion (Sec. 2.5)
    - Classification head (Sec. 2.6)
    """
    def __init__(self, num_classes=2, dropout_rate=0.2):
        super().__init__()

        # Modality-specific 3D Res-CNN encoders (Sec. 2.3)
        self.encoders = nn.ModuleDict({
            't1': Encoder(dropout_rate=dropout_rate),
            't2': Encoder(dropout_rate=dropout_rate),
            'dwi': Encoder(dropout_rate=dropout_rate)
        })

        # 1x1x1 conv projection to token embedding dim C_t=256 (Sec. 2.3)
        self.encode_proj = nn.ModuleDict({
            't1': nn.Sequential(nn.Conv3d(basic_dims * 16, transformer_basic_dims, 1), nn.Dropout3d(dropout_rate)),
            't2': nn.Sequential(nn.Conv3d(basic_dims * 16, transformer_basic_dims, 1), nn.Dropout3d(dropout_rate)),
            'dwi': nn.Sequential(nn.Conv3d(basic_dims * 16, transformer_basic_dims, 1), nn.Dropout3d(dropout_rate))
        })

        # Intra-modality transformers (Sec. 2.3, Eq. (4)-(8))
        self.intra_transformers = nn.ModuleDict({
            't1': Transformer(transformer_basic_dims, depth, num_heads, mlp_dim, dropout_rate),
            't2': Transformer(transformer_basic_dims, depth, num_heads, mlp_dim, dropout_rate),
            'dwi': Transformer(transformer_basic_dims, depth, num_heads, mlp_dim, dropout_rate)
        })

        # Pre/post LN around fusion tokens (Sec. 2.5 / 2.6)
        self.norm_before = nn.LayerNorm(transformer_basic_dims)
        self.norm_after = nn.LayerNorm(transformer_basic_dims)

        # Delta Function (Sec. 2.4, Eq. (9)-(13))
        self.delta_function = DeltaFunction(transformer_basic_dims, num_modals, dropout_rate)

        # Cross-modality transformer fusion (Sec. 2.5, Eq. (14)-(17))
        self.multimodal_transformer = Transformer(transformer_basic_dims, depth, num_heads, mlp_dim, dropout_rate)

        # Fusion projection (Sec. 2.6)
        self.fuse_proj = nn.Sequential(
            nn.Linear(transformer_basic_dims, fusion_dim),
            nn.Dropout(dropout_rate)
        )

        # Classification head (Sec. 2.6)
        self.classifier = nn.Sequential(
            nn.Linear(fusion_dim, fusion_dim),
            nn.GELU(),
            nn.LayerNorm(fusion_dim),
            nn.Dropout(0.4),
            nn.Linear(fusion_dim, fusion_dim // 2),
            nn.ReLU(),
            nn.BatchNorm1d(fusion_dim // 2),
            nn.Dropout(0.3),
            nn.Linear(fusion_dim // 2, num_classes)
        )

        self._initialize_weights()

    def _initialize_weights(self):
        """
        Initialization notes:
        - Convs: Kaiming (leaky_relu)
        - Linears: Xavier with smaller gain
        This is an implementation choice for stability; does not change the model definition.
        """
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=0.8)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, (nn.BatchNorm3d, nn.LayerNorm, nn.BatchNorm1d)):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def _create_pos_encoding(self, N, dim, device):
        """
        Sinusoidal positional encoding.
        Paper correspondence:
        - Positional encoding P added at each transformer layer (Eq. (4) and Eq. (15)).
        """
        pe = torch.zeros(1, N, dim, device=device)
        pos = torch.arange(N, device=device).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, dim, 2, device=device).float() * -(math.log(10000.0) / dim))

        pe[0, :, 0::2] = torch.sin(pos * div)
        if dim % 2 == 1:
            pe[0, :, 1::2] = torch.cos(pos * div[:-1])
        else:
            pe[0, :, 1::2] = torch.cos(pos * div)

        return pe

    def forward(self, x, missing_mask):
        """
        Args:
            x: (B, 3, D, H, W) input volumes for modalities [t1, t2, dwi]
            missing_mask: (B, 3) modality availability mask, 1=present, 0=missing

        Returns:
            logits: (B, num_classes)

        Paper correspondence:
        - Sec. 2.3: modality-specific encoder -> tokenization -> intra-modality transformer
        - Sec. 2.4: delta compensation for missing modalities
        - Sec. 2.5: concatenate + cross-modality transformer fusion
        - Sec. 2.6: LN + pooling + classifier
        """
        batch_size = x.size(0)
        device = x.device
        modality_names = ['t1', 't2', 'dwi']

        # ===== Step 1: modality-specific encoding (Sec. 2.3) =====
        modal_tokens = []

        for i, modality in enumerate(modality_names):
            # If at least one sample in this batch has this modality, compute tokens
            if missing_mask[:, i].sum() > 0:
                modal_input = x[:, i:i+1, :, :, :]

                # 3D Res-CNN encoder output: (B, C, d, h, w)
                modal_feat = self.encoders[modality](modal_input)

                # 1x1x1 projection to token embedding dim C_t=256 (Sec. 2.3)
                modal_token = self.encode_proj[modality](modal_feat)
                B, C, D, H, W = modal_token.shape
                modal_token = modal_token.view(B, C, -1).permute(0, 2, 1)  # (B, N, C)

                # Intra-modality transformer w/ PE at each layer (Eq. (4)-(8))
                num_tokens = modal_token.size(1)
                pos_encoding = self._create_pos_encoding(num_tokens, transformer_basic_dims, device)
                modal_token = self.intra_transformers[modality](modal_token, pos_encoding)

                modal_tokens.append(modal_token)
            else:
                modal_tokens.append(None)

        # ===== Step 2: Delta function compensation (Sec. 2.4, Eq. (9)-(13)) =====
        compensated_tokens = self.delta_function(modal_tokens, missing_mask)

        # ===== Step 3: gather valid tokens (present + proxy) =====
        valid_tokens = [token for token in compensated_tokens if token is not None]

        if len(valid_tokens) == 0:
            # Extreme case: all modalities missing for the whole batch
            return torch.zeros(batch_size, num_classes, device=device)

        # Eq. (14): align token length by truncating to the shortest sequence
        min_length = min(token.size(1) for token in valid_tokens)
        aligned_tokens = [token[:, :min_length, :] for token in valid_tokens]

        # Eq. (14): concatenate tokens across modalities along the token dimension
        fused = torch.cat(aligned_tokens, dim=1)  # (B, total_N, C)

        # ===== Step 4: cross-modality transformer fusion (Sec. 2.5, Eq. (15)-(17)) =====
        fused = self.norm_before(fused)

        fused_length = fused.size(1)
        pos = self._create_pos_encoding(fused_length, transformer_basic_dims, device)

        fused = self.multimodal_transformer(fused, pos)
        fused = self.norm_after(fused)

        # ===== Step 5: pooling + classification (Sec. 2.6) =====
        pooled = fused.mean(dim=1) # global average pooling over tokens
        fused_vec = self.fuse_proj(pooled)  # (B, fusion_dim)
        logits = self.classifier(fused_vec) # (B, num_classes)

        return logits
