# ResNetHE - Metastasis Classification from Histopathology Images

This notebook trains and evaluates a DenseNet-based model for binary classification of metastasis in histopathology image tiles.

## Pipeline Overview
1. Load and preprocess image tiles with corresponding labels
2. Split data by patient (no patient overlap between train/val)
3. Train DenseNet with transfer learning
4. Evaluate on validation set (patch-level and patient-level)
5. Generate predictions and aggregate by patient

## 1. Imports

In [None]:
import os
import re
import collections

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from tqdm import tqdm

# Scikit-learn
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    confusion_matrix,
    classification_report,
    roc_auc_score,
    average_precision_score
)

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset

# Torchvision
import torchvision
from torchvision import datasets, transforms, models

## 2. Configuration

In [None]:
# Reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Paths
IMAGE_DIR = "Images/2048Tiles"
CSV_PATH = "CSV/Image_Data.csv"

# Hyperparameters
BATCH_SIZE = 16
IMG_HEIGHT = 256
IMG_WIDTH = 256
NUM_CLASSES = 2
CLASS_NAMES = ['0', '1']

## 3. Dataset Class

In [None]:
class MetastasisDataset(Dataset):
    """
    Custom Dataset for loading histopathology image tiles with metastasis labels.
    
    Matches image filenames to patient codes in the CSV file and assigns
    corresponding metastasis labels.
    """
    
    unmatched_labels = []  # Track unmatched images globally for debugging
    
    def __init__(self, image_dir, csv_path, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        
        # Load and process label CSV
        df = pd.read_csv(csv_path)
        df['Code'] = df['Code'].astype(str).str.replace(r'[-_]', '', regex=True)
        self.label_map = dict(zip(df['Code'], df['Metastasis'].astype(str)))
        
        # Initialize storage
        self.image_paths = []
        self.labels = []
        self.patient_ids = []
        unmatched_images = 0
        
        # Match images to labels
        for fname in os.listdir(image_dir):
            fname_norm = fname.replace("-", "").replace("_", "")
            matched = False
            
            for pid in self.label_map:
                pid_norm = pid.replace("-", "").replace("_", "")
                if pid_norm in fname_norm:
                    self.image_paths.append(os.path.join(image_dir, fname))
                    self.labels.append(int(self.label_map[pid]))
                    self.patient_ids.append(pid)
                    matched = True
                    break
            
            if not matched:
                unmatched_images += 1
                MetastasisDataset.unmatched_labels.append(fname)
        
        print(f'Number of images not matched: {unmatched_images}')
        print(f"Loaded {len(self.image_paths)} labeled images.")
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        patient_id = self.patient_ids[idx]
        
        image = Image.open(img_path).convert("RGB")
        
        if self.transform:
            image = self.transform(image)
        
        return image, label, patient_id

## 4. Data Transforms

In [None]:
train_transform = transforms.Compose([
    transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

## 5. Data Loading with Patient-Level Split

In [None]:
def create_patient_split_loaders(image_dir, csv_path, batch_size, test_size=0.2, random_state=42):
    """
    Create train/val DataLoaders with patient-level splitting.
    
    Ensures no patient appears in both training and validation sets,
    which is critical for valid evaluation in medical imaging.
    """
    # Load full dataset to get patient information
    full_dataset = MetastasisDataset(
        image_dir=image_dir,
        csv_path=csv_path,
        transform=val_transform
    )
    
    # Create DataFrame for patient-level operations
    df = pd.DataFrame({
        "image_path": full_dataset.image_paths,
        "label": full_dataset.labels,
        "patient_id": full_dataset.patient_ids
    })
    
    # Get patient-level labels (max ensures any positive patch = positive patient)
    patient_labels = df.groupby("patient_id")["label"].max().reset_index()
    
    # Split at patient level
    train_patients, val_patients = train_test_split(
        patient_labels["patient_id"],
        test_size=test_size,
        stratify=patient_labels["label"],
        random_state=random_state
    )
    
    # Filter images by patient split
    train_df = df[df["patient_id"].isin(train_patients)]
    val_df = df[df["patient_id"].isin(val_patients)]
    
    # Sanity checks
    print(f"\nPatient-level split summary:")
    print(f"  Train patients: {train_df['patient_id'].nunique()}")
    print(f"  Val patients: {val_df['patient_id'].nunique()}")
    print(f"  Overlap: {set(train_df['patient_id']) & set(val_df['patient_id'])}")
    print(f"\nTrain label distribution:\n{train_df['label'].value_counts().to_string()}")
    print(f"\nVal label distribution:\n{val_df['label'].value_counts().to_string()}")
    
    # Create train dataset
    train_dataset = MetastasisDataset(
        image_dir=image_dir,
        csv_path=csv_path,
        transform=train_transform
    )
    train_dataset.image_paths = train_df["image_path"].tolist()
    train_dataset.labels = train_df["label"].tolist()
    train_dataset.patient_ids = train_df["patient_id"].tolist()
    
    # Create validation dataset
    val_dataset = MetastasisDataset(
        image_dir=image_dir,
        csv_path=csv_path,
        transform=val_transform
    )
    val_dataset.image_paths = val_df["image_path"].tolist()
    val_dataset.labels = val_df["label"].tolist()
    val_dataset.patient_ids = val_df["patient_id"].tolist()
    
    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, val_loader, train_dataset, val_dataset


def check_class_balance(train_dataset, val_dataset):
    """Print class balance statistics for train and validation sets."""
    train_counts = collections.Counter(train_dataset.labels)
    val_counts = collections.Counter(val_dataset.labels)
    
    train_ratio = train_counts[1] / (train_counts[0] + train_counts[1])
    val_ratio = val_counts[1] / (val_counts[0] + val_counts[1])
    
    print("\nClass Balance Analysis:")
    print(f"  Train - Metastasis: {train_counts[1]}, No Metastasis: {train_counts[0]}")
    print(f"  Train - Metastasis ratio: {train_ratio:.3f}")
    print(f"  Val - Metastasis: {val_counts[1]}, No Metastasis: {val_counts[0]}")
    print(f"  Val - Metastasis ratio: {val_ratio:.3f}")
    
    if abs(train_ratio - val_ratio) < 0.03:
        print("  ✓ Training and validation sets are balanced")
    else:
        print("  ⚠ Training and validation sets have different class distributions")

## 6. Model Configuration

In [None]:
def densenet_option(num=201):
    """Get a pretrained DenseNet model."""
    options = {
        121: (models.densenet121, models.DenseNet121_Weights.DEFAULT, "DenseNet121"),
        169: (models.densenet169, models.DenseNet169_Weights.DEFAULT, "DenseNet169"),
        201: (models.densenet201, models.DenseNet201_Weights.DEFAULT, "DenseNet201"),
    }
    
    if num not in options:
        raise ValueError(f"Invalid DenseNet variant: {num}. Choose from {list(options.keys())}")
    
    model_fn, weights, name = options[num]
    return model_fn(weights=weights), name


def freeze_all_except(model, train_blocks=None, train_classifier=True):
    """
    Freeze model layers except specified DenseNet blocks and classifier.
    
    Args:
        model: DenseNet model
        train_blocks (list): List of block indices to train (e.g., [4] or [3, 4])
        train_classifier (bool): Whether to train the classifier layer
    """
    # Freeze all parameters
    for param in model.parameters():
        param.requires_grad = False
    
    # Unfreeze classifier
    if train_classifier:
        for param in model.classifier.parameters():
            param.requires_grad = True
    
    # Unfreeze specified DenseNet blocks
    if train_blocks is not None:
        for block_num in train_blocks:
            block_name = f"denseblock{block_num}"
            block = getattr(model.features, block_name, None)
            if block is not None:
                for param in block.parameters():
                    param.requires_grad = True
            else:
                print(f"[Warning] {block_name} not found in model")


def create_model(densenet_variant=201, train_blocks=[4], num_classes=NUM_CLASSES):
    """Create and configure a DenseNet model for metastasis classification."""
    # Load pretrained model
    model, model_name = densenet_option(densenet_variant)
    
    # Replace classifier with custom head
    in_features = model.classifier.in_features
    model.classifier = nn.Sequential(
        nn.Linear(in_features, 256),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(256, num_classes)
    )
    
    # Apply transfer learning configuration
    freeze_all_except(model, train_blocks=train_blocks, train_classifier=True)
    
    # Move to device
    model.to(device)
    
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"\n{model_name} configured:")
    print(f"  Trainable parameters: {trainable_params:,}")
    print(f"  Training blocks: {train_blocks}")
    
    return model, model_name

## 7. Training Functions

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    """Train the model for one epoch."""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, labels, _ in tqdm(dataloader, desc="Training"):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc


def evaluate(model, dataloader, criterion, device):
    """Evaluate the model on a dataset."""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_labels = []
    all_preds = []
    all_probs = []
    
    with torch.no_grad():
        for inputs, labels, _ in tqdm(dataloader, desc="Evaluating"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predicted.cpu().numpy())
            
            probs = F.softmax(outputs, dim=1)
            all_probs.extend(probs[:, 1].cpu().numpy())
    
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100. * correct / total
    cm = confusion_matrix(all_labels, all_preds)
    report = classification_report(all_labels, all_preds, digits=4)
    auroc = roc_auc_score(all_labels, all_probs)
    auprc = average_precision_score(all_labels, all_probs)
    
    return epoch_loss, epoch_acc, cm, report, auroc, auprc


def train_model(model, train_loader, val_loader, num_epochs=10, lr=0.001):
    """Train and evaluate a model."""
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
    
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    
    for epoch in range(num_epochs):
        print(f'\nEpoch {epoch+1}/{num_epochs}')
        print('-' * 30)
        
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc, cm, report, val_auroc, val_auprc = evaluate(
            model, val_loader, criterion, device
        )
        
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
        print(f"Confusion Matrix:\n{cm}")
        
        scheduler.step()
    
    return history


def evaluate_test_set(model, test_loader, device):
    """Evaluate model on test set and print comprehensive metrics."""
    print("\n" + "=" * 60)
    print("FINAL TEST SET EVALUATION")
    print("=" * 60)
    
    criterion = nn.CrossEntropyLoss()
    test_loss, test_acc, cm, report, auroc, auprc = evaluate(
        model, test_loader, criterion, device
    )
    
    print(f'\nTest Loss: {test_loss:.4f}')
    print(f'Test Accuracy: {test_acc:.2f}%')
    print(f'Test AUROC: {auroc:.4f}')
    print(f'Test AUPRC: {auprc:.4f}')
    print(f"\nConfusion Matrix:\n{cm}")
    print(f"\nClassification Report:\n{report}")
    
    return test_loss, test_acc, auroc, auprc


def plot_training_history(history, title="Training History"):
    """Plot training and validation loss/accuracy curves."""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    ax1.plot(history['train_loss'], label='Train Loss')
    ax1.plot(history['val_loss'], label='Val Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title(f'{title} - Loss')
    ax1.legend()
    ax1.grid(True)
    
    ax2.plot(history['train_acc'], label='Train Acc')
    ax2.plot(history['val_acc'], label='Val Acc')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.set_title(f'{title} - Accuracy')
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    plt.show()

## 8. Stain-Specific Evaluation

In [None]:
def get_stain_loader(original_loader, stain_name):
    """
    Create a DataLoader containing only images of a specified stain type.
    
    Args:
        original_loader: Existing DataLoader
        stain_name (str): Stain type to filter for (e.g., 'HE', 'MITF')
    
    Returns:
        DataLoader: Filtered DataLoader for the specified stain
    """
    dataset = original_loader.dataset
    
    # Find indices matching the stain
    stain_indices = []
    for idx, path in enumerate(dataset.image_paths):
        filename = os.path.basename(path)
        if stain_name.upper() in filename.upper():
            stain_indices.append(idx)
    
    print(f"{stain_name}: {len(stain_indices)} images found")
    
    # Create subset and loader
    subset = Subset(dataset, stain_indices)
    stain_loader = DataLoader(
        subset,
        batch_size=original_loader.batch_size,
        shuffle=False
    )
    
    return stain_loader

## 9. Prediction Export and Patient Aggregation

In [None]:
def save_patch_predictions_to_csv(model, dataloader, device, output_csv='patch_level_predictions.csv'):
    """
    Save individual patch predictions to CSV file.
    Handles both regular Datasets and Subsets (for stain-specific evaluation).
    """
    model.eval()
    csv_data = []
    
    # Handle Subset vs regular Dataset
    dataset = dataloader.dataset
    if hasattr(dataset, 'dataset'):
        original_dataset = dataset.dataset
        indices = dataset.indices
    else:
        original_dataset = dataset
        indices = range(len(dataset))
    
    print("Computing patch-level predictions...")
    with torch.no_grad():
        for batch_idx, (inputs, labels, patient_ids) in enumerate(tqdm(dataloader, desc="Processing patches")):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            logits = model(inputs)
            probs = F.softmax(logits, dim=1)
            _, predicted = torch.max(logits, 1)
            
            for i in range(len(labels)):
                sample_idx = batch_idx * dataloader.batch_size + i
                if sample_idx < len(indices):
                    actual_idx = indices[sample_idx]
                    img_path = original_dataset.image_paths[actual_idx]
                    filename = os.path.basename(img_path)
                    
                    csv_data.append({
                        'filename': filename,
                        'true_label': labels[i].item(),
                        'predicted_class': predicted[i].item(),
                        'prob_no_metastasis': probs[i, 0].item(),
                        'prob_metastasis': probs[i, 1].item(),
                        'confidence': probs[i, predicted[i]].item(),
                        'correct': int(predicted[i].item() == labels[i].item())
                    })
    
    df = pd.DataFrame(csv_data)
    df.to_csv(output_csv, index=False)
    
    print(f"\n✓ Saved {len(df)} patch predictions to: {output_csv}")
    print(f"\nCSV Preview:\n{df.head(10).to_string(index=False)}")
    
    return df


def aggregate_by_patient_robust(df, trim_chars=3, case_sensitive=False):
    """
    Aggregate patch-level predictions to patient-level using mean probability pooling.
    """
    df = df.copy()
    
    # Handle null/empty filenames
    df = df[df['filename'].notna() & (df['filename'] != '')]
    
    # Extract patient ID
    if trim_chars > 0:
        too_short = df['filename'].str.len() <= trim_chars
        if too_short.any():
            print(f"⚠️  WARNING: {too_short.sum()} filenames are too short, skipping trim")
            df['patient_id'] = df.apply(
                lambda row: row['filename'][:-trim_chars] if len(row['filename']) > trim_chars else row['filename'],
                axis=1
            )
        else:
            df['patient_id'] = df['filename'].str[:-trim_chars]
    else:
        df['patient_id'] = df['filename']
    
    df['patient_id'] = df['patient_id'].str.strip()
    
    if not case_sensitive:
        df['patient_id'] = df['patient_id'].str.lower()
    
    # Check for label inconsistencies
    label_check = df.groupby('patient_id')['true_label'].nunique()
    inconsistent = label_check[label_check > 1]
    if len(inconsistent) > 0:
        print(f"⚠️  WARNING: {len(inconsistent)} patients have inconsistent labels!")
        for pid in inconsistent.index:
            labels = df[df['patient_id'] == pid]['true_label'].unique()
            print(f"   {pid}: {labels}")
    
    # Aggregate by patient
    aggregated = df.groupby('patient_id').agg({
        'true_label': lambda x: x.mode()[0] if len(x.mode()) > 0 else x.iloc[0],
        'prob_no_metastasis': 'mean',
        'prob_metastasis': 'mean',
        'filename': ['count', 'first']
    }).reset_index()
    
    aggregated.columns = [
        'patient_id', 'true_label', 'prob_no_metastasis',
        'prob_metastasis', 'num_patches', 'sample_filename'
    ]
    
    # Compute predictions and metrics
    aggregated['predicted_class'] = (
        aggregated['prob_metastasis'] > aggregated['prob_no_metastasis']
    ).astype(int)
    aggregated['confidence'] = aggregated[['prob_no_metastasis', 'prob_metastasis']].max(axis=1)
    aggregated['correct'] = (aggregated['predicted_class'] == aggregated['true_label']).astype(int)
    
    # Summary
    print(f"\n✓ Aggregated {len(df)} patches into {len(aggregated)} patients")
    print(f"  Avg patches/patient: {aggregated['num_patches'].mean():.1f}")
    print(f"  Patient-level accuracy: {aggregated['correct'].mean() * 100:.2f}%")
    
    return aggregated


def aggregate_by_patient_comprehensive(df, patient_id_length=-8):
    """
    Aggregate predictions by patient with comprehensive metrics and reporting.
    """
    df['patient_id'] = df['filename'].str[:patient_id_length]
    
    # Check for label consistency
    label_check = df.groupby('patient_id')['true_label'].nunique()
    inconsistent = label_check[label_check > 1]
    if len(inconsistent) > 0:
        print(f"WARNING: {len(inconsistent)} patients have inconsistent labels!")
        print(inconsistent)
    
    # Aggregate
    aggregated = df.groupby('patient_id').agg({
        'true_label': lambda x: x.mode()[0] if len(x.mode()) > 0 else x.iloc[0],
        'prob_no_metastasis': 'mean',
        'prob_metastasis': 'mean',
        'predicted_class': 'count'
    }).reset_index()
    
    aggregated.rename(columns={'predicted_class': 'num_patches'}, inplace=True)
    aggregated['predicted_class'] = (
        aggregated['prob_metastasis'] > aggregated['prob_no_metastasis']
    ).astype(int)
    aggregated['confidence'] = aggregated[['prob_no_metastasis', 'prob_metastasis']].max(axis=1)
    aggregated['correct'] = (aggregated['predicted_class'] == aggregated['true_label']).astype(int)
    
    # Reorder columns
    aggregated = aggregated[[
        'patient_id', 'num_patches', 'true_label', 'predicted_class',
        'prob_no_metastasis', 'prob_metastasis', 'confidence', 'correct'
    ]]
    
    # Calculate and print metrics
    accuracy = aggregated['correct'].mean() * 100
    auroc = roc_auc_score(aggregated['true_label'], aggregated['prob_metastasis'])
    cm = confusion_matrix(aggregated['true_label'], aggregated['predicted_class'])
    
    print("=" * 70)
    print("PATIENT-LEVEL AGGREGATION RESULTS")
    print("=" * 70)
    print(f"Total patches: {len(df)}")
    print(f"Total patients: {len(aggregated)}")
    print(f"Avg patches/patient: {aggregated['num_patches'].mean():.1f} "
          f"(min: {aggregated['num_patches'].min()}, max: {aggregated['num_patches'].max()})")
    print(f"\nPatient-level accuracy: {accuracy:.2f}%")
    print(f"Patient-level AUROC: {auroc:.4f}")
    print(f"\nConfusion Matrix:")
    print(f"                 Predicted")
    print(f"               No Met  Metastasis")
    print(f"True No Met      {cm[0,0]:4d}    {cm[0,1]:4d}")
    print(f"True Metastasis  {cm[1,0]:4d}    {cm[1,1]:4d}")
    print(f"\nClassification Report:")
    print(classification_report(
        aggregated['true_label'],
        aggregated['predicted_class'],
        target_names=['No Metastasis', 'Metastasis'],
        digits=4
    ))
    
    return aggregated

## 10. Load Data

In [None]:
train_loader, val_loader, train_dataset, val_dataset = create_patient_split_loaders(
    image_dir=IMAGE_DIR,
    csv_path=CSV_PATH,
    batch_size=BATCH_SIZE
)

check_class_balance(train_dataset, val_dataset)

## 11. Create Model

In [None]:
# Choose DenseNet variant: 121, 169, or 201
# Choose blocks to train: [4], [3,4], [2,3,4], or [] for classifier only
model, model_name = create_model(
    densenet_variant=201,
    train_blocks=[4]
)

## 12. Train Model

In [None]:
NUM_EPOCHS = 10
LEARNING_RATE = 1e-4

history = train_model(
    model,
    train_loader,
    val_loader,
    num_epochs=NUM_EPOCHS,
    lr=LEARNING_RATE
)

plot_training_history(history, model_name)

## 13. Evaluate on Test Set

In [None]:
test_loss, test_acc, test_auroc, test_auprc = evaluate_test_set(
    model, val_loader, device
)

## 14. Stain-Specific Evaluation (Optional)

In [None]:
# Create stain-specific loaders
val_loader_he = get_stain_loader(val_loader, 'HE')
val_loader_mitf = get_stain_loader(val_loader, 'MITF')
val_loader_anx = get_stain_loader(val_loader, 'ANX')
val_loader_bcl2 = get_stain_loader(val_loader, 'BCL2')
val_loader_bcl3 = get_stain_loader(val_loader, 'BCL3')
val_loader_pbp = get_stain_loader(val_loader, 'PBP')
val_loader_pir = get_stain_loader(val_loader, 'PIR')

# Evaluate on a specific stain (example: ANX)
test_loss, test_acc, test_auroc, test_auprc = evaluate_test_set(model, val_loader_anx, device)

## 15. Save Predictions and Aggregate by Patient

In [None]:
# Save patch-level predictions
patch_df = save_patch_predictions_to_csv(
    model=model,
    dataloader=val_loader,
    device=device,
    output_csv='patch_level_predictions.csv'
)

# Aggregate to patient level
patient_df = aggregate_by_patient_comprehensive(patch_df, patient_id_length=11)
patient_df.to_csv('patient_level_predictions.csv', index=False)
print("✓ Saved patient-level predictions")