In [None]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import seaborn as sns
import torch
import torch.nn as nn
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Subset
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, roc_curve

# Set random seed for reproducibility
random.seed(42)
torch.manual_seed(42)

# Paths
data_dir = r'D:\raj\catract detection\processed_images'
model_save_path = 'cataract_model.pth'

# Parameters
batch_size = 16
input_size = 224
epochs = 3
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data Transforms
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ]),
}

# Load Datasets
image_datasets = {
    x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
    for x in ['train', 'test']
}

# Optional: Reduce dataset size for quick testing
reduce_data = True
if reduce_data:
    reduced_size = 50
    image_datasets = {
        phase: Subset(dataset, random.sample(range(len(dataset)), min(len(dataset), reduced_size)))
        for phase, dataset in image_datasets.items()
    }

# Data Loaders
dataloaders = {
    x: DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True)
    for x in ['train', 'test']
}

# Class names
class_names = ['cataract', 'normal']

# Model
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(class_names))
model = model.to(device)

# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training Loop
for epoch in range(epochs):
    print(f'Epoch {epoch+1}/{epochs}\n{"-"*20}')
    model.train()
    running_loss, running_corrects = 0.0, 0

    for inputs, labels in dataloaders['train']:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    epoch_loss = running_loss / len(dataloaders['train'].dataset)
    epoch_acc = running_corrects.double() / len(dataloaders['train'].dataset)

    print(f'Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

# Save the model
torch.save(model.state_dict(), model_save_path)
print(f'Model saved to {model_save_path}')

# Evaluation
def evaluate_model(model):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in dataloaders['test']:
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())

    # Dynamically handle class names
    unique_labels = np.unique(all_labels)
    used_class_names = [class_names[label] for label in unique_labels]

    print(classification_report(all_labels, all_preds, target_names=used_class_names))
    cm = confusion_matrix(all_labels, all_preds)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=used_class_names, yticklabels=used_class_names)
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.show()

    if len(unique_labels) > 1:
        roc_score = roc_auc_score(all_labels, all_preds)
        print(f'ROC AUC Score: {roc_score:.4f}')
    else:
        print('ROC AUC Score: Cannot be computed (only one class present)')

evaluate_model(model)

# Grad-CAM Visualization
def show_gradcam(model, image_path):
    model.eval()

    img = Image.open(image_path).convert('RGB')
    transform = data_transforms['test']
    input_tensor = transform(img).unsqueeze(0).to(device)

    # Hook the gradients
    gradients = []
    activations = []

    def backward_hook(module, grad_input, grad_output):
        gradients.append(grad_output[0])

    def forward_hook(module, input, output):
        activations.append(output)

    final_conv = None
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            final_conv = module

    if final_conv is None:
        print("No convolution layer found!")
        return

    forward_handle = final_conv.register_forward_hook(forward_hook)
    backward_handle = final_conv.register_backward_hook(backward_hook)

    # Forward pass
    output = model(input_tensor)
    pred_class = output.argmax().item()

    # Backward pass
    model.zero_grad()
    output[0, pred_class].backward()

    # Process Grad-CAM
    gradients_ = gradients[0].cpu().numpy()[0]
    activations_ = activations[0].detach().cpu().numpy()[0]
    weights = np.mean(gradients_, axis=(1, 2))

    cam = np.zeros(activations_.shape[1:], dtype=np.float32)
    for i, w in enumerate(weights):
        cam += w * activations_[i, :, :]

    cam = np.maximum(cam, 0)
    cam = cam / cam.max() if cam.max() != 0 else cam
    cam = np.uint8(cam * 255)

    # Resize and overlay
    cam = Image.fromarray(cam).resize(img.size, Image.BILINEAR)
    plt.imshow(img)
    plt.imshow(cam, cmap='jet', alpha=0.5)
    plt.title(f'Grad-CAM for class: {class_names[pred_class]}')
    plt.axis('off')
    plt.show()

    # Clean up hooks
    forward_handle.remove()
    backward_handle.remove()

# Example: show_gradcam(model, r'D:\raj\catract detection\processed_images\test\cataract\sample.jpg')

# Inference function
def predict_image(model, image_path):
    model.eval()

    img = Image.open(image_path).convert('RGB')
    transform = data_transforms['test']
    input_tensor = transform(img).unsqueeze(0).to(device)

    output = model(input_tensor)
    _, pred = torch.max(output, 1)
    print(f'Predicted class: {class_names[pred.item()]}')

# Example: predict_image(model, r'D:\raj\catract detection\processed_images\test\normal\sample.jpg')