# PlantVillage Plant Disease Classifier Training

This notebook trains a deep learning model to classify plant diseases using the PlantVillage dataset.

➡️ **Set your dataset path below!**

In [2]:
# 1. Environment check: GPU availability
import tensorflow as tf

gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"GPUs found: {gpus}")
    except RuntimeError as e:
        print(e)
else:
    print("No GPU found. Training will use CPU (slow).")

ModuleNotFoundError: No module named 'tensorflow'

In [None]:
# 2. Set up paths and hyperparameters
import os

# Set your PlantVillage data directory here
DATA_DIR = r'C:\Users\parth\Downloads\Agrostuff\data\PlantVillage'  # Update if needed
MODEL_DIR = 'model'
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 10

os.makedirs(MODEL_DIR, exist_ok=True)

In [None]:
# 3. Data preparation
from tensorflow.keras.preprocessing.image import ImageDataGenerator

datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=0.2,
    horizontal_flip=True,
    zoom_range=0.2
)

train_gen = datagen.flow_from_directory(
    DATA_DIR,
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    subset='training',
    shuffle=True
)

val_gen = datagen.flow_from_directory(
    DATA_DIR,
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    subset='validation',
    shuffle=True
)

# Save class names for later use in the app
import pickle
with open(os.path.join(MODEL_DIR, 'class_names.pkl'), 'wb') as f:
    pickle.dump(list(train_gen.class_indices.keys()), f)
print('Class names saved.')

In [None]:
# 4. Build the model (MobileNetV2 transfer learning)
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(IMG_SIZE, IMG_SIZE, 3))
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(128, activation='relu')(x)
output = Dense(train_gen.num_classes, activation='softmax')(x)

model = Model(inputs=base_model.input, outputs=output)

# Freeze base model layers
for layer in base_model.layers:
    layer.trainable = False

model.compile(optimizer=Adam(), loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()

In [None]:
# 5. Train the model
history = model.fit(
    train_gen,
    validation_data=val_gen,
    epochs=EPOCHS
)

In [None]:
# 6. Save the trained model
model.save(os.path.join(MODEL_DIR, 'plant_disease_model.h5'))
print('Model saved to', os.path.join(MODEL_DIR, 'plant_disease_model.h5'))

## Done!

- Your trained model is saved in the `model` directory.
- Class names are saved as `class_names.pkl`.
- You can now use this model in your Streamlit app!