In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms, models
from sklearn.metrics import roc_curve, auc, roc_auc_score, confusion_matrix
import seaborn as sns
from tqdm.notebook import tqdm

In [2]:
DATA_DIR = "../data"  # Points to the base data directory
MODEL_SAVE_PATH = "../models/best_model.pth"
BATCH_SIZE = 32
NUM_EPOCHS = 40
LEARNING_RATE = 0.001
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
print(f"Using device: {DEVICE}")


Using device: cuda


In [4]:
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    # --- Data Augmentation ---
    transforms.RandomHorizontalFlip(),  # Randomly flip horizontally
    transforms.RandomRotation(10),       # Randomly rotate by up to 10 degrees
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)), #small shifts
    # --------------------------
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [5]:
# Cell 5
class LensingDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = ["no", "sphere", "vort"]
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
        self.samples = self._make_dataset()
        self.error_counts = {cls_name: 0 for cls_name in self.classes}

    def _make_dataset(self):
        samples = []
        for target_class in self.classes:
            class_dir = os.path.join(self.root_dir, target_class)
            if not os.path.isdir(class_dir):
                continue
            for filename in os.listdir(class_dir):
                if filename.endswith(".npy"):
                    path = os.path.join(class_dir, filename)
                    item = (path, self.class_to_idx[target_class])
                    samples.append(item)
        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        try:
            image_np = np.load(img_path)

            if len(image_np.shape) == 3:
                if image_np.shape == (1, 150, 150):
                    image_np = image_np[0]
                elif image_np.shape == (1, 1, 150):
                    square_size = int(np.sqrt(150))
                    image_np = image_np[0, 0, :square_size*square_size].reshape(square_size, square_size)
                    image_np = np.repeat(np.repeat(image_np, 150//square_size, axis=0), 150//square_size, axis=1)
                else:
                    image_np = image_np.reshape(image_np.shape[1], image_np.shape[2])

            if image_np.dtype == np.float64 or image_np.dtype == np.float32:
                image_np = (image_np - image_np.min()) / (image_np.max() - image_np.min() + 1e-8)
                image_np = (image_np * 255).astype(np.uint8)
            elif image_np.dtype != np.uint8:
                image_np = image_np.astype(np.uint8)

            if len(image_np.shape) == 2:
                image_pil = Image.fromarray(image_np).convert('RGB')
            elif len(image_np.shape) == 3 and image_np.shape[2] == 3:
                image_pil = Image.fromarray(image_np)
            else:
                image_pil = Image.fromarray(image_np).convert('RGB')

            if self.transform:
                image = self.transform(image_pil)
            else:
                image = transforms.ToTensor()(image_pil)
            return image, label
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            self.error_counts[self.classes[label]] += 1
            return torch.zeros((3, 224, 224)), label

    def report_errors(self):
        print("Data loading errors:")
        for cls_name, count in self.error_counts.items():
            print(f"  {cls_name}: {count}")


def collate_fn(batch):
    batch = [item for item in batch if item is not None and item[0] is not None]
    if len(batch) == 0:
        return torch.zeros((0, 3, 224, 224)), torch.zeros(0)
    return torch.utils.data.dataloader.default_collate(batch)

In [6]:
# Cell 6 (Experiment 1: Unfreeze layer3 and layer4)

train_transforms = data_transforms
val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

train_dataset = LensingDataset(root_dir=os.path.join(DATA_DIR, "dataset", "train"), transform=train_transforms)
val_dataset = LensingDataset(root_dir=os.path.join(DATA_DIR, "dataset", "val"), transform=val_transforms)

train_dataset.report_errors()
val_dataset.report_errors()

print("Training set class distribution:")
for i, cls_name in enumerate(train_dataset.classes):
    count = sum(1 for _, label in train_dataset.samples if label == i)
    print(f"  {cls_name}: {count}")

print("Validation set class distribution:")
for i, cls_name in enumerate(val_dataset.classes):
    count = sum(1 for _, label in val_dataset.samples if label == i)
    print(f"  {cls_name}: {count}")

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, collate_fn=collate_fn)

test_dataset = LensingDataset(root_dir=os.path.join(DATA_DIR, "dataset", "test"), transform=val_transforms)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, collate_fn=collate_fn)

test_dataset.report_errors()
print("Test set class distribution:")
for i, cls_name in enumerate(test_dataset.classes):
    count = sum(1 for _, label in test_dataset.samples if label == i)
    print(f"  {cls_name}: {count}")

# --- Model and Optimizer (Experiment 1) ---
num_classes = len(train_dataset.classes)
model = models.resnet18(pretrained=True)

# Unfreeze layer3 and layer4
for param in model.layer3.parameters():
    param.requires_grad = True
for param in model.layer4.parameters():
    param.requires_grad = True

# Optimizer with differential learning rates
optimizer = optim.Adam([
    {'params': model.layer3.parameters(), 'lr': 1e-5},
    {'params': model.layer4.parameters(), 'lr': 1e-5},
    {'params': model.fc.parameters(), 'lr': 1e-4}
], lr=1e-4)

model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(DEVICE)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)
criterion = nn.CrossEntropyLoss()

Data loading errors:
  no: 0
  sphere: 0
  vort: 0
Data loading errors:
  no: 0
  sphere: 0
  vort: 0
Training set class distribution:
  no: 10000
  sphere: 10000
  vort: 10000
Validation set class distribution:
  no: 2500
  sphere: 2500
  vort: 2500
Data loading errors:
  no: 0
  sphere: 0
  vort: 0
Test set class distribution:
  no: 0
  sphere: 0
  vort: 0




In [7]:
# Cell 8 (Revised - Learning Rate Once Per Epoch)
def train_model(model, criterion, optimizer, scheduler, train_loader, val_loader, num_epochs, device):
    best_val_auc = 0.0
    train_losses = []
    val_losses = []
    train_auc_scores = []
    val_auc_scores = []

    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")

        model.train()
        running_loss = 0.0
        all_labels = []
        all_probs = []
        for inputs, labels in tqdm(train_loader, desc="Training"):
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            all_labels.extend(labels.cpu().numpy())
            probs = torch.softmax(outputs, dim=1).detach().cpu().numpy()
            all_probs.extend(probs)

        epoch_loss = running_loss / len(train_loader.dataset)
        train_losses.append(epoch_loss)
        train_auc = roc_auc_score(all_labels, all_probs, multi_class='ovr')
        train_auc_scores.append(train_auc)
        print(f"  Train Loss: {epoch_loss:.4f}  Train AUC: {train_auc:.4f}")

        model.eval()
        running_loss = 0.0
        all_labels = []
        all_probs = []

        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc="Validation"):
                inputs = inputs.to(device)
                labels = labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                running_loss += loss.item() * inputs.size(0)
                all_labels.extend(labels.cpu().numpy())
                probs = torch.softmax(outputs, dim=1).detach().cpu().numpy()
                all_probs.extend(probs)

        epoch_loss = running_loss / len(val_loader.dataset)
        val_losses.append(epoch_loss)
        val_auc = roc_auc_score(all_labels, all_probs, multi_class='ovr')
        val_auc_scores.append(val_auc)
        print(f"  Val Loss: {epoch_loss:.4f}    Val AUC: {val_auc:.4f}")

        scheduler.step(epoch_loss)

        if val_auc > best_val_auc:
            best_val_auc = val_auc
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
            print(f"  Best model saved at epoch {epoch+1}")

        # ADDED: Print learning rate *after* validation and saving
        print(f"  Current Learning Rate: {optimizer.param_groups[0]['lr']}")

    return train_losses, val_losses, train_auc_scores, val_auc_scores

train_losses, val_losses, train_auc_scores, val_auc_scores = train_model(model, criterion, optimizer, scheduler, train_loader, val_loader, NUM_EPOCHS, DEVICE)

Epoch 1/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 1.1226  Train AUC: 0.5399


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 1.0774    Val AUC: 0.5831
  Best model saved at epoch 1
  Current Learning Rate: 1e-05
Epoch 2/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 1.0605  Train AUC: 0.6083


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 1.0504    Val AUC: 0.6364
  Best model saved at epoch 2
  Current Learning Rate: 1e-05
Epoch 3/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 1.0295  Train AUC: 0.6462


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 1.0346    Val AUC: 0.6618
  Best model saved at epoch 3
  Current Learning Rate: 1e-05
Epoch 4/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 1.0080  Train AUC: 0.6658


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 1.0265    Val AUC: 0.6765
  Best model saved at epoch 4
  Current Learning Rate: 1e-05
Epoch 5/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.9947  Train AUC: 0.6779


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.9783    Val AUC: 0.6961
  Best model saved at epoch 5
  Current Learning Rate: 1e-05
Epoch 6/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.9780  Train AUC: 0.6924


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.9811    Val AUC: 0.7045
  Best model saved at epoch 6
  Current Learning Rate: 1e-05
Epoch 7/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.9681  Train AUC: 0.7003


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.9640    Val AUC: 0.7133
  Best model saved at epoch 7
  Current Learning Rate: 1e-05
Epoch 8/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.9600  Train AUC: 0.7067


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.9402    Val AUC: 0.7280
  Best model saved at epoch 8
  Current Learning Rate: 1e-05
Epoch 9/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.9479  Train AUC: 0.7157


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.9402    Val AUC: 0.7271
  Current Learning Rate: 1e-05
Epoch 10/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.9395  Train AUC: 0.7217


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.9206    Val AUC: 0.7349
  Best model saved at epoch 10
  Current Learning Rate: 1e-05
Epoch 11/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.9290  Train AUC: 0.7285


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.9163    Val AUC: 0.7435
  Best model saved at epoch 11
  Current Learning Rate: 1e-05
Epoch 12/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.9189  Train AUC: 0.7362


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.9089    Val AUC: 0.7541
  Best model saved at epoch 12
  Current Learning Rate: 1e-05
Epoch 13/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.9111  Train AUC: 0.7422


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.8955    Val AUC: 0.7507
  Current Learning Rate: 1e-05
Epoch 14/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.9036  Train AUC: 0.7466


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.8855    Val AUC: 0.7621
  Best model saved at epoch 14
  Current Learning Rate: 1e-05
Epoch 15/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.8982  Train AUC: 0.7500


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.8853    Val AUC: 0.7624
  Best model saved at epoch 15
  Current Learning Rate: 1e-05
Epoch 16/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.8892  Train AUC: 0.7559


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.8934    Val AUC: 0.7648
  Best model saved at epoch 16
  Current Learning Rate: 1e-05
Epoch 17/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.8851  Train AUC: 0.7584


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.8531    Val AUC: 0.7815
  Best model saved at epoch 17
  Current Learning Rate: 1e-05
Epoch 18/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.8797  Train AUC: 0.7615


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.8514    Val AUC: 0.7809
  Current Learning Rate: 1e-05
Epoch 19/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.8747  Train AUC: 0.7649


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.8608    Val AUC: 0.7754
  Current Learning Rate: 1e-05
Epoch 20/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.8623  Train AUC: 0.7724


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.8580    Val AUC: 0.7839
  Best model saved at epoch 20
  Current Learning Rate: 1e-05
Epoch 21/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.8563  Train AUC: 0.7755


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.8448    Val AUC: 0.7856
  Best model saved at epoch 21
  Current Learning Rate: 1e-05
Epoch 22/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.8530  Train AUC: 0.7781


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.8327    Val AUC: 0.7934
  Best model saved at epoch 22
  Current Learning Rate: 1e-05
Epoch 23/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.8448  Train AUC: 0.7820


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.8320    Val AUC: 0.7933
  Current Learning Rate: 1e-05
Epoch 24/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.8417  Train AUC: 0.7839


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.8148    Val AUC: 0.8005
  Best model saved at epoch 24
  Current Learning Rate: 1e-05
Epoch 25/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.8398  Train AUC: 0.7854


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.8273    Val AUC: 0.8007
  Best model saved at epoch 25
  Current Learning Rate: 1e-05
Epoch 26/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.8344  Train AUC: 0.7881


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.8086    Val AUC: 0.8038
  Best model saved at epoch 26
  Current Learning Rate: 1e-05
Epoch 27/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.8279  Train AUC: 0.7913


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.8119    Val AUC: 0.8015
  Current Learning Rate: 1e-05
Epoch 28/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.8247  Train AUC: 0.7937


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.7978    Val AUC: 0.8094
  Best model saved at epoch 28
  Current Learning Rate: 1e-05
Epoch 29/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.8211  Train AUC: 0.7957


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.8048    Val AUC: 0.8074
  Current Learning Rate: 1e-05
Epoch 30/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.8121  Train AUC: 0.8013


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.8146    Val AUC: 0.8084
  Current Learning Rate: 1e-05
Epoch 31/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.8123  Train AUC: 0.8003


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.7970    Val AUC: 0.8128
  Best model saved at epoch 31
  Current Learning Rate: 1e-05
Epoch 32/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.8077  Train AUC: 0.8020


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.7914    Val AUC: 0.8107
  Current Learning Rate: 1e-05
Epoch 33/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.8030  Train AUC: 0.8053


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.7790    Val AUC: 0.8190
  Best model saved at epoch 33
  Current Learning Rate: 1e-05
Epoch 34/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.7990  Train AUC: 0.8071


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.7933    Val AUC: 0.8146
  Current Learning Rate: 1e-05
Epoch 35/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.7918  Train AUC: 0.8112


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.7836    Val AUC: 0.8185
  Current Learning Rate: 1e-05
Epoch 36/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.7894  Train AUC: 0.8130


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.7720    Val AUC: 0.8232
  Best model saved at epoch 36
  Current Learning Rate: 1e-05
Epoch 37/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.7851  Train AUC: 0.8152


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.7839    Val AUC: 0.8189
  Current Learning Rate: 1e-05
Epoch 38/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.7803  Train AUC: 0.8172


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.7798    Val AUC: 0.8193
  Current Learning Rate: 1e-05
Epoch 39/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.7773  Train AUC: 0.8189


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.7602    Val AUC: 0.8278
  Best model saved at epoch 39
  Current Learning Rate: 1e-05
Epoch 40/40


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.7765  Train AUC: 0.8187


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.7709    Val AUC: 0.8228
  Current Learning Rate: 1e-05


In [None]:
NUM_EPOCHS = 15

# Continue training!  Do NOT re-initialize the model or optimizer.
train_losses_cont, val_losses_cont, train_auc_scores_cont, val_auc_scores_cont = train_model(
    model, criterion, optimizer, scheduler, train_loader, val_loader, NUM_EPOCHS, DEVICE
)

Epoch 1/15


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.7737  Train AUC: 0.8202


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.7591    Val AUC: 0.8286
  Best model saved at epoch 1
  Current Learning Rate: 1e-05
Epoch 2/15


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.7724  Train AUC: 0.8214


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.7643    Val AUC: 0.8313
  Best model saved at epoch 2
  Current Learning Rate: 1e-05
Epoch 3/15


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.7672  Train AUC: 0.8240


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.7521    Val AUC: 0.8347
  Best model saved at epoch 3
  Current Learning Rate: 1e-05
Epoch 4/15


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.7632  Train AUC: 0.8255


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.7678    Val AUC: 0.8260
  Current Learning Rate: 1e-05
Epoch 5/15


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.7613  Train AUC: 0.8268


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.7701    Val AUC: 0.8276
  Current Learning Rate: 1e-05
Epoch 6/15


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.7553  Train AUC: 0.8300


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.7486    Val AUC: 0.8353
  Best model saved at epoch 6
  Current Learning Rate: 1e-05
Epoch 7/15


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.7552  Train AUC: 0.8289


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.7452    Val AUC: 0.8357
  Best model saved at epoch 7
  Current Learning Rate: 1e-05
Epoch 8/15


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.7491  Train AUC: 0.8314


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.7498    Val AUC: 0.8359
  Best model saved at epoch 8
  Current Learning Rate: 1e-05
Epoch 9/15


Training:   0%|          | 0/938 [00:00<?, ?it/s]

  Train Loss: 0.7484  Train AUC: 0.8326


Validation:   0%|          | 0/235 [00:00<?, ?it/s]

  Val Loss: 0.7448    Val AUC: 0.8370
  Best model saved at epoch 9
  Current Learning Rate: 1e-05
Epoch 10/15


Training:   0%|          | 0/938 [00:00<?, ?it/s]

In [None]:
train_transforms = data_transforms
val_transforms = transforms.Compose([  # Keep validation transforms simple
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
train_dataset = LensingDataset(root_dir=os.path.join(DATA_DIR, "dataset", "train"), transform=train_transforms)
val_dataset = LensingDataset(root_dir=os.path.join(DATA_DIR, "dataset", "val"), transform=val_transforms)
test_dataset = LensingDataset(root_dir=os.path.join(DATA_DIR, "dataset", "test"), transform=val_transforms)


train_dataset.report_errors()
val_dataset.report_errors()
test_dataset.report_errors()

print("Training set class distribution:")
for i, cls_name in enumerate(train_dataset.classes):
    count = sum(1 for _, label in train_dataset.samples if label == i)
    print(f"  {cls_name}: {count}")

print("Validation set class distribution:")
for i, cls_name in enumerate(val_dataset.classes):
    count = sum(1 for _, label in val_dataset.samples if label == i)
    print(f"  {cls_name}: {count}")
print("Test set class distribution:")
for i, cls_name in enumerate(test_dataset.classes):
    count = sum(1 for _, label in test_dataset.samples if label == i)
    print(f"  {cls_name}: {count}")

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, collate_fn=collate_fn)


# Define layer groups
def get_layer_groups(model):
    layer_groups = [
        list(model.conv1.parameters()) + list(model.bn1.parameters()),  # Early layers
        list(model.layer1.parameters()) + list(model.layer2.parameters()),  # Middle layers
        list(model.layer3.parameters()) + list(model.layer4.parameters()),  # Late layers
        list(model.fc.parameters())  # Fully connected layer
    ]
    return layer_groups

# Define learning rate grid
lr_grid = {
    'early': [1e-7, 1e-6],
    'middle': [1e-6, 1e-5],
    'late': [1e-5, 1e-4],
    'fc': [1e-4, 1e-3]
}

best_val_auc = 0.0
best_lrs = {}

# Iterate through all combinations of learning rates
for lr_early in lr_grid['early']:
    for lr_middle in lr_grid['middle']:
        for lr_late in lr_grid['late']:
            for lr_fc in lr_grid['fc']:

                print(f"Training with LR: early={lr_early}, middle={lr_middle}, late={lr_late}, fc={lr_fc}")

                # Create a *new* model and optimizer for each combination
                model = models.resnet18(pretrained=True)
                for param in model.parameters():
                    param.requires_grad = True  # Unfreeze all layers
                model.fc = nn.Linear(model.fc.in_features, len(train_dataset.classes))
                model = model.to(DEVICE)

                layer_groups = get_layer_groups(model)
                optimizer = optim.Adam([
                    {'params': layer_groups[0], 'lr': lr_early},
                    {'params': layer_groups[1], 'lr': lr_middle},
                    {'params': layer_groups[2], 'lr': lr_late},
                    {'params': layer_groups[3], 'lr': lr_fc}
                ])

                scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=False)
                criterion = nn.CrossEntropyLoss()


                # Train for a small number of epochs
                train_losses, val_losses, train_auc_scores, val_auc_scores = train_model(
                    model, criterion, optimizer, scheduler, train_loader, val_loader, NUM_EPOCHS, DEVICE
                )

                # Get the *best* validation AUC from this training run
                current_best_val_auc = max(val_auc_scores)
                print(f"  Best Validation AUC: {current_best_val_auc:.4f}")

                # If this is the best combination so far, save it
                if current_best_val_auc > best_val_auc:
                    best_val_auc = current_best_val_auc
                    best_lrs = {'early': lr_early, 'middle': lr_middle, 'late': lr_late, 'fc': lr_fc}
                    print(f"  New best AUC: {best_val_auc:.4f} with LR: {best_lrs}")
                    torch.save(model.state_dict(), MODEL_SAVE_PATH) #save the best model from the grid search

print(f"Best Learning Rates: {best_lrs}")
print(f"Best Validation AUC: {best_val_auc:.4f}")

NameError: name 'plot_training_history' is not defined

In [None]:
def plot_training_history(train_losses, val_losses, train_auc, val_auc):
    epochs = range(1, len(train_losses) + 1)

    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, 'b-', label='Training Loss')
    plt.plot(epochs, val_losses, 'r-', label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_auc, 'b-', label='Training AUC')
    plt.plot(epochs, val_auc, 'r-', label='Validation AUC')
    plt.title('Training and Validation AUC')
    plt.xlabel('Epochs')
    plt.ylabel('AUC')
    plt.legend()

    plt.tight_layout()
    plt.show()

SyntaxError: invalid syntax. Perhaps you forgot a comma? (2320105060.py, line 25)

In [None]:
def train_model(model, criterion, optimizer, scheduler, train_loader, val_loader, num_epochs, device):
    best_val_auc_run = 0.0  # Keep track of best AUC *within* the run
    train_losses = []
    val_losses = []
    train_auc_scores = []
    val_auc_scores = []

    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")

        model.train()
        running_loss = 0.0
        all_labels = []
        all_probs = []
        for inputs, labels in tqdm(train_loader, desc="Training", leave=False):
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            all_labels.extend(labels.cpu().numpy())
            probs = torch.softmax(outputs, dim=1).detach().cpu().numpy()
            all_probs.extend(probs)

        epoch_loss = running_loss / len(train_loader.dataset)
        train_losses.append(epoch_loss)
        train_auc = roc_auc_score(all_labels, all_probs, multi_class='ovr')
        train_auc_scores.append(train_auc)
        print(f"  Train Loss: {epoch_loss:.4f}  Train AUC: {train_auc:.4f}")

        model.eval()
        running_loss = 0.0
        all_labels = []
        all_probs = []

        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc="Validation", leave=False):
                inputs = inputs.to(device)
                labels = labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                running_loss += loss.item() * inputs.size(0)
                all_labels.extend(labels.cpu().numpy())
                probs = torch.softmax(outputs, dim=1).detach().cpu().numpy()
                all_probs.extend(probs)

        epoch_loss = running_loss / len(val_loader.dataset)
        val_losses.append(epoch_loss)
        val_auc = roc_auc_score(all_labels, all_probs, multi_class='ovr')
        val_auc_scores.append(val_auc)
        print(f"  Val Loss: {epoch_loss:.4f}    Val AUC: {val_auc:.4f}")

        scheduler.step(epoch_loss)
         # Track best val AUC within the run
        if val_auc > best_val_auc_run:
            best_val_auc_run = val_auc


        print(f"  Current Learning Rate: {optimizer.param_groups[0]['lr']}")

    return train_losses, val_losses, train_auc_scores, val_auc_scores

Evaluating: 0it [00:00, ?it/s]

ValueError: Found array with 0 sample(s) (shape=(0,)) while a minimum of 1 is required.

In [13]:
#Cell 11
def plot_roc_curve(labels, probabilities, num_classes):
    fpr = dict()
    tpr = dict()
    roc_auc = dict()

    for i in range(num_classes):
        class_labels = [1 if label == i else 0 for label in labels]
        fpr[i], tpr[i], _ = roc_curve(class_labels, [prob[i] for prob in probabilities])
        roc_auc[i] = auc(fpr[i], tpr[i])

    plt.figure(figsize=(8, 6))
    for i in range(num_classes):
        plt.plot(fpr[i], tpr[i], label=f'Class {i} (AUC = {roc_auc[i]:.2f})')

    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Multi-Class ROC Curve')
    plt.legend(loc="lower right")
    plt.show()

plot_roc_curve(test_labels, test_probs, num_classes)

ValueError: y_true takes value in {} and pos_label is not specified: either make y_true take value in {0, 1} or {-1, 1} or pass pos_label explicitly.

In [None]:
# Cell 12

def plot_confusion_matrix(labels, probabilities, classes):
     preds = np.argmax(probabilities, axis=1)
     cm = confusion_matrix(labels, preds)
     plt.figure(figsize=(8, 6))
     sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=classes, yticklabels=classes)
     plt.title("Confusion Matrix")
     plt.xlabel("Predicted Label")
     plt.ylabel("True Label")
     plt.show()
plot_confusion_matrix(test_labels, test_probs, train_dataset.classes)