# Plant Leaf Classification Using Keras and Transfer Learning

This simple image classifier has been trained to predict the type of plant from a single leaf. The model's training dataset contains 12 different classes, 10 of which can be identified as healthy or diseased. Diseased Bael and healthy Basil are the 21st and 22nd identifiable classes. There is a total of 4503 images: 2278 of healthy leaves and 2225 of diseased leaves, making a 7.3GB total dataset size. Class names can be seen in the notebook below. You can obtain this dataset from Kaggle [here](https://www.kaggle.com/datasets/csafrit2/plant-leaves-for-image-classification)

The model has been validated during training, using 5 images from each class. Some of the images in the test folder were used for "blind testing" after deployment.

the final accuracy achieved with the Mobilenetv2 tranfser model was 88%, but it is believed that further improvement might be possible  with more rigorous hyper-parameter and augmentation testing.


# Machine Learning Code

In [None]:
import os
from tensorflow import keras
import tensorflow as tf
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2, preprocess_input
from tensorflow.keras.preprocessing.image import load_img, ImageDataGenerator
from tensorflow.keras.layers import Dense, Conv2D, GlobalAveragePooling2D,Input
from tensorflow.keras import callbacks, optimizers
import numpy as np
from google.colab import drive

In [None]:
drive.mount('/content/drive') # mount Google Drive

Mounted at /content/drive


In [None]:
#!ls

drive  sample_data


### 1. Unzip and Save Data

In [None]:
#%cd /content/drive/MyDrive/Colab Notebooks/Datasets/PlantClassification_dataset/

/content/drive/MyDrive/Colab Notebooks/Datasets/PlantClassification_dataset


In [None]:
#!unzip PlantClassification_dataset.zip

Archive:  PlantClassification_dataset.zip
  inflating: Plants_2/images to predict/0001_0170.JPG  
  inflating: Plants_2/images to predict/0003_0179.JPG  
  inflating: Plants_2/images to predict/0005_0268.JPG  
  inflating: Plants_2/images to predict/0008_0148.JPG  
  inflating: Plants_2/images to predict/0015_0123.JPG  
  inflating: Plants_2/images to predict/0016_0118.JPG  
  inflating: Plants_2/images to predict/0019_0276.JPG  
  inflating: Plants_2/images to predict/0020_0271.JPG  
  inflating: Plants_2/test/Alstonia Scholaris diseased (P2a)/0014_0006.JPG  
  inflating: Plants_2/test/Alstonia Scholaris diseased (P2a)/0014_0007.JPG  
  inflating: Plants_2/test/Alstonia Scholaris diseased (P2a)/0014_0008.JPG  
  inflating: Plants_2/test/Alstonia Scholaris diseased (P2a)/0014_0009.JPG  
  inflating: Plants_2/test/Alstonia Scholaris diseased (P2a)/0014_0010.JPG  
  inflating: Plants_2/test/Alstonia Scholaris healthy (P2b)/0003_0006.JPG  
  inflating: Plants_2/test/Alstonia Scholaris hea

## 2. Data Explore & Prepare

In [None]:
# Count number of images in each sub-folder in the train folder
train_path = "/content/drive/MyDrive/Colab Notebooks/Datasets/PlantClassification_dataset/Plants_2/train/"

for i in os.listdir(train_path):
  print(i,len(os.listdir(train_path+i))) 
  
  #print(i)

Alstonia Scholaris diseased (P2a) 244
Alstonia Scholaris healthy (P2b) 168
Arjun diseased (P1a) 222
Arjun healthy (P1b) 210
Bael diseased (P4b) 107
Basil healthy (P8) 137
Chinar diseased (P11b) 110
Chinar healthy (P11a) 93
Gauva diseased (P3b) 131
Gauva healthy (P3a) 267
Jamun diseased (P5b) 335
Jamun healthy (P5a) 268
Jatropha diseased (P6b) 114
Jatropha healthy (P6a) 123
Lemon diseased (P10b) 67
Lemon healthy (P10a) 149
Mango diseased (P0b) 255
Mango healthy (P0a) 159
Pomegranate diseased (P9b) 261
Pomegranate healthy (P9a) 277
Pongamia Pinnata diseased (P7b) 265
Pongamia Pinnata healthy (P7a) 312


In [None]:
# Set-up the paths to the validataion data(used to test after training), and the test data(will be used as the blind test)
test_path = "/content/drive/MyDrive/Colab Notebooks/Datasets/PlantClassification_dataset/Plants_2/test"
valid_path = "/content/drive/MyDrive/Colab Notebooks/Datasets/PlantClassification_dataset/Plants_2/valid"

In [None]:
# Class names.
class_names = sorted([f for f in os.listdir(train_path) if not f.startswith('.')])

for i in range(len(class_names)):
    print(i, class_names[i])

0 Alstonia Scholaris diseased (P2a)
1 Alstonia Scholaris healthy (P2b)
2 Arjun diseased (P1a)
3 Arjun healthy (P1b)
4 Bael diseased (P4b)
5 Basil healthy (P8)
6 Chinar diseased (P11b)
7 Chinar healthy (P11a)
8 Gauva diseased (P3b)
9 Gauva healthy (P3a)
10 Jamun diseased (P5b)
11 Jamun healthy (P5a)
12 Jatropha diseased (P6b)
13 Jatropha healthy (P6a)
14 Lemon diseased (P10b)
15 Lemon healthy (P10a)
16 Mango diseased (P0b)
17 Mango healthy (P0a)
18 Pomegranate diseased (P9b)
19 Pomegranate healthy (P9a)
20 Pongamia Pinnata diseased (P7b)
21 Pongamia Pinnata healthy (P7a)


This function takes in the indicated inputs and returns images of with pixel size equal to target_size. In this case the image values are scaled between -1 and 1 value range, ready to be consumed by the Mobilenetv2 model.<br> Since shuffle=True, the images will be shuffled after each epoch in training.<br>
"dir_path" is the location of the data. The expectation is that each class is saved in a different folder.

In [None]:
def img_data(dir_path,target_size,batch,class_list,preprocessing):
  if preprocessing:
    gen_object = ImageDataGenerator(preprocessing_function=preprocessing)
  else: gen_object=ImageDataGenerator()
  return (gen_object.flow_from_directory(dir_path,
                                       target_size=target_size,
                                       batch_size=batch,
                                       class_mode="sparse",
                                       classes=class_list,
                                       shuffle=True))

In [None]:
train_data_gen = img_data(train_path,(224,224),64,os.listdir(train_path),preprocess_input)
valid_data_gen = img_data(valid_path,(224,224),64,os.listdir(valid_path),preprocess_input)

# preprocess_input is a function that ensures the data is prepared in a way that is expected by
# the model being used (e.g. Mobilenetv2)

Found 4274 images belonging to 22 classes.
Found 110 images belonging to 22 classes.


In [None]:
os.listdir(train_path)

['Alstonia Scholaris diseased (P2a)',
 'Alstonia Scholaris healthy (P2b)',
 'Arjun diseased (P1a)',
 'Arjun healthy (P1b)',
 'Bael diseased (P4b)',
 'Basil healthy (P8)',
 'Chinar diseased (P11b)',
 'Chinar healthy (P11a)',
 'Gauva diseased (P3b)',
 'Gauva healthy (P3a)',
 'Jamun diseased (P5b)',
 'Jamun healthy (P5a)',
 'Jatropha diseased (P6b)',
 'Jatropha healthy (P6a)',
 'Lemon diseased (P10b)',
 'Lemon healthy (P10a)',
 'Mango diseased (P0b)',
 'Mango healthy (P0a)',
 'Pomegranate diseased (P9b)',
 'Pomegranate healthy (P9a)',
 'Pongamia Pinnata diseased (P7b)',
 'Pongamia Pinnata healthy (P7a)']

## 3. Import & Train Mobilenetv2

In [None]:
# Import mobilenetv2 model, which takes an input shape of 224,224,3 and is pre-trained with imagnet weights
base_model = tf.keras.applications.mobilenet_v2.MobileNetV2(
    input_shape=(224,224,3),
    alpha=1.0,
    include_top=False,
    weights='imagenet',
    input_tensor=None,
    pooling=None,
    classes=1000,
    classifier_activation='softmax'
)

In [None]:
base_model.summary()

Model: "mobilenetv2_1.00_224"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_3 (InputLayer)           [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 Conv1 (Conv2D)                 (None, 112, 112, 32  864         ['input_3[0][0]']                
                                )                                                                 
                                                                                                  
 bn_Conv1 (BatchNormalization)  (None, 112, 112, 32  128         ['Conv1[0][0]']                  
                                )                                              

In [None]:
base_model.trainable = False # fix the already trained weights

In [None]:
# Attach classifier to the base
model = tf.keras.models.Sequential()
model.add(base_model)
model.add(GlobalAveragePooling2D())
model.add(Dense(256,activation='relu'))
model.add(Dense(22,activation='softmax'))

In [None]:
model.summary()

Model: "sequential_7"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 mobilenetv2_1.00_224 (Funct  (None, 7, 7, 1280)       2257984   
 ional)                                                          
                                                                 
 global_average_pooling2d_7   (None, 1280)             0         
 (GlobalAveragePooling2D)                                        
                                                                 
 dense_12 (Dense)            (None, 256)               327936    
                                                                 
 dense_13 (Dense)            (None, 22)                5654      
                                                                 
Total params: 2,591,574
Trainable params: 333,590
Non-trainable params: 2,257,984
_________________________________________________________________


In [None]:
from keras.backend import categorical_crossentropy
#model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

In [None]:
#early_stop = callbacks.EarlyStopping(monitor="val_loss",patience=5, mode="min") # stop after 5 epochs if val_loss is not decreasing
#model_checkpoint_save = callbacks.ModelCheckpoint("model_wt.hd5", save_best_only=True,monitor="val_accuracy",mode='max') # save the best model

In [None]:
#model.fit(train_data_gen,batch_size=64, validation_data=valid_data_gen,callbacks=[early_stop,model_checkpoint_save],
#          epochs=10)

Epoch 1/10



Epoch 2/10



Epoch 3/10
Epoch 4/10



Epoch 5/10



Epoch 6/10
Epoch 7/10



Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7f866e402f70>

## 4. Deploy & Test with Blind Data:

In [1]:
!pip install -q streamlit # install streamlit for App creation

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.6/9.6 MB[0m [31m64.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m184.3/184.3 KB[0m [31m23.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m164.8/164.8 KB[0m [31m20.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m238.7/238.7 KB[0m [31m28.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m80.6/80.6 KB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.7/4.7 MB[0m [31m98.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 KB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m77.1 MB/s[0m e

In [5]:
%%writefile app.py

import streamlit as st
import tensorflow as tf
import zipfile
import urllib.request
import cv2
import numpy as np
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2,preprocess_input as mobilenet_v2_preprocess_input

#import streamlit as st
#import tensorflow as tf

# Load the model
model = tf.keras.models.load_model('/content/drive/MyDrive/Colab Notebooks/Datasets/PlantClassification_dataset/model_wt.hd5',compile=False)

uploaded_file = st.file_uploader("Choose an image file to predict", type="jpg")
st.text("ONLY upload one of the following: ")
st.text("Alstonia Scholaris/Arjun/Bael/Basil/Chinar/Mango")
st.text("/Gauva/Jamun/Jatropha/Lemon/Pomegranate/Pongamia")

map_dict = { 
  0:'Alstonia Scholaris diseased (P2a)',
 1: 'Alstonia Scholaris healthy (P2b)',
 2: 'Arjun diseased (P1a)',
 3: 'Arjun healthy (P1b)',
 4: 'Bael diseased (P4b)',
 5: 'Basil healthy (P8)',
 6: 'Chinar diseased (P11b)',
 7: 'Chinar healthy (P11a)',
 8: 'Gauva diseased (P3b)',
 9: 'Gauva healthy (P3a)',
 10: 'Jamun diseased (P5b)',
 11: 'Jamun healthy (P5a)',
 12: 'Jatropha diseased (P6b)',
 13: 'Jatropha healthy (P6a)',
 14: 'Lemon diseased (P10b)',
 15: 'Lemon healthy (P10a)',
 16: 'Mango diseased (P0b)',
 17: 'Mango healthy (P0a)',
 18: 'Pomegranate diseased (P9b)',
 19: 'Pomegranate healthy (P9a)',
 20: 'Pongamia Pinnata diseased (P7b)',
 21: 'Pongamia Pinnata healthy (P7a)'
}


if uploaded_file is not None:
    # Convert the file to an opencv image.
    file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
    opencv_image = cv2.imdecode(file_bytes, 1)
    opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_BGR2RGB)
    resized = cv2.resize(opencv_image,(224,224))

    # Display image:
    st.image(opencv_image, channels="RGB")

    resized = mobilenet_v2_preprocess_input(resized)
    img_reshape = resized[np.newaxis,...]

    # Predict class:

    Genrate_pred = st.button("Prediction is:")    
    if Genrate_pred:
        prediction = model.predict(img_reshape).argmax()
        st.title("{}".format(map_dict [prediction]))

Overwriting app.py


In [6]:
!streamlit run app.py & npx localtunnel --port 8501  # open webpage for deployment

[##................] - fetchMetadata: sill resolveWithNewModule is-fullwidth-co[0m[K
Collecting usage statistics. To deactivate, set browser.gatherUsageStats to False.
[0m
[0m
[34m[1m  You can now view your Streamlit app in your browser.[0m
[0m
[34m  Network URL: [0m[1mhttp://172.28.0.12:8501[0m
[34m  External URL: [0m[1mhttp://34.171.252.187:8501[0m
[0m
[K[?25hnpx: installed 22 in 1.726s
your url is: https://fancy-corners-spend-34-171-252-187.loca.lt
2023-03-16 20:43:09.740891: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-16 20:43:09.870348: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-p