# üè• Multimodal Medical AI: Practical Implementation

## Table of Contents
1. [Image-Text Alignment with Medical Data](#practice-1-image-text-alignment)
2. [ECG Signal Processing](#practice-2-ecg-signal-processing)
3. [Multimodal Fusion Strategies](#practice-3-multimodal-fusion)
4. [Building a Simple Medical VLM](#practice-4-medical-vlm)
5. [Attention-based Fusion](#practice-5-attention-fusion)

### üéØ Learning Objectives
- Implement basic multimodal fusion techniques
- Process medical signals (ECG) for deep learning
- Build image-text encoders for medical data
- Apply attention mechanisms for cross-modal fusion
- Evaluate multimodal model performance

## Installing and Importing Essential Libraries

In [None]:
# Install required packages (uncomment if needed)
# !pip install torch torchvision transformers scikit-learn pandas numpy matplotlib seaborn

# Import essential libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score
from sklearn.preprocessing import StandardScaler
import torch
import torch.nn as nn
import torch.nn.functional as F
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Visualization settings
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 11
sns.set_style('whitegrid')

print("‚úÖ All libraries loaded successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")

---
## Practice 1: Image-Text Alignment with Medical Data

### üéØ Learning Objectives
- Understand contrastive learning for image-text pairs
- Implement a simple CLIP-style alignment
- Calculate cosine similarity between modalities

### üìñ Key Concepts
**Contrastive Learning:** Maximize similarity for matching pairs, minimize for non-matching pairs

**InfoNCE Loss:** $\mathcal{L} = -\log \frac{\exp(\text{sim}(i, t) / \tau)}{\sum_{j} \exp(\text{sim}(i, t_j) / \tau)}$

In [None]:
# 1.1 Simulate medical image and text embeddings
def generate_medical_embeddings(n_samples=100, embedding_dim=128):
    """
    Simulate embeddings for chest X-rays and radiology reports
    In practice, these would come from CNN and text encoders
    """
    # Simulate image embeddings (from ResNet/ViT)
    image_embeddings = np.random.randn(n_samples, embedding_dim)
    
    # Simulate text embeddings (from BERT/BioClinicalBERT)
    # Add some correlation with images for matching pairs
    text_embeddings = image_embeddings + np.random.randn(n_samples, embedding_dim) * 0.3
    
    # Normalize to unit vectors
    image_embeddings = image_embeddings / np.linalg.norm(image_embeddings, axis=1, keepdims=True)
    text_embeddings = text_embeddings / np.linalg.norm(text_embeddings, axis=1, keepdims=True)
    
    return image_embeddings, text_embeddings

# Generate embeddings
img_emb, txt_emb = generate_medical_embeddings(n_samples=50)

print(f"Image embeddings shape: {img_emb.shape}")
print(f"Text embeddings shape: {txt_emb.shape}")
print(f"\nSample image embedding (first 5 dims): {img_emb[0][:5]}")
print(f"Sample text embedding (first 5 dims): {txt_emb[0][:5]}")

In [None]:
# 1.2 Calculate similarity matrix and visualize
def calculate_similarity_matrix(img_emb, txt_emb):
    """
    Calculate cosine similarity between all image-text pairs
    """
    # Cosine similarity = dot product of normalized vectors
    similarity_matrix = img_emb @ txt_emb.T
    
    return similarity_matrix

# Calculate similarities
sim_matrix = calculate_similarity_matrix(img_emb, txt_emb)

# Visualize
plt.figure(figsize=(10, 8))
sns.heatmap(sim_matrix[:20, :20], cmap='RdYlBu_r', center=0, 
            square=True, linewidths=0.5, cbar_kws={"shrink": 0.8})
plt.title('Image-Text Similarity Matrix (First 20 samples)', fontsize=14, fontweight='bold')
plt.xlabel('Text Index', fontsize=12)
plt.ylabel('Image Index', fontsize=12)
plt.tight_layout()
plt.show()

print(f"\n‚úÖ Similarity matrix shape: {sim_matrix.shape}")
print(f"Diagonal (matching pairs) mean similarity: {np.diag(sim_matrix).mean():.4f}")
print(f"Off-diagonal (non-matching) mean similarity: {(sim_matrix.sum() - np.diag(sim_matrix).sum()) / (sim_matrix.size - len(sim_matrix)):.4f}")

In [None]:
# 1.3 Implement InfoNCE (Contrastive) Loss
def contrastive_loss(similarity_matrix, temperature=0.07):
    """
    Calculate InfoNCE loss for contrastive learning
    
    Args:
        similarity_matrix: (N, N) cosine similarity matrix
        temperature: temperature parameter for scaling
    """
    n = similarity_matrix.shape[0]
    
    # Scale by temperature
    logits = similarity_matrix / temperature
    
    # Labels are diagonal indices (matching pairs)
    labels = np.arange(n)
    
    # Calculate cross-entropy loss
    # Image-to-text
    i2t_loss = -np.mean(np.log(np.exp(logits[labels, labels]) / np.exp(logits).sum(axis=1)))
    
    # Text-to-image
    t2i_loss = -np.mean(np.log(np.exp(logits[labels, labels]) / np.exp(logits).sum(axis=0)))
    
    # Total loss
    total_loss = (i2t_loss + t2i_loss) / 2
    
    return total_loss, i2t_loss, t2i_loss

# Calculate loss
total_loss, i2t_loss, t2i_loss = contrastive_loss(sim_matrix)

print("üìä Contrastive Loss Results:")
print("=" * 50)
print(f"Image-to-Text Loss: {i2t_loss:.4f}")
print(f"Text-to-Image Loss: {t2i_loss:.4f}")
print(f"Total Loss: {total_loss:.4f}")
print("\nüí° Lower loss indicates better alignment between modalities")

---
## Practice 2: ECG Signal Processing

### üéØ Learning Objectives
- Generate and visualize ECG waveforms
- Apply signal preprocessing techniques
- Extract features for arrhythmia detection

### üìñ Key Concepts
**ECG Components:** P wave (atrial depolarization), QRS complex (ventricular depolarization), T wave (repolarization)

**Preprocessing:** Filtering, baseline correction, normalization, R-peak detection

In [None]:
# 2.1 Generate synthetic ECG signal
def generate_ecg_signal(duration=10, sampling_rate=360, heart_rate=75):
    """
    Generate synthetic ECG signal (simplified)
    
    Args:
        duration: signal duration in seconds
        sampling_rate: samples per second
        heart_rate: beats per minute
    """
    t = np.linspace(0, duration, duration * sampling_rate)
    
    # Calculate beats
    beat_interval = 60 / heart_rate  # seconds per beat
    
    ecg = np.zeros_like(t)
    
    # Generate QRS complexes
    for beat_time in np.arange(0, duration, beat_interval):
        beat_idx = int(beat_time * sampling_rate)
        
        # QRS complex (simplified Gaussian)
        qrs_width = 0.08  # seconds
        qrs_sigma = qrs_width * sampling_rate / 6
        qrs_indices = np.arange(max(0, beat_idx - int(3*qrs_sigma)), 
                                min(len(t), beat_idx + int(3*qrs_sigma)))
        
        if len(qrs_indices) > 0:
            ecg[qrs_indices] += 1.5 * np.exp(-0.5 * ((qrs_indices - beat_idx) / qrs_sigma) ** 2)
        
        # P wave (before QRS)
        p_idx = beat_idx - int(0.15 * sampling_rate)
        if p_idx > 0 and p_idx < len(t):
            p_width = int(0.08 * sampling_rate)
            p_indices = np.arange(max(0, p_idx - p_width), min(len(t), p_idx + p_width))
            if len(p_indices) > 0:
                ecg[p_indices] += 0.3 * np.exp(-0.5 * ((p_indices - p_idx) / (p_width/3)) ** 2)
        
        # T wave (after QRS)
        t_idx = beat_idx + int(0.25 * sampling_rate)
        if t_idx < len(t):
            t_width = int(0.15 * sampling_rate)
            t_indices = np.arange(max(0, t_idx - t_width), min(len(t), t_idx + t_width))
            if len(t_indices) > 0:
                ecg[t_indices] += 0.5 * np.exp(-0.5 * ((t_indices - t_idx) / (t_width/3)) ** 2)
    
    # Add baseline and noise
    baseline = 0.1 * np.sin(2 * np.pi * 0.2 * t)  # baseline wander
    noise = np.random.normal(0, 0.05, len(t))  # random noise
    ecg = ecg + baseline + noise
    
    return t, ecg

# Generate normal and abnormal ECG
t_normal, ecg_normal = generate_ecg_signal(duration=5, heart_rate=75)
t_tachy, ecg_tachy = generate_ecg_signal(duration=5, heart_rate=120)  # Tachycardia

# Visualize
fig, axes = plt.subplots(2, 1, figsize=(14, 8))

axes[0].plot(t_normal, ecg_normal, 'b-', linewidth=1.5)
axes[0].set_title('Normal ECG (HR: 75 bpm)', fontsize=13, fontweight='bold')
axes[0].set_ylabel('Amplitude (mV)', fontsize=11)
axes[0].grid(True, alpha=0.3)

axes[1].plot(t_tachy, ecg_tachy, 'r-', linewidth=1.5)
axes[1].set_title('Tachycardia ECG (HR: 120 bpm)', fontsize=13, fontweight='bold')
axes[1].set_xlabel('Time (seconds)', fontsize=11)
axes[1].set_ylabel('Amplitude (mV)', fontsize=11)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"‚úÖ Generated ECG signals")
print(f"Signal length: {len(ecg_normal)} samples")
print(f"Sampling rate: 360 Hz")
print(f"Duration: 5 seconds")

In [None]:
# 2.2 ECG Preprocessing Pipeline
def preprocess_ecg(ecg_signal, sampling_rate=360):
    """
    Apply preprocessing steps to ECG signal
    """
    # 1. Baseline correction (simple detrending)
    from scipy import signal as scipy_signal
    ecg_detrended = scipy_signal.detrend(ecg_signal)
    
    # 2. Band-pass filtering (0.5-40 Hz typical for ECG)
    nyquist = sampling_rate / 2
    low = 0.5 / nyquist
    high = 40 / nyquist
    b, a = scipy_signal.butter(4, [low, high], btype='band')
    ecg_filtered = scipy_signal.filtfilt(b, a, ecg_detrended)
    
    # 3. Normalization (0-1 scaling)
    ecg_normalized = (ecg_filtered - ecg_filtered.min()) / (ecg_filtered.max() - ecg_filtered.min())
    
    return ecg_normalized

# Preprocess signals
ecg_normal_processed = preprocess_ecg(ecg_normal)
ecg_tachy_processed = preprocess_ecg(ecg_tachy)

# Visualize preprocessing effect
fig, axes = plt.subplots(2, 2, figsize=(14, 8))

# Before preprocessing
axes[0, 0].plot(t_normal[:720], ecg_normal[:720], 'b-', alpha=0.7)
axes[0, 0].set_title('Before Preprocessing (Normal)', fontweight='bold')
axes[0, 0].set_ylabel('Amplitude')

axes[1, 0].plot(t_tachy[:720], ecg_tachy[:720], 'r-', alpha=0.7)
axes[1, 0].set_title('Before Preprocessing (Tachycardia)', fontweight='bold')
axes[1, 0].set_xlabel('Time (s)')
axes[1, 0].set_ylabel('Amplitude')

# After preprocessing
axes[0, 1].plot(t_normal[:720], ecg_normal_processed[:720], 'b-', linewidth=1.5)
axes[0, 1].set_title('After Preprocessing (Normal)', fontweight='bold')

axes[1, 1].plot(t_tachy[:720], ecg_tachy_processed[:720], 'r-', linewidth=1.5)
axes[1, 1].set_title('After Preprocessing (Tachycardia)', fontweight='bold')
axes[1, 1].set_xlabel('Time (s)')

plt.tight_layout()
plt.show()

print("‚úÖ ECG preprocessing completed")
print("Applied: Detrending ‚Üí Band-pass Filtering ‚Üí Normalization")

---
## Practice 3: Multimodal Fusion Strategies

### üéØ Learning Objectives
- Implement Early, Late, and Intermediate Fusion
- Compare fusion strategies on synthetic medical data
- Understand trade-offs between fusion approaches

### üìñ Key Concepts
**Early Fusion:** Concatenate features before model

**Late Fusion:** Combine predictions from separate models

**Intermediate Fusion:** Combine at middle layers

In [None]:
# 3.1 Generate synthetic multimodal medical data
def generate_multimodal_data(n_samples=1000):
    """
    Simulate patient data with multiple modalities
    Modality 1: Clinical features (age, BP, lab results)
    Modality 2: Image features (from X-ray/CT)
    Task: Binary classification (disease present or not)
    """
    # Modality 1: Clinical features (5 features)
    clinical_features = np.random.randn(n_samples, 5)
    
    # Modality 2: Image features (10 features)
    image_features = np.random.randn(n_samples, 10)
    
    # Create labels with correlation to both modalities
    clinical_score = clinical_features[:, 0] + clinical_features[:, 1] * 0.5
    image_score = image_features[:, 0] + image_features[:, 2] * 0.7
    
    combined_score = clinical_score + image_score
    labels = (combined_score > 0).astype(int)
    
    return clinical_features, image_features, labels

# Generate data
clinical_data, image_data, labels = generate_multimodal_data(n_samples=1000)

# Split data
from sklearn.model_selection import train_test_split

clinical_train, clinical_test, image_train, image_test, y_train, y_test = train_test_split(
    clinical_data, image_data, labels, test_size=0.2, random_state=42
)

print("üìä Multimodal Dataset Statistics:")
print("=" * 50)
print(f"Total samples: {len(labels)}")
print(f"Training samples: {len(y_train)}")
print(f"Test samples: {len(y_test)}")
print(f"\nClinical features shape: {clinical_data.shape}")
print(f"Image features shape: {image_data.shape}")
print(f"\nClass distribution: {np.bincount(labels)}")
print(f"Positive rate: {labels.mean():.2%}")

In [None]:
# 3.2 Implement different fusion strategies
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, roc_auc_score

def early_fusion(clinical_train, image_train, clinical_test, image_test, y_train, y_test):
    """Early Fusion: Concatenate features before model"""
    # Concatenate features
    X_train = np.concatenate([clinical_train, image_train], axis=1)
    X_test = np.concatenate([clinical_test, image_test], axis=1)
    
    # Train model
    model = LogisticRegression(random_state=42, max_iter=1000)
    model.fit(X_train, y_train)
    
    # Predict
    y_pred = model.predict(X_test)
    y_proba = model.predict_proba(X_test)[:, 1]
    
    acc = accuracy_score(y_test, y_pred)
    auc = roc_auc_score(y_test, y_proba)
    
    return acc, auc, model

def late_fusion(clinical_train, image_train, clinical_test, image_test, y_train, y_test):
    """Late Fusion: Train separate models and combine predictions"""
    # Train clinical model
    clinical_model = LogisticRegression(random_state=42, max_iter=1000)
    clinical_model.fit(clinical_train, y_train)
    clinical_proba = clinical_model.predict_proba(clinical_test)[:, 1]
    
    # Train image model
    image_model = LogisticRegression(random_state=42, max_iter=1000)
    image_model.fit(image_train, y_train)
    image_proba = image_model.predict_proba(image_test)[:, 1]
    
    # Combine predictions (average)
    combined_proba = (clinical_proba + image_proba) / 2
    y_pred = (combined_proba > 0.5).astype(int)
    
    acc = accuracy_score(y_test, y_pred)
    auc = roc_auc_score(y_test, combined_proba)
    
    return acc, auc, (clinical_model, image_model)

def unimodal_baseline(features_train, features_test, y_train, y_test, modality_name):
    """Unimodal baseline: Single modality only"""
    model = LogisticRegression(random_state=42, max_iter=1000)
    model.fit(features_train, y_train)
    
    y_pred = model.predict(features_test)
    y_proba = model.predict_proba(features_test)[:, 1]
    
    acc = accuracy_score(y_test, y_pred)
    auc = roc_auc_score(y_test, y_proba)
    
    return acc, auc, model

# Run all fusion strategies
print("üîÑ Training models with different fusion strategies...\n")

# Unimodal baselines
clinical_acc, clinical_auc, _ = unimodal_baseline(clinical_train, clinical_test, y_train, y_test, "Clinical")
image_acc, image_auc, _ = unimodal_baseline(image_train, image_test, y_train, y_test, "Image")

# Fusion methods
early_acc, early_auc, _ = early_fusion(clinical_train, image_train, clinical_test, image_test, y_train, y_test)
late_acc, late_auc, _ = late_fusion(clinical_train, image_train, clinical_test, image_test, y_train, y_test)

# Display results
results_df = pd.DataFrame({
    'Method': ['Clinical Only', 'Image Only', 'Early Fusion', 'Late Fusion'],
    'Accuracy': [clinical_acc, image_acc, early_acc, late_acc],
    'AUC': [clinical_auc, image_auc, early_auc, late_auc]
})

print("üìä Fusion Strategy Comparison:")
print("=" * 60)
print(results_df.to_string(index=False))
print("\nüí° Multimodal fusion typically outperforms unimodal approaches!")

# Visualize
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

methods = results_df['Method']
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4']

axes[0].bar(methods, results_df['Accuracy'], color=colors, alpha=0.8)
axes[0].set_ylabel('Accuracy', fontsize=12)
axes[0].set_title('Accuracy by Fusion Method', fontsize=13, fontweight='bold')
axes[0].set_ylim([0.5, 1.0])
axes[0].tick_params(axis='x', rotation=15)

axes[1].bar(methods, results_df['AUC'], color=colors, alpha=0.8)
axes[1].set_ylabel('AUC-ROC', fontsize=12)
axes[1].set_title('AUC-ROC by Fusion Method', fontsize=13, fontweight='bold')
axes[1].set_ylim([0.5, 1.0])
axes[1].tick_params(axis='x', rotation=15)

plt.tight_layout()
plt.show()

---
## Practice 4: Building a Simple Medical Vision-Language Model

### üéØ Learning Objectives
- Build simple image and text encoders
- Implement projection layers for common embedding space
- Train with contrastive loss

### üìñ Key Concepts
**Vision-Language Model:** Aligns visual and textual representations in a shared space

In [None]:
# 4.1 Define simple multimodal model architecture
class SimpleMedicalVLM(nn.Module):
    """
    Simplified Medical Vision-Language Model
    """
    def __init__(self, image_dim=512, text_dim=768, projection_dim=256):
        super(SimpleMedicalVLM, self).__init__()
        
        # Image encoder (simplified - normally would be ResNet/ViT)
        self.image_encoder = nn.Sequential(
            nn.Linear(image_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 256)
        )
        
        # Text encoder (simplified - normally would be BERT)
        self.text_encoder = nn.Sequential(
            nn.Linear(text_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 256)
        )
        
        # Projection heads
        self.image_projection = nn.Linear(256, projection_dim)
        self.text_projection = nn.Linear(256, projection_dim)
        
        # Temperature parameter for contrastive loss
        self.temperature = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
    
    def forward(self, image_features, text_features):
        # Encode
        image_encoded = self.image_encoder(image_features)
        text_encoded = self.text_encoder(text_features)
        
        # Project to common space
        image_projected = self.image_projection(image_encoded)
        text_projected = self.text_projection(text_encoded)
        
        # Normalize
        image_projected = F.normalize(image_projected, dim=-1)
        text_projected = F.normalize(text_projected, dim=-1)
        
        return image_projected, text_projected
    
    def contrastive_loss(self, image_embeddings, text_embeddings):
        """
        Calculate bidirectional contrastive loss
        """
        # Calculate similarity matrix
        logits = torch.matmul(image_embeddings, text_embeddings.T) * self.temperature.exp()
        
        # Labels are diagonal
        batch_size = image_embeddings.shape[0]
        labels = torch.arange(batch_size, device=image_embeddings.device)
        
        # Cross entropy loss (both directions)
        loss_i2t = F.cross_entropy(logits, labels)
        loss_t2i = F.cross_entropy(logits.T, labels)
        
        return (loss_i2t + loss_t2i) / 2

# Initialize model
model = SimpleMedicalVLM(image_dim=512, text_dim=768, projection_dim=256)

print("‚úÖ Medical VLM Model Architecture:")
print("=" * 60)
print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# 4.2 Generate synthetic training data and train
def generate_synthetic_vlm_data(n_samples=500, image_dim=512, text_dim=768):
    """
    Generate synthetic image-text pairs for training
    """
    image_features = torch.randn(n_samples, image_dim)
    # Add correlation between image and text
    text_features = torch.randn(n_samples, text_dim)
    # First few dimensions correlated
    text_features[:, :image_dim//2] += image_features[:, :image_dim//2] * 0.5
    
    return image_features, text_features

# Generate data
train_images, train_texts = generate_synthetic_vlm_data(n_samples=400)
val_images, val_texts = generate_synthetic_vlm_data(n_samples=100)

# Training setup
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
batch_size = 32
n_epochs = 10

train_losses = []
val_losses = []

print("üöÄ Training Medical VLM...\n")

# Training loop
for epoch in range(n_epochs):
    model.train()
    epoch_loss = 0
    n_batches = len(train_images) // batch_size
    
    for i in range(n_batches):
        start_idx = i * batch_size
        end_idx = start_idx + batch_size
        
        batch_images = train_images[start_idx:end_idx]
        batch_texts = train_texts[start_idx:end_idx]
        
        # Forward pass
        image_emb, text_emb = model(batch_images, batch_texts)
        loss = model.contrastive_loss(image_emb, text_emb)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    avg_train_loss = epoch_loss / n_batches
    train_losses.append(avg_train_loss)
    
    # Validation
    model.eval()
    with torch.no_grad():
        val_image_emb, val_text_emb = model(val_images, val_texts)
        val_loss = model.contrastive_loss(val_image_emb, val_text_emb).item()
        val_losses.append(val_loss)
    
    if (epoch + 1) % 2 == 0:
        print(f"Epoch [{epoch+1}/{n_epochs}] - Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}")

print("\n‚úÖ Training completed!")

# Plot training curve
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss', linewidth=2)
plt.plot(val_losses, label='Validation Loss', linewidth=2)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Contrastive Loss', fontsize=12)
plt.title('Medical VLM Training Curve', fontsize=14, fontweight='bold')
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

---
## Practice 5: Attention-based Fusion

### üéØ Learning Objectives
- Implement cross-attention mechanism
- Apply attention for multimodal fusion
- Visualize attention weights

### üìñ Key Concepts
**Cross-Attention:** One modality queries another modality

**Attention Formula:** $\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$

In [None]:
# 5.1 Implement cross-attention mechanism
class CrossAttentionFusion(nn.Module):
    """
    Cross-attention module for multimodal fusion
    """
    def __init__(self, dim=128, num_heads=4):
        super(CrossAttentionFusion, self).__init__()
        self.num_heads = num_heads
        self.dim = dim
        self.head_dim = dim // num_heads
        
        assert self.head_dim * num_heads == dim, "dim must be divisible by num_heads"
        
        # Query, Key, Value projections
        self.query = nn.Linear(dim, dim)
        self.key = nn.Linear(dim, dim)
        self.value = nn.Linear(dim, dim)
        
        # Output projection
        self.out = nn.Linear(dim, dim)
        
        self.scale = self.head_dim ** -0.5
    
    def forward(self, query_input, key_value_input):
        """
        Args:
            query_input: (batch, seq_len_q, dim) - e.g., image features
            key_value_input: (batch, seq_len_kv, dim) - e.g., text features
        """
        batch_size = query_input.shape[0]
        
        # Linear projections and reshape for multi-head attention
        Q = self.query(query_input).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.key(key_value_input).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.value(key_value_input).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale  # (batch, heads, seq_q, seq_kv)
        attention_weights = F.softmax(scores, dim=-1)
        
        # Apply attention to values
        attended = torch.matmul(attention_weights, V)  # (batch, heads, seq_q, head_dim)
        
        # Concatenate heads
        attended = attended.transpose(1, 2).contiguous().view(batch_size, -1, self.dim)
        
        # Output projection
        output = self.out(attended)
        
        return output, attention_weights

# Test cross-attention
cross_attn = CrossAttentionFusion(dim=128, num_heads=4)

# Create sample inputs
sample_image_features = torch.randn(8, 10, 128)  # batch=8, 10 image patches, dim=128
sample_text_features = torch.randn(8, 20, 128)   # batch=8, 20 text tokens, dim=128

# Forward pass
output, attn_weights = cross_attn(sample_image_features, sample_text_features)

print("‚úÖ Cross-Attention Module:")
print("=" * 60)
print(f"Image features shape: {sample_image_features.shape}")
print(f"Text features shape: {sample_text_features.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")
print(f"  (batch, num_heads, seq_query, seq_key_value)")

In [None]:
# 5.2 Visualize attention weights
def visualize_attention(attention_weights, sample_idx=0, head_idx=0):
    """
    Visualize attention weights for interpretation
    """
    # Get attention for specific sample and head
    attn = attention_weights[sample_idx, head_idx].detach().numpy()
    
    plt.figure(figsize=(12, 8))
    sns.heatmap(attn, cmap='YlOrRd', cbar=True, 
                xticklabels=[f'T{i}' for i in range(attn.shape[1])],
                yticklabels=[f'I{i}' for i in range(attn.shape[0])],
                linewidths=0.5, linecolor='gray')
    plt.title(f'Cross-Attention Weights (Sample {sample_idx}, Head {head_idx})', 
              fontsize=14, fontweight='bold')
    plt.xlabel('Text Token Index', fontsize=12)
    plt.ylabel('Image Patch Index', fontsize=12)
    plt.tight_layout()
    plt.show()
    
    # Statistics
    print(f"\nüìä Attention Statistics:")
    print(f"Min attention: {attn.min():.4f}")
    print(f"Max attention: {attn.max():.4f}")
    print(f"Mean attention: {attn.mean():.4f}")
    print(f"\nüí° Higher values indicate stronger alignment between image patch and text token")

visualize_attention(attn_weights, sample_idx=0, head_idx=0)

---
## üéØ Practice Complete!

### Summary of What We Learned:

1. **Image-Text Alignment**: Implemented contrastive learning for medical image-report pairs
2. **ECG Signal Processing**: Generated and preprocessed ECG signals for arrhythmia detection
3. **Multimodal Fusion**: Compared Early, Late, and Intermediate fusion strategies
4. **Medical VLM**: Built a simplified vision-language model from scratch
5. **Attention Mechanisms**: Implemented cross-attention for multimodal fusion

### Key Insights:
- Multimodal fusion typically outperforms unimodal approaches
- Different fusion strategies have different trade-offs
- Contrastive learning aligns different modalities in a shared space
- Attention mechanisms allow dynamic weighting of information

### Next Steps:
- Implement more complex architectures (Transformers, Graph Neural Networks)
- Work with real medical datasets (MIMIC-CXR, PhysioNet)
- Add missing modality handling
- Explore interpretability techniques (Grad-CAM, SHAP)
- Deploy models for clinical decision support

### üìö Recommended Resources:
- **Papers**: "CLIP" (Radford et al.), "BiomedCLIP", "MedCLIP"
- **Datasets**: MIMIC-CXR, PhysioNet, CheXpert
- **Libraries**: Hugging Face Transformers, PyTorch, timm

---

## üôè Thank you for completing this practice!

**Instructor**: Ho-min Park  
**Email**: homin.park@ghent.ac.kr | powersimmani@gmail.com