In [1]:
"""
Benchmark 2: Image-Only Model
Tests whether MRI scans alone can predict AD conversion

Architecture:
- ResNet18 image encoder
- LSTM for temporal modeling
- NO tabular features
- Shows need for multimodal integration
"""

import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
import numpy as np
import os
import pickle
from lifelines.utils import concordance_index
import warnings
warnings.filterwarnings('ignore')

CONFIG = {
    'img_out_dim': 256,
    'lstm_hidden': 128,
    'lstm_layers': 2,
    'dropout': 0.3,
    'lr': 5e-4,
    'weight_decay': 1e-4,
    'epochs': 100,
    'batch_size': 16,
    'max_seq_len': 10,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

# ============================================================================
# DATASET
# ============================================================================

class ImageOnlyDataset(Dataset):
    def __init__(self, manifest, valid_patients, transform=None, max_seq_len=10):
        self.sequences = []
        self.transform = transform
        self.max_seq_len = max_seq_len
        
        manifest["path"] = manifest["path"].str.replace("\\", "/", regex=False)
        manifest["path"] = "./AD_Multimodal/TFN_AD/" + manifest["path"]

        processed = 0
        skipped_not_valid = 0
        
        for ptid in manifest['PTID'].unique():
            if ptid not in valid_patients:
                skipped_not_valid += 1
                continue
                
            try:
                patient_rows = manifest[manifest['PTID'] == ptid]
                if len(patient_rows) == 0:
                    continue
                
                df = pd.read_pickle(patient_rows.iloc[0]["path"])
                
                dx_seq = df["DX"].tolist()
                if "MCI" not in dx_seq:
                    continue
                
                mci_idx = dx_seq.index("MCI")
                ad_idx = next((i for i, x in enumerate(dx_seq[mci_idx+1:], start=mci_idx+1) 
                              if x in ["AD", "Dementia"]), -1)
                
                if ad_idx != -1:
                    time_to_event = df["Years_bl"].iloc[ad_idx]
                    event = 1
                else:
                    time_to_event = df["Years_bl"].iloc[-1]
                    event = 0
                
                imgs, times = [], []
                valid_visits = 0
                
                for _, visit in df.iterrows():
                    image_path = visit["image_path"].replace(
                        "/home/mason/ADNI_Dataset/", 
                        "./AD_Multimodal/ADNI_Dataset/"
                    )
                    
                    if not os.path.exists(image_path):
                        continue
                    
                    img = Image.open(image_path).convert("RGB")
                    if self.transform:
                        img = self.transform(img)
                    
                    imgs.append(img)
                    times.append(visit["Years_bl"])
                    valid_visits += 1
                    
                    if valid_visits >= max_seq_len:
                        break
                
                if valid_visits < 2:
                    continue
                
                pad_len = max_seq_len - len(imgs)
                if pad_len > 0:
                    for _ in range(pad_len):
                        imgs.append(torch.zeros_like(imgs[-1]))
                        times.append(times[-1])
                
                self.sequences.append({
                    'ptid': ptid,
                    'imgs': torch.stack(imgs),
                    'times': np.array(times, dtype=np.float32),
                    'seq_len': valid_visits,
                    'time_to_event': time_to_event,
                    'event': event
                })

                processed += 1
                
            except Exception as e:
                continue

        print(f"  Processed: {processed} valid patients")
        print(f"  Skipped (not in valid set): {skipped_not_valid}")
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        seq = self.sequences[idx]
        return (
            seq['imgs'], seq['times'], seq['seq_len'],
            seq['time_to_event'], seq['event'], seq['ptid']
        )

# ============================================================================
# MODEL
# ============================================================================

class ImageOnlyModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # Image encoder (ResNet18)
        base = models.resnet18(pretrained=True)
        for param in list(base.parameters())[:-20]:
            param.requires_grad = False
        
        self.features = nn.Sequential(*list(base.children())[:-1])
        self.img_proj = nn.Linear(512, config['img_out_dim'])
        
        # LSTM for temporal modeling
        self.lstm = nn.LSTM(
            input_size=config['img_out_dim'],
            hidden_size=config['lstm_hidden'],
            num_layers=config['lstm_layers'],
            batch_first=True,
            dropout=config['dropout'] if config['lstm_layers'] > 1 else 0,
            bidirectional=True
        )
        
        # Risk prediction head
        self.risk_head = nn.Sequential(
            nn.Linear(config['lstm_hidden'] * 2, 64),
            nn.ReLU(),
            nn.Dropout(config['dropout']),
            nn.Linear(64, 1)
        )
    
    def encode_image(self, img):
        feats = self.features(img)
        feats = feats.view(feats.size(0), -1)
        return self.img_proj(feats)
    
    def forward(self, img_seq, seq_lengths):
        batch_size, seq_len = img_seq.shape[:2]
        
        # Encode each image
        img_feats = []
        for t in range(seq_len):
            feat = self.encode_image(img_seq[:, t])
            img_feats.append(feat)
        
        img_seq_feat = torch.stack(img_feats, dim=1)
        
        # LSTM
        packed = nn.utils.rnn.pack_padded_sequence(
            img_seq_feat, seq_lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        lstm_out, (h_n, c_n) = self.lstm(packed)
        
        h_forward = h_n[-2]
        h_backward = h_n[-1]
        h_final = torch.cat([h_forward, h_backward], dim=1)
        
        risk_score = self.risk_head(h_final)
        
        return risk_score

# ============================================================================
# LOSS & TRAINING
# ============================================================================

def cox_loss(risk_scores, times, events):
    order = torch.argsort(times, descending=True)
    risk_scores = risk_scores[order]
    events = events[order]
    
    log_risk = risk_scores.view(-1)
    log_cumsum_hazard = torch.logcumsumexp(log_risk, dim=0)
    loss = -(log_risk - log_cumsum_hazard) * events
    
    return loss.sum() / (events.sum() + 1e-7)

def train_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0
    
    for batch in loader:
        imgs, times, seq_lens, t_event, event, _ = batch
        
        imgs = imgs.to(device)
        seq_lens = torch.LongTensor(seq_lens)
        t_event = t_event.float().to(device)
        event = event.float().to(device)
        
        risk_scores = model(imgs, seq_lens)
        loss = cox_loss(risk_scores.squeeze(), t_event, event)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(loader)

def validate(model, loader, device):
    model.eval()
    all_risks, all_times, all_events = [], [], []
    
    with torch.no_grad():
        for batch in loader:
            imgs, times, seq_lens, t_event, event, _ = batch
            
            imgs = imgs.to(device)
            seq_lens = torch.LongTensor(seq_lens)
            
            risk_scores = model(imgs, seq_lens)
            
            all_risks.extend(risk_scores.cpu().numpy().flatten())
            all_times.extend(t_event.numpy())
            all_events.extend(event.numpy())
    
    all_events = np.array(all_events).astype(bool)
    all_times = np.array(all_times)
    all_risks = np.array(all_risks)
    
    c_index = concordance_index(all_times, -all_risks, all_events)
    return c_index

# ============================================================================
# EXPORT FOR R
# ============================================================================

def export_features(model, loader, device, output_path):
    """Export image features - SURVIVAL DATA ONLY (no longitudinal)"""
    model.eval()
    rows = []
    
    with torch.no_grad():
        for batch in loader:
            imgs, times, seq_lens, t_event, event, ptids = batch
            
            imgs = imgs.to(device)
            seq_lens = torch.LongTensor(seq_lens)
            
            # Only use FIRST visit for each patient (baseline)
            for i in range(len(ptids)):
                # Extract features from first visit only
                img_feat = model.encode_image(imgs[i:i+1, 0])
                
                row = {
                    "PTID": ptids[i],
                    "Years_bl": float(times[i][0]),
                    "time_to_event": float(t_event[i]),
                    "event": int(event[i])
                }
                
                feat_vals = img_feat[0].cpu().numpy()
                for k in range(len(feat_vals)):
                    row[f"img_feat_{k}"] = float(feat_vals[k])
                
                rows.append(row)
    
    df = pd.DataFrame(rows).sort_values(['PTID', 'Years_bl'])
    df.to_csv(output_path, index=False)
    
    print(f"\n✓ Exported to {output_path}")
    print(f"  Patients: {df['PTID'].nunique()}, Features: {len([c for c in df.columns if c.startswith('img_feat')])}")
    
    return df

# ============================================================================
# MAIN
# ============================================================================

def main():
    print("=" * 80)
    print("BENCHMARK 2: IMAGE-ONLY MODEL")
    print("=" * 80)
    
    device = CONFIG['device']
    print(f"\nDevice: {device}")
    
    manifest = pd.read_csv("./AD_Multimodal/TFN_AD/AD_Patient_Manifest.csv")

    # Load valid patients
    print("\nLoading valid patient list...")
    with open('VALID_PATIENTS.pkl', 'rb') as f:
        VALID_PATIENTS = pickle.load(f)
    print(f"Valid patients to process: {len(VALID_PATIENTS)}")
    
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    dataset = ImageOnlyDataset(manifest, VALID_PATIENTS, transform, max_seq_len=CONFIG['max_seq_len'])
    print(f"Total sequences: {len(dataset)}")
    
    n_train = int(0.8 * len(dataset))
    n_val = len(dataset) - n_train
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [n_train, n_val], generator=torch.Generator().manual_seed(42)
    )
    
    train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], 
                             shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], 
                           shuffle=False, num_workers=0)
    
    model = ImageOnlyModel(CONFIG).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['lr'], 
                                  weight_decay=CONFIG['weight_decay'])
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=5
    )
    
    print("\nTraining...")
    best_c_index = 0
    patience_counter = 0
    
    for epoch in range(CONFIG['epochs']):
        train_loss = train_epoch(model, train_loader, optimizer, device)
        val_c_index = validate(model, val_loader, device)
        
        scheduler.step(val_c_index)
        
        print(f"Epoch {epoch+1}/{CONFIG['epochs']} - Loss: {train_loss:.4f}, C-index: {val_c_index:.4f}")
        
        if val_c_index > best_c_index:
            best_c_index = val_c_index
            torch.save(model.state_dict(), 'image_only_model.pth')
            patience_counter = 0
        else:
            patience_counter += 1
        
        if patience_counter >= 15:
            print("Early stopping")
            break
    
    # Export
    model.load_state_dict(torch.load('image_only_model.pth'))
    full_loader = DataLoader(dataset, batch_size=CONFIG['batch_size'], 
                            shuffle=False, num_workers=0)
    
    export_features(model, full_loader, device, "image_only_features.csv")
    
    print("\n" + "=" * 80)
    print(f"✓ BEST C-INDEX: {best_c_index:.4f}")
    print("=" * 80)

if __name__ == "__main__":
    main()

BENCHMARK 2: IMAGE-ONLY MODEL

Device: cuda

Loading valid patient list...
Valid patients to process: 161
  Processed: 161 valid patients
  Skipped (not in valid set): 221
Total sequences: 161

Training...
Epoch 1/100 - Loss: 2.3631, C-index: 0.4829
Epoch 2/100 - Loss: 2.3223, C-index: 0.5265
Epoch 3/100 - Loss: 2.1861, C-index: 0.5701
Epoch 4/100 - Loss: 1.8151, C-index: 0.5919
Epoch 5/100 - Loss: 1.4998, C-index: 0.6231
Epoch 6/100 - Loss: 1.5734, C-index: 0.6137
Epoch 7/100 - Loss: 1.2856, C-index: 0.4766
Epoch 8/100 - Loss: 1.2950, C-index: 0.4798
Epoch 9/100 - Loss: 1.2088, C-index: 0.4766
Epoch 10/100 - Loss: 1.2103, C-index: 0.5514
Epoch 11/100 - Loss: 1.4386, C-index: 0.5607
Epoch 12/100 - Loss: 1.1520, C-index: 0.5202
Epoch 13/100 - Loss: 1.0775, C-index: 0.5109
Epoch 14/100 - Loss: 0.8756, C-index: 0.5826
Epoch 15/100 - Loss: 0.9555, C-index: 0.5763
Epoch 16/100 - Loss: 1.0185, C-index: 0.5140
Epoch 17/100 - Loss: 1.0639, C-index: 0.5327
Epoch 18/100 - Loss: 0.8031, C-index: 