# ResNet Tests on Imagery

In [47]:
import os
import re
import zipfile
import urllib.request
import collections

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score, average_precision_score


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import torchvision
from torchvision import datasets, transforms, models


# Data Loader

In [78]:

np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

image_dir = "Images/2048Tiles"
csv_path = "CSV/Image_Data.csv"

batch_size = 16
img_height = 256
img_width = 256
num_classes = 2
class_names = ['0','1']

Using device: cuda


In [79]:
unmatched_labels = []
class MetastasisDataset(Dataset):
    def __init__(self, image_dir, csv_path, transform=None):
        self.image_dir = image_dir
        self.transform = transform

        df = pd.read_csv(csv_path)

        # Normalize codes (remove dashes/underscores)
        df['Code'] = df['Code'].astype(str).str.replace(r'[-_]', '', regex=True)
        self.label_map = dict(zip(df['Code'], df['Metastasis'].astype(str)))

        self.image_paths = []
        self.labels = []
        self.patient_ids = []
        unmatched_images = 0

        for fname in os.listdir(image_dir):
            # Normalize filename (remove dashes/underscores)
            fname_norm = fname.replace("-", "").replace("_", "")
            matched = False
            for pid in self.label_map:
                pid_norm = pid.replace("-", "").replace("_", "")
                if pid_norm in fname_norm:
                    self.image_paths.append(os.path.join(image_dir, fname))
                    self.labels.append(int(self.label_map[pid]))
                    self.patient_ids.append(pid)
                    matched = True
                    break
            if not matched:
                unmatched_images+=1
                unmatched_labels.append(fname)
        print(f'Number of images that were not matched = {unmatched_images}')
        print(f"Loaded {len(self.image_paths)} labeled images.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        patient_id = self.patient_ids[idx]
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, label, patient_id

# Transforms & Data Loader

In [80]:
train_transform = transforms.Compose([
    transforms.Resize((img_height, img_width)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((img_height, img_width)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


full_dataset = MetastasisDataset(image_dir=image_dir, csv_path=csv_path, transform=val_transform)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_paths, val_paths, train_labels, val_labels, train_ids, val_ids = train_test_split(
    full_dataset.image_paths,
    full_dataset.labels,
    full_dataset.patient_ids,
    test_size=0.2,
    stratify=full_dataset.labels,
    random_state=42
)

# Patient ID workaround begin code:

df = pd.DataFrame({
    "image_path": full_dataset.image_paths,
    "label": full_dataset.labels,
    "patient_id": full_dataset.patient_ids
})

patient_labels = df.groupby("patient_id")["label"].max().reset_index()

train_patients, val_patients = train_test_split(
    patient_labels["patient_id"],
    test_size=0.2,
    stratify=patient_labels["label"],  # stratify by patient-level label
    random_state=42
)

train_df = df[df["patient_id"].isin(train_patients)]
val_df   = df[df["patient_id"].isin(val_patients)]

# Sanity check
print("Train patients:", train_df["patient_id"].nunique())
print("Val patients:", val_df["patient_id"].nunique())
print("Overlap:", set(train_df["patient_id"]) & set(val_df["patient_id"]))  # should be empty
print("Train label counts:\n", train_df["label"].value_counts())
print("Val label counts:\n", val_df["label"].value_counts())

### Old code

'''
# Wrap back into Dataset objects
train_dataset = MetastasisDataset(image_dir=image_dir, csv_path=csv_path, transform=train_transform)
train_dataset.image_paths, train_dataset.labels, train_dataset.patient_ids = train_paths, train_labels, train_ids

val_dataset = MetastasisDataset(image_dir=image_dir, csv_path=csv_path, transform=val_transform)
val_dataset.image_paths, val_dataset.labels, val_dataset.patient_ids = val_paths, val_labels, val_ids
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

print(f"Train dataset size: {len(train_dataset)}, Val dataset size: {len(val_dataset)}")
'''

# New Code
train_dataset = MetastasisDataset(image_dir=image_dir, csv_path=csv_path, transform=train_transform)
train_dataset.image_paths  = train_df["image_path"].tolist()
train_dataset.labels       = train_df["label"].tolist()
train_dataset.patient_ids  = train_df["patient_id"].tolist()

val_dataset = MetastasisDataset(image_dir=image_dir, csv_path=csv_path, transform=val_transform)
val_dataset.image_paths    = val_df["image_path"].tolist()
val_dataset.labels         = val_df["label"].tolist()
val_dataset.patient_ids    = val_df["patient_id"].tolist()

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

Number of images that were not matched = 5805
Loaded 52029 labeled images.
Train patients: 126
Val patients: 32
Overlap: set()
Train label counts:
 label
1    23998
0    18553
Name: count, dtype: int64
Val label counts:
 label
1    5539
0    3939
Name: count, dtype: int64
Number of images that were not matched = 5805
Loaded 52029 labeled images.
Number of images that were not matched = 5805
Loaded 52029 labeled images.


In [81]:
patient_list = []
for i in train_dataset.patient_ids:
    for k in val_dataset.patient_ids:
        if i ==k:
            patient_list.append(i)
        else: pass
patient_list

[]

### Note to add EL-0 to final CSV file as there are no EL-0xxxxx patients represented. Data on these patients in Onkos_Todo.csv in "Raw_CSV" folder

# Block for checking class imbalance

In [82]:
counts = collections.Counter(full_dataset.labels)
train_counts = collections.Counter(train_dataset.labels)
val_counts = collections.Counter(val_dataset.labels)

print(f'Metastasis Count in Full Dataset: {counts[1]}')
print(f'Non-Metastatic Count in Full Dataset: {counts[0]}')
print(f'Class imbalance in the Full dataset is: {counts[1] / (counts[0]+counts[1])} for Metastasis Representation')
print()
print(f'Class Imbalance Count in Train Dataset: {train_counts[1] / (train_counts[0]+train_counts[1])}')
print(f'Class Imbalance Count in Validation Dataset: {val_counts[1] / (val_counts[0]+val_counts[1])}')
print()
if abs((train_counts[1] / (train_counts[0]+train_counts[1])) - val_counts[1] / (val_counts[0]+val_counts[1])) < 0.03:
    print("Training and Test sets are balanced")
else:    
    print("Training and Test sets are not balanced")

Metastasis Count in Full Dataset: 29537
Non-Metastatic Count in Full Dataset: 22492
Class imbalance in the Full dataset is: 0.5677026273808837 for Metastasis Representation

Class Imbalance Count in Train Dataset: 0.5639820450753213
Class Imbalance Count in Validation Dataset: 0.5844059928254907

Training and Test sets are balanced


# Getting Pretrained Layers and adding new layers

#### ResNet

#### DenseNet

In [53]:
def densenet_option(num=201):
    match num:
        case 121:
            return models.densenet121(weights=models.DenseNet121_Weights.DEFAULT), "DenseNet121"
        case 169:
            return models.densenet169(weights=models.DenseNet169_Weights.DEFAULT), "DenseNet169"
        case 201:
            return models.densenet201(weights=models.DenseNet201_Weights.DEFAULT), "DenseNet201"
        case 264:
            return models.densenet264(weights=models.DenseNet264_Weights.DEFAULT), "DenseNet264"


def freeze_all_except(model, train_blocks=None, train_classifier=True):
    """
    train_blocks: list of block indices you want to train, e.g. [4] or [3,4]
    train_classifier: True/False — whether classifier is trainable
    """
    # First freeze everything
    for param in model.parameters():
        param.requires_grad = False

    # Train classifier (optional)
    if train_classifier:
        for param in model.classifier.parameters():
            param.requires_grad = True

    # Train selected DenseNet blocks
    if train_blocks is not None:
        for block_num in train_blocks:
            block_name = f"denseblock{block_num}"
            block = getattr(model.features, block_name, None)
            if block is not None:
                for param in block.parameters():
                    param.requires_grad = True
            else:
                print(f"[Warning] {block_name} not found in model")

In [54]:
# Load model
# Select the DenseNet Variant select from [121, 169, 201, 264]
# model, densenet_name = densenet_option(169) - example
model, densenet_name = densenet_option(201)

# Replace classifier - customizes the classifier
# returns pretrained features to be used as input into classifier
# to remap
in_features = model.classifier.in_features
model.classifier = nn.Sequential(
    nn.Linear(in_features, 256),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(256, num_classes)             
)

# ======== APPLY DENSENET ABLATION HERE =========
# OPTIONS:
#freeze_all_except(model, train_blocks=[], train_classifier=True)   # Only classifier
freeze_all_except(model, train_blocks=[4])        # Train only DenseBlock4 + classifier
#freeze_all_except(model, train_blocks=[3,4])   # Train DenseBlock3 + DenseBlock4 + classifier
#freeze_all_except(model, train_blocks=[2,3,4]) # Train last DenseBlock2 + DenseBlock3 +DenseBlock4 + classifier
# ===============================================

# Move model to device, device set in Data Loader - CPU or GPU
model.to(device)
print(f"{densenet_name} Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

DenseNet201 Parameters: 7,470,850


# Training Utility Functions

In [55]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    """Train the model for one epoch."""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels, _ in tqdm(dataloader, desc="Training"):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc


def evaluate(model, dataloader, criterion, device):
    """Evaluate the model."""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_labels = []
    all_preds = []
    all_probs = []
    
    with torch.no_grad():
        for inputs, labels, _ in tqdm(dataloader, desc="Evaluating"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predicted.cpu().numpy())
            
            # Get probabilities for positive class
            probs = F.softmax(outputs, dim=1)
            all_probs.extend(probs[:, 1].cpu().numpy())
    
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100. * correct / total
    cm = confusion_matrix(all_labels, all_preds)
    report = classification_report(all_labels, all_preds, digits=4)
    
    # Calculate AUROC and AUPRC
    auroc = roc_auc_score(all_labels, all_probs)
    auprc = average_precision_score(all_labels, all_probs)
    
    return epoch_loss, epoch_acc, cm, report, auroc, auprc


def train_model(model, train_loader, val_loader, num_epochs=10, lr=0.001):
    """
    Train and evaluate a model.
    Returns:
        Dictionary with training history
    """
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
    
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    
    for epoch in range(num_epochs):
        print(f'\nEpoch {epoch+1}/{num_epochs}')
        print('-' * 30)
        
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc, cm, report, val_auroc, val_auprc = evaluate(model, val_loader, criterion, device)
        
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
        print("\nConfusion Matrix:\n", cm)
        
        scheduler.step()
    
    return history

def evaluate_test_set(model, val_loader, device):
    """
    Evaluate model on test set and print metrics.
    """
    print("\n" + "="*60)
    print("FINAL TEST SET EVALUATION")
    print("="*60)
    
    criterion = nn.CrossEntropyLoss()
    test_loss, test_acc, cm, report, auroc, auprc = evaluate(model, val_loader, criterion, device)
    
    print(f'\nTest Loss: {test_loss:.4f}')
    print(f'Test Accuracy: {test_acc:.2f}%')
    print(f'Test AUROC: {auroc:.4f}')
    print(f'Test AUPRC: {auprc:.4f}')
    print("\nConfusion Matrix:")
    print(cm)
    print("\nClassification Report:")
    print(report)
    
    return test_loss, test_acc, auroc, auprc
    
def plot_training_history(history, title="Training History"):
    """Plot training and validation loss/accuracy."""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

    ax1.plot(history['train_loss'], label='Train Loss')
    ax1.plot(history['val_loss'], label='Val Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title(f'{title} - Loss')
    ax1.legend()
    ax1.grid(True)

    ax2.plot(history['train_acc'], label='Train Acc')
    ax2.plot(history['val_acc'], label='Val Acc')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.set_title(f'{title} - Accuracy')
    ax2.legend()
    ax2.grid(True)

    plt.tight_layout()
    plt.show()


# Executive Block

In [56]:
if __name__ == "__main__":
    print("=" * 60)
    print("Testing Pretrained ResNet34 on Image Slides")
    print("=" * 60)
    test_epochs = 10
    results = {}
    
    try:
        # Train on train/val sets
        history = train_model(model, train_loader, val_loader, num_epochs=test_epochs, lr=1e-4)
        results["ResNet34"] = history["val_acc"][-1]
        plot_training_history(history, "ResNet34")
        
        # ADD THIS: Evaluate on test set
        test_loss, test_acc, test_auroc, test_auprc = evaluate_test_set(model, val_loader, device)
        results["ResNet34_test_acc"] = test_acc
        results["ResNet34_test_auroc"] = test_auroc
        results["ResNet34_test_auprc"] = test_auprc
        
    except Exception as e:
        print(f"Error in ResNet34: {e}")
        results["ResNet34"] = 0


Testing Pretrained ResNet34 on Image Slides

Epoch 1/10
------------------------------


Training: 100%|██████████| 2660/2660 [05:15<00:00,  8.43it/s]
Evaluating: 100%|██████████| 593/593 [00:53<00:00, 11.02it/s]


Train Loss: 0.4289, Train Acc: 79.72%
Val Loss: 0.4640, Val Acc: 78.28%

Confusion Matrix:
 [[2944  995]
 [1064 4475]]

Epoch 2/10
------------------------------


Training:  44%|████▍     | 1169/2660 [02:33<03:15,  7.62it/s]


KeyboardInterrupt: 

# Sifting Val Loader to find Specific Proteins (For testing on specifics)

In [57]:
from torch.utils.data import Subset
import os

def get_stain_loader(original_loader, stain_name):
    """
    Create a new loader with only specified stain (keeps original intact).
    
    Args:
        original_loader: Existing DataLoader
        stain_name: Stain to filter for
    
    Returns:
        New DataLoader with filtered subset
    """
    dataset = original_loader.dataset
    
    # Find indices matching the stain
    stain_indices = []
    for idx, path in enumerate(dataset.image_paths):
        filename = os.path.basename(path)
        if stain_name.upper() in filename.upper():
            stain_indices.append(idx)
    
    print(f"{stain_name}: {len(stain_indices)} images found")
    
    # Create subset
    subset = Subset(dataset, stain_indices)
    
    # Create new loader
    stain_loader = DataLoader(
        subset, 
        batch_size=original_loader.batch_size, 
        shuffle=False
    )
    
    return stain_loader

In [58]:
val_loader_he = get_stain_loader(val_loader, 'HE')
val_loader_mitf = get_stain_loader(val_loader, 'MITF')
val_loader_anx = get_stain_loader(val_loader, 'ANX')
val_loader_bcl2 = get_stain_loader(val_loader, 'BCL2')
val_loader_bcl3 = get_stain_loader(val_loader, 'BCL3')
val_loader_pbp = get_stain_loader(val_loader, 'PBP')
val_loader_pir = get_stain_loader(val_loader, 'PIR')

print(f"Original val_loader: {len(val_loader.dataset)} images")
print(f"HE-only loader: {len(val_loader_he.dataset)} images")

test_loss, test_acc, test_auroc, test_auprc = evaluate_test_set(model, val_loader_anx, device)
results["ResNet34_test_acc"] = test_acc
results["ResNet34_test_auroc"] = test_auroc
results["ResNet34_test_auprc"] = test_auprc

HE: 1416 images found
MITF: 1782 images found
ANX: 1182 images found
BCL2: 1240 images found
BCL3: 1194 images found
PBP: 1303 images found
PIR: 1319 images found
Original val_loader: 9478 images
HE-only loader: 1416 images

FINAL TEST SET EVALUATION


Evaluating: 100%|██████████| 74/74 [00:05<00:00, 13.86it/s]


Test Loss: 0.7044
Test Accuracy: 73.52%
Test AUROC: 0.7814
Test AUPRC: 0.8439

Confusion Matrix:
[[262 189]
 [124 607]]

Classification Report:
              precision    recall  f1-score   support

           0     0.6788    0.5809    0.6260       451
           1     0.7626    0.8304    0.7950       731

    accuracy                         0.7352      1182
   macro avg     0.7207    0.7057    0.7105      1182
weighted avg     0.7306    0.7352    0.7305      1182






# Begin Code to Evaluate Logits and Softmax for Patient level classification

In [87]:
def aggregate_by_patient_comprehensive(df, patient_id_length=-8):
    """
    Aggregate predictions by patient with comprehensive checks and statistics.
    
    Args:
        df: DataFrame with patch-level predictions
        patient_id_length: Number of characters from filename to use as patient ID
    
    Returns:
        aggregated_df: DataFrame with patient-level predictions
    """
    # Extract patient ID
    df['patient_id'] = df['filename'].str[:patient_id_length]
    
    # Check for label consistency within patients
    label_check = df.groupby('patient_id')['true_label'].nunique()
    inconsistent = label_check[label_check > 1]
    if len(inconsistent) > 0:
        print(f"WARNING: {len(inconsistent)} patients have inconsistent labels across patches!")
        print(inconsistent)
    
    # Aggregate
    aggregated = df.groupby('patient_id').agg({
        'true_label': lambda x: x.mode()[0] if len(x.mode()) > 0 else x.iloc[0],  # Most common label
        'prob_no_metastasis': 'mean',
        'prob_metastasis': 'mean',
        'predicted_class': 'count'  # Will be renamed to num_patches
    }).reset_index()
    
    aggregated.rename(columns={'predicted_class': 'num_patches'}, inplace=True)
    
    # Create predicted class based on averaged probabilities
    aggregated['predicted_class'] = (aggregated['prob_metastasis'] > aggregated['prob_no_metastasis']).astype(int)
    
    # Confidence is the max probability
    aggregated['confidence'] = aggregated[['prob_no_metastasis', 'prob_metastasis']].max(axis=1)
    
    # Correctness
    aggregated['correct'] = (aggregated['predicted_class'] == aggregated['true_label']).astype(int)
    
    # Reorder columns
    aggregated = aggregated[[
        'patient_id',
        'num_patches',
        'true_label',
        'predicted_class',
        'prob_no_metastasis',
        'prob_metastasis',
        'confidence',
        'correct'
    ]]
    
    # Calculate metrics
    from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score
    
    accuracy = aggregated['correct'].mean() * 100
    auroc = roc_auc_score(aggregated['true_label'], aggregated['prob_metastasis'])
    cm = confusion_matrix(aggregated['true_label'], aggregated['predicted_class'])
    
    print("="*70)
    print("PATIENT-LEVEL AGGREGATION RESULTS")
    print("="*70)
    print(f"Total patches: {len(df)}")
    print(f"Total patients: {len(aggregated)}")
    print(f"Avg patches/patient: {aggregated['num_patches'].mean():.1f} (min: {aggregated['num_patches'].min()}, max: {aggregated['num_patches'].max()})")
    print(f"\nPatient-level accuracy: {accuracy:.2f}%")
    print(f"Patient-level AUROC: {auroc:.4f}")
    print(f"\nConfusion Matrix:")
    print(f"                 Predicted")
    print(f"               No Met  Metastasis")
    print(f"True No Met      {cm[0,0]:4d}    {cm[0,1]:4d}")
    print(f"True Metastasis  {cm[1,0]:4d}    {cm[1,1]:4d}")
    
    print(f"\nClassification Report:")
    print(classification_report(aggregated['true_label'], aggregated['predicted_class'],
                                target_names=['No Metastasis', 'Metastasis'],
                                digits=4))
    
    return aggregated

In [88]:
def save_patch_predictions_to_csv(model, dataloader, device, output_csv='patch_level_predictions.csv'):
    """
    Save individual patch predictions to CSV (handles both Dataset and Subset).
    
    Args:
        model: Trained CNN model
        dataloader: DataLoader containing validation data
        device: torch device
        output_csv: Path to save CSV file
    
    Returns:
        df: DataFrame with patch-level predictions
    """
    model.eval()
    
    csv_data = []
    
    # Handle Subset vs regular Dataset
    dataset = dataloader.dataset
    if hasattr(dataset, 'dataset'):
        # It's a Subset, get the underlying dataset
        original_dataset = dataset.dataset
        indices = dataset.indices
    else:
        # It's a regular dataset
        original_dataset = dataset
        indices = range(len(dataset))
    
    print("Computing patch-level predictions...")
    with torch.no_grad():
        for batch_idx, (inputs, labels, patient_ids) in enumerate(tqdm(dataloader, desc="Processing patches")):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            # Get raw logits
            logits = model(inputs)
            
            # Apply softmax to get probabilities
            probs = F.softmax(logits, dim=1)
            
            # Get predictions
            _, predicted = torch.max(logits, 1)
            
            # Process each image in the batch
            for i in range(len(labels)):
                # Get the actual dataset index
                sample_idx = batch_idx * dataloader.batch_size + i
                if sample_idx < len(indices):
                    actual_idx = indices[sample_idx]
                    
                    # Get filename from original dataset
                    img_path = original_dataset.image_paths[actual_idx]
                    filename = os.path.basename(img_path)
                    
                    # Add to CSV data
                    csv_data.append({
                        'filename': filename,
                        'true_label': labels[i].item(),
                        'predicted_class': predicted[i].item(),
                        'prob_no_metastasis': probs[i, 0].item(),
                        'prob_metastasis': probs[i, 1].item(),
                        'confidence': probs[i, predicted[i]].item(),
                        'correct': int(predicted[i].item() == labels[i].item())
                    })
    
    # Create DataFrame and save to CSV
    df = pd.DataFrame(csv_data)
    df.to_csv(output_csv, index=False)
    
    print(f"\n✓ Saved {len(df)} patch predictions to: {output_csv}")
    print(f"\nCSV Preview:")
    print(df.head(10).to_string(index=False))
    
    return df


# Usage
patch_df = save_patch_predictions_to_csv(
    model=model,
    dataloader=val_loader,
    device=device,
    output_csv='patch_level_predictions.csv'
)

# Then aggregate to patient level
print("\n" + "="*70)
print("Aggregating to patient level...")
patient_df = aggregate_by_patient_comprehensive(patch_df, patient_id_length=11)
patient_df.to_csv('he_patient_level_predictions.csv', index=False)
print("✓ Saved patient-level predictions")

Computing patch-level predictions...


Processing patches:  48%|████▊     | 286/593 [00:26<00:28, 10.82it/s]


KeyboardInterrupt: 

In [89]:
df = pd.read_csv('patch_level_predictions.csv')
df = df.sort_values(by='filename')
df = df.reset_index()
df = df.drop('index',axis=1)

# Check Metrics After Consolodating

In [90]:
# Before aggregation
print("Before aggregation:")
print(f"  Total patches: {len(df)}")
print(f"  Unique filenames: {df['filename'].nunique()}")

# After aggregation
patient_df = aggregate_by_patient_robust(df, trim_chars=3)

# Verify no data loss
print("\nVerification:")
print(f"  Patches in original: {len(df)}")
print(f"  Patches in aggregated (sum): {patient_df['num_patches'].sum()}")
print(f"  Match: {len(df) == patient_df['num_patches'].sum()}")

# Check for any patients with only 1 patch (might be suspicious)
single_patch = patient_df[patient_df['num_patches'] == 1]
if len(single_patch) > 0:
    print(f"\n⚠️  {len(single_patch)} patients have only 1 patch:")
    print(single_patch[['patient_id', 'sample_filename']].head(10))

Before aggregation:
  Total patches: 9478
  Unique filenames: 9478

✓ Aggregated 9478 patches into 9478 patients
  Avg patches/patient: 1.0
  Patient-level accuracy: 75.40%

Verification:
  Patches in original: 9478
  Patches in aggregated (sum): 9478
  Match: True

⚠️  9478 patients have only 1 patch:
                                       patient_id  \
0  ast_b28775-anx - 2018-05-28 15.23.57_slide_10.   
1  ast_b28775-anx - 2018-05-28 15.23.57_slide_12.   
2  ast_b28775-anx - 2018-05-28 15.23.57_slide_13.   
3  ast_b28775-anx - 2018-05-28 15.23.57_slide_14.   
4  ast_b28775-anx - 2018-05-28 15.23.57_slide_15.   
5  ast_b28775-anx - 2018-05-28 15.23.57_slide_16.   
6  ast_b28775-anx - 2018-05-28 15.23.57_slide_17.   
7  ast_b28775-anx - 2018-05-28 15.23.57_slide_18.   
8  ast_b28775-anx - 2018-05-28 15.23.57_slide_19.   
9   ast_b28775-anx - 2018-05-28 15.23.57_slide_2.   

                                     sample_filename  
0  AST_B28775-ANX - 2018-05-28 15.23.57_slide_10.png  
1 

In [91]:
def aggregate_by_patient_robust(df, trim_chars=3, case_sensitive=False):
    """
    Robust patient aggregation with error handling.
    
    Args:
        df: DataFrame with patch-level predictions
        trim_chars: Number of characters to remove from end (use 0 for none)
        case_sensitive: Whether patient IDs should be case-sensitive
    """
    df = df.copy()
    
    # Handle null/empty filenames
    df = df[df['filename'].notna() & (df['filename'] != '')]
    
    # Extract patient ID
    if trim_chars > 0:
        # Check if any filenames are too short
        too_short = df['filename'].str.len() <= trim_chars
        if too_short.any():
            print(f"⚠️  WARNING: {too_short.sum()} filenames are too short, skipping trim for those")
            df['patient_id'] = df.apply(
                lambda row: row['filename'][:-trim_chars] if len(row['filename']) > trim_chars else row['filename'],
                axis=1
            )
        else:
            df['patient_id'] = df['filename'].str[:-trim_chars]
    else:
        df['patient_id'] = df['filename']
    
    # Strip whitespace
    df['patient_id'] = df['patient_id'].str.strip()
    
    # Convert to lowercase if not case-sensitive
    if not case_sensitive:
        df['patient_id'] = df['patient_id'].str.lower()
    
    # Check for label inconsistencies
    label_check = df.groupby('patient_id')['true_label'].nunique()
    inconsistent = label_check[label_check > 1]
    if len(inconsistent) > 0:
        print(f"⚠️  WARNING: {len(inconsistent)} patients have inconsistent labels!")
        for pid in inconsistent.index:
            labels = df[df['patient_id'] == pid]['true_label'].unique()
            print(f"   {pid}: {labels}")
    
    # Aggregate
    aggregated = df.groupby('patient_id').agg({
        'true_label': lambda x: x.mode()[0] if len(x.mode()) > 0 else x.iloc[0],  # Most common
        'prob_no_metastasis': 'mean',
        'prob_metastasis': 'mean',
        'filename': ['count', 'first']  # Count and sample filename
    }).reset_index()
    
    # Flatten column names
    aggregated.columns = ['patient_id', 'true_label', 'prob_no_metastasis', 
                         'prob_metastasis', 'num_patches', 'sample_filename']
    
    # Predicted class
    aggregated['predicted_class'] = (aggregated['prob_metastasis'] > aggregated['prob_no_metastasis']).astype(int)
    
    # Confidence and correctness
    aggregated['confidence'] = aggregated[['prob_no_metastasis', 'prob_metastasis']].max(axis=1)
    aggregated['correct'] = (aggregated['predicted_class'] == aggregated['true_label']).astype(int)
    
    # Summary
    print(f"\n✓ Aggregated {len(df)} patches into {len(aggregated)} patients")
    print(f"  Avg patches/patient: {aggregated['num_patches'].mean():.1f}")
    print(f"  Patient-level accuracy: {aggregated['correct'].mean() * 100:.2f}%")
    
    return aggregated


# Usage
patient_df = aggregate_by_patient_robust(df, trim_chars=3, case_sensitive=False)


✓ Aggregated 9478 patches into 9478 patients
  Avg patches/patient: 1.0
  Patient-level accuracy: 75.40%
