 <h1><b><center> Cassava Leaf Disease Classification </center></b></h1>


## Key takeaways from this notebook:

1. Standing in the shoes of a beginner I understand how tricky it is to understand a problem statement and solve it in your initial roadmap to become a Data Scientist. This notebook mainly focuses on the basics of understanding a problem statement and implementing it's solution.

2. This notebook contains lucid explanation for every code written and it's purpose in the modelling of our Deep Learning model

3. Using this as a baseline model we can provide our predictions with very good accuracy for a noob!

# 1. Problem Description

As the second-largest provider of carbohydrates in Africa, cassava is a key food security crop grown by smallholder farmers because it can withstand harsh conditions. At least 80% of household farms in Sub-Saharan Africa grow this starchy root, but viral diseases are major sources of poor yields. With the help of data science, it may be possible to identify common diseases so they can be treated.

Existing methods of disease detection require farmers to solicit the help of government-funded agricultural experts to visually inspect and diagnose the plants. This suffers from being labor-intensive, low-supply and costly. As an added challenge, effective solutions for farmers must perform well under significant constraints, since African farmers may only have access to mobile-quality cameras with low-bandwidth.

In this competition, we introduce a dataset of 21,397 labeled images collected during a regular survey in Uganda. Most images were crowdsourced from farmers taking photos of their gardens, and annotated by experts at the National Crops Resources Research Institute (NaCRRI) in collaboration with the AI lab at Makerere University, Kampala. This is in a format that most realistically represents what farmers would need to diagnose in real life.


**Problem statement:**

Our task is to classify each cassava image into four disease categories or a fifth category indicating a healthy leaf. With your help of this farmers may be able to quickly identify diseased plants, potentially saving their crops before they inflict irreparable damage.

Source: https://www.kaggle.com/c/cassava-leaf-disease-classification

# 2. Problem overview

### 2.1 Data Description

There are 21397 train images consisting of 5 different class labels.


### 2.2 ML/DL Problem

It is a multiclass classification problem where we have to classify a given leaf image into one of the 5 categories.

### 2.3 Performance metric

Since we have to observe how accurate our model is working the performance metric is categorical accuracy score (Accuracy score for multiclass classification), where

Accuracy score = Number of correctly classified images / Total number of images

# 3. Exploratory Data Analysis

### 3.1 Importing data and necessary libraries

In [None]:
import os
import cv2
import json
import numpy as np
import pandas as pd
import seaborn as sns
from tensorflow import keras
import matplotlib.pyplot as plt

### 3.2 Loading the data and visualizing the images

In [None]:
df = pd.read_csv('../input/cassava-leaf-disease-classification/train.csv')
df['label'] = df['label'].astype('str')
df.head()

As you can see there are 2 columns in our dataframe, one consisting of the **image id** of our training data and other consisting the **labels** for the specific image.



In [None]:
print(f'There are total {df.shape[0]} images in our train data.')

Initially let us first understand what our output labels mean and then we can relate them to the images.

In [None]:
with open('../input/cassava-leaf-disease-classification/label_num_to_disease_map.json', 'r') as file:
    labels = json.load(file)
    
labels

From the above loaded json file we can understand that labels 0, 1, 2 and 3 are some kind of disease plants and label 4 means that the leaves are healthy. Anyways the purpose of this json file is just to understand the disease name w.r.t. labels

In [None]:
# defining some variables which will be useful later
TRAIN_PATH = '../input/cassava-leaf-disease-classification/train_images'
IMAGE_WIDTH = 300
IMAGE_HEIGHT = 300
EPOCHS = 20
BATCH_SIZE = 32

In [None]:
plt.figure(figsize=(16, 12))
df_sample = df.sample(12).reset_index(drop=True)
for i in range(9):
    plt.subplot(3, 3, i+1)
    img = cv2.imread(os.path.join(TRAIN_PATH, df_sample.image_id[i]))
    img = cv2.resize(img, (IMAGE_HEIGHT, IMAGE_WIDTH))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    plt.axis('off')
    plt.imshow(img)
    plt.title(labels.get(df_sample.label[i]))
plt.tight_layout()
plt.show()

## Let us look at the types of diseases:

### 1 - Cassava Bacterial Blight (CBB)

Xanthomonas axonopodis pv. manihotis is the pathogen that causes bacterial blight of cassava. Originally discovered in Brazil in 1912, the disease has followed cultivation of cassava across the world. Among diseases which afflict cassava worldwide, bacterial blight causes the largest losses in terms of yield.

#### Symptoms:

Symptoms include leaf spotting, wilting, dying, gum oozing on young shoots, and vascular coloration of mature stems and roots of susceptible varieties.

In [None]:
plt.figure(figsize=(12, 7))
df_sample = df[df.label == '0'].sample(3).reset_index(drop=True)
for i in range(3):
    plt.subplot(1, 3, i+1)
    img = cv2.imread(os.path.join(TRAIN_PATH, df_sample.image_id[i]))
    img = cv2.resize(img, (IMAGE_HEIGHT, IMAGE_WIDTH))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    plt.axis('off')
    plt.imshow(img)
    plt.title(labels.get(df_sample.label[i]))
plt.tight_layout()
plt.show()

### 2 - Cassava Brown Streak Disease (CBSD)

Cassava brown streak virus disease (CBSD) is a damaging disease of cassava plants, and is especially troublesome in East Africa. It was first identified in 1936 in Tanzania, and has spread to other coastal areas of East Africa, from Kenya to Mozambique. Recently, it was found that two distinct viruses are responsible for the disease: cassava brown streak virus (CBSV) and Ugandan cassava brown streak virus (UCBSV).

#### Symptoms:

CBSD is characterized by severe chlorosis and necrosis on infected leaves, giving them a yellowish, mottled appearance.
Chlorosis may be associated with the veins, spanning from the mid vein, secondary and tertiary veins, or rather in blotches unconnected to veins.
Leaf symptoms vary greatly depending on a variety of factors.
The growing conditions (i.e. altitude, rainfall quantity), plant age, and the virus species account for these differences.
Brown streaks may appear on the stems of the cassava plant. Also, a dry brown-black necrotic rot of the cassava tuber exists, which may progress from a small lesion to the whole root.
Finally, the roots can become constricted due to the tuber rot, stunting growth

In [None]:
plt.figure(figsize=(12, 7))
df_sample = df[df.label == '1'].sample(3).reset_index(drop=True)
for i in range(3):
    plt.subplot(1, 3, i+1)
    img = cv2.imread(os.path.join(TRAIN_PATH, df_sample.image_id[i]))
    img = cv2.resize(img, (IMAGE_HEIGHT, IMAGE_WIDTH))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    plt.axis('off')
    plt.imshow(img)
    plt.title(labels.get(df_sample.label[i]))
plt.tight_layout()
plt.show()

### 3 - Cassava Green Mottle (CGM)

It has not been confirmed to be a nepovirus; these are viruses that are transmitted by nematodes - hence the name. Narrow. Only known from Solomon Islands. It was first found on Choiseul in the 1970s; more recently (2010), similar symptoms were seen on Malaita.

#### Symptoms:

Look for yellow patterns on the leaves, from small dots to irregular patches of yellow and green.
Look for leaf margins that are distorted.
The plants may be stunted.

In [None]:
plt.figure(figsize=(12, 7))
df_sample = df[df.label == '2'].sample(3).reset_index(drop=True)
for i in range(3):
    plt.subplot(1, 3, i+1)
    img = cv2.imread(os.path.join(TRAIN_PATH, df_sample.image_id[i]))
    img = cv2.resize(img, (IMAGE_HEIGHT, IMAGE_WIDTH))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    plt.axis('off')
    plt.imshow(img)
    plt.title(labels.get(df_sample.label[i]))
plt.tight_layout()
plt.show()

### 4 - Cassava Mosaic Disease (CMD)

Cassava mosaic virus is the common name used to refer to any of eleven different species of plant pathogenic virus in the genus Begomovirus. African cassava mosaic virus (ACMV), East African cassava mosaic virus (EACMV), and South African cassava mosaic virus (SACMV) are distinct species of circular single-stranded DNA viruses which are transmitted by whiteflies and primarily infect cassava plants; these have thus far only been reported from Africa.

#### Symptoms:

Initially following infection of a cassava geminivirus in cassava, systemic symptoms develop.
These symptoms include chlorotic mosaic of the leaves, leaf distortion, and stunted growth.
Leaf stalks have a characteristic S-shape.
Infection can be overcome by the plant especially when a rapid onset of symptoms occurs. A slow onset of disease development usually correlates with death of the plant.
affected by whiteflies
affected by environmental factors such as temperature, wind, precipitation and plant density

In [None]:
plt.figure(figsize=(12, 7))
df_sample = df[df.label == '3'].sample(3).reset_index(drop=True)
for i in range(3):
    plt.subplot(1, 3, i+1)
    img = cv2.imread(os.path.join(TRAIN_PATH, df_sample.image_id[i]))
    img = cv2.resize(img, (IMAGE_HEIGHT, IMAGE_WIDTH))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    plt.axis('off')
    plt.imshow(img)
    plt.title(labels.get(df_sample.label[i]))
plt.tight_layout()
plt.show()

### 3.3 Understanding distribution of train data

Let us understand how our class labels are distributed

In [None]:
plt.figure(figsize=(20, 6))

plt.suptitle('Distribution of class labels', fontsize=20)
plt.subplot(1, 2, 1)
plt.pie(
    df.label.value_counts(), 
    labels=[0, 1, 2, 3, 4], 
    autopct='%d%%', 
    explode=[0.01, 0.01, 0.01, 0.01, 0.01]
)
plt.title('Pie distribution of class labels')
plt.legend()

plt.subplot(1, 2, 2)
sns.countplot(df.label.values)
plt.title('Bar distribution of class labels')
plt.show()

As you can see there is definitely a huge imbalance between the disease class labels. Approximately 61% data belongs to class 3 (Cassava Mosaic Disease (CMD) disease) and remaining 39% of data is divided into the remaining classes.

# 4. Deep learning model

### 4.1 Image augmentation using Keras Image Data Generator

Image augmentation is a technique of applying different transformations to original images which results in multiple transformed copies of the same image. Each copy, however, is different from the other in certain aspects depending on the augmentation techniques you apply like shifting, rotating, flipping, etc.

Applying these small amounts of variations on the original image does not change its target class but only provides a new perspective of capturing the object in real life. And so, we use it is quite often for building deep learning models. Keras ImageDataGenerator is a gem! It lets you augment your images in real-time while your model is still training! You can apply any random transformations on each training image as it is passed to the model. This will not only make your model robust but will also save up on the overhead memory.

#### Advantages of using Keras Image Data Generator:

The main benefit of using the Keras ImageDataGenerator class is that it is designed to provide real-time data augmentation. Meaning it is generating augmented images **on the fly** while your model is still in the training stage. But it only returns the transformed images and does not add it to the original corpus of images. If it was, in fact, the case, then the model would be seeing the original images multiple times which would definitely overfit our model.

Another advantage of ImageDataGenerator is that it requires **lower memory usage**. This is so because without using this class, we load all the images at once. But on using it, we are loading the images in batches which saves a lot of memory.

To understand keras parameters for image data generator: <a href='https://keras.io/api/preprocessing/image/'>click here</a>

In [None]:
train_datagen = keras.preprocessing.image.ImageDataGenerator(
    horizontal_flip=True,
    vertical_flip=True,
    rotation_range=20,
    shear_range=20,
    zoom_range=0.2,
    height_shift_range=0.1,
    width_shift_range=0.1,
    validation_split=0.2
)

train_imagegen = train_datagen.flow_from_dataframe(
    df,
    directory='../input/cassava-leaf-disease-classification/train_images',
    x_col='image_id',
    y_col='label',
    subset='training',
    target_size=(IMAGE_HEIGHT, IMAGE_WIDTH),
    class_mode='categorical',
    batch_size=BATCH_SIZE
)

Since we have provided validation_split=0.2 as a parameter, out of 21397 images our train data generator has generated 17118 images for train data.

Here class_mode='categorical' means that for eg. label=4 will be encoded as [0, 0, 0, 0, 1].

For each image it's size will be (300, 300, 3).

In [None]:
valid_datagen = keras.preprocessing.image.ImageDataGenerator(
    validation_split=0.2
)

valid_imagegen = valid_datagen.flow_from_dataframe(
    df,
    directory='../input/cassava-leaf-disease-classification/train_images',
    x_col='image_id',
    y_col='label',
    subset='validation',
    target_size=(IMAGE_HEIGHT, IMAGE_WIDTH),
    class_mode='categorical',
    batch_size=BATCH_SIZE
)

### 4.2 Creating DL model

For creating our Deep Learning model we will be using pretrained model '**Xception**' which s a type of Transfer Learning techniques.

**What is Transfer Learning?**

Transfer learning is a machine learning method where a model developed for a task is reused as the starting point for a model on a second task.

It is a popular approach in deep learning where pre-trained models are used as the starting point on computer vision and natural language processing tasks given the vast compute and time resources required to develop neural network models on these problems and from the huge jumps in skill that they provide on related problems.

To check the overall list of models provided by Keras as a part of transfer learning: <a href='https://keras.io/api/applications/'>click here</a>

The structure of our deep learning model is as follows:

1. Xception model as a part of transfer learning application by Keras.
2. Global Average Pooling technique to reduce the image shape and apply pooling on spatial dimensions.
3. Dense layer to provide the probability of predictions for all 5 classes, acting as a output layer.

If you are new to Deep Learning and want to understand the functionality of Global Average Pooling layer: <a href='https://adventuresinmachinelearning.com/global-average-pooling-convolutional-neural-networks/'>click here</a>

In [None]:
model = keras.models.Sequential()
model.add(keras.applications.Xception(input_shape=(IMAGE_HEIGHT, IMAGE_WIDTH, 3), 
                                            weights='imagenet', include_top=False))
model.add(keras.layers.GlobalAveragePooling2D())
model.add(keras.layers.Dense(5, activation='softmax'))
print(model.summary())

Let us look at the structure of our Deep Learning model

In [None]:
keras.utils.plot_model(model)

Let us define some of the callbacks for our DL model.

A callback is an object that can perform actions at various stages of training (e.g. at the start or end of an epoch, before or after a single batch, etc). 

To understand more <a href='https://keras.io/api/callbacks/'>click here</a>

In this model we will be using 3 callbacks as below:

1. **ModelCheckpoint** : Callback to save the Keras model or model weights at some frequency.

In [None]:
model_checkpoint = keras.callbacks.ModelCheckpoint(
    './best_weights.h5',
    monitor="val_loss",
    verbose=1,
    save_best_only=True,
    save_weights_only=True,
    mode="min"
)

2. **EarlyStopping** : Stop training when a monitored metric has stopped improving.

In [None]:
early_stopping = keras.callbacks.EarlyStopping(
    monitor="val_loss",
    min_delta=0.001,
    patience=5,
    verbose=1,
    mode="min",
    restore_best_weights=True,
)

3. **ReduceLROnPlateau** : Reduce learning rate when a metric has stopped improving.

In [None]:
reduce_lr = keras.callbacks.ReduceLROnPlateau(
    monitor="val_loss",
    factor=0.1,
    patience=2,
    verbose=1,
    mode="min",
    min_delta=0.001,
)

For this model we will be using **Adam** optimizer and since our output labels are categorical we will be using **categorical_crossentropy** as a loss function.

In [None]:
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

We have defined everything we need, it's time to train the model...

In [None]:
history = model.fit_generator(
    train_imagegen,
    epochs=EPOCHS,
    steps_per_epoch=(len(df)*0.8) // BATCH_SIZE,
    validation_data=valid_imagegen,
    validation_steps=(len(df)*0.2) // BATCH_SIZE,
    callbacks = [model_checkpoint, early_stopping, reduce_lr]
)

Since the performance metric for this competition is accuracy, let us plot the train and validation accuracy to monitor our model performance.

In [None]:
plt.figure(figsize=(15, 5))
plt.plot(history.history['accuracy'], 'b*-', label="train_acc")
plt.plot(history.history['val_accuracy'], 'r*-', label="val_acc")
plt.grid()
plt.title("train_acc vs val_acc")
plt.ylabel("Accuracy")
plt.xlabel("Epochs")
plt.legend()
plt.show()

Being a noob myself I scratched my head for some time and later understood the requirements for notebook submission is that there should be **no internet connection** for the notebook which we will use to submit our results, but since we have used Transfer learning to model our Deep Learning model we require a internet connection.

This is because since we wanted '**imagenet**' weights for our '**Xception**' model it fetches the weights from google storage in realtime using internet connection.

So this notebook will be our baseline model where we will obtain the weights after training our model on train images and later on we will use these weights in a different notebook to predict on our test images.

**PS:**

The link to prediction notebook is : https://www.kaggle.com/pndeepak/cassava-predict

This prediction notebook explains how to load our model weights and use them to predict labels on test images in lucid way.

<h2>Please upvote this notebook if you learnt atleast something from it and share it with others as well.</h2>