In [None]:
import os
import zipfile
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from google.colab import drive
import matplotlib.pyplot as plt


print("TensorFlow Version:", tf.__version__)
gpu_devices = tf.config.list_physical_devices('GPU')
if gpu_devices:
    print(f"✅ GPU Detected: {gpu_devices}")
    BATCH_SIZE = 64
else:
    print("⚠️ No GPU detected. Go to Runtime > Change runtime type > T4 GPU")
    BATCH_SIZE = 32

drive.mount('/content/drive')

In [None]:
zip_path = '/content/drive/MyDrive/practice/archive.zip'
extract_path = '/content/plant_data_extracted'


if not os.path.exists(extract_path):
    print("Unzipping dataset... (This will take ~3-5 mins due to 2.7GB size)")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_path)
    print("Unzipping Complete.")
else:
    print("Dataset already extracted. Skipping unzip.")


base_inner = os.path.join(extract_path, 'New Plant Diseases Dataset(Augmented)', 'New Plant Diseases Dataset(Augmented)')
train_dir = os.path.join(base_inner, 'train')
valid_dir = os.path.join(base_inner, 'valid')


test_dir_raw = os.path.join(extract_path, 'test', 'test')


print(f"Train Folder Found: {os.path.exists(train_dir)}")
print(f"Test Images Found: {len(os.listdir(test_dir_raw)) if os.path.exists(test_dir_raw) else 0}")


IMG_SIZE = (224, 224)


train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

valid_datagen = ImageDataGenerator(rescale=1./255)

print("\n--- Loading Data ---")
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical'
)

valid_generator = valid_datagen.flow_from_directory(
    valid_dir,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical'
)


base_model = MobileNetV2(input_shape=IMG_SIZE + (3,), include_top=False, weights='imagenet')
base_model.trainable = False


x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dropout(0.2)(x)
predictions = Dense(train_generator.num_classes, activation='softmax')(x)

model = Model(inputs=base_model.input, outputs=predictions)

model.compile(optimizer=Adam(learning_rate=0.001),
              loss='categorical_crossentropy',
              metrics=['accuracy'])


print(f"\nStarting Training with Batch Size {BATCH_SIZE}...")
history = model.fit(
    train_generator,
    epochs=5,
    validation_data=valid_generator
)


print("\n--- Testing on Unseen Images ---")
class_map = {v: k for k, v in train_generator.class_indices.items()}


test_files = os.listdir(test_dir_raw)
import random
selected_files = random.sample(test_files, min(len(test_files), 5))

plt.figure(figsize=(15, 5))

for i, file in enumerate(selected_files):
    path = os.path.join(test_dir_raw, file)


    img = load_img(path, target_size=IMG_SIZE)
    img_arr = img_to_array(img)
    img_arr = np.expand_dims(img_arr, axis=0) / 255.0


    pred = model.predict(img_arr, verbose=0)
    pred_class_index = np.argmax(pred)
    pred_label = class_map[pred_class_index]
    confidence = np.max(pred)


    plt.subplot(1, 5, i+1)
    plt.imshow(img)
    plt.title(f"{pred_label}\n({confidence*100:.1f}%)", fontsize=10)
    plt.axis('off')

plt.show()

In [None]:
train_dir = '/content/plant_data_extracted/New Plant Diseases Dataset(Augmented)/New Plant Diseases Dataset(Augmented)/train'

class_names = sorted(os.listdir(train_dir))


joblib.dump(class_names, '/content/drive/MyDrive/Plant_Project/plant_disease_classes.pkl')

print(f"Success! Saved {len(class_names)} disease classes to Google Drive.")
print("Here are the first few classes just to be sure:", class_names[:3])

In [None]:
model.save('/content/drive/MyDrive/Plant_Project/plant_disease_model.h5')