In [None]:
"""
Train Disease Percentage Prediction Model
Optimized version using kagglehub and tf.data pipeline
"""

import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
import kagglehub
import json

# -----------------------------
# Config
# -----------------------------
path = kagglehub.dataset_download("emmarex/plantdisease")
DATA_DIR = os.path.join(path, "PlantVillage")
IMG_SIZE = (224, 224)
BATCH_SIZE = 16
EPOCHS = 25

# Focus on common farmer crops
FARMER_CROPS = ["Tomato", "Potato", "Apple", "Corn", "Grape", "Strawberry", "Pepper", "Soybean"]

# -----------------------------
# Step 1: Collect image paths and labels
# -----------------------------
print("=" * 60)
print("Disease Percentage Model Training")
print("=" * 60)

image_paths = []
labels = []

# Filter classes to farmer-relevant crops
classes = [cls for cls in sorted(os.listdir(DATA_DIR))
           if any(crop.lower() in cls.lower() for crop in FARMER_CROPS)]

print(f"\nFound {len(classes)} relevant crop disease classes")

# Create disease percentage mapping
disease_mapping = {}
for cls in classes:
    if "healthy" in cls.lower():
        disease_mapping[cls] = 0.0
    elif "early" in cls.lower() or "spot" in cls.lower():
        disease_mapping[cls] = 0.2
    elif "late" in cls.lower() or "severe" in cls.lower():
        disease_mapping[cls] = 0.8
    else:
        disease_mapping[cls] = 0.5

# Save mapping
with open('disease_mapping.json', 'w') as f:
    json.dump(disease_mapping, f, indent=2)
print("Disease mapping saved")

# Collect image paths and labels
for cls_name in classes:
    cls_path = os.path.join(DATA_DIR, cls_name)
    img_files = [f for f in os.listdir(cls_path)
                 if f.lower().endswith((".jpg", ".jpeg", ".png"))][:400]  # Limit per class

    print(f"Loading {len(img_files)} images from {cls_name}")

    for f in img_files:
        image_paths.append(os.path.join(cls_path, f))
        # Add slight variation to disease percentage
        labels.append(disease_mapping[cls_name] + np.random.uniform(-0.05, 0.05))

image_paths = np.array(image_paths)
labels = np.clip(np.array(labels, dtype=np.float32), 0.0, 1.0)

print(f"\nTotal images collected: {len(image_paths)}")
print(f"Label distribution - Min: {labels.min():.3f}, Max: {labels.max():.3f}, Mean: {labels.mean():.3f}")

# -----------------------------
# Step 2: Create tf.data.Dataset
# -----------------------------
print("\nCreating tf.data pipeline...")

def process_image(file_path, label):
    """Efficient image processing pipeline"""
    img = tf.io.read_file(file_path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, IMG_SIZE)
    img = img / 255.0
    return img, label

# Create dataset
dataset = tf.data.Dataset.from_tensor_slices((image_paths, labels))
dataset = dataset.shuffle(buffer_size=len(labels), seed=42)
dataset = dataset.map(process_image, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

# Split into train and validation
val_size = int(0.2 * len(labels) / BATCH_SIZE)
train_ds = dataset.skip(val_size)
val_ds = dataset.take(val_size)

print(f"Training batches: {tf.data.experimental.cardinality(train_ds).numpy()}")
print(f"Validation batches: {tf.data.experimental.cardinality(val_ds).numpy()}")

# -----------------------------
# Step 3: Model Definition with Data Augmentation
# -----------------------------
print("\nBuilding model...")

# Data augmentation for better generalization
data_aug = tf.keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.2),
    layers.RandomZoom(0.15),
    layers.RandomContrast(0.1)
])

# Base model with fine-tuning
base_model = tf.keras.applications.MobileNetV2(
    input_shape=(224, 224, 3),
    include_top=False,
    weights="imagenet"
)

# Unfreeze last 150 layers for fine-tuning
base_model.trainable = True
for layer in base_model.layers[:-150]:
    layer.trainable = False

# Build complete model
inputs = tf.keras.Input(shape=(224, 224, 3))
x = data_aug(inputs)
x = tf.keras.applications.mobilenet_v2.preprocess_input(x)
x = base_model(x, training=True)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(256, activation="relu")(x)
x = layers.Dropout(0.3)(x)
x = layers.Dense(128, activation="relu")(x)
x = layers.Dropout(0.2)(x)
outputs = layers.Dense(1, activation="sigmoid")(x)

model = tf.keras.Model(inputs, outputs)

# Compile with Huber loss (robust to outliers)
model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4),
    loss=tf.keras.losses.Huber(),
    metrics=["mae", "mse"]
)

print("\nModel Summary:")
model.summary()

# -----------------------------
# Step 4: Training
# -----------------------------
print("\n" + "=" * 60)
print("Starting Training...")
print("=" * 60)

callbacks = [
    tf.keras.callbacks.EarlyStopping(
        monitor="val_loss",
        patience=5,
        restore_best_weights=True,
        verbose=1
    ),
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss",
        factor=0.5,
        patience=3,
        min_lr=1e-7,
        verbose=1
    ),
    tf.keras.callbacks.ModelCheckpoint(
        'disease_percentage_model_best.h5',
        monitor='val_loss',
        save_best_only=True,
        verbose=1
    )
]

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    callbacks=callbacks,
    verbose=1
)

# -----------------------------
# Step 5: Evaluation and Save
# -----------------------------
print("\n" + "=" * 60)
print("Training Complete! Evaluating...")
print("=" * 60)

# Final evaluation
results = model.evaluate(val_ds, verbose=0)
print(f"\nFinal Validation Results:")
print(f"  Loss: {results[0]:.4f}")
print(f"  MAE: {results[1]:.4f}")
print(f"  MSE: {results[2]:.4f}")
print(f"  RMSE: {np.sqrt(results[2]):.4f}")

# Sample predictions
print("\nSample Predictions:")
sample_batch = next(iter(val_ds))
sample_preds = model.predict(sample_batch[0][:5], verbose=0)
sample_actuals = sample_batch[1][:5].numpy()

for i, (pred, actual) in enumerate(zip(sample_preds, sample_actuals)):
    error = abs(pred[0] - actual)
    print(f"  Sample {i+1}: Predicted={pred[0]:.3f}, Actual={actual:.3f}, Error={error:.3f}")

# Save complete pipeline
import joblib
joblib.dump((model, history.history), "plant_disease_pipeline_final.pkl", compress=3)
model.save('disease_percentage_model.h5')

print("\n✅ Models saved successfully!")
print("  - plant_disease_pipeline_final.pkl (complete pipeline)")
print("  - disease_percentage_model.h5 (Keras model)")
print("  - disease_percentage_model_best.h5 (best checkpoint)")
print("  - disease_mapping.json (class mapping)")

print("\n" + "=" * 60)
print("Training Pipeline Complete!")
print("=" * 60)

Using Colab cache for faster access to the 'plantdisease' dataset.
Disease Percentage Model Training

Found 15 relevant crop disease classes
Disease mapping saved
Loading 400 images from Pepper__bell___Bacterial_spot
Loading 400 images from Pepper__bell___healthy
Loading 400 images from Potato___Early_blight
Loading 400 images from Potato___Late_blight
Loading 152 images from Potato___healthy
Loading 400 images from Tomato_Bacterial_spot
Loading 400 images from Tomato_Early_blight
Loading 400 images from Tomato_Late_blight
Loading 400 images from Tomato_Leaf_Mold
Loading 400 images from Tomato_Septoria_leaf_spot
Loading 400 images from Tomato_Spider_mites_Two_spotted_spider_mite
Loading 400 images from Tomato__Target_Spot
Loading 400 images from Tomato__Tomato_YellowLeaf__Curl_Virus
Loading 373 images from Tomato__Tomato_mosaic_virus
Loading 400 images from Tomato_healthy

Total images collected: 5725
Label distribution - Min: 0.000, Max: 0.850, Mean: 0.314

Creating tf.data pipeline..


Starting Training...
Epoch 1/25
[1m287/287[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 141ms/step - loss: 0.0315 - mae: 0.1952 - mse: 0.0629
Epoch 1: val_loss improved from inf to 0.07048, saving model to disease_percentage_model_best.h5




[1m287/287[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m80s[0m 156ms/step - loss: 0.0314 - mae: 0.1951 - mse: 0.0629 - val_loss: 0.0705 - val_mae: 0.3388 - val_mse: 0.1410 - learning_rate: 1.0000e-04
Epoch 2/25
[1m287/287[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 143ms/step - loss: 0.0180 - mae: 0.1428 - mse: 0.0359
Epoch 2: val_loss improved from 0.07048 to 0.05192, saving model to disease_percentage_model_best.h5




[1m287/287[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m45s[0m 152ms/step - loss: 0.0180 - mae: 0.1428 - mse: 0.0359 - val_loss: 0.0519 - val_mae: 0.2868 - val_mse: 0.1038 - learning_rate: 1.0000e-04
Epoch 3/25
[1m287/287[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 146ms/step - loss: 0.0143 - mae: 0.1263 - mse: 0.0286
Epoch 3: val_loss improved from 0.05192 to 0.04728, saving model to disease_percentage_model_best.h5




[1m287/287[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m46s[0m 155ms/step - loss: 0.0143 - mae: 0.1263 - mse: 0.0286 - val_loss: 0.0473 - val_mae: 0.2707 - val_mse: 0.0946 - learning_rate: 1.0000e-04
Epoch 4/25
[1m287/287[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 144ms/step - loss: 0.0135 - mae: 0.1186 - mse: 0.0269
Epoch 4: val_loss improved from 0.04728 to 0.03214, saving model to disease_percentage_model_best.h5




[1m287/287[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m45s[0m 154ms/step - loss: 0.0134 - mae: 0.1186 - mse: 0.0269 - val_loss: 0.0321 - val_mae: 0.2258 - val_mse: 0.0643 - learning_rate: 1.0000e-04
Epoch 5/25
[1m287/287[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 144ms/step - loss: 0.0104 - mae: 0.1014 - mse: 0.0208
Epoch 5: val_loss improved from 0.03214 to 0.03190, saving model to disease_percentage_model_best.h5




[1m287/287[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m45s[0m 154ms/step - loss: 0.0104 - mae: 0.1014 - mse: 0.0208 - val_loss: 0.0319 - val_mae: 0.2130 - val_mse: 0.0638 - learning_rate: 1.0000e-04
Epoch 6/25
[1m287/287[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 146ms/step - loss: 0.0076 - mae: 0.0855 - mse: 0.0152
Epoch 6: val_loss did not improve from 0.03190
[1m287/287[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m45s[0m 153ms/step - loss: 0.0076 - mae: 0.0855 - mse: 0.0152 - val_loss: 0.0555 - val_mae: 0.2434 - val_mse: 0.1110 - learning_rate: 1.0000e-04
Epoch 7/25
[1m287/287[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 145ms/step - loss: 0.0072 - mae: 0.0828 - mse: 0.0144
Epoch 7: val_loss improved from 0.03190 to 0.03149, saving model to disease_percentage_model_best.h5




[1m287/287[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m45s[0m 154ms/step - loss: 0.0072 - mae: 0.0828 - mse: 0.0144 - val_loss: 0.0315 - val_mae: 0.2231 - val_mse: 0.0630 - learning_rate: 1.0000e-04
Epoch 8/25
[1m287/287[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 144ms/step - loss: 0.0069 - mae: 0.0792 - mse: 0.0139
Epoch 8: val_loss did not improve from 0.03149
[1m287/287[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m45s[0m 153ms/step - loss: 0.0069 - mae: 0.0792 - mse: 0.0139 - val_loss: 0.0475 - val_mae: 0.2746 - val_mse: 0.0950 - learning_rate: 1.0000e-04
Epoch 9/25
[1m287/287[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 145ms/step - loss: 0.0065 - mae: 0.0752 - mse: 0.0129
Epoch 9: val_loss did not improve from 0.03149
[1m287/287[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m45s[0m 151ms/step - loss: 0.0065 - mae: 0.0752 - mse: 0.0129 - val_loss: 0.0536 - val_mae:




✅ Models saved successfully!
  - plant_disease_pipeline_final.pkl (complete pipeline)
  - disease_percentage_model.h5 (Keras model)
  - disease_percentage_model_best.h5 (best checkpoint)
  - disease_mapping.json (class mapping)

Training Pipeline Complete!
