In [None]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from CustomDataset import CustomImageDataset
from Models import ResNet50, ResNet101, ResNet50_MCDropout
from utils.Plots import (
    plot_confusion, plot_roc, reliability_diagram,
    expected_calibration_error, brier_score, entropy_hist
)

In [None]:
# Configuration
mode = "Binary"                 # e.g., "Binary", "External", "Merged"
model_name = "resnet50"         # "resnet50" | "resnet101" | "resnet50_mcdo"
epochs = 10
batch_size = 32
seed = 777
save_dir = "results/baseline"

In [None]:
# Dataset sanity check (uses CustomDataset defaults)
ds_train = CustomImageDataset(mode=mode, build_div='train')
ds_val   = CustomImageDataset(mode=mode, build_div='val')
ds_test  = CustomImageDataset(mode=mode, build_div='test')
print(f"Train: {len(ds_train)}, Val: {len(ds_val)}, Test: {len(ds_test)}")

In [None]:
# Model instantiation
if model_name.lower() == "resnet50":
    model = ResNet50(input_channel=3, label_num=1)
elif model_name.lower() == "resnet101":
    model = ResNet101(input_channel=3, label_num=1)
elif model_name.lower() in ("resnet50_mcdo", "resnet50_mcdropout"):
    model = ResNet50_MCDropout(input_channel=3, label_num=1)
else:
    raise ValueError(f"Unknown model type: {model_name}")
model

In [None]:
# Training (simple, readable loop)
import os
import torch.nn as nn
import torch.optim as optim

os.makedirs(save_dir, exist_ok=True)
torch.manual_seed(seed); np.random.seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

train_loader = DataLoader(ds_train, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(ds_val,   batch_size=batch_size, shuffle=False)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(1, epochs + 1):
    model.train(); running_loss = 0.0; correct = 0
    for x, y in train_loader:
        x = x.to(device); y = y.to(device).float().unsqueeze(1)
        logits = model(x)
        loss = criterion(logits, y)
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        running_loss += loss.item()
        correct += ((torch.sigmoid(logits) >= 0.5).int() == y.int()).sum().item()
    print(f"Epoch {epoch}/{epochs} | loss={running_loss/len(train_loader):.4f} | acc={correct/len(ds_train):.4f}")

In [None]:
# Validation & save best
model.eval(); val_loss = 0.0; correct = 0
with torch.no_grad():
    for x, y in val_loader:
        x = x.to(device); y = y.to(device).float().unsqueeze(1)
        logits = model(x)
        val_loss += criterion(logits, y).item()
        correct += ((torch.sigmoid(logits) >= 0.5).int() == y.int()).sum().item()
print(f"Val | loss={val_loss/len(val_loader):.4f} | acc={correct/len(ds_val):.4f}")
torch.save(model.state_dict(), f"{save_dir}/best_{model_name}.pt")
print("Saved:", f"{save_dir}/best_{model_name}.pt")

In [None]:
# Evaluation & calibration
test_loader = DataLoader(ds_test, batch_size=batch_size, shuffle=False)
model.load_state_dict(torch.load(f"{save_dir}/best_{model_name}.pt", map_location=device))
model.eval()

all_y, all_p = [], []
with torch.no_grad():
    for x, y in test_loader:
        p = torch.sigmoid(model(x.to(device)).squeeze(1)).cpu().numpy()
        all_p.extend(p); all_y.extend(y.numpy())

all_p = np.array(all_p); all_y = np.array(all_y)
plot_confusion(all_y, (all_p >= 0.5).astype(int))
plot_roc(all_y, all_p)
reliability_diagram(all_y, all_p)
print("ECE:", expected_calibration_error(all_y, all_p))
print("Brier:", brier_score(all_y, all_p))
_ = entropy_hist(all_p)