Import Libraries

In [None]:
!pip install -q efficientnet

In [None]:
import os
import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
from functools import partial
warnings.filterwarnings('ignore')
import json

from sklearn import model_selection

import tensorflow as tf
from tensorflow.keras.applications import Xception
from tensorflow.keras.applications import InceptionV3
import efficientnet.tfkeras as efn
from tensorflow.keras.layers import GlobalAveragePooling2D,Dense,Input,Flatten
from tensorflow.keras.models import Model,Sequential
from kaggle_datasets import KaggleDatasets
import re
import math

import matplotlib.pyplot as plt
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix

In [None]:
# Checking the tf version
print(tf.__version__)

In [None]:
TOP_DIR = '../input/cassava-leaf-disease-classification/'
print(os.listdir(TOP_DIR))

Creating Run strategy

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver() 
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy() 

print("REPLICAS: ", strategy.num_replicas_in_sync)

Set Cloud Path

In [None]:
GCS_DS_PATH = KaggleDatasets().get_gcs_path('cassava-leaf-disease-classification')
GCS_DS_PATH

Initialize the required stuffs

In [None]:
IMAGE_SIZE = 600
EPOCHS = 50
BATCH_SIZE = 16 * strategy.num_replicas_in_sync

NUM_TRAINING_IMAGES = len(os.listdir('../input/cassava-leaf-disease-classification/train_images'))
NUM_TEST_IMAGES = len(os.listdir('../input/cassava-leaf-disease-classification/test_images'))
STEPS_PER_EPOCH = NUM_TRAINING_IMAGES // BATCH_SIZE
AUTO = tf.data.experimental.AUTOTUNE

Creating Train validation and test splits

In [None]:
TRAINING_SAMPLE,VAL_SAMPLE = model_selection.train_test_split(tf.io.gfile.glob(GCS_DS_PATH + '/train_tfrecords/*.tfrec'),
                                                              test_size=0.2,random_state=2021)
TEST_SAMPLE = tf.io.gfile.glob(GCS_DS_PATH + '/test_tfrecords/*.tfrec')                                                        

Check the size of each chunks

In [None]:
def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

NUM_TRAINING_IMAGES = count_data_items(TRAINING_SAMPLE)
NUM_VALIDATION_IMAGES = count_data_items(VAL_SAMPLE)
NUM_TEST_IMAGES = count_data_items(TEST_SAMPLE)
print('Dataset: {} training images, {} validation images, {} unlabeled test images'.format(NUM_TRAINING_IMAGES, NUM_VALIDATION_IMAGES, NUM_TEST_IMAGES))

Reading the CSV

In [None]:
train_df = pd.read_csv('../input/cassava-leaf-disease-classification/train.csv')
sample_df = pd.read_csv('../input/cassava-leaf-disease-classification/sample_submission.csv')

In [None]:
train_df.head()

Initialize the required data pipeline

In [None]:
def data_augment(image, label):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image,)
    image = tf.image.rot90(image, k = np.random.randint(4))
    return image, label 

def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32) / 255.0  
    image = tf.image.resize(image, [IMAGE_SIZE, IMAGE_SIZE])
    image = tf.reshape(image, [IMAGE_SIZE, IMAGE_SIZE, 3])
    return image

def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), 
        "target": tf.io.FixedLenFeature([], tf.int64),  
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    label = tf.cast(example['target'], tf.int32)
    return image, label 

def read_unlabeled_tfrecord(example):
    UNLABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), 
        "id": tf.io.FixedLenFeature([], tf.string), 
    }
    example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    idnum = example['id']
    return image, idnum

def load_dataset(filenames, labeled=True, ordered=False):
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False 

    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) 
    dataset = dataset.with_options(ignore_order)
    dataset = dataset.map(read_labeled_tfrecord if labeled else read_unlabeled_tfrecord, num_parallel_calls=AUTO)
    return dataset

def get_training_dataset():
    dataset = load_dataset(tf.io.gfile.glob(TRAINING_SAMPLE), labeled=True)
    dataset = dataset.map(data_augment, num_parallel_calls=AUTO)
    dataset = dataset.repeat() # the training dataset must repeat for several epochs
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) 
    return dataset

def get_validation_dataset():
    dataset = load_dataset(tf.io.gfile.glob(VAL_SAMPLE), labeled=True, ordered=False)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTO) 
    return dataset

def get_test_dataset(ordered=False):
    dataset = load_dataset(tf.io.gfile.glob(TEST_SAMPLE), labeled=False, ordered=ordered)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) 
    return dataset

Creating Train and validation dataset by calling the data pipeline

In [None]:
training_dataset = get_training_dataset()
validation_dataset = get_validation_dataset()

Read json file to check the label mapping

In [None]:
with open('../input/cassava-leaf-disease-classification/label_num_to_disease_map.json') as f:
    labels_map=json.load(f)
    print(labels_map)

In [None]:
CLASSES = ['Cassava Bacterial Blight (CBB)','Cassava Brown Streak Disease (CBSD)','Cassava Green Mottle (CGM)',
           'Cassava Mosaic Disease (CMD)','Healthy']

View Data from the folders

In [None]:
def batch_to_numpy_images_and_labels(data):
    images, labels = data
    numpy_images = images.numpy()
    numpy_labels = labels.numpy()
    if numpy_labels.dtype == object: # binary string in this case,
                                     # these are image ID strings
        numpy_labels = [None for _ in enumerate(numpy_images)]
    # If no labels, only image IDs, return None for labels (this is
    # the case for test data)
    return numpy_images, numpy_labels

def title_from_label_and_target(label, correct_label):
    if correct_label is None:
        return CLASSES[label], True
    correct = (label == correct_label)
    return "{} [{}{}{}]".format(CLASSES[label], 'OK' if correct else 'NO', u"\u2192" if not correct else '',
                                CLASSES[correct_label] if not correct else ''), correct

def display_one_flower(image, title, subplot, red=False, titlesize=1):
    plt.subplot(*subplot)
    plt.axis('off')
    plt.imshow(image)
    if len(title) > 0:
        plt.title(title, fontsize=int(titlesize) if not red else int(titlesize/1.2), color='red' if red else 'black', fontdict={'verticalalignment':'center'})
    return (subplot[0], subplot[1], subplot[2]+1)


def display_batch_of_images(databatch, predictions=None):
    """This will work with:
    display_batch_of_images(images)
    display_batch_of_images(images, predictions)
    display_batch_of_images((images, labels))
    display_batch_of_images((images, labels), predictions)
    """
    # data
    images, labels = batch_to_numpy_images_and_labels(databatch)
    if labels is None:
        labels = [None for _ in enumerate(images)]
        
    # auto-squaring: this will drop data that does not fit into square
    # or square-ish rectangle
    rows = int(math.sqrt(len(images)))
    cols = len(images)//rows
        
    # size and spacing
    FIGSIZE = 20.0
    SPACING = 0.1
    subplot=(rows,cols,1)
    if rows < cols:
        plt.figure(figsize=(FIGSIZE,FIGSIZE/cols*rows))
    else:
        plt.figure(figsize=(FIGSIZE/rows*cols,FIGSIZE))
    
    # display
    for i, (image, label) in enumerate(zip(images[:rows*cols], labels[:rows*cols])):
        title = '' if label is None else CLASSES[label]
        correct = True
        if predictions is not None:
            title, correct = title_from_label_and_target(predictions[i], label)
        dynamic_titlesize = FIGSIZE*SPACING/max(rows,cols)*40+3 # magic formula tested to work from 1x1 to 10x10 images
        subplot = display_one_flower(image, title, subplot, not correct, titlesize=dynamic_titlesize)
    
    #layout
    plt.tight_layout()
    if label is None and predictions is None:
        plt.subplots_adjust(wspace=0, hspace=0)
    else:
        plt.subplots_adjust(wspace=SPACING, hspace=SPACING)
    plt.show()
    
def display_training_curves(training, validation, title, subplot):
    if subplot%10==1: # set up the subplots on the first call
        plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
        plt.tight_layout()
    ax = plt.subplot(subplot)
    ax.set_facecolor('#F8F8F8')
    ax.plot(training)
    ax.plot(validation)
    ax.set_title('model '+ title)
    ax.set_ylabel(title)
    #ax.set_ylim(0.28,1.05)
    ax.set_xlabel('epoch')
    ax.legend(['train', 'valid.'])

In [None]:
ds_iter = iter(training_dataset.unbatch().batch(20))

In [None]:
one_batch = next(ds_iter)
display_batch_of_images(one_batch)

There are issues in the dataset sometime healty leafs also looks like disease 

Defining the call backs for training

In [None]:
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    "EfficientNetB7.h5",
    save_best_only=True,
    monitor = 'val_loss',
    mode='min'
)
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor = 'val_loss',
                                  factor = 0.3,
                                  patience = 3,
                                  min_lr = 1e-5,
                                  mode = 'min',
                                  verbose = 1)

early_stopping_cb = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    mode='min', 
    patience=5,
    restore_best_weights=True, 
    verbose=1
)
callbacks = [checkpoint_cb,reduce_lr,early_stopping_cb]

Training the model via calling the strategY

In [None]:
with strategy.scope():    
    pretrained_model = efn.EfficientNetB7(weights=None, include_top=False ,input_shape=[IMAGE_SIZE,IMAGE_SIZE, 3])
    pretrained_model.trainable = False 
    
    model = tf.keras.Sequential([
        pretrained_model,
        tf.keras.layers.Dropout(0.25),
        tf.keras.layers.Dense(1024),
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dropout(0.25),
        tf.keras.layers.Dense(5, activation='softmax')
    ])
    model.add_weight('../input/keras-pretrained-models/EfficientNetB7_NoTop_ImageNet.h5')
model.compile(
    optimizer='adam',
    loss = 'sparse_categorical_crossentropy',
    metrics=['sparse_categorical_accuracy']
)

historical = model.fit(training_dataset, 
          steps_per_epoch=STEPS_PER_EPOCH, 
          epochs=EPOCHS, 
          validation_data=validation_dataset,
                      callbacks=callbacks)

In [None]:
display_training_curves(
    historical.history['loss'],
    historical.history['val_loss'],
    'loss',
    211,
)
display_training_curves(
    historical.history['sparse_categorical_accuracy'],
    historical.history['val_sparse_categorical_accuracy'],
    'accuracy',
    212,
)

In [None]:
def display_confusion_matrix(cmat, score, precision, recall):
    plt.figure(figsize=(15,15))
    ax = plt.gca()
    ax.matshow(cmat, cmap='Reds')
    ax.set_xticks(range(len(CLASSES)))
    ax.set_xticklabels(CLASSES, fontdict={'fontsize': 7})
    plt.setp(ax.get_xticklabels(), rotation=45, ha="left", rotation_mode="anchor")
    ax.set_yticks(range(len(CLASSES)))
    ax.set_yticklabels(CLASSES, fontdict={'fontsize': 7})
    plt.setp(ax.get_yticklabels(), rotation=45, ha="right", rotation_mode="anchor")
    titlestring = ""
    if score is not None:
        titlestring += 'f1 = {:.3f} '.format(score)
    if precision is not None:
        titlestring += '\nprecision = {:.3f} '.format(precision)
    if recall is not None:
        titlestring += '\nrecall = {:.3f} '.format(recall)
    if len(titlestring) > 0:
        ax.text(101, 1, titlestring, fontdict={'fontsize': 18, 'horizontalalignment':'right', 'verticalalignment':'top', 'color':'#804040'})
    plt.show()
    
def display_training_curves(training, validation, title, subplot):
    if subplot%10==1: # set up the subplots on the first call
        plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
        plt.tight_layout()
    ax = plt.subplot(subplot)
    ax.set_facecolor('#F8F8F8')
    ax.plot(training)
    ax.plot(validation)
    ax.set_title('model '+ title)
    ax.set_ylabel(title)
    #ax.set_ylim(0.28,1.05)
    ax.set_xlabel('epoch')
    ax.legend(['train', 'valid.'])

In [None]:
cmdataset = get_validation_dataset()
images_ds = cmdataset.map(lambda image, label: image)
labels_ds = cmdataset.map(lambda image, label: label).unbatch()

cm_correct_labels = next(iter(labels_ds.batch(NUM_VALIDATION_IMAGES))).numpy()
cm_probabilities = model.predict(images_ds)
cm_predictions = np.argmax(cm_probabilities, axis=-1)

labels = range(len(CLASSES))
cmat = confusion_matrix(
    cm_correct_labels,
    cm_predictions,
    labels=labels,
)
cmat = (cmat.T / cmat.sum(axis=1)).T # normalize

In [None]:
score = f1_score(
    cm_correct_labels,
    cm_predictions,
    labels=labels,
    average='macro',
)
precision = precision_score(
    cm_correct_labels,
    cm_predictions,
    labels=labels,
    average='macro',
)
recall = recall_score(
    cm_correct_labels,
    cm_predictions,
    labels=labels,
    average='macro',
)
display_confusion_matrix(cmat, score, precision, recall)

In [None]:
dataset = get_validation_dataset()
dataset = dataset.unbatch().batch(20)
batch = iter(dataset)

In [None]:
images, labels = next(batch)
probabilities = model.predict(images)
predictions = np.argmax(probabilities, axis=-1)
display_batch_of_images((images, labels), predictions)