In [1]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau


In [2]:
train_dir = 'C:\\Users\\KIIT\\Downloads\\riceleaf\\train'
test_dir = 'C:\\Users\\KIIT\\Downloads\\riceleaf\\test'

# Define image dimensions and batch size
img_height, img_width = 224, 224
batch_size = 5

# Use ImageDataGenerator for data augmentation and normalization
train_datagen = ImageDataGenerator(
    rescale=1./255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True
)

test_datagen = ImageDataGenerator(rescale=1./255)

# Generate batches of augmented data for training and validation
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical'
)

validation_generator = test_datagen.flow_from_directory(
    test_dir,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical'
)


Found 4153 images belonging to 4 classes.
Found 1779 images belonging to 4 classes.


In [9]:
# Load MobileNetV2 model pretrained on ImageNet without top layer
base_model = tf.keras.applications.MobileNetV2(weights='imagenet', include_top=False, input_shape=(img_height, img_width, 3))

# Add custom top layers for classification
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(512, activation='relu')(x)
predictions = Dense(len(train_generator.class_indices), activation='softmax')(x)

# Combine base model and top layers
model = Model(inputs=base_model.input, outputs=predictions)

# Freeze base layers during initial training
for layer in base_model.layers:
    layer.trainable = False

# Compile the model
model.compile(optimizer=Adam(learning_rate=0.001), loss='categorical_crossentropy', metrics=['accuracy'])

model.summary()


In [5]:
# Define callbacks (optional: EarlyStopping, ReduceLROnPlateau)
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, min_lr=1e-6)

# Train the model
epochs = 4
history = model.fit(
    train_generator,
    epochs=epochs,
    validation_data=validation_generator,
    callbacks=[early_stopping, reduce_lr]
)


Epoch 1/4


  self._warn_if_super_not_called()


[1m831/831[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m170s[0m 199ms/step - accuracy: 0.8531 - loss: 0.4336 - val_accuracy: 0.9545 - val_loss: 0.1045 - learning_rate: 0.0010
Epoch 2/4
[1m831/831[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m162s[0m 195ms/step - accuracy: 0.9669 - loss: 0.0900 - val_accuracy: 0.9719 - val_loss: 0.0645 - learning_rate: 0.0010
Epoch 3/4
[1m831/831[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m157s[0m 188ms/step - accuracy: 0.9791 - loss: 0.0579 - val_accuracy: 0.9893 - val_loss: 0.0309 - learning_rate: 0.0010
Epoch 4/4
[1m831/831[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m175s[0m 209ms/step - accuracy: 0.9807 - loss: 0.0727 - val_accuracy: 0.9882 - val_loss: 0.0343 - learning_rate: 0.0010


In [7]:
# Evaluate the model on the test set
test_loss, test_acc = model.evaluate(validation_generator, verbose=2)
print(f'Test accuracy: {test_acc:.4f}')


356/356 - 36s - 102ms/step - accuracy: 0.9893 - loss: 0.0309
Test accuracy: 0.9893


In [None]:
w = 40
h = 30
fig = plt.figure(figsize=(12, 8))
columns = 5
rows = 3

# Ensure we don't exceed the number of available images/labels
num_images = min(columns * rows, len(train_labels), len(train_images))

for i in range(1, num_images + 1):
    ax = fig.add_subplot(rows, columns, i)
    if train_labels[i - 1] == 0:
        ax.title.set_text('Bacterialblight')
    elif train_labels[i - 1] == 1:
        ax.title.set_text('Blast')
    elif train_labels[i - 1] == 2:
        ax.title.set_text('Brownspot')
    else:
        ax.title.set_text('Tungro')
    plt.imshow(train_images[i - 1], interpolation='nearest')

plt.tight_layout()
plt.show()