In [None]:
import os

folder_path = '/kaggle/input/combined-dataset1to4-modified/Combined_Dataset1to4' 

try:
    class_names = [item for item in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, item))]
    
    print(f"✅ Found {len(class_names)} classes in the folder '{os.path.abspath(folder_path)}':\n")
    
    for name in sorted(class_names):
        print(name)

except FileNotFoundError:
    print(f"❌ Error: The folder path '{folder_path}' was not found. Please check the path and try again.")

In [None]:
import os
import shutil
import collections

# --- CONFIGURATION ---

SOURCE_DIRECTORY = '/kaggle/input/combined-dataset1to4-modified/Combined_Dataset1to4'


OUTPUT_DIRECTORY = 'Organized_Dataset'

# --- SCRIPT LOGIC ---

# The list of all your original folder names
folder_list = [
    "American_Bollworm_on_Cotton", "Anthracnose_on_Cotton", "Apple_Black_rot",
    "Apple_Cedar_apple_rust", "Apple_healthy", "Apple_scab", "Army_worm",
    "Bacterial_blight_in_Cotton", "Becterial_Blight_in_Rice", "Blueberry_healthy",
    "Brownspot", "Cherry_(including_sour)_Powdery_mildew", "Cherry_(including_sour)_healthy",
    "Corn_(maize)_Cercospora_leaf_spot_Gray_leaf_spot", "Corn_(maize)_Common_rust_",
    "Corn_(maize)_Northern_Leaf_Blight", "Corn_(maize)_healthy", "Cotton_Aphid",
    "Cotton_Healthy", "Cotton_Leaf_Curl", "Cotton_mealy_bug", "Cotton_whitefly",
    "Flag_Smut", "GrapeLeaf_blight_(Isariopsis_Leaf_Spot)", "Grape_Black_rot",
    "Grape_Esca_(Black_Measles)", "Grape_healthy", "Maize_ear_rot",
    "Maize_fall_armyworm", "Maize_stem_borer", "Orange_Haunglongbing_(Citrus_greening)",
    "Peach_Bacterial_spot", "Peach_healthy", "Pepper_bell_Bacterial_spot",
    "Pepper_bell_healthy", "Pink_bollworm_in_cotton", "Potato_Early_blight",
    "Potato_Late_blight", "Potato_healthy", "Raspberry_healthy", "Red_cotton_bug",
    "Rice_Blast", "Soybean_healthy", "Squash_Powdery_mildew",
    "Strawberry_Leaf_scorch", "Strawberry_healthy", "Sugarcane_Healthy",
    "Sugarcane_Mosaic", "Sugarcane_RedRot", "Sugarcane_RedRust",
    "Sugarcane_Yellow_Rust", "Thirps_on_cotton", "Tomato___Bacterial_spot",
    "Tomato___Early_blight", "Tomato___Late_blight", "Tomato___Leaf_Mold",
    "Tomato___Septoria_leaf_spot", "Tomato___Spider_mites_Two-spotted_spider_mite",
    "Tomato___Target_Spot", "Tomato___Tomato_Yellow_Leaf_Curl_Virus",
    "Tomato___Tomato_mosaic_virus", "Tomato___healthy", "Tungro",
    "Wheat_Brown_leaf_Rust", "Wheat_Healthy", "Wheat_Stem_fly",
    "Wheat___Yellow_Rust", "Wheat_aphid", "Wheat_black_rust",
    "Wheat_leaf_blight", "Wheat_mite", "Wheat_powdery_mildew", "Wheat_scab", "Wilt"
]

# Dictionary to map keywords to a standard category name
PLANT_MAP = {
    'Apple': 'Apple', 'Cotton': 'Cotton', 'cotton': 'Cotton', 'Rice': 'Rice',
    'Blueberry': 'Blueberry', 'Cherry': 'Cherry', 'Corn': 'Corn', 'Maize': 'Corn',
    'Grape': 'Grape', 'Orange': 'Orange', 'Peach': 'Peach', 'Pepper': 'Pepper',
    'Potato': 'Potato', 'Raspberry': 'Raspberry', 'Soybean': 'Soybean',
    'Squash': 'Squash', 'Strawberry': 'Strawberry', 'Sugarcane': 'Sugarcane',
    'Tomato': 'Tomato', 'Wheat': 'Wheat'
}

def get_new_names(original_name):
    """Parses the original folder name to get the new category and subfolder name."""
    name_lower = original_name.lower()

    # Priority 1: Handle special separators like '_on_' and '_in_'
    if '_on_' in name_lower:
        parts = original_name.split('_on_')
        return parts[1].capitalize(), parts[0]
    if '_in_' in name_lower:
        parts = original_name.split('_in_')
        return parts[1].capitalize(), parts[0]
    
    # Priority 2: General keyword search
    for keyword, category in PLANT_MAP.items():
        if keyword.lower() in name_lower:
            # Clean up the name for the new subfolder
            # This is a bit complex to handle all cases like Tomato___, Corn_(maize)_ etc.
            new_subfolder_name = original_name.replace(keyword, '').replace('___', '_').strip('_')
            if '(maize)' in new_subfolder_name:
                new_subfolder_name = new_subfolder_name.replace('(maize)', '').strip('_')
            if '(including_sour)' in new_subfolder_name:
                 new_subfolder_name = new_subfolder_name.replace('(including_sour)', '').strip('_')
            if 'Leaf' in new_subfolder_name and category == 'Grape':
                 new_subfolder_name = new_subfolder_name.replace('Leaf', '_leaf')
            
            # If cleaning results in an empty name, use original name part
            if not new_subfolder_name:
                new_subfolder_name = original_name.split('_')[-1]

            return category, new_subfolder_name

    return "Unclassified", original_name

def main():
    """Main function to create the new structure and copy files."""
    print(f"Preparing to copy and organize folders into '{OUTPUT_DIRECTORY}'...")
    
    # Create the main output directory if it doesn't exist
    os.makedirs(OUTPUT_DIRECTORY, exist_ok=True)
    
    processed_count = 0
    for original_folder_name in folder_list:
        source_path = os.path.join(SOURCE_DIRECTORY, original_folder_name)
        
        # Check if the source folder actually exists
        if not os.path.isdir(source_path):
            print(f"⚠️  Warning: Source folder not found, skipping: '{original_folder_name}'")
            continue
            
        category, new_subfolder_name = get_new_names(original_folder_name)
        
        # Create the new category path (e.g., 'Organized_Dataset/Apple')
        target_category_path = os.path.join(OUTPUT_DIRECTORY, category)
        
        # Create the final destination path (e.g., 'Organized_Dataset/Apple/Black_rot')
        destination_path = os.path.join(target_category_path, new_subfolder_name)
        
        print(f"Copying: '{original_folder_name}'  ->  '{category}/{new_subfolder_name}'")
        
        try:
            # The core command: copy the entire directory tree
            shutil.copytree(source_path, destination_path)
            processed_count += 1
        except FileExistsError:
            print(f"    -  Skipped: Destination folder already exists.")
        except Exception as e:
            print(f"    -  ❌ Error copying '{original_folder_name}': {e}")
            
    print(f"\n✅ Done! Processed and copied {processed_count} folders.")
    print(f"Your new, organized dataset is ready in the '{OUTPUT_DIRECTORY}' folder.")


if __name__ == "__main__":
    main()

In [1]:
import os
import shutil

# --- CONFIGURATION ---
SOURCE_DIRECTORY = '/kaggle/input/combined-dataset1to4-modified/Combined_Dataset1to4'
OUTPUT_DIRECTORY = 'Organized_Dataset_Tomato_Potato' # New output folder

# ❗ KEY CHANGE: Define a list of the only plant categories you want to process.
# The script will ignore any folder that doesn't belong to these categories.
CLASSES_TO_PROCESS = ['Tomato', 'Potato']

# --- SCRIPT LOGIC ---

PLANT_MAP = {
    'Apple': 'Apple', 'Cotton': 'Cotton', 'Rice': 'Rice',
    'Blueberry': 'Blueberry', 'Cherry': 'Cherry', 'Corn': 'Corn', 'Maize': 'Corn',
    'Grape': 'Grape', 'Orange': 'Orange', 'Peach': 'Peach', 'Pepper': 'Pepper',
    'Potato': 'Potato', 'Raspberry': 'Raspberry', 'Soybean': 'Soybean',
    'Squash': 'Squash', 'Strawberry': 'Strawberry', 'Sugarcane': 'Sugarcane',
    'Tomato': 'Tomato', 'Wheat': 'Wheat'
}

def get_new_names(original_name):
    """Parses the original folder name to get the new category and subfolder name."""
    name_lower = original_name.lower()

    if '_on_' in name_lower:
        parts = original_name.split('_on_')
        return parts[1].capitalize(), parts[0].replace('_', ' ')
    if '_in_' in name_lower:
        parts = original_name.split('_in_')
        return parts[1].capitalize(), parts[0].replace('_', ' ')
    
    for keyword, category in PLANT_MAP.items():
        if keyword.lower() in name_lower:
            new_subfolder_name = original_name.replace(keyword, '').replace('___', '_').strip(' _')
            new_subfolder_name = new_subfolder_name.replace('(maize)', '').replace('(including_sour)', '').strip(' _')
            
            if not new_subfolder_name or new_subfolder_name.lower() == 'healthy':
                new_subfolder_name = 'healthy'
                
            return category, new_subfolder_name.replace('_', ' ').capitalize()

    return "Unclassified", original_name

def main():
    """Main function to create the new structure and copy files."""
    print(f"Preparing to organize folders for {CLASSES_TO_PROCESS}...")
    print(f"Source: '{SOURCE_DIRECTORY}'")

    try:
        all_folders = [f for f in os.listdir(SOURCE_DIRECTORY) if os.path.isdir(os.path.join(SOURCE_DIRECTORY, f))]
        print(f"Found {len(all_folders)} total folders to check.")
    except FileNotFoundError:
        print(f"❌ Error: The source directory was not found: '{SOURCE_DIRECTORY}'")
        return
        
    os.makedirs(OUTPUT_DIRECTORY, exist_ok=True)
    
    processed_count = 0
    for original_folder_name in all_folders:
        source_path = os.path.join(SOURCE_DIRECTORY, original_folder_name)
        
        category, new_subfolder_name = get_new_names(original_folder_name)
        
        # ❗ KEY CHANGE: Check if the detected category is in our target list.
        if category in CLASSES_TO_PROCESS:
            # If it is, proceed with copying the folder.
            destination_path = os.path.join(OUTPUT_DIRECTORY, category, new_subfolder_name)
            
            print(f"Copying: '{original_folder_name}'  ->  '{category}/{new_subfolder_name}'")
            
            try:
                shutil.copytree(source_path, destination_path)
                processed_count += 1
            except FileExistsError:
                print(f"    - Skipped: Destination folder already exists.")
            except Exception as e:
                print(f"    - ❌ Error copying '{original_folder_name}': {e}")
        # If the category is not 'Tomato' or 'Potato', the script simply ignores it
        # and moves to the next folder.
            
    print(f"\n✅ Done! Processed and copied {processed_count} folders related to {CLASSES_TO_PROCESS}.")
    print(f"Your new, organized dataset is ready in the '{OUTPUT_DIRECTORY}' folder.")


if __name__ == "__main__":
    main()

Preparing to organize folders for ['Tomato', 'Potato']...
Source: '/kaggle/input/combined-dataset1to4-modified/Combined_Dataset1to4'
Found 74 total folders to check.
Copying: 'Tomato___Late_blight'  ->  'Tomato/Late blight'
Copying: 'Tomato___healthy'  ->  'Tomato/Healthy'
Copying: 'Tomato___Spider_mites_Two-spotted_spider_mite'  ->  'Tomato/Spider mites two-spotted spider mite'
Copying: 'Potato_Late_blight'  ->  'Potato/Late blight'
Copying: 'Tomato___Early_blight'  ->  'Tomato/Early blight'
Copying: 'Tomato___Septoria_leaf_spot'  ->  'Tomato/Septoria leaf spot'
Copying: 'Potato_healthy'  ->  'Potato/Healthy'
Copying: 'Tomato___Tomato_Yellow_Leaf_Curl_Virus'  ->  'Tomato/Yellow leaf curl virus'
Copying: 'Tomato___Bacterial_spot'  ->  'Tomato/Bacterial spot'
Copying: 'Tomato___Target_Spot'  ->  'Tomato/Target spot'
Copying: 'Potato_Early_blight'  ->  'Potato/Early blight'
Copying: 'Tomato___Tomato_mosaic_virus'  ->  'Tomato/Mosaic virus'
Copying: 'Tomato___Leaf_Mold'  ->  'Tomato/Leaf 

In [2]:
import os
import shutil
import random

# --- CONFIGURATION ---
# The root directory of the previously organized dataset (Stage 0 output)
SOURCE_ROOT = '/kaggle/working/Organized_Dataset_Tomato_Potato'
# The root directory for the final Stage 1 (Coarse Classification) splits
DESTINATION_ROOT = 'Stage_1_Splits'

# Define the plant categories to process (must match the folder names in SOURCE_ROOT)
PLANT_CATEGORIES = ['Tomato', 'Potato']

# Define the desired split ratios (must sum to 1.0)
SPLIT_RATIOS = {
    'train': 0.70,
    'validation': 0.15,
    'test': 0.15
}

def split_data():
    """
    Combines all disease images per plant, shuffles them, and splits them
    into the final train, validation, and test directories for Stage 1.
    """
    if sum(SPLIT_RATIOS.values()) != 1.0:
        print("❌ Error: Split ratios must sum exactly to 1.0. Check your configuration.")
        return

    print(f"Starting data split process for Stage 1 (Train: {SPLIT_RATIOS['train']:.0%}, Valid: {SPLIT_RATIOS['validation']:.0%}, Test: {SPLIT_RATIOS['test']:.0%})")

    # 1. Create the necessary destination directories
    for split_type in SPLIT_RATIOS.keys():
        for category in PLANT_CATEGORIES:
            os.makedirs(os.path.join(DESTINATION_ROOT, split_type, category), exist_ok=True)
            
    total_images_processed = 0

    # 2. Process each plant category
    for category in PLANT_CATEGORIES:
        source_category_path = os.path.join(SOURCE_ROOT, category)
        
        if not os.path.exists(source_category_path):
            print(f"⚠️ Warning: Source folder not found for {category} at {source_category_path}. Skipping.")
            continue
            
        print(f"\n--- Processing {category} ---")
        
        # Collect all image file paths across all disease subfolders
        all_image_paths = []
        for root, _, files in os.walk(source_category_path):
            for file in files:
                # Basic check to ensure we only process image files
                if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                    all_image_paths.append(os.path.join(root, file))

        # 3. Shuffle and split the paths
        random.shuffle(all_image_paths)
        total_count = len(all_image_paths)
        print(f"Total images found: {total_count}")
        
        if total_count == 0:
             print(f"Skipping {category}: No images found.")
             continue

        # Calculate split indices
        train_end = int(total_count * SPLIT_RATIOS['train'])
        validation_end = train_end + int(total_count * SPLIT_RATIOS['validation'])

        # Split the list
        train_files = all_image_paths[:train_end]
        validation_files = all_image_paths[train_end:validation_end]
        test_files = all_image_paths[validation_end:] # Test takes the remainder

        # Store files to copy by split type
        split_files = {
            'train': train_files,
            'validation': validation_files,
            'test': test_files
        }

        # 4. Copy files to the final destination structure
        for split_type, file_list in split_files.items():
            destination_dir = os.path.join(DESTINATION_ROOT, split_type, category)
            print(f"  - Copying {len(file_list):5d} files to {split_type}/{category}...")
            
            for src_path in file_list:
                # Use os.path.basename to get only the filename (flattening the disease structure)
                dst_path = os.path.join(destination_dir, os.path.basename(src_path))
                shutil.copy2(src_path, dst_path) # copy2 preserves metadata
                total_images_processed += 1
                
    print(f"\n✅ Data splitting complete. Total images copied: {total_images_processed}")
    print(f"The Stage 1 dataset is ready in the '{DESTINATION_ROOT}' folder.")

if __name__ == "__main__":
    split_data()


Starting data split process for Stage 1 (Train: 70%, Valid: 15%, Test: 15%)

--- Processing Tomato ---
Total images found: 19006
  - Copying 13304 files to train/Tomato...
  - Copying  2850 files to validation/Tomato...
  - Copying  2852 files to test/Tomato...

--- Processing Potato ---
Total images found: 2344
  - Copying  1640 files to train/Potato...
  - Copying   351 files to validation/Potato...
  - Copying   353 files to test/Potato...

✅ Data splitting complete. Total images copied: 21350
The Stage 1 dataset is ready in the 'Stage_1_Splits' folder.


In [3]:
import os
import shutil
import random

# --- CONFIGURATION ---
# SOURCE_ROOT: Source Directory, which is the output from the initial organization script
SOURCE_ROOT = '/kaggle/working/Organized_Dataset_Tomato_Potato'
# DESTINATION_ROOT: Destination Directory for the final Stage 2 split dataset
DESTINATION_ROOT = 'Stage_2_Splits'

# Plant Categories to process (must match the top-level folders in SOURCE_ROOT)
PLANT_CATEGORIES = ['Tomato', 'Potato']

# Desired split ratios: Train / Validation / Test (must sum to 1.0)
SPLIT_RATIOS = {
    'train': 0.70,
    'validation': 0.15,
    'test': 0.15
}

def split_disease_data():
    """
    For each plant category, this script divides the data by disease into separate 
    train, validation, and test directories. This prepares the dataset for
    the specialized Stage 2 models.
    """
    if sum(SPLIT_RATIOS.values()) != 1.0:
        print("❌ Error: Split ratios must sum exactly to 1.0. Please check the configuration.")
        return

    print(f"Starting data split process for Stage 2 (Train: {SPLIT_RATIOS['train']:.0%}, Validation: {SPLIT_RATIOS['validation']:.0%}, Test: {SPLIT_RATIOS['test']:.0%})")

    total_images_processed = 0

    # 1. Loop through each plant category (e.g., 'Tomato', 'Potato')
    for category in PLANT_CATEGORIES:
        source_category_path = os.path.join(SOURCE_ROOT, category)
        
        if not os.path.exists(source_category_path):
            print(f"⚠️ Warning: Source folder not found for {category} at {source_category_path}. Skipping.")
            continue
            
        print(f"\n--- Processing {category} ---")
        
        # 2. Identify disease subfolders (which serve as the specific labels)
        disease_folders = [d for d in os.listdir(source_category_path) 
                           if os.path.isdir(os.path.join(source_category_path, d))]

        for disease_name in disease_folders:
            source_disease_path = os.path.join(source_category_path, disease_name)
            
            # 3. Collect all image file paths for this disease
            all_image_paths = [os.path.join(source_disease_path, f) 
                               for f in os.listdir(source_disease_path) 
                               if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

            random.shuffle(all_image_paths)
            total_count = len(all_image_paths)
            
            if total_count == 0:
                 print(f"  - No images found for disease {disease_name}. Skipping.")
                 continue
                 
            print(f"  - Disease '{disease_name}': Total {total_count} images.")

            # 4. Calculate Split Indices
            train_end = int(total_count * SPLIT_RATIOS['train'])
            validation_end = train_end + int(total_count * SPLIT_RATIOS['validation'])

            train_files = all_image_paths[:train_end]
            validation_files = all_image_paths[train_end:validation_end]
            test_files = all_image_paths[validation_end:]

            split_files = {
                'train': train_files,
                'validation': validation_files,
                'test': test_files
            }

            # 5. Copy files to the destination structure
            for split_type, file_list in split_files.items():
                # Destination Path: DESTINATION_ROOT / PLANT / SPLIT_TYPE / DISEASE
                destination_dir = os.path.join(DESTINATION_ROOT, category, split_type, disease_name)
                os.makedirs(destination_dir, exist_ok=True)
                
                # Copying {len(file_list)} images to {split_type}/{disease_name}... (commented out for cleaner output)

                for src_path in file_list:
                    dst_path = os.path.join(destination_dir, os.path.basename(src_path))
                    shutil.copy2(src_path, dst_path)
                    total_images_processed += 1
                
    print(f"\n✅ Data splitting complete. Total images copied: {total_images_processed}")
    print(f"The Stage 2 dataset is ready in the '{DESTINATION_ROOT}' folder.")

if __name__ == "__main__":
    split_disease_data()


Starting data split process for Stage 2 (Train: 70%, Validation: 15%, Test: 15%)

--- Processing Tomato ---
  - Disease 'Septoria leaf spot': Total 1940 images.
  - Disease 'Spider mites two-spotted spider mite': Total 1676 images.
  - Disease 'Late blight': Total 2009 images.
  - Disease 'Leaf mold': Total 1061 images.
  - Disease 'Healthy': Total 1684 images.
  - Disease 'Bacterial spot': Total 2234 images.
  - Disease 'Target spot': Total 1422 images.
  - Disease 'Mosaic virus': Total 452 images.
  - Disease 'Early blight': Total 1105 images.
  - Disease 'Yellow leaf curl virus': Total 5423 images.

--- Processing Potato ---
  - Disease 'Late blight': Total 1095 images.
  - Disease 'Healthy': Total 152 images.
  - Disease 'Early blight': Total 1097 images.

✅ Data splitting complete. Total images copied: 21350
The Stage 2 dataset is ready in the 'Stage_2_Splits' folder.


In [4]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
import os

# --- CONFIGURATION ---
# Define the root path where your split data is located
DATA_ROOT = '/kaggle/working/Stage_1_Splits'
# Model parameters
IMAGE_SIZE = (128, 128)
BATCH_SIZE = 32
EPOCHS = 30
# Name of the saved model file
MODEL_FILENAME = 'stage1_classifier.h5'

# --- 1. DATA PREPARATION ---
print("Initializing data generators...")

# Define data augmentation and preprocessing for the training set
train_datagen = ImageDataGenerator(
    rescale=1./255, # Normalize pixel values
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

# Only normalization for validation and test sets (no augmentation)
valid_test_datagen = ImageDataGenerator(rescale=1./255)

# Load training data
train_generator = train_datagen.flow_from_directory(
    os.path.join(DATA_ROOT, 'train'),
    target_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='binary', # Use 'binary' for 2 classes (Tomato vs Potato)
    shuffle=True
)

# Load validation data
validation_generator = valid_test_datagen.flow_from_directory(
    os.path.join(DATA_ROOT, 'validation'),
    target_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='binary',
    shuffle=False # Do not shuffle validation data
)

# --- 2. MODEL DEFINITION (Simple Custom CNN) ---
print("\nDefining Stage 1 CNN model...")
# The model will learn to distinguish between the two plant types (Tomato=0, Potato=1)
model = Sequential([
    # Block 1
    Conv2D(32, (3, 3), activation='relu', input_shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 3)),
    MaxPooling2D((2, 2)),
    
    # Block 2
    Conv2D(64, (3, 3), activation='relu'),
    MaxPooling2D((2, 2)),
    
    # Block 3
    Conv2D(128, (3, 3), activation='relu'),
    MaxPooling2D((2, 2)),
    
    # Classification Head
    Flatten(),
    Dropout(0.5), # Regularization to prevent overfitting
    Dense(128, activation='relu'),
    # Output layer for binary classification (Tomato vs. Potato)
    Dense(1, activation='sigmoid') 
])

# --- 3. MODEL COMPILATION ---
model.compile(
    optimizer='adam',
    loss='binary_crossentropy',
    metrics=['accuracy']
)

model.summary()

# --- 4. CALLBACKS ---
# Stop training if validation loss doesn't improve for 5 epochs
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

# Save the best model based on validation accuracy
model_checkpoint = ModelCheckpoint(
    MODEL_FILENAME,
    monitor='val_accuracy',
    save_best_only=True,
    mode='max',
    verbose=1
)

# --- 5. TRAINING ---
print("\nStarting Stage 1 training...")

history = model.fit(
    train_generator,
    steps_per_epoch=train_generator.samples // BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=validation_generator,
    validation_steps=validation_generator.samples // BATCH_SIZE,
    callbacks=[early_stopping, model_checkpoint]
)

print("\n--- Training Finished ---")
print(f"The best performing Stage 1 model has been saved as '{MODEL_FILENAME}'.")

# Optional: Evaluate on the test set (using the validation generator for a simple check)
# It's better to load the test generator separately if needed, but for simplicity we skip here.
print("\nLoading the best model weights for final summary...")
best_model = tf.keras.models.load_model(MODEL_FILENAME)
loss, accuracy = best_model.evaluate(validation_generator, verbose=0)
print(f"Final Validation Accuracy: {accuracy*100:.2f}%")

print("\nYour Stage 1 model is ready!")


2025-10-22 07:44:18.148041: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1761119058.366168      37 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1761119058.454393      37 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Initializing data generators...
Found 14934 images belonging to 2 classes.
Found 3201 images belonging to 2 classes.

Defining Stage 1 CNN model...


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
I0000 00:00:1761119071.756087      37 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13942 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:04.0, compute capability: 7.5
I0000 00:00:1761119071.756797      37 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13942 MB memory:  -> device: 1, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5



Starting Stage 1 training...


  self._warn_if_super_not_called()


Epoch 1/30


I0000 00:00:1761119075.726318     129 service.cc:148] XLA service 0x7bac340035e0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1761119075.727516     129 service.cc:156]   StreamExecutor device (0): Tesla T4, Compute Capability 7.5
I0000 00:00:1761119075.727538     129 service.cc:156]   StreamExecutor device (1): Tesla T4, Compute Capability 7.5
I0000 00:00:1761119076.057812     129 cuda_dnn.cc:529] Loaded cuDNN version 90300


[1m  3/466[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m35s[0m 76ms/step - accuracy: 0.5226 - loss: 0.6918

I0000 00:00:1761119079.456195     129 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m466/466[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 138ms/step - accuracy: 0.8799 - loss: 0.3359
Epoch 1: val_accuracy improved from -inf to 0.92500, saving model to stage1_classifier.h5
[1m466/466[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m75s[0m 148ms/step - accuracy: 0.8799 - loss: 0.3358 - val_accuracy: 0.9250 - val_loss: 0.1786
Epoch 2/30
[1m  1/466[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m10s[0m 22ms/step - accuracy: 0.9688 - loss: 0.1087




Epoch 2: val_accuracy improved from 0.92500 to 0.92531, saving model to stage1_classifier.h5
[1m466/466[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 7ms/step - accuracy: 0.9688 - loss: 0.1087 - val_accuracy: 0.9253 - val_loss: 0.1800
Epoch 3/30
[1m466/466[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 131ms/step - accuracy: 0.9085 - loss: 0.2272
Epoch 3: val_accuracy improved from 0.92531 to 0.94844, saving model to stage1_classifier.h5
[1m466/466[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m64s[0m 138ms/step - accuracy: 0.9085 - loss: 0.2272 - val_accuracy: 0.9484 - val_loss: 0.1375
Epoch 4/30
[1m  1/466[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m9s[0m 21ms/step - accuracy: 0.8438 - loss: 0.4022
Epoch 4: val_accuracy improved from 0.94844 to 0.94969, saving model to stage1_classifier.h5
[1m466/466[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 7ms/step - accuracy: 0.8438 - loss: 0.4022 - val_accuracy: 0.9497 - val_loss: 0.1379
Epoch 5/30
[1m466/466[0m [3

In [5]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
import os

# --- CONFIGURATION ---
# Base directory for the Stage 2 data splits
DATA_ROOT = '/kaggle/working/Stage_2_Splits'
# Specific sub-folder for this model's data
PLANT_FOLDER = 'Tomato'

# Model parameters
IMAGE_SIZE = (128, 128)
BATCH_SIZE = 32
EPOCHS = 30
# Name of the saved model file
MODEL_FILENAME = 'stage2_tomato_classifier.h5'

# --- 1. DATA PREPARATION ---
print(f"Initializing data generators for {PLANT_FOLDER} diseases...")

# Define data augmentation and preprocessing for the training set
train_datagen = ImageDataGenerator(
    rescale=1./255, # Normalize pixel values
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

# Only normalization for validation and test sets
valid_test_datagen = ImageDataGenerator(rescale=1./255)

# Full path to the training data for this specific plant
train_dir = os.path.join(DATA_ROOT, PLANT_FOLDER, 'train')
validation_dir = os.path.join(DATA_ROOT, PLANT_FOLDER, 'validation')

# Load training data
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical', # Use 'categorical' for multi-class disease classification
    shuffle=True
)

# Load validation data
validation_generator = valid_test_datagen.flow_from_directory(
    validation_dir,
    target_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=False
)

# Determine the number of output classes (diseases + healthy)
NUM_CLASSES = train_generator.num_classes
print(f"Number of classes (diseases) detected for {PLANT_FOLDER}: {NUM_CLASSES}")

# --- 2. MODEL DEFINITION (CNN Architecture) ---
print(f"\nDefining Stage 2 CNN model for {PLANT_FOLDER}...")
model = Sequential([
    # Block 1
    Conv2D(32, (3, 3), activation='relu', input_shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 3)),
    MaxPooling2D((2, 2)),
    
    # Block 2
    Conv2D(64, (3, 3), activation='relu'),
    MaxPooling2D((2, 2)),
    
    # Block 3
    Conv2D(128, (3, 3), activation='relu'),
    MaxPooling2D((2, 2)),
    
    # Classification Head
    Flatten(),
    Dropout(0.5), 
    Dense(128, activation='relu'),
    # Output layer: uses NUM_CLASSES units and softmax for multi-class probability
    Dense(NUM_CLASSES, activation='softmax') 
])

# --- 3. MODEL COMPILATION ---
model.compile(
    optimizer='adam',
    # Use categorical_crossentropy for multi-class classification
    loss='categorical_crossentropy', 
    metrics=['accuracy']
)

model.summary()

# --- 4. CALLBACKS ---
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

model_checkpoint = ModelCheckpoint(
    MODEL_FILENAME,
    monitor='val_accuracy',
    save_best_only=True,
    mode='max',
    verbose=1
)

# --- 5. TRAINING ---
print(f"\nStarting Stage 2 training for {PLANT_FOLDER}...")

history = model.fit(
    train_generator,
    steps_per_epoch=train_generator.samples // BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=validation_generator,
    validation_steps=validation_generator.samples // BATCH_SIZE,
    callbacks=[early_stopping, model_checkpoint]
)

print("\n--- Training Finished ---")
print(f"The best performing Stage 2 {PLANT_FOLDER} model has been saved as '{MODEL_FILENAME}'.")


Initializing data generators for Tomato diseases...
Found 13300 images belonging to 10 classes.
Found 2847 images belonging to 10 classes.
Number of classes (diseases) detected for Tomato: 10

Defining Stage 2 CNN model for Tomato...



Starting Stage 2 training for Tomato...
Epoch 1/30
[1m415/415[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 136ms/step - accuracy: 0.4148 - loss: 1.7044
Epoch 1: val_accuracy improved from -inf to 0.56889, saving model to stage2_tomato_classifier.h5
[1m415/415[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m63s[0m 143ms/step - accuracy: 0.4151 - loss: 1.7036 - val_accuracy: 0.5689 - val_loss: 1.3660
Epoch 2/30
[1m  1/415[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m9s[0m 22ms/step - accuracy: 0.6250 - loss: 0.9953
Epoch 2: val_accuracy improved from 0.56889 to 0.58452, saving model to stage2_tomato_classifier.h5
[1m415/415[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 6ms/step - accuracy: 0.6250 - loss: 0.9953 - val_accuracy: 0.5845 - val_loss: 1.3130
Epoch 3/30
[1m415/415[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 130ms/step - accuracy: 0.7050 - loss: 0.8839
Epoch 3: val_accuracy improved from 0.58452 to 0.62571, saving model to stage2_tomato_classifier.h5


In [6]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
import os

# --- CONFIGURATION ---
# Base directory for the Stage 2 data splits
DATA_ROOT = '/kaggle/working/Stage_2_Splits'
# Specific sub-folder for this model's data
PLANT_FOLDER = 'Potato'

# Model parameters
IMAGE_SIZE = (128, 128)
BATCH_SIZE = 32
EPOCHS = 30
# Name of the saved model file
MODEL_FILENAME = 'stage2_potato_classifier.h5'

# --- 1. DATA PREPARATION ---
print(f"Initializing data generators for {PLANT_FOLDER} diseases...")

# Define data augmentation and preprocessing for the training set
train_datagen = ImageDataGenerator(
    rescale=1./255, # Normalize pixel values
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

# Only normalization for validation and test sets
valid_test_datagen = ImageDataGenerator(rescale=1./255)

# Full path to the training data for this specific plant
train_dir = os.path.join(DATA_ROOT, PLANT_FOLDER, 'train')
validation_dir = os.path.join(DATA_ROOT, PLANT_FOLDER, 'validation')

# Load training data
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical', # Use 'categorical' for multi-class disease classification
    shuffle=True
)

# Load validation data
validation_generator = valid_test_datagen.flow_from_directory(
    validation_dir,
    target_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=False
)

# Determine the number of output classes (diseases + healthy)
NUM_CLASSES = train_generator.num_classes
print(f"Number of classes (diseases) detected for {PLANT_FOLDER}: {NUM_CLASSES}")

# --- 2. MODEL DEFINITION (CNN Architecture) ---
print(f"\nDefining Stage 2 CNN model for {PLANT_FOLDER}...")
model = Sequential([
    # Block 1
    Conv2D(32, (3, 3), activation='relu', input_shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 3)),
    MaxPooling2D((2, 2)),
    
    # Block 2
    Conv2D(64, (3, 3), activation='relu'),
    MaxPooling2D((2, 2)),
    
    # Block 3
    Conv2D(128, (3, 3), activation='relu'),
    MaxPooling2D((2, 2)),
    
    # Classification Head
    Flatten(),
    Dropout(0.5), 
    Dense(128, activation='relu'),
    # Output layer: uses NUM_CLASSES units and softmax for multi-class probability
    Dense(NUM_CLASSES, activation='softmax') 
])

# --- 3. MODEL COMPILATION ---
model.compile(
    optimizer='adam',
    # Use categorical_crossentropy for multi-class classification
    loss='categorical_crossentropy', 
    metrics=['accuracy']
)

model.summary()

# --- 4. CALLBACKS ---
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

model_checkpoint = ModelCheckpoint(
    MODEL_FILENAME,
    monitor='val_accuracy',
    save_best_only=True,
    mode='max',
    verbose=1
)

# --- 5. TRAINING ---
print(f"\nStarting Stage 2 training for {PLANT_FOLDER}...")

history = model.fit(
    train_generator,
    steps_per_epoch=train_generator.samples // BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=validation_generator,
    validation_steps=validation_generator.samples // BATCH_SIZE,
    callbacks=[early_stopping, model_checkpoint]
)

print("\n--- Training Finished ---")
print(f"The best performing Stage 2 {PLANT_FOLDER} model has been saved as '{MODEL_FILENAME}'.")


Initializing data generators for Potato diseases...
Found 1639 images belonging to 3 classes.
Found 350 images belonging to 3 classes.
Number of classes (diseases) detected for Potato: 3

Defining Stage 2 CNN model for Potato...



Starting Stage 2 training for Potato...
Epoch 1/30
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 160ms/step - accuracy: 0.4680 - loss: 1.0145
Epoch 1: val_accuracy improved from -inf to 0.67188, saving model to stage2_potato_classifier.h5
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 183ms/step - accuracy: 0.4690 - loss: 1.0124 - val_accuracy: 0.6719 - val_loss: 0.7469
Epoch 2/30
[1m 1/51[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m1s[0m 22ms/step - accuracy: 0.6875 - loss: 0.8062
Epoch 2: val_accuracy improved from 0.67188 to 0.67500, saving model to stage2_potato_classifier.h5
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 10ms/step - accuracy: 0.6875 - loss: 0.8062 - val_accuracy: 0.6750 - val_loss: 0.7378
Epoch 3/30
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 133ms/step - accuracy: 0.6651 - loss: 0.7607
Epoch 3: val_accuracy improved from 0.67500 to 0.82500, saving model to stage2_potato_classifier.h5
[1m51/51

In [8]:
import tensorflow as tf
import numpy as np
import os
from tensorflow.keras.preprocessing import image
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# --- CONFIGURATION ---
# Assumed locations for the three trained model files
STAGE1_MODEL_PATH = '/kaggle/working/stage1_classifier.h5'
STAGE2_TOMATO_MODEL_PATH = '/kaggle/working/stage2_tomato_classifier.h5'
STAGE2_POTATO_MODEL_PATH = '/kaggle/working/stage2_potato_classifier.h5'

# Directory where the Stage 2 data splits are located (needed to reliably get class indices)
STAGE2_DATA_ROOT = 'Stage_2_Splits'

# Image size used during training
IMAGE_SIZE = (128, 128)

# --- GLOBAL MODEL AND LABEL INITIALIZATION ---
# Load models globally once to avoid loading them repeatedly for every prediction
STAGE1_MODEL = None
STAGE2_TOMATO_MODEL = None
STAGE2_POTATO_MODEL = None

try:
    print("Loading Stage 1 Model...")
    STAGE1_MODEL = tf.keras.models.load_model(STAGE1_MODEL_PATH)

    print("Loading Stage 2 Tomato Model...")
    STAGE2_TOMATO_MODEL = tf.keras.models.load_model(STAGE2_TOMATO_MODEL_PATH)
    
    print("Loading Stage 2 Potato Model...")
    STAGE2_POTATO_MODEL = tf.keras.models.load_model(STAGE2_POTATO_MODEL_PATH)
    print("All models loaded successfully.")

except Exception as e:
    print(f"❌ ERROR: Could not load one or more models. Ensure files are present: {e}")
    # The functions will handle None models.


def get_stage_labels(plant_folder):
    """
    Dynamically or statically retrieves the class labels for a Stage 2 model.
    This mimics the alphabetical ordering used by ImageDataGenerator during training.
    """
    if not os.path.exists(os.path.join(STAGE2_DATA_ROOT, plant_folder, 'train')):
        # Fallback to assumed alphabetical labels if data directory isn't available
        if plant_folder == 'Tomato':
            return {
                0: 'Bacterial spot', 1: 'Early blight', 2: 'Healthy', 3: 'Late blight',
                4: 'Leaf mold', 5: 'Mosaic virus', 6: 'Septoria leaf spot', 
                7: 'Spider mites two-spotted spider mite', 8: 'Target spot', 9: 'Yellow leaf curl virus'
            }
        elif plant_folder == 'Potato':
            return {0: 'Early blight', 1: 'Healthy', 2: 'Late blight'}
        return {}
    
    # Use ImageDataGenerator to reliably get class indices
    temp_datagen = ImageDataGenerator(rescale=1./255)
    generator = temp_datagen.flow_from_directory(
        os.path.join(STAGE2_DATA_ROOT, plant_folder, 'train'),
        target_size=IMAGE_SIZE,
        batch_size=1,
        class_mode='categorical',
        shuffle=False
    )
    return {v: k for k, v in generator.class_indices.items()}

# Define Stage 1 labels (Binary Classification)
STAGE1_LABELS = {0: 'Potato', 1: 'Tomato'} 

# Define Stage 2 labels using the helper function
TOMATO_DISEASE_LABELS = get_stage_labels('Tomato')
POTATO_DISEASE_LABELS = get_stage_labels('Potato')

# --- PREPROCESSING ---

def preprocess_image(img_path):
    """Loads an image, resizes it, and converts it to a NumPy array for prediction."""
    try:
        img = image.load_img(img_path, target_size=IMAGE_SIZE)
        img_array = image.img_to_array(img)
        img_array = np.expand_dims(img_array, axis=0)  # Add batch dimension
        img_array /= 255.0  # Rescale to 0-1 range
        return img_array
    except FileNotFoundError:
        return None
    except Exception:
        return None

# --- PREDICTION FUNCTION ---

def predict_disease(image_path):
    """
    Executes the two-stage inference pipeline for a single image.
    Returns (plant_label, disease_label, confidence)
    """
    global STAGE1_MODEL, STAGE2_TOMATO_MODEL, STAGE2_POTATO_MODEL

    if not all([STAGE1_MODEL, STAGE2_TOMATO_MODEL, STAGE2_POTATO_MODEL]):
        return "N/A", "N/A", 0.0

    processed_image = preprocess_image(image_path)
    if processed_image is None:
        return "N/A", "N/A", 0.0

    # 1. STAGE 1: COARSE CLASSIFICATION (Plant Type)
    stage1_pred = STAGE1_MODEL.predict(processed_image, verbose=0)
    
    plant_index = 1 if stage1_pred[0][0] >= 0.5 else 0
    plant_label = STAGE1_LABELS.get(plant_index, "Unknown Plant")
    
    # 2. STAGE 2: FINE-GRAINED CLASSIFICATION (Disease)
    if plant_label == 'Tomato':
        stage2_model = STAGE2_TOMATO_MODEL
        disease_labels = TOMATO_DISEASE_LABELS
    elif plant_label == 'Potato':
        stage2_model = STAGE2_POTATO_MODEL
        disease_labels = POTATO_DISEASE_LABELS
    else:
        return plant_label, "Inference Failed (Stage 1 error)", 0.0
        
    stage2_pred = stage2_model.predict(processed_image, verbose=0)
    
    disease_index = np.argmax(stage2_pred[0])
    disease_confidence = np.max(stage2_pred[0])
    disease_label = disease_labels.get(disease_index, "Unknown Disease Index")
    
    return plant_label, disease_label, disease_confidence


Loading Stage 1 Model...
Loading Stage 2 Tomato Model...
Loading Stage 2 Potato Model...
All models loaded successfully.
Found 13300 images belonging to 10 classes.
Found 1639 images belonging to 3 classes.


In [13]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.preprocessing import image
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# --- CONFIGURATION ---
STAGE1_MODEL_PATH = '/kaggle/working/stage1_classifier.h5'
STAGE2_TOMATO_MODEL_PATH = '/kaggle/working/stage2_tomato_classifier.h5'
STAGE2_POTATO_MODEL_PATH = '/kaggle/working/stage2_potato_classifier.h5'
STAGE2_DATA_ROOT = 'Stage_2_Splits' 
IMAGE_SIZE = (128, 128)

# Define Stage 1 labels (Binary Classification)
STAGE1_LABELS = {0: 'Potato', 1: 'Tomato'} 

# --- UTILITY FUNCTION FOR LABEL RETRIEVAL ---

def get_stage_labels(plant_folder):
    """
    Retrieves the class labels for a Stage 2 model, ensuring alphabetical order.
    """
    if not os.path.exists(os.path.join(STAGE2_DATA_ROOT, plant_folder, 'train')):
        # Fallback to hardcoded alphabetical labels if training data path is unavailable
        if plant_folder == 'Tomato':
            return {
                0: 'Bacterial spot', 1: 'Early blight', 2: 'Healthy', 3: 'Late blight',
                4: 'Leaf mold', 5: 'Mosaic virus', 6: 'Septoria leaf spot', 
                7: 'Spider mites two-spotted spider mite', 8: 'Target spot', 9: 'Yellow leaf curl virus'
            }
        elif plant_folder == 'Potato':
            return {0: 'Early blight', 1: 'Healthy', 2: 'Late blight'}
        return {}
    
    # Dynamic label retrieval using ImageDataGenerator (recommended if path exists)
    temp_datagen = ImageDataGenerator(rescale=1./255)
    generator = temp_datagen.flow_from_directory(
        os.path.join(STAGE2_DATA_ROOT, plant_folder, 'train'),
        target_size=IMAGE_SIZE,
        batch_size=1,
        class_mode='categorical',
        shuffle=False
    )
    # Note: Returns mapping of index -> class_name
    return {index: name for name, index in generator.class_indices.items()}

def preprocess_image(img_path):
    """Loads an image, resizes it, and converts it to a NumPy array for prediction."""
    try:
        img = image.load_img(img_path, target_size=IMAGE_SIZE)
        img_array = image.img_to_array(img)
        img_array = np.expand_dims(img_array, axis=0)  # Add batch dimension
        img_array /= 255.0  # Rescale to 0-1 range
        return img_array
    except Exception:
        return None

# --- TWO-STAGE CLASSIFIER CLASS (The Unified Model Wrapper) ---

class TwoStageClassifier:
    """
    A unified wrapper class that loads and manages all three models 
    for the two-stage prediction pipeline.
    """
    def __init__(self):
        self.stage1_model = None
        self.stage2_tomato_model = None
        self.stage2_potato_model = None
        self.tomato_labels = get_stage_labels('Tomato')
        self.potato_labels = get_stage_labels('Potato')
        self._load_models()

    def _load_models(self):
        """Loads all three Keras models from the specified paths."""
        try:
            print("Loading Stage 1 Model...")
            self.stage1_model = tf.keras.models.load_model(STAGE1_MODEL_PATH)

            print("Loading Stage 2 Tomato Model...")
            self.stage2_tomato_model = tf.keras.models.load_model(STAGE2_TOMATO_MODEL_PATH)
            
            print("Loading Stage 2 Potato Model...")
            self.stage2_potato_model = tf.keras.models.load_model(STAGE2_POTATO_MODEL_PATH)
            print("All models loaded successfully.")

        except Exception as e:
            print(f"❌ ERROR: Could not load one or more models. Check files: {e}")
            # Ensure all models are None if loading fails
            self.stage1_model = None
            self.stage2_tomato_model = None
            self.stage2_potato_model = None


    def predict(self, image_path):
        """
        Executes the two-stage inference pipeline for a single image.
        Returns (plant_label, disease_label, confidence)
        """
        if not all([self.stage1_model, self.stage2_tomato_model, self.stage2_potato_model]):
            return "N/A", "N/A", 0.0

        processed_image = preprocess_image(image_path)
        if processed_image is None:
            return "N/A", "N/A", 0.0

        # 1. STAGE 1: COARSE CLASSIFICATION (Plant Type)
        stage1_pred = self.stage1_model.predict(processed_image, verbose=0)
        
        # Binary prediction using a threshold of 0.5 (assuming a sigmoid output from Stage 1)
        # Assuming Potato is index 0 and Tomato is index 1
        plant_index = 1 if stage1_pred[0][0] >= 0.5 else 0 
        plant_label = STAGE1_LABELS.get(plant_index, "Unknown Plant")
        
        # 2. STAGE 2: FINE-GRAINED CLASSIFICATION (Disease)
        if plant_label == 'Tomato':
            stage2_model = self.stage2_tomato_model
            disease_labels = self.tomato_labels
        elif plant_label == 'Potato':
            stage2_model = self.stage2_potato_model
            disease_labels = self.potato_labels
        else:
            return plant_label, "Inference Failed (Stage 1 error)", 0.0
            
        stage2_pred = stage2_model.predict(processed_image, verbose=0)
        
        disease_index = np.argmax(stage2_pred[0])
        disease_confidence = np.max(stage2_pred[0])
        disease_label = disease_labels.get(disease_index, "Unknown Disease Index")
        
        return plant_label, disease_label, disease_confidence


# --- EVALUATION FUNCTION (Testing Logic) ---

def evaluate_pipeline():
    """
    Iterates through the entire test dataset, runs the TwoStageClassifier, 
    and calculates overall prediction accuracy.
    """
    # Initialize the unified classifier
    classifier = TwoStageClassifier()
    
    if not os.path.exists(STAGE2_DATA_ROOT):
        print(f"❌ Error: Stage 2 data root not found at '{STAGE2_DATA_ROOT}'.")
        return

    total_images = 0
    correct_plant_predictions = 0
    correct_disease_predictions = 0
    
    print("--- Starting Two-Stage Pipeline Evaluation ---")
    
    plants_to_evaluate = ['Tomato', 'Potato']

    for true_plant in plants_to_evaluate:
        plant_test_root = os.path.join(STAGE2_DATA_ROOT, true_plant, 'test')

        if not os.path.exists(plant_test_root):
            print(f"⚠️ Warning: Test data for {true_plant} not found at '{plant_test_root}'. Skipping.")
            continue
            
        print(f"Processing test data for: {true_plant}")

        for root, dirs, files in os.walk(plant_test_root):
            image_files = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

            if not image_files:
                continue

            true_disease = os.path.basename(root)
            
            if true_disease == 'test':
                continue

            for file_name in image_files:
                
                total_images += 1
                image_path = os.path.join(root, file_name)

                # Use the unified class method for prediction
                predicted_plant, predicted_disease, confidence = classifier.predict(image_path)
                
                # 1. Check Stage 1 Accuracy (Plant Type)
                plant_match = (predicted_plant == true_plant)
                if plant_match:
                    correct_plant_predictions += 1
                
                # 2. Check Stage 2 Accuracy (Disease) - ONLY COUNT IF STAGE 1 WAS CORRECT
                if plant_match and (predicted_disease == true_disease):
                    correct_disease_predictions += 1

                if total_images % 100 == 0:
                    print(f"Processed {total_images} images...")

    print(f"\n--- Evaluation Complete. Processed {total_images} total images. ---")
    
    # Calculate Metrics
    plant_accuracy = (correct_plant_predictions / total_images) * 100 if total_images > 0 else 0
    disease_accuracy = (correct_disease_predictions / total_images) * 100 if total_images > 0 else 0

    print(f"\n[METRICS]")
    print(f"1. Stage 1 (Plant Type) Accuracy: {plant_accuracy:.2f}% ({correct_plant_predictions}/{total_images})")
    print(f"2. Overall Pipeline (Disease) Accuracy: {disease_accuracy:.2f}% ({correct_disease_predictions}/{total_images})")
    

if __name__ == "__main__":
    evaluate_pipeline()


Found 13300 images belonging to 10 classes.
Found 1639 images belonging to 3 classes.
Loading Stage 1 Model...
Loading Stage 2 Tomato Model...
Loading Stage 2 Potato Model...
All models loaded successfully.
--- Starting Two-Stage Pipeline Evaluation ---
Processing test data for: Tomato
Processed 100 images...
Processed 200 images...
Processed 300 images...
Processed 400 images...
Processed 500 images...
Processed 600 images...
Processed 700 images...
Processed 800 images...
Processed 900 images...
Processed 1000 images...
Processed 1100 images...
Processed 1200 images...
Processed 1300 images...
Processed 1400 images...
Processed 1500 images...
Processed 1600 images...
Processed 1700 images...
Processed 1800 images...
Processed 1900 images...
Processed 2000 images...
Processed 2100 images...
Processed 2200 images...
Processed 2300 images...
Processed 2400 images...
Processed 2500 images...
Processed 2600 images...
Processed 2700 images...
Processed 2800 images...
Processing test data f

In [14]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.preprocessing import image
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# --- CONFIGURATION ---
STAGE1_MODEL_PATH = '/kaggle/working/stage1_classifier.h5'
STAGE2_TOMATO_MODEL_PATH = '/kaggle/working/stage2_tomato_classifier.h5'
STAGE2_POTATO_MODEL_PATH = '/kaggle/working/stage2_potato_classifier.h5'
STAGE2_DATA_ROOT = 'Stage_2_Splits' 
IMAGE_SIZE = (128, 128)

# Define Stage 1 labels (Binary Classification)
STAGE1_LABELS = {0: 'Potato', 1: 'Tomato'} 

# --- UTILITY FUNCTION FOR LABEL RETRIEVAL ---

def get_stage_labels(plant_folder):
    """
    Retrieves the class labels for a Stage 2 model, ensuring alphabetical order.
    """
    if not os.path.exists(os.path.join(STAGE2_DATA_ROOT, plant_folder, 'train')):
        # Fallback to hardcoded alphabetical labels if training data path is unavailable
        if plant_folder == 'Tomato':
            return {
                0: 'Bacterial spot', 1: 'Early blight', 2: 'Healthy', 3: 'Late blight',
                4: 'Leaf mold', 5: 'Mosaic virus', 6: 'Septoria leaf spot', 
                7: 'Spider mites two-spotted spider mite', 8: 'Target spot', 9: 'Yellow leaf curl virus'
            }
        elif plant_folder == 'Potato':
            return {0: 'Early blight', 1: 'Healthy', 2: 'Late blight'}
        return {}
    
    # Dynamic label retrieval using ImageDataGenerator (recommended if path exists)
    temp_datagen = ImageDataGenerator(rescale=1./255)
    generator = temp_datagen.flow_from_directory(
        os.path.join(STAGE2_DATA_ROOT, plant_folder, 'train'),
        target_size=IMAGE_SIZE,
        batch_size=1,
        class_mode='categorical',
        shuffle=False
    )
    # Note: Returns mapping of index -> class_name
    return {index: name for name, index in generator.class_indices.items()}

def preprocess_image(img_path):
    """Loads an image, resizes it, and converts it to a NumPy array for prediction."""
    try:
        img = image.load_img(img_path, target_size=IMAGE_SIZE)
        img_array = image.img_to_array(img)
        img_array = np.expand_dims(img_array, axis=0)  # Add batch dimension
        img_array /= 255.0  # Rescale to 0-1 range
        return img_array
    except Exception:
        return None

# --- TWO-STAGE CLASSIFIER CLASS (The Unified Model Wrapper) ---

class TwoStageClassifier:
    """
    A unified wrapper class that loads and manages all three models 
    for the two-stage prediction pipeline.
    """
    def __init__(self):
        self.stage1_model = None
        self.stage2_tomato_model = None
        self.stage2_potato_model = None
        self.tomato_labels = get_stage_labels('Tomato')
        self.potato_labels = get_stage_labels('Potato')
        self._load_models()

    def _load_models(self):
        """Loads all three Keras models from the specified paths."""
        try:
            print("Loading Stage 1 Model...")
            self.stage1_model = tf.keras.models.load_model(STAGE1_MODEL_PATH)

            print("Loading Stage 2 Tomato Model...")
            self.stage2_tomato_model = tf.keras.models.load_model(STAGE2_TOMATO_MODEL_PATH)
            
            print("Loading Stage 2 Potato Model...")
            self.stage2_potato_model = tf.keras.models.load_model(STAGE2_POTATO_MODEL_PATH)
            print("All models loaded successfully.")

        except Exception as e:
            print(f"❌ ERROR: Could not load one or more models. Check files: {e}")
            # Ensure all models are None if loading fails
            self.stage1_model = None
            self.stage2_tomato_model = None
            self.stage2_potato_model = None


    def predict(self, image_path):
        """
        Executes the two-stage inference pipeline for a single image.
        Returns (plant_label, disease_label, confidence)
        """
        if not all([self.stage1_model, self.stage2_tomato_model, self.stage2_potato_model]):
            return "N/A", "N/A", 0.0

        processed_image = preprocess_image(image_path)
        if processed_image is None:
            return "N/A", "N/A", 0.0

        # 1. STAGE 1: COARSE CLASSIFICATION (Plant Type)
        stage1_pred = self.stage1_model.predict(processed_image, verbose=0)
        
        # Binary prediction using a threshold of 0.5 (assuming a sigmoid output from Stage 1)
        # Assuming Potato is index 0 and Tomato is index 1
        plant_index = 1 if stage1_pred[0][0] >= 0.5 else 0 
        plant_label = STAGE1_LABELS.get(plant_index, "Unknown Plant")
        
        # 2. STAGE 2: FINE-GRAINED CLASSIFICATION (Disease)
        if plant_label == 'Tomato':
            stage2_model = self.stage2_tomato_model
            disease_labels = self.tomato_labels
        elif plant_label == 'Potato':
            stage2_model = self.stage2_potato_model
            disease_labels = self.potato_labels
        else:
            return plant_label, "Inference Failed (Stage 1 error)", 0.0
            
        stage2_pred = stage2_model.predict(processed_image, verbose=0)
        
        disease_index = np.argmax(stage2_pred[0])
        disease_confidence = np.max(stage2_pred[0])
        disease_label = disease_labels.get(disease_index, "Unknown Disease Index")
        
        return plant_label, disease_label, disease_confidence


# --- EVALUATION FUNCTION (Testing Logic) ---

def evaluate_pipeline():
    """
    Iterates through the entire test dataset, runs the TwoStageClassifier, 
    and calculates overall prediction accuracy, including a per-class breakdown.
    """
    # Initialize the unified classifier
    classifier = TwoStageClassifier()
    
    if not os.path.exists(STAGE2_DATA_ROOT):
        print(f"❌ Error: Stage 2 data root not found at '{STAGE2_DATA_ROOT}'.")
        return

    # Dictionary to store results: Key = "Plant: Disease", Value = {'Correct': int, 'Total': int}
    class_results = {}
    
    total_images = 0
    correct_plant_predictions = 0
    correct_disease_predictions = 0
    
    print("--- Starting Two-Stage Pipeline Evaluation ---")
    
    plants_to_evaluate = ['Tomato', 'Potato']

    for true_plant in plants_to_evaluate:
        plant_test_root = os.path.join(STAGE2_DATA_ROOT, true_plant, 'test')

        if not os.path.exists(plant_test_root):
            print(f"⚠️ Warning: Test data for {true_plant} not found at '{plant_test_root}'. Skipping.")
            continue
            
        print(f"Processing test data for: {true_plant}")

        for root, dirs, files in os.walk(plant_test_root):
            image_files = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

            if not image_files:
                continue

            true_disease = os.path.basename(root)
            
            if true_disease == 'test':
                continue

            # Define the unique class key
            class_key = f"{true_plant}: {true_disease}"
            if class_key not in class_results:
                class_results[class_key] = {'Correct': 0, 'Total': 0}

            for file_name in image_files:
                
                total_images += 1
                class_results[class_key]['Total'] += 1
                image_path = os.path.join(root, file_name)

                # Use the unified class method for prediction
                predicted_plant, predicted_disease, confidence = classifier.predict(image_path)
                
                # 1. Check Stage 1 Accuracy (Plant Type)
                plant_match = (predicted_plant == true_plant)
                if plant_match:
                    correct_plant_predictions += 1
                
                # 2. Check Stage 2 Accuracy (Disease) - ONLY COUNT IF STAGE 1 WAS CORRECT
                # This is the overall pipeline accuracy for the specific class
                if plant_match and (predicted_disease == true_disease):
                    correct_disease_predictions += 1
                    class_results[class_key]['Correct'] += 1


                if total_images % 100 == 0:
                    print(f"Processed {total_images} images...")

    print(f"\n--- Evaluation Complete. Processed {total_images} total images. ---")
    
    # --- METRICS CALCULATION AND DISPLAY ---
    
    # Calculate Overall Metrics
    plant_accuracy = (correct_plant_predictions / total_images) * 100 if total_images > 0 else 0
    disease_accuracy = (correct_disease_predictions / total_images) * 100 if total_images > 0 else 0

    print(f"\n[OVERALL METRICS]")
    print(f"1. Stage 1 (Plant Type) Accuracy: {plant_accuracy:.2f}% ({correct_plant_predictions}/{total_images})")
    print(f"2. Overall Pipeline (Disease) Accuracy: {disease_accuracy:.2f}% ({correct_disease_predictions}/{total_images})")
    
    # Display Per-Class Metrics
    print("\n[PER-CLASS ACCURACY BREAKDOWN (Stage 1 AND Stage 2 Must Be Correct)]")
    
    # Sort results alphabetically for cleaner output
    sorted_class_keys = sorted(class_results.keys())

    # Header
    print(f"{'Class':<50} {'Correct':>7} {'Total':>7} {'Accuracy (%)':>15}")
    print("-" * 80)
    
    for class_key in sorted_class_keys:
        res = class_results[class_key]
        correct = res['Correct']
        total = res['Total']
        
        # Calculate per-class accuracy
        accuracy = (correct / total) * 100 if total > 0 else 0.0
        
        print(f"{class_key:<50} {correct:>7} {total:>7} {accuracy:>15.2f}")


if __name__ == "__main__":
    evaluate_pipeline()


Found 13300 images belonging to 10 classes.
Found 1639 images belonging to 3 classes.
Loading Stage 1 Model...
Loading Stage 2 Tomato Model...
Loading Stage 2 Potato Model...
All models loaded successfully.
--- Starting Two-Stage Pipeline Evaluation ---
Processing test data for: Tomato
Processed 100 images...
Processed 200 images...
Processed 300 images...
Processed 400 images...
Processed 500 images...
Processed 600 images...
Processed 700 images...
Processed 800 images...
Processed 900 images...
Processed 1000 images...
Processed 1100 images...
Processed 1200 images...
Processed 1300 images...
Processed 1400 images...
Processed 1500 images...
Processed 1600 images...
Processed 1700 images...
Processed 1800 images...
Processed 1900 images...
Processed 2000 images...
Processed 2100 images...
Processed 2200 images...
Processed 2300 images...
Processed 2400 images...
Processed 2500 images...
Processed 2600 images...
Processed 2700 images...
Processed 2800 images...
Processing test data f

In [23]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing import image
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# --- CONFIGURATION ---
STAGE1_MODEL_PATH = '/kaggle/working/stage1_classifier.h5'
STAGE2_TOMATO_MODEL_PATH = '/kaggle/working/stage2_tomato_classifier.h5'
STAGE2_POTATO_MODEL_PATH = '/kaggle/working/stage2_potato_classifier.h5' 

# NOTE: This path is only needed to dynamically load labels, but we will use
# hardcoded labels if models are not loaded in a training environment.
STAGE2_DATA_ROOT = 'Stage_2_Splits' 

IMAGE_SIZE = (128, 128)
STAGE1_LABELS = {0: 'Potato', 1: 'Tomato'} 

# --- UTILITY FUNCTIONS ---

def get_stage_labels(plant_folder):
    """
    Retrieves the class labels for a Stage 2 model, ensuring alphabetical order.
    Uses hardcoded labels as a primary defense against environment errors.
    """
    if plant_folder == 'Tomato':
        return {
            0: 'Bacterial spot', 1: 'Early blight', 2: 'Healthy', 3: 'Late blight',
            4: 'Leaf mold', 5: 'Mosaic virus', 6: 'Septoria leaf spot', 
            7: 'Spider mites two-spotted spider mite', 8: 'Target spot', 9: 'Yellow leaf curl virus'
        }
    elif plant_folder == 'Potato':
        return {0: 'Early blight', 1: 'Healthy', 2: 'Late blight'}
    return {}

def preprocess_image(img_path):
    """Loads an image, resizes it, and converts it to a NumPy array for prediction."""
    try:
        img = image.load_img(img_path, target_size=IMAGE_SIZE)
        img_array = image.img_to_array(img)
        img_array = np.expand_dims(img_array, axis=0)  # Add batch dimension
        img_array /= 255.0  # Rescale to 0-1 range
        return img_array
    except FileNotFoundError:
        print(f"❌ Error: Image file not found at path: {img_path}")
        return None
    except Exception as e:
        print(f"❌ Error processing image {img_path}: {e}")
        return None

# --- TWO-STAGE CLASSIFIER CLASS ---

class TwoStageClassifier:
    """
    A unified wrapper class that loads and manages all three models 
    for the two-stage prediction pipeline.
    """
    def __init__(self):
        self.stage1_model = None
        self.stage2_tomato_model = None
        self.stage2_potato_model = None
        
        # Load labels from utility function
        self.tomato_labels = get_stage_labels('Tomato')
        self.potato_labels = get_stage_labels('Potato')
        self._load_models()

    def _load_models(self):
        """Loads all three Keras models from the specified paths."""
        try:
            print("Loading Models...")
            self.stage1_model = tf.keras.models.load_model(STAGE1_MODEL_PATH)
            self.stage2_tomato_model = tf.keras.models.load_model(STAGE2_TOMATO_MODEL_PATH)
            self.stage2_potato_model = tf.keras.models.load_model(STAGE2_POTATO_MODEL_PATH)
            print("✅ All models loaded successfully.")

        except Exception as e:
            print(f"❌ CRITICAL ERROR: Could not load one or more models. Ensure files ({STAGE1_MODEL_PATH}, etc.) are in the correct directory.")
            print(f"Details: {e}")
            self.stage1_model = None
            self.stage2_tomato_model = None
            self.stage2_potato_model = None


    def predict(self, image_path):
        """
        Executes the two-stage inference pipeline for a single image.
        Returns (plant_label, disease_label, confidence)
        """
        if not all([self.stage1_model, self.stage2_tomato_model, self.stage2_potato_model]):
            return "N/A", "Models Not Loaded", 0.0

        processed_image = preprocess_image(image_path)
        if processed_image is None:
            return "N/A", "Image Processing Failed", 0.0

        # 1. STAGE 1: COARSE CLASSIFICATION (Plant Type)
        stage1_pred = self.stage1_model.predict(processed_image, verbose=0)
        
        # Determine plant based on sigmoid output (index 1 is Tomato, index 0 is Potato)
        plant_index = 1 if stage1_pred[0][0] >= 0.5 else 0 
        plant_label = STAGE1_LABELS.get(plant_index, "Unknown Plant")
        
        # 2. STAGE 2: FINE-GRAINED CLASSIFICATION (Disease)
        if plant_label == 'Tomato':
            stage2_model = self.stage2_tomato_model
            disease_labels = self.tomato_labels
        elif plant_label == 'Potato':
            stage2_model = self.stage2_potato_model
            disease_labels = self.potato_labels
        else:
            return plant_label, "Inference Failed", 0.0
            
        stage2_pred = stage2_model.predict(processed_image, verbose=0)
        
        disease_index = np.argmax(stage2_pred[0])
        disease_confidence = np.max(stage2_pred[0])
        disease_label = disease_labels.get(disease_index, "Unknown Disease Index")
        
        return plant_label, disease_label, disease_confidence


# --- INTERACTIVE EXECUTION ---

if __name__ == "__main__":
    # Initialize the classifier (models are loaded here)
    classifier = TwoStageClassifier()

    if classifier.stage1_model is None:
        print("\nCannot proceed with prediction because models failed to load. Please check file paths.")
    else:
        print("\n--- Two-Stage Leaf Disease Predictor ---")
        
        # Prompt user for image path
        image_path = input("Enter the path to the leaf image (e.g., path/to/leaf.jpg): ").strip()
        
        if not image_path:
            print("No path entered. Exiting.")
        elif not os.path.exists(image_path):
             print(f"File not found: {image_path}. Please check the path and try again.")
        else:
            print("\nProcessing image...")
            
            # Run the prediction pipeline
            plant, disease, confidence = classifier.predict(image_path)

            print("\n--- RESULTS ---")
            
            # Displaying the plant type first, as requested
            print(f"1. Plant Type:      {plant}")
            
            # Displaying the disease/condition
            print(f"2. Condition:       {disease}")
            
            # Displaying the confidence score
            if confidence > 0.0:
                 print(f"3. Confidence:      {confidence:.4f}")
            print("-----------------")
            


Loading Models...
✅ All models loaded successfully.

--- Two-Stage Leaf Disease Predictor ---


Enter the path to the leaf image (e.g., path/to/leaf.jpg):  /kaggle/input/combined-dataset1to4-modified/Combined_Dataset1to4/Tomato___Bacterial_spot/01375198-62af-4c40-bddf-f3c11107200b___GCREC_Bact.Sp_5914.JPG



Processing image...

--- RESULTS ---
1. Plant Type:      Tomato
2. Condition:       Bacterial spot
3. Confidence:      0.9999
-----------------


In [24]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.preprocessing import image
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# --- CONFIGURATION ---
STAGE1_MODEL_PATH = '/kaggle/working/stage1_classifier.h5'
STAGE2_TOMATO_MODEL_PATH = '/kaggle/working/stage2_tomato_classifier.h5'
STAGE2_POTATO_MODEL_PATH = '/kaggle/working/stage2_potato_classifier.h5' 

# NOTE: This path is only needed to dynamically load labels
STAGE2_DATA_ROOT = 'Stage_2_Splits' 

IMAGE_SIZE = (128, 128)
STAGE1_LABELS = {0: 'Potato', 1: 'Tomato'} 

# --- UTILITY FUNCTIONS ---

def get_stage_labels(plant_folder):
    """
    Retrieves the class labels for a Stage 2 model, ensuring alphabetical order.
    Uses hardcoded labels for reliability in deployment.
    """
    if plant_folder == 'Tomato':
        return {
            0: 'Bacterial spot', 1: 'Early blight', 2: 'Healthy', 3: 'Late blight',
            4: 'Leaf mold', 5: 'Mosaic virus', 6: 'Septoria leaf spot', 
            7: 'Spider mites two-spotted spider mite', 8: 'Target spot', 9: 'Yellow leaf curl virus'
        }
    elif plant_folder == 'Potato':
        return {0: 'Early blight', 1: 'Healthy', 2: 'Late blight'}
    return {}

def preprocess_image(img_path):
    """Loads an image, resizes it, and converts it to a NumPy array for prediction."""
    try:
        img = image.load_img(img_path, target_size=IMAGE_SIZE)
        img_array = image.img_to_array(img)
        img_array = np.expand_dims(img_array, axis=0)  # Add batch dimension
        img_array /= 255.0  # Rescale to 0-1 range
        return img_array
    except Exception:
        # Silently skip files that aren't valid images
        return None

# --- TWO-STAGE CLASSIFIER CLASS ---
# (Reused from single_image_predictor.py)

class TwoStageClassifier:
    """
    A unified wrapper class that loads and manages all three models 
    for the two-stage prediction pipeline.
    """
    def __init__(self):
        self.stage1_model = None
        self.stage2_tomato_model = None
        self.stage2_potato_model = None
        
        self.tomato_labels = get_stage_labels('Tomato')
        self.potato_labels = get_stage_labels('Potato')
        self._load_models()

    def _load_models(self):
        """Loads all three Keras models from the specified paths."""
        try:
            print("Loading Models...")
            self.stage1_model = tf.keras.models.load_model(STAGE1_MODEL_PATH)
            self.stage2_tomato_model = tf.keras.models.load_model(STAGE2_TOMATO_MODEL_PATH)
            self.stage2_potato_model = tf.keras.models.load_model(STAGE2_POTATO_MODEL_PATH)
            print("✅ All models loaded successfully.")

        except Exception as e:
            print(f"❌ CRITICAL ERROR: Could not load one or more models. Ensure files ({STAGE1_MODEL_PATH}, etc.) are in the correct directory.")
            print(f"Details: {e}")
            self.stage1_model = None
            self.stage2_tomato_model = None
            self.stage2_potato_model = None


    def predict(self, image_path):
        """
        Executes the two-stage inference pipeline for a single image.
        Returns (plant_label, disease_label, confidence)
        """
        if not all([self.stage1_model, self.stage2_tomato_model, self.stage2_potato_model]):
            return "N/A", "Models Not Loaded", 0.0

        processed_image = preprocess_image(image_path)
        if processed_image is None:
            return "N/A", "Image Processing Failed", 0.0

        # 1. STAGE 1: COARSE CLASSIFICATION (Plant Type)
        stage1_pred = self.stage1_model.predict(processed_image, verbose=0)
        
        # Determine plant based on sigmoid output (index 1 is Tomato, index 0 is Potato)
        plant_index = 1 if stage1_pred[0][0] >= 0.5 else 0 
        plant_label = STAGE1_LABELS.get(plant_index, "Unknown Plant")
        
        # 2. STAGE 2: FINE-GRAINED CLASSIFICATION (Disease)
        if plant_label == 'Tomato':
            stage2_model = self.stage2_tomato_model
            disease_labels = self.tomato_labels
        elif plant_label == 'Potato':
            stage2_model = self.stage2_potato_model
            disease_labels = self.potato_labels
        else:
            return plant_label, "Inference Failed", 0.0
            
        stage2_pred = stage2_model.predict(processed_image, verbose=0)
        
        disease_index = np.argmax(stage2_pred[0])
        disease_confidence = np.max(stage2_pred[0])
        disease_label = disease_labels.get(disease_index, "Unknown Disease Index")
        
        return plant_label, disease_label, disease_confidence


# --- BATCH EXECUTION FUNCTION ---

def batch_predict(folder_path):
    """
    Iterates through all files in the given folder and prints the prediction for each image.
    """
    classifier = TwoStageClassifier()

    if classifier.stage1_model is None:
        return

    if not os.path.isdir(folder_path):
        print(f"\n❌ Error: Folder not found at path: {folder_path}")
        return

    print(f"\n--- Starting Batch Prediction for folder: {folder_path} ---")
    results = []
    
    # Walk through the directory (including subdirectories if needed, but typically only files in the root)
    for root, dirs, files in os.walk(folder_path):
        for file_name in files:
            # Check for common image extensions
            if file_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                image_path = os.path.join(root, file_name)
                
                # Run the prediction pipeline
                plant, disease, confidence = classifier.predict(image_path)

                results.append({
                    "FileName": file_name,
                    "Plant": plant,
                    "Condition": disease,
                    "Confidence": f"{confidence:.4f}"
                })
                
                print(f"[{file_name}] -> PLANT: {plant:<6} | DISEASE: {disease:<40} | CONFIDENCE: {confidence:.4f}")

    if results:
        df = pd.DataFrame(results)
        print("\n--- Summary DataFrame ---")
        # Display the results in a clean table format
        print(df.to_markdown(index=False))
    else:
        print("No image files found in the specified folder.")

# --- INTERACTIVE EXECUTION ---

if __name__ == "__main__":
    print("\n--- Two-Stage Batch Disease Predictor ---")
    
    # Prompt user for folder path
    folder_path = input("Enter the path to the folder containing leaf images: ").strip()
    
    if not folder_path:
        print("No path entered. Exiting.")
    else:
        batch_predict(folder_path)



--- Two-Stage Batch Disease Predictor ---


Enter the path to the folder containing leaf images:  /kaggle/input/combined-dataset1to4-modified/Combined_Dataset1to4/Tomato___Bacterial_spot


Loading Models...
✅ All models loaded successfully.

--- Starting Batch Prediction for folder: /kaggle/input/combined-dataset1to4-modified/Combined_Dataset1to4/Tomato___Bacterial_spot ---
[095f2dd4-7e65-44ab-a867-c5d9634ec532___GCREC_Bact.Sp_3801.JPG] -> PLANT: Tomato | DISEASE: Bacterial spot                           | CONFIDENCE: 0.9944
[Bacterial_spots2277_jpg.rf.aab397a71c8e464d67441af26235969f.jpg] -> PLANT: Potato | DISEASE: Late blight                              | CONFIDENCE: 0.7132
[a2ea3cb7-5c9e-4d3c-9e6c-42f7eb5f98a4___GCREC_Bact.Sp_6076.JPG] -> PLANT: Tomato | DISEASE: Early blight                             | CONFIDENCE: 0.9381
[3926a14d-ed26-4c2b-9dcc-a15370eae355___GCREC_Bact.Sp_5647.JPG] -> PLANT: Tomato | DISEASE: Bacterial spot                           | CONFIDENCE: 0.9972
[922416fb-0c08-4edd-894b-bbea0ba183f3___GCREC_Bact.Sp_5900.JPG] -> PLANT: Tomato | DISEASE: Bacterial spot                           | CONFIDENCE: 0.9960
[bacterial-canker4x640-1nz1vm7_jpg.rf.ca