In [None]:
# Setup
import sys
sys.path.insert(0, '..')

import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from PIL import Image
import cv2

import torch
import torch.nn.functional as F
from torchvision import transforms

# Project imports
from src.models import DRModel
from src.utils import BenGrahamPreprocessor
from src.xai import GradCAM, IntegratedGradients

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 1. Load Model & Configuration

In [None]:
# Configuration
CONFIG = {
    'model': {
        'backbone': 'efficientnet_b5',
        'num_classes': 5,
        'head_type': 'regression',
        'pretrained': False,  # We'll load weights
        'dropout': 0.5,
        'pooling': 'gem'
    },
    'input_size': 456,
    'checkpoint_path': '../checkpoints/last.ckpt',
    'thresholds_path': '../checkpoints/thresholds.json'
}

# Class names for ICDR scale
CLASS_NAMES = {
    0: 'No DR',
    1: 'Mild NPDR',
    2: 'Moderate NPDR', 
    3: 'Severe NPDR',
    4: 'Proliferative DR'
}

# Clinical descriptions
CLINICAL_DESCRIPTIONS = {
    0: 'No visible signs of diabetic retinopathy. Continue annual screening.',
    1: 'Mild nonproliferative DR. Microaneurysms only. Annual follow-up recommended.',
    2: 'Moderate nonproliferative DR. More than microaneurysms. Follow-up in 6 months.',
    3: 'Severe nonproliferative DR. Significant hemorrhages. Refer to ophthalmologist.',
    4: 'Proliferative DR. Neovascularization present. Urgent referral required.'
}

SEVERITY_COLORS = {
    0: '#2ECC71',  # Green
    1: '#F1C40F',  # Yellow
    2: '#E67E22',  # Orange
    3: '#E74C3C',  # Red
    4: '#9B59B6'   # Purple
}

In [None]:
# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create model
model = DRModel(
    backbone=CONFIG['model']['backbone'],
    num_classes=CONFIG['model']['num_classes'],
    head_type=CONFIG['model']['head_type'],
    pretrained=CONFIG['model']['pretrained'],
    dropout=CONFIG['model']['dropout'],
    pooling=CONFIG['model']['pooling']
)

# Load weights if available
if os.path.exists(CONFIG['checkpoint_path']):
    checkpoint = torch.load(CONFIG['checkpoint_path'], map_location=device)
    state_dict = checkpoint.get('state_dict', checkpoint)
    # Remove 'model.' prefix if present
    state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}
    model.load_state_dict(state_dict, strict=False)
    print("✓ Loaded trained weights")
else:
    print("⚠ No checkpoint found, using random weights for demo")

model = model.to(device)
model.eval()

# Load thresholds
if os.path.exists(CONFIG['thresholds_path']):
    with open(CONFIG['thresholds_path'], 'r') as f:
        thresholds_data = json.load(f)
        THRESHOLDS = thresholds_data['thresholds']
    print(f"✓ Loaded optimized thresholds: {THRESHOLDS}")
else:
    THRESHOLDS = [0.5, 1.5, 2.5, 3.5]
    print(f"Using default thresholds: {THRESHOLDS}")

In [None]:
# Preprocessing pipeline
preprocessor = BenGrahamPreprocessor(output_size=CONFIG['input_size'])

# Transform for model input
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

## 2. Inference Functions

In [None]:
def preprocess_image(image_path, preprocessor, transform, size):
    """Load and preprocess an image."""
    # Load image
    image = cv2.imread(str(image_path))
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Apply Ben Graham preprocessing
    processed = preprocessor.process(image)
    
    # Resize to model input size
    processed = cv2.resize(processed, (size, size))
    
    # Apply transforms
    tensor = transform(processed)
    
    return processed, tensor


def apply_thresholds(prediction, thresholds):
    """Convert regression output to class."""
    pred_class = 0
    for i, thresh in enumerate(thresholds):
        if prediction > thresh:
            pred_class = i + 1
    return pred_class


def predict_single(image_path, model, preprocessor, transform, thresholds, device):
    """Predict DR grade for a single image."""
    # Preprocess
    processed_image, tensor = preprocess_image(
        image_path, preprocessor, transform, CONFIG['input_size']
    )
    
    # Add batch dimension
    tensor = tensor.unsqueeze(0).to(device)
    
    # Predict
    with torch.no_grad():
        output = model(tensor).squeeze().cpu().item()
    
    # Apply thresholds
    pred_class = apply_thresholds(output, thresholds)
    
    return {
        'raw_score': output,
        'class': pred_class,
        'class_name': CLASS_NAMES[pred_class],
        'description': CLINICAL_DESCRIPTIONS[pred_class],
        'processed_image': processed_image
    }

In [None]:
def predict_with_tta(image_path, model, preprocessor, transform, thresholds, device, n_tta=8):
    """Predict with Test Time Augmentation."""
    # Preprocess
    processed_image, base_tensor = preprocess_image(
        image_path, preprocessor, transform, CONFIG['input_size']
    )
    
    # TTA augmentations
    tta_tensors = [base_tensor]  # Original
    
    # Horizontal flip
    tta_tensors.append(torch.flip(base_tensor, dims=[2]))
    
    # Vertical flip
    tta_tensors.append(torch.flip(base_tensor, dims=[1]))
    
    # Both flips
    tta_tensors.append(torch.flip(base_tensor, dims=[1, 2]))
    
    # Rotations (90, 180, 270)
    for k in [1, 2, 3]:
        rotated = torch.rot90(base_tensor, k, dims=[1, 2])
        tta_tensors.append(rotated)
    
    # Limit to n_tta
    tta_tensors = tta_tensors[:n_tta]
    
    # Stack and predict
    batch = torch.stack(tta_tensors).to(device)
    
    with torch.no_grad():
        outputs = model(batch).squeeze().cpu().numpy()
    
    # Average predictions
    avg_output = outputs.mean()
    std_output = outputs.std()
    
    # Apply thresholds
    pred_class = apply_thresholds(avg_output, thresholds)
    
    return {
        'raw_score': avg_output,
        'std': std_output,
        'individual_scores': outputs,
        'class': pred_class,
        'class_name': CLASS_NAMES[pred_class],
        'description': CLINICAL_DESCRIPTIONS[pred_class],
        'processed_image': processed_image,
        'confidence': 1 - (std_output / 2)  # Rough confidence estimate
    }

## 3. Single Image Inference Demo

In [None]:
# Demo with a sample image
# Replace with your own image path
sample_image_path = Path('../data/aptos/processed/sample_image.png')

# For demo, create a random image if no real data
if not sample_image_path.exists():
    print("No sample image found. Creating synthetic demo...")
    # Create synthetic fundus-like image for demo
    demo_image = np.zeros((512, 512, 3), dtype=np.uint8)
    cv2.circle(demo_image, (256, 256), 200, (100, 50, 50), -1)
    cv2.circle(demo_image, (256, 256), 180, (150, 80, 60), -1)
    # Add some features
    cv2.circle(demo_image, (200, 200), 20, (200, 100, 100), -1)  # Optic disc
    cv2.line(demo_image, (150, 256), (350, 256), (120, 60, 60), 3)  # Vessels
    demo_image_path = Path('../data/demo_image.png')
    demo_image_path.parent.mkdir(parents=True, exist_ok=True)
    cv2.imwrite(str(demo_image_path), cv2.cvtColor(demo_image, cv2.COLOR_RGB2BGR))
    sample_image_path = demo_image_path
    print(f"Created demo image at {sample_image_path}")

In [None]:
# Run inference
if sample_image_path.exists():
    result = predict_with_tta(
        sample_image_path, model, preprocessor, transform, THRESHOLDS, device
    )
    
    print("=" * 50)
    print("PREDICTION RESULT")
    print("=" * 50)
    print(f"Raw Score: {result['raw_score']:.3f} (std: {result['std']:.3f})")
    print(f"Predicted Class: {result['class']} - {result['class_name']}")
    print(f"Confidence: {result['confidence']:.1%}")
    print(f"\nClinical Note: {result['description']}")
    print("=" * 50)

In [None]:
# Visualize result
if sample_image_path.exists():
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Original
    original = cv2.imread(str(sample_image_path))
    original = cv2.cvtColor(original, cv2.COLOR_BGR2RGB)
    axes[0].imshow(original)
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    # Preprocessed
    axes[1].imshow(result['processed_image'])
    axes[1].set_title('Ben Graham Preprocessed')
    axes[1].axis('off')
    
    # Prediction visualization
    grades = list(range(5))
    colors = [SEVERITY_COLORS[g] for g in grades]
    bars = axes[2].bar(grades, [1 if g == result['class'] else 0.1 for g in grades], color=colors)
    
    # Highlight predicted class
    bars[result['class']].set_edgecolor('black')
    bars[result['class']].set_linewidth(3)
    
    axes[2].axhline(y=result['raw_score'] / 4, color='red', linestyle='--', linewidth=2, label=f'Score: {result["raw_score"]:.2f}')
    axes[2].set_xticks(grades)
    axes[2].set_xticklabels(['No DR', 'Mild', 'Mod', 'Severe', 'PDR'])
    axes[2].set_ylabel('Prediction')
    axes[2].set_title(f'Prediction: {result["class_name"]}')
    axes[2].legend()
    
    plt.tight_layout()
    plt.show()

## 4. Grad-CAM Visualization

In [None]:
# Initialize Grad-CAM
# For EfficientNet, we use the last convolutional layer
target_layer = model.backbone.features[-1]  # Last feature layer
gradcam = GradCAM(model, target_layer)

In [None]:
def visualize_gradcam(image_path, model, gradcam, preprocessor, transform, device):
    """Generate and visualize Grad-CAM heatmap."""
    # Preprocess
    processed, tensor = preprocess_image(
        image_path, preprocessor, transform, CONFIG['input_size']
    )
    tensor = tensor.unsqueeze(0).to(device)
    
    # Generate Grad-CAM
    heatmap, pred = gradcam(tensor)
    heatmap = heatmap.cpu().numpy()
    
    # Resize heatmap to image size
    heatmap = cv2.resize(heatmap, (processed.shape[1], processed.shape[0]))
    
    # Create colored heatmap
    heatmap_colored = cv2.applyColorMap(np.uint8(255 * heatmap), cv2.COLORMAP_JET)
    heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
    
    # Superimpose on original
    superimposed = cv2.addWeighted(processed, 0.6, heatmap_colored, 0.4, 0)
    
    return processed, heatmap, superimposed, pred

In [None]:
if sample_image_path.exists():
    processed, heatmap, superimposed, pred = visualize_gradcam(
        sample_image_path, model, gradcam, preprocessor, transform, device
    )
    
    pred_class = apply_thresholds(pred.item(), THRESHOLDS)
    
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    
    axes[0].imshow(processed)
    axes[0].set_title('Preprocessed Image')
    axes[0].axis('off')
    
    im = axes[1].imshow(heatmap, cmap='jet')
    axes[1].set_title('Grad-CAM Heatmap')
    axes[1].axis('off')
    plt.colorbar(im, ax=axes[1], fraction=0.046, pad=0.04)
    
    axes[2].imshow(superimposed)
    axes[2].set_title('Overlay')
    axes[2].axis('off')
    
    # Threshold the heatmap to show high attention regions
    mask = heatmap > 0.5
    highlighted = processed.copy()
    highlighted[~mask] = highlighted[~mask] * 0.3  # Dim low attention areas
    axes[3].imshow(highlighted)
    axes[3].set_title('High Attention Regions')
    axes[3].axis('off')
    
    plt.suptitle(f'Grad-CAM Analysis - Prediction: {CLASS_NAMES[pred_class]} (Score: {pred.item():.3f})', fontsize=14)
    plt.tight_layout()
    plt.show()

## 5. Batch Inference

In [None]:
def batch_predict(image_paths, model, preprocessor, transform, thresholds, device, batch_size=16):
    """Run batch inference on multiple images."""
    results = []
    
    for i in range(0, len(image_paths), batch_size):
        batch_paths = image_paths[i:i+batch_size]
        batch_tensors = []
        
        for path in batch_paths:
            _, tensor = preprocess_image(path, preprocessor, transform, CONFIG['input_size'])
            batch_tensors.append(tensor)
        
        batch = torch.stack(batch_tensors).to(device)
        
        with torch.no_grad():
            outputs = model(batch).squeeze().cpu().numpy()
        
        if outputs.ndim == 0:
            outputs = [outputs.item()]
        
        for path, output in zip(batch_paths, outputs):
            pred_class = apply_thresholds(output, thresholds)
            results.append({
                'image': path.name,
                'raw_score': output,
                'predicted_class': pred_class,
                'class_name': CLASS_NAMES[pred_class]
            })
    
    return pd.DataFrame(results)

In [None]:
# Example: batch inference on test set
test_images_dir = Path('../data/aptos/test_images')

if test_images_dir.exists():
    image_paths = list(test_images_dir.glob('*.png'))[:50]  # First 50 images
    
    if image_paths:
        print(f"Running inference on {len(image_paths)} images...")
        results_df = batch_predict(
            image_paths, model, preprocessor, transform, THRESHOLDS, device
        )
        
        print("\nResults:")
        print(results_df.head(10))
        
        # Distribution
        print("\nPrediction Distribution:")
        print(results_df['class_name'].value_counts())
else:
    print("Test images directory not found. Skipping batch inference demo.")

## 6. Clinical Report Generation

In [None]:
def generate_clinical_report(image_path, result, save_path=None):
    """Generate a clinical report for DR screening."""
    report = f"""
╔══════════════════════════════════════════════════════════════════╗
║           DIABETIC RETINOPATHY SCREENING REPORT                  ║
╠══════════════════════════════════════════════════════════════════╣
║ Image: {str(image_path.name)[:50]:<50} ║
╠══════════════════════════════════════════════════════════════════╣
║                                                                  ║
║  DIAGNOSIS: {result['class_name']:<45} ║
║  Grade: {result['class']}/4                                              ║
║  Confidence: {result.get('confidence', 0.0):.1%}                                          ║
║                                                                  ║
╠══════════════════════════════════════════════════════════════════╣
║  CLINICAL NOTES:                                                 ║
║  {result['description'][:60]:<60} ║
║  {result['description'][60:120] if len(result['description']) > 60 else '':<60} ║
║                                                                  ║
╠══════════════════════════════════════════════════════════════════╣
║  RECOMMENDATIONS:                                                ║
"""
    
    recommendations = {
        0: ['Continue annual diabetic eye examinations',
            'Maintain blood sugar control',
            'Regular blood pressure monitoring'],
        1: ['Follow-up examination in 12 months',
            'Optimize glycemic control',
            'Consider referral if progression'],
        2: ['Follow-up examination in 6 months',
            'Ophthalmology consultation recommended',
            'Intensive glycemic management'],
        3: ['Urgent ophthalmology referral',
            'Consider panretinal photocoagulation',
            'Close blood sugar monitoring'],
        4: ['IMMEDIATE ophthalmology referral',
            'Treatment required within 2 weeks',
            'High risk of vision loss']
    }
    
    for rec in recommendations[result['class']]:
        report += f"║  • {rec:<59} ║\n"
    
    report += f"""
║                                                                  ║
╠══════════════════════════════════════════════════════════════════╣
║  AI Analysis Score: {result['raw_score']:.3f}                                     ║
║  Analysis performed by: DR Detection Model v1.0                  ║
║                                                                  ║
║  ⚠️  This is an AI-assisted screening tool. Final diagnosis      ║
║     should be made by a qualified ophthalmologist.              ║
╚══════════════════════════════════════════════════════════════════╝
"""
    
    if save_path:
        with open(save_path, 'w') as f:
            f.write(report)
    
    return report

In [None]:
if sample_image_path.exists():
    # Generate report for our sample
    result = predict_with_tta(
        sample_image_path, model, preprocessor, transform, THRESHOLDS, device
    )
    
    report = generate_clinical_report(sample_image_path, result)
    print(report)

## 7. Performance Visualization

In [None]:
# Create severity scale visualization
fig, ax = plt.subplots(figsize=(12, 2))

# Draw gradient bar
for i in range(5):
    ax.axvspan(i, i+1, facecolor=SEVERITY_COLORS[i], alpha=0.8)
    ax.text(i+0.5, 0.5, f"Grade {i}\n{CLASS_NAMES[i].replace(' ', '\n')}", 
            ha='center', va='center', fontsize=10, fontweight='bold')

# Add thresholds
for i, t in enumerate(THRESHOLDS):
    ax.axvline(x=t, color='black', linestyle='--', linewidth=2)
    ax.text(t, 1.1, f'{t:.2f}', ha='center', fontsize=8)

ax.set_xlim(0, 5)
ax.set_ylim(0, 1)
ax.set_xticks(range(6))
ax.set_yticks([])
ax.set_title('ICDR Severity Scale with Optimized Thresholds', fontsize=12)
ax.set_xlabel('Regression Score')

plt.tight_layout()
plt.show()

## Summary

This notebook demonstrated:

1. ✅ Loading trained DR detection model
2. ✅ Single image inference with preprocessing
3. ✅ Test Time Augmentation (TTA) for improved predictions
4. ✅ Grad-CAM visualization for explainability
5. ✅ Batch inference for multiple images
6. ✅ Clinical report generation

### Key Points
- The model uses regression output with optimized thresholds
- TTA improves prediction reliability
- Grad-CAM shows which regions influence the prediction
- Always verify AI predictions with clinical expertise