# 🌾 GrainPalette - Rice Type Classification Using Transfer Learning (MobileNetV2)

In [None]:
# 📦 Install dependencies
!pip install -q tensorflow matplotlib

In [None]:
# 📚 Import libraries
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
import numpy as np
import os

In [None]:
# 📁 Mount Google Drive (optional: if using dataset from drive)
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# 📂 Set dataset directory (adjust this to your own path)
dataset_path = '/content/drive/MyDrive/rice_dataset/'

In [None]:
# 🔄 Data Preprocessing
IMG_SIZE = 224
BATCH_SIZE = 32

datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=0.2,
    horizontal_flip=True,
    zoom_range=0.2
)

train_data = datagen.flow_from_directory(
    dataset_path,
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    subset='training'
)

val_data = datagen.flow_from_directory(
    dataset_path,
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    subset='validation'
)

In [None]:
# 🧠 Load MobileNetV2 base model
base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(IMG_SIZE, IMG_SIZE, 3))
base_model.trainable = False

In [None]:
# 🔧 Build the classifier
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(128, activation='relu')(x)
predictions = Dense(train_data.num_classes, activation='softmax')(x)

model = Model(inputs=base_model.input, outputs=predictions)

In [None]:
# ⚙️ Compile the model
model.compile(optimizer=Adam(), loss='categorical_crossentropy', metrics=['accuracy'])

In [None]:
# 🏋️ Train the model
history = model.fit(
    train_data,
    validation_data=val_data,
    epochs=10
)

In [None]:
# 📊 Plot training accuracy
plt.plot(history.history['accuracy'], label='Train')
plt.plot(history.history['val_accuracy'], label='Validation')
plt.title('Accuracy over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
# 💾 Save model
model.save('rice_type_classifier.h5')
print('Model saved as rice_type_classifier.h5')