<a href="https://colab.research.google.com/github/smwaingeni-ai/AgriX-AfricaDeepTech2025/blob/main/notebooks/crop_disease_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**🔹 Step 1: Install Required Libraries**

In [1]:
!pip install tensorflow matplotlib



**Step 2: Import Libraries**

In [2]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
import pathlib
from PIL import Image
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers, models


**step 3: Clone Your GitHub Repo**

In [7]:
# Remove old folder (optional)
!rm -rf AgriX-AfricaDeepTech2025

# Clone again
!git clone https://github.com/smwaingeni-ai/AgriX-AfricaDeepTech2025.git

# Change directory
%cd AgriX-AfricaDeepTech2025


Cloning into 'AgriX-AfricaDeepTech2025'...
remote: Enumerating objects: 596, done.[K
remote: Counting objects: 100% (325/325), done.[K
remote: Compressing objects: 100% (243/243), done.[K
remote: Total 596 (delta 190), reused 81 (delta 81), pack-reused 271 (from 2)[K
Receiving objects: 100% (596/596), 24.39 MiB | 14.16 MiB/s, done.
Resolving deltas: 100% (229/229), done.
/content/AgriX-AfricaDeepTech2025


In [8]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Parameters
data_dir = "/content/AgriX-AfricaDeepTech2025/data/crops/plantvillage_subset"
img_size = (128, 128)
batch_size = 16

# Generator
datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)

train_generator = datagen.flow_from_directory(
    data_dir,
    target_size=img_size,
    batch_size=batch_size,
    class_mode='categorical',
    subset='training'
)

val_generator = datagen.flow_from_directory(
    data_dir,
    target_size=img_size,
    batch_size=batch_size,
    class_mode='categorical',
    subset='validation'
)

print("✅ Detected classes:", list(train_generator.class_indices.keys()))


Found 35 images belonging to 3 classes.
Found 8 images belonging to 3 classes.
✅ Detected classes: ['Healthy', 'Maize___Leaf_Spot', 'Tomato___Bacterial_spot']


**Step 2: Build the CNN Model**

In [9]:
from tensorflow.keras import layers, models

model = models.Sequential([
    layers.Input(shape=(128, 128, 3)),
    layers.Conv2D(32, (3, 3), activation='relu'),
    layers.MaxPooling2D(2, 2),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D(2, 2),
    layers.Conv2D(128, (3, 3), activation='relu'),
    layers.MaxPooling2D(2, 2),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(train_generator.num_classes, activation='softmax')
])

model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

model.summary()


**Step 3: Train the Model**

In [10]:
history = model.fit(
    train_generator,
    epochs=5,
    validation_data=val_generator
)


Epoch 1/5


  self._warn_if_super_not_called()


[1m2/3[0m [32m━━━━━━━━━━━━━[0m[37m━━━━━━━[0m [1m0s[0m 175ms/step - accuracy: 0.2303 - loss: 1.4289



[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 520ms/step - accuracy: 0.2437 - loss: 1.4729 - val_accuracy: 0.5000 - val_loss: 1.2066
Epoch 2/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 432ms/step - accuracy: 0.5139 - loss: 1.1154 - val_accuracy: 0.3750 - val_loss: 1.1386
Epoch 3/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 513ms/step - accuracy: 0.3810 - loss: 1.0884 - val_accuracy: 0.3750 - val_loss: 1.0550
Epoch 4/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 589ms/step - accuracy: 0.3492 - loss: 1.0406 - val_accuracy: 0.7500 - val_loss: 0.9896
Epoch 5/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 400ms/step - accuracy: 0.6557 - loss: 0.9487 - val_accuracy: 0.7500 - val_loss: 0.8424


**Step 4: Export to TFLite for Mobile Use**

In [11]:
# Convert to TensorFlow Lite
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

# Save model
with open("crop_disease_model.tflite", "wb") as f:
    f.write(tflite_model)

print("✅ TFLite model saved!")


Saved artifact at '/tmp/tmpa19pgsoh'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 128, 128, 3), dtype=tf.float32, name='keras_tensor')
Output Type:
  TensorSpec(shape=(None, 3), dtype=tf.float32, name=None)
Captures:
  140060670751568: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140063349912400: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140060670754448: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140060670746384: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140060670752144: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140060670751760: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140060670750416: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140060670752720: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140060670748880: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140060670749840: TensorSpec(shape=(), dtype=tf.resource, name=None)
✅ TFLite model s

In [22]:
import tensorflow as tf
import os

# ✅ Get class names
class_names = list(train_generator.class_indices.keys())

# === 1. Define and train a simple CNN model ===
model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(128, 128, 3)),
    tf.keras.layers.Conv2D(16, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Conv2D(32, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(len(class_names), activation='softmax')
])

model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# Train the model
model.fit(train_generator, epochs=5)

# === 2. Convert model to TFLite ===
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

# === 3. Save the TFLite model ===
output_path = "/content/AgriX-AfricaDeepTech2025/prototype/tflite_model/crop_disease_model.tflite"
os.makedirs(os.path.dirname(output_path), exist_ok=True)

with open(output_path, "wb") as f:
    f.write(tflite_model)

print(f"✅ TFLite model saved to: {output_path}")

# === 4. Verify ===
!ls -lh /content/AgriX-AfricaDeepTech2025/prototype/tflite_model/


Epoch 1/5




[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 204ms/step - accuracy: 0.3612 - loss: 1.8212
Epoch 2/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 231ms/step - accuracy: 0.3641 - loss: 1.1686
Epoch 3/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 455ms/step - accuracy: 0.6520 - loss: 0.8655
Epoch 4/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 308ms/step - accuracy: 0.7277 - loss: 0.8472
Epoch 5/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 415ms/step - accuracy: 0.5995 - loss: 0.8321
Saved artifact at '/tmp/tmpt_q1q74r'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 128, 128, 3), dtype=tf.float32, name='keras_tensor_41')
Output Type:
  TensorSpec(shape=(None, 3), dtype=tf.float32, name=None)
Captures:
  140060528942480: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140060528941520: TensorSpec(shape=(), dtype=tf.resource, name=None