As the amount of data and complexity of the deep learning models increase, the need for more capable hardware platforms also increases. **Tensor Processing Units (TPU)** are custom designed ASICs developed by **Google** specifically for deep learning. TPU v3-8, available on Kaggle, has 8 cores and 128 GB memory. Each core has vector processing unit (VPU) and matrix multiply unit (MXU).

TPU is very fast at data processing, then data should also be fed in a fast way. TFRecord format is suitable for this purpose. If there are lots of files to be read from local disk or from a device on the network  and feeding data to your model is the bottleneck for your training process, try using TFRecord. Instead of locating, opening, reading and closing thousands of files again and again, TFRecord stores data serialized into a few files which are called shard.

Sharding a dataset into multiple files is a good practice because of the following reasons:

* tf.data.Dataset API can read input examples in parallel
* tf.data.Dataset API can shuffle the examples better with sharded files

Shard files are connected to tf.data pipeline. Data is read from all shards in parallel making data consumption very fast.

Consider an image dataset which stores just label and image itself. To convert to TFRecord format, image is converted to tf.train.BytesList and label is converted to tf.train.Int64List. Then they are converted to tf.train.Feature format. Finally features are combined to produce tf.train.Example which are serialized and written to tfrec files sequentially.

In [None]:
!pip install -q efficientnet

In [None]:
import numpy as np
import pandas as pd

import re
import os

import efficientnet.tfkeras as efn
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.layers import GlobalAveragePooling2D
from tensorflow.keras.models import Sequential
from tensorflow.keras.callbacks import LearningRateScheduler

from kaggle_datasets import KaggleDatasets

from sklearn.metrics import confusion_matrix, f1_score
from sklearn.metrics import precision_score, recall_score

import seaborn as sea
import matplotlib.pyplot as plt

In [None]:
sea.set_style("darkgrid")
np.random.seed(3)
tf.random.set_seed(6)

## Initialize TPU

Since TPU is a network connected accelerator, we need to locate it. Then we need to instantiate a TPU strategy. This strategy will enable us to define copies of our model on different cores. Data will be shared among the cores and model training will be accomplished in parallel.

In [None]:
# detect and init the TPU
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)

# instantiate a distribution strategy
tpu_strategy = tf.distribute.experimental.TPUStrategy(tpu)

print("Number of TPU cores\t: ", tpu_strategy.num_replicas_in_sync)
print('Running on TPU\t\t: ', tpu.master())

## Load Dataset

The dataset includes 104 types of flowers from five different public datasets. Dataset is located in **Google Cloud Storage (GCS)** bucket and is available in different resolutions. My choice is 512x512. First we get GCS location.

In [None]:
gcs_path = KaggleDatasets().get_gcs_path("tpu-getting-started")
data_path = os.path.join(gcs_path, "tfrecords-jpeg-512x512/")

train_path = os.path.join(gcs_path, "tfrecords-jpeg-512x512/train/")
val_path = os.path.join(gcs_path, "tfrecords-jpeg-512x512/val/")
test_path = os.path.join(gcs_path, "tfrecords-jpeg-512x512/test/")

Names of 104 flower species are listed below

In [None]:
classes = ['pink primrose',       'hard-leaved pocket orchid',
           'canterbury bells',    'sweet pea', 
           'wild geranium',       'tiger lily',
           'moon orchid',         'bird of paradise',
           'monkshood',           'globe thistle', 
           'snapdragon',          "colt's foot",
           'king protea',         'spear thistle',
           'yellow iris',         'globe-flower',
           'purple coneflower',   'peruvian lily',
           'balloon flower',      'giant white arum lily',
           'fire lily',           'pincushion flower',
           'fritillary',          'red ginger',
           'grape hyacinth',      'corn poppy',
           'prince of wales feathers', 'stemless gentian',
           'artichoke',           'sweet william',         
           'carnation',           'garden phlox', 
           'love in the mist',    'cosmos',
           'alpine sea holly',    'ruby-lipped cattleya',
           'cape flower',         'great masterwort',
           'siam tulip',          'lenten rose',
           'barberton daisy',     'daffodil', 
           'sword lily',          'poinsettia',
           'bolero deep blue',    'wallflower',
           'marigold',            'buttercup',
           'daisy',               'common dandelion', 
           'petunia',             'wild pansy',
           'primula',             'sunflower',
           'lilac hibiscus',      'bishop of llandaff',
           'gaura',               'geranium',
           'orange dahlia',       'pink-yellow dahlia',  
           'cautleya spicata',    'japanese anemone',
           'black-eyed susan',    'silverbush',
           'californian poppy',   'osteospermum',
           'spring crocus',       'iris',
           'windflower',          'tree poppy',
           'gazania',             'azalea',
           'water lily',          'rose',
           'thorn apple',         'morning glory',
           'passion flower',      'lotus',
           'toad lily',           'anthurium',
           'frangipani',          'clematis',
           'hibiscus',            'columbine',
           'desert-rose',         'tree mallow',
           'magnolia',            'cyclamen ',
           'watercress',          'canna lily',
           'hippeastrum ',        'bee balm',
           'pink quill',          'foxglove',
           'bougainvillea',       'camellia',
           'mallow',              'mexican petunia',
           'bromelia',            'blanket flower',
           'trumpet creeper',     'blackberry lily',
           'common tulip',        'wild rose']

## Create Data Pipeline

Each example in the dataset includes image and its label.

In [None]:
def read_record(example):
    features = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "class": tf.io.FixedLenFeature([], tf.int64),
    }
    example = tf.io.parse_single_example(example, features)
    
    image = tf.image.decode_jpeg(example["image"], channels=3)
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.reshape(image, [512,512,3])   
    label = tf.cast(example["class"], tf.int32)
    
    return image, label

def read_record_test(example):
    features = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "id": tf.io.FixedLenFeature([], tf.string),
    }
    example = tf.io.parse_single_example(example, features)
    
    image = tf.image.decode_jpeg(example["image"], channels=3)
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.reshape(image, [512,512,3])
    image_id = example["id"]
    
    return image, image_id

In [None]:
AUTO = tf.data.experimental.AUTOTUNE

def get_size(filenames):    
    # get number of examples from shard filenames
    data_size = 0
    for i in range(len(filenames)):

        count = re.search(r"-[0-9][0-9][0-9]\.", filenames[i]).group()
        data_size = data_size + int(count[1:4])
        
    return data_size

def prepare_dataset(flag, order = False):

    filenames = tf.io.gfile.glob(os.path.join(data_path, flag)+"/*.tfrec")   
    data_size = get_size(filenames)    
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)
    
    # disregard the order of .tfrec files
    ignore_order = tf.data.Options()
    if order == False:
        ignore_order.experimental_deterministic = False
    else:
        ignore_order.experimental_deterministic = True
    dataset = dataset.with_options(ignore_order)
    
    if flag == "test":
        dataset = dataset.map(read_record_test, num_parallel_calls=AUTO)
    else:
        dataset = dataset.map(read_record, num_parallel_calls=AUTO)
        
    return dataset, data_size
    
train_dataset, train_size = prepare_dataset("train")
val_dataset, val_size = prepare_dataset("val")
test_dataset, test_size = prepare_dataset("test")

print("Size of Training Dataset\t: ", train_size)
print("Size of Validation Dataset\t: ", val_size)
print("Size of Test Dataset\t\t: ", test_size)

**train_dataset** is repeated and shuffled. Repeat is necessary because training continues for more than 1 epoch. Shuffling is applied to a subset of dataset in a buffer. The size of the buffer is adjustable. Shuffling is not needed for validation and test sets.

Prefetch is used to prepare next batch of examples while current batch is in use. Remember a copy of our model will be placed on each TPU core. Input data batch will be shared among cores. This is called data parallelism. So batch size should be resized according to the number of available cores.

Augmentation can be incorporated in our data pipeline. For this notebook we will use random contrast adjustment. If you are interested in more advanced augmentation with tf.data, you can refer to my notebook on **Pneumonia Detection** where I describe the augmentation process and transformation functions in detail.

In [None]:
def augment(image, label):    
    
    # random contrast parameters
    cont_low = 0.8
    cont_high = 1.2
    
    cont_factor = tf.random.uniform([], minval=cont_low,
                                        maxval=cont_high,
                                        dtype=tf.float32)

    trn_image = tf.image.adjust_contrast(image, cont_factor)    
    
    return trn_image, label    

In [None]:
# batch_size is scaled with the number of TPU cores
batch_size = 16 * tpu_strategy.num_replicas_in_sync

train_dataset = train_dataset.map(augment,
                num_parallel_calls = AUTO)

train_dataset = train_dataset.repeat().shuffle(3000) \
                .batch(batch_size).prefetch(AUTO)
    
val_dataset = val_dataset.batch(batch_size).prefetch(AUTO)
test_dataset = test_dataset.batch(batch_size).prefetch(AUTO)

Some of the training images are depicted below

In [None]:
show = train_dataset.unbatch().batch(9)
image, label = next(iter(show))

fig, axes = plt.subplots(constrained_layout = True,
                         nrows=3, ncols=3, figsize=(10, 10))

for i in range(3):
    for j in range(3):       
        axes[i][j].imshow(image[i*3+j], aspect="auto")
        axes[i][j].axis("off")
        axes[i][j].title.set_text(classes[label[i*3+j]])

## Training

We will use **EfficientNetB5** with imagenet weights. Model is defined inside **tpu_strategy.scope()**. This way a seperate copy of the model is created on each different core.

In [None]:
with tpu_strategy.scope():
    
    base_model = efn.EfficientNetB5(include_top=False,
                                    input_shape=(512,512,3),
                                    weights='imagenet')
    base_model.trainable = True  
    
    model = Sequential(name="Flower_Detector")
    model.add(base_model)
    model.add(GlobalAveragePooling2D(name="GAP"))
    model.add(Dense(104, activation="softmax", name="Probs"))
    
    model.compile(optimizer=tf.keras.optimizers.Adam(),
                 loss="sparse_categorical_crossentropy", 
                 metrics=["sparse_categorical_accuracy"])
    
model.summary()

We need to define a learning rate schedule function starting with a very small learning rate. Since we are fine-tuning the parameters of the base model, starting with a very small value is necessary for the adaptation of the pretrained coefficients. If we start with a high learning rate, the pretrained coefficients may change abruptly deteriorating the performance of the model.

Also, max learning rate should be scaled in accordance with the batch size. During backpropagation, the gradient may be noisy if batch size is small. Large batches are more dependable so learning rate can be increased as batch size increases.

In [None]:
epoch = 20

# Learning rate scheduler
def decay(inp):   
    lr_init = 0.00005
    # max learning rate is scaled with the number of TPU cores
    lr_max = 0.000125 * tpu_strategy.num_replicas_in_sync
    lin_lr = 5
    if inp <= lin_lr:
        lr = inp*(lr_max - lr_init) / lin_lr + lr_init
    else:
        lr = lr_max * np.exp(-0.1*(inp - lin_lr))
        
    return lr

lrs = LearningRateScheduler(decay)

x = np.linspace(0,epoch,200)
y = [decay(i) for i in x]

plt.xticks(range(0,epoch+1,2))
plt.plot(x,y);
plt.ylabel("Learning Rate");
plt.xlabel("epoch");

In [None]:
step_per_epoch = train_size // batch_size

history = model.fit(train_dataset,
                    validation_data=val_dataset,
                    steps_per_epoch=step_per_epoch,
                    epochs=epoch,
                    callbacks=[lrs])

Loss and sparse categorical accuracy plots are depicted below for train and validation datasets.

In [None]:
e = np.linspace(1, epoch, epoch)

fig, axes = plt.subplots(constrained_layout = True,
                         nrows=1, ncols=2, figsize=(12, 5))

sea.lineplot(x = e, y = history.history['loss'],
             ax=axes[0], label="train");
sea.lineplot(x = e, y = history.history['val_loss'],
             ax=axes[0], label="val");
axes[0].set_ylabel("Loss")
axes[0].set_xlabel("epoch")
axes[0].set_xticks(range(0,epoch+1,2))

sea.lineplot(x = e, y = history.history['sparse_categorical_accuracy'],
             ax=axes[1], label="train");
sea.lineplot(x = e, y = history.history['val_sparse_categorical_accuracy'],
             ax=axes[1], label="val");
axes[1].set_ylabel("Sparse Categorical Accuracy")
axes[1].set_xlabel("epoch")
axes[1].set_xticks(range(0,epoch+1,2));

## Performance Details on Validation Set

During training, read order of tfrec files isn't important. But for validation, we seperate images and labels. Their order is important.

In [None]:
val_dataset, val_size = prepare_dataset("val", True)
val_dataset = val_dataset.batch(batch_size).prefetch(AUTO)

val_image_dataset = val_dataset.map(lambda image,label: image)
val_label_dataset = val_dataset.map(lambda image,label: label).unbatch()
val_labels = next(iter(val_label_dataset.batch(val_size))).numpy()

val_probs = model.predict(val_image_dataset)
val_preds = np.argmax(val_probs, axis=-1)

Compute **confusion matrix**, **precision**, **recall** and **f1-score** for validation dataset.

In [None]:
con_mat = confusion_matrix(val_labels, val_preds,
                           labels=range(len(classes)))

prec = precision_score(val_labels, val_preds,
                       labels=range(len(classes)),
                       average='macro')

rec = recall_score(val_labels, val_preds,
                   labels=range(len(classes)),
                   average='macro')

f1 = f1_score(val_labels, val_preds,
              labels=range(len(classes)),
              average='macro')


plt.figure(figsize=(10,10))
ax = plt.gca()
ax.matshow(con_mat, cmap='Greens')
ax.set_xticks(range(len(classes)))
ax.set_xticklabels(classes, fontdict={'fontsize': 6})
plt.setp(ax.get_xticklabels(), rotation=65, ha="left",
                                rotation_mode="anchor")
ax.set_yticks(range(len(classes)))
ax.set_yticklabels(classes, fontdict={'fontsize': 6})
plt.setp(ax.get_yticklabels(), rotation=25, ha="right",
                                rotation_mode="anchor")
plt.show()

print('Precision \t: {:.4f}'.format(prec))
print('Recall \t\t: {:.4f}'.format(rec))
print('F1 score \t: {:.4f}'.format(f1))

## Prediction on Test Set

Make predictions on test images

In [None]:
test_dataset, test_size = prepare_dataset("test", True)
test_dataset = test_dataset.batch(batch_size).prefetch(AUTO)

test_image_dataset = test_dataset.map(lambda image,idnum: image)
test_id_dataset = test_dataset.map(lambda image,idnum: idnum).unbatch()
test_ids = next(iter(test_id_dataset.batch(test_size))).numpy().astype('U')

test_probs = model.predict(test_image_dataset)
test_preds = np.argmax(test_probs, axis=-1)

## Submission

Prepare and save csv file

In [None]:
#np.savetxt('submission.csv', np.rec.fromarrays([test_ids,test_preds]),
#           fmt=['%s', '%d'], delimiter=',', header='id,label', comments='')