In [None]:
!pip install -q efficientnet
!pip install -q albumentations
!pip install -q kaggle
!pip install -q plotly
!pip install -q seaborn
!pip install -q scikit-learn
!pip install -q opencv-python
!pip install -q tf2onnx

# Import essential libraries
import os
import sys
import json
import random
import warnings
import zipfile
from pathlib import Path
from typing import Dict, List, Tuple, Optional

# Core ML libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# TensorFlow and Keras
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Model, optimizers, callbacks
from tensorflow.keras.applications import EfficientNetV2B0, EfficientNetV2B1
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Image processing
import cv2
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Sklearn utilities
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
from sklearn.model_selection import train_test_split

# Suppress warnings
warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# Configure TensorFlow
print(f"TensorFlow version: {tf.__version__}")
print(f"GPU Available: {tf.config.list_physical_devices('GPU')}")

# Enable mixed precision for faster training
from tensorflow.keras.mixed_precision import Policy
policy = Policy('mixed_float16')
tf.keras.mixed_precision.set_global_policy(policy)
print("✅ Mixed precision enabled")

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

print("✅ Environment setup complete!")

In [None]:
from pathlib import Path

if not Path("/content/plantdisease.zip").exists():
    !kaggle datasets download -d emmarex/plantdisease -p /content

!unzip -oq /content/plantdisease.zip -d /content/

data_dir = Path('/content/PlantVillage') 
if data_dir.exists() and any(data_dir.iterdir()):
    print("✅ Dataset extracted successfully!")
    print(f"Dataset path: {data_dir}")

    subdirs = sorted([d for d in data_dir.iterdir() if d.is_dir()])
    print(f"\nDataset contains {len(subdirs)} subdirectories:")
    for subdir in subdirs[:5]:
        num_images = len(list(subdir.glob('*.jpg')) + list(subdir.glob('*.JPG')))
        print(f"  📁 {subdir.name}: {num_images} images")
    if len(subdirs) > 5:
        print(f"  ... and {len(subdirs) - 5} more directories")
else:
    print("❌ Extraction failed. Please check the .zip file.")

In [None]:
# 📊 Dataset Exploration and Analysis
print("📊 Analyzing dataset structure and class distribution...")

# Define dataset paths
DATASET_DIR = Path('/content/plantdisease')
IMG_SIZE = (224, 224)  # EfficientNetV2 recommended size
BATCH_SIZE = 32
NUM_CLASSES = 38

# Get all class directories
class_dirs = sorted([d for d in DATASET_DIR.iterdir() if d.is_dir()])
class_names = [d.name for d in class_dirs]

print(f"📋 Found {len(class_names)} classes:")
for i, class_name in enumerate(class_names):
    print(f"  {i+1:2d}. {class_name}")

# Analyze class distribution
class_counts = {}
total_images = 0

for class_dir in class_dirs:
    # Count images (jpg, JPG, png, PNG)
    image_files = (list(class_dir.glob('*.jpg')) + 
                  list(class_dir.glob('*.JPG')) + 
                  list(class_dir.glob('*.png')) + 
                  list(class_dir.glob('*.PNG')))
    
    count = len(image_files)
    class_counts[class_dir.name] = count
    total_images += count

print(f"\n📈 Dataset Statistics:")
print(f"Total Images: {total_images:,}")
print(f"Total Classes: {len(class_names)}")
print(f"Average per class: {total_images // len(class_names):,}")

# Find min/max class sizes
min_class = min(class_counts, key=class_counts.get)
max_class = max(class_counts, key=class_counts.get)
print(f"Smallest class: {min_class} ({class_counts[min_class]} images)")
print(f"Largest class: {max_class} ({class_counts[max_class]} images)")

# Create class distribution visualization
plt.figure(figsize=(15, 8))
classes_sorted = sorted(class_counts.items(), key=lambda x: x[1], reverse=True)
classes, counts = zip(*classes_sorted)

plt.bar(range(len(classes)), counts, alpha=0.7, color='skyblue', edgecolor='navy')
plt.title('Class Distribution in PlantDisease Dataset', fontsize=16, fontweight='bold')
plt.xlabel('Plant Disease Classes', fontsize=12)
plt.ylabel('Number of Images', fontsize=12)
plt.xticks(range(len(classes)), classes, rotation=45, ha='right')
plt.grid(axis='y', alpha=0.3)

# Add value labels on bars
for i, count in enumerate(counts):
    plt.text(i, count + 50, str(count), ha='center', va='bottom', fontsize=8)

plt.tight_layout()
plt.show()

# Calculate class balance
class_balance = max(counts) / min(counts)
print(f"\n⚖️ Class Balance Ratio: {class_balance:.2f}:1")
if class_balance > 3:
    print("⚠️  Dataset is imbalanced. We'll use class weights and balanced sampling.")
else:
    print("✅ Dataset is reasonably balanced.")

In [None]:
# 🖼️ Sample Images Visualization
print("🖼️ Displaying sample images from each class...")

# Create a grid of sample images
fig, axes = plt.subplots(6, 6, figsize=(20, 20))
axes = axes.flatten()

# Show samples from first 36 classes
for idx, class_dir in enumerate(class_dirs[:36]):
    # Get a random image from this class
    image_files = list(class_dir.glob('*.jpg')) + list(class_dir.glob('*.JPG'))
    if image_files:
        random_image = random.choice(image_files)
        
        # Load and display image
        img = Image.open(random_image)
        img = img.resize((150, 150))
        
        axes[idx].imshow(img)
        axes[idx].set_title(class_dir.name, fontsize=10, pad=10)
        axes[idx].axis('off')
    else:
        axes[idx].text(0.5, 0.5, 'No images', ha='center', va='center')
        axes[idx].axis('off')

plt.suptitle('Sample Images from Plant Disease Dataset', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

# Display image size analysis
print("\n📏 Analyzing image dimensions...")
sample_sizes = []
for class_dir in class_dirs[:5]:  # Sample from first 5 classes
    image_files = list(class_dir.glob('*.jpg')) + list(class_dir.glob('*.JPG'))
    for img_path in image_files[:10]:  # Sample 10 images per class
        try:
            with Image.open(img_path) as img:
                sample_sizes.append(img.size)
        except:
            continue

if sample_sizes:
    widths, heights = zip(*sample_sizes)
    print(f"Sample image dimensions:")
    print(f"  Width range: {min(widths)} - {max(widths)} pixels")
    print(f"  Height range: {min(heights)} - {max(heights)} pixels")
    print(f"  Most common size: {max(set(sample_sizes), key=sample_sizes.count)}")
    print(f"✅ We'll resize all images to {IMG_SIZE} for consistent training")

In [None]:
# Advanced Data Preprocessing and Augmentation Pipeline
class AdvancedDataGenerator:
    """Advanced data generator with custom augmentation strategies"""
    
    def __init__(self, dataset_dir, img_size=(224, 224), batch_size=32):
        self.dataset_dir = Path(dataset_dir)
        self.img_size = img_size
        self.batch_size = batch_size
        self.class_names = sorted([d.name for d in self.dataset_dir.iterdir() if d.is_dir()])
        self.num_classes = len(self.class_names)
        
        # Create class to index mapping
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.class_names)}
        
        # Calculate class weights for imbalanced dataset
        self.class_weights = self._calculate_class_weights()
        
        # Define augmentation strategies
        self.train_augmentation = A.Compose([
            A.Resize(height=img_size[0], width=img_size[1]),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.2),
            A.RandomRotate90(p=0.3),
            A.Rotate(limit=15, p=0.3),
            A.RandomBrightnessContrast(
                brightness_limit=0.2, 
                contrast_limit=0.2, 
                p=0.3
            ),
            A.HueSaturationValue(
                hue_shift_limit=10, 
                sat_shift_limit=20, 
                val_shift_limit=10, 
                p=0.3
            ),
            A.GaussNoise(var_limit=(10.0, 50.0), p=0.2),
            A.GaussianBlur(blur_limit=3, p=0.1),
            A.OneOf([
                A.GridDistortion(p=0.5),
                A.ElasticTransform(p=0.5),
            ], p=0.2),
            A.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
        
        self.val_augmentation = A.Compose([
            A.Resize(height=img_size[0], width=img_size[1]),
            A.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
    
    def _calculate_class_weights(self):
        """Calculate class weights to handle imbalanced dataset"""
        class_counts = {}
        for class_dir in self.dataset_dir.iterdir():
            if class_dir.is_dir():
                count = len(list(class_dir.glob('*.jpg')) + list(class_dir.glob('*.JPG')))
                class_counts[class_dir.name] = count
        
        total_samples = sum(class_counts.values())
        weights = {}
        
        for class_name, count in class_counts.items():
            weight = total_samples / (self.num_classes * count)
            weights[self.class_to_idx[class_name]] = weight
            
        return weights
    
    def create_dataset_split(self, train_split=0.8, val_split=0.15, test_split=0.05):
        """Create train/validation/test splits"""
        all_files = []
        all_labels = []
        
        for class_dir in self.dataset_dir.iterdir():
            if class_dir.is_dir():
                class_idx = self.class_to_idx[class_dir.name]
                image_files = list(class_dir.glob('*.jpg')) + list(class_dir.glob('*.JPG'))
                
                for img_file in image_files:
                    all_files.append(str(img_file))
                    all_labels.append(class_idx)
        
        # Create stratified splits
        X_temp, X_test, y_temp, y_test = train_test_split(
            all_files, all_labels, 
            test_size=test_split, 
            stratify=all_labels, 
            random_state=SEED
        )
        
        X_train, X_val, y_train, y_val = train_test_split(
            X_temp, y_temp, 
            test_size=val_split/(train_split + val_split), 
            stratify=y_temp, 
            random_state=SEED
        )
        
        return (X_train, y_train), (X_val, y_val), (X_test, y_test)
    
    def load_and_preprocess_image(self, image_path, augmentation):
        """Load and preprocess a single image"""
        try:
            # Load image
            image = cv2.imread(str(image_path))
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
            # Apply augmentation
            augmented = augmentation(image=image)
            image = augmented['image']
            
            return image.astype(np.float32)
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            # Return a blank image if loading fails
            blank = np.zeros((*self.img_size, 3), dtype=np.float32)
            return blank
    
    def create_tf_dataset(self, files, labels, augmentation, shuffle=True):
        """Create TensorFlow dataset"""
        def load_image_fn(image_path, label):
            # Convert tensor to string for file path
            image_path_str = tf.py_function(lambda x: x.numpy().decode('utf-8'), [image_path], tf.string)
            
            # Load and preprocess image
            image = tf.py_function(
                lambda path: self.load_and_preprocess_image(path.numpy().decode('utf-8'), augmentation),
                [image_path_str],
                tf.float32
            )
            image.set_shape([*self.img_size, 3])
            
            # One-hot encode label
            label_onehot = tf.one_hot(label, self.num_classes)
            
            return image, label_onehot
        
        # Create dataset
        dataset = tf.data.Dataset.from_tensor_slices((files, labels))
        
        if shuffle:
            dataset = dataset.shuffle(buffer_size=1000, seed=SEED)
        
        dataset = dataset.map(load_image_fn, num_parallel_calls=tf.data.AUTOTUNE)
        dataset = dataset.batch(self.batch_size)
        dataset = dataset.prefetch(tf.data.AUTOTUNE)
        
        return dataset

# Initialize data generator
data_gen = AdvancedDataGenerator(DATASET_DIR, IMG_SIZE, BATCH_SIZE)

print(f"✅ Data generator initialized")
print(f"📊 Found {data_gen.num_classes} classes")
print(f"🎯 Class weights calculated for balanced training")

# Create dataset splits
(X_train, y_train), (X_val, y_val), (X_test, y_test) = data_gen.create_dataset_split()

print(f"\n📊 Dataset Split:")
print(f"  Training: {len(X_train):,} images")
print(f"  Validation: {len(X_val):,} images")
print(f"  Test: {len(X_test):,} images")

# Create TensorFlow datasets
train_dataset = data_gen.create_tf_dataset(X_train, y_train, data_gen.train_augmentation, shuffle=True)
val_dataset = data_gen.create_tf_dataset(X_val, y_val, data_gen.val_augmentation, shuffle=False)
test_dataset = data_gen.create_tf_dataset(X_test, y_test, data_gen.val_augmentation, shuffle=False)

print("✅ TensorFlow datasets created successfully!")

In [None]:
# Building advanced EfficientNetV2 model architecture

class PlantDiseaseModel:
    """Advanced plant disease detection model with transfer learning"""

    def __init__(self, num_classes, img_size=(224, 224), model_name='efficientnetv2-b0'):
        self.num_classes = num_classes
        self.img_size = img_size
        self.model_name = model_name
        self.model = None

    def build_model(self, dropout_rate=0.3, l2_reg=1e-4):
        """Build model with EfficientNetV2 backbone"""

        # Input layer
        inputs = layers.Input(shape=(*self.img_size, 3), name='input_image')

        # Base model (EfficientNetV2)
        if self.model_name == 'efficientnetv2-b0':
            base_model = EfficientNetV2B0(
                weights='imagenet',
                include_top=False,
                input_tensor=inputs,
                pooling='avg' # Global average pooling
            )
        elif self.model_name == 'efficientnetv2-b1':
            base_model = EfficientNetV2B1(
                weights='imagenet',
                include_top=False,
                input_tensor=inputs,
                pooling='avg'
            )
        else:
            raise ValueError(f"Unsupported model name: {self.model_name}")

        # Freeze base model initially
        base_model.trainable = False

        # Base model output (already pooled)
        x = base_model.output

        # Custom classification head
        x = layers.BatchNormalization(name='bn_1')(x)
        x = layers.Dropout(dropout_rate, name='dropout_1')(x)

        x1 = layers.Dense(512, activation='relu',
                          kernel_regularizer=tf.keras.regularizers.l2(l2_reg),
                          name='dense_1')(x)
        x1 = layers.BatchNormalization(name='bn_2')(x1)
        x1 = layers.Dropout(dropout_rate/2, name='dropout_2')(x1)

        x2 = layers.Dense(256, activation='relu',
                          kernel_regularizer=tf.keras.regularizers.l2(l2_reg),
                          name='dense_2')(x1)
        x2 = layers.BatchNormalization(name='bn_3')(x2)
        x2 = layers.Dropout(dropout_rate/2, name='dropout_3')(x2)

        # Output layer
        outputs = layers.Dense(self.num_classes, activation='softmax',
                               kernel_regularizer=tf.keras.regularizers.l2(l2_reg),
                               name='predictions')(x2)

        # Final model
        model = Model(inputs=inputs, outputs=outputs, name='PlantDiseaseDetector')
        self.model = model
        return model

    def unfreeze_base_model(self, unfreeze_layers=-30):
        """Unfreeze top layers of base model for fine-tuning"""
        if self.model is None:
            raise ValueError("Model must be built first")

        # Find EfficientNet backbone
        base_model = None
        for layer in self.model.layers:
            if 'efficientnet' in layer.name.lower():
                base_model = layer
                break

        if base_model:
            base_model.trainable = True
            if abs(unfreeze_layers) > len(base_model.layers):
                print(f"Warning: unfreeze_layers ({unfreeze_layers}) > total layers ({len(base_model.layers)}). Unfreezing all.")
                unfreeze_layers = -len(base_model.layers)

            for layer in base_model.layers[:unfreeze_layers]:
                layer.trainable = False

            print(f"✅ Unfroze top {abs(unfreeze_layers)} layers of base model")
        else:
            print("⚠️ Could not find base EfficientNet layer in the model.")

    def compile_model(self, learning_rate=1e-3, label_smoothing=0.1):
        """Compile model with advanced optimization"""
        if self.model is None:
            raise ValueError("Model must be built first")

        optimizer = tf.keras.optimizers.Adam(
            learning_rate=learning_rate,
            clipnorm=1.0
        )

        loss = tf.keras.losses.CategoricalCrossentropy(
            label_smoothing=label_smoothing
        )

        metrics = [
            'accuracy',
            tf.keras.metrics.TopKCategoricalAccuracy(k=3, name='top3_accuracy'),
            tf.keras.metrics.Precision(name='precision'),
            tf.keras.metrics.Recall(name='recall')
        ]

        self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
        print("✅ Model compiled successfully")


# Build the model
model_builder = PlantDiseaseModel(num_classes=data_gen.num_classes, img_size=IMG_SIZE)
model = model_builder.build_model(dropout_rate=0.3, l2_reg=1e-4)

print(f"🏗️ Model Architecture:")
print(f"  Input Shape: {model.input_shape}")
print(f"  Output Shape: {model.output_shape}")
print(f"  Total Parameters: {model.count_params():,}")

# Count trainable parameters
trainable_params = np.sum([np.prod(w.shape) for w in model.trainable_weights])
print(f"  Trainable Parameters: {trainable_params:,}")

# Compile model
model_builder.compile_model(learning_rate=1e-3)

# Display first & last layers
print("\n🔍 Model Summary (Key Layers):")
for i, layer in enumerate(model.layers[:5]):
    try:
        if isinstance(layer, layers.InputLayer):
            shape_info = model.input_shape
        else:
            shape_info = layer.output_shape
    except AttributeError:
        # Fallback: compute output shape dynamically
        shape_info = layer.compute_output_shape(model.input_shape)
    print(f"  {i+1}. {layer.name}: {shape_info}")

print("  ...")

for i, layer in enumerate(model.layers[-5:], len(model.layers)-5):
    try:
        if isinstance(layer, layers.InputLayer):
            shape_info = model.input_shape
        else:
            shape_info = layer.output_shape
    except AttributeError:
        shape_info = layer.compute_output_shape(model.input_shape)
    print(f"  {i+1}. {layer.name}: {shape_info}")

print("\n✅ Model ready for training!")

In [None]:
# Advanced Training Configuration and Callbacks

class TrainingManager:
    """Manages training process with advanced callbacks and monitoring"""

    def __init__(self, model, model_builder, class_weights):
        self.model = model
        self.model_builder = model_builder
        self.class_weights = class_weights
        self.history = None

    def get_callbacks(self, patience=7):
        """Get comprehensive training callbacks"""

        callbacks_list = [
            # Model checkpointing - save best model
            tf.keras.callbacks.ModelCheckpoint(
                filepath='/content/best_model.h5',
                monitor='val_accuracy',
                save_best_only=True,
                save_weights_only=False,
                mode='max',
                verbose=1
            ),

            # Early stopping
            tf.keras.callbacks.EarlyStopping(
                monitor='val_accuracy',
                patience=patience,
                restore_best_weights=True,
                verbose=1,
                mode='max'
            ),

            # Learning rate reduction
            tf.keras.callbacks.ReduceLROnPlateau(
                monitor='val_loss',
                factor=0.5,
                patience=patience//2,
                min_lr=1e-7,
                verbose=1,
                mode='min'
            ),

            # Cosine annealing schedule
            tf.keras.callbacks.LearningRateScheduler(
                self._cosine_annealing_schedule,
                verbose=0
            ),

            # CSV logger for training history
            tf.keras.callbacks.CSVLogger(
                '/content/training_log.csv',
                append=False
            )
        ]

        return callbacks_list

    def _cosine_annealing_schedule(self, epoch, lr):
        """Cosine annealing learning rate schedule"""
        max_epochs = 50
        min_lr = 1e-7
        max_lr = 1e-3

        if epoch < 5:  # Warmup
            return min_lr + (max_lr - min_lr) * epoch / 5
        else:
            return min_lr + (max_lr - min_lr) * (1 + np.cos(np.pi * epoch / max_epochs)) / 2

    def train_phase1(self, train_dataset, val_dataset, epochs=20):
        """Phase 1: Train with frozen base model"""
        print("🚀 Phase 1: Training with frozen base model...")

        callbacks = self.get_callbacks(patience=5)

        self.history_phase1 = self.model.fit(
            train_dataset,
            epochs=epochs,
            validation_data=val_dataset,
            callbacks=callbacks,
            class_weight=self.class_weights,
            verbose=1
        )

        print("✅ Phase 1 training completed")
        return self.history_phase1

    def train_phase2(self, train_dataset, val_dataset, epochs=30, fine_tune_lr=1e-5):
        """Phase 2: Fine-tuning with unfrozen base model"""
        print("🔧 Phase 2: Fine-tuning with unfrozen base model...")

        # Unfreeze base model
        self.model_builder.unfreeze_base_model(unfreeze_layers=-30)

        # Recompile with lower learning rate
        self.model_builder.compile_model(learning_rate=fine_tune_lr)

        callbacks = self.get_callbacks(patience=10)

        self.history_phase2 = self.model.fit(
            train_dataset,
            epochs=epochs,
            validation_data=val_dataset,
            callbacks=callbacks,
            class_weight=self.class_weights,
            verbose=1
        ) # Removed incorrect line continuation

        print("✅ Phase 2 fine-tuning completed")
        return self.history_phase2

    def plot_training_history(self):
        """Plot comprehensive training history"""
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))

        # Combine histories if both phases exist
        if hasattr(self, 'history_phase1') and hasattr(self, 'history_phase2'):
            combined_history = {}
            for key in self.history_phase1.history.keys():
                combined_history[key] = (self.history_phase1.history[key] +
                                       self.history_phase2.history[key])
            history = combined_history
            phase1_epochs = len(self.history_phase1.history['loss'])
        else:
            history = self.history_phase1.history if hasattr(self, 'history_phase1') else {}
            phase1_epochs = 0

        epochs = range(1, len(history.get('loss', [])) + 1)

        # Plot training & validation accuracy
        axes[0, 0].plot(epochs, history.get('accuracy', []), 'b-', label='Training Accuracy')
        axes[0, 0].plot(epochs, history.get('val_accuracy', []), 'r-', label='Validation Accuracy')
        if phase1_epochs > 0:
            axes[0, 0].axvline(x=phase1_epochs, color='green', linestyle='--', alpha=0.7, label='Fine-tuning Start')
        axes[0, 0].set_title('Model Accuracy')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Accuracy')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)

        # Plot training & validation loss
        axes[0, 1].plot(epochs, history.get('loss', []), 'b-', label='Training Loss')
        axes[0, 1].plot(epochs, history.get('val_loss', []), 'r-', label='Validation Loss')
        if phase1_epochs > 0:
            axes[0, 1].axvline(x=phase1_epochs, color='green', linestyle='--', alpha=0.7, label='Fine-tuning Start')
        axes[0, 1].set_title('Model Loss')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Loss')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)

        # Plot learning rate
        if 'lr' in history:
            axes[1, 0].plot(epochs, history['lr'], 'g-', label='Learning Rate')
            axes[1, 0].set_title('Learning Rate Schedule')
            axes[1, 0].set_xlabel('Epoch')
            axes[1, 0].set_ylabel('Learning Rate')
            axes[1, 0].set_yscale('log')
            axes[1, 0].legend()
            axes[1, 0].grid(True, alpha=0.3)

        # Plot top-3 accuracy
        if 'top3_accuracy' in history:
            axes[1, 1].plot(epochs, history.get('top3_accuracy', []), 'purple', label='Training Top-3')
            axes[1, 1].plot(epochs, history.get('val_top3_accuracy', []), 'orange', label='Validation Top-3')
            axes[1, 1].set_title('Top-3 Accuracy')
            axes[1, 1].set_xlabel('Epoch')
            axes[1, 1].set_ylabel('Top-3 Accuracy')
            axes[1, 1].legend()
            axes[1, 1].grid(True, alpha=0.3)

        plt.tight_layout()
        plt.show()

        # Print final metrics
        if history:
            final_acc = history.get('val_accuracy', [0])[-1]
            final_loss = history.get('val_loss', [float('inf')])[-1]
            final_top3 = history.get('val_top3_accuracy', [0])[-1]

            print(f"\n📊 Final Training Results:")
            print(f"  Validation Accuracy: {final_acc:.4f} ({final_acc*100:.2f}%)")
            print(f"  Validation Loss: {final_loss:.4f}")
            print(f"  Validation Top-3 Accuracy: {final_top3:.4f} ({final_top3*100:.2f}%)")


# Initialize training manager
trainer = TrainingManager(model, model_builder, data_gen.class_weights)

print("✅ Training manager initialized")
print("📋 Training will proceed in 2 phases:")
print("  Phase 1: Frozen base model (20 epochs)")
print("  Phase 2: Fine-tuning (30 epochs)")
print("\n🎯 Expected final accuracy: >95%")

In [None]:
# 🏃‍♂️ Model Training Execution

# Phase 1: Train with frozen base model
print("=" * 60)
print("PHASE 1: TRANSFER LEARNING WITH FROZEN BASE MODEL")
print("=" * 60)

history_phase1 = trainer.train_phase1(
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    epochs=20
)

# Display Phase 1 results
trainer.plot_training_history()

# Phase 2: Fine-tune with unfrozen base model
print("\\n" + "=" * 60)
print("PHASE 2: FINE-TUNING WITH UNFROZEN BASE MODEL")
print("=" * 60)

history_phase2 = trainer.train_phase2(
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    epochs=30,
    fine_tune_lr=1e-5
)

# Display final training results
trainer.plot_training_history()

print("\\n🎉 Training completed successfully!")
print("📁 Best model saved to: /content/best_model.h5")
print("📊 Training log saved to: /content/training_log.csv")

In [None]:
# 📊 Comprehensive Model Evaluation
class ModelEvaluator:
    """Comprehensive model evaluation and analysis"""
    
    def __init__(self, model, test_dataset, class_names):
        self.model = model
        self.test_dataset = test_dataset
        self.class_names = class_names
        self.num_classes = len(class_names)
        
    def evaluate_model(self):
        """Comprehensive model evaluation"""
        print("🔍 Evaluating model on test dataset...")
        
        # Get predictions and true labels
        y_pred_probs = []
        y_true = []
        
        for batch_images, batch_labels in self.test_dataset:
            batch_pred = self.model.predict(batch_images, verbose=0)
            y_pred_probs.extend(batch_pred)
            y_true.extend(batch_labels.numpy())
        
        y_pred_probs = np.array(y_pred_probs)
        y_true = np.array(y_true)
        
        # Convert one-hot to class indices
        y_true_idx = np.argmax(y_true, axis=1)
        y_pred_idx = np.argmax(y_pred_probs, axis=1)
        
        # Calculate basic metrics
        test_accuracy = np.mean(y_true_idx == y_pred_idx)
        
        # Calculate top-3 accuracy
        top3_pred = np.argsort(y_pred_probs, axis=1)[:, -3:]
        top3_accuracy = np.mean([true_label in pred_top3 for true_label, pred_top3 in zip(y_true_idx, top3_pred)])
        
        print(f"\n📈 Test Set Performance:")
        print(f"  Test Accuracy: {test_accuracy:.4f} ({test_accuracy*100:.2f}%)")
        print(f"  Top-3 Accuracy: {top3_accuracy:.4f} ({top3_accuracy*100:.2f}%)")
        
        return y_true_idx, y_pred_idx, y_pred_probs
    
    def plot_confusion_matrix(self, y_true, y_pred):
        """Plot detailed confusion matrix"""        
        cm = confusion_matrix(y_true, y_pred)
        
        # Calculate per-class accuracy
        class_accuracies = cm.diagonal() / cm.sum(axis=1)
        
        plt.figure(figsize=(20, 16))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                    xticklabels=self.class_names, yticklabels=self.class_names,
                    cbar_kws={'shrink': 0.8})
        plt.title('Confusion Matrix - Plant Disease Detection', fontsize=16, fontweight='bold')
        plt.xlabel('Predicted Label', fontsize=12)
        plt.ylabel('True Label', fontsize=12)
        plt.xticks(rotation=45, ha='right')
        plt.yticks(rotation=0)
        plt.tight_layout()
        plt.show()
        
        # Show per-class performance
        print("\n🎯 Per-Class Performance:")
        class_performance = list(zip(self.class_names, class_accuracies))
        class_performance.sort(key=lambda x: x[1], reverse=True)
        
        for i, (class_name, accuracy) in enumerate(class_performance):
            status = "🟢" if accuracy > 0.9 else "🟡" if accuracy > 0.8 else "🔴"
            print(f"  {status} {class_name:25s}: {accuracy:.3f} ({accuracy*100:.1f}%)")
    
    def plot_classification_report(self, y_true, y_pred):
        """Generate detailed classification report"""        
        from sklearn.metrics import classification_report
        
        report = classification_report(y_true, y_pred, target_names=self.class_names, output_dict=True)
        report_df = pd.DataFrame(report).transpose()
        
        # Plot precision, recall, f1-score
        fig, axes = plt.subplots(1, 3, figsize=(18, 12))
        metrics = ['precision', 'recall', 'f1-score']
        colors = ['skyblue', 'lightgreen', 'lightcoral']
        
        for i, (metric, color) in enumerate(zip(metrics, colors)):
            class_scores = report_df[metric][:-3]  # Exclude accuracy, macro avg, weighted avg
            sorted_scores = class_scores.sort_values(ascending=True)
            y_pos = np.arange(len(sorted_scores))
            axes[i].barh(y_pos, sorted_scores.values, color=color, alpha=0.7, edgecolor='navy')
            axes[i].set_yticks(y_pos)
            axes[i].set_yticklabels(sorted_scores.index, fontsize=8)
            axes[i].set_xlabel(metric.capitalize(), fontsize=12)
            axes[i].set_title(f'{metric.capitalize()} by Class', fontsize=14, fontweight='bold')
            axes[i].grid(axis='x', alpha=0.3)
            
            for j, v in enumerate(sorted_scores.values):
                axes[i].text(v + 0.01, j, f'{v:.3f}', va='center', fontsize=8)
        
        plt.tight_layout()
        plt.show()
        
        # Print summary statistics
        print(f"\n📊 Overall Performance Summary:")
        print(f"  Macro Average Precision: {report['macro avg']['precision']:.4f}")
        print(f"  Macro Average Recall: {report['macro avg']['recall']:.4f}")
        print(f"  Macro Average F1-Score: {report['macro avg']['f1-score']:.4f}")
        print(f"  Weighted Average F1-Score: {report['weighted avg']['f1-score']:.4f}")
    
    def analyze_misclassifications(self, y_true, y_pred, y_pred_probs, top_n=10):
        """Analyze most common misclassifications"""        
        misclassified = y_true != y_pred
        misclassified_indices = np.where(misclassified)[0]
        
        print(f"\n❌ Misclassification Analysis:")
        print(f"Total misclassified: {len(misclassified_indices)} out of {len(y_true)} ({len(misclassified_indices)/len(y_true)*100:.2f}%)")
        
        misclass_patterns = {}
        for idx in misclassified_indices:
            true_class = self.class_names[y_true[idx]]
            pred_class = self.class_names[y_pred[idx]]
            pattern = f"{true_class} -> {pred_class}"
            
            if pattern not in misclass_patterns:
                misclass_patterns[pattern] = []
            misclass_patterns[pattern].append(idx)
        
        sorted_patterns = sorted(misclass_patterns.items(), key=lambda x: len(x[1]), reverse=True)
        
        print(f"\n🔍 Top {min(top_n, len(sorted_patterns))} Misclassification Patterns:")
        for i, (pattern, indices) in enumerate(sorted_patterns[:top_n]):
            frequency = len(indices)
            percentage = frequency / len(misclassified_indices) * 100
            print(f"  {i+1:2d}. {pattern:50s}: {frequency:3d} cases ({percentage:.1f}%)")
    
    def plot_confidence_distribution(self, y_pred_probs):
        """Plot prediction confidence distribution"""        
        max_confidences = np.max(y_pred_probs, axis=1)
        
        plt.figure(figsize=(12, 6))
        
        # Histogram of confidence scores
        plt.subplot(1, 2, 1)
        plt.hist(max_confidences, bins=50, alpha=0.7, color='skyblue', edgecolor='navy')
        plt.axvline(np.mean(max_confidences), color='red', linestyle='--', label=f'Mean: {np.mean(max_confidences):.3f}')
        plt.axvline(np.median(max_confidences), color='green', linestyle='--', label=f'Median: {np.median(max_confidences):.3f}')
        plt.title('Prediction Confidence Distribution', fontsize=14, fontweight='bold')
        plt.xlabel('Confidence Score')
        plt.ylabel('Frequency')
        plt.legend()
        plt.grid(alpha=0.3)
        
        # Box plot of confidence scores
        plt.subplot(1, 2, 2)
        plt.boxplot(max_confidences, labels=['All Predictions'])
        plt.title('Confidence Score Distribution', fontsize=14, fontweight='bold')
        plt.ylabel('Confidence Score')
        plt.grid(alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        print(f"\n📊 Confidence Statistics:")
        print(f"  Mean Confidence: {np.mean(max_confidences):.4f}")
        print(f"  Median Confidence: {np.median(max_confidences):.4f}")
        print(f"  Std Confidence: {np.std(max_confidences):.4f}")
        print(f"  Low Confidence (<0.8): {np.sum(max_confidences < 0.8)} samples ({np.sum(max_confidences < 0.8)/len(max_confidences)*100:.1f}%)")


# Perform comprehensive evaluation
evaluator = ModelEvaluator(model, test_dataset, data_gen.class_names)
y_true, y_pred, y_pred_probs = evaluator.evaluate_model()

# Generate visualizations and analysis
evaluator.plot_confusion_matrix(y_true, y_pred)
evaluator.plot_classification_report(y_true, y_pred)
evaluator.analyze_misclassifications(y_true, y_pred, y_pred_probs)
evaluator.plot_confidence_distribution(y_pred_probs)

print("\n✅ Model evaluation completed!")


In [None]:
import json
import tensorflow as tf
from pathlib import Path

try:
    import tf2onnx
    TF2ONNX_AVAILABLE = True
except ImportError:
    TF2ONNX_AVAILABLE = False


class ModelExporter:
    def __init__(self, model: tf.keras.Model, class_names: list, export_dir: str = "exported_models"):
        self.model = model
        self.class_names = class_names
        # Use model.input_shape which is reliable
        self.input_shape = model.input_shape[1:]
        self.export_dir = Path(export_dir)
        self.export_dir.mkdir(exist_ok=True)
        print(f"📤 Initialized model exporter. Files will be saved to: {self.export_dir.resolve()}")

    def export_tensorflow_model(self):
        print("💾 Exporting TensorFlow SavedModel...")
        tf_model_path = self.export_dir / 'tensorflow_model'
        try:
            # Use model.export() for SavedModel format in Keras 3+
            self.model.export(tf_model_path)
            print(f"✅ TensorFlow model saved to: {tf_model_path}")
            return tf_model_path
        except Exception as e:
            print(f"❌ Failed to export TensorFlow SavedModel: {e}")
            return None

    def export_keras_model(self):
        print("💾 Exporting Keras models...")
        h5_path = self.export_dir / 'plant_disease_model.h5'
        keras_path = self.export_dir / 'plant_disease_model.keras'
        try:
            # Save in both H5 (legacy) and new Keras format
            self.model.save(h5_path)
            self.model.save(keras_path)
            print("✅ Keras models saved:")
            print(f"  H5 format: {h5_path}")
            print(f"  Keras format: {keras_path}")
            return h5_path, keras_path
        except Exception as e:
            print(f"❌ Failed to export Keras models: {e}")
            return None, None

    def export_tflite_model(self):
        print("📱 Exporting TensorFlow Lite model...")
        try:
            tf_model_path = self.export_dir / 'tensorflow_model'
            if not tf_model_path.exists():
                print("TensorFlow SavedModel not found, attempting to create it for TFLite conversion.")
                saved_model_path = self.export_tensorflow_model()
                if not saved_model_path:
                    print("Skipping TFLite export as SavedModel creation failed.")
                    return None

            converter = tf.lite.TFLiteConverter.from_saved_model(str(tf_model_path))
            converter.optimizations = [tf.lite.Optimize.DEFAULT]
            tflite_model = converter.convert()

            tflite_path = self.export_dir / 'plant_disease_model.tflite'
            tflite_path.write_bytes(tflite_model)

            size_mb = len(tflite_model) / (1024 * 1024)
            print(f"✅ TensorFlow Lite model saved: {tflite_path}")
            print(f"   Model size: {size_mb:.2f} MB")
            return tflite_path
        except Exception as e:
            print(f"❌ Failed to export TensorFlow Lite model: {e}")
            print(f"Error details: {e}")
            return None

    def export_onnx_model(self):
        print("🔄 Exporting ONNX model...")
        if not TF2ONNX_AVAILABLE:
            print("❌ tf2onnx not installed. Skipping ONNX export.")
            print("   Install with: pip install tf2onnx")
            return None

        try:
            onnx_path = self.export_dir / 'plant_disease_model.onnx'
            tf_model_path = str(self.export_dir / 'tensorflow_model')
            if not Path(tf_model_path).exists():
                print("TensorFlow SavedModel not found, attempting to create it for ONNX conversion.")
                saved_model_path = self.export_tensorflow_model()
                if not saved_model_path:
                    print("Skipping ONNX export as SavedModel creation failed.")
                    return None

            import subprocess
            command = [
                "python", "-m", "tf2onnx.convert",
                "--saved-model", tf_model_path,
                "--output", str(onnx_path),
                "--opset", "13"
            ]
            result = subprocess.run(command, check=True, capture_output=True, text=True)
            print("tf2onnx stdout:", result.stdout)
            print("tf2onnx stderr:", result.stderr)

            print(f"✅ ONNX model saved: {onnx_path}")
            return onnx_path
        except subprocess.CalledProcessError as e:
            print(f"❌ Failed to export ONNX model (subprocess error): {e}")
            print("stdout:", e.stdout)
            print("stderr:", e.stderr)
            return None
        except Exception as e:
            print(f"❌ Failed to export ONNX model: {e}")
            print(f"Error details: {e}")
            return None

    def create_class_mapping(self):
        print("📋 Creating class mapping files...")
        class_mapping = {
            "classes": [
                {
                    "id": i,
                    "name": class_name,
                    "display_name": class_name.replace('_', ' ').title()
                }
                for i, class_name in enumerate(self.class_names)
            ],
            "num_classes": len(self.class_names),
            "model_info": {
                "input_shape": list(self.input_shape)
            }
        }

        json_path = self.export_dir / 'class_mapping.json'
        with json_path.open('w') as f:
            json.dump(class_mapping, f, indent=2)

        txt_path = self.export_dir / 'class_names.txt'
        with txt_path.open('w') as f:
            f.write("\n".join(self.class_names))

        print("✅ Class mapping files created:")
        print(f"  JSON: {json_path}")
        print(f"  Text: {txt_path}")
        return json_path, txt_path

    def create_deployment_config(self):
        print("⚙️ Creating deployment configuration...")
        config = {
            "model": {
                "name": "Plant Disease Detection Model V2",
                "version": "2.0.0",
                "description": "Advanced EfficientNetV2-based plant disease detection",
                "framework": "tensorflow",
                "input_shape": list(self.input_shape),
            },
            "preprocessing": {
                "resize": list(self.input_shape[:2]),
                "normalize": True,
                "mean": [0.485, 0.456, 0.406],
                "std": [0.229, 0.224, 0.225]
            },
            "classes": self.class_names,
            "deployment": {
                "recommended_format": "tensorflow_savedmodel",
                "mobile_format": "tflite",
                "cross_platform_format": "onnx"
            }
        }

        config_path = self.export_dir / 'deployment_config.json'
        with config_path.open('w') as f:
            json.dump(config, f, indent=2)

        print(f"✅ Deployment config created: {config_path}")
        return config_path

    def export_all(self):
        print("\n🚀 Starting complete model export process...\n")
        exported_files = {}

        exported_files['tensorflow'] = self.export_tensorflow_model()
        exported_files['keras_h5'], exported_files['keras_new'] = self.export_keras_model()
        # Skipping TFLite and ONNX exports
        # exported_files['tflite'] = self.export_tflite_model()
        # exported_files['onnx'] = self.export_onnx_model()

        exported_files['class_json'], exported_files['class_txt'] = self.create_class_mapping()
        exported_files['config'] = self.create_deployment_config()

        print(f"\n🎉 Model export completed! (Note: TFLite and ONNX exports were skipped)\n")
        print(f"📁 All available files saved to: {self.export_dir.resolve()}")
        return exported_files


# --- Execute the export process ---
# Initialize the ModelExporter
exporter = ModelExporter(model, data_gen.class_names)

# Start the export process
exported_files_info = exporter.export_all()