In [None]:
# Import necessary libraries
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import ResNet50
from tensorflow.keras import layers, models
from tensorflow.keras.optimizers import Adam

# Set image dimensions
IMG_WIDTH, IMG_HEIGHT = 224, 224

# Define data generators for training and validation
train_datagen = ImageDataGenerator(rescale=1./255, 
                                   shear_range=0.2, 
                                   zoom_range=0.2, 
                                   horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1./255)

# Set directories for the dataset (Update paths accordingly)
train_dir = 'path_to_train_data'
validation_dir = 'path_to_validation_data'

train_generator = train_datagen.flow_from_directory(
    train_dir, 
    target_size=(IMG_WIDTH, IMG_HEIGHT),
    batch_size=32, 
    class_mode='categorical')

validation_generator = test_datagen.flow_from_directory(
    validation_dir, 
    target_size=(IMG_WIDTH, IMG_HEIGHT),
    batch_size=32, 
    class_mode='categorical')

# Load ResNet50 model with pre-trained weights
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(IMG_WIDTH, IMG_HEIGHT, 3))

# Freeze the layers of ResNet50
for layer in base_model.layers:
    layer.trainable = False

# Add custom layers for the rice leaf disease classification
model = models.Sequential([
    base_model,
    layers.Flatten(),
    layers.Dense(512, activation='relu'),
    layers.Dropout(0.5),
    layers.Dense(len(train_generator.class_indices), activation='softmax')  # Number of classes
])

# Compile the model
model.compile(optimizer=Adam(), loss='categorical_crossentropy', metrics=['accuracy'])

# Train the model
history = model.fit(train_generator, epochs=10, validation_data=validation_generator)

# Save the model
model.save('resnet50_rice_leaf_model.h5')

# Plot training and validation accuracy
plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0, 1])
plt.legend(loc='lower right')
plt.show()
