The aim of this kernel was to create a baseline for Multi-label classification for the [Human Protein Atlas - Single Cell Classification](https://www.kaggle.com/ayuraj/hpa-multi-label-classification-with-tf-and-w-b) challenge using Tensorflow and use TPUv3-8 (refer to [this link](https://www.kaggle.com/docs/tpu) for detailed documentation of usage of TPUv3-8 on Kaggle Kernels).

## Dependencies

In [None]:
!pip install -qq visualkeras

In [None]:
import os
import visualkeras
import numpy as np
import pandas as pd
from glob import glob
import tensorflow as tf
from PIL import ImageFont
from typing import List, Tuple
from collections import Counter
import plotly.graph_objects as go
from matplotlib import pyplot as plt
from plotly.subplots import make_subplots
from kaggle_datasets import KaggleDatasets

## Basic Exploratory Analysis

In [None]:
# Refer https://www.kaggle.com/c/hpa-single-cell-image-classification/data

LABEL_NAMES = [
    "Nucleoplasm", "Nuclear Membrane", "Nucleoli",
    "Nucleoli Fibrillar Center", "Nuclear Speckles",
    "Nuclear Bodies", "Endoplasmic Reticulum", "Golgi Apparatus",
    "Intermediate Filaments", "Actin Filaments", "Microtubules",
    "Mitotic Spindle", "Centrosome", "Plasma Membrane", "Mitochondria",
    "Aggresome", "Cytosol", "Vesicles", "Negative"
]

In [None]:
GCS_PATH = KaggleDatasets().get_gcs_path('hpa-single-cell-image-classification')

TRAIN_IMAGE_FILES = glob('../input/hpa-single-cell-image-classification/train/*.png')
TEST_IMAGE_FILES = glob('../input/hpa-single-cell-image-classification/test/*.png')

TRAIN_TFRECORDS = tf.io.gfile.glob(os.path.join(GCS_PATH, 'train_tfrecords/*.tfrec'))
TEST_TFRECORDS = tf.io.gfile.glob(os.path.join(GCS_PATH, 'train_tfrecords/*.tfrec'))

print('Number of Train Images:', len(TRAIN_IMAGE_FILES))
print('Number of Test Images:', len(TEST_IMAGE_FILES))
print('Number of Train TFRecord Files:', len(TRAIN_TFRECORDS))
print('Number of Test TFRecord Files:', len(TEST_TFRECORDS))

In [None]:
dataframe = pd.read_csv('../input/hpa-single-cell-image-classification/train.csv')
dataframe.head()

It is evident from the given dataframe, there are multiple labels involved with a single image, separated by `|`. Now, we would check for the distribution of classes across the dataset. We would also use a simple class weightage strategy to balance imbalanced classes.

In [None]:
# Ref: https://www.kaggle.com/dschettler8845/hpa-xai-ig-tfrecords-tpu-training

# Getting Label Distributions
LABEL_COUNTS = Counter([
    c for sublist in dataframe['Label'].str.split('|').to_list() for c in sublist
])

# Calculating class weights
threshold = sorted(LABEL_COUNTS.values())[3]
CLASS_WEIGHTS = {
    int(k): min(1.0, threshold / v) for k, v in LABEL_COUNTS.items()
}

# Visualization of Label imbalance and class weights
fig = go.Figure(data=[
    go.Bar(
        name='Class Distributions',
        x=[str(i) for i in list(range(len(LABEL_NAMES)))],
        y=[LABEL_COUNTS[str(key)] for key in list(range(len(LABEL_NAMES)))]
    ),
    go.Bar(
        name='Class Weights',
        x=[str(i) for i in list(range(len(LABEL_NAMES)))],
        y=[CLASS_WEIGHTS[key] * 2000 for key in list(CLASS_WEIGHTS.keys())]
    )
])
fig.update_layout(barmode='stack', uniformtext_minsize=8, uniformtext_mode='hide')
fig.update_traces(textposition='outside')
fig.show()

## Tensorflow Dataloader from TFRecords

Since TPUs are very fast, many models ported to TPU end up with a data bottleneck. The TPU is sitting idle, waiting for data for the most part of each training epoch. TPUs read training data exclusively from GCS buckets. Data for TPU training typically comes sharded across the appropriate number of larger TFRecord files. We would create a TFRecord Dataloader class to read the data from tfrecord files.

In [None]:
class TFRecordLoader:

    def __init__(self, image_size: List[int], n_classes: int, include_yellow_channel: bool):
        self.image_size = image_size
        self.n_classes = n_classes
        self.include_yellow_channel = include_yellow_channel

    def _parse_image(self, image):
        image = tf.image.decode_png(image, channels=1)
        image = tf.image.resize(image, self.image_size)
        image = tf.cast(image, dtype=tf.float32)
        return image

    def _parse_label(self, label):
        indices = tf.strings.to_number(
            tf.strings.split(label, sep='|'),
            out_type=tf.int32
        )
        return tf.reduce_max(
            tf.one_hot(indices, depth=self.n_classes), axis=-2
        )

    def _make_example(self, example):
        feature_format = {
            'image': tf.io.FixedLenFeature(shape=(), dtype=tf.string),
            'image_name': tf.io.FixedLenFeature(shape=(), dtype=tf.string),
            'target': tf.io.FixedLenFeature(shape=(), dtype=tf.string)
        }
        features = tf.io.parse_single_example(example, features=feature_format)
        image = self._parse_image(features['image'])
        image_name = features['image_name']
        label = self._parse_label(features['target'])
        return image, image_name, label

    def _combine_channels(self, red, green, blue, yellow):
        # Ref: https://www.kaggle.com/dschettler8845/hpa-xai-ig-tfrecords-tpu-training
        (r_i, r_j, r_k), (g_i, g_j, g_k), (b_i, b_j, b_k), (y_i, y_j, y_k) = red, green, blue, yellow
        combined_image = tf.stack(
            [r_i[..., 0], g_i[..., 0], b_i[..., 0], y_i[..., 0]], axis=-1
        ) if self.include_yellow_channel else tf.stack(
            [r_i[..., 0], g_i[..., 0], b_i[..., 0]], axis=-1
        )
        return combined_image, r_k

    def _preprocess(self, dataset):
        # Ref: https://www.kaggle.com/dschettler8845/hpa-xai-ig-tfrecords-tpu-training
        red_dataset = dataset.filter(
            lambda x, y, z: tf.strings.regex_full_match(y, ".*red.*"))
        green_dataset = dataset.filter(
            lambda x, y, z: tf.strings.regex_full_match(y, ".*green.*"))
        blue_dataset = dataset.filter(
            lambda x, y, z: tf.strings.regex_full_match(y, ".*blue.*"))
        yellow_dataset = dataset.filter(
            lambda x, y, z: tf.strings.regex_full_match(y, ".*yellow.*"))
        dataset = tf.data.Dataset.zip(
            (red_dataset, green_dataset, blue_dataset, yellow_dataset)
        )
        dataset = dataset.map(
            map_func=self._combine_channels,
            num_parallel_calls=tf.data.AUTOTUNE
        )
        return dataset

    def get_dataset(self, train_tfrecord_files: List[str], ignore_order: bool = False):
        options = tf.data.Options()
        options.experimental_deterministic = False
        dataset = tf.data.TFRecordDataset(
            train_tfrecord_files, num_parallel_reads=tf.data.AUTOTUNE)
        dataset = dataset.with_options(options) if ignore_order else dataset
        dataset = dataset.map(
            map_func=self._make_example, num_parallel_calls=tf.data.AUTOTUNE)
        dataset = self._preprocess(dataset)
        return dataset

In [None]:
def plot_result(
    images, captions: List[str], title: str, figsize: Tuple[int, int]):
    fig = plt.figure(figsize=figsize)
    plt.suptitle(
        'Label: ' + title[0], fontsize=20, fontweight='bold')
    for index in range(len(captions)):
        fig.add_subplot(
            1, len(captions), index + 1
        ).set_title(captions[index])
        _ = plt.imshow(images[index])
        plt.axis(False)
    plt.show()

In [None]:
loader = TFRecordLoader(
    image_size=[512, 512], n_classes=19, include_yellow_channel=False
)
dataset = loader.get_dataset(TRAIN_TFRECORDS)

for x, y in dataset.take(4):
    plot_result(
        [x[..., 0], x[..., 1], x[..., 2], x],
        ['red channel', 'green channel', 'blue channel', 'combined image'],
        [LABEL_NAMES[label] for label in np.where(y.numpy()==1)[0]], (20, 6)
    )

In [None]:
loader = TFRecordLoader(
    image_size=[512, 512], n_classes=19, include_yellow_channel=True
)
dataset = loader.get_dataset(TRAIN_TFRECORDS)

for x, y in dataset.take(4):
    plot_result(
        [x[..., 0], x[..., 1], x[..., 2], x[..., 3]],
        ['red channel', 'green channel', 'blue channel', 'yellow channel'],
        [LABEL_NAMES[label] for label in np.where(y.numpy()==1)[0]], (20, 6)
    )

## Data Augmentation

In [None]:
class AugmentationFactory:

    def __init__(self, include_flips: bool, include_rotation: bool, include_jitter: bool):
        self.include_flips = include_flips
        self.include_rotation = include_rotation
        self.include_jitter = include_jitter

    @staticmethod
    def _flip_horizontal(image, seed):
        image = tf.image.stateless_random_flip_left_right(image, seed)
        return image

    @staticmethod
    def _flip_vertical(image, seed):
        image = tf.image.stateless_random_flip_up_down(image, seed)
        return image

    @staticmethod
    def _rotate(image):
        rotation_k = tf.random.uniform((1,), minval=0, maxval=4, dtype=tf.int32)[0]
        image = tf.image.rot90(image, k=rotation_k)
        return image

    @staticmethod
    def _random_jitter(image, seed):
        image = tf.image.stateless_random_saturation(image, 0.9, 1.1, seed)
        image = tf.image.stateless_random_brightness(image, 0.075, seed)
        image = tf.image.stateless_random_contrast(image, 0.9, 1.1, seed)
        return image

    def _map_augmentations(self, image, label):
        seed = tf.random.uniform((2,), minval=0, maxval=100, dtype=tf.int32)
        if self.include_flips:
            image = self._flip_horizontal(image=image, seed=seed)
            image = self._flip_vertical(image=image, seed=seed)
        image = self._rotate(image=image) if self.include_rotation else image
        image = self._random_jitter(image=image, seed=seed) if self.include_jitter else image
        return image, label

    def augment_dataset(self, dataset):
        return dataset.map(
            map_func=self._map_augmentations,
            num_parallel_calls=tf.data.AUTOTUNE
        )

In [None]:
loader = TFRecordLoader(
    image_size=[512, 512], n_classes=19, include_yellow_channel=False
)
dataset = loader.get_dataset(TRAIN_TFRECORDS)

augmentation_factory = AugmentationFactory(
    include_flips=True, include_rotation=True, include_jitter=True
)
dataset = augmentation_factory.augment_dataset(dataset)

for x, y in dataset.take(4):
    plot_result(
        [x[..., 0], x[..., 1], x[..., 2], x],
        ['red channel', 'green channel', 'blue channel', 'Combined Image'],
        [LABEL_NAMES[label] for label in np.where(y.numpy()==1)[0]], (20, 6)
    )

## Simple Tensorflow Model

In [None]:
def get_backbone(backbone_name: str, input_shape: List[int], weights: str = 'imagenet'):
    backbone_class = None
    if 'efficientnet' in backbone_name:
        if 'b0' in backbone_name:
            backbone_class = tf.keras.applications.EfficientNetB0
        elif 'b1' in backbone_name:
            backbone_class = tf.keras.applications.EfficientNetB1
        elif 'b2' in backbone_name:
            backbone_class = tf.keras.applications.EfficientNetB2
        elif 'b3' in backbone_name:
            backbone_class = tf.keras.applications.EfficientNetB3
        elif 'b4' in backbone_name:
            backbone_class = tf.keras.applications.EfficientNetB4
        elif 'b5' in backbone_name:
            backbone_class = tf.keras.applications.EfficientNetB5
        elif 'b6' in backbone_name:
            backbone_class = tf.keras.applications.EfficientNetB6
        elif 'b7' in backbone_name:
            backbone_class = tf.keras.applications.EfficientNetB7
    backbone = backbone_class(
        include_top=False, weights=weights, input_shape=input_shape
    )
    return backbone


In [None]:
# Ref: https://www.kaggle.com/ayuraj/hpa-multi-label-classification-with-tf-and-w-b

def simple_model(input_shape: List[int], backbone_name: str, dropout: float, n_classes: int):
    backbone = get_backbone(
        backbone_name=backbone_name, weights='imagenet', input_shape=input_shape
    )
    backbone.trainable = True
    input_tensor = tf.keras.Input(input_shape, name='inputs')
    backbone_features = backbone(input_tensor, training=True)
    x = tf.keras.layers.GlobalAveragePooling2D(name='global_average_pool_2d')(backbone_features)
    x = tf.keras.layers.Dropout(rate=dropout, name='dropout_{}'.format(dropout))(x) if dropout > 0 else x
    output_tensor = tf.keras.layers.Dense(n_classes, activation='softmax', name='outputs')(x)
    return tf.keras.Model(
        input_tensor, output_tensor,
        name='{}_transfer_learning_model'.format(backbone_name)
    )

In [None]:
def get_strategy():
    try:  # detect TPUs
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
    except ValueError:  # detect GPUs
        strategy = tf.distribute.MirroredStrategy()  # for GPU or multi-GPU machines
    print("Number of accelerators: ", strategy.num_replicas_in_sync)
    return strategy

In [None]:
strategy = get_strategy()

with strategy.scope():
    model = simple_model(
        input_shape=[224, 224, 3], backbone_name='efficientnetb0',
        dropout=0.5, n_classes=len(LABEL_NAMES)
    )
model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.BinaryCrossentropy(label_smoothing=0.0),
    metrics=[tf.keras.metrics.AUC(multi_label=True)]
)
model.summary()

In [None]:
visualkeras.layered_view(model, spacing=100)

For the baseline, I created a simple model to perform transfer learning from EfficientNet Backbones pre-trained on Imagenet. Unlike the baseline kernel by [Ayush Thakur](https://www.kaggle.com/ayuraj), where Sigmoid Focal Crossentrpy Loss was used, I used plain old Binary Crossentropy. While using Focal Loss, it seems to be converging faster and in order to acheive similar results, I tried experimenting with a custor learning rate scheduling strategy which we will discuss shortly.

## Training

### Configure Dataset

In [None]:
def configure_dataset(augmented_dataset, shuffle_buffer: int = 128, batch_size: int = 16):
    dataset = augmented_dataset.repeat()
    dataset = augmented_dataset.shuffle(shuffle_buffer)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    return dataset

In [None]:
loader = TFRecordLoader(
    image_size=[224, 224], n_classes=19,
    include_yellow_channel=False
)
dataset = loader.get_dataset(
    TRAIN_TFRECORDS, ignore_order=True
)

augmentation_factory = AugmentationFactory(
    include_flips=True, include_rotation=True, include_jitter=True
)
dataset = augmentation_factory.augment_dataset(dataset)
BATCH_SIZE = 64 * strategy.num_replicas_in_sync
train_dataset = configure_dataset(
    dataset, batch_size=BATCH_SIZE
)
train_dataset

### Custom Learning Rate Scheduling

In [None]:
def custom_lr_scheduler(epoch, warmup_epochs=3, sustain_epochs=2, initial_lr=1e-5, max_lr=1e-4, epsilon=0.9):
    if epoch < warmup_epochs:
        lr = ((max_lr - initial_lr) / warmup_epochs * epoch) + initial_lr
    elif epoch < warmup_epochs + sustain_epochs:
        lr = max_lr
    else:
        lr = ((max_lr - initial_lr) * epsilon ** (epoch - warmup_epochs)) + initial_lr
    return lr

The idea of the custom Learning rate scheduling strategy was inspired by the kernel [HPA - XAI & IG [TFRECORDS+TPU][TRAINING]](https://www.kaggle.com/dschettler8845/hpa-xai-ig-tfrecords-tpu-training) by [Darien Schettler](https://www.kaggle.com/dschettler8845). The idea is basically to increase the learning rate at a constant rate initially for faster convergence and then keep decreasing it exponentially post the warmup epochs. The initial increase and subsequent sutainance of the increased learning rate helps the model to converge faster while the eponential decay of learning rate susequently helps in avoiding overfitting, especially while training for a large number of epochs

In [None]:
fig = go.Figure(
    data=go.Scatter(
        x=list(range(1, 16)),
        y=[
            custom_lr_scheduler(epoch) for epoch in list(range(1, 16))
        ]
    )
)
fig.update_layout(title='Custom LR Scheduling Policy')
fig.show()

### Training the Model

In [None]:
callbacks = [
    tf.keras.callbacks.LearningRateScheduler(custom_lr_scheduler),
    tf.keras.callbacks.EarlyStopping(monitor='loss', patience=10, restore_best_weights=True),
    tf.keras.callbacks.ReduceLROnPlateau(monitor='loss', patience=5)
]

In [None]:
EPOCHS = 150

history = model.fit(
    train_dataset, epochs=EPOCHS,
    class_weight=CLASS_WEIGHTS
)

### Save Model

In [None]:
model.save(
    './weights/efficientnetb0_fine_tuned_classification',
    options=tf.saved_model.SaveOptions(
        experimental_io_device='/job:localhost'
    )
)

### Training History

In [None]:
fig = make_subplots(rows=2, cols=1)
fig.add_trace(
    go.Scatter(
        x=list(range(EPOCHS)),
        y=history.history['loss'],
         name='loss'
    ), row=1, col=1
)
fig.add_trace(
    go.Scatter(
        x=list(range(EPOCHS)),
        y=history.history['auc'],
         name='AUC'
    ), row=2, col=1
)
fig.update_layout(title='Training History')
fig.show()

### Evaluating the Model

In [None]:
loader = TFRecordLoader(
    image_size=[224, 224], n_classes=19, include_yellow_channel=False
)
dataset = loader.get_dataset(TEST_TFRECORDS)
test_dataset = configure_dataset(
    dataset, batch_size=BATCH_SIZE
)

test_loss, test_auc = model.evaluate(test_dataset)

print('Loss on Test Data:', test_loss)
print('Loss on Test AUC:', test_auc)