In [3]:
!pip install ultralytics segmentation-models-pytorch albumentations --quiet


import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from ultralytics import YOLO
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2
import warnings
warnings.filterwarnings('ignore')

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.0/62.0 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m15.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.8/154.8 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.8/16.8 MB[0m [31m94.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m91.1 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m78.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.



In [21]:
# ============================================================================
# SECTION 1: CONFIGURATION
# ============================================================================

class Config:
    """Configuration for inference pipeline"""
    
    # Model paths - UPDATE THESE TO YOUR UPLOADED MODELS
    YOLO_MODEL_PATH = '/kaggle/input/glaucomamodels/tensorflow2/default/1/yolov8l_od_oc_best.pt'
    RESNET_MODEL_PATH = '/kaggle/input/glaucomamodels/tensorflow2/default/1/best_resnet_model.pth'
    UNET_VESSEL_MODEL_PATH = '/kaggle/input/glaucomamodels/tensorflow2/default/1/unet_vessel_best.pth'
    
    # Test images directory
    TEST_IMAGES_DIR = '/kaggle/input/testimage/image1734.png'
    
    # Inference parameters
    IMAGE_SIZE = 512
    YOLO_CONF = 0.25
    YOLO_IOU = 0.45
    
    # Decision fusion weights (adjusted without U-Net OD/OC)
    WEIGHT_RESNET = 0.60    # Increased from 0.35
    WEIGHT_YOLO = 0.20      # Increased from 0.15 (quality check)
    WEIGHT_VESSEL = 0.20    # Increased from 0.10
    
    # Classification thresholds
    THRESHOLD_NORMAL = 0.35
    THRESHOLD_SUSPICIOUS = 0.65

config = Config()

print("="*60)
print("GLAUCOMA DETECTION INFERENCE PIPELINE - BOOK 5")
print("="*60)
print("\nConfiguration:")
print(f"  YOLO model: {config.YOLO_MODEL_PATH}")
print(f"  ResNet model: {config.RESNET_MODEL_PATH}")
print(f"  U-Net Vessel model: {config.UNET_VESSEL_MODEL_PATH}")
print(f"\nDecision weights:")
print(f"  ResNet: {config.WEIGHT_RESNET}")
print(f"  YOLO: {config.WEIGHT_YOLO}")
print(f"  Vessels: {config.WEIGHT_VESSEL}")
print("="*60)



GLAUCOMA DETECTION INFERENCE PIPELINE - BOOK 5

Configuration:
  YOLO model: /kaggle/input/glaucomamodels/tensorflow2/default/1/yolov8l_od_oc_best.pt
  ResNet model: /kaggle/input/glaucomamodels/tensorflow2/default/1/best_resnet_model.pth
  U-Net Vessel model: /kaggle/input/glaucomamodels/tensorflow2/default/1/unet_vessel_best.pth

Decision weights:
  ResNet: 0.6
  YOLO: 0.2
  Vessels: 0.2


In [22]:
# ============================================================================
# SECTION 2: DEVICE SETUP
# ============================================================================

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nDevice: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")



Device: cuda
GPU: Tesla T4


In [23]:
# ============================================================================
# SECTION 3: PREPROCESSING
# ============================================================================

class FundusPreprocessor:
    """Unified preprocessing pipeline"""
    
    def __init__(self, target_size=512, apply_clahe=True, apply_green_channel=True):
        self.target_size = target_size
        self.apply_clahe = apply_clahe
        self.apply_green_channel = apply_green_channel
    
    def preprocess(self, image):
        """Main preprocessing function"""
        # Resize
        img = cv2.resize(image, (self.target_size, self.target_size), 
                        interpolation=cv2.INTER_AREA)
        
        # Green channel extraction
        if self.apply_green_channel and len(img.shape) == 3:
            green_channel = img[:, :, 1]
        else:
            green_channel = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if len(img.shape) == 3 else img
        
        # CLAHE
        if self.apply_clahe:
            clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
            enhanced = clahe.apply(green_channel)
        else:
            enhanced = green_channel
        
        # Normalize
        normalized = enhanced.astype(np.float32) / 255.0
        
        # Return 3-channel
        return np.stack([normalized]*3, axis=-1)

preprocessor = FundusPreprocessor(target_size=config.IMAGE_SIZE)


In [24]:
# ============================================================================
# SECTION 4: MODEL LOADERS
# ============================================================================

print("\n" + "="*60)
print("LOADING MODELS")
print("="*60)

def load_yolo_model(model_path):
    """Load YOLO v8 model for OD/OC detection"""
    print(f"Loading YOLO from {model_path}...")
    if not Path(model_path).exists():
        raise FileNotFoundError(f"YOLO model not found: {model_path}")
    model = YOLO(model_path)
    print("✓ YOLO loaded")
    return model

def load_resnet_model(model_path, num_classes=2):
    """Load ResNet50 classifier"""
    print(f"Loading ResNet from {model_path}...")
    if not Path(model_path).exists():
        raise FileNotFoundError(f"ResNet model not found: {model_path}")
    
    model = models.resnet50(pretrained=False)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model = model.to(device)
    model.eval()
    print("✓ ResNet loaded")
    return model

def load_unet_vessel_model(model_path):
    """Load U-Net for vessel segmentation"""
    print(f"Loading U-Net Vessels from {model_path}...")
    if not Path(model_path).exists():
        raise FileNotFoundError(f"U-Net Vessel model not found: {model_path}")
    
    model = smp.Unet(
        encoder_name="resnet34",
        encoder_weights=None,
        in_channels=3,
        classes=1,
        activation=None
    )

    # ✅ Load checkpoint
    checkpoint = torch.load(model_path, map_location=device, weights_only=False)

    # ✅ If it’s a full checkpoint, extract the actual model weights
    if "model_state_dict" in checkpoint:
        state_dict = checkpoint["model_state_dict"]
    else:
        state_dict = checkpoint  # in case it’s already a pure state_dict

    model.load_state_dict(state_dict)
    model = model.to(device)
    model.eval()
    print("✓ U-Net Vessels loaded")
    return model


# Load all models
try:
    yolo_model = load_yolo_model(config.YOLO_MODEL_PATH)
    resnet_model = load_resnet_model(config.RESNET_MODEL_PATH)
    unet_vessel_model = load_unet_vessel_model(config.UNET_VESSEL_MODEL_PATH)
    print("\n✅ All models loaded successfully!")
except FileNotFoundError as e:
    print(f"\n❌ Error: {e}")
    print("\nPlease ensure all model files are uploaded to Kaggle:")
    print("  1. yolov8l_od_oc_best.pt")
    print("  2. best_resnet_model.pth")
    print("  3. unet_vessel_best.pth")
    raise



LOADING MODELS
Loading YOLO from /kaggle/input/glaucomamodels/tensorflow2/default/1/yolov8l_od_oc_best.pt...
✓ YOLO loaded
Loading ResNet from /kaggle/input/glaucomamodels/tensorflow2/default/1/best_resnet_model.pth...
✓ ResNet loaded
Loading U-Net Vessels from /kaggle/input/glaucomamodels/tensorflow2/default/1/unet_vessel_best.pth...
✓ U-Net Vessels loaded

✅ All models loaded successfully!


In [26]:
# ============================================================================
# SECTION 5: YOLO INFERENCE
# ============================================================================

def yolo_detect_od_oc(image, model, conf=0.25, iou=0.45):
    """
    Detect OD and OC using YOLO
    Returns: detection info and quality score
    """
    results = model(image, conf=conf, iou=iou, verbose=False)
    
    detections = {
        'od_detected': False,
        'oc_detected': False,
        'od_bbox': None,
        'oc_bbox': None,
        'od_conf': 0.0,
        'oc_conf': 0.0,
        'quality_score': 0.0
    }
    
    if len(results) > 0 and results[0].boxes is not None:
        boxes = results[0].boxes
        
        for box in boxes:
            cls = int(box.cls[0])
            conf = float(box.conf[0])
            bbox = box.xyxy[0].cpu().numpy()
            
            if cls == 0:  # Optic Disc
                detections['od_detected'] = True
                detections['od_bbox'] = bbox
                detections['od_conf'] = conf
            elif cls == 1:  # Optic Cup
                detections['oc_detected'] = True
                detections['oc_bbox'] = bbox
                detections['oc_conf'] = conf
    
    # Quality score based on detection confidence
    if detections['od_detected'] and detections['oc_detected']:
        detections['quality_score'] = (detections['od_conf'] + detections['oc_conf']) / 2
    elif detections['od_detected']:
        detections['quality_score'] = detections['od_conf'] *0.5
    else:
        detections['quality_score'] = 0.0
    
    return detections

# ============================================================================
# SECTION 6: RESNET INFERENCE
# ============================================================================

def resnet_classify(image, model):
    """
    Classify image using ResNet
    Returns: probability of glaucoma
    """
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    # Convert to uint8 if needed
    if image.dtype != np.uint8:
        image = (image * 255).astype(np.uint8)
    
    img_tensor = transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        outputs = model(img_tensor)
        probs = torch.softmax(outputs, dim=1)
        glaucoma_prob = probs[0, 1].item()  # Probability of class 1 (glaucoma)
    
    return glaucoma_prob


In [27]:
# ============================================================================
# SECTION 7: U-NET VESSEL SEGMENTATION
# ============================================================================

def unet_segment_vessels(image, model):
    """
    Segment vessels using U-Net
    Returns: vessel density score
    """
    preprocessed = preprocessor.preprocess(image)
    preprocessed_uint8 = (preprocessed * 255).astype(np.uint8)
    
    transform = A.Compose([
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    
    img_tensor = transform(image=preprocessed_uint8)['image'].unsqueeze(0).to(device)
    
    with torch.no_grad():
        output = model(img_tensor)
        pred_sigmoid = torch.sigmoid(output).squeeze().cpu().numpy()
        vessel_mask = (pred_sigmoid > 0.5).astype(np.uint8)
    
    # Calculate vessel density (percentage of vessel pixels)
    vessel_density = np.mean(vessel_mask)
    
    return vessel_density, vessel_mask

In [28]:
# ============================================================================
# SECTION 8: DECISION FUSION MODULE
# ============================================================================

def fuse_predictions(resnet_prob, yolo_quality, vessel_density):
    """
    Fuse all model outputs into final decision
    
    Inputs:
        resnet_prob: 0-1, probability from ResNet classifier
        yolo_quality: 0-1, YOLO detection quality
        vessel_density: 0-1, vessel coverage percentage
    
    Returns:
        final_score: 0-1, final glaucoma risk score
        decision: 'Normal' / 'Suspicious' / 'Glaucoma'
        confidence: confidence level
        explanation: detailed reasoning
    """
    
    # Convert vessel density to risk (lower density = higher risk in glaucoma)
    # Typical healthy vessel density: 10-15%, glaucoma: 5-8%
    vessel_risk = 1.0 - (vessel_density / 0.15)  # Normalize assuming 15% is normal
    vessel_risk = np.clip(vessel_risk, 0.0, 1.0)
    
    # Weighted fusion (NO U-Net OD/OC component)
    final_score = (
        config.WEIGHT_RESNET * resnet_prob +
        config.WEIGHT_YOLO * (1.0 - yolo_quality) +  # Lower quality = higher risk
        config.WEIGHT_VESSEL * vessel_risk
    )
    
    # Decision thresholds
    if final_score < config.THRESHOLD_NORMAL:
        decision = "Normal"
        confidence = "High" if final_score < 0.25 else "Moderate"
    elif final_score < config.THRESHOLD_SUSPICIOUS:
        decision = "Suspicious"
        confidence = "Moderate"
    else:
        decision = "Glaucoma"
        confidence = "High" if final_score > 0.75 else "Moderate"
    
    # Generate explanation
    explanation = []
    
    if resnet_prob > 0.7:
        explanation.append(f"Deep learning classifier indicates HIGH glaucoma risk ({resnet_prob:.2f})")
    elif resnet_prob > 0.5:
        explanation.append(f"Deep learning classifier indicates MODERATE glaucoma risk ({resnet_prob:.2f})")
    else:
        explanation.append(f"Deep learning classifier indicates LOW glaucoma risk ({resnet_prob:.2f})")
    
    if yolo_quality < 0.5:
        explanation.append(f"Image quality is LOW ({yolo_quality:.2f}) - OD/OC detection uncertain")
    elif yolo_quality > 0.8:
        explanation.append(f"Image quality is GOOD ({yolo_quality:.2f}) - clear optic disc detected")
    
    if vessel_density < 0.08:
        explanation.append(f"Vessel density is LOW ({vessel_density:.3f}) - possible vascular dropout (glaucoma sign)")
    elif vessel_density < 0.10:
        explanation.append(f"Vessel density is BORDERLINE ({vessel_density:.3f})")
    else:
        explanation.append(f"Vessel density is NORMAL ({vessel_density:.3f})")
    
    explanation.append("Note: For CDR measurement, use Book 4 (Standalone CDR Calculator)")
    
    return final_score, decision, confidence, explanation


In [29]:
# ============================================================================
# SECTION 9: COMPLETE INFERENCE PIPELINE
# ============================================================================

def predict_glaucoma(image_path, visualize=True):
    """
    Complete glaucoma detection pipeline
    
    Args:
        image_path: Path to fundus image
        visualize: Whether to generate visualization
    
    Returns:
        results: Dictionary with all predictions
        fig: matplotlib figure (if visualize=True)
    """
    
    # Load image
    image = cv2.imread(str(image_path))
    if image is None:
        raise ValueError(f"Could not load image: {image_path}")
    
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    print(f"\n{'='*60}")
    print(f"Processing: {Path(image_path).name}")
    print('='*60)
    
    # ---- Step 1: YOLO Detection ----
    print("\n[1/3] Running YOLO OD/OC detection...")
    yolo_results = yolo_detect_od_oc(image_rgb, yolo_model, 
                                     conf=config.YOLO_CONF, iou=config.YOLO_IOU)
    print(f"  OD detected: {yolo_results['od_detected']} (conf: {yolo_results['od_conf']:.3f})")
    print(f"  OC detected: {yolo_results['oc_detected']} (conf: {yolo_results['oc_conf']:.3f})")
    print(f"  Quality score: {yolo_results['quality_score']:.3f}")
    
    # ---- Step 2: ResNet Classification ----
    print("\n[2/3] Running ResNet classification...")
    preprocessed_img = preprocessor.preprocess(image_rgb)
    resnet_prob = resnet_classify(preprocessed_img, resnet_model)
    print(f"  Glaucoma probability: {resnet_prob:.3f}")
    
    # ---- Step 3: U-Net Vessel Segmentation ----
    print("\n[3/3] Running U-Net vessel segmentation...")
    vessel_density, vessel_mask = unet_segment_vessels(image_rgb, unet_vessel_model)
    print(f"  Vessel density: {vessel_density:.3f}")
    
    # ---- Step 4: Decision Fusion ----
    print("\n[4/4] Fusing predictions...")
    final_score, decision, confidence, explanation = fuse_predictions(
        resnet_prob, yolo_results['quality_score'], vessel_density
    )
    
    print(f"\n{'='*60}")
    print(f"FINAL RESULTS")
    print('='*60)
    print(f"Decision: {decision}")
    print(f"Confidence: {confidence}")
    print(f"Risk Score: {final_score:.3f}")
    print(f"\nExplanation:")
    for i, exp in enumerate(explanation, 1):
        print(f"  {i}. {exp}")
    print('='*60)
    
    # Compile results
    results = {
        'decision': decision,
        'confidence': confidence,
        'risk_score': final_score,
        'resnet_prob': resnet_prob,
        'vessel_density': vessel_density,
        'yolo_detections': yolo_results,
        'explanation': explanation
    }
    
    # ---- Visualization ----
    if visualize:
        fig = visualize_results(image_rgb, results, vessel_mask, yolo_results)
        return results, fig
    
    return results, None


In [30]:
# ============================================================================
# SECTION 10: VISUALIZATION
# ============================================================================

def visualize_results(image, results, vessel_mask, yolo_results):
    """Create comprehensive visualization"""
    
    fig = plt.figure(figsize=(18, 10))
    gs = fig.add_gridspec(2, 4, hspace=0.3, wspace=0.3)
    
    # Color map for decision
    decision_colors = {
        'Normal': 'green',
        'Suspicious': 'orange',
        'Glaucoma': 'red'
    }
    decision_color = decision_colors.get(results['decision'], 'gray')
    
    # Row 1: Input, Preprocessed, YOLO, Decision
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.imshow(image)
    ax1.set_title("Original Image", fontsize=12, weight='bold')
    ax1.axis('off')
    
    preprocessed = preprocessor.preprocess(image)
    ax2 = fig.add_subplot(gs[0, 1])
    ax2.imshow(preprocessed)
    ax2.set_title("Preprocessed", fontsize=12, weight='bold')
    ax2.axis('off')
    
    # YOLO detections
    ax3 = fig.add_subplot(gs[0, 2])
    img_yolo = image.copy()
    if yolo_results['od_bbox'] is not None:
        x1, y1, x2, y2 = yolo_results['od_bbox'].astype(int)
        cv2.rectangle(img_yolo, (x1, y1), (x2, y2), (255, 0, 0), 3)
        cv2.putText(img_yolo, f"OD: {yolo_results['od_conf']:.2f}", 
                   (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), 2)
    if yolo_results['oc_bbox'] is not None:
        x1, y1, x2, y2 = yolo_results['oc_bbox'].astype(int)
        cv2.rectangle(img_yolo, (x1, y1), (x2, y2), (0, 255, 0), 3)
        cv2.putText(img_yolo, f"OC: {yolo_results['oc_conf']:.2f}", 
                   (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
    ax3.imshow(img_yolo)
    ax3.set_title("YOLO Detections", fontsize=12, weight='bold')
    ax3.axis('off')
    
    # Final Decision
    ax4 = fig.add_subplot(gs[0, 3])
    ax4.text(0.5, 0.6, results['decision'], 
            ha='center', va='center', fontsize=32, 
            color=decision_color, weight='bold')
    ax4.text(0.5, 0.4, f"Risk: {results['risk_score']:.3f}", 
            ha='center', va='center', fontsize=16)
    ax4.text(0.5, 0.3, f"Confidence: {results['confidence']}", 
            ha='center', va='center', fontsize=14)
    ax4.set_xlim(0, 1)
    ax4.set_ylim(0, 1)
    ax4.axis('off')
    ax4.set_title("Final Decision", fontsize=12, weight='bold')
    
    # Row 2: Vessel segmentation and metrics
    ax5 = fig.add_subplot(gs[1, 0])
    vessel_overlay = cv2.resize(image, vessel_mask.shape)
    vessel_overlay[vessel_mask == 1] = [255, 255, 0]
    ax5.imshow(vessel_overlay)
    ax5.set_title(f"Vessels (Density: {results['vessel_density']:.3f})", 
                 fontsize=12, weight='bold')
    ax5.axis('off')
    
    # Metrics
    ax6 = fig.add_subplot(gs[1, 1:3])
    metrics_text = f"""
MODEL OUTPUTS (No CDR - use Book 4 for CDR)
{'='*50}

ResNet Probability:  {results['resnet_prob']:.3f}
Vessel Density:      {results['vessel_density']:.3f}
YOLO Quality:        {results['yolo_detections']['quality_score']:.3f}

FUSION WEIGHTS:
{'='*50}
ResNet Weight:       {config.WEIGHT_RESNET}
YOLO Weight:         {config.WEIGHT_YOLO}
Vessel Weight:       {config.WEIGHT_VESSEL}

THRESHOLDS:
{'='*50}
Normal:              < {config.THRESHOLD_NORMAL}
Suspicious:          {config.THRESHOLD_NORMAL} - {config.THRESHOLD_SUSPICIOUS}
Glaucoma:            ≥ {config.THRESHOLD_SUSPICIOUS}
    """
    ax6.text(0.05, 0.95, metrics_text, transform=ax6.transAxes,
            fontsize=10, verticalalignment='top', family='monospace',
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.3))
    ax6.axis('off')
    
    # Explanation
    ax7 = fig.add_subplot(gs[1, 3])
    explanation_text = "Clinical Interpretation:\n" + "─" * 30 + "\n"
    for i, exp in enumerate(results['explanation'], 1):
        explanation_text += f"{i}. {exp}\n\n"
    ax7.text(0.05, 0.95, explanation_text, transform=ax7.transAxes,
            fontsize=9, verticalalignment='top',
            bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.3))
    ax7.axis('off')
    
    plt.suptitle(f"Glaucoma Detection Report (3-Model Fusion)", 
                fontsize=14, weight='bold', y=0.98)
    
    return fig


In [31]:
# ============================================================================
# SECTION 11: BATCH INFERENCE
# ============================================================================

def batch_inference(image_dir, output_csv='batch_results.csv'):
    """
    Run inference on all images in a directory
    
    Args:
        image_dir: Directory containing fundus images
        output_csv: Output CSV file path
    """
    image_dir = Path(image_dir)
    image_files = (list(image_dir.glob('*.jpg')) + 
                  list(image_dir.glob('*.png')) +
                  list(image_dir.glob('*.jpeg')))
    
    print(f"\n{'='*60}")
    print(f"BATCH PROCESSING: {len(image_files)} images")
    print('='*60)
    
    batch_results = []
    
    for img_path in image_files:
        try:
            results, _ = predict_glaucoma(str(img_path), visualize=False)
            
            batch_results.append({
                'image': img_path.name,
                'decision': results['decision'],
                'confidence': results['confidence'],
                'risk_score': results['risk_score'],
                'resnet_prob': results['resnet_prob'],
                'vessel_density': results['vessel_density'],
                'yolo_quality': results['yolo_detections']['quality_score']
            })
            
            print(f"✓ {img_path.name}: {results['decision']} (risk: {results['risk_score']:.3f})")
            
        except Exception as e:
            print(f"✗ {img_path.name}: Error - {e}")
            batch_results.append({
                'image': img_path.name,
                'decision': 'ERROR',
                'error': str(e)
            })
    
    # Save results
    import pandas as pd
    df = pd.DataFrame(batch_results)
    df.to_csv(output_csv, index=False)
    print(f"\n✓ Results saved to {output_csv}")
    
    # Summary statistics
    if len(df[df['decision'] != 'ERROR']) > 0:
        print("\n" + "="*60)
        print("BATCH SUMMARY")
        print("="*60)
        print(df['decision'].value_counts())
        print(f"\nAverage Risk Score: {df[df['decision'] != 'ERROR']['risk_score'].mean():.3f}")
    
    return df

In [32]:
# ============================================================================
# SECTION 12: TEST ON SAMPLE IMAGES
# ============================================================================

print("\n" + "="*60)
print("TESTING INFERENCE PIPELINE")
print("="*60)

# Find test images
test_dir = Path(config.TEST_IMAGES_DIR)
if test_dir.exists():
    test_images = (list(test_dir.glob('*.jpg')) + 
                  list(test_dir.glob('*.png')) +
                  list(test_dir.glob('*.jpeg')))
    
    if len(test_images) > 0:
        print(f"\nFound {len(test_images)} test images")
        
        # Test on first 3 images
        for img_path in test_images[:3]:
            try:
                results, fig = predict_glaucoma(str(img_path), visualize=True)
                
                plt.savefig(f'result_{img_path.stem}.png', 
                           dpi=150, bbox_inches='tight')
                plt.show()
                
            except Exception as e:
                print(f"\n✗ Error processing {img_path.name}: {e}")
    else:
        print("\n⚠️ No test images found in test directory")
        print("Please upload test images to run inference")
else:
    print(f"\n⚠️ Test directory not found: {test_dir}")
    print("Please upload test images to Kaggle")

print("\n" + "="*60)
print("✅ INFERENCE PIPELINE READY!")
print("="*60)
print("\nUsage:")
print("  # Single image:")
print("  results, fig = predict_glaucoma('path/to/image.jpg')")
print("\n  # Batch processing:")
print("  df = batch_inference('/kaggle/input/test-images/')")


TESTING INFERENCE PIPELINE

⚠️ No test images found in test directory
Please upload test images to run inference

✅ INFERENCE PIPELINE READY!

Usage:
  # Single image:
  results, fig = predict_glaucoma('path/to/image.jpg')

  # Batch processing:
  df = batch_inference('/kaggle/input/test-images/')
