# OCT Image Classification Model
## Classify OCT Images into: Normal, DME, Drusen, or CNV

This notebook implements a deep learning model to classify OCT (Optical Coherence Tomography) retinal images into four categories:
- **Normal**: Healthy retina
- **DME**: Diabetic Macular Edema
- **Drusen**: Deposits under the retina
- **CNV**: Choroidal Neovascularization


## 1. Setup and Installation


In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')


In [None]:
# Install required packages
!pip install -q timm albumentations scikit-learn


## 2. Import Libraries


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
import torchvision.models as models

import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
from pathlib import Path
import json
from PIL import Image
import glob

from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.model_selection import train_test_split

import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA Device: {torch.cuda.get_device_name(0)}")


## 3. Configuration


In [None]:
# ============= CONFIGURATION =============
# Update these paths according to your Google Drive structure

# Base path - CHANGE THIS TO YOUR FOLDER PATH
BASE_PATH = '/content/drive/MyDrive/oct_major_project/'

# Dataset paths - Update these according to your folder structure
# Option 1: If you have separate folders for each class
DATA_PATHS = {
    'NORMAL': os.path.join(BASE_PATH, 'NORMAL 2.v1i.coco-segmentation/train'),
    'DME': os.path.join(BASE_PATH, 'DME 2.v1i.coco-segmentation/train'),
    'DRUSEN': os.path.join(BASE_PATH, 'drusen 3.v1i.coco-segmentation/train'),
    'CNV': os.path.join(BASE_PATH, 'CNV 2.v1i.coco-segmentation/train')
}

# Model save path
MODEL_SAVE_PATH = os.path.join(BASE_PATH, 'classification_models')
os.makedirs(MODEL_SAVE_PATH, exist_ok=True)

# Training hyperparameters
CONFIG = {
    'img_size': 224,
    'batch_size': 32,
    'num_epochs': 50,
    'learning_rate': 0.001,
    'num_classes': 4,
    'train_split': 0.8,
    'val_split': 0.1,
    'test_split': 0.1,
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'num_workers': 0,  # Set to 0 for Google Colab to avoid multiprocessing issues
    'model_name': 'resnet50',  # Options: resnet50, efficientnet_b0, vgg16
}

# Class mapping
CLASS_NAMES = ['CNV', 'DME', 'DRUSEN', 'NORMAL']
CLASS_TO_IDX = {name: idx for idx, name in enumerate(CLASS_NAMES)}
IDX_TO_CLASS = {idx: name for name, idx in CLASS_TO_IDX.items()}

print("Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")
print(f"\nClass Mapping: {CLASS_TO_IDX}")


In [None]:
class OCTDataset(Dataset):
    """Custom Dataset for OCT Images"""
    
    def __init__(self, image_paths, labels, transform=None):
        """
        Args:
            image_paths: List of paths to images
            labels: List of labels (indices)
            transform: Optional transform to be applied on images
        """
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Load image
        img_path = self.image_paths[idx]
        image = cv2.imread(img_path)
        
        if image is None:
            raise ValueError(f"Could not load image: {img_path}")
        
        # Convert BGR to RGB
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Apply transforms
        if self.transform:
            image = self.transform(image)
        
        label = self.labels[idx]
        
        return image, label


## 5. Data Loading and Preparation


In [None]:
def load_dataset_paths(data_paths):
    """Load all image paths and their corresponding labels"""
    all_image_paths = []
    all_labels = []
    
    print("Loading dataset...")
    for class_name, folder_path in data_paths.items():
        if not os.path.exists(folder_path):
            print(f"Warning: Folder not found: {folder_path}")
            continue
        
        # Get all image files (jpg, jpeg, png)
        image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']
        image_files = []
        for ext in image_extensions:
            image_files.extend(glob.glob(os.path.join(folder_path, ext)))
        
        print(f"  {class_name}: {len(image_files)} images")
        
        # Add to lists
        all_image_paths.extend(image_files)
        all_labels.extend([CLASS_TO_IDX[class_name]] * len(image_files))
    
    print(f"\nTotal images: {len(all_image_paths)}")
    return all_image_paths, all_labels

# Load all data
image_paths, labels = load_dataset_paths(DATA_PATHS)


In [None]:
# Split data into train, validation, and test sets
from sklearn.model_selection import train_test_split

# First split: train+val vs test
train_val_paths, test_paths, train_val_labels, test_labels = train_test_split(
    image_paths, labels, 
    test_size=CONFIG['test_split'], 
    random_state=42, 
    stratify=labels
)

# Second split: train vs val
val_size = CONFIG['val_split'] / (CONFIG['train_split'] + CONFIG['val_split'])
train_paths, val_paths, train_labels, val_labels = train_test_split(
    train_val_paths, train_val_labels,
    test_size=val_size,
    random_state=42,
    stratify=train_val_labels
)

print(f"Dataset Split:")
print(f"  Training samples: {len(train_paths)}")
print(f"  Validation samples: {len(val_paths)}")
print(f"  Test samples: {len(test_paths)}")


## 6. Data Augmentation and Transforms


In [None]:
# Define transforms
train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((CONFIG['img_size'], CONFIG['img_size'])),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_test_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((CONFIG['img_size'], CONFIG['img_size'])),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create datasets
train_dataset = OCTDataset(train_paths, train_labels, transform=train_transform)
val_dataset = OCTDataset(val_paths, val_labels, transform=val_test_transform)
test_dataset = OCTDataset(test_paths, test_labels, transform=val_test_transform)

# Create data loaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=CONFIG['batch_size'], 
    shuffle=True, 
    num_workers=CONFIG['num_workers']
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=CONFIG['batch_size'], 
    shuffle=False, 
    num_workers=CONFIG['num_workers']
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=CONFIG['batch_size'], 
    shuffle=False, 
    num_workers=CONFIG['num_workers']
)

print("Data loaders created successfully!")


## 7. Visualize Sample Images


In [None]:
def visualize_samples(dataset, num_samples=8):
    """Visualize sample images from the dataset"""
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    axes = axes.ravel()
    
    # Get random samples
    indices = np.random.choice(len(dataset), num_samples, replace=False)
    
    for i, idx in enumerate(indices):
        # Get original image (without normalization)
        img_path = dataset.image_paths[idx]
        label = dataset.labels[idx]
        
        # Load and display image
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        axes[i].imshow(img)
        axes[i].set_title(f'Class: {IDX_TO_CLASS[label]}', fontsize=12, fontweight='bold')
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize training samples
print("Sample Training Images:")
visualize_samples(train_dataset)


## 8. Model Architecture


In [None]:
def create_model(model_name='resnet50', num_classes=4, pretrained=True):
    """Create a classification model with transfer learning"""
    
    if model_name == 'resnet50':
        model = models.resnet50(pretrained=pretrained)
        num_features = model.fc.in_features
        model.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
    
    elif model_name == 'resnet18':
        model = models.resnet18(pretrained=pretrained)
        num_features = model.fc.in_features
        model.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(num_features, num_classes)
        )
    
    elif model_name == 'vgg16':
        model = models.vgg16(pretrained=pretrained)
        num_features = model.classifier[6].in_features
        model.classifier[6] = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(num_features, num_classes)
        )
    
    elif model_name == 'efficientnet_b0':
        model = models.efficientnet_b0(pretrained=pretrained)
        num_features = model.classifier[1].in_features
        model.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(num_features, num_classes)
        )
    
    else:
        raise ValueError(f"Unsupported model: {model_name}")
    
    return model

# Create model
model = create_model(
    model_name=CONFIG['model_name'], 
    num_classes=CONFIG['num_classes'], 
    pretrained=True
)
model = model.to(CONFIG['device'])

print(f"Model: {CONFIG['model_name']}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")


## 9. Training Setup


In [None]:
# Loss function
criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=CONFIG['learning_rate'])

# Learning rate scheduler (reduces learning rate when validation loss plateaus)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5
)

print("Training setup complete!")


## 10. Training Functions


In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    progress_bar = tqdm(dataloader, desc='Training')
    
    for images, labels in progress_bar:
        images = images.to(device)
        labels = labels.to(device)
        
        # Zero the gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{100 * correct / total:.2f}%'
        })
    
    epoch_loss = running_loss / total
    epoch_acc = 100 * correct / total
    
    return epoch_loss, epoch_acc


def validate_epoch(model, dataloader, criterion, device):
    """Validate for one epoch"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    progress_bar = tqdm(dataloader, desc='Validation')
    
    with torch.no_grad():
        for images, labels in progress_bar:
            images = images.to(device)
            labels = labels.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Statistics
            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Update progress bar
            progress_bar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{100 * correct / total:.2f}%'
            })
    
    epoch_loss = running_loss / total
    epoch_acc = 100 * correct / total
    
    return epoch_loss, epoch_acc


## 11. Training Loop with Model Checkpointing


In [None]:
# Training history
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': []
}

# Best model tracking
best_val_acc = 0.0
best_model_path = os.path.join(MODEL_SAVE_PATH, 'best_oct_classifier.pth')

print("Starting training...\n")

for epoch in range(CONFIG['num_epochs']):
    print(f"\nEpoch {epoch+1}/{CONFIG['num_epochs']}")
    print("-" * 50)
    
    # Train
    train_loss, train_acc = train_epoch(
        model, train_loader, criterion, optimizer, CONFIG['device']
    )
    
    # Validate
    val_loss, val_acc = validate_epoch(
        model, val_loader, criterion, CONFIG['device']
    )
    
    # Update learning rate
    scheduler.step(val_loss)
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    # Print epoch results
    print(f"\nEpoch Results:")
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"  Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
            'val_loss': val_loss,
            'config': CONFIG,
            'class_to_idx': CLASS_TO_IDX
        }, best_model_path)
        print(f"  âœ“ Best model saved! (Val Acc: {val_acc:.2f}%)")
    
    # Save checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        checkpoint_path = os.path.join(MODEL_SAVE_PATH, f'checkpoint_epoch_{epoch+1}.pth')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'history': history
        }, checkpoint_path)
        print(f"  âœ“ Checkpoint saved at epoch {epoch+1}")

print("\n" + "="*50)
print(f"Training completed!")
print(f"Best validation accuracy: {best_val_acc:.2f}%")
print(f"Best model saved at: {best_model_path}")
print("="*50)


In [None]:
def plot_training_history(history):
    """Plot training and validation metrics"""
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot loss
    axes[0].plot(history['train_loss'], label='Train Loss', marker='o')
    axes[0].plot(history['val_loss'], label='Val Loss', marker='s')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training and Validation Loss')
    axes[0].legend()
    axes[0].grid(True)
    
    # Plot accuracy
    axes[1].plot(history['train_acc'], label='Train Accuracy', marker='o')
    axes[1].plot(history['val_acc'], label='Val Accuracy', marker='s')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy (%)')
    axes[1].set_title('Training and Validation Accuracy')
    axes[1].legend()
    axes[1].grid(True)
    
    plt.tight_layout()
    plt.savefig(os.path.join(MODEL_SAVE_PATH, 'training_history.png'), dpi=300, bbox_inches='tight')
    plt.show()

plot_training_history(history)


## 13. Load Best Model and Evaluate on Test Set


In [None]:
# Load best model
checkpoint = torch.load(best_model_path)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"Loaded best model from epoch {checkpoint['epoch']}")
print(f"Best validation accuracy: {checkpoint['val_acc']:.2f}%")


In [None]:
def evaluate_model(model, dataloader, device):
    """Evaluate model and return predictions and labels"""
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc='Evaluating'):
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            probs = torch.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs, 1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    return np.array(all_preds), np.array(all_labels), np.array(all_probs)

# Evaluate on test set
test_preds, test_labels, test_probs = evaluate_model(model, test_loader, CONFIG['device'])

# Calculate accuracy
test_accuracy = accuracy_score(test_labels, test_preds)
print(f"\nTest Accuracy: {test_accuracy * 100:.2f}%")


## 14. Classification Report and Confusion Matrix


In [None]:
# Classification report
print("\nClassification Report:")
print("="*60)
print(classification_report(test_labels, test_preds, target_names=CLASS_NAMES, digits=4))

# Confusion matrix
cm = confusion_matrix(test_labels, test_preds)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
            cbar_kws={'label': 'Count'})
plt.xlabel('Predicted Label', fontsize=12)
plt.ylabel('True Label', fontsize=12)
plt.title('Confusion Matrix - Test Set', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(MODEL_SAVE_PATH, 'confusion_matrix.png'), dpi=300, bbox_inches='tight')
plt.show()

# Normalized confusion matrix
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

plt.figure(figsize=(10, 8))
sns.heatmap(cm_normalized, annot=True, fmt='.2%', cmap='Blues',
            xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
            cbar_kws={'label': 'Percentage'})
plt.xlabel('Predicted Label', fontsize=12)
plt.ylabel('True Label', fontsize=12)
plt.title('Normalized Confusion Matrix - Test Set', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(MODEL_SAVE_PATH, 'confusion_matrix_normalized.png'), dpi=300, bbox_inches='tight')
plt.show()


## 15. Single Image Prediction Function


In [None]:
def predict_single_image(model, image_path, transform, device, class_names):
    """
    Predict the class of a single OCT image
    
    Args:
        model: Trained model
        image_path: Path to the image
        transform: Image transformation pipeline
        device: Device to run inference on
        class_names: List of class names
    
    Returns:
        predicted_class: Predicted class name
        probabilities: Dictionary of class probabilities
    """
    # Load and preprocess image
    image = cv2.imread(image_path)
    if image is None:
        raise ValueError(f"Could not load image: {image_path}")
    
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    original_image = image_rgb.copy()
    
    # Apply transforms
    image_tensor = transform(image_rgb).unsqueeze(0)  # Add batch dimension
    image_tensor = image_tensor.to(device)
    
    # Predict
    model.eval()
    with torch.no_grad():
        outputs = model(image_tensor)
        probabilities = torch.softmax(outputs, dim=1)
        predicted_idx = torch.argmax(probabilities, dim=1).item()
    
    # Get class name and probabilities
    predicted_class = class_names[predicted_idx]
    probs_dict = {class_names[i]: probabilities[0][i].item() for i in range(len(class_names))}
    
    return predicted_class, probs_dict, original_image


def visualize_prediction(image, predicted_class, probabilities):
    """
    Visualize the prediction with image and probability bars
    """
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Display image
    axes[0].imshow(image)
    axes[0].set_title(f'Predicted: {predicted_class}', fontsize=14, fontweight='bold')
    axes[0].axis('off')
    
    # Display probabilities
    classes = list(probabilities.keys())
    probs = list(probabilities.values())
    colors = ['green' if c == predicted_class else 'blue' for c in classes]
    
    axes[1].barh(classes, probs, color=colors, alpha=0.7)
    axes[1].set_xlabel('Probability', fontsize=12)
    axes[1].set_title('Class Probabilities', fontsize=14, fontweight='bold')
    axes[1].set_xlim([0, 1])
    
    # Add probability values on bars
    for i, (cls, prob) in enumerate(zip(classes, probs)):
        axes[1].text(prob + 0.02, i, f'{prob:.2%}', 
                    va='center', fontsize=10, fontweight='bold')
    
    plt.tight_layout()
    plt.show()
    
    # Print detailed results
    print("\n" + "="*50)
    print(f"Predicted Class: {predicted_class}")
    print("="*50)
    print("\nClass Probabilities:")
    for cls, prob in sorted(probabilities.items(), key=lambda x: x[1], reverse=True):
        print(f"  {cls:10s}: {prob:6.2%}")
    print("="*50)

print("Prediction functions defined!")


In [None]:
# Example: Predict on a random test image
random_idx = np.random.randint(0, len(test_paths))
test_image_path = test_paths[random_idx]
true_label = IDX_TO_CLASS[test_labels[random_idx]]

print(f"Test Image: {os.path.basename(test_image_path)}")
print(f"True Label: {true_label}")
print("\nPredicting...")

predicted_class, probabilities, image = predict_single_image(
    model, test_image_path, val_test_transform, CONFIG['device'], CLASS_NAMES
)

visualize_prediction(image, predicted_class, probabilities)


## 17. Predict on Your Own Image


In [None]:
# TO PREDICT ON YOUR OWN IMAGE:
# 1. Upload your image to Google Colab or provide the path
# 2. Update the image_path variable below
# 3. Run this cell

# Example:
# from google.colab import files
# uploaded = files.upload()  # This will prompt you to upload a file
# image_path = list(uploaded.keys())[0]  # Get the uploaded filename

# OR provide direct path:
# image_path = '/content/drive/MyDrive/oct_major_project/my_test_image.jpg'

# Uncomment and modify the lines below:
# predicted_class, probabilities, image = predict_single_image(
#     model, image_path, val_test_transform, CONFIG['device'], CLASS_NAMES
# )
# visualize_prediction(image, predicted_class, probabilities)

print("Ready to predict on custom images!")
print("Uncomment the code above and provide your image path.")


## 18. Visualize Test Set Predictions (Sample)


In [None]:
def visualize_test_predictions(model, test_paths, test_labels, num_samples=12):
    """Visualize predictions on random test samples"""
    fig, axes = plt.subplots(3, 4, figsize=(20, 15))
    axes = axes.ravel()
    
    # Get random samples
    indices = np.random.choice(len(test_paths), num_samples, replace=False)
    
    for i, idx in enumerate(indices):
        img_path = test_paths[idx]
        true_label = IDX_TO_CLASS[test_labels[idx]]
        
        # Predict
        predicted_class, probs, image = predict_single_image(
            model, img_path, val_test_transform, CONFIG['device'], CLASS_NAMES
        )
        
        # Display
        axes[i].imshow(image)
        
        # Color code: green if correct, red if wrong
        color = 'green' if predicted_class == true_label else 'red'
        title = f'True: {true_label}\nPred: {predicted_class}\n({probs[predicted_class]:.1%})'
        axes[i].set_title(title, fontsize=10, fontweight='bold', color=color)
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(MODEL_SAVE_PATH, 'test_predictions_sample.png'), 
                dpi=300, bbox_inches='tight')
    plt.show()

print("Visualizing test predictions...\n")
visualize_test_predictions(model, test_paths, test_labels)


## 19. Save Model Summary and Configuration


In [None]:
# Save model summary and results
summary = {
    'model_name': CONFIG['model_name'],
    'num_classes': CONFIG['num_classes'],
    'class_names': CLASS_NAMES,
    'class_to_idx': CLASS_TO_IDX,
    'img_size': CONFIG['img_size'],
    'training_samples': len(train_paths),
    'validation_samples': len(val_paths),
    'test_samples': len(test_paths),
    'num_epochs': CONFIG['num_epochs'],
    'batch_size': CONFIG['batch_size'],
    'learning_rate': CONFIG['learning_rate'],
    'best_val_accuracy': float(best_val_acc),
    'test_accuracy': float(test_accuracy * 100),
    'total_parameters': sum(p.numel() for p in model.parameters()),
    'trainable_parameters': sum(p.numel() for p in model.parameters() if p.requires_grad)
}

summary_path = os.path.join(MODEL_SAVE_PATH, 'model_summary.json')
with open(summary_path, 'w') as f:
    json.dump(summary, f, indent=4)

print("Model Summary:")
print("="*60)
for key, value in summary.items():
    print(f"{key:25s}: {value}")
print("="*60)
print(f"\nSummary saved to: {summary_path}")


## 20. Export Model for Production (Optional)


In [None]:
# Save a lightweight version for deployment
deployment_model_path = os.path.join(MODEL_SAVE_PATH, 'oct_classifier_deployment.pth')

torch.save({
    'model_state_dict': model.state_dict(),
    'class_to_idx': CLASS_TO_IDX,
    'idx_to_class': IDX_TO_CLASS,
    'model_name': CONFIG['model_name'],
    'img_size': CONFIG['img_size'],
    'num_classes': CONFIG['num_classes']
}, deployment_model_path)

print(f"Deployment model saved to: {deployment_model_path}")
print(f"File size: {os.path.getsize(deployment_model_path) / (1024*1024):.2f} MB")


## 21. How to Load and Use the Saved Model


In [None]:
# Example code to load and use the saved model in a new session

"""
# Load the model
import torch
import torchvision.models as models
import torchvision.transforms as transforms

# Load checkpoint
checkpoint = torch.load('path/to/best_oct_classifier.pth')

# Create model
model = create_model(
    model_name=checkpoint['config']['model_name'],
    num_classes=checkpoint['config']['num_classes'],
    pretrained=False
)

# Load weights
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Define transforms
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Predict on new image
predicted_class, probabilities, image = predict_single_image(
    model, 'path/to/image.jpg', transform, device, CLASS_NAMES
)
"""

print("See the code above for how to load and use the model in a new session.")


## Summary

### What This Notebook Does:
1. âœ… Loads OCT images from multiple classes (Normal, DME, Drusen, CNV)
2. âœ… Splits data into train/val/test sets
3. âœ… Applies data augmentation for better generalization
4. âœ… Uses transfer learning (ResNet50, VGG16, or EfficientNet)
5. âœ… Trains the model with automatic checkpointing
6. âœ… Saves the best model based on validation accuracy
7. âœ… Evaluates on test set with detailed metrics
8. âœ… Provides single image prediction functionality
9. âœ… Generates visualizations and reports

### Key Files Generated:
- `best_oct_classifier.pth` - Best model checkpoint
- `oct_classifier_deployment.pth` - Lightweight deployment model
- `model_summary.json` - Model configuration and results
- `training_history.png` - Training curves
- `confusion_matrix.png` - Confusion matrix visualization

### Next Steps:
1. Use the trained model for predictions
2. Integrate into a web application
3. Deploy for real-world use

---
**Good luck with your project! ðŸš€**
