### Define the imports
---

In [None]:
import tensorflow as tf
import pathlib
import os
import io
import logging

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import numpy as np
import pandas as pd

from zipfile import ZipFile

from tensorflow.data.experimental import AUTOTUNE
from tensorflow.keras.applications import resnet50
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.utils import plot_model

from sklearn.model_selection import train_test_split
from sklearn.metrics import ConfusionMatrixDisplay

%load_ext tensorboard


In [None]:
# Remove tensorflow warnings and define the logger
tf.get_logger().setLevel('ERROR')
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
logger = logging.getLogger()

### Define the class to extract the raw data and convert them into TFRecords
---

Converting the data to TFRecords helps to optimise the data pipeline, which makes the overall training process faster.

In [None]:
class DataExtractor():

    '''Define the DataExtractor class used to extract raw data from zip and convert to TFRecords'''

    def __init__(self, random_state=42, test_size=0.1, val_size=0.1):
        self.random_state = random_state
        self.test_size = test_size
        self.val_size = val_size

    @staticmethod
    def _bytes_feature(value):
        '''Boilerplate code to convert an image into Tensorflow Feature'''
        if isinstance(value, type(tf.constant(0))):
            value = value.numpy()
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

    @staticmethod
    def _int64_feature(value):
        '''Boilerplate code to convert an image label into Tensorflow Feature'''
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
    
    @staticmethod
    def serialize_example(image, label):

        '''Function to convert data into tf.train.Example format and serialize it'''

        image_string = open(image, 'rb').read()
        
        image_shape = tf.io.decode_image(image_string, channels=3, expand_animations=False).shape
    
        feature = {
            'image_raw': DataExtractor._bytes_feature(image_string),
            'label': DataExtractor._int64_feature(label),
            'height': DataExtractor._int64_feature(image_shape[0]),
            'width': DataExtractor._int64_feature(image_shape[1]),
            'depth': DataExtractor._int64_feature(image_shape[2]),  
        }

        example_proto = tf.train.Example(features=tf.train.Features(feature=feature))

        return example_proto.SerializeToString()

    def extract_data(self, rawpath, datafolder) -> list:

        '''Function to extract data from zip and returns the labels'''

        self.rawpath = rawpath
        self.datafolder = pathlib.Path(datafolder)

        with ZipFile(self.rawpath, 'r') as file:
            file.extractall(path=self.datafolder)

        self.labels = [path.name for path in self.datafolder.glob('*') if path.is_dir()]

        return self.labels

    def generate_tfrecords(self, tfrfolder):
        
        '''Function to segregate data into train/val/test splits and generate TFRecords'''

        self.tfrfolder = pathlib.Path(tfrfolder)
        self.tfrfolder.mkdir(exist_ok=True)

        segregated_data = {
                            'train': {}, 
                            'val': {}, 
                            'test': {}
                        }
        
        for label in self.labels:
            image_list = [image for image in self.datafolder.joinpath(label).iterdir()]


            train_img, test_img = train_test_split(
                                        image_list, 
                                        test_size=self.test_size, 
                                        shuffle=True, 
                                        random_state=self.random_state)

            train_img, val_img = train_test_split(
                                        train_img, 
                                        test_size=self.val_size/(1-self.test_size), 
                                        shuffle=True, 
                                        random_state=self.random_state)

            segregated_data['train'][label] = train_img
            segregated_data['val'][label] = val_img
            segregated_data['test'][label] = test_img

        for segment, labels in segregated_data.items():
            recordfile = os.path.join(self.tfrfolder, f'{segment}.tfrecord')

            with tf.io.TFRecordWriter(recordfile) as writer:
                for label, img_list in labels.items():
                    for img in img_list:
                        label_cat = self.labels.index(label)
                        serialized_example = DataExtractor.serialize_example(img, label_cat)
                        writer.write(serialized_example)

### Define the class to load the data and perform data augmentation
---
Data augmentation is only applied to training data, hence we should not include it in the model definition.  
We can make use of prefetching to speed up the data pipeline as well.

In [None]:
class DataLoader():

    '''Function to load the data from TFRecords and perform data augmentation'''

    def __init__(self, img_size, batch_size, buffer_size, drop_remainder=True, random_seed=42):
        self.batch_size = batch_size
        self.buffer_size = buffer_size
        self.drop_remainder = drop_remainder
        self.img_size = img_size
        self.random_seed = random_seed
        self.rng = tf.random.Generator.from_seed(self.random_seed, alg='philox')

        logdir = "logs/train_im/"
        self.file_writer = tf.summary.create_file_writer(logdir)

    @staticmethod
    def parse_image(example_proto):
        
        '''Function to parse the image, applied to all train/val/test data'''

        feature_description = {
            'image_raw': tf.io.FixedLenFeature([], tf.string),
            'label': tf.io.FixedLenFeature([], tf.int64),
            'height': tf.io.FixedLenFeature([], tf.int64),
            'width': tf.io.FixedLenFeature([], tf.int64),
            'depth': tf.io.FixedLenFeature([], tf.int64),  
        }

        features = tf.io.parse_single_example(example_proto, feature_description)
        
        image = tf.io.decode_image(features['image_raw'], channels=3, expand_animations = False)
        image = tf.image.resize(image, IMG_SIZE)

        label = tf.cast(features['label'], tf.int32)

        return (image, label)

    @staticmethod
    def plot_to_image(figure):

        '''Boilerplate code to save a matplotlib figure to png image'''

        buf = io.BytesIO()
        plt.savefig(buf, format='png')
        # Closing the figure prevents it from being displayed directly inside
        # the notebook.
        plt.close(figure)
        buf.seek(0)
        # Convert PNG buffer to TF image
        image = tf.image.decode_image(buf.getvalue(), channels=4, expand_animations=False)
        # Add the batch dimension
        image = tf.expand_dims(image, 0)
        return image

    def wrapper_image_augment(self, image, label):
        
        '''Wrapper function containing the data augmentation code'''

        seed = self.rng.make_seeds(2)[0]

        def image_augment(image, label, seed):
            
            # Uncomment these cells to experiment with data augmentation

            # image = tf.image.stateless_random_flip_left_right(image, seed=seed)
            # image = tf.image.stateless_random_flip_up_down(image, seed=seed)
            # image = tf.image.rot90(image)
            # image = tf.image.stateless_random_brightness(image, max_delta=0.1, seed=seed)
            # image = tf.image.stateless_random_hue(image, max_delta=0.2, seed=seed)
            # image = tf.image.stateless_random_saturation(image, lower=0.5, upper=0.7, seed=seed)
            # image = tf.image.stateless_random_contrast(image, lower=0.4, upper=0.6, seed=seed)
            
            return (image, label)

        return image_augment(image, label, seed)

    def preprocess_data(self, data_path):

        '''Function used to preprocess val/test data - augmentation not applied'''

        dataset = tf.data.TFRecordDataset(data_path)\
                    .map(DataLoader.parse_image, num_parallel_calls=AUTOTUNE)\
                    .shuffle(self.buffer_size, self.random_seed)\
                    .batch(self.batch_size, drop_remainder=self.drop_remainder)\
                    .repeat()\
                    .prefetch(buffer_size=AUTOTUNE)

        return dataset

    def preprocess_train(self, train_path):

        '''Function used to preprocess training data - augmentation applied'''

        dataset = tf.data.TFRecordDataset(train_path)\
                    .map(DataLoader.parse_image, num_parallel_calls=AUTOTUNE)\
                    .map(self.wrapper_image_augment, num_parallel_calls=AUTOTUNE)\
                    .shuffle(self.buffer_size, self.random_seed)\
                    .batch(self.batch_size, drop_remainder=self.drop_remainder)\
                    .repeat()\
                    .prefetch(buffer_size=AUTOTUNE)

        return dataset

    def read_tfrecords(self, train_path, val_path, test_path):

        '''Function to read TFRecords from path and apply preprocessing'''

        self.train_data = self.preprocess_train(train_path)
        self.val_data = self.preprocess_data(val_path)
        self.test_data = self.preprocess_data(test_path)

        return (self.train_data, self.val_data, self.test_data)

    def plot_train_images(self, steps, labels):

        '''Function to generate training images to be visualised in TensorBoard'''

        data_X = []
        data_y = []
        for X, y in self.train_data.take(steps).unbatch().as_numpy_iterator():
            data_X.append(X)
            data_y.append(y)

        df = pd.DataFrame({'X': data_X, 'y': data_y})

        grid_size = np.ceil(np.sqrt(len(df))).astype(int)
        df['label_y'] = df['y'].map(lambda x: labels[x])

        fig = plt.figure(figsize=(25, 25))

        grid = ImageGrid(fig, 111, nrows_ncols=(grid_size, grid_size), axes_pad=0.3)

        for idx in range(len(df)):
            img = df['X'][idx] /255.0
            grid[idx].set_axis_off()
            grid[idx].set_title(df['label_y'][idx])
            grid[idx].imshow(img)
        
        img = DataLoader.plot_to_image(fig)

        with self.file_writer.as_default():
            tf.summary.image("Training Images", img, step=0)


### Define the model
---
ResNet50 is used as the feature extractor, and a simple Dense layer ontop of a GlobalAveragePooling2D layer is used as the classifier.

In [None]:
class MyModel(Model):

    '''
    Define an image classification model using ResNet50 as feature extractor
    Subclass from tensorflow.keras.models.Model
    '''

    def __init__(self, n_classes, **kwargs):
        super().__init__(**kwargs)

        # Define the feature extractor and preprocessing fn
        self.base = resnet50.ResNet50(
                    input_shape=(224, 224, 3), 
                    include_top=False, 
                    weights='imagenet')
        self.preprocess = resnet50.preprocess_input

        # Freeze the base model
        for layer in self.base.layers:
            layer.trainable = False

        # Define the classifier
        self.global_pool = GlobalAveragePooling2D()
        self.classifier = Dense(n_classes)

    def call(self, input_tensor):
        x = self.preprocess(input_tensor) 
        x = self.base(x)
        x = self.global_pool(x)

        return self.classifier(x)


### Define the training steps in a ModelTrainer Class
---
This is the most complicated step in the entire workflow.  

Here we have defined `generate_confusion_matrix` and `error_analysis` to produce visualisation at the end of each epoch.  
More details at further down in the notebook.  

In [None]:
class ModelTrainer():

    '''Define a ModelTrainer class containing all the steps required to train the model'''

    def __init__(self, model, train_data, val_data, optimizer, epochs, steps_per_epoch_train, steps_per_epoch_val, labels):
        self.model = model
        self.train_data = train_data
        self.val_data = val_data
        self.optimizer = optimizer
        self.epochs = epochs
        self.steps_per_epoch_train = steps_per_epoch_train
        self.steps_per_epoch_val = steps_per_epoch_val
        self.labels = labels
        
        # Used to save the model at regular intervals
        self.checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, model=self.model)
        self.checkpoint_manager = tf.train.CheckpointManager(self.checkpoint, './tf_ckpts', max_to_keep = 5)

        # Define the train and validation metrics
        self.loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        self.train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
        self.val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
        self.train_loss = tf.keras.metrics.Mean(name='loss')
        self.val_loss = tf.keras.metrics.Mean(name='val_loss')

        # Define the directory and summary writer for TensorBoard logs
        log_dir_train = 'logs/train/'
        log_dir_val = 'logs/val/' 
        log_dir_cm = 'logs/confusion_matrix/' 
        log_dir_ea = 'logs/error_analysis/' 

        self.train_summary_writer = tf.summary.create_file_writer(log_dir_train)
        self.val_summary_writer = tf.summary.create_file_writer(log_dir_val)
        self.confusion_matrix_writer = tf.summary.create_file_writer(log_dir_cm)
        self.error_analysis_writer = tf.summary.create_file_writer(log_dir_ea)

        # Define an epoch counter to continue graph when training restarts
        self.epoch_counter = 1
        self.best_epoch_counter = 1

    @staticmethod
    def plot_to_image(figure):

        '''Boilerplate code to convert matplotlib image to png'''

        # Save the plot to a PNG in memory.
        buf = io.BytesIO()
        plt.savefig(buf, format='png')

        # Closing the figure prevents it from being displayed directly inside
        # the notebook.
        plt.close(figure)
        buf.seek(0)

        # Convert PNG buffer to TF image
        image = tf.image.decode_image(buf.getvalue(), channels=4, expand_animations=False)

        # Add the batch dimension
        image = tf.expand_dims(image, 0)

        return image

    def generate_confusion_matrix(self, val_y, logits, step):

        '''Function to generate the confusion matrix'''

        # Apply the activation function and obtain the prediction
        val_pred = tf.nn.softmax(logits)   
        val_pred = np.argmax(val_pred, axis=1)

        # Generate the confusion matrix figure
        display = ConfusionMatrixDisplay.from_predictions(val_y, val_pred, display_labels=self.labels)
        display.ax_.tick_params('x', labelrotation=45.0)
        fig = display.figure_
        fig.tight_layout(pad=3.0)

        # Convert the matplotlib figure to png
        image = ModelTrainer.plot_to_image(fig)

        # Write the png image to TensorBoard logs
        with self.confusion_matrix_writer.as_default():
            tf.summary.image("Confusion Matrix", image, step=step)

    def error_analysis(self, val_X, val_y, logits, step):
        
        '''Function to display all validation images wrongly classified'''

        # Apply activation and obtain predictions
        val_pred = tf.nn.softmax(logits)  
        val_pred = np.argmax(val_pred, axis=1)
        
        # Compare predictions to ground truth and obtain misclassified images
        misclass_X = []
        misclass_y = []

        for idx, instance in enumerate(zip(val_X, val_y)):
            val_X, val_y = instance
            if val_y != val_pred[idx]:
                misclass_X.append(val_X)
                misclass_y.append(val_pred[idx])

        misclass_df = pd.DataFrame({'misclass_X': misclass_X, 'misclass_y': misclass_y})
        
        # Set display grid 
        grid_size = np.ceil(np.sqrt(len(misclass_df))).astype(int)
        
        # Obtain predicted labels for images 
        misclass_df['val_label'] = misclass_df['misclass_y'].map(lambda x: self.labels[x])

        # Obtain matplotlib figure of misclassified images
        fig = plt.figure(figsize=(15, 15))

        grid = ImageGrid(fig, 111, nrows_ncols=(grid_size, grid_size), axes_pad=0.3)

        for idx in range(len(misclass_df)):
            img = misclass_df['misclass_X'][idx] / 255.0

            grid[idx].imshow(img)
            grid[idx].set_axis_off()
            grid[idx].set_title(misclass_df['val_label'][idx])

        # Convert matplotlib figure to png
        fig.tight_layout(pad=3.0)
        image = ModelTrainer.plot_to_image(fig)

        # Write the png image to TensorBoard logs
        with self.error_analysis_writer.as_default():
            tf.summary.image("Error Analysis", image, step=step)

    def train_step(self, data):
        
        '''Function to test on a batch of training data'''

        inputs, labels = data

        # Perform automatic differentiation
        with tf.GradientTape() as tape:
            logits = self.model(inputs, training=True)
            step_loss = self.loss_fn(labels, logits)

        # Update the trainable variables based on the optimizer and gradients calculated
        gradients = tape.gradient(step_loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        # Update the train accuracy and loss
        self.train_acc_metric.update_state(labels, logits)
        self.train_loss.update_state(step_loss)

        return (step_loss, logits)

    def val_step(self, data):

        '''Function to test on a batch of validation data'''

        inputs, labels = data
        
        # Obtain the logits and loss
        logits = self.model(inputs, training=False)
        step_loss = self.loss_fn(labels, logits)

        # Update the validation accuracy and loss
        self.val_acc_metric.update_state(labels, logits)
        self.val_loss.update_state(step_loss)

        return (step_loss, logits)

    def predict_step(self, data):

        '''Function to predict on a batch of test data'''

        inputs, labels = data
        
        # Obtain the logits and loss
        logits = self.model(inputs, training=False)
        step_loss = self.loss_fn(labels, logits)

        return (step_loss, logits)
    
    def train_and_evaluate(self):

        '''Function to perform training and validation over the entire datasets'''

        logger.info('[INFO] Start Training')
        
        # Check if checkpoints exists, and restore training from checkpoints if available.
        self.checkpoint.restore(self.checkpoint_manager.latest_checkpoint)
        if self.checkpoint_manager.latest_checkpoint:

            # Continue graph from best epoch
            self.epoch_counter = self.best_epoch_counter
            logger.info(f'[INFO] Restoring training from epoch: {self.epoch_counter} from ckpt: {self.checkpoint_manager.latest_checkpoint}')
        else:
            logger.info(f'[INFO] Training from scratch')

            # Reset epoch counter if training from scratch
            self.epoch_counter = 1
            self.best_epoch_counter = 1
    
        patience = 5 # wait 5 epochs to see if there is any improvement before terminating training
        wait = 0 
        best = 0

        # Run the training over the specified no. of epochs
        for epoch in range(self.epochs):

            # Perform training on a batch of training data
            logger.info(f'[INFO] Performing training on epoch: {self.epoch_counter}')
            train_count = self.steps_per_epoch_train
            for step, batch in enumerate(self.train_data):
                if train_count == 0:
                    break
                step_loss, logits = self.train_step(batch) 
                train_count -= 1
            
            # Perform validation on a batch of validation data
            logger.info(f'[INFO] Performing validation on epoch: {self.epoch_counter}')
            
            val_X = []
            val_y = []
            logit_list = []
            
            val_count = self.steps_per_epoch_val
            for step, batch in enumerate(self.val_data):
                
                if val_count == 0:
                    break

                step_loss, logits = self.val_step(batch)

                X, y = batch
                val_X.extend(X)
                val_y.extend(y)
                logit_list.extend(logits)

                val_count -= 1
            
            # Obtain the train and validation accuracy for the epoch
            epoch_train_acc = self.train_acc_metric.result()
            epoch_train_loss = self.train_loss.result()
            epoch_val_acc = self.val_acc_metric.result()
            epoch_val_loss = self.val_loss.result()
            
            # Generate confusion matrix and error analysis at end of each epoch
            self.generate_confusion_matrix(val_y, logit_list, self.epoch_counter)
            self.error_analysis(val_X, val_y, logit_list, self.epoch_counter)

            # Log the results to console
            logger.info(f'[INFO] Epoch: {self.epoch_counter} | loss: {epoch_train_loss:0.2f} | accuracy: {epoch_train_acc:0.2f} | val_loss: {epoch_val_loss:0.2f} | val_accuracy: {epoch_val_acc:0.2f}')

            # Not really needed yet... Just in case.
            epoch_logs = {
                            'loss': epoch_train_loss, 
                            'accuracy': epoch_train_acc, 
                            'val_loss': epoch_val_loss, 
                            'val_accuracy': epoch_val_acc
                        }

            # Write the metrics to summary for TensorBoard
            with self.train_summary_writer.as_default():
                tf.summary.scalar('loss', epoch_train_loss, step=self.epoch_counter)
                tf.summary.scalar('accuracy', epoch_train_acc, step=self.epoch_counter)

            with self.val_summary_writer.as_default():
                tf.summary.scalar('loss', epoch_val_loss, step=self.epoch_counter)
                tf.summary.scalar('accuracy', epoch_val_acc, step=self.epoch_counter)

            # Reset the training and validation metrics after every epoch
            self.train_acc_metric.reset_states()
            self.train_loss.reset_states()
            self.val_acc_metric.reset_states()
            self.val_loss.reset_states()

            # Implement early stopping
            wait += 1
            if epoch_val_acc > best:
                
                self.best_epoch_counter = self.epoch_counter
                best = epoch_val_acc
                wait = 0

                # Save the checkpoints only if the validation accuracy is the best observed
                save_path = self.checkpoint_manager.save()
                logger.info(f'[INFO] Saved checkpoint on epoch {self.epoch_counter} at: {save_path}')

            if wait >= patience:

                # Stop training once number of epochs without improvement exceeds patience
                logger.info(f'[INFO] Executing early stopping at epoch {self.epoch_counter}, best val_accuracy seen: {best:0.2f} at epoch: {self.best_epoch_counter}')
                break
            
            # Increment epoch counter by 1 at end of epoch
            self.epoch_counter += 1
    
    def test(self, test_dataset, steps_per_epoch_test):

        test_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
        test_loss = tf.keras.metrics.Mean(name='loss')

        test_count = steps_per_epoch_test
        for step, batch in enumerate(test_dataset):
            if test_count == 0:
                break
            
            step_loss, logits = self.predict_step(batch)
            _, labels = batch

            # Update the validation accuracy and loss
            test_acc_metric.update_state(labels, logits)
            test_loss.update_state(step_loss)
            test_count -= 1

        accuracy = test_acc_metric.result()
        loss = test_loss.result()
        
        test_acc_metric.reset_states()
        test_loss.reset_states()

        return (accuracy, loss)

### Define the parameters and paths
---

In [None]:
# Define the folder paths containing the raw, extracted and tfrecords
rawpath = r'../data/raw/pokemon.zip'
datafolder = r'../data/extracted'
tfrfolder = r'../data/tfrecords'

# Define the paths to store the tfrecords
train_path = os.path.join(*[os.pardir, 'data', 'tfrecords', 'train.tfrecord'])
val_path = os.path.join(*[os.pardir, 'data', 'tfrecords', 'val.tfrecord'])
test_path = os.path.join(*[os.pardir, 'data', 'tfrecords', 'test.tfrecord'])

# Define the training parameters
RANDOM_STATE = 42
BATCHSIZE = 16 
BUFFERSIZE = 250
EPOCHS = 50 # Just set this as high as you like - there are implementations to stop training when appropriate
LEARNING_RATE = 0.001

TRAIN_SIZE = 0.6
TEST_SIZE = 0.2
VAL_SIZE = 0.2

n_train = int(TRAIN_SIZE*250)
n_val = int(VAL_SIZE*250)
n_test = int(TEST_SIZE*250)

steps_per_epoch_val = n_val // BATCHSIZE
steps_per_epoch_test = n_test // BATCHSIZE
steps_per_epoch_train = n_train // BATCHSIZE

IMG_SIZE = [224, 224]

optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)

# Define the model save path
save_path = '../models/'

### Initialise and execute the classes created
---

In [None]:
# Run this cell to reset the files created in previous runs
!rm -rf ../data/extracted/
!rm -rf ../data/tfrecords/
!rm -rf ./logs/
!rm -rf ./tf_ckpts/

In [None]:
# Perform data extraction and convert data into tfrecords
dataextractor = DataExtractor(random_state=RANDOM_STATE, test_size=TEST_SIZE, val_size=VAL_SIZE)
labels = dataextractor.extract_data(rawpath, datafolder)
dataextractor.generate_tfrecords(tfrfolder)

In [None]:
# Load the data from tfrecords and perform data augmentation
dataloader = DataLoader(img_size=IMG_SIZE, batch_size=BATCHSIZE, buffer_size=BUFFERSIZE, drop_remainder=True, random_seed=42)
train_dataset, val_dataset, test_dataset = dataloader.read_tfrecords(train_path, val_path, test_path)

# Plot the training images in TensorBoard
img = dataloader.plot_train_images(steps=steps_per_epoch_train, labels=labels)

In [None]:
# Initialise the model 
model = MyModel(n_classes=5)

# Initialise the Model Trainer
model_trainer = ModelTrainer(
                          model=model, 
                          train_data=train_dataset, 
                          val_data=val_dataset, 
                          optimizer=optimizer, 
                          epochs=EPOCHS, 
                          steps_per_epoch_train=steps_per_epoch_train,
                          steps_per_epoch_val=steps_per_epoch_val,
                          labels=labels
                        )

In [None]:
# Perform training
model_trainer.train_and_evaluate()

# NOTE: 
# There are a couple of warnings and unrelated info automatically generated 
# by Tensorflow during training. Instead of relying on the console output, navigate to 
# Tensorboard to view the visualisations and metrics that are generated instead


### Perform error analysis with the use of TensorBoard
---

While the model is training, we can make use of TensorBoard to view the visualisations and metrics generated.  
This includes the confusion matrix and error analysis plot that will be generated at the end of each epoch.  

To run TensorBoard, run the code in any cell:
```
%tensorboard --logdir logs
```
In vscode, there might be some requirements to download extensions.  
Download any extensions required, and once the code is executed a prompt might appear:  
![image info](../unrelated_imgs/tensorboard_vscode.png)  

Choose the option to 'select another folder', and choose the logs directory.  
It might take sometime for the outputs to be generated as the model takes time to complete the first epoch.  

There will be two tabs, one for the 'scalar', and one for the 'images'.  
For this project, the output will be automatically generated at the end of each epoch and they can be refreshed during training by refreshing Tensorboard.  

#### Scalar tab
![image info](../unrelated_imgs/tensorboard_scalartab.png)  

The scalar tab has been configured to show the train/val accuracy and loss at the end of each epoch.  
This was mostly used to observe overfitting/underfitting and also as a sanity check to ensure that the model can actually learn.  

#### Images tab
![image info](../unrelated_imgs/tensorboard_imagetab.png)  

The images tab is perhaps more important than the scalar tab as this was used to understand the model before and during training.  

#### Training Images
![image info](../unrelated_imgs/tensorboard_trainimg.png)  

The training images plot shows us the images that are used in our model training.  
Observing closely, there are a couple images with multiple pokemon in it, but they are mostly labelled as `squirtle`.  
This is likely to affect the model as it might learn to classify images with multiple pokemon features as `squirtle`.  

#### Confusion Matrix  
![image info](../unrelated_imgs/tensorboard_cm.png)  

This was mostly used to observe what the model is actually trying to do, and the training can be terminated early if the model starts becoming 'stupid'.  
Initially there was a code bug which forces the model to only predict all images as 'squirtle', and fortunately this was discovered before wasting unnecessary time on training.  


#### Error Analysis
![image info](../unrelated_imgs/tensorboard_erroranalysis.png)  


This is perhaps the most useful plot generated.  
The plot shows all the misclassified images and the corresponding label predicted.  
This is mainly used to get an understanding on where we might need to improve on our data collection.  
For example, in the image above we can tell that the model is classifing images with multiple pokemon features as `squirtle`.  

### Conclusion
1. Observing the confusion matrix, it appears that the model is doing pretty well classifing `pikachu` and `mewtwo`.  
2. The model also appears to have the tendency to misclassify some `bulbasaur` and `charmander` as `squirtle`, and this might be due to the inclusion of images  
with all three pokemon labelled as `squirtle` in our training data.
3. Looking at the misclassified images, the model appears to be classifying all images with multiple pokemons as `squirtle`.  This is understandable if we look back to  
the training images, where we observe that the labels for images with multiple  pokemons are mostly `squirtle` as well.  

One way to test the hypotheses above is to perhaps visualise the feature maps from the model layers.  
One possible remedy is to remove the images with multiple pokemon from our dataset, and to focus on adding in images of  
`squirtle`, `bulbasaur` and `charmander`.

Due to time constraint, I am not able to perform the above two tasks prior to assignment submission.  
Hyperparameter tuning on the model architecture and model hyperparameters was also not performed.  
Only Resnet50 was used, and exploration of other feature extractors was not explored.  
So these points may be considered as a 'next steps' or a shortcoming of the project.


In [None]:
%tensorboard --logdir logs

### Perform fine tuning 
---
Aside from the points highlighted in the write up above, we can try to improve the model further by either: 
1. Data Augmentation (modify data), and/or
2. Unfreezing some layers of the feature extractor and continue training  

Data augmentation can be performed by uncommenting the codes in the `DataLoader` class and rerunning the whole project.  
The model has also been designed to allow for fine tuning by continuing from the best checkpoint observed.  

Here we will explore unfreezing some of the layers in ResNet50 for training.  

In [None]:
# First look at the overall structure of the model.
model.summary()

The complete model consists of a resnet50 feature extractor and a dense layer on top of a global average pooling layer.  
We can view the base model summary by explicitly calling on it  

In [None]:
base_layers = model.base.layers
print(f'Total layers in resnet50: {len(base_layers)}')
model.base.summary()

The text output might not be easy to understand for ResNet50 with residual blocks.  
As such, we can try to plot the model instead.  

There might be some requirements to install extensions in vscode

In [None]:
# plot_model(model.base)

Unfreeze the layers from conv5_block3_1_conv layer onwards

In [None]:
print(model.base.layers[165].name)

for layer in model.base.layers[165:]:
    layer.trainable = True

Reconfigure the optimizer and continue training from where we left off  
Notice that the training will continue from the previous best checkpoint and not from the last epoch trained

In [None]:
# Might be a good idea to lower the learning rate to avoid overfitting
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE/10)

# Initialise a new ModelTrainer with the new optimizer
model_trainer.optimizer = optimizer

# Continue training from where we left off earlier
model_trainer.train_and_evaluate()

It is possible to keep retraining for as many times as patience allows, given that we have implemented early stopping mechanism to stop the training  
once no improvement has been observed for a number of epochs. Retraining will also continue off from the point where the model performance is highest,  
and not from the last epoch trained.

Visualise with TensorBoard

In [None]:
%tensorboard --logdir logs

### Evaluate the model on the test set
---

In [None]:
test_accuracy, test_loss = model_trainer.test(test_dataset, steps_per_epoch_test)

print(f'Test Accuracy: {test_accuracy:0.2f}')
print(f'Test Loss: {test_loss:0.2f}')


### Generate a prediction on a single data
---
Keep rerunning the cell to generate predictions on different test images

In [None]:
# Obtain a single data from the test dataset
for data in test_dataset.take(1).unbatch().as_numpy_iterator():
    X, y = data
    break

batched_data = (X[np.newaxis, ...], y[np.newaxis, ...])
_, logits = model_trainer.predict_step(batched_data)

activated_logits = tf.nn.softmax(logits)
pred = np.argmax(activated_logits)

fig, ax = plt.subplots(figsize=(5, 5))
ax.imshow(X/255.0)
ax.set_title(f'Actual label: {labels[y]}')
ax.set_axis_off()

df = pd.DataFrame(labels, columns = ['pred_class'])
df['confidence'] = activated_logits.numpy().reshape(5, 1)
df = df.sort_values(by='confidence', ascending=False).reset_index(drop=True)
fig.tight_layout()

print(df)
