In [None]:
import wandb
from wandb.integration.keras import WandbMetricsLogger
import os
os.environ['WANDB_API_KEY'] = 'api-key-here'

In [None]:
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.efficientnet_v2 import preprocess_input
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

In [None]:
base_dir = '/content/plant_disease_data/PlantVillage/color'

In [None]:
img_size = 256
batch_size = 32

In [None]:
train_data_gen = ImageDataGenerator(
    preprocessing_function=preprocess_input,
    validation_split=0.2,
    rotation_range=25,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.3,
    brightness_range=[0.8, 1.2],
    horizontal_flip=True,
    fill_mode='nearest'
)

In [None]:
val_data_gen = ImageDataGenerator(
    preprocessing_function=preprocess_input,
    validation_split=0.2
)

In [None]:
train_generator = train_data_gen.flow_from_directory(
    base_dir,
    target_size=(img_size, img_size),
    batch_size=batch_size,
    subset='training',
    class_mode='categorical',
    seed = 42,
    shuffle = True
)

In [None]:
validation_generator = val_data_gen.flow_from_directory(
    base_dir,
    target_size=(img_size, img_size),
    batch_size=batch_size,
    subset='validation',
    class_mode='categorical',
    seed = 42,
    shuffle = False
)

In [None]:
batch = next(train_generator)
print("Input range:", batch[0].min(), "to", batch[0].max())
print("Number of classes:", train_generator.num_classes)

In [None]:
from tensorflow.keras import Sequential, models, layers, optimizers
from tensorflow.keras.layers import Dropout, Dense, Flatten, BatchNormalization, GlobalAveragePooling2D
from tensorflow.keras.applications import EfficientNetV2B0
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.optimizers import AdamW

In [None]:
convolutional_base = EfficientNetV2B0(weights='imagenet', include_top=False, input_shape=(256,256,3))
convolutional_base.summary()

In [None]:
num_of_classes = 38

model = models.Sequential()

model.add(convolutional_base)
model.add(GlobalAveragePooling2D())
model.add(Dropout(0.5))
model.add(Dense(num_of_classes, activation='softmax'))

model.summary()

In [None]:
model.compile(
    optimizer=AdamW(learning_rate=0.001, weight_decay=0.01),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

In [None]:
wandb.init(
    project="plant-disease-cnn",
    name="best-EfficientNet-epoch5-test",
    reinit=True
)

early_stop = EarlyStopping(
    monitor='val_loss',
    patience=5,
    restore_best_weights=True,
    verbose=1
)

checkpoint = ModelCheckpoint(
    'best_model.keras',
    monitor='val_loss',
    save_best_only=True,
    mode='min',
    verbose=1
)


In [None]:
history = model.fit(
    train_generator,
    epochs=3,
    validation_data=validation_generator,
    callbacks=[WandbMetricsLogger(), early_stop, checkpoint],
)

In [None]:
val_loss, val_acc = model.evaluate(validation_generator, verbose=1)
print(f"\nValidation Accuracy: {val_acc:.4f}")
print(f"Validation Loss: {val_loss:.4f}")

In [None]:
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend(['Train', 'Validation'], loc = 'upper left')

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='loss')
plt.plot(history.history['val_loss'], label = 'val_loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(['Train', 'Test'], loc = 'upper left')

plt.show()

In [None]:
import os
import numpy as np
from PIL import Image
import json

def load_and_preprocess_image(image_path, target_size=(256, 256)):
    img = Image.open(image_path).convert('RGB')
    img = img.resize(target_size)
    plt.imshow(img)
    plt.axis('off')
    plt.show()

    img_array = np.array(img)
    img_array = np.expand_dims(img_array, axis=0)
    img_array = preprocess_input(img_array)
    return img_array


def predict_image_class(model, image_path, class_indices):
  preprocessing_img = load_and_preprocess_image(image_path)
  predictions = model.predict(preprocessing_img)
  print(predictions.shape)   # should be (N, 38)
  print(np.sum(predictions[0]))  # should be ~1.0 (since softmax)
  predicted_class_index = np.argmax(predictions, axis=1)[0]
  print(predicted_class_index)
  predicted_class_label = class_indices[predicted_class_index]
  return predicted_class_label


In [None]:
class_indices = {v: k for k, v in train_generator.class_indices.items()}

In [None]:
class_indices

In [None]:
json.dump(class_indices, open('class_indices.json', 'w'))

In [None]:
image_path = '/content/plant_disease_data/PlantVillage/color/Corn_(maize)___Common_rust_/RS_Rust 1565.JPG'
predicted_class_name = predict_image_class(model, image_path, class_indices)
print(f'The predicted class is: {predicted_class_name}')

In [None]:
model.export("EfficientnetV2B0_plant_disease_model")

In [None]:
import shutil

#zips exported folder
shutil.make_archive("/content/EfficientnetV2B0_plant_disease_model", 'zip', "/content/EfficientnetV2B0_plant_disease_model")

In [None]:
model.save('plant_disease_model.keras')