Hello fellow Kagglers,

This notebook demonstrates how to train a CNN using ArcFace loss on the Google Landmark 2021 dataset. The EfficientNetV2-S CNN is used, introduced in [this](https://arxiv.org/pdf/2104.00298.pdf) paper. Moreover, the ArcFace loss is used to create class embeddings which are close to eachother, the ArcFace paper can be found [here](https://arxiv.org/pdf/1801.07698.pdf).

Since the given dataset is huge with ~1.5M images training efficiency is key. To optimize training the images are converted to TFRrecords for fast reading, these datasets can be found here: [Part 1](https://www.kaggle.com/markwijkhuizen/landmark-recognition-2021-tfrecords-384-part-1), [Part 2](https://www.kaggle.com/markwijkhuizen/landmark-recognition-2021-tfrecords-384-part-2), [Part 3](https://www.kaggle.com/markwijkhuizen/landmark-recognition-2021-tfrecords-384-part-3). The last trick was to use bfloat16 training, which is a 16 bits float with a lower precision than a conventional 16 bits float, but the range of a 32 bits float. This reduces the computation time and allows for bigger batch sizes.

In [None]:
import warnings
warnings.simplefilter('ignore')

This next line adds the [EfficientNetV2 GitHub repository](https://github.com/google/automl/tree/master/efficientnetv2) with the corresponding weights to the system.

In [None]:
import sys
sys.path.append('/kaggle/input/efficientnetv2-pretrained-imagenet21k-weights')

In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt

from tensorflow.keras.mixed_precision import experimental as mixed_precision
from kaggle_datasets import KaggleDatasets
from tqdm.notebook import tqdm
from brain_automl.efficientnetv2 import effnetv2_model, preprocessing

import re
import os
import io
import time
import pickle
import math
import random
import sys

print(f'tensorflow version: {tf.__version__}')
print(f'tensorflow keras version: {tf.keras.__version__}')
print(f'python version: P{sys.version}')

In [None]:
# Seed all random number generators
def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    
seed_everything(42)

In [None]:
# Detect hardware, return appropriate distribution strategy
try:
    TPU = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
    print('Running on TPU ', TPU.master())
except ValueError:
    print('Running on GPU')
    TPU = None

if TPU:
    tf.config.experimental_connect_to_cluster(TPU)
    tf.tpu.experimental.initialize_tpu_system(TPU)
    strategy = tf.distribute.experimental.TPUStrategy(TPU)
else:
    strategy = tf.distribute.get_strategy() # default distribution strategy in Tensorflow. Works on CPU and single GPU.

REPLICAS = strategy.num_replicas_in_sync
print(f'REPLICAS: {REPLICAS}')

In [None]:
DEBUG = False

# Image dimensions
IMG_SIZE = 384
N_CHANNELS = 3
INPUT_SHAPE = (IMG_SIZE, IMG_SIZE, N_CHANNELS)
# Dataset size
N_SAMPLES = 1580470

# EfficientNet version, s, l, xl, xxl
EFN_SIZE = 's'

# Batch size per replica, there are 8 replicas resulting in a batch size of 1024
BATCH_SIZE_BASE = 6 if DEBUG else (128 if TPU else 16)
BATCH_SIZE = BATCH_SIZE_BASE * REPLICAS

MODEL_POLICY = 'mixed_bfloat16' # float32 or mixed_bfloat16
IMAGE_DTYPE = tf.bfloat16 if MODEL_POLICY == 'mixed_bfloat16' else tf.float32
LABEL_DTYPE = tf.int32

# Imagenet mean and standard deviation for normalizing images
IMAGENET_MEAN = tf.constant([0.485, 0.456, 0.406], dtype=tf.float32)
IMAGENET_STD = tf.constant([0.229, 0.224, 0.225], dtype=tf.float32)

# Tensorflow AUTO flag
AUTO = tf.data.experimental.AUTOTUNE

print(f'BATCH_SIZE: {BATCH_SIZE}, IMAGE_DTYPE: {IMAGE_DTYPE}, LABEL_DTYPE: {LABEL_DTYPE}')
print(f'MODEL_POLICY: {MODEL_POLICY}')

# Landmark\_id to label mappers

In [None]:
# mappers from landmark_id to label and vice versa
with open('/kaggle/input/landmark-recognition-2021-tfrecords-384-part-1/label2landmark_id.pkl', 'rb') as f:
    label2landmark_id = pickle.load(f)
    
with open('/kaggle/input/landmark-recognition-2021-tfrecords-384-part-1/landmark_id2label.pkl', 'rb') as f:
    landmark_id2label = pickle.load(f)

In [None]:
N_CLASSES = len(label2landmark_id.keys())
print(f'N_CLASSES: {N_CLASSES}')

# Dataset

Decode function for the TFRecords. Documentation on TFRecords can be found [here](https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset). Each TFRecords contains 3000 images which can be processed in one go. This is much faster than reading all images one by one.

In [None]:
def decode_tfrecord(record_bytes):
    # Data the sample contains
    features = tf.io.parse_single_example(record_bytes, {
        'image': tf.io.FixedLenFeature([], tf.string),
        'label': tf.io.FixedLenFeature([], tf.int64),
        'width': tf.io.FixedLenFeature([], tf.int64),
        'height': tf.io.FixedLenFeature([], tf.int64),
    })

    image = tf.io.decode_jpeg(features['image'])
    label = tf.cast(features['label'], dtype=tf.int32)
    height = features['height']
    width = features['width']
    
    # Cutout Random Square
    if height != width:
        if height > width:
            # Get random offset
            offset = tf.random.uniform(shape=(), minval=0, maxval=height-width, dtype=tf.int64)
            image = tf.slice(image, [offset, 0, 0], [width, width, N_CHANNELS])
        else:
            # Get random offset
            offset = tf.random.uniform(shape=(), minval=0, maxval=width-height, dtype=tf.int64)
            image = tf.slice(image, [0, offset, 0], [height, height, N_CHANNELS])
    
    # Reshape and Normalize
    size = tf.math.reduce_min([height, width])
    image = tf.reshape(image, [size, size, N_CHANNELS])
    image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
    image = tf.cast(image, tf.float32)  / 255.0
    image = (image - IMAGENET_MEAN) / IMAGENET_STD
    
    # One hot encode the label, required for metrics
    label_one_hot = tf.one_hot(label, N_CLASSES, dtype=tf.uint8)
    
    # Cast image if not desired dtype, will be converted to bfloat16 on TPU
    if image.dtype != IMAGE_DTYPE:
        image = tf.cast(image, IMAGE_DTYPE)
    
    # CNN required both image and label, passed as dictionary
    return { 'image': image, 'label': label }, label_one_hot

In [None]:
# Simple function to benchmark the dataset, images will be read with ~6000 images/second!
def benchmark_dataset(dataset, num_epochs=3, n_steps_per_epoch=25, bs=BATCH_SIZE):
    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        for idx, (inputs, labels) in enumerate(dataset.take(n_steps_per_epoch + 1)):
            images = inputs['image']
            if idx == 0:
                epoch_start = time.perf_counter()
            elif idx == 1 and epoch_num == 0:
                print(f'image shape: {images.shape}, image dtype: {images.dtype}')
            else:
                pass
        epoch_t = time.perf_counter() - epoch_start
        mean_step_t = round(epoch_t / n_steps_per_epoch * 1000, 1)
        n_imgs_per_s = int(1 / (mean_step_t / 1000) * bs)
        print(f'epoch {epoch_num} took: {round(epoch_t, 2)} sec, mean step duration: {mean_step_t}ms, images/s: {n_imgs_per_s}')

In [None]:
# Plots a batch of images
def show_batch(dataset, rows=4, cols=3):
    inputs, lbls = next(iter(dataset))
    imgs = inputs['image']
    fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(cols*7, rows*4))
    for r in range(rows):
        for c in range(cols):
            img = imgs[r*cols+c].numpy().astype(np.float32)
            img += abs(img.min())
            img /= img.max()
            axes[r, c].imshow(img)
            axes[r, c].set_title(f'Label: {np.argmax(lbls[r*cols+c])}')

# Train Dataset

In [None]:
# Google Cloud paths to the TFRecords datasets, required as TPU will read from Google Cloud only
GCS_DS_PATH_1 = KaggleDatasets().get_gcs_path('landmark-recognition-2021-tfrecords-384-part-1')
GCS_DS_PATH_2 = KaggleDatasets().get_gcs_path('landmark-recognition-2021-tfrecords-384-part-2')
GCS_DS_PATH_3 = KaggleDatasets().get_gcs_path('landmark-recognition-2021-tfrecords-384-part-3')

In [None]:
def get_train_dataset(bs=BATCH_SIZE):
    # Ignore order, improves performance
    ignore_order = tf.data.Options()
    ignore_order.experimental_deterministic = False
    
    # Use glob to find all TFRecords files
    FNAMES_TRAIN_TFRECORDS = (
        tf.io.gfile.glob(f'{GCS_DS_PATH_1}/*.tfrecords') +
        tf.io.gfile.glob(f'{GCS_DS_PATH_2}/*.tfrecords') +
        tf.io.gfile.glob(f'{GCS_DS_PATH_3}/*.tfrecords')
    )
    
    print(f'Found roughly {len(FNAMES_TRAIN_TFRECORDS) * int(3e3)} images, N_SAMPLES: {N_SAMPLES}')
    
    train_dataset = tf.data.TFRecordDataset(FNAMES_TRAIN_TFRECORDS, num_parallel_reads=AUTO)
    train_dataset = train_dataset.with_options(ignore_order)
    train_dataset = train_dataset.repeat()
    train_dataset = train_dataset.map(decode_tfrecord, num_parallel_calls=AUTO)
    train_dataset = train_dataset.batch(BATCH_SIZE)
    train_dataset = train_dataset.prefetch(AUTO)
    
    return train_dataset

train_dataset = get_train_dataset()

In [None]:
# Benchmark the dataset, close to 6000 images/second can be read using TFRecords!
benchmark_dataset(train_dataset)

In [None]:
# Sanity check, what type and shapes are the images and labels
inputs, lbls_oh = next(iter(train_dataset))
imgs = inputs['image']
print(f'imgs shape: {imgs.shape}, imgs dtype: {imgs.dtype}')
img0 = imgs[0].numpy().astype(np.float32)
train_imgs_info = (img0.mean(), img0.std(), img0.min(), img0.max())
print('train img 0 mean: %.3f, 0 std: %.3f, min: %.3f, max: %.3f' % train_imgs_info)

# Labels
lbls = inputs['label'].numpy()
print(f'lbls shape: {lbls.shape}, lbls dtype: {lbls.dtype}')
print(f'lbls min: {lbls.min()}, lbls max: {lbls.max()}')

# Labels One Hot Encoded
lbls_oh_np = lbls_oh.numpy()
print(f'lbls_oh shape: {lbls_oh.shape}, lbls_oh dtype: {lbls_oh.dtype}')

In [None]:
show_batch(train_dataset)

# ArcMargin Product

ArcMargin product used for the ArcFace loss. I can't explain the math behind it, but the basic idea is to cluster the embeddings of samples belonging to the same class close together. The difference between conventional softmax and ArcFace loss is nicely illustrated in [this](https://www.kaggle.com/chankhavu/keras-layers-arcface-cosface-adacos) notebook. If someone could explain the math behind it that would be very welcome :)

In [None]:
class ArcMarginProduct(tf.keras.layers.Layer):
    '''
    Implements large margin arc distance.

    Reference:
        https://arxiv.org/pdf/1801.07698.pdf
        https://github.com/lyakaap/Landmark2019-1st-and-3rd-Place-Solution/
            blob/master/src/modeling/metric_learning.py
    '''
    def __init__(self, n_classes, s=30, m=0.50, easy_margin=False, ls_eps=0.0, **kwargs):
        
        super(ArcMarginProduct, self).__init__(**kwargs)

        self.n_classes = n_classes
        self.s = s
        self.m = m
        self.ls_eps = ls_eps
        self.easy_margin = easy_margin
        self.cos_m = tf.math.cos(m)
        self.sin_m = tf.math.sin(m)
        self.th = tf.math.cos(math.pi - m)
        self.mm = tf.math.sin(math.pi - m) * m

    def build(self, input_shape):
        super(ArcMarginProduct, self).build(input_shape[0])

        self.W = self.add_weight(
            name='W',
            shape=(int(input_shape[0][-1]), self.n_classes),
            initializer='glorot_uniform',
            dtype=tf.float32,
            trainable=True,
            regularizer=None
        )

    def call(self, inputs):
        X, y = inputs
        
        cosine = tf.matmul(
            tf.math.l2_normalize(X, axis=1),
            tf.math.l2_normalize(self.W, axis=0)
        )
        sine = tf.math.sqrt(1.0 - tf.math.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        phi = tf.where(cosine > self.th, phi, cosine - self.mm)
        one_hot = tf.cast(
            tf.one_hot(y, depth=self.n_classes),
            dtype=cosine.dtype
        )
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.n_classes

        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s
        
        return output

# Model

In [None]:
# EfficientNetV2-S pretrained weights also need to be read from Google Cloud
GCS_WEIGHTS_PATH = KaggleDatasets().get_gcs_path('efficientnetv2-pretrained-imagenet21k-weights')

In [None]:
def get_model():
    tf.keras.backend.clear_session()
    # enable XLA optmizations
    tf.config.optimizer.set_jit(True)
    
    # set half precision policy
    mixed_precision.set_policy(MODEL_POLICY)

    # Print compute and variable dtype, on TPU this will be bfloat16 for compute and float32 for variable
    print(f'Compute dtype: {mixed_precision.global_policy().compute_dtype}')
    print(f'Variable dtype: {mixed_precision.global_policy().variable_dtype}')

    with strategy.scope():
        WEIGHT_PATH = f'{GCS_WEIGHTS_PATH}/efficientnetv2-{EFN_SIZE}-21k-ft1k'
        cnn = effnetv2_model.EffNetV2Model(model_name='efficientnetv2-s')
        
        # Inputs, note the names are equal to the dictionary keys in the dataset
        image = tf.keras.layers.Input(INPUT_SHAPE, name='image', dtype=IMAGE_DTYPE)
        label = tf.keras.layers.Input([], name='label', dtype=tf.int32)

        # Build the model with a dummy call, this is required
        cnn(tf.ones([1,*INPUT_SHAPE]), training=False)

        # Get the latest checkpoint from path
        ckpt = tf.train.latest_checkpoint(WEIGHT_PATH)

        # Load the weights
        cnn.load_weights(ckpt)
        
        # CNN call, we need only the output layer
        x = cnn(image, features_only=True)[0]
        # Global Average Pooling, cast to float32 for ArcMargin product
        x = tf.keras.layers.GlobalAveragePooling2D(name='pooling', dtype=tf.float32)(x)
        # Optional Dropout layer
        x = tf.keras.layers.Dropout(0.00, name='dropout', dtype=tf.float32)(x)
        # ArcMargin product
        output = ArcMarginProduct(n_classes=N_CLASSES, name='arc_margin', dtype=tf.float32)([x, label])

        # We will use the famous Adam optimizer for fast learning
        optimizer = tf.keras.optimizers.Adam()

        # Categorical Cross Entropy loss, from_logits=True so no softmax needed
        loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)

        # To track learning progress, top 1/10/100/1000 accuracies will be kept track of
        metrics = [
            tf.keras.metrics.TopKCategoricalAccuracy(k=1, name='accuracy@1'),
            tf.keras.metrics.TopKCategoricalAccuracy(k=10, name='accuracy@10'),
            tf.keras.metrics.TopKCategoricalAccuracy(k=100, name='accuracy@100'),
            tf.keras.metrics.TopKCategoricalAccuracy(k=1000, name='accuracy@1000'),
        ]

        model = tf.keras.models.Model(inputs = [image, label], outputs = [output])
        model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

        return model

In [None]:
model = get_model()

In [None]:
# Plot model summary
model.summary()

In [None]:
# Plot slightly more detailed model summary
tf.keras.utils.plot_model(model, show_shapes=True, show_dtype=True, show_layer_names=True, expand_nested=False)

# Learning Rate Scheduler

Because of the transfer learning approach, an exponential warmup is used with a cosine decay

In [None]:
# Due to the huge batch size of 1024 and usage of bfloat16 15 epochs are possible in a single run!
EPOCHS = 2 if DEBUG else 15

In [None]:
# returns the learning rate given an epoch number
def lrfn(epoch):
    # Config
    LR_START = 1e-5 # start of learning rate
    LR_MAX = 2e-4 # peak learning rate
    LR_FINAL = 2e-5 # final learning rate
    LR_RAMPUP_EPOCHS = 3 # number of exponential warmup epochs
    LR_SUSTAIN_EPOCHS = 2 # number of epochs at maximum learning rate
    
    DECAY_EPOCHS = EPOCHS  - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS - 1
    LR_EXP_DECAY = (LR_FINAL / LR_MAX) ** (1 / (EPOCHS - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS - 1))

    if epoch < LR_RAMPUP_EPOCHS: # exponential warmup
        lr = LR_START + (LR_MAX + LR_START) * (epoch / LR_RAMPUP_EPOCHS) ** 2.5
    elif epoch < LR_RAMPUP_EPOCHS + LR_SUSTAIN_EPOCHS: # sustain lr
        lr = LR_MAX
    else: # cosine decay
        epoch_diff = epoch - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS + 1
        decay_factor = (epoch_diff / (DECAY_EPOCHS + 1)) * math.pi
        decay_factor= (tf.math.cos(decay_factor).numpy() + 1) / 2
        lr = LR_FINAL + (LR_MAX - LR_FINAL) * decay_factor

    return round(lr, 8)

In [None]:
# Plots the learning rate schedule
def plot_lr_schedule(lr_schedule, name):
    plt.figure(figsize=(15,8))
    plt.plot(lr_schedule)
    x = np.arange(EPOCHS)
    x_axis_labels = list(map(str, np.arange(1, EPOCHS+1)))
    plt.xticks(x, x_axis_labels) # set tick step to 1 and let x axis start at 1
    schedule_info = f'start: {lr_schedule[0]}, max: {max(lr_schedule)}, final: {lr_schedule[-1]}'
    plt.title(f'Step Learning Rate Schedule {name}, {schedule_info}', size=16)
    plt.grid()
    plt.show()

# Learning rate for encoder
LR_SCHEDULE = [lrfn(step) for step in range(EPOCHS)]
plot_lr_schedule(LR_SCHEDULE, 'Ecnoder')

# Callbacks

In [None]:
# Learning rate callback
lr_callback = tf.keras.callbacks.LearningRateScheduler(lambda step: lrfn(step), verbose=1)
# Model checkpoint, saves weights if train loss reduces
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    'model.h5', monitor='loss', verbose=1, save_best_only=True, save_weights_only=True
)

# Training

In [None]:
# Because of the repeating dataset the amount of steps per epoch needs to be defined
STEPS_PER_EPOCH = N_SAMPLES // BATCH_SIZE

print(f'BATCH_SIZE: {BATCH_SIZE}, STEPS_PER_EPOCH: {STEPS_PER_EPOCH}')

In [None]:
# Train the model, each epoch takes just ~23 minutes!
history = model.fit(
    train_dataset,
    steps_per_epoch = STEPS_PER_EPOCH,
    epochs = EPOCHS,
    verbose = 2,
    callbacks = [
        lr_callback,
        model_checkpoint_callback,
    ],
)

# Training History

Plot metric history during training

In [None]:
# Function to plot the metric history
def plot_history_metric(history, metric, f_best):
    plt.figure(figsize=(15, 8))
    N_EPOCHS = len(history.history['loss'])
    x = [1, 5] + [10 + 5 * idx for idx in range((N_EPOCHS - 10) // 5 + 1)]
    x_ticks = np.arange(1, N_EPOCHS+1)
    # summarize history for accuracy
    plt.plot(x_ticks, history.history[metric])
    values = history.history[metric]
    argmin = f_best(values)
    plt.scatter(argmin + 1, values[argmin], color='red', s=50, marker='o')
    
    plt.title(f'Model {metric}', fontsize=24, pad=10)
    plt.ylabel(metric, fontsize=20, labelpad=10)
    plt.xlabel('epoch', fontsize=20, labelpad=10)
    plt.tick_params(axis='x', labelsize=8)
    plt.xticks(x, fontsize=16) # set tick step to 1 and let x axis start at 1
    plt.yticks(fontsize=16)
    plt.legend(['train'],  prop={'size': 18})
    plt.grid()

In [None]:
plot_history_metric(history, 'loss', np.argmin)

In [None]:
plot_history_metric(history, 'accuracy@1', np.argmax)

In [None]:
plot_history_metric(history, 'accuracy@10', np.argmax)

In [None]:
plot_history_metric(history, 'accuracy@100', np.argmax)

In [None]:
plot_history_metric(history, 'accuracy@1000', np.argmax)