## Phase-3: Baseline

Step-0a: Data Ingestion and Initial Preprocessing

In [1]:
from google.colab import drive
drive.mount('/content/drive')

# Go to your project folder
%cd /content/drive/MyDrive/multimodal_mammography


Mounted at /content/drive
/content/drive/.shortcut-targets-by-id/1BPq115f9Nu1zGGIf_pYtngsK8hJDkhdW/multimodal_mammography


In [2]:
import importlib.util

def load_module_from_path(name, path):
    spec = importlib.util.spec_from_file_location(name, path)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module


In [3]:
# Load environment setup
env = load_module_from_path("env", "setup/environment.py")
install = load_module_from_path("install", "setup/install_colab.py")
_ = load_module_from_path("imports", "setup/imports.py")  # No functions to call

# Run setup
install.install_dependencies()
env.suppress_warnings()
env.set_seed(42)
device = env.get_device()


🔄 Detected Google Colab environment.
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
✅ Google Drive mounted.
📦 Installing required packages...
✅ Dependencies installed.
🔁 Seed set to 42
 Using device: cuda


Step-0b: Loading Required csvs' and extracting/exploring images


In [4]:
# ✅ Load the dynamic module
data_loader = load_module_from_path("data_loader", "data/load_data.py")

# ✅ Correct CSV paths
metadata_path    = "/content/drive/MyDrive/multimodal_mammography/dataset/csv/metadata.csv"
breast_anno_path = "/content/drive/MyDrive/multimodal_mammography/dataset/csv/breast-level_annotations.csv"
finding_anno_path = "/content/drive/MyDrive/multimodal_mammography/dataset/csv/finding_annotations.csv"

# ✅ Load and view data
metadata_df, breast_df, finding_df = data_loader.load_mammo_data(
    metadata_path,
    breast_anno_path,
    finding_anno_path,
    verbose=False
)


In [5]:
import pandas as pd

image_df=pd.read_csv("/content/drive/MyDrive/multimodal_mammography/dataset/csv/image_df_upsampled_studywise.csv")

In [6]:
print(image_df.columns)

Index(['image_id', 'study_id', 'filename', 'birads', 'birads_dir', 'density',
       'laterality', 'view_position', 'split', 'finding_categories',
       'finding_birads_clean', 'xmin', 'ymin', 'xmax', 'ymax', 'has_bbox',
       'age', 'birads_binary', 'birads_cleaned', 'birads_study_level',
       'finding_mass', 'finding_suspicious_calcification',
       'finding_focal_asymmetry', 'finding_asymmetry',
       'finding_global_asymmetry', 'finding_architectural_distortion',
       'finding_skin_thickening', 'finding_skin_retraction',
       'finding_nipple_retraction', 'finding_suspicious_lymph_node',
       'finding_no_finding', 'image_path', 'case_category', 'upsampled'],
      dtype='object')


In [7]:
import zipfile
import os

# Path to your zip file
zip_path = "/content/drive/MyDrive/multimodal_mammography/dataset/zipped_folder/birads_preprocessed_dataset.zip"

# Destination folder to extract files
extract_dir = "/content/birads_preprocessed_dataset"

# Make sure the directory exists
os.makedirs(extract_dir, exist_ok=True)

# Unzip the dataset
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_dir)

print("Extraction complete.")
print("Extracted to:", extract_dir)


Extraction complete.
Extracted to: /content/birads_preprocessed_dataset


In [8]:
import os

# List a few extracted files/folders
for root, dirs, files in os.walk(extract_dir):
    print("Root:", root)
    print("Subdirs:", dirs[:5])   # show first 5 dirs
    print("Files:", files[:5])   # show first 5 files
    break


Root: /content/birads_preprocessed_dataset
Subdirs: ['training', 'test']
Files: ['image_df_upsampled_preprocessed.csv']


In [9]:
import os
from collections import defaultdict

base_dir = "/content/birads_preprocessed_dataset"

def find_small_studies(base_dir, min_images=4):
    small_studies = defaultdict(list)

    for split in ["training", "test"]:
        for case in ["normal", "abnormal"]:
            case_path = os.path.join(base_dir, split, case)
            if not os.path.exists(case_path):
                continue

            for study in os.listdir(case_path):
                study_path = os.path.join(case_path, study)
                if not os.path.isdir(study_path):
                    continue

                imgs = [f for f in os.listdir(study_path) if f.endswith(".png")]
                if len(imgs) < min_images:
                    small_studies[(split, case, study)] = imgs

    return small_studies

small_studies = find_small_studies(base_dir)

if small_studies:
    print("⚠️ Studies with fewer than 4 images:")
    for (split, case, study), imgs in small_studies.items():
        print(f"- {split}/{case}/{study} -> {len(imgs)} images: {imgs}")
else:
    print("✅ All studies have at least 4 images.")


✅ All studies have at least 4 images.


In [10]:
import os
from collections import Counter

base_dir = "/content/birads_preprocessed_dataset"
splits = ["training", "test"]
classes = ["normal", "abnormal"]

# Dictionary to store study -> image count
study_image_counts = {}

for split in splits:
    split_path = os.path.join(base_dir, split)
    for cls in classes:
        cls_path = os.path.join(split_path, cls)
        if not os.path.exists(cls_path):
            continue
        for study in os.listdir(cls_path):
            study_path = os.path.join(cls_path, study)
            if os.path.isdir(study_path):
                images = [f for f in os.listdir(study_path) if f.endswith(".png")]
                study_image_counts[study] = len(images)

# Summarize the distribution of images per study
count_distribution = Counter(study_image_counts.values())
print("Image count per study distribution:")
for n_images, n_studies in sorted(count_distribution.items()):
    print(f"{n_images} images: {n_studies} studies")

# Optional: total studies
print(f"\nTotal studies counted: {len(study_image_counts)}")


Image count per study distribution:
4 images: 7999 studies

Total studies counted: 7999


In [11]:
import os
import pandas as pd
from tqdm import tqdm  # import tqdm

# Paths
base_dir = "/content/birads_preprocessed_dataset"
original_csv = os.path.join(base_dir, "image_df_upsampled_preprocessed.csv")
fixed_csv = os.path.join(base_dir, "image_df_preprocessed_fixed.csv")

# Load original CSV
df_orig = pd.read_csv(original_csv)

# Ensure string types for safe matching
df_orig["study_id"] = df_orig["study_id"].astype(str)
df_orig["filename"] = df_orig["filename"].astype(str)

# Prepare list for final rows
rows = []

splits = ["training", "test"]
classes = ["normal", "abnormal"]

for split in splits:
    split_path = os.path.join(base_dir, split)
    if not os.path.exists(split_path):
        continue

    for cls in classes:
        cls_path = os.path.join(split_path, cls)
        if not os.path.exists(cls_path):
            continue

        study_list = [s for s in os.listdir(cls_path) if os.path.isdir(os.path.join(cls_path, s))]
        for study in tqdm(study_list, desc=f"{split}/{cls} studies"):
            study_path = os.path.join(cls_path, study)

            images = sorted([f for f in os.listdir(study_path) if f.endswith(".png")])
            if len(images) != 4:
                continue  # only keep studies with exactly 4 images

            for img in images:
                # Try to get metadata from original CSV
                match = df_orig[(df_orig["study_id"] == study) &
                                (df_orig["filename"] == img)]
                if not match.empty:
                    row = match.iloc[0].copy()
                    row["image_path"] = os.path.join(study_path, img)  # update path
                else:
                    # If missing in original CSV, create minimal row with placeholders
                    row = {col: -1 for col in df_orig.columns}  # -1 as placeholder
                    row["study_id"] = study
                    row["filename"] = img
                    row["image_path"] = os.path.join(study_path, img)
                    row["split"] = split
                    row["case_category"] = cls

                rows.append(row)

# Build DataFrame
df_fixed = pd.DataFrame(rows)

# Save CSV
df_fixed.to_csv(fixed_csv, index=False)

# Summary
print(f"✅ Fixed CSV saved at: {fixed_csv}")
print(f"Total studies included: {df_fixed['study_id'].nunique()}")
print(f"Total images included: {len(df_fixed)}")


training/normal studies: 100%|██████████| 5065/5065 [02:05<00:00, 40.31it/s]
training/abnormal studies: 100%|██████████| 1934/1934 [00:46<00:00, 41.97it/s]
test/normal studies: 100%|██████████| 916/916 [00:23<00:00, 39.20it/s]
test/abnormal studies: 100%|██████████| 84/84 [00:01<00:00, 46.02it/s]


✅ Fixed CSV saved at: /content/birads_preprocessed_dataset/image_df_preprocessed_fixed.csv
Total studies included: 7999
Total images included: 31996


In [12]:
print(df_fixed.columns)

Index(['image_id', 'study_id', 'filename', 'birads', 'birads_dir', 'density',
       'laterality', 'view_position', 'split', 'finding_categories',
       'finding_birads_clean', 'xmin', 'ymin', 'xmax', 'ymax', 'has_bbox',
       'age', 'birads_binary', 'birads_cleaned', 'birads_study_level',
       'finding_mass', 'finding_suspicious_calcification',
       'finding_focal_asymmetry', 'finding_asymmetry',
       'finding_global_asymmetry', 'finding_architectural_distortion',
       'finding_skin_thickening', 'finding_skin_retraction',
       'finding_nipple_retraction', 'finding_suspicious_lymph_node',
       'finding_no_finding', 'image_path', 'case_category', 'upsampled',
       'preprocessed_path'],
      dtype='object')


In [13]:
import os
import pandas as pd

# Paths
base_dir = "/content/birads_preprocessed_dataset"
csv_path = os.path.join(base_dir, "image_df_preprocessed_fixed.csv")

# Load CSV metadata
df = pd.read_csv(csv_path)

# Build sets for quick lookup
expected_study_ids = set(df["study_id"].astype(str).unique())
expected_image_ids = set(df["image_id"].astype(str).unique())

issues = []

# Iterate over splits and classes
for split in ["training", "test"]:
    for cls in ["normal", "abnormal"]:
        cls_path = os.path.join(base_dir, split, cls)
        if not os.path.exists(cls_path):
            issues.append(f"Missing folder: {cls_path}")
            continue

        # Iterate over studies
        for study in os.listdir(cls_path):
            study_path = os.path.join(cls_path, study)
            if not os.path.isdir(study_path):
                continue

            # Validate study ID
            if study not in expected_study_ids:
                issues.append(f"Study folder '{study}' not found in CSV")

            # Validate image files
            for img in os.listdir(study_path):
                if img.endswith(".png"):
                    img_id = os.path.splitext(img)[0]  # remove extension
                    if img_id not in expected_image_ids:
                        issues.append(f"Image '{img}' in '{study_path}' not found in CSV")

# Summary
if not issues:
    print("✅ Dataset structure matches CSV metadata and is valid.")
else:
    print("⚠️ Issues found:")
    for issue in issues:
        print("-", issue)


✅ Dataset structure matches CSV metadata and is valid.


In [14]:
import pandas as pd
import os

# Paths
base_dir = "/content/birads_preprocessed_dataset"
csv_path = os.path.join(base_dir, "image_df_preprocessed_fixed.csv")

# Load CSV
df = pd.read_csv(csv_path)

# Iterate over all columns and print unique values and counts
for col in df.columns:
    unique_vals = df[col].nunique()
    print(f"{col}: {unique_vals} unique values")
    # Optionally, show top 10 most frequent values
    print(df[col].value_counts().head(10))
    print("-" * 50)


image_id: 19996 unique values
image_id
2bd9c72b886e97da1aff1361962c6acc    24
6266ffa44d75d2edc9d3c725b20b6d49    24
7dbf6830cc06730cfe74cd58937f89a8    24
2973bcf878fad1e9edade25be62602ce    24
a7acc2e02a4944c4fc72e32507b17fa7    22
85a6579cbdc403cfc4dde0a8149ed855    22
3cd51ee99070c4d625d52b848d5e9bfc    22
10e0f362333df810ac84a9db8fb3fd42    22
c0d6b03b2add28581aec656ad0d10613    21
b0124dae990a237fc01f625feea67a52    21
Name: count, dtype: int64
--------------------------------------------------
study_id: 7999 unique values
study_id
77d2b897870fd48aacdc9b2bbad1ef52            4
9b720facd58a23aef24bee0823ba5a27_dup2638    4
28eb668601ac25349e36c8cdd040b41b            4
7fa05485d8e90a042cd57d5bb2206b57            4
7d62d74422a284f22062355f3c772c8f            4
55a12c640ccebc100e67e21476b57285            4
fe23c1647f7617ef219a0a0e07c9eec5_dup2724    4
428b656fce3168763e8f2fccb4ffdfba_dup2685    4
317986c9303c4a6b9e6d015d67baf4bf            4
571e2132db7005421ab241f54439e7bb          

In [15]:
import torch
import numpy as np
import random
import os

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


In [16]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
import torchvision.transforms as T
from sklearn.model_selection import train_test_split

# ----------------------------
# Paths
# ----------------------------
base_dir = "/content/birads_preprocessed_dataset"
csv_path = os.path.join(base_dir, "image_df_preprocessed_fixed.csv")

# ----------------------------
# Load CSV and map labels
# ----------------------------
df = pd.read_csv(csv_path)
df["birads_binary"] = df["birads_binary"].map({"normal": 0, "abnormal": 1})  # map to numeric

# ----------------------------
# Image transforms
# ----------------------------
IMAGE_TRANSFORMS = {
    "train": T.Compose([
        T.Grayscale(num_output_channels=3),
        T.Resize((224, 224)),
        T.RandomHorizontalFlip(),
        T.RandomRotation(10),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406],
                    [0.229, 0.224, 0.225])
    ]),
    "val": T.Compose([
        T.Grayscale(num_output_channels=3),
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406],
                    [0.229, 0.224, 0.225])
    ]),
    "test": T.Compose([
        T.Grayscale(num_output_channels=3),
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406],
                    [0.229, 0.224, 0.225])
    ])
}

# ----------------------------
# Study-level Dataset
# ----------------------------
class MammogramStudyDataset(Dataset):
    def __init__(self, df, split="training", transform=None):
        self.df_split = df[df["split"] == split].copy()
        self.transform = transform

        # Group studies
        self.study_groups = {}
        for study_id, group in self.df_split.groupby("study_id"):
            valid_images = [row["image_path"] for _, row in group.iterrows() if os.path.exists(row["image_path"])]
            if len(valid_images) == 4:  # keep only complete studies
                self.study_groups[study_id] = {
                    "image_paths": valid_images,
                    "label": int(group["birads_binary"].iloc[0])
                }

        self.study_ids = list(self.study_groups.keys())

    def __len__(self):
        return len(self.study_ids)

    def __getitem__(self, idx):
        study_id = self.study_ids[idx]
        study_data = self.study_groups[study_id]

        images = []
        for img_path in study_data["image_paths"]:
            img = Image.open(img_path).convert("RGB")
            if self.transform:
                img = self.transform(img)
            images.append(img)

        images = torch.stack(images)  # shape: (4, C, H, W)
        label = study_data["label"]

        return images, label

# ----------------------------
# Create datasets
# ----------------------------
train_dataset = MammogramStudyDataset(df, split="training", transform=IMAGE_TRANSFORMS["train"])
val_dataset   = MammogramStudyDataset(df, split="training", transform=IMAGE_TRANSFORMS["val"])
test_dataset  = MammogramStudyDataset(df, split="test", transform=IMAGE_TRANSFORMS["test"])

# Train/Val split by study
train_ids, val_ids = train_test_split(train_dataset.study_ids, test_size=0.2, random_state=42)
train_dataset.study_ids = train_ids
val_dataset.study_ids = val_ids

# ----------------------------
# Dataloaders
# ----------------------------
batch_size = 4
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

# ----------------------------
# Summary
# ----------------------------
print("Train studies:", len(train_dataset))
print("Val studies:", len(val_dataset))
print("Test studies:", len(test_dataset))


Train studies: 5599
Val studies: 1400
Test studies: 1000


In [17]:
print("Total studies in training split:", len(train_dataset.study_ids))


Total studies in training split: 5599


In [18]:
import torch
import torch.nn as nn
import torchvision.models as models

class StudyLevelResNet(nn.Module):
    def __init__(self, backbone="resnet50", pretrained=True, num_classes=1):
        super(StudyLevelResNet, self).__init__()

        # Backbone selection
        if backbone == "resnet50":
            self.cnn = models.resnet50(pretrained=pretrained)
            in_features = self.cnn.fc.in_features
            self.cnn.fc = nn.Identity()  # remove final classification layer
        elif backbone == "efficientnet_b0":
            self.cnn = models.efficientnet_b0(pretrained=pretrained)
            in_features = self.cnn.classifier[1].in_features
            self.cnn.classifier = nn.Identity()
        else:
            raise ValueError(f"Unsupported backbone: {backbone}")

        # Study-level classification head
        self.fc = nn.Sequential(
            nn.Linear(in_features, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)  # binary classification
        )

    def forward(self, x):
        """
        x: (batch_size, num_views, C, H, W)
        num_views = 4 (study-level images)
        """
        batch_size, num_views, C, H, W = x.shape
        x = x.view(batch_size * num_views, C, H, W)  # flatten views into batch

        # Extract features per view
        feats = self.cnn(x)  # (batch_size*num_views, in_features)
        feats = feats.view(batch_size, num_views, -1)  # group by study

        # Fuse across views (mean pooling)
        fused = feats.mean(dim=1)  # (batch_size, in_features)

        # Classification head
        out = self.fc(fused).squeeze(1)  # (batch_size,)
        return out


In [19]:
import torch
import torch.nn as nn
import torch.optim as optim

# Model setup
model = StudyLevelResNet(backbone="resnet50", pretrained=True, num_classes=1)
model = model.to(device)

# Loss function
criterion = nn.BCEWithLogitsLoss()

# Optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

# Scheduler (without verbose)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="max",    # maximize monitored metric (e.g., val AUROC)
    factor=0.5,    # reduce LR by 50%
    patience=2     # wait 2 epochs before reducing
)

print(f"Model ready on device: {device}")


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


100%|██████████| 97.8M/97.8M [00:00<00:00, 181MB/s]


Model ready on device: cuda


In [20]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score, f1_score, brier_score_loss

# -----------------------------
# Metrics helper
# -----------------------------
def compute_metrics(y_true, y_pred_probs):
    y_pred = (y_pred_probs >= 0.5).astype(int)
    return {
        "AUROC": roc_auc_score(y_true, y_pred_probs),
        "AUPRC": average_precision_score(y_true, y_pred_probs),
        "Accuracy": accuracy_score(y_true, y_pred),
        "F1": f1_score(y_true, y_pred),
        "Brier": brier_score_loss(y_true, y_pred_probs)
    }

# -----------------------------
# Training one epoch
# -----------------------------
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    all_labels, all_preds = [], []

    loop = tqdm(loader, desc="Training", leave=False)
    for images, labels in loop:
        images, labels = images.to(device), labels.float().to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * images.size(0)
        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(torch.sigmoid(outputs).detach().cpu().numpy())

        loop.set_postfix(loss=loss.item())

    metrics = compute_metrics(np.array(all_labels), np.array(all_preds))
    return total_loss / len(loader.dataset), metrics

# -----------------------------
# Validation
# -----------------------------
def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    all_labels, all_preds = [], []

    loop = tqdm(loader, desc="Validation", leave=False)
    with torch.no_grad():
        for images, labels in loop:
            images, labels = images.to(device), labels.float().to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            total_loss += loss.item() * images.size(0)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(torch.sigmoid(outputs).cpu().numpy())

            loop.set_postfix(loss=loss.item())

    metrics = compute_metrics(np.array(all_labels), np.array(all_preds))
    return total_loss / len(loader.dataset), metrics

# -----------------------------
# Full training loop
# -----------------------------
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, epochs=10):
    best_val_auroc = 0.0

    for epoch in range(1, epochs + 1):
        print(f"\n=== Epoch {epoch}/{epochs} ===")

        train_loss, train_metrics = train_one_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_metrics = validate(model, val_loader, criterion, device)

        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
        print(f"Train AUROC: {train_metrics['AUROC']:.4f} | Val AUROC: {val_metrics['AUROC']:.4f}")
        print(f"Train Accuracy: {train_metrics['Accuracy']:.4f} | Val Accuracy: {val_metrics['Accuracy']:.4f}")

        # Step scheduler with validation AUROC
        scheduler.step(val_metrics['AUROC'])

        # Optional: save best model
        if val_metrics['AUROC'] > best_val_auroc:
            best_val_auroc = val_metrics['AUROC']
            torch.save(model.state_dict(), "best_studylevel_model.pth")
            print("✅ Saved new best model.")

    print(f"\nTraining complete. Best Val AUROC: {best_val_auroc:.4f}")


In [21]:
# -----------------------------
# Test evaluation
# -----------------------------
def evaluate_test(model, test_loader, device):
    model.eval()
    all_labels, all_preds = [], []

    loop = tqdm(test_loader, desc="Testing", leave=False)
    with torch.no_grad():
        for images, labels in loop:
            images, labels = images.to(device), labels.float().to(device)
            outputs = model(images)
            preds = torch.sigmoid(outputs)

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)
    metrics = compute_metrics(all_labels, all_preds)

    print("\n=== Test Metrics ===")
    print(f"AUROC  : {metrics['AUROC']:.4f}")
    print(f"AUPRC  : {metrics['AUPRC']:.4f}")
    print(f"Accuracy: {metrics['Accuracy']:.4f}")
    print(f"F1     : {metrics['F1']:.4f}")
    print(f"Brier  : {metrics['Brier']:.4f}")
    return metrics


In [22]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import numpy as np
import os

def run_training(model, train_loader, val_loader, test_loader, device,
                 epochs=10, lr=1e-4, weight_decay=1e-4, checkpoint_path="best_model.pth"):

    # Loss, optimizer, scheduler
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=2)

    best_val_auroc = 0.0

    for epoch in range(1, epochs + 1):
        # -----------------------------
        # Training
        # -----------------------------
        model.train()
        train_loss = 0.0
        all_labels, all_preds = [], []

        loop = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs} [Train]", leave=False)
        for images, labels in loop:
            images, labels = images.to(device), labels.float().to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * images.size(0)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(torch.sigmoid(outputs).detach().cpu().numpy())

            loop.set_postfix(loss=loss.item())

        train_loss /= len(train_loader.dataset)
        train_metrics = compute_metrics(np.array(all_labels), np.array(all_preds))

        # -----------------------------
        # Validation
        # -----------------------------
        model.eval()
        val_loss = 0.0
        all_labels, all_preds = [], []

        loop = tqdm(val_loader, desc=f"Epoch {epoch}/{epochs} [Val]", leave=False)
        with torch.no_grad():
            for images, labels in loop:
                images, labels = images.to(device), labels.float().to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)

                val_loss += loss.item() * images.size(0)
                all_labels.extend(labels.cpu().numpy())
                all_preds.extend(torch.sigmoid(outputs).cpu().numpy())

                loop.set_postfix(loss=loss.item())

        val_loss /= len(val_loader.dataset)
        val_metrics = compute_metrics(np.array(all_labels), np.array(all_preds))

        print(f"Epoch {epoch}/{epochs} | "
              f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | "
              f"Val AUROC: {val_metrics['AUROC']:.4f}")

        # -----------------------------
        # Scheduler & Checkpoint
        # -----------------------------
        scheduler.step(val_metrics['AUROC'])

        if val_metrics['AUROC'] > best_val_auroc:
            best_val_auroc = val_metrics['AUROC']
            torch.save(model.state_dict(), checkpoint_path)
            print(f"✅ Saved best model at epoch {epoch} (Val AUROC: {best_val_auroc:.4f})")

    # -----------------------------
    # Load best model & test
    # -----------------------------
    model.load_state_dict(torch.load(checkpoint_path))
    model.to(device)
    print("\n=== Evaluating on Test Set ===")
    test_metrics = evaluate_test(model, test_loader, device)

    return model, train_metrics, val_metrics, test_metrics


In [23]:
model = StudyLevelResNet(backbone="resnet50", pretrained=True).to(device)

trained_model, train_metrics, val_metrics, test_metrics = run_training(
    model,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    device=device,
    epochs=10,
    lr=1e-4,
    weight_decay=1e-4,
    checkpoint_path="best_studylevel_model.pth"
)




Epoch 1/10 | Train Loss: 0.5958 | Val Loss: 0.5862 | Val AUROC: 0.6374
✅ Saved best model at epoch 1 (Val AUROC: 0.6374)




Epoch 2/10 | Train Loss: 0.5874 | Val Loss: 0.6100 | Val AUROC: 0.6119




Epoch 3/10 | Train Loss: 0.5849 | Val Loss: 0.5740 | Val AUROC: 0.7017
✅ Saved best model at epoch 3 (Val AUROC: 0.7017)




Epoch 4/10 | Train Loss: 0.5786 | Val Loss: 0.5562 | Val AUROC: 0.7010




Epoch 5/10 | Train Loss: 0.5628 | Val Loss: 0.5354 | Val AUROC: 0.7271
✅ Saved best model at epoch 5 (Val AUROC: 0.7271)




Epoch 6/10 | Train Loss: 0.5456 | Val Loss: 0.5083 | Val AUROC: 0.7709
✅ Saved best model at epoch 6 (Val AUROC: 0.7709)




Epoch 7/10 | Train Loss: 0.5251 | Val Loss: 0.4836 | Val AUROC: 0.8002
✅ Saved best model at epoch 7 (Val AUROC: 0.8002)




Epoch 8/10 | Train Loss: 0.5018 | Val Loss: 0.4691 | Val AUROC: 0.8149
✅ Saved best model at epoch 8 (Val AUROC: 0.8149)




Epoch 9/10 | Train Loss: 0.4672 | Val Loss: 0.4159 | Val AUROC: 0.8552
✅ Saved best model at epoch 9 (Val AUROC: 0.8552)




Epoch 10/10 | Train Loss: 0.4363 | Val Loss: 0.4020 | Val AUROC: 0.8956
✅ Saved best model at epoch 10 (Val AUROC: 0.8956)

=== Evaluating on Test Set ===


                                                          


=== Test Metrics ===
AUROC  : 0.5866
AUPRC  : 0.1593
Accuracy: 0.9010
F1     : 0.1239
Brier  : 0.0879


