# ResNet-50 Training Notebook
This notebook trains a ResNet-50 model on your fundus image dataset.

## Dataset and Preprocessing

In [3]:
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

# ‚úÖ Correct dataset path (FIXED: uses space, not dash)
data_dir = r"C:\Users\Saif\Desktop\Glaucoma Detection\data\train"

# ‚úÖ Check if path exists before proceeding
if not os.path.exists(data_dir):
    raise FileNotFoundError(f"‚ùå Dataset path not found: {data_dir}")
else:
    print(f"‚úÖ Dataset path found: {data_dir}")
    print(f"Class folders: {os.listdir(data_dir)}")

# ‚úÖ Define transforms (FIXED: define BEFORE using)
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# ‚úÖ Load dataset
full_dataset = datasets.ImageFolder(data_dir, transform=train_transforms)
print(f"‚úÖ Found {len(full_dataset)} images across {len(full_dataset.classes)} classes: {full_dataset.classes}")

# ‚úÖ Train/Validation split
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# ‚úÖ DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

print(f"‚úÖ Train size: {len(train_dataset)}, Val size: {len(val_dataset)}")


‚úÖ Dataset path found: C:\Users\Saif\Desktop\Glaucoma Detection\data\train
Class folders: ['glaucoma', 'normal']
‚úÖ Found 17830 images across 2 classes: ['glaucoma', 'normal']
‚úÖ Train size: 14264, Val size: 3566


## ResNet Model Definition and Training

In [4]:
import torch
import torch.nn as nn
import torchvision.models as models

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = models.resnet50(pretrained=True)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 2)  # 2 classes: normal / glaucoma
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)




## Training Pipeline (ResNet Only)

In [5]:
from tqdm import tqdm

epochs = 20
best_val_acc = 0.0

for epoch in range(epochs):
    # ---------------------------
    # TRAINING
    # ---------------------------
    model.train()
    running_loss, running_corrects = 0.0, 0

    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]", leave=True):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        running_corrects += torch.sum(preds == labels)

    epoch_loss = running_loss / len(train_dataset)
    epoch_acc = running_corrects.double() / len(train_dataset)

    # ---------------------------
    # VALIDATION
    # ---------------------------
    model.eval()
    val_loss, val_corrects = 0.0, 0
    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Val]", leave=True):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            val_corrects += torch.sum(preds == labels)

    val_epoch_loss = val_loss / len(val_dataset)
    val_epoch_acc = val_corrects.double() / len(val_dataset)

    print(f"üìä Epoch {epoch+1}/{epochs} | "
          f"Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f} | "
          f"Val Loss: {val_epoch_loss:.4f}, Val Acc: {val_epoch_acc:.4f}")

    # ---------------------------
    # SAVE BEST MODEL
    # ---------------------------
    if val_epoch_acc > best_val_acc:
        best_val_acc = val_epoch_acc
        torch.save(model.state_dict(), "best_resnet_model.pth")
        print(f"‚úÖ Model saved with Val Acc: {best_val_acc:.4f}")

print(f"üéØ Training Complete. Best Val Accuracy: {best_val_acc:.4f}")


Epoch 1/20 [Train]:   3%|‚ñà‚ñà‚ñà‚ñã                                                                                                                   | 14/446 [03:57<2:02:00, 16.94s/it]


KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt
import random

class_names = full_dataset.classes
sample_img, sample_label = random.choice(val_dataset)
model.eval()
with torch.no_grad():
    output = model(sample_img.unsqueeze(0).to(device))
    _, pred = torch.max(output, 1)

plt.imshow(sample_img.permute(1, 2, 0))
plt.title(f"Predicted: {class_names[pred]}, Actual: {class_names[sample_label]}")
plt.axis('off')
plt.show()
