# 🧠🔬 Cross-Attention Transformer for MRI + Proteomic Fusion

## A Complete Walkthrough: From Theory to Implementation

This notebook demonstrates how to build a **cross-attention transformer** that fuses MRI spatial features with proteomic biomarkers for Alzheimer's Disease classification.

### 📋 **What You'll Learn:**
1. **PyTorch vs Custom Implementation** - When to use each
2. **Positional Encoding** - Spatial (MRI) vs Categorical (Proteins)
3. **Cross-Attention Mechanism** - How modalities "talk" to each other
4. **Attention Visualization** - Understanding model decisions
5. **End-to-End Training** - From data to predictions

### 🎯 **Architecture Overview:**
```
MRI Patches [B,100,768] ──┐
                          │
                          ├── Cross-Attention ──> Classification
                          │
Proteins [B,8,8] ─────────┘
```


## 📦 **Setup & Imports**


In [1]:
# Install required packages (if needed)
# !pip install torch torchvision matplotlib seaborn scikit-learn

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, roc_auc_score, classification_report
from sklearn.model_selection import train_test_split

import math
from typing import Dict, Tuple, Optional, List
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")


Using device: cpu
PyTorch version: 2.8.0+cpu


## 🤔 **Design Decisions: PyTorch Built-in vs Custom Implementation**

### ✅ **Use PyTorch Built-in:**
- **`nn.MultiheadAttention`** - Mature, optimized, well-tested
- **`nn.TransformerEncoder/TransformerEncoderLayer`** - For self-attention blocks
- **`nn.Embedding`** - For categorical positional encoding
- **`nn.Linear`, `nn.LayerNorm`, `nn.GELU`** - Standard components

### 🔧 **Custom Implementation:**
- **Cross-attention orchestration** - PyTorch doesn't have cross-modal transformers
- **3D spatial positional encoding** - Domain-specific for brain imaging
- **Multimodal fusion strategy** - Research-specific architecture
- **Attention visualization tools** - For interpretability

### 🎯 **Best of Both Worlds:**
This approach gives us **reliability** (PyTorch) + **flexibility** (custom) + **maintainability** (less code to debug)!


## 🎯 **Step 1: Positional Encoding - The Foundation**

Positional encoding is **critical** for both modalities:
- **MRI**: Spatial brain anatomy (hippocampus vs cortex)
- **Proteins**: Biological function categories (amyloid vs inflammation)


In [2]:
class Spatial3DPositionalEncoding(nn.Module):
    """
    3D spatial positional encoding for MRI brain patches.
    Encodes actual anatomical coordinates in 3D brain space.
    """
    
    def __init__(self, embed_dim: int, max_patches: int = 1000):
        super().__init__()
        self.embed_dim = embed_dim
        
        # Project 3D coordinates (x,y,z) to embedding dimension
        self.coord_projection = nn.Sequential(
            nn.Linear(3, embed_dim // 2),
            nn.GELU(),
            nn.Linear(embed_dim // 2, embed_dim)
        )
        
        # Default grid coordinates for patches (if real coordinates unavailable)
        self.register_buffer('default_coords', self._create_default_grid(max_patches))
    
    def _create_default_grid(self, max_patches: int) -> torch.Tensor:
        """Create a default 3D grid of coordinates."""
        # Assume cubic grid: find cube root
        grid_size = int(np.ceil(max_patches ** (1/3)))
        
        # Create 3D coordinate grid
        coords = torch.linspace(-1, 1, grid_size)
        xx, yy, zz = torch.meshgrid(coords, coords, coords, indexing='ij')
        
        # Flatten and take first max_patches coordinates
        grid_coords = torch.stack([xx.flatten(), yy.flatten(), zz.flatten()], dim=1)
        return grid_coords[:max_patches]
    
    def forward(self, x: torch.Tensor, coordinates: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Args:
            x: [batch_size, num_patches, embed_dim] - MRI patch embeddings
            coordinates: [batch_size, num_patches, 3] - 3D spatial coordinates (optional)
        """
        batch_size, num_patches, _ = x.shape
        
        if coordinates is not None:
            coords = coordinates
        else:
            # Use default grid coordinates
            coords = self.default_coords[:num_patches].unsqueeze(0).repeat(batch_size, 1, 1)
        
        # Project coordinates to embedding space
        pos_encoding = self.coord_projection(coords)
        
        return x + pos_encoding


class CategoricalPositionalEncoding(nn.Module):
    """
    Categorical positional encoding for proteomic features.
    Groups proteins by biological pathway/function using PyTorch Embedding.
    """
    
    def __init__(self, embed_dim: int, protein_categories: List[str]):
        super().__init__()
        self.protein_categories = protein_categories
        self.num_categories = len(protein_categories)
        
        # Use PyTorch's optimized Embedding layers
        self.category_embedding = nn.Embedding(self.num_categories, embed_dim)
        self.position_embedding = nn.Embedding(50, embed_dim)  # Max 50 proteins per category
        
        # Initialize embeddings with small random values
        nn.init.normal_(self.category_embedding.weight, std=0.02)
        nn.init.normal_(self.position_embedding.weight, std=0.02)
    
    def forward(self, x: torch.Tensor, 
                category_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Args:
            x: [batch_size, num_proteins, embed_dim]
            category_ids: [num_proteins] - category ID for each protein
        """
        batch_size, num_proteins, _ = x.shape
        
        if category_ids is None:
            # Default: distribute proteins across categories
            category_ids = torch.arange(num_proteins, device=x.device) % self.num_categories
        
        # Position within category
        position_ids = torch.arange(num_proteins, device=x.device) // self.num_categories
        
        # Get embeddings
        category_emb = self.category_embedding(category_ids)  # [num_proteins, embed_dim]
        position_emb = self.position_embedding(position_ids)  # [num_proteins, embed_dim]
        
        # Combine and add to input
        pos_encoding = category_emb + position_emb
        return x + pos_encoding.unsqueeze(0)  # Broadcast across batch


# Define protein categories for AD research
AD_PROTEIN_CATEGORIES = [
    'amyloid_pathway',    # Aβ40, Aβ42
    'tau_pathway',        # p-tau, t-tau
    'inflammation',       # cytokines
    'neurodegeneration',  # neurofilament
    'synaptic',          # synaptic proteins
    'vascular',          # vascular markers
    'metabolic',         # metabolic proteins
    'other'              # uncategorized/
]

print(f"Defined {len(AD_PROTEIN_CATEGORIES)} protein categories for AD research")
print("Categories:", AD_PROTEIN_CATEGORIES)


Defined 8 protein categories for AD research
Categories: ['amyloid_pathway', 'tau_pathway', 'inflammation', 'neurodegeneration', 'synaptic', 'vascular', 'metabolic', 'other']


## 🔄 **Step 2: Cross-Attention Transformer (PyTorch + Custom)**

Here we use **PyTorch's built-in** components where possible, but implement **custom cross-modal orchestration**.

### 🏗️ **Architecture Strategy:**
1. **Self-attention**: Use `nn.TransformerEncoderLayer` (PyTorch built-in)
2. **Cross-attention**: Use `nn.MultiheadAttention` (PyTorch built-in)
3. **Orchestration**: Custom logic to coordinate cross-modal attention


In [5]:
class MultimodalCrossAttentionTransformer(nn.Module):
    """
    Complete multimodal transformer using PyTorch components + custom cross-modal logic.
    """
    
    def __init__(self, 
                 mri_embed_dim: int = 768,
                 proteomic_embed_dim: int = 8,
                 hidden_dim: int = 256,
                 num_heads: int = 8,
                 num_self_layers: int = 2,
                 num_cross_layers: int = 2,
                 num_classes: int = 2,
                 dropout: float = 0.1):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        
        # Input projections to common hidden dimension
        self.mri_projection = nn.Linear(mri_embed_dim, hidden_dim)
        self.proteomic_projection = nn.Linear(proteomic_embed_dim, hidden_dim)
        
        # Positional encodings
        self.mri_pos_encoder = Spatial3DPositionalEncoding(hidden_dim)
        self.protein_pos_encoder = CategoricalPositionalEncoding(hidden_dim, AD_PROTEIN_CATEGORIES)
        
        # Self-attention layers using PyTorch TransformerEncoderLayer
        mri_encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=dropout,
            activation='gelu',
            batch_first=True
        )
        self.mri_self_encoder = nn.TransformerEncoder(mri_encoder_layer, num_self_layers)
        
        protein_encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=dropout,
            activation='gelu',
            batch_first=True
        )
        self.protein_self_encoder = nn.TransformerEncoder(protein_encoder_layer, num_self_layers)
        
        # Cross-attention layers using PyTorch MultiheadAttention
        self.mri_to_protein_attention = nn.ModuleList([
            nn.MultiheadAttention(hidden_dim, num_heads, dropout, batch_first=True)
            for _ in range(num_cross_layers)
        ])
        
        self.protein_to_mri_attention = nn.ModuleList([
            nn.MultiheadAttention(hidden_dim, num_heads, dropout, batch_first=True)
            for _ in range(num_cross_layers)
        ])
        
        # Layer norms for residual connections
        self.mri_cross_norms = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_cross_layers)])
        self.protein_cross_norms = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_cross_layers)])
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, num_classes)
        )
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, mri_embeddings: torch.Tensor, proteomic_embeddings: torch.Tensor,
                mri_coordinates: Optional[torch.Tensor] = None,
                protein_categories: Optional[torch.Tensor] = None,
                return_attention: bool = False) -> Dict[str, torch.Tensor]:
        """
        Forward pass through the multimodal transformer.
        
        Args:
            mri_embeddings: [batch_size, num_patches, mri_embed_dim]
            proteomic_embeddings: [batch_size, num_proteins, proteomic_embed_dim]
            mri_coordinates: [batch_size, num_patches, 3] - optional 3D coordinates
            protein_categories: [num_proteins] - optional category assignments
            return_attention: whether to return attention weights for visualization
        """
        # Project to common hidden dimension
        mri_hidden = self.mri_projection(mri_embeddings)
        protein_hidden = self.proteomic_projection(proteomic_embeddings)
        
        # Add positional encodings
        mri_hidden = self.mri_pos_encoder(mri_hidden, mri_coordinates)
        protein_hidden = self.protein_pos_encoder(protein_hidden, protein_categories)
        
        # Apply dropout
        mri_hidden = self.dropout(mri_hidden)
        protein_hidden = self.dropout(protein_hidden)
        
        # Self-attention within each modality (PyTorch built-in)
        mri_hidden = self.mri_self_encoder(mri_hidden)
        protein_hidden = self.protein_self_encoder(protein_hidden)
        
        # Cross-attention between modalities (custom orchestration with PyTorch attention)
        attention_weights = {}
        
        for i, (mri_cross_attn, protein_cross_attn, mri_norm, protein_norm) in enumerate(
            zip(self.mri_to_protein_attention, self.protein_to_mri_attention, 
                self.mri_cross_norms, self.protein_cross_norms)
        ):
            # MRI attending to proteins
            mri_attended, mri_attn = mri_cross_attn(
                query=mri_hidden, key=protein_hidden, value=protein_hidden,
                need_weights=return_attention
            )
            mri_hidden = mri_norm(mri_hidden + mri_attended)
            
            # Proteins attending to MRI
            protein_attended, protein_attn = protein_cross_attn(
                query=protein_hidden, key=mri_hidden, value=mri_hidden,
                need_weights=return_attention
            )
            protein_hidden = protein_norm(protein_hidden + protein_attended)
            
            if return_attention:
                attention_weights[f'mri_to_protein_layer_{i}'] = mri_attn
                attention_weights[f'protein_to_mri_layer_{i}'] = protein_attn
        
        # Global pooling and fusion
        mri_pooled = torch.mean(mri_hidden, dim=1)  # [batch, hidden_dim]
        protein_pooled = torch.mean(protein_hidden, dim=1)  # [batch, hidden_dim]
        
        # Concatenate modalities for classification
        fused_features = torch.cat([mri_pooled, protein_pooled], dim=1)  # [batch, hidden_dim*2]
        
        # Classification
        logits = self.classifier(fused_features)
        
        outputs = {'logits': logits}
        if return_attention:
            outputs['attention_weights'] = attention_weights
            
        return outputs


print("✅ Cross-attention transformer implemented using PyTorch + custom components!")


✅ Cross-attention transformer implemented using PyTorch + custom components!


## 🧪 **Step 3: Test the Complete Model**

Let's test our hybrid PyTorch + custom implementation with sample data to see the cross-attention in action!


In [6]:
# Create model
model = MultimodalCrossAttentionTransformer(
    mri_embed_dim=768,
    proteomic_embed_dim=8,
    hidden_dim=256,
    num_heads=8,
    num_cross_layers=2
).to(device)

# Create sample data
batch_size = 4
num_mri_patches = 64
num_proteins = 8

mri_embeddings = torch.randn(batch_size, num_mri_patches, 768).to(device)
proteomic_embeddings = torch.randn(batch_size, num_proteins, 8).to(device)

# Optional: provide spatial coordinates and protein categories
mri_coordinates = torch.randn(batch_size, num_mri_patches, 3).to(device)
protein_category_ids = torch.tensor([0, 0, 1, 1, 2, 3, 4, 5]).to(device)

print(f"Input shapes:")
print(f"  MRI embeddings: {mri_embeddings.shape}")
print(f"  Proteomic embeddings: {proteomic_embeddings.shape}")
print(f"  MRI coordinates: {mri_coordinates.shape}")
print(f"  Protein categories: {protein_category_ids.shape}")

# Forward pass
with torch.no_grad():
    outputs = model(
        mri_embeddings=mri_embeddings,
        proteomic_embeddings=proteomic_embeddings,
        mri_coordinates=mri_coordinates,
        protein_categories=protein_category_ids,
        return_attention=True
    )

print(f"\nModel outputs:")
print(f"  Logits shape: {outputs['logits'].shape}")
print(f"  Number of attention maps: {len(outputs['attention_weights'])}")
print(f"  Attention map keys: {list(outputs['attention_weights'].keys())}")

# Check model parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nModel statistics:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")

print("\n✅ Model test successful!")


Input shapes:
  MRI embeddings: torch.Size([4, 64, 768])
  Proteomic embeddings: torch.Size([4, 8, 8])
  MRI coordinates: torch.Size([4, 64, 3])
  Protein categories: torch.Size([8])

Model outputs:
  Logits shape: torch.Size([4, 2])
  Number of attention maps: 4
  Attention map keys: ['mri_to_protein_layer_0', 'protein_to_mri_layer_0', 'mri_to_protein_layer_1', 'protein_to_mri_layer_1']

Model statistics:
  Total parameters: 4,625,794
  Trainable parameters: 4,625,794

✅ Model test successful!
