# Multi-Modal Fusion Model
## Image + IoT + Geo -> Disease Classification + Outbreak Prediction + Severity

**Architecture**:
- Swin backbone (frozen) -> 256-d visual features
- Temporal Transformer (4L, 8H) -> 256-d temporal features
- Spatial MLP -> 128-d spatial features
- Cross-Attention + Gated Fusion -> 640-d
- Three prediction heads (disease, outbreak 7d, severity)

In [None]:
import sys
sys.path.insert(0, '../..')

import torch
import numpy as np
import matplotlib.pyplot as plt

from src.models.fusion_model import MultiModalFusionModel
from src.models.losses import MultiTaskLoss
from src.utils.config import ConfigManager
from src.utils.seed import set_seed

cfg = ConfigManager('../../configs/config.yaml')
set_seed(cfg.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

## 1. Model Architecture Verification

In [None]:
model = MultiModalFusionModel(
    num_classes=15,
    swin_model_name=cfg.model.variant,
    freeze_swin=True,
    feature_dim=768,
    visual_proj_dim=256,
    temporal_config={
        'num_features': 7, 'd_model': 128, 'nhead': 8,
        'num_layers': 4, 'output_dim': 256, 'sequence_length': 30,
    },
    spatial_config={'input_dim': 3, 'hidden_dim': 64, 'output_dim': 128},
    fusion_config={'fused_dim': 640},
)

total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total parameters: {total:,}')
print(f'Trainable parameters: {trainable:,}')
print(f'Frozen (Swin): {total - trainable:,}')

In [None]:
# Test forward pass
batch_size = 2
images = torch.randn(batch_size, 3, 224, 224)
iot_seq = torch.randn(batch_size, 30, 7)
geo = torch.randn(batch_size, 3)

with torch.no_grad():
    outputs = model(images, iot_seq, geo)

print(f'Disease logits: {outputs["disease_logits"].shape}')
print(f'Outbreak risk:  {outputs["outbreak_risk"].shape}')
print(f'Severity:       {outputs["severity_logits"].shape}')

## 2. Multi-Task Loss Verification

In [None]:
mtl = MultiTaskLoss(num_tasks=3, task_names=['disease', 'outbreak', 'severity'])

# Simulate losses
losses = [
    torch.tensor(1.5, requires_grad=True),
    torch.tensor(0.3, requires_grad=True),
    torch.tensor(0.8, requires_grad=True),
]

total, loss_dict = mtl(losses)
print(f'Total loss: {total.item():.4f}')
print(f'\nPer-task details:')
for k, v in loss_dict.items():
    print(f'  {k}: {v:.4f}')

## 3. Training

Run the full fusion training with:
```bash
python train_fusion.py --config configs/config.yaml --swin_checkpoint models/checkpoints/best_swin_classifier.pth
```

## 4. Results Analysis (after training)

In [None]:
from pathlib import Path

ckpt_path = Path('../../models/checkpoints/best_fusion_model.pth')
if ckpt_path.exists():
    ckpt = torch.load(ckpt_path, map_location='cpu')
    print(f'Best epoch: {ckpt["epoch"]}')
    print(f'Best val_total_loss: {ckpt.get("best_val_total_loss", "N/A")}')
else:
    print('Fusion model checkpoint not found. Train the model first.')