In [1]:
import os
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from tensorflow.keras.applications.efficientnet_v2 import EfficientNetV2S
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint

train_dir = r'archive\New Plant Diseases Dataset(Augmented)\New Plant Diseases Dataset(Augmented)\train'
valid_dir = r'archive\New Plant Diseases Dataset(Augmented)\New Plant Diseases Dataset(Augmented)\valid'

img_size = (224, 224)
batch_size = 32

train_datagen = ImageDataGenerator(rescale=1./255)
valid_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=img_size,
    batch_size=batch_size,
    class_mode='categorical')

valid_generator = valid_datagen.flow_from_directory(
    valid_dir,
    target_size=img_size,
    batch_size=batch_size,
    class_mode='categorical')

model = Sequential()
model.add(EfficientNetV2S(include_top=False, input_shape=(224, 224, 3), weights='imagenet'))
model.add(GlobalAveragePooling2D())
model.add(Dense(train_generator.num_classes, activation='softmax'))

model.compile(optimizer=Adam(lr=0.001), loss='categorical_crossentropy', metrics=['accuracy'])


checkpoint_path = 'path/to/checkpoints/efficientnetv2.h5'
checkpoint_callback = ModelCheckpoint(checkpoint_path, monitor='val_accuracy', verbose=1, save_best_only=True, mode='max')

epochs = 10
for epoch in range(epochs):
    model.fit(train_generator,
              epochs=1, 
              validation_data=valid_generator,
              callbacks=[checkpoint_callback])

    loss, accuracy = model.evaluate(valid_generator)
    print(f'Validation accuracy after epoch {epoch + 1}: {accuracy:.2f}')

    if os.path.exists(checkpoint_path):
        model.load_weights(checkpoint_path)
    else:
        print("Checkpoint file not found.")

    
    if epoch > 0:
        if accuracy <= prev_accuracy:
            print("Validation accuracy did not improve. Stopping training.")
            break

    prev_accuracy = accuracy


if os.path.exists(checkpoint_path):
    model.load_weights(checkpoint_path)


loss, accuracy = model.evaluate(valid_generator)
print(f'Final validation accuracy: {accuracy:.2f}')


X_train = np.concatenate([batch[0] for batch in train_generator], axis=0)
y_train = np.concatenate([batch[1] for batch in train_generator], axis=0)
X_valid = np.concatenate([batch[0] for batch in valid_generator], axis=0)
y_valid = np.concatenate([batch[1] for batch in valid_generator], axis=0)

rf_model = RandomForestClassifier()
rf_model.fit(X_train.reshape(X_train.shape[0], -1), y_train)

rf_accuracy = rf_model.score(X_valid.reshape(X_valid.shape[0], -1), y_valid)
print(f'Random Forest accuracy: {rf_accuracy:.2f}')

Found 69768 images belonging to 38 classes.
Found 17452 images belonging to 38 classes.




 184/2181 [=>............................] - ETA: 7:26:27 - loss: 0.8282 - accuracy: 0.7610

KeyboardInterrupt: 