## This notebook is inspired from [Getting Started: TPUs + Cassava Leaf Disease](https://www.kaggle.com/jessemostipak/getting-started-tpus-cassava-leaf-disease)

## Import modules

In [None]:
import re
import os
import math

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras

from functools import partial
from kaggle_datasets import KaggleDatasets
from sklearn.model_selection import train_test_split
print("Tensorflow version " + tf.__version__)

## Initialize TPU

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
    
except:
    strategy = tf.distribute.get_strategy()
    
print('Number of replicas in sync:', strategy.num_replicas_in_sync)

## Set up some constant variables

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
GCS_PATH = KaggleDatasets().get_gcs_path()
BATCH_SIZE = 16 * strategy.num_replicas_in_sync
IMAGE_SIZE = [512, 512]
CLASSES = ['0', '1', '2', '3', '4']
CLASS_NAMES = ['Cassava Bacterial Blight', 
               'Cassava Brown Streak Disease', 
               'Cassava Green Mottle', 
               'Cassava Mosaic Disease', 
               'Healthy']
EPOCHS = 7

## Data decoding

In [None]:
def decode_image(image):
    image = tf.image.decode_jpeg(image, channels = 3)
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

## Parse single example from TFRecord format

In [None]:
def read_tfrecord(example, labeled):
    tfrecord_format = {'image': tf.io.FixedLenFeature([], tf.string), 
                       'target': tf.io.FixedLenFeature([], tf.int64)} if labeled else \
     {'image': tf.io.FixedLenFeature([], tf.string), 
      'image_name': tf.io.FixedLenFeature([], tf.string)}
    
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    if labeled:
        label = tf.cast(example['target'], tf.int32)
        return image, label
    idnum = example['image_name']
    return image, idnum

## Read TFRecords dataset

In [None]:
def load_dataset(filenames, labeled = True, ordered = False):
    # For optimal performance, reading from multiple files at once
    # Order does not matter since we will be shuffling the data anyway
    ignore_order = tf.data.Options()
    if not ordered:
        # disable order, increase speed
        ignore_order.experimental_deterministic = False 
        
    # automatically interleaves reads from multiple files
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads = AUTOTUNE) 
    
    # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.with_options(ignore_order) 
    
    dataset = dataset.map(partial(read_tfrecord, labeled = labeled), num_parallel_calls = AUTOTUNE)
    return dataset

## Train validation split

In [None]:
TRAIN_FILENAMES, VAL_FILENAMES = train_test_split(
    tf.io.gfile.glob(GCS_PATH + '/train_tfrecords/ld_train*.tfrec'),
    test_size = 0.2, random_state = 5)

TEST_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/test_tfrecords/ld_test*.tfrec')

## Data Augmentation

In [None]:
def data_augment(image, label): 
    image = tf.image.random_flip_left_right(image)
    return image, label

## Training data pipeline

In [None]:
def get_training_dataset():
    dataset = load_dataset(TRAIN_FILENAMES, labeled = True)  
    dataset = dataset.map(data_augment, num_parallel_calls = AUTOTUNE)  
    dataset = dataset.repeat().shuffle(2048).batch(BATCH_SIZE).prefetch(AUTOTUNE)
    return dataset

## Validation data pipeline

In [None]:
def get_validation_dataset(ordered = False):
    dataset = load_dataset(VAL_FILENAMES, labeled = True, ordered = ordered) 
    dataset = dataset.batch(BATCH_SIZE).cache().prefetch(AUTOTUNE)
    return dataset

## Test data pipeline

In [None]:
def get_test_dataset(ordered = False):
    dataset = load_dataset(TEST_FILENAMES, labeled = False, ordered = ordered)
    dataset = dataset.batch(BATCH_SIZE).prefetch(AUTOTUNE)
    return dataset

In [None]:
def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

In [None]:
# check the number of training, validation and test examples
NUM_TRAINING_IMAGES = count_data_items(TRAIN_FILENAMES)
NUM_VALIDATION_IMAGES = count_data_items(VAL_FILENAMES)
NUM_TEST_IMAGES = count_data_items(TEST_FILENAMES)

print('Dataset: {} training images, {} validation images, {} (unlabeled) test images'.format(
    NUM_TRAINING_IMAGES, NUM_VALIDATION_IMAGES, NUM_TEST_IMAGES))

## Check the dataset output shapes

In [None]:
print("Training data shapes:")
for image, label in get_training_dataset().take(3):
    print(image.numpy().shape, label.numpy().shape)
print('-------')

print("Validation data shapes:")
for image, label in get_validation_dataset().take(3):
    print(image.numpy().shape, label.numpy().shape)
print('------')

print("Test data shapes:")
for image, idnum in get_test_dataset().take(3):
    print(image.numpy().shape, idnum.numpy().shape)
print("Test data IDs:", idnum.numpy().astype('U')) # U = unicode string

## Display few images

In [None]:
# numpy and matplotlib defaults
np.set_printoptions(threshold = 15, linewidth = 80)

def batch_to_numpy_images_and_labels(data):
    images, labels = data
    numpy_images = images.numpy()
    numpy_labels = labels.numpy()
    if numpy_labels.dtype == object: # binary string in this case, these are image ID strings
        # If no labels, only image IDs, return None for labels (this is the case for test data)
        numpy_labels = [None for _ in enumerate(numpy_images)]
    return numpy_images, numpy_labels


def title_from_label_and_target(label, correct_label):
    if correct_label is None:
        return CLASSES[label], True
    correct = (label == correct_label)
    return "{} [{}{}{}]".format(CLASSES[label], 'OK' if correct else 'NO', u"\u2192" if not correct else '',
                                CLASSES[correct_label] if not correct else ''), correct


def display_one_plant(image, title, subplot, red = False, titlesize = 16):
    plt.subplot(*subplot)
    plt.axis('off')
    plt.imshow(image)
    if len(title) > 0:
        plt.title(title, fontsize = int(titlesize) if not red else int(titlesize / 1.2), color = 'red' if red else 'black', 
                  fontdict = {'verticalalignment':'center'}, pad = int(titlesize / 1.5))
    return (subplot[0], subplot[1], subplot[2] + 1)


def display_batch_of_images(databatch, predictions = None):
    """This will work with:
    display_batch_of_images(images)
    display_batch_of_images(images, predictions)
    display_batch_of_images((images, labels))
    display_batch_of_images((images, labels), predictions)
    """
    # data
    images, labels = batch_to_numpy_images_and_labels(databatch)
    if labels is None:
        labels = [None for _ in enumerate(images)]
        
    # auto-squaring: this will drop data that does not fit into square or square-ish rectangle
    rows = int(math.sqrt(len(images)))
    cols = len(images)//rows
        
    # size and spacing
    FIGSIZE = 13.0
    SPACING = 0.1
    subplot = (rows, cols, 1)
    if rows < cols:
        plt.figure(figsize = (FIGSIZE, FIGSIZE / cols * rows))
    else:
        plt.figure(figsize = (FIGSIZE / rows * cols, FIGSIZE))
    
    # display
    for i, (image, label) in enumerate(zip(images[:rows * cols], labels[:rows * cols])):
        title = '' if label is None else CLASSES[label]
        correct = True
        if predictions is not None:
            title, correct = title_from_label_and_target(predictions[i], label)
            
        # magic formula tested to work from 1x1 to 10x10 images
        dynamic_titlesize = FIGSIZE * SPACING / max(rows, cols) * 40 + 3 
        subplot = display_one_plant(image, title, subplot, not correct, titlesize = dynamic_titlesize)
    
    # layout
    plt.tight_layout()
    if label is None and predictions is None:
        plt.subplots_adjust(wspace = 0, hspace = 0)
    else:
        plt.subplots_adjust(wspace = SPACING, hspace = SPACING)
    plt.show()

In [None]:
# load training dataset to display images
training_dataset = get_training_dataset()
training_dataset = training_dataset.unbatch().batch(20)
train_batch = iter(training_dataset)

In [None]:
# run this cell again for another randomized set of training images
display_batch_of_images(next(train_batch))

In [None]:
# load validation dataset to display images
validation_dataset = get_validation_dataset()
validation_dataset = validation_dataset.unbatch().batch(20)
valid_batch = iter(validation_dataset)

In [None]:
# run this cell again for another randomized set of validation images
display_batch_of_images(next(valid_batch))

## Learning rate scheduler

In [None]:
# Learning Rate Schedule for Fine Tuning 
def exponential_lr(epoch,
                   start_lr = 0.00001, min_lr = 0.00001, max_lr = 0.00005,
                   rampup_epochs = 5, sustain_epochs = 0,
                   exp_decay = 0.8):

    def lr(epoch, start_lr, min_lr, max_lr, rampup_epochs, sustain_epochs, exp_decay):
        # linear increase from start to rampup_epochs
        if epoch < rampup_epochs:
            lr = ((max_lr - start_lr) /
                  rampup_epochs * epoch + start_lr)
            
        # constant max_lr during sustain_epochs
        elif epoch < rampup_epochs + sustain_epochs:
            lr = max_lr
            
        # exponential decay towards min_lr
        else:
            lr = ((max_lr - min_lr) *
                  exp_decay ** (epoch - rampup_epochs - sustain_epochs) +
                  min_lr)
        return lr
    
    return lr(epoch,
              start_lr,
              min_lr,
              max_lr,
              rampup_epochs,
              sustain_epochs,
              exp_decay)

lr_callback = tf.keras.callbacks.LearningRateScheduler(exponential_lr, verbose = True)

rng = [i for i in range(EPOCHS)]
y = [exponential_lr(x) for x in rng]
plt.plot(rng, y)

print("Learning rate schedule: {:.3g} to {:.3g} to {:.3g}".format(y[0], max(y), y[-1]))

## Build model


In [None]:
with strategy.scope():       
    base_model = tf.keras.applications.DenseNet201(weights = 'imagenet', include_top = False, input_shape = (512, 512, 3))
    base_model.trainable = True
    pooled_out = tf.keras.layers.GlobalAveragePooling2D()(base_model.output)
    out = tf.keras.layers.Dropout(0.5)(pooled_out)
    out = tf.keras.layers.Dense(len(CLASSES), activation = 'softmax')(out)
    
    model = tf.keras.Model(base_model.input, out)
    
    model.compile(
        optimizer = tf.keras.optimizers.Adam(),
        loss = 'sparse_categorical_crossentropy',  
        metrics = ['sparse_categorical_accuracy'])

## Train the model

In [None]:
# load data DenseNet201
train_dataset = get_training_dataset()
valid_dataset = get_validation_dataset()

STEPS_PER_EPOCH = NUM_TRAINING_IMAGES // BATCH_SIZE
VALID_STEPS = NUM_VALIDATION_IMAGES // BATCH_SIZE

In [None]:
# train the model
history = model.fit(train_dataset, steps_per_epoch = STEPS_PER_EPOCH, 
                    epochs = EPOCHS, validation_data = valid_dataset, 
                    validation_steps = VALID_STEPS, callbacks = [lr_callback])

## Save the model for inference

In [None]:
model.save('model.h5')

## Evaluate model

In [None]:
# print out variables available to us
print(history.history.keys())

In [None]:
# create learning curves to evaluate model performance
history_frame = pd.DataFrame(history.history)
history_frame.loc[:, ['loss', 'val_loss']].plot()
history_frame.loc[:, ['sparse_categorical_accuracy', 'val_sparse_categorical_accuracy']].plot();

## Make predictions on test data

In [None]:
# this code will convert our test image data to a float32 
def to_float32(image, label):
    return tf.cast(image, tf.float32), label

In [None]:
test_ds = get_test_dataset(ordered = True) 
test_ds = test_ds.map(to_float32)

print('Computing predictions...')
test_images_ds = test_ds
test_images_ds = test_ds.map(lambda image, idnum: image)
probabilities = model.predict(test_images_ds)
predictions = np.argmax(probabilities, axis = -1)
print(predictions)

In [None]:
print('Generating submission.csv file...')
test_ids_ds = test_ds.map(lambda image, idnum: idnum).unbatch()
test_ids = next(iter(test_ids_ds.batch(NUM_TEST_IMAGES))).numpy().astype('U') # all in one batch
np.savetxt('submission.csv', np.rec.fromarrays([test_ids, predictions]), fmt = ['%s', '%d'], delimiter =',', 
           header = 'id,label', comments = '')
!head submission.csv