<div class="list-group" id="list-tab" role="tablist">
<a id="10"></a>
<h2 class="list-group-item list-group-item-action active" data-toggle="list" style='background:coral; border:0; color:blue' role="tab" aria-controls="home"><center>1. Cassava Leaf Disease Classification - Intorduction
 </center></h2>
    
<img src = "https://i2.wp.com/agrihomegh.com/wp-content/uploads/2016/07/CassavaBB-source-www.plantwise.org_-e1490126745230.jpg?resize=640%2C468&ssl=1">
<br><br>
As the second-largest provider of carbohydrates in Africa, cassava is a key food security crop grown by smallholder farmers because it can withstand harsh conditions. 
    
At least 80% of household farms in Sub-Saharan Africa grow this starchy root, but viral diseases are major sources of poor yields. 
    
With the help of data science, it may be possible to identify common diseases so they can be treated.

Existing methods of disease detection require farmers to solicit the help of government-funded agricultural experts to visually inspect and diagnose the plants. 
    
This suffers from being labor-intensive, low-supply and costly. 
    
As an added challenge, effective solutions for farmers must perform well under significant constraints, since African farmers may only have access to mobile-quality cameras with low-bandwidth.

In this competition, we are given a dataset of 21,367 labeled images collected during a regular survey in Uganda. 
    
Most images were crowdsourced from farmers taking photos of their gardens, and annotated by experts at the National Crops Resources Research Institute (NaCRRI) in collaboration with the AI lab at Makerere University, Kampala. This is in a format that most realistically represents what farmers would need to diagnose in real life.

Our goal is to classify each cassava image into four disease categories or a fifth category indicating a healthy leaf. 
    
With our help, farmers may be able to quickly identify diseased plants, potentially saving their crops before they inflict irreparable damage.

<div class="list-group" id="list-tab" role="tablist">
<a id="10"></a>
<h2 class="list-group-item list-group-item-action active" data-toggle="list" style='background:coral; border:0; color:blue' role="tab" aria-controls="home"><center>2. Import Libraries
 </center></h2>

In [None]:
import numpy as np
import pandas as pd

import tensorflow as tf

import re
import math

from matplotlib import pyplot as plt
import seaborn as sns
from colorama import Fore, Back, Style
import cv2

import tensorflow.keras.backend as K
from keras.applications import VGG19
from keras.models import Sequential
from keras.layers import Dense,Flatten,Conv2D,MaxPooling2D,Dropout
from keras import optimizers

<div class="list-group" id="list-tab" role="tablist">
<a id="10"></a>
<h2 class="list-group-item list-group-item-action active" data-toggle="list" style='background:coral; border:0; color:blue' role="tab" aria-controls="home"><center>3. Files Availabe for the competition</center></h2>

1) `[train/test]_images` - The full set of test images will only be available to your notebook when it is submitted for scoring. Expect to see roughly 15,000 images in the test set.

2) `train.csv`
* image_id the image file name.
* label the ID code for the disease.

3) `sample_submission.csv` -  A properly formatted sample submission, given the disclosed test set content.
* image_id the image file name.
* label the predicted ID code for the disease.

4) `[train/test]_tfrecords` -  The image files in tfrecord format.

5) `label_num_to_disease_map.json` - The mapping between each disease code and the real disease name.

In [None]:
# load train.csv and check
train = pd.read_csv("../input/cassava-leaf-disease-classification/train.csv")
train.head()

In [None]:
# check how many unique diseases are present in the dataset
n_classes = train.label.nunique()
n_classes

In [None]:
# lets take a look at the submission file
sub = pd.read_csv("../input/cassava-leaf-disease-classification/sample_submission.csv")
sub.head()

<div class="list-group" id="list-tab" role="tablist">
<a id="10"></a>
<h2 class="list-group-item list-group-item-action active" data-toggle="list" style='background:coral; border:0; color:blue' role="tab" aria-controls="home"><center>4. Exploratory Data Analysis
 </center></h2>

In [None]:
# check how many unique diseases are present in the dataset
n_classes = train.label.nunique()
n_classes

In [None]:
print("List of unique classes(disease):",train.label.unique())

<div class="list-group" id="list-tab" role="tablist">
<a id="10"></a>
<h3 class="list-group-item list-group-item-action active" data-toggle="list" style='background:coral; border:0; color:blue' role="tab" aria-controls="home">4.1 Class Mapping</h3>

In [None]:
print(Fore.BLUE + "            Disease-Label Mapping",Style.RESET_ALL)
print("----------------------------------------------")
print(Fore.YELLOW + "Label              Disease",Style.RESET_ALL)
print("  0        Cassava Bacterial Blight (CBB)\n")
print("  1        Cassava Brown Streak Disease (CBSD)\n")
print("  2        Cassava Green Mottle (CGM)\n")
print("  3        Cassava Mosaic Disease (CMD)\n")
print("  4        Health") 
print("----------------------------------------------")

<div class="list-group" id="list-tab" role="tablist">
<a id="10"></a>
<h3 class="list-group-item list-group-item-action active" data-toggle="list" style='background:coral; border:0; color:blue' role="tab" aria-controls="home">4.2 Visualize Class Mapping</h3>

In [None]:
# let's check class distribution
fig = plt.figure(constrained_layout=True, figsize=(8,6))

sns.countplot(train.label,             
              alpha=0.9,              
              order = train.label.value_counts().sort_values(ascending=False).index   
             )
plt.xlabel("Label")
plt.ylabel("Count")
plt.title('Class Distribution')

plt.show()

Dataset is imbalanced, it has most of the images belonging to disease Cassava Mosaic Disease (CMD)"

<div class="list-group" id="list-tab" role="tablist">
<a id="10"></a>
<h3 class="list-group-item list-group-item-action active" data-toggle="list" style='background:coral; border:0; color:blue' role="tab" aria-controls="home">4.3 Add Image Path to Training Dataset</h3>

In [None]:
# add image path to train
train_dir = "../input/cassava-leaf-disease-classification/train_images"

# update image names with the whole path
def append_ext(fn):
    return train_dir+"/"+fn

train["image_id"]= train["image_id"].apply(append_ext)

<div class="list-group" id="list-tab" role="tablist">
<a id="10"></a>
<h3 class="list-group-item list-group-item-action active" data-toggle="list" style='background:coral; border:0; color:blue' role="tab" aria-controls="home">4.4 Size of first 10 Images</h3>

In [None]:
### Let's check the size of top 10 images
files = train.image_id[:10]
print(Fore.BLUE + "Shape of files from training dataset",Style.RESET_ALL)
for i in range(10):
    im = cv2.imread(files[i])    
    print(im.shape)

We have images are of same size..good!

<div class="list-group" id="list-tab" role="tablist">
<a id="10"></a>
<h3 class="list-group-item list-group-item-action active" data-toggle="list" style='background:coral; border:0; color:blue' role="tab" aria-controls="home">4.5 Few Images from Cassava Bacterial Blight (CBB)</h3>

In [None]:
# visualize few images belonging to disease: Cassava Bacterial Blight (CBB)
class_0 = train[train.label == 0]

images = []

for i in range(1,11):    
    img=cv2.imread(class_0.image_id.iloc[i])   
    image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    images.append(img)

f, ax = plt.subplots(5,2, figsize=(20,15))
for i, img in enumerate(images):        
        ax[i//2, i%2].imshow(img)
        ax[i//2, i%2].axis('off')

These are group images, wont be easy for models to extract patterns and make predictions 

<div class="list-group" id="list-tab" role="tablist">
<a id="10"></a>
<h3 class="list-group-item list-group-item-action active" data-toggle="list" style='background:coral; border:0; color:blue' role="tab" aria-controls="home">4.6 Few Images from Cassava Brown Streak Disease (CBSD)</h3>

In [None]:
# visualize few images belonging to disease Cassava Brown Streak Disease (CBSD)
class_1 = train[train.label == 1]

images = []

for i in range(1,11):    
    img=cv2.imread(class_1.image_id.iloc[i])   
    image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    images.append(img)

f, ax = plt.subplots(5,2, figsize=(20,15))
for i, img in enumerate(images):        
        ax[i//2, i%2].imshow(img)
        ax[i//2, i%2].axis('off')

<div class="list-group" id="list-tab" role="tablist">
<a id="10"></a>
<h3 class="list-group-item list-group-item-action active" data-toggle="list" style='background:coral; border:0; color:blue' role="tab" aria-controls="home">4.7 Cassava Green Mottle (CGM)</h3>

In [None]:
# visualize few images belonging to disease Cassava Green Mottle (CGM)
class_2 = train[train.label == 2]

images = []

for i in range(1,11):    
    img=cv2.imread(class_2.image_id.iloc[i])   
    image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    images.append(img)

f, ax = plt.subplots(5,2, figsize=(20,15))
for i, img in enumerate(images):        
        ax[i//2, i%2].imshow(img)
        ax[i//2, i%2].axis('off')

Same observation for these set of images as well

<div class="list-group" id="list-tab" role="tablist">
<a id="10"></a>
<h3 class="list-group-item list-group-item-action active" data-toggle="list" style='background:coral; border:0; color:blue' role="tab" aria-controls="home">4.8 Cassava Mosaic Disease (CMD)</h3>

In [None]:
# visualize few images from Cassava Mosaic Disease (CMD) disease
class_3 = train[train.label == 3]

images = []

for i in range(1,11):    
    img=cv2.imread(class_3.image_id.iloc[i])   
    image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    images.append(img)

f, ax = plt.subplots(5,2, figsize=(20,15))
for i, img in enumerate(images):        
        ax[i//2, i%2].imshow(img)
        ax[i//2, i%2].axis('off')

<div class="list-group" id="list-tab" role="tablist">
<a id="10"></a>
<h3 class="list-group-item list-group-item-action active" data-toggle="list" style='background:coral; border:0; color:blue' role="tab" aria-controls="home">4.9 Health</h3>

In [None]:
# visualize few images from class "Health"
class_4 = train[train.label == 4]

images = []

for i in range(1,11):    
    img=cv2.imread(class_4.image_id.iloc[i])   
    image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    images.append(img)

f, ax = plt.subplots(5,2, figsize=(20,15))
for i, img in enumerate(images):        
        ax[i//2, i%2].imshow(img)
        ax[i//2, i%2].axis('off')

<div class="list-group" id="list-tab" role="tablist">
<a id="10"></a>
<h2 class="list-group-item list-group-item-action active" data-toggle="list" style='background:coral; border:0; color:blue' role="tab" aria-controls="home"><center>5. Model Training & Helper Functions </center></h2>

Even though we can not make submission with "internet option ON", and hence can not use TPU for submission,but we can traing our model using TPU, save the model, load it in a seperate notebook and make predictions on test data.

In [None]:
# TPU detection  
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
except ValueError:
    tpu = None

# TPUStrategy for distributed training
if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else: # default strategy that works on CPU and single GPU
    strategy = tf.distribute.get_strategy()

print("Number of accelerators: ", strategy.num_replicas_in_sync)

In [None]:
from kaggle_datasets import KaggleDatasets

GCS_DS_PATH = KaggleDatasets().get_gcs_path('cassava-leaf-disease-classification')
print(GCS_DS_PATH) # what do gcs paths look like?

GCS_PATTERN_TRAIN = GCS_DS_PATH + "/train_tfrecords/*.tfrec"
GCS_PATTERN_TEST  = GCS_DS_PATH + "/test_tfrecords/*.tfrec"

<div class="list-group" id="list-tab" role="tablist">
<a id="10"></a>
<h3 class="list-group-item list-group-item-action active" data-toggle="list" style='background:coral; border:0; color:blue' role="tab" aria-controls="home">5.1. Declare Necessary Variables/Parameters</h3>


In [None]:
# define parameters
IMAGE_SIZE = [512, 512] # image size

HEIGHT = 512
WIDTH = 512
CHANNELS = 3

EPOCHS = 30 # no. of epochs to train the model

VALIDATION_SPLIT = 0.19 # split ratio for training & validation datasets

AUTO = tf.data.experimental.AUTOTUNE

filenames = tf.io.gfile.glob(GCS_PATTERN_TRAIN)
split = int(len(filenames) * VALIDATION_SPLIT)

TRAINING_FILENAMES = filenames[split:]
VALIDATION_FILENAMES = filenames[:split]
TEST_FILENAMES = tf.io.gfile.glob(GCS_PATTERN_TEST) 

# classes of disease
CLASSES = ['Cassava Bacterial Blight (CBB)',
           'Cassava Brown Streak Disease (CBSD)',
           'Cassava Green Mottle (CGM)',
           'Cassava Mosaic Disease (CMD)',
           'Health']                                                                                                                 

<div class="list-group" id="list-tab" role="tablist">
<a id="10"></a>
<h3 class="list-group-item list-group-item-action active" data-toggle="list" style='background:coral; border:0; color:blue' role="tab" aria-controls="home">5.2 Helper function - Set # 1</h3>


In [None]:
def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32) / 255.0  # convert image to floats in [0, 1] range
    image = tf.reshape(image, [*IMAGE_SIZE, 3]) # explicit size needed for TPU
    return image

def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "target": tf.io.FixedLenFeature([], tf.int64),  # shape [] means single element
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    label = tf.cast(example['target'], tf.int32)
    return image, label # returns a dataset of (image, label) pairs

def read_unlabeled_tfrecord(example):
    UNLABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "image_name": tf.io.FixedLenFeature([], tf.string),  # shape [] means single element
    }
    example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    idnum = example['image_name']
    return image, idnum # returns a dataset of image(s)

def load_dataset(filenames, labeled=True, ordered=False):   
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False # disable order, increase speed

    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(read_labeled_tfrecord if labeled else read_unlabeled_tfrecord, num_parallel_calls=AUTO)
    # returns a dataset of (image, label) pairs if labeled=True or (image, id) pairs if labeled=False
    return dataset

<div class="list-group" id="list-tab" role="tablist">
<a id="10"></a>
<h3 class="list-group-item list-group-item-action active" data-toggle="list" style='background:coral; border:0; color:blue' role="tab" aria-controls="home">5.3. Data Augmentation</h3>

In [None]:
def data_augment(image, label):
    p_rotation = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_spatial = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_rotate = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_pixel = tf.random.uniform([], 0, 1.0, dtype=tf.float32)    
    p_shear = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_crop = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    
    # Shear
    if p_shear > .2:
        if p_shear > .6:
            image = transform_shear(image, HEIGHT, shear=20.)
        else:
            image = transform_shear(image, HEIGHT, shear=-20.)
    # Rotation
    if p_rotation > .2:
        if p_rotation > .6:
            image = transform_rotation(image, HEIGHT, rotation=45.)
        else:
            image = transform_rotation(image, HEIGHT, rotation=-45.)
    # Flips
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    if p_spatial > .75:
        image = tf.image.transpose(image)
    # Rotates
    if p_rotate > .75:
        image = tf.image.rot90(image, k=3) # rotate 270ยบ
    elif p_rotate > .5:
        image = tf.image.rot90(image, k=2) # rotate 180ยบ
    elif p_rotate > .25:
        image = tf.image.rot90(image, k=1) # rotate 90ยบ
    # Pixel-level transforms
    if p_pixel >= .2:
        if p_pixel >= .8:
            image = tf.image.random_saturation(image, lower=.7, upper=1.3)
        elif p_pixel >= .6:
            image = tf.image.random_contrast(image, lower=.8, upper=1.2)
        elif p_pixel >= .4:
            image = tf.image.random_brightness(image, max_delta=.1)
        else:
            image = tf.image.adjust_gamma(image, gamma=.6)
    # Crops
    if p_crop > .7:
        if p_crop > .9:
            image = tf.image.central_crop(image, central_fraction=.6)
        elif p_crop > .8:
            image = tf.image.central_crop(image, central_fraction=.7)
        else:
            image = tf.image.central_crop(image, central_fraction=.8)
    elif p_crop > .4:
        crop_size = tf.random.uniform([], int(HEIGHT*.6), HEIGHT, dtype=tf.int32)
        image = tf.image.random_crop(image, size=[crop_size, crop_size, CHANNELS])
            
    image = tf.image.resize(image, size=[HEIGHT, WIDTH])

    return image, label

In [None]:
# data augmentation @cdeotte kernel: https://www.kaggle.com/cdeotte/rotation-augmentation-gpu-tpu-0-96
def transform_rotation(image, height, rotation):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly rotated
    DIM = height
    XDIM = DIM%2 #fix for size 331
    
    rotation = rotation * tf.random.uniform([1],dtype='float32')
    # CONVERT DEGREES TO RADIANS
    rotation = math.pi * rotation / 180.
    
    # ROTATION MATRIX
    c1 = tf.math.cos(rotation)
    s1 = tf.math.sin(rotation)
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    rotation_matrix = tf.reshape(tf.concat([c1,s1,zero, -s1,c1,zero, zero,zero,one],axis=0),[3,3])

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(rotation_matrix,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    
    # FIND ORIGIN PIXEL VALUES 
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image, tf.transpose(idx3))
        
    return tf.reshape(d,[DIM,DIM,3])

def transform_shear(image, height, shear):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly sheared
    DIM = height
    XDIM = DIM%2 #fix for size 331
    
    shear = shear * tf.random.uniform([1],dtype='float32')
    shear = math.pi * shear / 180.
        
    # SHEAR MATRIX
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    c2 = tf.math.cos(shear)
    s2 = tf.math.sin(shear)
    shear_matrix = tf.reshape(tf.concat([one,s2,zero, zero,c2,zero, zero,zero,one],axis=0),[3,3])    

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(shear_matrix,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    
    # FIND ORIGIN PIXEL VALUES 
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image, tf.transpose(idx3))
        
    return tf.reshape(d,[DIM,DIM,3])

<div class="list-group" id="list-tab" role="tablist">
<a id="10"></a>
<h3 class="list-group-item list-group-item-action active" data-toggle="list" style='background:coral; border:0; color:blue' role="tab" aria-controls="home">5.4. Helper Functions - Set # 2 </h3>


In [None]:
# another set of helper functions
def get_training_dataset():
    dataset = load_dataset(TRAINING_FILENAMES, labeled=True)
    dataset = dataset.map(data_augment, num_parallel_calls=AUTO)
    dataset = dataset.repeat() # the training dataset must repeat for several epochs
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def get_validation_dataset(ordered=False):
    dataset = load_dataset(VALIDATION_FILENAMES, labeled=True, ordered=ordered)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTO)
    return dataset

def get_test_dataset(ordered=False):
    dataset = load_dataset(TEST_FILENAMES, labeled=False, ordered=ordered)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO)
    return dataset

def count_data_items(filenames):    
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

NUM_TRAINING_IMAGES = count_data_items(TRAINING_FILENAMES)
NUM_VALIDATION_IMAGES = count_data_items(VALIDATION_FILENAMES)
NUM_TEST_IMAGES = count_data_items(TEST_FILENAMES)
print('Dataset: {} training images, {} validation images, {} unlabeled test images'.format(NUM_TRAINING_IMAGES, NUM_VALIDATION_IMAGES, NUM_TEST_IMAGES))

<div class="list-group" id="list-tab" role="tablist">
<a id="10"></a>
<h3 class="list-group-item list-group-item-action active" data-toggle="list" style='background:coral; border:0; color:blue' role="tab" aria-controls="home">5.5 Get Datasets</h3>

In [None]:
# Define the batch size. This will be 16 with TPU off and 128 (=16*8) with TPU on
BATCH_SIZE = 16 * strategy.num_replicas_in_sync

ds_train = get_training_dataset()
ds_valid = get_validation_dataset()
ds_test = get_test_dataset()

print("Training:", ds_train)
print ("Validation:", ds_valid)
print("Test:", ds_test)

In [None]:
y_true = []
for image,label in ds_valid:
    y_true.append(label)

In [None]:
np.set_printoptions(threshold=15, linewidth=80)

print("Training data shapes:")
for image, label in ds_train.take(3):
    print(image.numpy().shape, label.numpy().shape)
print("Training data label examples:", label.numpy())

<div class="list-group" id="list-tab" role="tablist">
<a id="10"></a>
<h3 class="list-group-item list-group-item-action active" data-toggle="list" style='background:coral; border:0; color:blue' role="tab" aria-controls="home">5.6. Visualizing Images again</h3>

In [None]:
def batch_to_numpy_images_and_labels(data):
    images, labels = data
    numpy_images = images.numpy()
    numpy_labels = labels.numpy()
    if numpy_labels.dtype == object: # binary string in this case,
                                     # these are image ID strings
        numpy_labels = [None for _ in enumerate(numpy_images)]
    # If no labels, only image IDs, return None for labels (this is
    # the case for test data)
    return numpy_images, numpy_labels

def title_from_label_and_target(label, correct_label):
    if correct_label is None:
        return CLASSES[label], True
    correct = (label == correct_label)
    return "{} [{}{}{}]".format(CLASSES[label], 'OK' if correct else 'NO', u"\u2192" if not correct else '',
                                CLASSES[correct_label] if not correct else ''), correct

def display_one_flower(image, title, subplot, red=False, titlesize=16):
    plt.subplot(*subplot)
    plt.axis('off')
    plt.imshow(image)
    if len(title) > 0:
        plt.title(title, fontsize=int(titlesize) if not red else int(titlesize/1.2), color='red' if red else 'black', fontdict={'verticalalignment':'center'}, pad=int(titlesize/1.5))
    return (subplot[0], subplot[1], subplot[2]+1)
    
def display_batch_of_images(databatch, predictions=None):
    """This will work with:
    display_batch_of_images(images)
    display_batch_of_images(images, predictions)
    display_batch_of_images((images, labels))
    display_batch_of_images((images, labels), predictions)
    """
    # data
    images, labels = batch_to_numpy_images_and_labels(databatch)
    if labels is None:
        labels = [None for _ in enumerate(images)]
        
    # auto-squaring: this will drop data that does not fit into square
    # or square-ish rectangle
    rows = int(math.sqrt(len(images)))
    cols = len(images)//rows
        
    # size and spacing
    FIGSIZE = 30.0
    SPACING = 0.1
    subplot=(rows,cols,1)
    if rows < cols:
        plt.figure(figsize=(FIGSIZE,FIGSIZE/cols*rows))
    else:
        plt.figure(figsize=(FIGSIZE/rows*cols,FIGSIZE))
    
    # display
    for i, (image, label) in enumerate(zip(images[:rows*cols], labels[:rows*cols])):
        title = '' if label is None else CLASSES[label]
        correct = True
        if predictions is not None:
            title, correct = title_from_label_and_target(predictions[i], label)
        dynamic_titlesize = FIGSIZE*SPACING/max(rows,cols)*40+3 # magic formula tested to work from 1x1 to 10x10 images
        subplot = display_one_flower(image, title, subplot, not correct, titlesize=dynamic_titlesize)
    
    #layout
    plt.tight_layout()
    if label is None and predictions is None:
        plt.subplots_adjust(wspace=0, hspace=0)
    else:
        plt.subplots_adjust(wspace=SPACING, hspace=SPACING)
    plt.show()

In [None]:
# visualize 20 images from ds_train
ds_iter = iter(ds_train.unbatch().batch(20))
one_batch = next(ds_iter)
display_batch_of_images(one_batch)

<div class="list-group" id="list-tab" role="tablist">
<a id="10"></a>
<h3 class="list-group-item list-group-item-action active" data-toggle="list" style='background:coral; border:0; color:blue' role="tab" aria-controls="home">5.7. Train the Model - VGG19 by freezing top 4 layers</h3>

In [None]:
#  Adding a densely connected classifier on top of the convolutional base
with strategy.scope():
    conv_base = VGG19(weights='imagenet',include_top=False,input_shape=[*IMAGE_SIZE, 3])
    
    
    # freezing the layers of pre-trained model
    conv_base.trainable = False
    
    #conv_base.trainable = True
    set_trainable = False

    for layer in conv_base.layers:
        if layer.name == 'block5_conv1' or 'block5_conv2' or 'block5_conv3'or 'block5_conv4':
            set_trainable = True

        if set_trainable:
            layer.trainable = True
        else:
            layer.trainable = False
            
        
    model = Sequential()
    model.add(conv_base) # adding pre-trained conv base model
    model.add(Flatten())
    model.add(Dense(512, activation='relu'))
    model.add(Dense(5, activation='softmax'))    
    
    
    model.compile(
    optimizer=optimizers.RMSprop(),
    loss = 'sparse_categorical_crossentropy',
    metrics=['sparse_categorical_accuracy'])


    model.summary()

In [None]:
# steps per epoch calculation
steps_per_epoch = NUM_TRAINING_IMAGES // BATCH_SIZE
validation_steps = NUM_VALIDATION_IMAGES //BATCH_SIZE

# fit the model on training data and validate on validation dataset 
history = model.fit(ds_train, steps_per_epoch=steps_per_epoch, epochs=30,
                validation_data=ds_valid, validation_steps=validation_steps)

<div class="list-group" id="list-tab" role="tablist">
<a id="10"></a>
<h3 class="list-group-item list-group-item-action active" data-toggle="list" style='background:coral; border:0; color:blue' role="tab" aria-controls="home">5.8 Model Evaluation</h3>

In [None]:
# plotting the training results
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(1, len(acc) + 1)

plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()

plt.figure()
plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()

plt.show()

<div class="list-group" id="list-tab" role="tablist">
<a id="10"></a>
<h3 class="list-group-item list-group-item-action active" data-toggle="list" style='background:coral; border:0; color:blue' role="tab" aria-controls="home">5.9 Save the model</h3>

In [None]:
#model.save('model.h5')
model.save_weights('model.h5')

<div class="list-group" id="list-tab" role="tablist">
<a id="10"></a>
<h2 class="list-group-item list-group-item-action active" data-toggle="list" style='background:coral; border:0; color:blue' role="tab" aria-controls="home"><center>6. Model Evaluation </center></h2>

In [None]:
# predictions and valuation on validation dataset
dataset = get_validation_dataset()
dataset = dataset.unbatch().batch(20)
batch = iter(dataset)

images, labels = next(batch)
probabilities = model.predict(images)
predictions = np.argmax(probabilities, axis=-1)
display_batch_of_images((images, labels), predictions)