# Import Necesary Libraries

In [None]:
import math, re, os
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow_addons as tfa
import numpy as np
import pandas as pd
import json
import random
import PIL
from tqdm import tqdm
from matplotlib import pyplot as plt
from kaggle_datasets import KaggleDatasets
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix

AUTO = tf.data.experimental.AUTOTUNE

# Setup and Detect TPU 

Please read the kaggle documentation for using tpu: 
[Tensor Processing Units (TPUs) Documentation | Kaggle](https://www.kaggle.com/docs/tpu)

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
    print("Running on TPU ", tpu.cluster_spec().as_dict()["worker"])
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except ValueError:
    print("Not connected to a TPU runtime. Using CPU/GPU strategy")
    strategy = tf.distribute.MirroredStrategy()

# NEW on TPU in TensorFlow 24: shorter cross-compatible TPU/GPU/multi-GPU/cluster-GPU detection code

try: # detect TPUs
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect() # TPU detection
    strategy = tf.distribute.TPUStrategy(tpu)
except ValueError: # detect GPUs
    strategy = tf.distribute.MirroredStrategy() # for GPU or multi-GPU machines
    #strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU
    #strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() # for clusters of multi-GPU machines

print("Number of accelerators: ", strategy.num_replicas_in_sync)

In [None]:
image_size = (512,512)
epochs= 12
batch_size = 16 * strategy.num_replicas_in_sync
print('Batch size:', batch_size)

# Read Data From Google Cloud Storage (GCS)

TPUs read data directly from Google Cloud Storage (GCS). This Kaggle utility will copy the dataset to a GCS bucket co-located with the TPU. If you have multiple datasets attached to the notebook, you can pass the name of a specific dataset to the get_gcs_path function. The name of the dataset is the name of the directory it is mounted in. Use !ls /kaggle/input/ to list attached datasets.

> from kaggle_datasets import KaggleDatasets

In [None]:
gcs_path = KaggleDatasets().get_gcs_path()
print(gcs_path)

training_files = tf.io.gfile.glob(gcs_path + '/train_tfrecords/*.tfrec')
test_files = tf.io.gfile.glob(gcs_path + '/test_tfrecords/*.tfrec')

print('Training tfrecords: '+ str(len(training_files)))
print('Test tfrecords: '+ str(len(test_files)))


def count_data_items(filenames):
    # the number of data items is written in the name of the .tfrec files, i.e. flowers00-230.tfrec = 230 data items
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

NUM_TRAINING_IMAGES = count_data_items(training_files)
NUM_TEST_IMAGES = count_data_items(test_files)

print('Dataset: {} training images, {} unlabeled test images'.format(NUM_TRAINING_IMAGES, NUM_TEST_IMAGES))

# The train.csv & label_num_to_disease_map.json

In [None]:
with open('../input/cassava-leaf-disease-classification/label_num_to_disease_map.json') as js:
    classes = json.load(js)
print(classes)

train = pd.read_csv('../input/cassava-leaf-disease-classification/train.csv')
print('Number of entries:', len(train))
print('Label Frequencies:')
print(train['label'].value_counts().plot.bar())

# Functions for reading tfrecord files


**decode_image** - For converting bytestring images into arrays.

**read_labeled_tfrecord** - Returns image & label from the tfrecords.

**read_labeled_tfrecord_with_imageid** - Returns image, label & image id from the tfrecords.

**read_unlabeled_tfrecord** - Returns image & image id.

The **keys of the dictionaries** (*i.e. LABELED_TFREC_FORMAT, UNLABELED_TFREC_FORMAT*) need to match the **keys in the tfrecords**. If the keys dont match then it will throw an **InvalidArgumentError**. 
> InvalidArgumentError: Feature: (data type: string) is required but could not be found

*When training we will use the **read_labeled_tfrecords** because we dont need the image ids during training.*

In [None]:
def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)  # image format uint8 [0,255]
    image = tf.reshape(image, [*image_size, 3]) # explicit size needed for TPU
    return image


def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "target": tf.io.FixedLenFeature([], tf.int64),  # shape [] means single element
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    
    image = decode_image(example['image'])
    label = tf.cast(example['target'], tf.int32)

    return image, label


def read_labeled_tfrecord_with_imageid(example):
    LABELED_TFREC_FORMAT_WITH_ID = {
        "target": tf.io.FixedLenFeature([], tf.int64),  # shape [] means single element
        "image_name": tf.io.FixedLenFeature([], tf.string),
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT_WITH_ID)
    
    image = decode_image(example['image'])
    label = tf.cast(example['target'], tf.int32)
    image_name = example['image_name']
    
    return image, label, image_name # returns a dataset of (image, label) pairs\

def read_unlabeled_tfrecord(example):
    UNLABELED_TFREC_FORMAT = {
        'image_name' : tf.io.FixedLenFeature([], tf.string),
        'image' : tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
    
    image = decode_image(example['image'])
    image_name = example['image_name']
    
    return image, image_name

# Visualize Some Images
Display 1 image from each tfrecord with corresponding image_id & disease. 

Display the test image.

In [None]:
# Custom function for visualization
def show_im(fig, row, col, index, path=None, image=None, title=None, title_color='white'):
    if image is not None:
      image = image
    elif path is not None:
      image = PIL.Image.open(path)   
    ax = fig.add_subplot(row, col, index)
    ax.set_xticks([]), ax.set_yticks([])  # to hide tick values on X and Y axis
    ax.imshow(image)
    
    if title:
        plt.title(title,
                  color=title_color)
        
    fig.tight_layout(pad=0.02)

In [None]:
# Display 1 image from each tfrecord with corresponding image_id & disease
fig1 = plt.figure(figsize=(20,20))

for i in range(len(training_files)):
    raw_dataset = tf.data.TFRecordDataset(training_files[i])
    for raw_record in raw_dataset.take(1):
        image, label, image_name = read_labeled_tfrecord_with_imageid(raw_record)
        label = str(int(label))
        image_name = image_name.numpy().decode('utf-8')

        show_im(fig1,4,4,i+1,image=image, title=f'{image_name}/{classes[label]}')

In [None]:
# Display the test data (only 1 image)
fig1 = plt.figure(figsize=(8,8))

raw_dataset = tf.data.TFRecordDataset(test_files[0])
for raw_record in raw_dataset.take(1):

    image, image_name = read_unlabeled_tfrecord(raw_record)
    image_name = image_name.numpy().decode('utf-8')

    show_im(fig1,1,1,1,image=image, title=f'{image_name}')

# Functions for Loading Dataset

Read from TFRecords. For optimal performance, read from multiple files at once.

In [None]:
def load_dataset(filenames, labeled=True, ordered=False):
    # When ordered=False, disregarding data order. Order does not matter since we will be shuffling the data anyway.

    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False # disable order, increase speed

    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(read_labeled_tfrecord if labeled else read_unlabeled_tfrecord, num_parallel_calls=AUTO)
    # returns a dataset of (image, label) pairs if labeled=True or (image, id) pairs if labeled=False
    return dataset

def get_training_dataset():
    dataset = load_dataset(training_files, labeled=True)
    dataset = dataset.repeat() # the training dataset must repeat for several epochs
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def get_test_dataset(ordered=False):
    dataset = load_dataset(test_files, labeled=False, ordered=ordered)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

In [None]:
print("Training data shape:")
for image, label in get_training_dataset().take(3):
    print(image.numpy().shape, label.numpy().shape)
print("Training data label examples:", label.numpy())

print("Test data shape:")
for image, image_name in get_test_dataset().take(3):
    print(image.numpy().shape, image_name.numpy().shape)
print("Test data IDs:", image_name.numpy().astype('U')) # U=unicode string

# Define & Compile Model within strategy.scope():

In [None]:
seed = 1200

with strategy.scope():
#     img_adjust = tf.keras.layers.Lambda(lambda data: tf.keras.applications.inception_resnet_v2.preprocess_input(tf.cast(data, tf.float32)), 
#                                               input_shape=[*image_size, 3])
    
    pretrained = keras.applications.InceptionResNetV2(include_top=False, weights="imagenet", input_shape=(*image_size, 3))
    pretrained.trainable = True
    
    model = keras.Sequential([
#         img_adjust,
        pretrained,
        keras.layers.GlobalAveragePooling2D(),
        keras.layers.Dense(len(classes), 
            kernel_initializer=keras.initializers.RandomUniform(seed=seed),
            bias_initializer=keras.initializers.Zeros(), name='dense_top', activation='softmax')
    ])

    model.compile(loss= keras.losses.SparseCategoricalCrossentropy(), 
                  optimizer= keras.optimizers.Adam(lr=1e-4), 
                  metrics= ['sparse_categorical_accuracy'],
                  steps_per_execution=16
                 )

print(model.summary())

# Training

In [None]:
steps_per_epoch = NUM_TRAINING_IMAGES//batch_size
print(steps_per_epoch)

model.fit(get_training_dataset(), epochs=epochs, steps_per_epoch=steps_per_epoch)

# Let's Validate on Training Data

Validate on 10 batches of 1000 images and the validation accuracy is soemthing like this.

In [None]:
val_dataset = get_training_dataset()
val_dataset = val_dataset.unbatch().batch(1000)
batch = iter(val_dataset)

for i in range(1, 11):
    val_images, val_labels = next(batch)

    probabilities = model.predict(val_images)
    predictions = np.argmax(probabilities, axis=-1)

    correct = 0
    for j in range(len(val_labels.numpy())):
        if val_labels.numpy()[j]==predictions[j]:
            correct +=1

    print(f'Validation accuracy of batch {i}: ', correct/1000*100)

# These are great scores! 
However there might be some catches ;)  .

I am not going to be exploring that in this notebook. 

# Visualization
Visualization of the model's performance on actual image files (not tfrecords). Here visualization on 25 random images is done.

In [None]:
random_images = random.sample(range(1, NUM_TRAINING_IMAGES), 25)


fig1 = plt.figure(figsize=(35,35))

j=0
for i in tqdm(random_images):
    j+=1
    
    img = PIL.Image.open('../input/cassava-leaf-disease-classification/train_images/'+train['image_id'][i])
    img = img.resize(image_size)  # Resize to (512,512)
    array = keras.preprocessing.image.img_to_array(img)  # Convert the image to a tensor
    array = tf.reshape(array, (1, 512, 512, 3))

    
    output = model.predict(array)
    index = str(np.argmax(output))
    
    prediction = classes[index]
    truth = classes[str(train['label'][i])]
    
    if prediction==truth:
        title_color='green'
    else:
        title_color='red'
    
    show_im(fig1, 5, 5, j, image=img, title=f'{prediction}/{truth}', title_color=title_color)
    