## About

In this notebook, I'll train a model based on EfficientNetB0, GeM pooling, and ArcFace. 

The whole training pipeline is built with TensorFlow, and the training will be done on TPU.

This notebook is based on [stratified-tfrecords-training-pipeline](https://www.kaggle.com/ks2019/stratified-tfrecords-training-pipeline).

In [None]:
!pip install efficientnet tensorflow_addons > /dev/null

In [None]:
import os
import math
import random
import re
import warnings
from pathlib import Path
from typing import Optional, Tuple

import efficientnet.tfkeras as efn
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_addons as tfa
from kaggle_datasets import KaggleDatasets
from sklearn.model_selection import KFold

In [None]:
tf.__version__

## Config

In [None]:
NUM_FOLDS = 4
IMAGE_SIZE = 256
BATCH_SIZE = 64
EFFICIENTNET_SIZE = 0
WEIGHTS = "imagenet"
N_CLASSES = 81313
FOLDS = [0, 1, 2, 3]
EPOCHS = 20
SEED = 1213

SAVEDIR = Path("./")

## Utilities

In [None]:
def set_seed(seed=42):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)


set_seed(SEED)

In [None]:
def auto_select_accelerator():
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
        print("Running on TPU:", tpu.master())
    except ValueError:
        strategy = tf.distribute.get_strategy()
    print(f"Running on {strategy.num_replicas_in_sync} replicas")
    return strategy

In [None]:
strategy = auto_select_accelerator()
REPLICAS = strategy.num_replicas_in_sync
AUTO = tf.data.experimental.AUTOTUNE

## Data Loading

In [None]:
gcs_paths = []
for i in range(5):
    gcs_path = KaggleDatasets().get_gcs_path(f"landmark-recognition-2021-tfrecords-fold{i}")
    print(gcs_path)
    gcs_paths.append(gcs_path)
    
all_files = []
for path in gcs_paths:
    all_files.extend(np.sort(np.array(tf.io.gfile.glob(path + "/*.tfrec"))))

print("train files: ", len(all_files))

In [None]:
def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.image.resize(image, size=(IMAGE_SIZE, IMAGE_SIZE))
    image = tf.cast(image, tf.float32) / 255.0
    return image


def read_labeled_tfrecord(example):
    tfrec_format = {
        "image_name": tf.io.FixedLenFeature([], tf.string),
        "image": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.int64),
    }

    example = tf.io.parse_single_example(example, tfrec_format)
    posting_id = example["image_name"]
    image = decode_image(example["image"])
    label_group = tf.cast(example["target"], tf.int32)
    matches = 1
    return posting_id, image, label_group, matches


def arcface_format(posting_id, image, label_group, matches):
    return posting_id, {'inp1': image, 'inp2': label_group}, label_group, matches


# This function loads TF Records and parse them into tensors
def load_dataset(filenames, batch_size=64, cache=False, repeat=False, shuffle=False):        
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads = AUTO)
    if cache:
        dataset = dataset.cache()

    if shuffle:
        dataset = dataset.shuffle(2048)
        opt = tf.data.Options()
        opt.experimental_deterministic = False
        dataset = dataset.with_options(opt)

    dataset = dataset.map(read_labeled_tfrecord, num_parallel_calls = AUTO) 
    dataset = dataset.map(arcface_format, num_parallel_calls=AUTO)
    if repeat:
        dataset = dataset.repeat()
    dataset = dataset.batch(batch_size * REPLICAS)
    dataset = dataset.prefetch(AUTO)
    return dataset


def count_data_items(filenames):
    # The number of data items is written in the name of the .tfrec files, i.e. flowers00-230.tfrec = 230 data items
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

In [None]:
NUM_TRAINING_IMAGES = count_data_items(all_files)
NUM_TRAINING_IMAGES

## Model

In [None]:
class GeM(tf.keras.layers.Layer):
    def __init__(self, pool_size, init_norm=3.0, normalize=False, **kwargs):
        self.pool_size = pool_size
        self.init_norm = init_norm
        self.normalize = normalize

        super(GeM, self).__init__(**kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'pool_size': self.pool_size,
            'init_norm': self.init_norm,
            'normalize': self.normalize,
        })
        return config

    def build(self, input_shape):
        feature_size = input_shape[-1]
        self.p = self.add_weight(name='norms', shape=(feature_size,),
                                 initializer=tf.keras.initializers.constant(self.init_norm),
                                 trainable=True)
        super(GeM, self).build(input_shape)

    def call(self, inputs):
        x = inputs
        x = tf.math.maximum(x, 1e-6)
        x = tf.pow(x, self.p)

        x = tf.nn.avg_pool(x, self.pool_size, self.pool_size, 'VALID')
        x = tf.pow(x, 1.0 / self.p)

        if self.normalize:
            x = tf.nn.l2_normalize(x, 1)
        return x

    def compute_output_shape(self, input_shape):
        return tuple([None, input_shape[-1]])

In [None]:
class ArcMarginProduct(tf.keras.layers.Layer):
    '''
    Implements large margin arc distance.

    Reference:
        https://arxiv.org/pdf/1801.07698.pdf
        https://github.com/lyakaap/Landmark2019-1st-and-3rd-Place-Solution/
            blob/master/src/modeling/metric_learning.py
    '''
    def __init__(self, n_classes, s=30, m=0.50, easy_margin=False,
                 ls_eps=0.0, **kwargs):

        super(ArcMarginProduct, self).__init__(**kwargs)

        self.n_classes = n_classes
        self.s = s
        self.m = m
        self.ls_eps = ls_eps
        self.easy_margin = easy_margin
        self.cos_m = tf.math.cos(m)
        self.sin_m = tf.math.sin(m)
        self.th = tf.math.cos(math.pi - m)
        self.mm = tf.math.sin(math.pi - m) * m

    def get_config(self):

        config = super().get_config().copy()
        config.update({
            'n_classes': self.n_classes,
            's': self.s,
            'm': self.m,
            'ls_eps': self.ls_eps,
            'easy_margin': self.easy_margin,
        })
        return config

    def build(self, input_shape):
        super(ArcMarginProduct, self).build(input_shape[0])

        self.W = self.add_weight(
            name='W',
            shape=(int(input_shape[0][-1]), self.n_classes),
            initializer='glorot_uniform',
            dtype='float32',
            trainable=True,
            regularizer=None)

    def call(self, inputs):
        X, y = inputs
        y = tf.cast(y, dtype=tf.int32)
        cosine = tf.matmul(
            tf.math.l2_normalize(X, axis=1),
            tf.math.l2_normalize(self.W, axis=0)
        )
        sine = tf.math.sqrt(1.0 - tf.math.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = tf.where(cosine > 0, phi, cosine)
        else:
            phi = tf.where(cosine > self.th, phi, cosine - self.mm)
        one_hot = tf.cast(
            tf.one_hot(y, depth=self.n_classes),
            dtype=cosine.dtype
        )
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.n_classes

        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s
        return output

In [None]:
def build_model(size=256, efficientnet_size=0, weights="imagenet", count=0):
    inp = tf.keras.layers.Input(shape=(size, size, 3), name="inp1")
    label = tf.keras.layers.Input(shape=(), name="inp2")
    x = getattr(efn, f"EfficientNetB{efficientnet_size}")(
        weights=weights, include_top=False, input_shape=(size, size, 3))(inp)
    x = GeM(8)(x)
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(512, name="dense_before_arcface", kernel_initializer="he_normal")(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = ArcMarginProduct(
        n_classes=N_CLASSES,
        s=30,
        m=0.5,
        name="head/arc_margin",
        dtype="float32"
    )([x, label])
    output = tf.keras.layers.Softmax(dtype="float32")(x)
    model = tf.keras.Model(inputs=[inp, label], outputs=[output])
    lr_decayed_fn = tf.keras.experimental.CosineDecay(1e-3, count)
    opt = tfa.optimizers.AdamW(lr_decayed_fn, learning_rate=1e-4)
    model.compile(
        optimizer=opt,
        loss=[tf.keras.losses.SparseCategoricalCrossentropy()],
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
    )
    return model

## Other training utilities

In [None]:
def get_lr_callback(plot=False):
    lr_start   = 1e-3
    lr_max     = 0.00003 * BATCH_SIZE  
    lr_min     = 1e-5
    lr_ramp_ep = 4
    lr_sus_ep  = 0
    lr_decay   = 0.9
   
    def lrfn(epoch):
        if epoch < lr_ramp_ep:
            lr = (lr_max - lr_start) / lr_ramp_ep * epoch + lr_start
            
        elif epoch < lr_ramp_ep + lr_sus_ep:
            lr = lr_max
            
        else:
            lr = (lr_max - lr_min) * lr_decay**(epoch - lr_ramp_ep - lr_sus_ep) + lr_min
            
        return lr
        
    if plot:
        epochs = list(range(EPOCHS))
        learning_rates = [lrfn(x) for x in epochs]
        plt.scatter(epochs,learning_rates)
        plt.show()

    lr_callback = tf.keras.callbacks.LearningRateScheduler(lrfn, verbose=False)
    return lr_callback

get_lr_callback(plot=True)

## Training

In [None]:
kf = KFold(n_splits=NUM_FOLDS, shuffle=True, random_state=SEED)
files_train_all = np.array(all_files)

In [None]:
for fold, (trn_idx, val_idx) in enumerate(kf.split(files_train_all)):
    if fold not in FOLDS:
        continue
    files_train = files_train_all[trn_idx]
    files_valid = files_train_all[val_idx]

    print("=" * 120)
    print(f"Fold {fold}")
    print("=" * 120)

    train_image_count = count_data_items(files_train)
    valid_image_count = count_data_items(files_valid)

    tf.keras.backend.clear_session()
    strategy = auto_select_accelerator()

    with strategy.scope():
        model = build_model(
            size=IMAGE_SIZE,
            efficientnet_size=EFFICIENTNET_SIZE,
            weights=WEIGHTS,
            count=train_image_count // BATCH_SIZE // REPLICAS // 4
        )

    model_ckpt = tf.keras.callbacks.ModelCheckpoint(
        str(SAVEDIR / f"fold{fold}.h5"), monitor="val_loss", verbose=1, save_best_only=True,
        save_weights_only=True, mode="min", save_freq="epoch"
    )

    train_dataset = load_dataset(files_train, batch_size=BATCH_SIZE, shuffle=True, repeat=True)
    train_dataset = train_dataset.map(lambda posting_id, image, label_group, matches: (image, label_group))

    valid_dataset = load_dataset(files_valid, batch_size=BATCH_SIZE * 2, shuffle=False, repeat=False)
    valid_dataset = valid_dataset.map(lambda posting_id, image, label_group, matches: (image, label_group))

    STEPS_PER_EPOCH = train_image_count // BATCH_SIZE // REPLICAS // 4
    history = model.fit(
        train_dataset,
        epochs=EPOCHS,
        callbacks=[model_ckpt, get_lr_callback()],
        steps_per_epoch=STEPS_PER_EPOCH,
        validation_data=valid_dataset,
        verbose=1
    )