# ðŸŒ¾ Wheat Disease Detection - Model Training
**Classes:** Healthy | Unhealthy (Smut, Yellow Rust) | Others

**Architecture:** ResNet50 Transfer Learning (2-Phase Training)

## Instructions:
1. Upload `Wheat.zip` to your Google Drive or directly to Colab.
2. Run all cells in order.
3. After training, **download** `wheat_disease_model.h5` and `wheat_class_indices.json`.
4. Place both files into your project's `Ml_Models/` folder.

In [None]:
# ================================================
# CELL 1: Install Dependencies
# ================================================
!pip install opencv-python tensorflow


In [None]:
# ================================================
# CELL 2: Import Libraries
# ================================================
import os
import cv2
import json
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from sklearn.model_selection import train_test_split
from sklearn.utils import class_weight
import matplotlib.pyplot as plt

print("All libraries imported successfully!")
print("TensorFlow Version:", tf.__version__)


In [None]:
# ================================================
# CELL 3: Upload and Extract Dataset
# ================================================
# Option A: If you uploaded Wheat.zip directly to Colab runtime:
!unzip -q /content/Wheat.zip -d /content/
print("Folder extracted successfully!")


In [None]:
# ================================================
# CELL 4: Verify Dataset Structure
# ================================================
!ls /content/wheat
print("\n--- Healthy folder ---")
!ls /content/wheat/Healthy | wc -l
print("\n--- Others folder ---")
!ls /content/wheat/Others | wc -l
print("\n--- Unhealthy folder (sub-classes) ---")
!ls /content/wheat/Unhealthy


In [None]:
# ================================================
# CELL 5: Configuration
# ================================================

# Where your data is located after unzipping
DATA_DIR = '/content/wheat'

# Where to save the finished model and mapping file
MODEL_SAVE_PATH = '/content/wheat_disease_model.h5'
MAPPING_SAVE_PATH = '/content/wheat_class_indices.json'

IMG_SIZE = (224, 224)
BATCH_SIZE = 32
EPOCHS_PHASE_1 = 15
EPOCHS_PHASE_2 = 20

print("Configuration set!")
print(f"Data directory : {DATA_DIR}")
print(f"Model will be saved to: {MODEL_SAVE_PATH}")
print(f"Class mapping will be saved to: {MAPPING_SAVE_PATH}")


In [None]:
# ================================================
# CELL 6: Load Images into Memory
# ================================================
print("Loading images into memory...")

images = []
labels = []
class_names = []

# ================================
# SET YOUR LIMIT HERE
# ================================
MAX_IMAGES_PER_CLASS = 1100

def load_folder(folder_path, class_name):
    if not os.path.exists(folder_path): return
    if class_name not in class_names:
        class_names.append(class_name)
    class_idx = class_names.index(class_name)

    loaded_count = 0
    skipped_count = 0

    for filename in os.listdir(folder_path):
        # Stop loading if we hit the limit
        if loaded_count >= MAX_IMAGES_PER_CLASS:
            skipped_count += 1
            continue

        if filename.endswith(('.png', '.jpg', '.jpeg', '.JPG')):
            img_path = os.path.join(folder_path, filename)
            try:
                img = cv2.imread(img_path)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                img = cv2.resize(img, IMG_SIZE)
                images.append(img)
                labels.append(class_idx)
                loaded_count += 1
            except Exception as e:
                print(f"Error reading {img_path}: {e}")

    print(f"[{class_name}] Loaded: {loaded_count} images (Skipped: {skipped_count} extra images)")

# -------------------------------------------------
# Load Healthy class
# -------------------------------------------------
load_folder(os.path.join(DATA_DIR, 'Healthy'), 'Healthy')

# -------------------------------------------------
# Load Others class (non-wheat images)
# -------------------------------------------------
load_folder(os.path.join(DATA_DIR, 'Others'), 'Others')

# -------------------------------------------------
# Load Unhealthy sub-classes (Smut, Yellow Rust, etc.)
# -------------------------------------------------
unhealthy_dir = os.path.join(DATA_DIR, 'Unhealthy')
if os.path.exists(unhealthy_dir):
    for sub_dir in os.listdir(unhealthy_dir):
        path = os.path.join(unhealthy_dir, sub_dir)
        if os.path.isdir(path):
            load_folder(path, sub_dir)

# Convert to Numpy Arrays
X = np.array(images)
y = np.array(labels)

print(f"\nSUCCESS: Loaded {len(X)} total images.")
print(f"Classes Found ({len(class_names)}): {class_names}")


In [None]:
# ================================================
# CELL 7: Preprocess & Split Data
# ================================================
print("Preprocessing images for ResNet50...")
X = tf.keras.applications.resnet.preprocess_input(X)

print("Splitting data into Training (80%) and Validation (20%)...")
X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

print(f"Training Set: {len(X_train)} images")
print(f"Validation Set: {len(X_val)} images")

# Handle class imbalance automatically
weights = class_weight.compute_class_weight(
    class_weight='balanced',
    classes=np.unique(y_train),
    y=y_train
)
class_weights = dict(enumerate(weights))
print("Class Weights calculated!")
print(f"Class Weights: {class_weights}")


In [None]:
# ================================================
# CELL 8: Build the ResNet50 Transfer Learning Model
# ================================================
print("Building the Transfer Learning Model...")

# Data augmentation block
data_augmentation = keras.Sequential([
    layers.RandomFlip("horizontal_and_vertical"),
    layers.RandomRotation(0.2),
    layers.RandomZoom(0.2),
    layers.RandomTranslation(0.1, 0.1),
])

# Load Base Model
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(IMG_SIZE[0], IMG_SIZE[1], 3))
base_model.trainable = False  # Freeze base layers

# Create Top Model
inputs = keras.Input(shape=(IMG_SIZE[0], IMG_SIZE[1], 3))
x = data_augmentation(inputs)
x = base_model(x, training=False)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.4)(x)
x = layers.Dense(256, activation='relu')(x)
x = layers.Dropout(0.4)(x)
outputs = layers.Dense(len(class_names), activation='softmax')(x)

model = keras.Model(inputs, outputs)

# Save rules
callbacks = [
    EarlyStopping(patience=8, restore_best_weights=True, monitor='val_accuracy'),
    ModelCheckpoint(MODEL_SAVE_PATH, save_best_only=True, monitor='val_accuracy'),
    ReduceLROnPlateau(factor=0.2, patience=3, monitor='val_loss')
]

model.summary()


In [None]:
# ================================================
# CELL 9: Phase 1 - Train Custom Head Layers
# ================================================
print("Starting PHASE 1: Training Custom Head Layers")

model.compile(
    optimizer=Adam(1e-3),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

history1 = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=EPOCHS_PHASE_1,
    batch_size=BATCH_SIZE,
    callbacks=callbacks,
    class_weight=class_weights
)


In [None]:
# ================================================
# CELL 10: Phase 2 - Fine-Tune Top 30 Layers
# ================================================
print("Starting PHASE 2: Fine-Tuning Top 30 Layers")

# Unfreeze the base model
base_model.trainable = True

# Refreeze all layers except the top 30
for layer in base_model.layers[:-30]:
    layer.trainable = False

# Recompile with a MUCH LOWER learning rate
model.compile(
    optimizer=Adam(1e-5),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

history2 = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=EPOCHS_PHASE_2,
    batch_size=BATCH_SIZE,
    callbacks=callbacks,
    class_weight=class_weights
)


In [None]:
# ================================================
# CELL 11: Save Results & Plot Training Curves
# ================================================

# Save the mapping file so your Flask app knows which index is which class
class_mapping = {name: int(idx) for idx, name in enumerate(class_names)}
class_mapping_inverted = {int(idx): name for idx, name in enumerate(class_names)}

with open(MAPPING_SAVE_PATH, 'w') as f:
    json.dump(class_mapping, f)

print("\n--- TRAINING COMPLETE ---")
print(f"Model saved to: {MODEL_SAVE_PATH}")
print(f"Mapping saved to: {MAPPING_SAVE_PATH}")
print(f"Your App Classes: {class_mapping_inverted}")

# Plotting
acc = history1.history['accuracy'] + history2.history['accuracy']
val_acc = history1.history['val_accuracy'] + history2.history['val_accuracy']
loss = history1.history['loss'] + history2.history['loss']
val_loss = history1.history['val_loss'] + history2.history['val_loss']

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Acc')
plt.plot(val_acc, label='Validation Acc')
plt.axvline(x=EPOCHS_PHASE_1, color='red', linestyle='--', label='Fine-Tuning Starts')
plt.legend()
plt.title('Wheat Model - Accuracy')

plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.axvline(x=EPOCHS_PHASE_1, color='red', linestyle='--', label='Fine-Tuning Starts')
plt.legend()
plt.title('Wheat Model - Loss')
plt.show()


In [None]:
# ================================================
# CELL 12: Download Model Files
# ================================================
# Download both files to your local machine
from google.colab import files

print("Downloading wheat_disease_model.h5 ...")
files.download(MODEL_SAVE_PATH)

print("Downloading wheat_class_indices.json ...")
files.download(MAPPING_SAVE_PATH)

print("\nDone! Place both files in your project's Ml_Models/ folder.")
