<a href="https://colab.research.google.com/github/tcivie/Bone_Marrow_Cells_Classification/blob/main/AI_Project_CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Start by converting the jpeg files to TFRecords

---



##Importing necessary modules and libraries
We first import the necessary modules and libraries, such as os, tensorflow, PIL, numpy, and keras. These modules and libraries are used throughout the notebook for various tasks such as reading and manipulating images, file management, and data preprocessing.

In [None]:
import os
import tensorflow as tf
from PIL import Image
import PIL
import re
import numpy as np
import keras
import cv2

##Mounting the Google Drive folder
We then use the google.colab.drive module to mount the Google Drive folder where the images are stored. The folder is mounted at the location `/content/drive/MyDrive/BM_cytomorphology_data/`. This allows the notebook to access the dataset and perform operations on it.

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
images_folder = '/content/drive/MyDrive/BM_cytomorphology_data/'
tfrecords_folder = '/content/drive/MyDrive/BM_cytomorphology_data_tf'
images_in_tf = 230

##Defining the classes/folders
In this code snippet we define a list of `CLASSES`, which contains the names of all the classes/folders in the dataset. We also define a list of `SMALL_CLASSES`, which contains the names of the classes that have a deficiency of data. These lists are used to access the specific folders and perform operations on them.

In [None]:
CLASSES = ["ABE", "ART", "BAS", "BLA", "EBO", "EOS", "FGC", "HAC", "KSC", "LYI", "LYT", "MMZ", "MON", "MYB", "NGB",
              "NGS", "NIF", "OTH", "PEB", "PLM", "PMO"]

In [None]:
SMALL_CLASSES = ['ABE','BAS','FGC','HAC','KSC','LYI','OTH']

##Generating additional data
Here we use the `keras.preprocessing.image.ImageDataGenerator` module to create more data for the `SMALL_CLASSES`. The module applies different data augmentation techniques such as rotation, rescaling, zoom, and flipping on the original images. These newly generated images are saved to the same folder as the original images with a prefix "augmented_". This step is done to balance the data, by creating more data for the underrepresented classes.

In [None]:
datagen = keras.preprocessing.image.ImageDataGenerator(rotation_range =15,
                         rescale=1./255,
                         zoom_range=0.2, 
                         horizontal_flip = True,
                         vertical_flip = True,
                         fill_mode = 'nearest') 

for class_name in SMALL_CLASSES:
  class_path = os.path.join(images_folder,class_name)
  images_list = os.listdir(class_path)

  # max_index = images_list.map(lambda name: int(re.match('[0-9]{5}',name))) # convert the list of filenames to list of indexes
  # max_index = max(max_index) + 1
  number_of_images = len(images_list)
  print(number_of_images)
  total_images = number_of_images
  for image in images_list:
    if total_images >= 4040:
      break
    i = 0
    image = os.path.join(class_path,image)
    print(image)
    image = keras.preprocessing.image.image_utils.img_to_array(Image.open(image))
    image = image.reshape((1,) + image.shape)

    for batch in datagen.flow(image, batch_size=1, save_prefix='augmented_', save_to_dir =class_path, save_format='jpg'):
      total_images += 1
      i += 1
      if i == number_of_images or total_images >= 4040:
        break

##Counting the number of files in each folder
We then count the number of files in each folder to check that the data is balanced. This is done by iterating through the list of `CLASSES` and counting the number of files in each folder using the `os.listdir()` function.

In [None]:
for class_name in CLASSES:
  class_path = os.path.join(images_folder,class_name)
  images_list = os.listdir(class_path)
  print(class_name + ":" + str(len(images_list)))

##Renaming the newly generated files
Finally, the newly generated files are renamed to match the regular ones. This is done by using regular expressions to find the existing file names and extract their indexes, then the indexes are used to rename the newly generated files. This step is done to ensure that the newly generated files are in the same format as the original files and can be easily used for training the CNN.

In [None]:
for class_name in SMALL_CLASSES:
  class_path = os.path.join(images_folder,class_name)
  images_list = os.listdir(class_path)

  reg = re.compile(r'[A-Z]{3}_[0-9]{5}')
  previous_fies = list(filter(reg.search, images_list))

  max_index = list()
  for file_name in previous_fies:
    max_index.append(int(re.findall(r'[0-9]{5}',file_name)[0])) # convert the list of filenames to list of indexes
  # max_index = previous_fies.map(lambda name: int(re.match('[0-9]{5}',name))) # convert the list of filenames to list of indexes
  max_index = max(max_index) + 1
   
  list_of_files_to_rename = list(filter(lambda x: not reg.search(x), images_list))

  print(class_name,max_index,list_of_files_to_rename)

  for file_name in list_of_files_to_rename:
    os.rename(os.path.join(class_path,file_name),os.path.join(class_path,class_name + '_' + str(max_index).rjust(5,'0') + '.jpg'))
    max_index += 1
  print(class_name,max_index,list_of_files_to_rename)

##Encoding data to TFRecords
The following part is used to encode the images in the dataset into the TFRecords format. The TFRecords format is a binary file format used to store data for TensorFlow.

###Defining helper functions
The first part of the code defines three helper functions:

* `_bytestring_feature`: This function takes a list of bytestrings as an input 
and returns a tf.train.Feature object containing the bytestrings.
* `_int_feature`: This function takes a list of integers as an input and returns a tf.train.Feature object containing the integers.
* `_float_feature`: This function takes a list of floats as an input and returns a tf.train.Feature object containing the floats.

In [None]:
def _bytestring_feature(list_of_bytestrings):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=list_of_bytestrings))

def _int_feature(list_of_ints): # int64
  return tf.train.Feature(int64_list=tf.train.Int64List(value=list_of_ints))

def _float_feature(list_of_floats): # float32
  return tf.train.Feature(float_list=tf.train.FloatList(value=list_of_floats))

###Create example function
The `create_example` function takes an image and its label as input, and returns a `tf.train.Example` object. The function first converts the class label into a one-hot encoded array, which is a binary array used to represent the class label. The function then creates a feature dictionary containing the following keys:

* "image": a bytestring feature containing the image data.
* "class": an int feature containing the class number.
* "one_hot_class": a float feature containing the one-hot encoded class label.

In [None]:
def create_example(image, label):
    class_num = np.argmax(np.array(CLASSES)==label) # 'ART' => 1 (order defined in CLASSES)
    one_hot_class = np.eye(len(CLASSES))[class_num]     # [0, 1, 0, 0, 0] for class #2, cells
    feature = {
      "image": _bytestring_feature([image]), # one image in the list
      "class": _int_feature([class_num]),        # one class in the list
      "one_hot_class": _float_feature(one_hot_class.tolist()) # variable length  list of floats, n=len(CLASSES)
    }
    return tf.train.Example(features=tf.train.Features(feature=feature))

###Writing the data to TFRecords
The rest of the code snippet iterates through the list of `CLASSES`, and for each class:

* It creates a list of image file names in the class folder, by using the re module to match the file names with a specific pattern.
* It groups the image file names into small lists, each containing the number of images defined by the variable 'images_in_tf'.
* It creates a new TFRecords file and writes the images in each small list to the file, by using the tf.io.TFRecordWriter module.

This block of code is reading all the images from the directory and creating tfrecords files from them, it also groups them in specific number of files to make sure the the size of the files are manageable.

In [None]:
counted_images = 0 # To make sure we are not exceeding the number of images in tf
record_file_number = 0
record_file_name = ("train.tfrecords-%.3d" % record_file_number)

for class_name in CLASSES:
  print(class_name)
  class_path = os.path.join(images_folder,class_name)

  images_list = os.listdir(class_path)
  regex = re.compile(r'[0-9]{5}\.jpg$')
  images_list = [i for i in images_list if not regex.match(i)]
  print(images_list)

  samples = list()
  for i in range(0, len(images_list), images_in_tf):
    samples.append(images_list[i:i+images_in_tf])
  for sample_list in samples:
    with tf.io.TFRecordWriter(
          tfrecords_folder + "/file_%.3i.tfrec" % record_file_number
      ) as writer:
        print(tfrecords_folder + "/file_%.3i.tfrec" % record_file_number)
        record_file_number += 1
        for sample in sample_list:
            image_path = os.path.join(class_path,sample)
            image = open(image_path,'rb').read()
            example = create_example(image, class_name)
            writer.write(example.SerializeToString())

## File uploading to the google data storage
Here we upload the TFRecords files to a Google Cloud Storage bucket.

###Authenticating the user
Here we use the `google.colab.auth` module to authenticate. This is necessary to give the notebook access to the Google Cloud Storage bucket.

In [None]:
from google.colab import auth
auth.authenticate_user()

###Uploading files to the bucket
Here we use a shell command that uses the `gsutil` command-line tool to upload the TFRecords files to the bucket. The `-m` flag is used to perform a parallel upload, which speeds up the upload process. The `-r` flag is used to perform a recursive upload, which uploads all the files in the specified directory and its subdirectories. The `cp` command is used to copy the files to the bucket. The source directory is `/content/drive/MyDrive/BM_cytomorphology_data_tf/` and the destination is the bucket `gs://bm_cytomorphology_data_tf/`.

This block of code is uploading the created tfrecords files from the local storage to the google storage bucket. This will allow the data to be easily accessible for any other machine learning tasks that may need it.

In [None]:
!gsutil -m cp -r /content/drive/MyDrive/BM_cytomorphology_data_tf/ gs://bm_cytomorphology_data_tf/

#Create CNN using Google TPU
This code snippet is used to create and train a convolutional neural network (CNN) using Google TPU and TensorFlow.

##Imports
First, the necessary libraries and modules are imported, including TensorFlow, Numpy, and Matplotlib. The version of TensorFlow is also printed to ensure that it is the correct version.

In [None]:
import re, time
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
import plotly.express as px
print("Tensorflow version " + tf.__version__)
AUTOTUNE = tf.data.AUTOTUNE

## TPU or GPU detection
The code then detects whether a TPU or GPU is available for training the model. If a TPU is available, it connects to it and creates a `tf.distribute.TPUStrategy` object. If a TPU is not available, it creates a `tf.distribute.MirroredStrategy` object for training on a GPU or multiple GPUs. The number of available accelerators (TPUs or GPUs) is also printed.

In [None]:
try: # detect TPUs
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect() # TPU detection
    strategy = tf.distribute.TPUStrategy(tpu)
except ValueError: # detect GPUs
    strategy = tf.distribute.MirroredStrategy() # for GPU or multi-GPU machines
    # strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU
    # strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() # for clusters of multi-GPU machines

print("Number of accelerators: ", strategy.num_replicas_in_sync)

## Configuration
We then set some configuration variables, such as the number of training epochs, the image size, and the list of classes.

In [None]:
EPOCHS = 12
IMAGE_SIZE = [250, 250]

DATASET = 'gs://bm_cytomorphology_data_tf/BM_cytomorphology_data_tf/*'
CLASSES = ["ABE", "ART", "BAS", "BLA", "EBO", "EOS", "FGC", "HAC", "KSC", "LYI", "LYT", "MMZ", "MON", "MYB", "NGB",
              "NGS", "NIF", "OTH", "PEB", "PLM", "PMO"]

###Mixed precision
Mixed precision training is a technique that uses lower-precision data types (such as bfloat16) for some model variables and computations to save memory and improve computation speed. If the `MIXED_PRECISION` variable is set to True, the code enables mixed precision training by setting a global policy for TensorFlow.

In [None]:
MIXED_PRECISION = False
if MIXED_PRECISION:
    if tpu: 
        policy = tf.keras.mixed_precision.Policy('mixed_bfloat16')
    else:
        policy = tf.keras.mixed_precision.Policy('mixed_float16')
        tf.config.optimizer.set_jit(True) # XLA compilation
    tf.keras.mixed_precision.set_global_policy(policy)
    print('Mixed precision enabled')

###Batch and learning rate settings
The code then sets batch size and learning rate settings based on the number of available accelerators. These settings are used during the training process to control the number of examples processed at a time and the rate at which the model learns from the data.

In [None]:
if strategy.num_replicas_in_sync == 8: # TPU or 8xGPU
    BATCH_SIZE = 16 * strategy.num_replicas_in_sync
    VALIDATION_BATCH_SIZE = 16 * strategy.num_replicas_in_sync
    start_lr = 0.00001
    min_lr = 0.00001
    max_lr = 0.00005 * strategy.num_replicas_in_sync
    rampup_epochs = 5
    sustain_epochs = 0
    exp_decay = .8
elif strategy.num_replicas_in_sync == 1: # single GPU
    BATCH_SIZE = 16
    VALIDATION_BATCH_SIZE = 16
    start_lr = 0.00001
    min_lr = 0.00001
    max_lr = 0.0002
    rampup_epochs = 5
    sustain_epochs = 0
    exp_decay = .8
else: # TPU pod
    BATCH_SIZE = 8 * strategy.num_replicas_in_sync
    VALIDATION_BATCH_SIZE = 8 * strategy.num_replicas_in_sync
    start_lr = 0.00001
    min_lr = 0.00001
    max_lr = 0.00002 * strategy.num_replicas_in_sync
    rampup_epochs = 7
    sustain_epochs = 0
    exp_decay = .8

###Define learning function and plot the learning curve
The code defines a function 'lrfn' that takes an epoch as input and returns the learning rate for that epoch. The learning rate is calculated based on the current epoch, the starting learning rate, the minimum learning rate, the maximum learning rate, the number of ramp-up epochs, the number of sustain epochs, and an exponential decay rate. The function first checks if the current epoch is less than the number of ramp-up epochs, in which case the learning rate increases linearly from the starting learning rate to the maximum learning rate over the ramp-up period. If the current epoch is greater than the ramp-up period but less than the number of sustain epochs, the learning rate remains at the maximum value. If the current epoch is greater than the sum of the ramp-up and sustain epochs, the learning rate decays exponentially towards the minimum learning rate.

The code also creates a callback function 'lr_callback' which wraps the 'lrfn' function and is used to schedule the learning rate during training. Finally, the code plots the learning curve by generating an array of the learning rates for each epoch and plotting it using matplotlib.

In [None]:
def lrfn(epoch): # Learning function
    def lr(epoch, start_lr, min_lr, max_lr, rampup_epochs, sustain_epochs, exp_decay):
        if epoch < rampup_epochs:
            lr = (max_lr - start_lr)/rampup_epochs * epoch + start_lr
        elif epoch < rampup_epochs + sustain_epochs:
            lr = max_lr
        else:
            lr = (max_lr - min_lr) * exp_decay**(epoch-rampup_epochs-sustain_epochs) + min_lr
        return lr
    return lr(epoch, start_lr, min_lr, max_lr, rampup_epochs, sustain_epochs, exp_decay)
    
lr_callback = tf.keras.callbacks.LearningRateScheduler(lambda epoch: lrfn(epoch), verbose=True)

rng = [i for i in range(EPOCHS)]
y = [lrfn(x) for x in rng]
plt.plot(rng, [lrfn(x) for x in rng])
print(y[0], y[-1])

####Display utilities
The code defines several utility functions for displaying images and image data.

The function `dataset_to_numpy_util` takes a dataset and a number of images as input and returns the first N images and labels from the dataset as numpy arrays.

The function `title_from_label_and_target` takes a label and a correct label as input, converts them from one-hot encoding to class numbers, and returns a string that contains the label, whether the label is correct and the correct label.

The function `display_one_cell` takes an image, a title, a subplot index, and a red flag as input, and displays the image in the specified subplot with the specified title.

The function `display_9_images_from_dataset` takes a dataset as input and displays the first 9 images and labels from the dataset in a 3x3 grid.

The function `display_9_images_with_predictions` takes a list of images, a list of predictions, and a list of labels as input and displays the images in a 3x3 grid with their corresponding predictions and labels.

In [None]:
def dataset_to_numpy_util(dataset, N):
  print("dataset_to_numpy_util: ")
  print(dataset)
  dataset = dataset.unbatch().batch(N)
  for images, labels in dataset:
    numpy_images = images.numpy()
    numpy_labels = labels.numpy()
    break;  
  return numpy_images, numpy_labels

def title_from_label_and_target(label, correct_label):
  # print("title_from_label_and_target: " + label + "\n" + correct_label)
  label = np.argmax(label, axis=-1)  # one-hot to class number
  correct_label = np.argmax(correct_label, axis=-1) # one-hot to class number
  correct = (label == correct_label)
  return "{} [{}{}{}]".format(CLASSES[label], str(correct), ', shoud be ' if not correct else '',
                              CLASSES[correct_label] if not correct else ''), correct

def display_one_cell(image, title, subplot, red=False):
  print("display_one_cell:")
  print(type(image))
  print(title)
  plt.subplot(subplot)
  plt.axis('off')
  plt.imshow(image)
  plt.title(title, fontsize=16, color='red' if red else 'black')
  return subplot+1
  
def display_9_images_from_dataset(dataset):
  # print("display_9_images_from_dataset: ")
  # print(dataset)
  subplot=331
  plt.figure(figsize=(13,13))
  images, labels = dataset_to_numpy_util(dataset, 9)
  for i, image in enumerate(images):
    title = CLASSES[np.argmax(labels[i], axis=-1)]
    subplot = display_one_cell(image, title, subplot)
    if i >= 8:
      break;
              
  #plt.tight_layout() # bug in tight layout in this version of matplotlib
  plt.subplots_adjust(wspace=0.1, hspace=0.1)
  plt.show()
  
def display_9_images_with_predictions(images, predictions, labels):
  # print("display_9_images_with_predictions: ")
  # print(images)
  subplot=331
  plt.figure(figsize=(13,13))
  for i, image in enumerate(images):
    title, correct = title_from_label_and_target(predictions[i], labels[i])
    subplot = display_one_cell(image, title, subplot, not correct)
    if i >= 8:
      break;
              
  #plt.tight_layout() # bug in tight layout in this version of matplotlib
  plt.subplots_adjust(wspace=0.1, hspace=0.1)
  plt.show()
  
def display_training_curves(training, validation, title, subplot):
  if subplot%10==1: # set up the subplots on the first call
    plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
    #plt.tight_layout() # bug in tight layout in this version of matplotlib
  ax = plt.subplot(subplot)
  ax.set_facecolor('#F8F8F8')
  ax.plot(training)
  ax.plot(validation)
  ax.set_title('model '+ title)
  ax.set_ylabel(title)
  #ax.set_ylim(0.28,1.05)
  ax.set_xlabel('epoch')
  ax.legend(['train', 'valid.'])

#Data Loading and Training
We start by importing the Google colab package and authenticating the user. This is necessary for accessing the data from the Colab environment.

In [None]:
from google.colab import auth
auth.authenticate_user()

##Splitting the data into train and validation sets
We then define a function `count_data_items(filenames)` to count the number of data items in the dataset. The number of data items is written in the name of the .tfrec files.

A validation split of 20% is defined and the filenames are shuffled randomly. Then, the dataset is split into training and validation sets using the validation split defined earlier. We also define the number of steps per epoch for the training dataset and print the number of training and validation images.

In [None]:
def count_data_items(filenames):
    # the number of data items is written in the name of the .tfrec files
    n = [int(re.compile(r"_([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

validation_split = 0.20
filenames = tf.io.gfile.glob(DATASET)

import random
random.shuffle(filenames)

split = len(filenames) - int(len(filenames) * validation_split)
TRAIN_FILENAMES = filenames[:split]
VALID_FILENAMES = filenames[split:]
TRAIN_STEPS = count_data_items(TRAIN_FILENAMES) // BATCH_SIZE
print("TRAINING IMAGES: ", count_data_items(TRAIN_FILENAMES), ", STEPS PER EPOCH: ", TRAIN_STEPS)
print("VALIDATION IMAGES: ", count_data_items(VALID_FILENAMES))

##Reading data from TFRecords
We define a function `read_tfrecord(example)` to read the data from the TFRecords. The function receives an example and parse it using the feature definition. The features are "image", "class" and "one_hot_class". The function then decodes the image, casts it to float and scales it to the range [0, 1]. Finally, the function returns the image and one_hot_class.

In [None]:
def read_tfrecord(example):
  print("read_tfrecord:" + example)
  features = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "class": tf.io.FixedLenFeature([], tf.int64),  # shape [] means scalar
        "one_hot_class": tf.io.VarLenFeature(tf.float32),
    }
  example = tf.io.parse_single_example(example, features)
  print(example)
  image = tf.io.decode_jpeg(example['image'], channels=3)
  image = tf.cast(image, tf.float32) / 255.0  # convert image to floats in [0, 1] range
  class_label = tf.cast(example['class'], tf.int32)
  one_hot_class = tf.sparse.to_dense(example['one_hot_class'])
  one_hot_class = tf.reshape(one_hot_class, [len(CLASSES)])
  return image, one_hot_class

##Resizing images
Explicit size will be needed for TPU.

We define a function `force_image_sizes(dataset, image_size)` that reshapes the images to the image size defined in the argument.

In [None]:
def force_image_sizes(dataset, image_size):
    print("force_image_sizes:")
    print(dataset)
    reshape_images = lambda image, label: (tf.reshape(image, [*image_size, 3]), label)
    dataset = dataset.map(reshape_images, num_parallel_calls=AUTOTUNE)
    return dataset

##Loading dataset
We define a function `load_dataset(filenames)` that loads the dataset from the filenames passed as an argument. The function reads the dataset using TFRecordDataset and applies the read_tfrecord function to it. Then the dataset is reshaped using the force_image_sizes function.

In [None]:
def load_dataset(filenames):
    print("load_dataset:")
    print(filenames)
    # read from TFRecords. For optimal performance, use "interleave(tf.data.TFRecordDataset, ...)"
    # to read from multiple TFRecord files at once and set the option experimental_deterministic = False
    # to allow order-altering optimizations.

    opt = tf.data.Options()
    opt.experimental_deterministic = False

    dataset = tf.data.Dataset.from_tensor_slices(filenames)
    dataset = dataset.with_options(opt)
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE) # automatically interleaves reads from multiple files
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)
    dataset = force_image_sizes(dataset, IMAGE_SIZE)
    print(dataset)
    return dataset

##Data augmentation
We define a function `data_augment(image, one_hot_class)` that performs data augmentation on the input image. The function applies random flip and random crop to the image and returns the augmented image and one_hot_class.

In [None]:
def data_augment(image, one_hot_class):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_crop(image, size=[250,250,3])
    image = tf.image.random_brightness(image, max_delta=0.5)
    image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
    return image, one_hot_class

##Function to create a training dataset
###Loading the training dataset

We first load the training dataset by calling the `load_dataset()` function and passing in the `TRAIN_FILENAMES` variable.
###Data Augmentation

We apply data augmentation to the dataset by calling the `data_augment()` function and passing the dataset to it. This will randomly flip and crop the images in the dataset.
###Repeat and Shuffle

We repeat the dataset indefinitely and shuffle it to ensure randomness in the images during training.
###Batching

We batch the dataset with a batch size of `BATCH_SIZE` so that the model can process multiple images at once.
###Prefetching

We prefetch the next batch of images while training using `dataset.prefetch(AUTOTUNE)`. This will improve the performance of the model by reducing the time taken to load the next batch of images.

In [None]:
def get_training_dataset():
    dataset = load_dataset(TRAIN_FILENAMES)
    dataset = dataset.map(data_augment, num_parallel_calls=AUTOTUNE)
    dataset = dataset.repeat()
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTOTUNE) # prefetch next batch while training (autotune prefetch buffer size)
    print("get_training_dataset:")
    print(dataset)
    return dataset

##Function to create a validation dataset
###Loading the validation dataset

We first load the validation dataset by calling the `load_dataset()` function and passing in the `VALID_FILENAMES` variable.
###Batching

We batch the dataset with a batch size of `VALIDATION_BATCH_SIZE` so that the model can process multiple images at once.
###Prefetching

We prefetch the next batch of images while training using d`ataset.prefetch(AUTOTUNE)`. This will improve the performance of the model by reducing the time taken to load the next batch of images.
###TPU Sharding

We set the sharding policy to `tf.data.experimental.AutoShardPolicy.DATA` to disable file sharding policy for TPU 32-core pods.

In [None]:
def get_validation_dataset():
    dataset = load_dataset(VALID_FILENAMES)
    dataset = dataset.batch(VALIDATION_BATCH_SIZE)
    dataset = dataset.prefetch(AUTOTUNE) # prefetch next batch while training (autotune prefetch buffer size)
    
    # needed for TPU 32-core pod: the test dataset has only 3 files but there are 4 TPUs. FILE sharding policy must be disabled.
    opt = tf.data.Options()
    opt.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
    dataset = dataset.with_options(opt)
    print("get_validation_dataset:")
    print(dataset)
    return dataset

##Creating training and validation datasets
We use the `get_training_dataset()` and `get_validation_dataset()` functions to create the training and validation datasets, respectively. These datasets are composed by loading the data from the TFRecords files, applying data augmentation and shuffling on the training dataset and batching the data. We also prefetch the next batch of data to improve performance.

In [None]:
training_dataset = get_training_dataset()
validation_dataset = get_validation_dataset()

###Displaying a sample of the validation dataset
We use the `display_9_images_from_dataset()` function to display a sample of the validation dataset, allowing us to visually inspect the data.

In [None]:
display_9_images_from_dataset(validation_dataset)

## Model
This cell creates a model using the DenseNet201 (We also have other architecures we can use. The DenseNet201 architecture with pre-trained weights from ImageNet. The model is composed of the pre-trained model with additional layers of global average pooling and a dense layer with softmax activation for output. The model is then compiled using the Adam optimizer, categorical crossentropy loss, and accuracy metrics.

In [None]:
def create_model():
    #pretrained_model = tf.keras.applications.MobileNetV2(input_shape=[*IMAGE_SIZE, 3], include_top=False)
    #pretrained_model = tf.keras.applications.Xception(input_shape=[*IMAGE_SIZE, 3], include_top=False)
    #pretrained_model = tf.keras.applications.VGG16(weights='imagenet', include_top=False ,input_shape=[*IMAGE_SIZE, 3])
    #pretrained_model = tf.keras.applications.MobileNet(weights='imagenet', include_top=False, input_shape=[*IMAGE_SIZE, 3])
    #pretrained_model = tf.keras.applications.ResNet50(weights='imagenet', include_top=False, input_shape=[*IMAGE_SIZE, 3])
    # pretrained_model = tf.keras.applications.InceptionV3(weights='imagenet', include_top=False, input_shape=[*IMAGE_SIZE, 3])
    pretrained_model = tf.keras.applications.DenseNet201(weights='imagenet', include_top=False, input_shape=[*IMAGE_SIZE, 3])
    pretrained_model.trainable = True

    model = tf.keras.Sequential([
        pretrained_model,
        tf.keras.layers.GlobalAveragePooling2D(),
        #tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(len(CLASSES), activation='softmax', dtype=tf.float32) # the float32 is needed on softmax layer when using mixed precision
    ])

    model.compile(
        optimizer='adam',
        loss = 'categorical_crossentropy',
        metrics=['accuracy']
    )

    return model
    

In [None]:
with strategy.scope(): # creating the model in the TPUStrategy scope places the model on the TPU
    model = create_model()
model.summary()

## Training
This cell trains the model using the defined TPUStrategy and the training and validation datasets. It tracks the training time, final accuracy, and displays training curves for accuracy and loss.

In [None]:
start_time = time.time()
history = model.fit(training_dataset, validation_data=validation_dataset,
                    steps_per_epoch=TRAIN_STEPS, epochs=EPOCHS, callbacks=[lr_callback])

final_accuracy = history.history["val_accuracy"][-5:]
print("FINAL ACCURACY MEAN-5: ", np.mean(final_accuracy))
print("TRAINING TIME: ", time.time() - start_time, " sec")

In [None]:
print(history.history.keys())
display_training_curves(history.history['accuracy'][1:], history.history['val_accuracy'][1:], 'accuracy', 211)
display_training_curves(history.history['loss'][1:], history.history['val_loss'][1:], 'loss', 212)

## Predictions
This cell creates predictions for a random set of images from the validation dataset and compares them to the true labels. It also evaluates the model on this dataset and displays 9 images with their predictions and labels.

In [None]:
# a couple of images to test predictions too
some_cells, some_labels = dataset_to_numpy_util(validation_dataset, 160)

In [None]:
# randomize the input so that you can execute multiple times to change results
permutation = np.random.permutation(8*20)
some_cells, some_labels = (some_cells[permutation], some_labels[permutation])

predictions = model.predict(some_cells, batch_size=16)
evaluations = model.evaluate(some_cells, some_labels, batch_size=16)
  
print(np.array(CLASSES)[np.argmax(predictions, axis=-1)].tolist())
print('[val_loss, val_acc]', evaluations)

display_9_images_with_predictions(some_cells, predictions, some_labels)

## Save the model
This cell saves the model to a Google Cloud Storage bucket after authenticating the user.

In [None]:
from google.colab import auth
auth.authenticate_user()

In [None]:
import os
SAVE_PATH = 'gs://bm_cytomorphology_data_tf/models/'
MODEL_NAME = 'DenseNet201.h5'
model.save(MODEL_NAME)

In [None]:
!gsutil -m cp -r /content/DenseNet201.h5 gs://bm_cytomorphology_data_tf/models/

## Reload the model
This cell reloads the model and runs a prediction on it.

In [None]:
reload_model = tf.keras.models.load_model(os.path.join(SAVE_PATH,MODEL_NAME))

predictions = reload_model.predict(some_cells, batch_size=16)
evaluations = reload_model.evaluate(some_cells, some_labels, batch_size=16)
print(np.array(CLASSES)[np.argmax(predictions, axis=-1)].tolist())
print('[val_loss, val_acc]', evaluations)
display_9_images_with_predictions(some_cells, predictions, some_labels)