# Enhanced Cocoa Disease Segmentation Model

This notebook provides a comprehensive solution for cocoa disease segmentation with:
- Fixed mask creation logic
- Data validation and quality checks
- Class balancing and proper metrics
- Robust training pipeline
- Enhanced inference and visualization

## Key Improvements:
1. **Fixed Binary Output Problem**: Proper multi-class segmentation (0, 1, 2, 3)
2. **Enhanced Mask Processing**: Better overlap handling and validation
3. **Class Balancing**: Handle imbalanced datasets effectively
4. **Better Metrics**: IoU, F1-score, and detailed per-class evaluation
5. **Robust Inference**: Post-processing and confidence scoring

In [None]:
# Import required libraries
import cv2
import numpy as np
import os
import glob
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, concatenate, Dropout
from tensorflow.keras.models import Model
from sklearn.utils.class_weight import compute_class_weight

# Import our custom modules
from data_validation import (
    validate_mask_content, analyze_mask_directory, print_analysis_summary,
    visualize_mask_samples, validate_image_mask_pairs
)
from model_utils import (
    create_custom_metrics, calculate_class_weights, create_weighted_loss,
    evaluate_model_detailed, print_evaluation_results, plot_confusion_matrix,
    create_training_callbacks, plot_training_history
)
from inference import CocoaDiseasePredictor, print_prediction_summary

print(f"TensorFlow version: {tf.__version__}")
print("✅ All modules imported successfully!")

In [None]:
# Configuration
IMG_SIZE = (128, 128)
NUM_CLASSES = 4  # background, healthy, black_pod_rot, pod_borer
BATCH_SIZE = 8
EPOCHS = 50
LEARNING_RATE = 0.001

# Data paths - Update these to your actual data paths
BASE_DATA_DIR = '/content/drive/MyDrive/data'  # Update this path
TRAIN_IMG_DIR = os.path.join(BASE_DATA_DIR, 'images/train')
TRAIN_MASK_DIR = os.path.join(BASE_DATA_DIR, 'masks/train/Multiclass')
VAL_IMG_DIR = os.path.join(BASE_DATA_DIR, 'images/val')
VAL_MASK_DIR = os.path.join(BASE_DATA_DIR, 'masks/val/Multiclass')

# Model save paths
MODEL_SAVE_DIR = '/content/drive/MyDrive/models'  # Update this path
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
BEST_MODEL_PATH = os.path.join(MODEL_SAVE_DIR, 'enhanced_model.keras')

print(f"Configuration loaded:")
print(f"  Image size: {IMG_SIZE}")
print(f"  Number of classes: {NUM_CLASSES}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Model will be saved to: {BEST_MODEL_PATH}")

## Step 1: Fixed Mask Creation with Proper Overlap Handling

In [None]:
def enhanced_combine_masks(healthy_dir, black_pod_rot_dir, pod_borer_dir, output_dir, priority_order=None):
    """
    Enhanced mask combination with proper overlap handling.
    
    Args:
        healthy_dir: Directory containing healthy masks
        black_pod_rot_dir: Directory containing black pod rot masks
        pod_borer_dir: Directory containing pod borer masks
        output_dir: Output directory for combined masks
        priority_order: List defining class priority for overlaps [default: [3, 2, 1]]
                       Higher priority classes override lower priority ones
    
    Returns:
        dict: Processing results and statistics
    """
    if priority_order is None:
        priority_order = [3, 2, 1]  # pod_borer > black_pod_rot > healthy
    
    os.makedirs(output_dir, exist_ok=True)
    
    # Find all mask files
    mask_files = []
    for ext in ['*.png', '*.jpg', '*.jpeg']:
        mask_files.extend(glob.glob(os.path.join(healthy_dir, ext)))
    
    print(f"🔍 Found {len(mask_files)} mask files to process")
    
    if len(mask_files) == 0:
        print(f"⚠️  WARNING: No mask files found in {healthy_dir}")
        return {'error': 'No mask files found'}
    
    processing_stats = {
        'total_processed': 0,
        'successful': 0,
        'errors': 0,
        'overlap_cases': 0,
        'class_distribution': Counter()
    }
    
    for i, mask_path in enumerate(mask_files):
        basename = os.path.basename(mask_path)
        
        if i % 100 == 0:
            print(f"  Processing {i+1}/{len(mask_files)}: {basename}")
        
        try:
            # Read individual class masks
            healthy = cv2.imread(os.path.join(healthy_dir, basename), cv2.IMREAD_GRAYSCALE)
            black_pod_rot = cv2.imread(os.path.join(black_pod_rot_dir, basename), cv2.IMREAD_GRAYSCALE)
            pod_borer = cv2.imread(os.path.join(pod_borer_dir, basename), cv2.IMREAD_GRAYSCALE)
            
            if healthy is None:
                print(f"  ⚠️  Warning: Couldn't read healthy mask {basename}")
                processing_stats['errors'] += 1
                continue
            
            # Initialize combined mask with background (0)
            combined_mask = np.zeros_like(healthy, dtype=np.uint8)
            
            # Create binary masks for each class
            class_masks = {
                1: healthy > 127 if healthy is not None else np.zeros_like(combined_mask, dtype=bool),
                2: black_pod_rot > 127 if black_pod_rot is not None else np.zeros_like(combined_mask, dtype=bool),
                3: pod_borer > 127 if pod_borer is not None else np.zeros_like(combined_mask, dtype=bool)
            }
            
            # Check for overlaps
            total_class_pixels = sum(np.sum(mask) for mask in class_masks.values())
            union_pixels = np.sum(np.logical_or.reduce(list(class_masks.values())))
            
            if total_class_pixels > union_pixels:
                processing_stats['overlap_cases'] += 1
            
            # Apply classes in priority order (lowest priority first)
            for class_id in reversed(priority_order):
                if class_id in class_masks:
                    combined_mask[class_masks[class_id]] = class_id
            
            # Validate the combined mask
            unique_values = np.unique(combined_mask)
            if not set(unique_values).issubset({0, 1, 2, 3}):
                print(f"  ⚠️  Warning: Invalid values in combined mask for {basename}: {unique_values}")
                # Clamp invalid values
                combined_mask = np.clip(combined_mask, 0, 3)
            
            # Count class distribution
            class_counts = Counter(combined_mask.flatten())
            for class_id, count in class_counts.items():
                processing_stats['class_distribution'][class_id] += count
            
            # Save combined mask
            output_path = os.path.join(output_dir, os.path.splitext(basename)[0] + '.png')
            cv2.imwrite(output_path, combined_mask)
            
            processing_stats['successful'] += 1
            
        except Exception as e:
            print(f"  ❌ Error processing {basename}: {e}")
            processing_stats['errors'] += 1
        
        processing_stats['total_processed'] += 1
    
    # Print summary
    print(f"\n✅ Mask combination completed:")
    print(f"  Total processed: {processing_stats['total_processed']}")
    print(f"  Successful: {processing_stats['successful']}")
    print(f"  Errors: {processing_stats['errors']}")
    print(f"  Overlap cases handled: {processing_stats['overlap_cases']}")
    
    # Print class distribution
    total_pixels = sum(processing_stats['class_distribution'].values())
    if total_pixels > 0:
        print(f"\n📊 Class distribution:")
        class_names = {0: 'Background', 1: 'Healthy', 2: 'Black Pod Rot', 3: 'Pod Borer'}
        for class_id in sorted(processing_stats['class_distribution'].keys()):
            count = processing_stats['class_distribution'][class_id]
            percentage = (count / total_pixels) * 100
            name = class_names.get(class_id, f'Class {class_id}')
            print(f"  {name}: {count:,} pixels ({percentage:.2f}%)")
    
    return processing_stats

# Only run mask combination if the multiclass directories don't exist or are empty
if not os.path.exists(TRAIN_MASK_DIR) or len(os.listdir(TRAIN_MASK_DIR)) == 0:
    print("🔧 Creating enhanced multiclass masks for TRAINING data...")
    train_stats = enhanced_combine_masks(
        healthy_dir=os.path.join(BASE_DATA_DIR, 'masks/train/healthy'),
        black_pod_rot_dir=os.path.join(BASE_DATA_DIR, 'masks/train/black_pod_rot'),
        pod_borer_dir=os.path.join(BASE_DATA_DIR, 'masks/train/pod_borer'),
        output_dir=TRAIN_MASK_DIR
    )
else:
    print(f"✅ Training multiclass masks already exist in {TRAIN_MASK_DIR}")

if not os.path.exists(VAL_MASK_DIR) or len(os.listdir(VAL_MASK_DIR)) == 0:
    print("\n🔧 Creating enhanced multiclass masks for VALIDATION data...")
    val_stats = enhanced_combine_masks(
        healthy_dir=os.path.join(BASE_DATA_DIR, 'masks/val/healthy'),
        black_pod_rot_dir=os.path.join(BASE_DATA_DIR, 'masks/val/black_pod_rot'),
        pod_borer_dir=os.path.join(BASE_DATA_DIR, 'masks/val/pod_borer'),
        output_dir=VAL_MASK_DIR
    )
else:
    print(f"✅ Validation multiclass masks already exist in {VAL_MASK_DIR}")

## Step 2: Data Validation and Quality Assessment

In [None]:
# Validate training masks
print("🔍 VALIDATING TRAINING MASKS")
print("="*50)
train_mask_analysis = analyze_mask_directory(TRAIN_MASK_DIR, sample_size=100)
print_analysis_summary(train_mask_analysis)

# Validate validation masks  
print("\n🔍 VALIDATING VALIDATION MASKS")
print("="*50)
val_mask_analysis = analyze_mask_directory(VAL_MASK_DIR, sample_size=50)
print_analysis_summary(val_mask_analysis)

In [None]:
# Validate image-mask pairs
print("🔍 VALIDATING IMAGE-MASK PAIRS")
print("="*50)

train_pair_validation = validate_image_mask_pairs(TRAIN_IMG_DIR, TRAIN_MASK_DIR, sample_size=10)
val_pair_validation = validate_image_mask_pairs(VAL_IMG_DIR, VAL_MASK_DIR, sample_size=5)

print(f"\n📊 Training pairs: {train_pair_validation.get('total_possible_pairs', 0)} total, {train_pair_validation.get('valid_pairs', 0)} valid")
print(f"📊 Validation pairs: {val_pair_validation.get('total_possible_pairs', 0)} total, {val_pair_validation.get('valid_pairs', 0)} valid")

In [None]:
# Visualize sample masks to understand the data
print("🎨 Visualizing sample masks...")
visualize_mask_samples(TRAIN_MASK_DIR, num_samples=4)

## Step 3: Enhanced Data Loading with Validation

In [None]:
def get_enhanced_matching_paths(image_dir, mask_dir, image_exts=('jpg', 'jpeg', 'png'), mask_exts=('png',)):
    """Enhanced version with validation of matching image-mask pairs."""
    images = []
    for ext in image_exts:
        images += glob.glob(os.path.join(image_dir, f'*.{ext}'))

    masks = []
    for ext in mask_exts:
        masks += glob.glob(os.path.join(mask_dir, f'*.{ext}'))

    # Create dictionaries for matching
    image_dict = {os.path.splitext(os.path.basename(p))[0]: p for p in images}
    mask_dict = {os.path.splitext(os.path.basename(p))[0]: p for p in masks}

    # Find common keys
    common_keys = sorted(set(image_dict.keys()) & set(mask_dict.keys()))

    matched_images = [image_dict[k] for k in common_keys]
    matched_masks = [mask_dict[k] for k in common_keys]

    print(f"📊 Matched {len(matched_images)} image-mask pairs from {len(images)} images and {len(masks)} masks")

    if len(common_keys) == 0:
        print("❌ WARNING: No matching image/mask pairs found!")
        print("Sample image keys:", list(image_dict.keys())[:5])
        print("Sample mask keys:", list(mask_dict.keys())[:5])

    return matched_images, matched_masks

def enhanced_process_data(image_path, mask_path):
    """Enhanced data processing with validation."""
    # Load and process image
    image = tf.io.read_file(image_path)
    
    # Try to decode as different formats
    try:
        image = tf.image.decode_jpeg(image, channels=3)
    except:
        try:
            image = tf.image.decode_png(image, channels=3)
        except:
            # Last resort - try as any image format
            image = tf.image.decode_image(image, channels=3)
    
    image = tf.image.resize(image, IMG_SIZE)
    image = tf.cast(image, tf.float32) / 255.0
    
    # Load and process mask
    mask = tf.io.read_file(mask_path)
    mask = tf.image.decode_png(mask, channels=1)
    mask = tf.image.resize(mask, IMG_SIZE, method='nearest')
    mask = tf.cast(mask, tf.int32)  # Keep as int32 for sparse_categorical_crossentropy
    mask = tf.squeeze(mask, axis=-1)  # Remove channel dimension for sparse format
    
    # Validate mask values
    mask = tf.clip_by_value(mask, 0, NUM_CLASSES-1)
    
    return image, mask

def create_enhanced_dataset(image_dir, mask_dir, batch_size=8, shuffle=True, augment=False):
    """Create enhanced dataset with validation and optional augmentation."""
    image_paths, mask_paths = get_enhanced_matching_paths(image_dir, mask_dir)
    
    if len(image_paths) == 0:
        print("❌ ERROR: No matching image/mask pairs found!")
        return None
    
    # Create dataset
    image_ds = tf.data.Dataset.from_tensor_slices(image_paths)
    mask_ds = tf.data.Dataset.from_tensor_slices(mask_paths)
    
    dataset = tf.data.Dataset.zip((image_ds, mask_ds))
    dataset = dataset.map(enhanced_process_data, num_parallel_calls=tf.data.AUTOTUNE)
    
    # Add data augmentation for training
    if augment:
        dataset = dataset.map(augment_data, num_parallel_calls=tf.data.AUTOTUNE)
    
    if shuffle:
        dataset = dataset.shuffle(1000)
    
    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    
    return dataset

def augment_data(image, mask):
    """Simple data augmentation that preserves mask integrity."""
    # Random horizontal flip
    if tf.random.uniform(()) > 0.5:
        image = tf.image.flip_left_right(image)
        mask = tf.image.flip_left_right(tf.expand_dims(mask, -1))
        mask = tf.squeeze(mask, -1)
    
    # Random brightness adjustment (only for image)
    image = tf.image.random_brightness(image, 0.1)
    
    # Random contrast adjustment (only for image)  
    image = tf.image.random_contrast(image, 0.9, 1.1)
    
    # Ensure image values stay in [0, 1]
    image = tf.clip_by_value(image, 0.0, 1.0)
    
    return image, mask

# Create enhanced datasets
print("🔧 Creating enhanced datasets...")
train_dataset = create_enhanced_dataset(TRAIN_IMG_DIR, TRAIN_MASK_DIR, 
                                       batch_size=BATCH_SIZE, shuffle=True, augment=True)
val_dataset = create_enhanced_dataset(VAL_IMG_DIR, VAL_MASK_DIR, 
                                     batch_size=BATCH_SIZE, shuffle=False, augment=False)

if train_dataset is None or val_dataset is None:
    print("❌ ERROR: Failed to create datasets. Please check your file paths and masks.")
else:
    print("✅ Enhanced datasets created successfully!")
    
    # Quick validation of dataset contents
    print("\n🔍 Validating dataset contents...")
    for images, masks in train_dataset.take(1):
        print(f"  Batch image shape: {images.shape}")
        print(f"  Batch mask shape: {masks.shape}")
        print(f"  Image value range: [{tf.reduce_min(images):.3f}, {tf.reduce_max(images):.3f}]")
        print(f"  Mask unique values: {tf.unique(tf.reshape(masks, [-1]))[0].numpy()}")
        break

## Step 4: Calculate Class Weights for Balanced Training

In [None]:
# Calculate class weights to handle imbalanced data
print("⚖️  Calculating class weights for balanced training...")
class_weights = calculate_class_weights(TRAIN_MASK_DIR, num_classes=NUM_CLASSES, method='balanced')

print(f"\n📊 Class weights will be used to balance training:")
class_names = ['Background', 'Healthy', 'Black Pod Rot', 'Pod Borer']
for i, (class_id, weight) in enumerate(class_weights.items()):
    print(f"  {class_names[i]}: {weight:.3f}")

## Step 5: Enhanced U-Net Model with Better Architecture

In [None]:
def build_enhanced_unet(input_shape=(128, 128, 3), num_classes=4, dropout_rate=0.3):
    """
    Build enhanced U-Net architecture with improvements:
    - Batch normalization for stable training
    - Dropout for regularization
    - Skip connections for better gradient flow
    """
    inputs = Input(input_shape, name='input_image')
    
    # Encoder (Contracting Path)
    # Block 1
    c1 = Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    c1 = tf.keras.layers.BatchNormalization()(c1)
    c1 = Conv2D(64, (3, 3), activation='relu', padding='same')(c1)
    c1 = tf.keras.layers.BatchNormalization()(c1)
    p1 = MaxPooling2D((2, 2))(c1)
    p1 = Dropout(dropout_rate)(p1)
    
    # Block 2
    c2 = Conv2D(128, (3, 3), activation='relu', padding='same')(p1)
    c2 = tf.keras.layers.BatchNormalization()(c2)
    c2 = Conv2D(128, (3, 3), activation='relu', padding='same')(c2)
    c2 = tf.keras.layers.BatchNormalization()(c2)
    p2 = MaxPooling2D((2, 2))(c2)
    p2 = Dropout(dropout_rate)(p2)
    
    # Block 3
    c3 = Conv2D(256, (3, 3), activation='relu', padding='same')(p2)
    c3 = tf.keras.layers.BatchNormalization()(c3)
    c3 = Conv2D(256, (3, 3), activation='relu', padding='same')(c3)
    c3 = tf.keras.layers.BatchNormalization()(c3)
    p3 = MaxPooling2D((2, 2))(c3)
    p3 = Dropout(dropout_rate)(p3)
    
    # Block 4
    c4 = Conv2D(512, (3, 3), activation='relu', padding='same')(p3)
    c4 = tf.keras.layers.BatchNormalization()(c4)
    c4 = Conv2D(512, (3, 3), activation='relu', padding='same')(c4)
    c4 = tf.keras.layers.BatchNormalization()(c4)
    p4 = MaxPooling2D((2, 2))(c4)
    p4 = Dropout(dropout_rate)(p4)
    
    # Bridge (Bottleneck)
    c5 = Conv2D(1024, (3, 3), activation='relu', padding='same')(p4)
    c5 = tf.keras.layers.BatchNormalization()(c5)
    c5 = Conv2D(1024, (3, 3), activation='relu', padding='same')(c5)
    c5 = tf.keras.layers.BatchNormalization()(c5)
    c5 = Dropout(dropout_rate)(c5)
    
    # Decoder (Expansive Path)
    # Block 6
    u6 = Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(c5)
    u6 = concatenate([u6, c4])
    u6 = Dropout(dropout_rate)(u6)
    c6 = Conv2D(512, (3, 3), activation='relu', padding='same')(u6)
    c6 = tf.keras.layers.BatchNormalization()(c6)
    c6 = Conv2D(512, (3, 3), activation='relu', padding='same')(c6)
    c6 = tf.keras.layers.BatchNormalization()(c6)
    
    # Block 7
    u7 = Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(c6)
    u7 = concatenate([u7, c3])
    u7 = Dropout(dropout_rate)(u7)
    c7 = Conv2D(256, (3, 3), activation='relu', padding='same')(u7)
    c7 = tf.keras.layers.BatchNormalization()(c7)
    c7 = Conv2D(256, (3, 3), activation='relu', padding='same')(c7)
    c7 = tf.keras.layers.BatchNormalization()(c7)
    
    # Block 8
    u8 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c7)
    u8 = concatenate([u8, c2])
    u8 = Dropout(dropout_rate)(u8)
    c8 = Conv2D(128, (3, 3), activation='relu', padding='same')(u8)
    c8 = tf.keras.layers.BatchNormalization()(c8)
    c8 = Conv2D(128, (3, 3), activation='relu', padding='same')(c8)
    c8 = tf.keras.layers.BatchNormalization()(c8)
    
    # Block 9
    u9 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c8)
    u9 = concatenate([u9, c1])
    u9 = Dropout(dropout_rate)(u9)
    c9 = Conv2D(64, (3, 3), activation='relu', padding='same')(u9)
    c9 = tf.keras.layers.BatchNormalization()(c9)
    c9 = Conv2D(64, (3, 3), activation='relu', padding='same')(c9)
    c9 = tf.keras.layers.BatchNormalization()(c9)
    
    # Output layer
    outputs = Conv2D(num_classes, (1, 1), activation='softmax', name='segmentation_output')(c9)
    
    model = Model(inputs=[inputs], outputs=[outputs], name='enhanced_unet')
    
    return model

# Build enhanced model
print("🏗️  Building enhanced U-Net model...")
model = build_enhanced_unet(input_shape=(*IMG_SIZE, 3), num_classes=NUM_CLASSES)

# Create custom metrics
custom_metrics = create_custom_metrics(NUM_CLASSES)

# Create weighted loss function
weighted_loss = create_weighted_loss(class_weights)

# Compile model with enhanced configuration
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss=weighted_loss,
    metrics=['accuracy'] + list(custom_metrics.values())
)

print("✅ Enhanced model compiled successfully!")
print(f"📊 Model summary:")
model.summary()

## Step 6: Enhanced Training with Comprehensive Monitoring

In [None]:
# Create training callbacks
callbacks = create_training_callbacks(BEST_MODEL_PATH, patience=15)

# Add additional monitoring
callbacks.append(
    tf.keras.callbacks.LearningRateScheduler(
        lambda epoch: LEARNING_RATE * 0.95 ** epoch
    )
)

print(f"🚀 Starting enhanced training with {len(callbacks)} callbacks...")
print(f"📊 Training configuration:")
print(f"  Epochs: {EPOCHS}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Class weights: {[f'{w:.3f}' for w in class_weights.values()]}")
print(f"  Model will be saved to: {BEST_MODEL_PATH}")

# Train the model
if train_dataset is not None and val_dataset is not None:
    try:
        history = model.fit(
            train_dataset,
            epochs=EPOCHS,
            validation_data=val_dataset,
            callbacks=callbacks,
            verbose=1
        )
        
        print("\n✅ Training completed successfully!")
        
        # Plot training history
        plot_training_history(history, save_path=os.path.join(MODEL_SAVE_DIR, 'training_history.png'))
        
    except Exception as e:
        print(f"❌ Training failed: {e}")
        import traceback
        traceback.print_exc()
else:
    print("❌ Cannot start training: datasets not available")

## Step 7: Comprehensive Model Evaluation

In [None]:
# Load the best model for evaluation
if os.path.exists(BEST_MODEL_PATH):
    print(f"📈 Loading best model from: {BEST_MODEL_PATH}")
    best_model = tf.keras.models.load_model(BEST_MODEL_PATH, compile=False)
    
    # Recompile with metrics for evaluation
    best_model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
        loss=weighted_loss,
        metrics=['accuracy'] + list(custom_metrics.values())
    )
    
    print("✅ Best model loaded and compiled for evaluation")
    
    # Perform detailed evaluation
    if val_dataset is not None:
        print("\n🔍 Performing detailed evaluation...")
        evaluation_results = evaluate_model_detailed(best_model, val_dataset, class_names)
        
        # Print detailed results
        print_evaluation_results(evaluation_results, class_names)
        
        # Plot confusion matrix
        plot_confusion_matrix(evaluation_results['confusion_matrix'], class_names)
        
    else:
        print("⚠️  Validation dataset not available for evaluation")
        
else:
    print(f"⚠️  Best model not found at: {BEST_MODEL_PATH}")
    print("Using the current model for evaluation...")
    best_model = model

## Step 8: Enhanced Inference and Testing

In [None]:
# Create enhanced predictor
if os.path.exists(BEST_MODEL_PATH):
    print("🔮 Creating enhanced predictor...")
    predictor = CocoaDiseasePredictor(BEST_MODEL_PATH, input_size=IMG_SIZE)
    print("✅ Enhanced predictor created successfully!")
    
    # Test on a validation image
    if val_dataset is not None:
        print("\n🧪 Testing predictor on validation data...")
        
        # Get a sample from validation dataset
        for val_images, val_masks in val_dataset.take(1):
            # Take first image from batch
            sample_image = val_images[0].numpy()
            sample_mask = val_masks[0].numpy()
            
            # Convert to correct format for predictor
            sample_image_uint8 = (sample_image * 255).astype(np.uint8)
            
            # Make prediction
            results = predictor.predict(sample_image_uint8, visualize=True)
            
            # Print detailed summary
            print_prediction_summary(results)
            
            break
    
else:
    print("⚠️  Best model not available for enhanced inference")

## Step 9: Interactive Testing Interface

In [None]:
# Interactive testing function
def test_uploaded_image():
    """Interactive function to test uploaded images."""
    try:
        from google.colab import files
        
        print("📁 Upload an image to test the enhanced model:")
        uploaded = files.upload()
        
        if not uploaded:
            print("No files uploaded!")
            return
        
        for filename in uploaded.keys():
            print(f"\n🔍 Analyzing: {filename}")
            print("="*50)
            
            # Predict using enhanced predictor
            if os.path.exists(BEST_MODEL_PATH):
                results = predictor.predict(filename, visualize=True)
                print_prediction_summary(results)
            else:
                print("⚠️  Enhanced predictor not available")
                
    except ImportError:
        print("📝 This interactive function is designed for Google Colab")
        print("For local testing, use: predictor.predict('path/to/image.jpg')")

# Create test interface
print("🎮 Enhanced testing interface ready!")
print("Call test_uploaded_image() to upload and test images")

# Uncomment the line below to start interactive testing
# test_uploaded_image()

## Summary and Next Steps

### ✅ Improvements Implemented:

1. **Fixed Binary Output Problem**: 
   - Enhanced mask creation with proper overlap handling
   - Data validation to ensure 4-class output (0, 1, 2, 3)
   - Proper softmax activation for multi-class prediction

2. **Enhanced Data Pipeline**:
   - Comprehensive data validation utilities
   - Improved image-mask pairing verification
   - Data augmentation for better generalization

3. **Better Model Training**:
   - Class balancing with computed weights
   - Enhanced U-Net with batch normalization and dropout
   - Comprehensive metrics (IoU, Dice, per-class metrics)
   - Improved callbacks and monitoring

4. **Robust Inference Pipeline**:
   - Post-processing with morphological operations
   - Confidence scoring and uncertainty estimation
   - Comprehensive visualization and analysis
   - Disease severity assessment with recommendations

### 🎯 Expected Outcomes:
- Model should now properly predict all 4 classes
- Diseased pods correctly classified as black_pod_rot or pod_borer
- More realistic training progression (not 100% accuracy immediately)
- Better generalization with detailed per-class performance metrics

### 🚀 Usage:
```python
# For single image prediction
predictor = CocoaDiseasePredictor('enhanced_model.keras')
results = predictor.predict('image.jpg')

# For batch prediction
batch_results = predictor.batch_predict(['img1.jpg', 'img2.jpg'])
```