In [1]:
import os
import sys

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from tensorflow import keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from keras.applications import inception_v3 as inc_net
from keras.preprocessing import image
from skimage.segmentation import mark_boundaries
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.efficientnet import preprocess_input

from pathlib import Path
from rembg import remove, new_session
from lime import lime_image

import sys
sys.path.append('../functions')

RSEED = 42
DATASET_PATH = '../data/images/' # Path to the parent folder where the original data is stored
TRAINING_IMAGES = ''
TESTING_IMAGES = ''

2024-04-07 20:13:17.032484: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-04-07 20:13:17.034693: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-04-07 20:13:17.109380: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-04-07 20:13:17.391102: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


ModuleNotFoundError: No module named 'skimage'

# 1. Consolidating classes

Upon closer inspection it became apparent that there are some classes with very bad images (=images that the model didn't train well on) or classes that should either be excluded or grouped together based on domain knowledge of crop diseases. Since the aim of this project is to provide farmers with treatments for the diseases, we'll also exclude diseases that are too broad to recommend any treatment for.

Removed classes:
- scab (bad training data)
- green_mottle (bad training data)
- gray_spot_rust (not specific enough for remedy)
- yellow_leaf (condition not disease)
- leaf_curl (condition not disease)
- leaf_blight
- leaf_scorch
- pests (too unspecific)
- nematode
- virus (too unspecific)

Merged classes:
- septoria has been merged with brown_spot
- phytophora has been merged with late_blight
- mosaic_disease has been merged with mosaic_virus

This reduced the final classes to 25 (24 diseases and 1 healthy class)
> alternaria_leaf_spot, bacterial_blight, bacterial_spot, bacterial_wilt, black_measles, black_rot, blast, brown_spot, brown_streak_disease, citrus_greening common_rust, early_blight, gray_leaf_spot, healthy, isariopsis_leaf_spot, late_blight, leaf_curl, leaf_mold, mosaic_disease, northern_leaf_blight, powdery_mildew,red_rot, spider_mites, target_spot, tungro

*We've removed and copied the classes manually, so there's no one-step solution to reproduce this step.*

# 2. Image augmentation

Image augmentation can make a model more robust and prevent overfitting by introducing more variance by altering the training data. 

In [None]:
# This code first defines the augmentations to perform on the training and validation data. The validation data only gets rescaled.

batch_size = 32
train_datagen = ImageDataGenerator(rescale=1./255,
                                   rotation_range=10,
                                   width_shift_range=0.1,
                                   height_shift_range=0.1,
                                   horizontal_flip=True,
                                   validation_split=0.2)  # val 20%

val_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)


train_data = train_datagen.flow_from_directory(TRAINING_IMAGES, 
                                               target_size=(224, 224), 
                                               color_mode='rgb',
                                               batch_size=batch_size, 
                                               class_mode='categorical',
                                               shuffle=True,
                                               subset = 'training') 

val_data = val_datagen.flow_from_directory(TRAINING_IMAGES, 
                                           target_size=(224, 224), 
                                           color_mode='rgb',
                                           batch_size=batch_size, 
                                           class_mode='categorical',
                                           shuffle=False,
                                           subset = 'validation')

In [None]:
# Use this code to display and check the augmented images
# show augmented images for training

# Get the first batch of images and labels
batch = next(train_data)

# Extract images from the batch
images = batch[0]

# Plot multiple images
plt.figure(figsize=(10, 10))
for i in range(20):  # Adjust the number of images to display
    plt.subplot(4, 5, i+1)
    plt.imshow(images[i])
    plt.axis('off')
plt.show()

# show augmented images for validation

# Get the first batch of images and labels
batch = next(val_data)

# Extract images from the batch
images = batch[0]

# Plot multiple images
plt.figure(figsize=(10, 10))
for i in range(4):  # Adjust the number of images to display
    plt.subplot(2, 2, i+1)
    plt.imshow(images[i])
    plt.axis('off')
plt.show()

This didn't help our model, so next we've had a look at what the model learned to be important using the explainer lime.

# 3. Background removal

Let's have a look at what the model bases its predictions on, or which parts of the images are most relevant to the model.

## 3.1 Image explainer (LIME)

In [None]:
# This code uses lime explainer on a sample image

IMAGE = '../data/external_test_data/bacterial_spot/bacterial-symptoms-pepper.jpg' # the image to be tested
MODEL = keras.models.load_model('../models/model_filtered.h5') # the model to be used for making predictions
CLASSES = [
    'alternaria_leaf_spot',
    'bacterial_blight',
    'bacterial_spot',
    'bacterial_wilt',
    'black_measles',
    'black_rot',
    'blast',
    'brown_spot',
    'brown_streak_disease',
    'citrus_greening',
    'common_rust',
    'early_blight',
    'gray_leaf_spot',
    'healthy',
    'isariopsis_leaf_spot',
    'late_blight',
    'leaf_curl',
    'leaf_mold',
    'mosaic_disease',
    'northern_leaf_blight',
    'powdery_mildew',
    'red_rot',
    'spider_mites',
    'target_spot',
    'tungro',
    ]

In [None]:
def transform_img_fn(path_list):
    out = []
    for img_path in path_list:
        img = image.load_img(img_path, target_size=(224, 224))
        x = image.img_to_array(img)
        x = np.expand_dims(x, axis=0)
        x = inc_net.preprocess_input(x)
        out.append(x)
    return np.vstack(out)

images = transform_img_fn([os.path.join(IMAGE)])
plt.imshow(images[0] / 2 + 0.5)
preds = MODEL.predict(images)

explainer = lime_image.LimeImageExplainer()
explanation = explainer.explain_instance(images[0].astype('double'), MODEL.predict, top_labels=5, hide_color=0, num_samples=1000)

temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=4, hide_rest=False)
plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))

temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=False, num_features=4, hide_rest=False)
plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))

images = transform_img_fn([os.path.join(IMAGE)])
preds = MODEL.predict(images)

def display_class_probabilities(model, img_path, class_names):
    # Load and preprocess the input data
    img = image.load_img(img_path, target_size=(224, 224)) 
    img_array = image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0)
    img_array = preprocess_input(img_array)

    # Get class probabilities
    probabilities = model.predict(img_array)[0]

# Get indices of the top 3 predicted classes
    top3_indices = np.argsort(probabilities)[::-1][:3]

    # Display top 3 predicted classes with probabilities
    print("Top 3 Predicted Classes:")
    count = 1
    for i in top3_indices:
        print(f"{count}. {class_names[i]}: {probabilities[i]}")
        count += 1

display_class_probabilities(MODEL, IMAGE, CLASSES)

We could see that the model targets large proportions of the background, so we tried how things look without a background.

## 3.2 Background removal

In [None]:
session = new_session()

# Define the input directory
input_directory = ''

# Define the output directory
output_directory = ''

# Walk through all directories and subdirectories
for x in next(os.walk(input_directory))[1]:
    for filename in os.listdir(os.path.join(input_directory, x)):
        # Construct the full file path for input and output files
        input_filepath = os.path.join(input_directory, x, filename)
        output_filepath = os.path.join(output_directory, x, filename)
        # Ensure the output directory exists, if not, create it
        os.makedirs(os.path.dirname(output_filepath), exist_ok=True)
        # Convert to Path objects for convenience
        input_path = Path(input_filepath)
        output_path = Path(output_filepath)
        # Check if the output file already exists
        if output_path.is_file():
            print(f"Output file {output_path} already exists. Skipping...")
            continue
        # Open the input image
        with open(input_filepath, 'rb') as input_file:
            input_data = input_file.read()
            # Perform your operation (e.g., using remove function)
            output_data = remove(input_data, session=session)
        # Write the processed data to the output file
        with open(output_filepath, 'wb') as output_file:
            output_file.write(output_data)

The process of removing backgrounds and converting the resulting images to .png files with alpha-layer proved to be rather time-consuming. We tried it on a small test smaple and since the explainer still marked large parts of the (now transparent) background as relevant, we omitted this approach for the time being. Another reason why we abandoned this technique is because automatically removing the background of some images rendered them almost empty and thus unrecognizable. One would have to spend a lot of time to sort those images out.

# 4. Model customization

The next step to further fine tune the model was to unfreeze some more layers of the pre-trained model and retrain them with our own data. This lead to a better model so we include this in the training of the final model.

In [None]:
# Example code to retrain a model and unfreeze the top 10 layers

def unfreeze_model_and_clone(model):
    # Clone the original model
    unfrozen_model = tf.keras.models.clone_model(model)
    unfrozen_model.set_weights(model.get_weights())  # Copy weights

    # Unfreeze the top 10 layers while leaving BatchNorm layers frozen
    for layer in unfrozen_model.layers[-10:]:
        if not isinstance(layer, tf.keras.layers.BatchNormalization):
            layer.trainable = True

    optimizer = tf.keras.optimizers.legacy.Adam(learning_rate=1e-5)
    unfrozen_model.compile(
        optimizer=optimizer, loss="categorical_crossentropy", metrics=["accuracy"]
    )
    
    return unfrozen_model

# Create a new model with unfrozen layers
unfrozen_model = unfreeze_model_and_clone(model)

epochs = 2
hist = unfrozen_model.fit(train_ds, epochs=epochs, validation_data=val_ds)

# Save the model to disk
unfrozen_model.save("unfrozen_model.h5")

# Check performance of the unfrozen model

def plot_hist(hist):
    plt.plot(hist.history["accuracy"])
    plt.plot(hist.history["val_accuracy"])
    plt.title("model accuracy")
    plt.ylabel("accuracy")
    plt.xlabel("epoch")
    plt.legend(["train", "validation"], loc="upper left")
    plt.show()

plot_hist(hist)
