In [ ]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from CustomDataset import CustomImageDataset
from Models import ResNet50, ResNet101, ResNet50_MCDropout
from utils.Plots import plot_confusion, plot_roc, reliability_diagram, expected_calibration_error, brier_score, entropy_hist

In [ ]:
# Configuration
mode = "Binary"
model_name = "resnet50"
epochs = 10
batch_size = 32
seed = 777
save_dir = "results/baseline"

In [ ]:
# Dataset sanity check
ds_train = CustomImageDataset(mode=mode, build_div='train')
ds_val   = CustomImageDataset(mode=mode, build_div='val')
ds_test  = CustomImageDataset(mode=mode, build_div='test')
print(f"Train: {len(ds_train)}, Val: {len(ds_val)}, Test: {len(ds_test)}")

In [ ]:
# Model instantiation
model = ResNet50(input_channel=3, label_num=1)
model

In [ ]:
# Training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader = DataLoader(ds_train, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(ds_val, batch_size=batch_size, shuffle=False)
model = model.to(device)
import torch.nn as nn, torch.optim as optim
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.BCEWithLogitsLoss()
for epoch in range(1, epochs + 1):
    model.train()
    total_loss = 0.0
    correct = 0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device).float().unsqueeze(1)
        preds = model(imgs)
        loss = criterion(preds, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        correct += ((torch.sigmoid(preds) >= 0.5).int() == labels.int()).sum().item()
    avg_loss = total_loss / len(train_loader)
    acc = correct / len(ds_train)
    print(f"Epoch {epoch}/{epochs}: Loss={avg_loss:.4f}, Acc={acc:.4f}")

In [ ]:
# Validation & Save
model.eval()
val_loss = 0.0
correct = 0
with torch.no_grad():
    for imgs, labels in val_loader:
        imgs, labels = imgs.to(device), labels.to(device).float().unsqueeze(1)
        preds = model(imgs)
        loss = criterion(preds, labels)
        val_loss += loss.item()
        correct += ((torch.sigmoid(preds) >= 0.5).int() == labels.int()).sum().item()
print(f"Val Loss: {val_loss/len(val_loader):.4f}, Val Acc: {correct/len(ds_val):.4f}")
import os
os.makedirs(save_dir, exist_ok=True)
torch.save(model.state_dict(), f"{save_dir}/best_{model_name}.pt")
print("Saved model to", f"{save_dir}/best_{model_name}.pt")

In [ ]:
# Evaluation & Calibration
test_loader = DataLoader(ds_test, batch_size=batch_size, shuffle=False)
model.load_state_dict(torch.load(f"{save_dir}/best_{model_name}.pt", map_location=device))
model.eval()
all_labels, all_probs = [], []
with torch.no_grad():
    for imgs, labels in test_loader:
        preds = model(imgs.to(device)).squeeze(1)
        probs = torch.sigmoid(preds).cpu().numpy()
        all_probs.extend(probs)
        all_labels.extend(labels.numpy())
plot_confusion(all_labels, np.array(all_probs)>=0.5)
plot_roc(all_labels, np.array(all_probs))
reliability_diagram(all_labels, np.array(all_probs))
print("ECE:", expected_calibration_error(all_labels, np.array(all_probs)))
print("Brier Score:", brier_score(all_labels, np.array(all_probs)))
entropy_hist(np.array(all_probs))