In [None]:
import os
import numpy as np
from glob import glob
from PIL import Image
from tqdm import tqdm
import torch
import random
from torch.utils.data import Dataset, DataLoader, random_split,Subset
from torchvision import transforms
import matplotlib.pyplot as plt
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR,CosineAnnealingLR,SequentialLR,LinearLR
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, f1_score
from torchvision import models
from torchvision.models import (
    EfficientNet_B0_Weights,
    EfficientNet_B1_Weights,
    EfficientNet_B2_Weights,
    EfficientNet_B3_Weights,
    EfficientNet_B4_Weights)
from torch.nn.functional import sigmoid
import cv2

import seaborn as sns

In [100]:
SEED = 42


random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [7]:
def compute_mean_std(root_dir):
    exts = ("*.png",)
    sum_   = np.zeros(3, dtype=np.float64)
    sum_sq = np.zeros(3, dtype=np.float64)
    cnt    = 0

    files = []
    for split in ("train", "test"):
        for cls in ("normal","cataract"):
            for ext in exts:
                files += glob(os.path.join(root_dir, "processed_images", split, cls, ext))

    for f in tqdm(files, desc="Computing mean/std"):
        img = np.array(Image.open(f).convert("RGB")) / 255.0
        sum_   += img.mean(axis=(0,1))
        sum_sq += (img**2).mean(axis=(0,1))
        cnt    += 1

    mean = sum_ / cnt
    var  = (sum_sq / cnt) - mean**2
    std  = np.sqrt(var)
    return mean.tolist(), std.tolist()

In [99]:
_MEAN, _STD=[0.6257231324993875, 0.4934742948338769, 0.42569583700621416],[0.25667137400692847, 0.2345312511218496, 0.2305881956020596]

In [113]:
class ImageDataset(Dataset):
    def __init__(self, dir, split="train", transform=None):
        self.transform = transform
        self.samples = []
        split_dir = os.path.join(dir, "processed_images", split)
        for label_name, label_idx in [("normal", 0), ("cataract", 1)]:
            for img_path in glob(os.path.join(split_dir, label_name, "*.png")):
                self.samples.append((img_path, label_idx))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, torch.tensor(label, dtype=torch.float32)

In [None]:
def get_dataloaders(
    root,
    batch_size=32,
    val_frac=0.2,
    num_workers=2,
    train_transform=None,
    val_transform=None
):

    raw_ds = ImageDataset(root, "train", transform=None)
    n = len(raw_ds)
    val_size = int(n * val_frac)
    train_size = n - val_size

    all_indices = list(range(n))
    train_indices, val_indices = random_split(all_indices, [train_size, val_size])

    train_ds = Subset(
        ImageDataset(root, "train", transform=train_transform),
        train_indices
    )
    val_ds = Subset(
        ImageDataset(root, "train", transform=val_transform),
        val_indices
    )
    test_ds = ImageDataset(root, "test", transform=val_transform)

    return (
        DataLoader(train_ds, batch_size, shuffle=True,  num_workers=num_workers, pin_memory=True),
        DataLoader(val_ds,   batch_size, shuffle=False, num_workers=num_workers, pin_memory=True),
        DataLoader(test_ds,  batch_size, shuffle=False, num_workers=num_workers, pin_memory=True),
    )

In [93]:
def build_model(backbone="resnet50", pretrained=True):
    if backbone.startswith("resnet"):
        model = getattr(models, backbone)(pretrained=pretrained)
        in_feats = model.fc.in_features
        model.fc = nn.Sequential(
            nn.Linear(in_feats, 1)
        
        )

    elif backbone.startswith("efficientnet"):
        if pretrained:
            weights_map = {
                "efficientnet_b0": EfficientNet_B0_Weights.DEFAULT,
                "efficientnet_b1": EfficientNet_B1_Weights.DEFAULT,
                "efficientnet_b2": EfficientNet_B2_Weights.DEFAULT,
                "efficientnet_b3": EfficientNet_B3_Weights.DEFAULT,
                "efficientnet_b4": EfficientNet_B4_Weights.DEFAULT
            }
            weights = weights_map.get(backbone, None)
            if weights is None:
                raise ValueError(f"Pretrained weights not available for {backbone}")
            model = getattr(models, backbone)(weights=weights)
        else:
            model = getattr(models, backbone)(weights=None)

        in_feats = model.classifier[1].in_features
        model.classifier = nn.Sequential(
            nn.Dropout(0.4),
            nn.Linear(in_feats, 1)
        
        )

    else:
        raise ValueError(f"Unsupported backbone: {backbone}")

    return model

In [94]:
def train_single_epoch(model, train_loader, optimizer, scheduler, loss_fn, device):
    model.train()
    t_loss, preds, lbl = 0, [], []

    for imgs, labs in tqdm(train_loader, desc="Training", leave=False):
        imgs, labs = imgs.to(device), labs.to(device).unsqueeze(1).float()

        optimizer.zero_grad()
        logits = model(imgs)
        loss = loss_fn(logits, labs)
        loss.backward()
        optimizer.step()
        scheduler.step()


        probs = torch.sigmoid(logits.detach().cpu())
        labs_cpu = labs.detach().cpu()

        if probs.ndim > 1:
            probs = probs.squeeze(1)
        if labs_cpu.ndim > 1:
            labs_cpu = labs_cpu.squeeze(1)

        preds += probs.tolist()
        lbl += labs_cpu.tolist()
        t_loss += loss.item() * imgs.size(0)

    avg_loss = t_loss / len(train_loader.dataset)
    auc = roc_auc_score(lbl, preds)
    acc = accuracy_score(lbl, [p > 0.5 for p in preds])
    return avg_loss, auc, acc

In [95]:
def eval_one_epoch(model, loader, loss_fn, device):
    model.eval()
    preds, lbl = [], []
    total_loss = 0.0

    with torch.no_grad():
        for imgs, labs in loader:
            imgs, labs = imgs.to(device), labs.to(device).unsqueeze(1).float()

            logits = model(imgs)
            loss = loss_fn(logits, labs)
            total_loss += loss.item() * imgs.size(0)

            probs = torch.sigmoid(logits.detach().cpu())
            labs_cpu = labs.detach().cpu()
            if probs.ndim > 1:   probs = probs.squeeze(1)
            if labs_cpu.ndim > 1: labs_cpu = labs_cpu.squeeze(1)

            preds += probs.tolist()
            lbl   += labs_cpu.tolist()

    avg_loss = total_loss / len(loader.dataset)
    auc      = roc_auc_score(lbl, preds)
    acc      = accuracy_score(lbl, [p>0.5 for p in preds])
    return avg_loss, auc, acc, preds, lbl

In [None]:
def evaluate_model_outputs(y_true, y_pred_logits, threshold=0.5):
    
    probs = torch.sigmoid(torch.tensor(y_pred_logits)).numpy()
    preds = (probs > threshold).astype(int)
    y_true = np.array(y_true)

    auc = roc_auc_score(y_true, probs)
    f1 = f1_score(y_true, preds)
    report = classification_report(y_true, preds, target_names=["Normal", "Cataract"])
    cm = confusion_matrix(y_true, preds)

    print("\n[Classification Report]\n")
    print(report)
    print(f"ROC AUC Score: {auc:.4f}")
    print(f"F1 Score: {f1:.4f}\n")


    plt.figure(figsize=(5,4))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=["Normal", "Cataract"], yticklabels=["Normal", "Cataract"])
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.title("Confusion Matrix")
    plt.tight_layout()
    plt.show()

    return {
        "classification_report": report,
        "roc_auc": auc,
        "f1_score": f1,
        "confusion_matrix": cm
    }


In [96]:
clean_transforms_iter1 = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
])

train_transforms_iter1 = transforms.Compose([
    transforms.Resize(400),
    transforms.CenterCrop(380),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(0.1, 0.1),
    transforms.RandomAdjustSharpness(0.7),
    transforms.ToTensor(),
    transforms.Normalize(_MEAN, _STD)
])

val_transforms_iter1 = transforms.Compose([
    transforms.Resize(400),
    transforms.CenterCrop(380),
    transforms.ToTensor(),
    transforms.Normalize(_MEAN, _STD)
])

In [None]:
DATA_ROOT = "data/raw/cataract-image-dataset/processed_images"
BATCH_SIZE   = 32
VAL_FRAC     = 0.2

BACKBONE     = "efficientnet_b4"     # "resnet_50"
PRETRAINED   = True
LR           = 1e-4
WEIGHT_DECAY = 1e-4
EPOCHS       = 10

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

Using device: cuda


In [115]:
train_loader_iter1, val_loader_iter1, test_loader_iter1 = get_dataloaders(
    root=DATA_ROOT,
    batch_size=8,
    val_frac=0.2,
    num_workers=2,
    train_transform=train_transforms_iter1,
    val_transform=val_transforms_iter1
)

491 ---- <torch.utils.data.dataset.Subset object at 0x792017d88b10> ... <torch.utils.data.dataset.Subset object at 0x792017d88590>


In [116]:
model_iter1 = build_model(backbone="efficientnet_b4", pretrained=True).to(DEVICE)

for param in model_iter1.features[:3].parameters():
    param.requires_grad = False

In [117]:
bce_loss_iter1=nn.BCEWithLogitsLoss()

opt_iter1 = AdamW(model_iter1.parameters(), lr=1e-4, weight_decay=1e-4)

warmup_epochs_iter1 = 2

scheduler_iter1 = SequentialLR(
    opt_iter1,
    schedulers=[
        LinearLR(opt_iter1, start_factor=0.1, total_iters=warmup_epochs_iter1),
        CosineAnnealingLR(opt_iter1, T_max=EPOCHS - warmup_epochs_iter1)
    ],
    milestones=[warmup_epochs_iter1]
)

In [119]:
best_val_auc_iter1 = 0.0
history_iter1 = {
    'train_loss': [],
    'val_loss': [],
    'train_auc': [],
    'val_auc': [],
    'train_acc': [],
    'val_acc': []
}

for epoch in range(1, EPOCHS + 1):

    train_loss, train_auc, train_acc = train_single_epoch(
        model_iter1, train_loader_iter1,
        opt_iter1, scheduler_iter1,
        bce_loss_iter1, DEVICE
    )

 
    val_loss, val_auc, val_acc, val_preds, val_labels = eval_one_epoch(
        model_iter1, val_loader_iter1,
        bce_loss_iter1, DEVICE
    )


    history_iter1['train_loss'].append(train_loss)
    history_iter1['val_loss'].append(val_loss)
    history_iter1['train_auc'].append(train_auc)
    history_iter1['val_auc'].append(val_auc)
    history_iter1['train_acc'].append(train_acc)
    history_iter1['val_acc'].append(val_acc)


    print(f"Epoch {epoch}/{EPOCHS}  "
          f"Train L:{train_loss:.4f} AUC:{train_auc:.4f} Acc:{train_acc:.4f} | "
          f"Val   L:{val_loss:.4f} AUC:{val_auc:.4f} Acc:{val_acc:.4f}")


    if val_auc > best_val_auc_iter1:
        best_val_auc_iter1 = val_auc
        torch.save(model_iter1.state_dict(), "best_model_effnet.pth")

                                                         

Epoch 1/10  Train L:0.2141 AUC:0.9809 Acc:0.9338 | Val   L:0.1517 AUC:0.9899 Acc:0.9388


                                                         

Epoch 2/10  Train L:0.1591 AUC:0.9882 Acc:0.9491 | Val   L:0.1273 AUC:0.9912 Acc:0.9388


                                                         

Epoch 3/10  Train L:0.1237 AUC:0.9933 Acc:0.9542 | Val   L:0.1540 AUC:0.9916 Acc:0.9490


                                                         

Epoch 4/10  Train L:0.1197 AUC:0.9921 Acc:0.9618 | Val   L:0.1638 AUC:0.9907 Acc:0.9490


                                                         

Epoch 5/10  Train L:0.1063 AUC:0.9940 Acc:0.9593 | Val   L:0.1084 AUC:0.9920 Acc:0.9592


                                                         

Epoch 6/10  Train L:0.0960 AUC:0.9939 Acc:0.9771 | Val   L:0.1099 AUC:0.9945 Acc:0.9592


                                                         

Epoch 7/10  Train L:0.0634 AUC:0.9981 Acc:0.9822 | Val   L:0.1405 AUC:0.9937 Acc:0.9490


                                                         

Epoch 8/10  Train L:0.0621 AUC:0.9985 Acc:0.9771 | Val   L:0.1561 AUC:0.9937 Acc:0.9490


                                                         

Epoch 9/10  Train L:0.0641 AUC:0.9980 Acc:0.9771 | Val   L:0.1408 AUC:0.9941 Acc:0.9490


                                                         

Epoch 10/10  Train L:0.0427 AUC:0.9987 Acc:0.9898 | Val   L:0.1487 AUC:0.9941 Acc:0.9490


In [None]:
model.load_state_dict(torch.load("best_model_effnet.pth"))
model.eval()


test_loss, test_auc, test_acc, test_preds, test_labels = eval_one_epoch(
    model, test_loader, loss_fn, DEVICE
)

evaluate_predictions(test_labels, test_preds)

print(f"\n Test — Loss: {test_loss:.4f} | AUC: {test_auc:.4f} | Accuracy: {test_acc:.4f}")