# Hello Kaggler!

In this notebook, we will try to use **TPU** (Tensor Processing Unit) provided by **Google** to train a Convolutional Neural Network on the dataset provided by **Plant Pathology 2021** competition.

The following are the steps we will do here.

1. Create and format a list of names and labels (as either **1** or **0** for each of the 5 labels except the label 'healthy' i.e. an image will be considered as 'healthy' if none of the other 5 classes are **1**) ready to be used by the input pipline.
2. Split it into training and validation set.
3. Connect to the **TPU** clusters consisting of 8-cores and obtaining **GCS** path for the **Plant Pathology 2021** dataset.
4. Develop the **TPU** strategy.
5. Create an optimized input pipline using **TensorFlow** ***tf.data.Dataset*** API.
6. Map preprocessing and augmentation on the input pipepline.
7. Further optimizing the pipline by incorporating **Caching**, **Prefetching**, and **Mapping Parallelism**.
8. Define the model under the scope of *strategy*.
9. Creating custom training loop for *forward inference* and *backpropagation* using **Gradient Tape** API provided by **TensorFlow**.
10. Train the network.

Check out the other notebook for GPU [TensorFlow-Custom-Distributed-Training-GPU](https://www.kaggle.com/mohammadasimbluemoon/tensorflow-custom-distributed-training-gpu)

1. **To start with first we import all the required libraries**

In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow_addons as tfa
from matplotlib import image
from matplotlib import pyplot
import os
import cv2
import random
import concurrent.futures
import time
import sklearn
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from keras import backend as K
from kaggle_secrets import UserSecretsClient
from kaggle_datasets import KaggleDatasets
from PIL import Image

print(tf.__version__)

2. **In the following code, the training list of names and labels are shown. We can see there are 12 different kinds of combinations. This is a multi-label problem, because we can see that the unique labels are the first 6 labels, and the rest are their combinations.**


    label_names=['healthy', 'scab', 'frog_eye_leaf_spot', 'powdery_mildew', 'rust', 'complex']

In [None]:
train = pd.read_csv("../input/plant-pathology-2021-fgvc8/train.csv", dtype=str)
print(train['labels'].value_counts())
print(train['labels'].value_counts().plot.bar())
print(train['labels'].count())
train.head()

In [None]:

label_names=['scab', 'frog_eye_leaf_spot', 'powdery_mildew', 'rust', 'complex']
names=[]
labels = []
for i in range(len(train)):
    name = train['image'][i]
    label = train['labels'][i]
    splits = label.split()
    vec = np.zeros(len(label_names))
    for split in splits:
        if split!='healthy':
            vec[label_names.index(split)] = 1
    labels.append(vec)
    names.append(name)
def myfunc():
    return 0.2
c = list(zip(names, labels))
random.shuffle(c, myfunc)
names, labels = zip(*c)

# Splitting into train and validation sets
VAL_SPLIT = 0.2
train_names, val_names, train_labels, val_labels = train_test_split(names, labels, \
                                                   test_size=VAL_SPLIT, random_state=42,\
                                                   stratify=labels)
# train_names, _, train_labels, _ = train_test_split(train_names, train_labels, \
#                                                    test_size=0.5, random_state=42,\
#                                                    stratify=train_labels)

train_names = list(train_names)
val_names = list(val_names)
print("Length of training set: ", len(train_names))
print("Length of validation set: ", len(val_names))

In [None]:
user_secrets = UserSecretsClient()
user_credential = user_secrets.get_gcloud_credential()
user_secrets.set_tensorflow_credential(user_credential)

GCS_PATH = KaggleDatasets().get_gcs_path('plant-pathology-2021-fgvc8')
print(GCS_PATH)

In [None]:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
print("All devices: ", tf.config.list_logical_devices('TPU'))

In [None]:
strategy = tf.distribute.TPUStrategy(resolver)

TRAIN_BATCH_SIZE = 32 * strategy.num_replicas_in_sync
TRAIN_SHUFFLE_BUFFER = 6144
VAL_BATCH_SIZE = 32 * strategy.num_replicas_in_sync
VAL_SHUFFLE_BUFFER = 3584

In [None]:
SEED = 10000

random_rotation = tf.keras.layers.experimental.preprocessing.RandomRotation(3.142/2, seed=SEED)
random_flip = tf.keras.layers.experimental.preprocessing.RandomFlip(mode="horizontal_and_vertical", seed=SEED)
random_zoom = tf.keras.layers.experimental.preprocessing.RandomZoom((-0.1, 0.35), seed=SEED)
random_translate = tf.keras.layers.experimental.preprocessing.RandomTranslation((-0.2, 0.2), (-0.2, 0.2), seed=SEED)
random_contrast = tf.keras.layers.experimental.preprocessing.RandomContrast((0.2, 1.5), seed=SEED)

IMSIZE = 512
CHANNEL = 3

def _parse_train(name, label):
    image_string = tf.io.read_file(GCS_PATH + '/train_images/' + name)
    image_decoded = tf.image.decode_jpeg(image_string)
    image_resized = tf.image.resize(image_decoded, [IMSIZE, IMSIZE])
    imgs = tf.reshape(image_resized, (1,IMSIZE, IMSIZE, 3))
    imgs = random_rotation.call(imgs)
    imgs = random_flip.call(imgs)
    imgs = random_zoom.call(imgs)
    imgs = random_translate.call(imgs)
    imgs = random_contrast.call(imgs)
    imgs = tf.reshape(imgs, (IMSIZE, IMSIZE, 3))
    return imgs/255, label

def _parse_val(name, label):
    image_string = tf.io.read_file(GCS_PATH + '/train_images/' + name)
    image_decoded = tf.image.decode_jpeg(image_string)
    image_resized = tf.image.resize(image_decoded, [IMSIZE, IMSIZE])
#     image_resized = tf.cast(image_resized, tf.uint8)
    return image_resized/255, label

# def _normalize(img, label):
#     return tf.cast(img, tf.uint8)/255, label
train_dataset = tf.data.Dataset.from_tensor_slices((tf.constant(train_names), tf.constant(train_labels)))\
                               .map(_parse_train, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False)\
                               .shuffle(TRAIN_SHUFFLE_BUFFER)\
                               .prefetch(tf.data.AUTOTUNE)\
                               .cache()\
                               .batch(TRAIN_BATCH_SIZE, drop_remainder=True)\
                               .prefetch(tf.data.AUTOTUNE)

val_dataset = tf.data.Dataset.from_tensor_slices((tf.constant(val_names), tf.constant(val_labels)))\
                             .map(_parse_val, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False)\
                             .shuffle(VAL_SHUFFLE_BUFFER)\
                             .prefetch(tf.data.AUTOTUNE)\
                             .cache()\
                             .batch(VAL_BATCH_SIZE, drop_remainder=True)\
                             .prefetch(tf.data.AUTOTUNE)

In [None]:
# def IResNet_brain_module(inputx, layercount):
#     n = layercount
#     ############################################################################
#     # Parallel Block 1
#     x1_1 = tf.keras.layers.Conv2D(n, (3, 3), activation = 'relu', padding = 'same', kernel_regularizer='l2', bias_regularizer='l1')(inputx)
#     x1_2 = tf.keras.layers.Conv2D(n, (1, 1), activation = 'relu', padding = 'same', kernel_regularizer='l2', bias_regularizer='l1')(x1_1)
#     x1 = tf.keras.layers.add([x1_1, x1_2])
#     ############################################################################
#     # Parallel Block 2
#     x2_1 = tf.keras.layers.Conv2D(n, (3, 3), activation = 'relu', padding = 'same', kernel_regularizer='l2', bias_regularizer='l1')(inputx)
#     x2_1_1 = tf.keras.layers.Conv2D(n, (3, 3), activation = 'relu', padding = 'same', kernel_regularizer='l2', bias_regularizer='l1')(x2_1)
#     x2_2 = tf.keras.layers.Conv2D(n, (1, 1), activation = 'relu', padding = 'same', kernel_regularizer='l2', bias_regularizer='l1')(x2_1)
#     x2 = tf.keras.layers.add([x2_1, x2_1_1, x2_2])
#     x2 = tf.keras.layers.add([x2, x1])
#     ############################################################################
#     # Parallel Block 3
#     x3_1 = tf.keras.layers.Conv2D(n, (1, 1), activation = 'relu', padding = 'same', kernel_regularizer='l2', bias_regularizer='l1')(inputx)
#     x3_2 = tf.keras.layers.Conv2D(n, (1, 1), activation = 'relu', padding = 'same', kernel_regularizer='l2', bias_regularizer='l1')(inputx)
#     x3 = tf.keras.layers.add([x3_1, x3_2])
#     x3 = tf.keras.layers.add([x3, x2, x1])
#     ############################################################################
#     # Parallel Block 4
#     x4_1 = tf.keras.layers.Conv2D(n, (3, 3), activation = 'relu', padding = 'same', kernel_regularizer='l2', bias_regularizer='l1')(inputx)
#     x4_1_1 = tf.keras.layers.Conv2D(n, (3, 3), activation = 'relu', padding = 'same', kernel_regularizer='l2', bias_regularizer='l1')(x4_1)
#     x4_1_2 = tf.keras.layers.Conv2D(n, (3, 3), activation = 'relu', padding = 'same', kernel_regularizer='l2', bias_regularizer='l1')(x4_1_1)
#     x4_2 = tf.keras.layers.Conv2D(n, (1, 1), activation = 'relu', padding = 'same', kernel_regularizer='l2', bias_regularizer='l1')(x4_1)
#     x4 = tf.keras.layers.add([x4_1, x4_1_1, x4_1_2, x4_2])
#     x4 = tf.keras.layers.add([x4, x3, x2, x1])
#     ############################################################################
#     mod = tf.keras.layers.concatenate([x1, x2, x3, x4], axis = -1)
#     mod = tf.keras.layers.BatchNormalization()(mod)
#     return mod
# def IResNet_connection_module(inputx, layercount):
#     n = layercount
#     ############################################################################
#     # Parallel Block 1
#     x1 = tf.keras.layers.Conv2D(n, (1, 3), activation = 'relu', padding = 'same', kernel_regularizer='l2', bias_regularizer='l1')(inputx)
#     x2 = tf.keras.layers.Conv2D(n, (3, 1), activation = 'relu', padding = 'same', kernel_regularizer='l2', bias_regularizer='l1')(x1)
#     x3 = tf.keras.layers.add([inoutx, x2])

#     ############################################################################
#     # Parallel Block 2
#     x4 = tf.keras.layers.Conv2D(n, (1, 3), activation = 'relu', padding = 'same', kernel_regularizer='l2', bias_regularizer='l1')(inputx)
#     x5 = tf.keras.layers.Conv2D(n, (3, 1), activation = 'relu', padding = 'same', kernel_regularizer='l2', bias_regularizer='l1')(x4)
#     x6 = tf.keras.layers.add([inputx, x5])

#     x7 = tf.keras.layers.Conv2D(n, (1, 3), activation = 'relu', padding = 'same', kernel_regularizer='l2', bias_regularizer='l1')(x6)
#     x8 = tf.keras.layers.Conv2D(n, (3, 1), activation = 'relu', padding = 'same', kernel_regularizer='l2', bias_regularizer='l1')(x7)
#     x9 = tf.keras.layers.add([x6, x8])

#     ############################################################################
#     # Parallel Block 3
#     x10 = tf.keras.layers.Conv2D(n, (1, 1), activation = 'relu', padding = 'same', kernel_regularizer='l2', bias_regularizer='l1')(inputx)
#     ############################################################################
#     mod = tf.keras.layers.concatenate([x3, x9, x10, inputx], axis = -1)
#     mod = tf.keras.layers.BatchNormalization()(mod)
#     return mod

# def IResNet_reduction_module(inputx, layercount):
#     n = layercount
#     ############################################################################
#     # Reduction module
#     R1= tf.keras.layers.Conv2D(n, (3, 3), activation = 'relu', kernel_regularizer='l2', bias_regularizer='l1')(inputx)
#     R1 = tf.keras.layers.Conv2D(n, (3, 3), strides = 2, activation = 'relu', kernel_regularizer='l2', bias_regularizer='l1')(R1)
#     R1 = tf.keras.layers.BatchNormalization()(R1) 
#     R1 = tf.keras.layers.Conv2D(n, (3, 3), strides = 2, activation = 'relu', kernel_regularizer='l2', bias_regularizer='l1')(R1)
#     mod = tf.keras.layers.BatchNormalization()(R1)
#     R1 = tf.keras.layers.Conv2D(n, (3, 3), strides = 2, activation = 'relu', kernel_regularizer='l2', bias_regularizer='l1')(R1)
#     mod = tf.keras.layers.BatchNormalization()(R1)
#     return mod
# def IResNet_classifier_module(input, num_classes, activation):
#     ############################################################################
#     # Classifier module
#     R1= tf.keras.layers.Flatten()(input)
#     mod = tf.keras.layers.Dense(num_classes, activation=activation)(R1)
#     return mod

In [None]:
# def IResNet_connection_module(inputx, layercount):
#     n = layercount
#     ############################################################################
#     # Parallel Block 1
#     x1 = tf.keras.layers.SeparableConv2D(n, (1, 3), activation = 'relu', padding = 'same', kernel_regularizer='l2', bias_regularizer='l1')(inputx)
#     x2 = tf.keras.layers.Conv2D(n, (3, 1), activation = 'relu', padding = 'same', kernel_regularizer='l2', bias_regularizer='l1')(x1)
#     x3 = tf.keras.layers.add([x1, x2])

#     ############################################################################
#     # Parallel Block 2
#     x4 = tf.keras.layers.SeparableConv2D(n, (1, 3), activation = 'relu', padding = 'same', kernel_regularizer='l2', bias_regularizer='l1')(inputx)
#     x5 = tf.keras.layers.Conv2D(n, (3, 1), activation = 'relu', padding = 'same', kernel_regularizer='l2', bias_regularizer='l1')(x4)
#     x6 = tf.keras.layers.add([x4, x5])

#     x7 = tf.keras.layers.SeparableConv2D(n, (1, 3), activation = 'relu', padding = 'same', kernel_regularizer='l2', bias_regularizer='l1')(x6)
#     x8 = tf.keras.layers.Conv2D(n, (3, 1), activation = 'relu', padding = 'same', kernel_regularizer='l2', bias_regularizer='l1')(x7)
#     x9 = tf.keras.layers.add([x7, x8, x6])

#     ############################################################################
#     # Parallel Block 3
#     x10 = tf.keras.layers.Conv2D(n, (1, 1), activation = 'relu', padding = 'same', kernel_regularizer='l2', bias_regularizer='l1')(inputx)
#     ############################################################################
#     mod = tf.keras.layers.concatenate([x3, x9, x10, inputx], axis = -1)
#     mod = tf.keras.layers.BatchNormalization()(mod)
#     return mod

In [None]:
# with strategy.scope():
#     inp = tf.keras.layers.Input(shape=(IMSIZE, IMSIZE, CHANNEL))
#     connect = IResNet_connection_module(inp, 8)
#     connect = IResNet_connection_module(connect, 8)
#     red = IResNet_reduction_module(connect, 64)
#     connect = IResNet_connection_module(red, 16)
#     connect = IResNet_connection_module(red, 32)
#     red = IResNet_reduction_module(connect, 128)
#     cla = IResNet_classifier_module(red, 5, 'sigmoid')
#     model = tf.keras.models.Model(inputs=inp, outputs=cla, name = "IResNetv1")
#     model.summary()
#     tf.keras.utils.plot_model(model, show_shapes=True,to_file='./img.png')


In [None]:
with strategy.scope():
    base_model = tf.keras.applications.Xception(include_top=False,\
                                                   weights='imagenet', pooling='max')
    base_model.trainable=True
    model = tf.keras.Sequential([
        base_model,
        tf.keras.layers.Dense(5, activation='sigmoid')
    ])
    model.summary()
    model.load_weights('../input/plantpathology2021kerasmodelsxception/xception-best-loss-aug.h5')
    optimizer = tf.keras.optimizers.SGD(0.001)
    epoch_auc = tf.keras.metrics.AUC(num_thresholds=200, multi_label=True)
    val_epoch_auc = tf.keras.metrics.AUC(num_thresholds=200, multi_label=True)
    loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
    val_loss = tf.keras.metrics.Mean()
    acc = tf.keras.metrics.BinaryAccuracy()
    val_acc = tf.keras.metrics.BinaryAccuracy()
    f1 = tfa.metrics.F1Score(num_classes=5, average='weighted', threshold=0.5)
    val_f1 = tfa.metrics.F1Score(num_classes=5, average='weighted', threshold=0.5)

train_loss_history = []
train_acc_history = []
train_f1_history = []
train_auc_history = []
val_loss_history = []
val_acc_history = []
val_f1_history = []
val_auc_history = []
lr_list = []
dist_train_dataset = strategy.experimental_distribute_dataset(train_dataset)
dist_val_dataset = strategy.experimental_distribute_dataset(val_dataset)

In [None]:
with strategy.scope():
    def compute_loss(labels, predictions):
        per_example_loss = loss_object(labels, predictions)
        return tf.nn.compute_average_loss(per_example_loss, global_batch_size=TRAIN_BATCH_SIZE)
    
def train_step(inputs):
    images, labels = inputs
    with tf.GradientTape() as tape:
        logits = model(images)
        loss_value = compute_loss(labels, logits)
    grads = tape.gradient(loss_value, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    epoch_auc.update_state(labels, logits)
    acc.update_state(labels, logits)
    f1.update_state(labels, logits)
    train_loss_history.append(loss_value)
    return loss_value

@tf.function
def distributed_train_step(dist_inputs):
    per_replica_losses = strategy.run(train_step, args=(dist_inputs,))
    loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                         axis=None)
    return loss

def val_step(inputs):
    images, labels = inputs
    logits = model(images)
    loss_value = loss_object(labels, logits)
    
    val_f1.update_state(labels, logits)
    val_loss.update_state(loss_value)
    val_acc.update_state(labels, logits)
    val_epoch_auc.update_state(labels, logits)
    
@tf.function
def distributed_val_step(dist_inputs):
    strategy.run(val_step, args=(dist_inputs,))
    
def train(epochs, modelname, verbose=1, PATIENCE=4, DECAY=0.9):
    
    ########################## Epoch Loop ##########################
    patience = 0
    for epoch in range(epochs):
        lr_list.append(optimizer.learning_rate.numpy())
        ind = 0
        start = time.time()
        i = 0
        print ('\nEpoch {}/{} '.format(epoch+1, epochs))
        
        ####################### Train Loop #########################
        num_batches = 0
        loss = 0.0
        for data in dist_train_dataset:
            loss += distributed_train_step(data)
            num_batches += 1
            auc = epoch_auc.result()
            accuracy = acc.result()
            f1score = f1.result()
            percent = float(i+1) * 100 / len(train_dataset)
            arrow   = '-' * int(percent/100 * 20 - 1) + '>'
            spaces  = ' ' * (20 - len(arrow))
            if(verbose):    
                print('\rTraining: [%d/%d] [%s%s] %d %% - Training Loss: %f - Training AUC: %f - Training ACC: %f - Training F1: %f'% (num_batches, len(train_dataset), arrow, spaces, percent, loss/num_batches, auc, accuracy, f1score), end='', flush=True)
            i += 1
        
        train_loss_history.append(loss.numpy()/num_batches)
        train_acc_history.append(accuracy.numpy())
        train_f1_history.append(f1score.numpy())
        train_auc_history.append(auc.numpy())
        if(not verbose):
            print(' Epoch Loss: ', loss/num_batches)
        i = 0
        if(verbose):
            print(" -", int(time.time()-start), "s", end="")
            print()
        start = time.time()
        
        ####################### Validation Loop #########################
        num_batches=0
        for data in dist_val_dataset:
            num_batches += 1
            distributed_val_step(data)
            auc = val_epoch_auc.result()
            loss = val_loss.result()
            accuracy = val_acc.result()
            f1score = val_f1.result()
            percent = float(i+1) * 100 / len(val_dataset)
            arrow   = '-' * int(percent/100 * 20 - 1) + '>'
            spaces  = ' ' * (20 - len(arrow))
            if(verbose):    
                print('\rValidate: [%d/%d] [%s%s] %d %% - Validation Loss: %f - Validation AUC: %f - Validation ACC: %f - Validation F1: %f'% (num_batches, len(val_dataset), arrow, spaces, percent, loss, auc, accuracy, f1score), end='', flush=True)
            i += 1
            
        if(epoch > 0):
            if(loss.numpy() < min(val_loss_history)):
                tf.keras.models.save_model(model, './' + modelname + '-best-loss-aug.h5', save_format='h5', include_optimizer=True, overwrite=True)

            if(accuracy.numpy() > max(val_acc_history)):
                tf.keras.models.save_model(model, './' + modelname + '-best-acc-aug.h5', save_format='h5', include_optimizer=True, overwrite=True)

            if(f1score.numpy() > max(val_f1_history)):
                tf.keras.models.save_model(model, './' + modelname + '-best-f1-aug.h5', save_format='h5', include_optimizer=True, overwrite=True)
        
            if(loss.numpy() >= min(val_loss_history)):
                if(patience >= PATIENCE):
                    patience = 0
                    K.set_value(optimizer.learning_rate, optimizer.learning_rate.numpy()*DECAY)
                    ind = 1
                patience += 1
        
        val_loss_history.append(loss.numpy())
        val_acc_history.append(accuracy.numpy())
        val_f1_history.append(f1score.numpy())
        val_auc_history.append(auc.numpy())
        if(verbose):
            print(" -", int(time.time()-start), "s")
            if(ind):
                print("\nLearning rate reduced to: ", optimizer.learning_rate.numpy())
            
        epoch_auc.reset_states()
        val_epoch_auc.reset_states()
        val_loss.reset_states()
        acc.reset_states()
        val_acc.reset_states()
        f1.reset_states()
        val_f1.reset_states()
    model.save(modelname + '-aug-final-epoch.h5')

In [None]:
train(100, 'xception', 1, 2, 0.8)

In [None]:
plt.figure(figsize=(20,15))

plt.subplot(3,2,1)
plt.plot(train_loss_history[2:], label = "train_loss")
plt.plot(val_loss_history, label = "val_loss")
plt.title('Loss Profile')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend()

plt.subplot(3,2,2)
plt.plot(train_acc_history, label = "train_acc")
plt.plot(val_acc_history, label = "val_acc")
plt.title('Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend()

plt.subplot(3,2,3)
plt.plot(train_f1_history, label = "train_f1")
plt.plot(val_f1_history, label = "val_f1")
plt.title('F1 Score')
plt.ylabel('F1 Score')
plt.xlabel('Epoch')
plt.legend()

plt.subplot(3,2,4)
plt.plot(train_auc_history, label = "train_auc")
plt.plot(val_auc_history, label = "val_auc")
plt.title('AUC')
plt.ylabel('AUC')
plt.xlabel('Epoch')
plt.legend()

plt.subplot(3,2,5)
plt.plot(lr_list, label="lr")
plt.title('Learning Rate')
plt.ylabel('learning rate')
plt.xlabel('Epoch')
plt.legend()

plt.show()