Hello Fellow Kagglers,

This notebook demonstrates the training process for the Happy Whale 2022 competition using the EfficientNetV2-XL model with a DOLG head and ArcFace classifier.

This model architecture is based on the 1st place solution of the [Google Landmark Recognition 2021](https://www.kaggle.com/c/landmark-recognition-2021) competition by [Christof Henkel](https://www.kaggle.com/christofhenkel) who poblished his solution written in Pytorch on [GitHuB](https://github.com/ChristofHenkel/kaggle-landmark-2021-1st-place/blob/main/models/ch_mdl_dolg_efficientnet.py).

This solution is a Tensorflow implementation of his model architecture with some tweaks. Among others, this solution uses conventional global average pooling instead of generalized mean pooling and does not have a batch normalization and PReLu layer in the head.

The implementation of Christof Henkel is based on DOLG: Single-Stage Image Retrieval with Deep Orthogonal Fusion of Local and Global Features ([paper](https://arxiv.org/pdf/2108.02927.pdf)).

This notebook should contribute to this competition by showing a Tensorflow implementation of DOLG and, to the best of my knowledge, setting a new benchmark of a single fold model.

[Inference Notebook](https://www.kaggle.com/markwijkhuizen/happy-whale-2022-efficientnetv2-xl-dolg-inference)

In [None]:
import sys
sys.path.append('/kaggle/input/efficientnetv2-pretrained-imagenet21k-weights/brain_automl/')
sys.path.append('/kaggle/input/efficientnetv2-pretrained-imagenet21k-weights/brain_automl/efficientnetv2/')

In [None]:
import warnings
warnings.simplefilter('ignore')

In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow.keras.backend as K
import matplotlib.pyplot as plt

from tensorflow.keras.mixed_precision import experimental as mixed_precision
from kaggle_datasets import KaggleDatasets
from tqdm.notebook import tqdm

import re
import os
import io
import time
import pickle
import math
import random
import sys
import imageio
import effnetv2_model

print(f'tensorflow version: {tf.__version__}')
print(f'tensorflow keras version: {tf.keras.__version__}')
print(f'python version: P{sys.version}')

In [None]:
def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    
seed_everything(42)

In [None]:
# Detect hardware, return appropriate distribution strategy
try:
    TPU = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
    print('Running on TPU ', TPU.master())
except ValueError:
    print('Running on GPU')
    TPU = None

if TPU:
    tf.config.experimental_connect_to_cluster(TPU)
    tf.tpu.experimental.initialize_tpu_system(TPU)
    strategy = tf.distribute.experimental.TPUStrategy(TPU)
else:
    strategy = tf.distribute.get_strategy() # default distribution strategy in Tensorflow. Works on CPU and single GPU.

REPLICAS = strategy.num_replicas_in_sync
print(f'REPLICAS: {REPLICAS}')

# Load Train

In [None]:
train = pd.read_csv('/kaggle/input/happy-whale-and-dolphin/train.csv')
display(train.head())

In [None]:
DEBUG = False

# Input Image Shape
IMG_SIZE = 640
N_CHANNELS = 3
INPUT_SHAPE = (IMG_SIZE, IMG_SIZE, N_CHANNELS)
N_SAMPLES = len(train)

N_EPOCHS = 20

# Model Configuration
EFNV2_SIZE = 'xl'
DOLG_SIZE = 1024
EMBEDDING_SIZE = 2048

# Due to huge amount of parameters batch size of 16 is not possible
BATCH_SIZE_BASE = 12
BATCH_SIZE = BATCH_SIZE_BASE * REPLICAS

CROP = True

# ImageNet Normalization
IMAGENET_MEAN = tf.constant([0.485, 0.456, 0.406], dtype=tf.float32)
IMAGENET_STD = tf.constant([0.229, 0.224, 0.225], dtype=tf.float32)

AUTO = tf.data.experimental.AUTOTUNE
EPS = tf.keras.backend.epsilon()

print(f'N_SAMPLES: {N_SAMPLES}, BATCH_SIZE: {BATCH_SIZE}')

# Number of Labels

In [None]:
N_INDIVIDUAL_IDS = train['individual_id'].nunique()
print(f'N_INDIVIDUAL_IDS: {N_INDIVIDUAL_IDS}')

# Dataset

The amazing cropped [backfintfrecords](https://www.kaggle.com/datasets/jpbremer/backfintfrecords) by [Jan Bre](https://www.kaggle.com/jpbremer) is used.

In [None]:
# Function to Decode TFRecords and augment the image
def decode_tfrecord(record_bytes):
    features = tf.io.parse_single_example(record_bytes, {
        'image_name': tf.io.FixedLenFeature([], tf.string),
        'image': tf.io.FixedLenFeature([], tf.string),
        'target': tf.io.FixedLenFeature([], tf.int64),
    })

    target = tf.cast(features['target'], tf.int32)
    
    image = tf.io.decode_jpeg(features['image'])
    # Resize Image
    image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
    # Explicit reshape needed for TPU, tell cimpiler dimensions of image
    image = tf.reshape(image, INPUT_SHAPE)
        
    # Image Augmentations: retrieved from:
    # https://www.kaggle.com/aikhmelnytskyy/happywhale-arcface-baseline-eff7-tpu-768-concat?scriptVersionId=88596800&cellId=15
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_hue(image, 0.01)
    image = tf.image.random_saturation(image, 0.70, 1.30)
    image = tf.image.random_contrast(image, 0.80, 1.20)
    image = tf.image.random_brightness(image, 0.10)
        
    # ImageNet Normalization
    image = tf.cast(image, tf.float32)  / 255.0
    image = (image - IMAGENET_MEAN) / IMAGENET_STD
    
    return { 'image': image, 'individual_id_input': target }, { 'individual_id': target }

In [None]:
# Simple Function to benchmark the dataset to make sure the data loader won't form a bottleneck
def benchmark_dataset(dataset, num_epochs=3, n_steps_per_epoch=25, bs=BATCH_SIZE):
    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        for idx, (inputs, labels) in enumerate(dataset.take(n_steps_per_epoch + 1)):
            images = inputs['image']
            if idx == 0:
                epoch_start = time.perf_counter()
            elif idx == 1 and epoch_num == 0:
                print(f'image shape: {images.shape}, image dtype: {images.dtype}')
            else:
                pass
        epoch_t = time.perf_counter() - epoch_start
        mean_step_t = round(epoch_t / n_steps_per_epoch * 1000, 1)
        n_imgs_per_s = int(1 / (mean_step_t / 1000) * bs)
        print(f'epoch {epoch_num} took: {round(epoch_t, 2)} sec, mean step duration: {mean_step_t}ms, images/s: {n_imgs_per_s}')

In [None]:
# Function to show a batch of images
def show_batch(dataset, rows=5, cols=4):
    inputs, lbls = next(iter(dataset))
    imgs = inputs['image']
    fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(cols*6, rows*6))
    for r in range(rows):
        for c in range(cols):
            idx = r*cols+c
            img = imgs[idx].numpy().astype(np.float32)
            img += abs(img.min())
            img /= img.max()
            axes[r, c].imshow(img)
            individual_id = lbls['individual_id'][idx]
            axes[r, c].set_title(f'individual_id: {individual_id}')

# Train Dataset

In [None]:
# For TPU's the dataset needs to be stored in Google Cloud
# Retrieve the Google Cloud location of the dataset
GCS_DS_PATH = KaggleDatasets().get_gcs_path('backfintfrecords')

In [None]:
# Train Dataset
def get_train_dataset(bs=BATCH_SIZE, center_cutout=False, dr=True, sr=True):
    FNAMES_TRAIN_TFRECORDS = tf.io.gfile.glob(f'{GCS_DS_PATH}/*train*.tfrec')
    
    # Shuffle TFRecords
    random.shuffle(FNAMES_TRAIN_TFRECORDS)
    
    train_dataset = tf.data.TFRecordDataset(FNAMES_TRAIN_TFRECORDS, num_parallel_reads=AUTO)
    
    # Shuffle and Repeat dataset
    if sr:
        ignore_order = tf.data.Options()
        ignore_order.experimental_deterministic = False
        
        train_dataset = train_dataset.with_options(ignore_order)
        train_dataset = train_dataset.shuffle(N_SAMPLES if TPU else 1024)
        train_dataset = train_dataset.repeat()
        
    map_fn = lambda e: decode_tfrecord(e)
    train_dataset = train_dataset.map(map_fn, num_parallel_calls=AUTO)
    train_dataset = train_dataset.batch(bs, drop_remainder=dr)
    train_dataset = train_dataset.prefetch(AUTO)
    
    return train_dataset

In [None]:
benchmark_dataset(get_train_dataset(center_cutout=True))

In [None]:
# Input Image statistics, verify normalization
inputs, labels = next(iter(get_train_dataset()))
imgs = inputs['image']
print(f'inputs keys: {inputs.keys()}, labels keys: {labels.keys()}')
print(f'imgs shape: {imgs.shape}, imgs dtype: {imgs.dtype}')
img0 = imgs[0].numpy().astype(np.float32)
train_imgs_info = (img0.mean(), img0.std(), img0.min(), img0.max())
print('train img 0 mean: %.3f, 0 std: %.3f, min: %.3f, max: %.3f' % train_imgs_info)

In [None]:
show_batch(get_train_dataset(bs=32))

# Augmentation Test

Visualization of augmentations

In [None]:
def get_demo_image():
    demo_image = imageio.imread('/kaggle/input/happy-whale-and-dolphin/train_images/000562241d384d.jpg')
    h, w = demo_image.shape[:2]
    demo_image = demo_image[:, (w - h) // 2:(w - h) // 2 + h,:]
    demo_image = tf.constant(demo_image)
    demo_image_h = tf.constant(demo_image.shape[0], dtype=tf.int64)
    demo_image_w = tf.constant(demo_image.shape[1], dtype=tf.int64)
    
    return demo_image, demo_image_h, demo_image_w

demo_image, demo_image_h, demo_image_w = get_demo_image()

In [None]:
def decode_tfrecord_demo(record_bytes):
    image = demo_image
    height = demo_image_h
    width = demo_image_w
    
    # Resize to IMG_SIZE
    image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE], method=tf.image.ResizeMethod.BICUBIC)
        
    # Image Augmentations
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_hue(image, 0.01)
    image = tf.image.random_saturation(image, 0.70, 1.30)
    image = tf.image.random_contrast(image, 0.80, 1.20)
    image = tf.image.random_brightness(image, 0.10)
        
    # ImageNet Normalization
    image = tf.cast(image, tf.float32)  / 255.0
    image = (image - IMAGENET_MEAN) / IMAGENET_STD
    
    return { 'image': image, 'individual_id': 42 }, { 'individual_id': 42 }

In [None]:
def get_train_dataset_augmentation(labels=['individual_id', 'species', 'family']):
    FNAMES_TRAIN_TFRECORDS = [f'{GCS_DS_PATH}/*.tfrec']
    
    dataset = tf.data.TFRecordDataset.from_tensor_slices(np.zeros(32))
        
    dataset = dataset.map(decode_tfrecord_demo, num_parallel_calls=AUTO)
    dataset = dataset.batch(32)
    
    return dataset

dataset = get_train_dataset_augmentation()

In [None]:
show_batch(get_train_dataset_augmentation())

# Dynamic Margins

Based on the [2nd place solution](https://arxiv.org/pdf/2010.05350.pdf) of the Google Landmark Recognition 2021 competition. Margins are computed based on class occurance, making the descriptors of classes with few examples more unique.

There is a large class inbalance, as shown below, with number of samples per class ranging from 1 to 400.

In [None]:
display(train['individual_id'].value_counts().describe().to_frame(name='Value'))

In [None]:
plt.figure(figsize=(15,8))
plt.title('Individual Id Sample Count', size=24)
train['individual_id'].value_counts().value_counts().sort_index().head(25).plot(kind='bar')
plt.xticks(size=16)
plt.yticks(size=16)
plt.xlabel('Frequency', size=18)
plt.ylabel('Class Size', size=18)
plt.show()

In [None]:
def get_dynamic_margins(a=0.40, b=0.05):
    # Individual Id Value Counts
    value_counts = train['individual_id'].value_counts().sort_index()
        
    # Compute Dynamic Margins
    dynamic_margins = value_counts.sort_index().values ** -0.25
    dynamic_margins = a* dynamic_margins + b
    dynamic_margins = tf.constant(dynamic_margins, dtype=tf.float32)
    
    # Sanity Check, class size to margin mapping
    class_size2margin = pd.Series(data=dynamic_margins, index=value_counts.sort_index())
    class_size2margin = class_size2margin.drop_duplicates().sort_index()
    
    return dynamic_margins, class_size2margin

dynamic_margins, class_size2margin = get_dynamic_margins()

In [None]:
# Dynamic Margins Statistics
display(pd.Series(dynamic_margins).describe())

In [None]:
# Plot Dynamic Margins Distribution
# Over Half of the classes have just 1 sample with the maximum margin of 0.45
plt.figure(figsize=(15, 8))
plt.title(f'Dynamic Margins Distribution', size=24)
pd.Series(np.flip(np.sort(dynamic_margins))).plot()
plt.xticks(size=16)
plt.yticks(size=16)
plt.xlabel('Class Count', size=18)
plt.ylabel('Margin', size=18)
plt.grid()
plt.show()

In [None]:
# Plot Dynamic Margins Distribution
plt.figure(figsize=(15, 8))
plt.title(f'Class Size to Margin Mapping', size=24)
plt.xscale('log')
class_size2margin.plot()
plt.xticks(size=16)
plt.yticks(size=16)
plt.xlabel('Class Size', size=18)
plt.ylabel('Margin', size=18)
plt.grid()
plt.show()

In [None]:
# Class size to margin mapping
display(class_size2margin.head(10))

# ArcMargin Product

In [None]:
class ArcMarginPenaltyLogists(tf.keras.layers.Layer):
    """ArcMarginPenaltyLogists"""
    def __init__(self, num_classes, dynamic_margins, logist_scale=30, k=1, **kwargs):
        super(ArcMarginPenaltyLogists, self).__init__(**kwargs)
        self.num_classes = num_classes
        if type(dynamic_margins) == float:
            print(f'Using Static Margin: {dynamic_margins}')
            self.dynamic_margins = tf.fill(dims=[num_classes], value=dynamic_margins)
        else:
            print(f'Using Dynamic Margins, first 5: {dynamic_margins[:5]}')
            self.dynamic_margins = dynamic_margins
        self.logist_scale = logist_scale
        self.k = k

    def build(self, input_shape):
        initializer_amplitude = 1.0 / tf.math.sqrt(float(self.num_classes))
        initializer = tf.keras.initializers.random_uniform(-initializer_amplitude, initializer_amplitude)
        self.w = self.add_variable("weights", shape=[int(input_shape[-1]), self.num_classes * self.k], initializer=initializer)

    def call(self, embds, labels, debug=False):
        # Dynamic Margins
        margins = tf.gather(self.dynamic_margins, labels)
        cos_m = tf.identity(tf.math.cos(margins), name='cos_m')
        sin_m = tf.identity(tf.math.sin(margins), name='sin_m')
        th = tf.identity(tf.math.cos(math.pi - margins), name='th')
        mm = tf.multiply(sin_m, margins, name='mm')
        if debug:
            print(f'margins: {margins}')
            print(f'cos shape: {cos_m.shape}, sin_m shape: {sin_m.shape}, th shape: {th.shape}, mm shape: {mm.shape}')
        
        embds = tf.cast(embds, tf.float32)
        
        normed_embds = tf.nn.l2_normalize(embds, axis=1, name='normed_embd')
        normed_w = tf.nn.l2_normalize(self.w, axis=0, name='normed_weights')

        cos_t = tf.matmul(normed_embds, normed_w, name='cos_t')
        cos_t = tf.reshape(cos_t, shape=[-1, self.num_classes, self.k])
        cos_t = tf.math.reduce_max(cos_t, axis=2)
        if debug:
            print(f'cos_t shape: {cos_t.shape}')
        sin_t = tf.sqrt(1. - cos_t ** 2, name='sin_t')

        cos_mt = tf.subtract(cos_t * tf.reshape(cos_m, [-1, 1]), sin_t * tf.reshape(sin_m, [-1, 1]), name='cos_mt')

        cos_mt = tf.where(cos_t > tf.reshape(th, [-1, 1]), cos_mt, cos_t - tf.reshape(mm, [-1, 1]))

        mask = tf.one_hot(tf.cast(labels, tf.int32), depth=self.num_classes, name='one_hot_mask')

        logists = tf.where(mask == 1., cos_mt, cos_t)
        logists = tf.multiply(logists, self.logist_scale, 'arcface_logist')

        return logists

# DOLG

DOLG: DOLG: Single-Stage Image Retrieval with Deep Orthogonal Fusion of Local and Global Features ([paper](https://arxiv.org/pdf/2108.02927.pdf))

Implementation based on [Pytorch implementation](https://github.com/ChristofHenkel/kaggle-landmark-2021-1st-place/blob/main/models/ch_mdl_dolg_efficientnet.py) published by Christof Henkel on GitHub.

In [None]:
class GeM(tf.keras.layers.Layer):
    def __init__(self, init_norm=3.0, **kwargs):
        super(GeM, self).__init__(**kwargs)
        self.init_norm = init_norm
        self.gap2d = tf.keras.layers.GlobalAveragePooling2D()

    def build(self, input_shape):
        super(GeM, self).build(input_shape)
        feature_size = input_shape[-1]
        self.p = self.add_weight(
                name = "norms",
                shape = feature_size,
                initializer = tf.keras.initializers.constant(self.init_norm),
                trainable = True,
            )

    def call(self, inputs):
        x = tf.math.maximum(inputs, EPS)
        x = tf.pow(x, self.p)

        x = self.gap2d(x)
        x = tf.pow(x, 1.0 / self.p)

        return x

In [None]:
# Multi-Atrous Branch
class MultiAtrous(tf.keras.layers.Layer):
    def __init__(self, dolg_s, upsampling=1, kernel_size=3, padding="same",  **kwargs):
        super(MultiAtrous, self).__init__(**kwargs)
        self.d0 = tf.keras.layers.Conv2D(dolg_s // 2, 3, dilation_rate=(3,3), padding='same')
        self.d1 = tf.keras.layers.Conv2D(dolg_s // 2, 3, dilation_rate=(6,6), padding='same')
        self.d2 = tf.keras.layers.Conv2D(dolg_s // 2, 3, dilation_rate=(9,9), padding='same')
        self.conv1 = tf.keras.layers.Conv2D(dolg_s, kernel_size=1)
        
    @tf.function()
    def call(self, inputs, training=None, **kwargs):
        x0 = self.d0(inputs)
        x1 = self.d1(inputs)
        x2 = self.d2(inputs)
        x = tf.keras.layers.Concatenate(axis=3)([x0,x1,x2])
        x = self.conv1(x)
        x = tf.keras.activations.relu(x)
        return x
            
    def get_config(self):
        config = {
            'dilation_rates': self.dilation_rates,
            'kernel_size'   : self.kernel_size,
            'padding'       : self.padding,
            'upsampling'    : self.upsampling
        }
        base_config = super(MultiAtrous, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

In [None]:
class SpatialAttention2d(tf.keras.layers.Layer):
    def __init__(self, dolg_s, **kwargs):
        super(SpatialAttention2d, self).__init__(**kwargs)
        self.conv1 = tf.keras.layers.Conv2D(dolg_s, 1)
        self.bn = tf.keras.layers.BatchNormalization()
        self.conv2 = tf.keras.layers.Conv2D(1, 1)

    @tf.function()
    def call(self, x):
        x = self.conv1(x)
        x = self.bn(x)
        
        feature_map_norm, _ = tf.linalg.normalize(x, ord=2, axis=3)
        
        x = tf.keras.activations.relu(x)
        x = self.conv2(x)
        
        att_score = tf.keras.activations.softplus(x)
        
        x = att_score * feature_map_norm

        return x

In [None]:
class OrthogonalFusion(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    @tf.function()
    def call(self, inputs):
        fl, fg = inputs
        fl = tf.transpose(fl, [0,3,1,2])
        
        bs, c, w, h = fl.shape
        
        fl_b = tf.reshape(fl, [tf.shape(fl)[0],c,w*h])
        fl_dot_fg = tf.matmul(fg[:,tf.newaxis,:] ,fl_b)
       
        fl_dot_fg = tf.reshape(fl_dot_fg, [tf.shape(fl_dot_fg)[0],1,w,h])
        
        fg_norm = tf.norm(fg, ord=2, axis=1)
        
        fl_proj = (fl_dot_fg / fg_norm[:,tf.newaxis,tf.newaxis,tf.newaxis]) * fg[:,:,tf.newaxis,tf.newaxis]
        fl_orth = fl - fl_proj
        
        fg_rep = tf.tile(fg[:,:,tf.newaxis,tf.newaxis], multiples=(1,1,w,h))
        f_fused = tf.keras.layers.Concatenate(axis=1)([fl_orth, fg_rep])
        
        # Transpose
        f_fused = tf.transpose(f_fused, [0,2,3,1])
        
        return f_fused

In [None]:
class GlobalBranch(tf.keras.layers.Layer):
    def __init__(self, dolg_s, **kwargs):
        super().__init__(**kwargs)
        self.conv2d = tf.keras.layers.Conv2D(dolg_s, 1, name='global_conv2d')
        self.bn = tf.keras.layers.BatchNormalization()
        self.pool = tf.keras.layers.GlobalAveragePooling2D()
        
    @tf.function()
    def call(self, inputs):
        x = self.conv2d(inputs)
        x = self.bn(x)
        x = tf.nn.silu(x)
        x = self.pool(x)
        
        return x

In [None]:
class DolgBranch(tf.keras.layers.Layer):
    def __init__(self, dolg_s, idx, **kwargs):
        super().__init__(name=f'dolg_branch_{idx}', **kwargs)
        dolg_s = int(dolg_s)
        # Local
        self.mam = MultiAtrous(dolg_s, name=f'mam_{idx}')
        self.sa2d = SpatialAttention2d(dolg_s, name=f'sa2d_{idx}')
        # Global
        self.global_branch = GlobalBranch(dolg_s, name=f'g_{idx}')
        # Orthogonal Fusion
        self.orthogonal_fusion = OrthogonalFusion()
        # Pooling
        self.pool = tf.keras.layers.GlobalAveragePooling2D()
        
    @tf.function()
    def call(self, inputs):
        inputs_l, inputs_g = inputs
        # Local
        l = self.mam(inputs_l)
        l = self.sa2d(l)
        # Global
        g = self.global_branch(inputs_g)
        # Orthogonal Fusion
        f = self.orthogonal_fusion([l, g])
        # Pooling
        descriptor = self.pool(f)
        
        return descriptor

# Model

EfficientNetV2 models published by Google on [GitHub](https://github.com/google/automl/tree/master/efficientnetv2)

In [None]:
GCS_WEIGHTS_PATH = KaggleDatasets().get_gcs_path('efficientnetv2-pretrained-imagenet21k-weights')

In [None]:
def get_model():
    tf.keras.backend.clear_session()
    # enable XLA optmizations
    tf.config.optimizer.set_jit(True)

    with strategy.scope():
        # Input
        image = tf.keras.layers.Input(INPUT_SHAPE, name='image', dtype=tf.float32)
        individual_id = tf.keras.layers.Input([], name='individual_id_input', dtype=tf.int32)
        
        
        # EfficientNetV2 CNN
        cnn = effnetv2_model.get_model(f'efficientnetv2-{EFNV2_SIZE}', include_top=False, weights=None)
        
        # Load Pretrained ImageNet21K Finetuned Imagenet1K Weights
        if TPU:
            WEIGHT_PATH = f'{GCS_WEIGHTS_PATH}/efficientnetv2-{EFNV2_SIZE}-21k-ft1k'
            ckpt = tf.train.latest_checkpoint(WEIGHT_PATH)
            cnn.load_weights(ckpt)
        
        # CNN Outputs
        embedding, fm5, fm4, fm3, fm2, fm1 = cnn(image, with_endpoints=True)
        print(f'embedding: {embedding.shape}, fm5: {fm5.shape}, fm4: {fm4.shape}, fm3: {fm3.shape}, fm2: {fm2.shape}, fm1: {fm1.shape}')
        
        # DOLG Branches
        descriptor = DolgBranch(DOLG_SIZE, 1)([fm2, fm1])
        
        # Dense Layer
        descriptor = tf.keras.layers.Dropout(0.00)(descriptor)
        descriptor = tf.keras.layers.Dense(EMBEDDING_SIZE, name='descriptor_dense')(descriptor)
        
        # ArcMarginProduct
        outputs = ArcMarginPenaltyLogists(N_INDIVIDUAL_IDS, dynamic_margins=dynamic_margins, k=1, name='individual_id')(descriptor, individual_id)
        
        model = tf.keras.models.Model(inputs=[image, individual_id], outputs=outputs)
        
        # OPTIMIZER
        optimizer = tf.optimizers.Adam()
        
        # LOSS
        loss = {
            'individual_id': tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        }
        # METRICS
        metrics =[
            tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5, name='top5acc'),
            tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1, name='top1acc'),
        ]

        # Compile Model
        model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

        return model

In [None]:
model = get_model()

In [None]:
# The EfficientNetV2-XL Model is Huge with over 200 million parameters!
model.summary()

In [None]:
tf.keras.utils.plot_model(model, show_shapes=True, show_dtype=False, show_layer_names=True, expand_nested=False)

# Learning Rate Scheduler

In [None]:
TRAIN_STEPS_PER_EPOCH = N_SAMPLES // BATCH_SIZE
print(f'N_EPOCHS: {N_EPOCHS}, TRAIN_STEPS_PER_EPOCH: {TRAIN_STEPS_PER_EPOCH}')

In [None]:
def lrfn(current_step, num_warmup_steps, lr_max, num_cycles=0.50, num_training_steps=N_EPOCHS):
    
    if current_step < num_warmup_steps:
        return lr_max * 0.5 ** (num_warmup_steps - current_step)
    else:
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))

        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) * lr_max

In [None]:
def plot_lr_schedule(lr_schedule, name):
    fig = plt.figure(figsize=(20,8))
    plt.plot([None] + lr_schedule + [None])
    # X Labels
    x = np.arange(N_EPOCHS + 2)
    x_axis_labels = [None] + list(map(str, np.arange(1, N_EPOCHS+1))) + [None]
    plt.xlim([0, N_EPOCHS + 1])
    plt.xticks(x, x_axis_labels, size=12) # set tick step to 1 and let x axis start at 1
    plt.yticks(size=12)
    
    # Increase y-limit for better readability
    plt.ylim([0, max(lr_schedule) * 1.1])
    
    # Title
    schedule_info = f'start: {lr_schedule[0]:.1E}, max: {max(lr_schedule):.1E}, final: {lr_schedule[-1]:.1E}'
    plt.title(f'Step Learning Rate Schedule {name}, {schedule_info}', size=18, pad=12)
    
    # Plot Learning Rates
    for x, val in enumerate(lr_schedule):
        if x < len(lr_schedule) - 1:
            if lr_schedule[x - 1] < val:
                ha = 'right'
            else:
                ha = 'left'
        elif x == 0:
            ha = 'right'
        else:
            ha = 'left'
        plt.plot(x + 1, val, 'o', color='black');
        offset_y = (max(lr_schedule) - min(lr_schedule)) * 0.02
        plt.annotate(f'{val:.1E}', xy=(x + 1, val + offset_y), size=12, ha=ha)
    
    plt.xlabel('Epoch', size=16, labelpad=5)
    plt.ylabel('Learning Rate', size=16, labelpad=5)
    plt.grid()
    plt.show()

# Learning rate for encoder
LR_SCHEDULE = [lrfn(step, num_warmup_steps=3, lr_max=5e-4, num_cycles=0.50) for step in range(N_EPOCHS)]
plot_lr_schedule(LR_SCHEDULE, 'Model')

# Callbacks

In [None]:
# Make Checkpoints for Models based on best training loss
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    'model_best.h5', monitor='loss', verbose=1, save_best_only=True, save_weights_only=True
)
model_checkpoint_callback.set_model(model)

# Learning Rate Scheduler
learning_rate_callback = tf.keras.callbacks.LearningRateScheduler(lambda epoch: LR_SCHEDULE[epoch], verbose=1)

In [None]:
# Training
history = model.fit(
    get_train_dataset(),
    steps_per_epoch = TRAIN_STEPS_PER_EPOCH,
    epochs = N_EPOCHS,
    verbose = 1,
    callbacks = [
        learning_rate_callback,
        model_checkpoint_callback,
    ],
)

In [None]:
# Load Best Weights
model.load_weights('model_best.h5')

# Training History

In [None]:
def plot_history_metric(metric, f_best=np.argmax, yscale='linear'):
    x = np.arange(1, len(history.history[metric]) + 1)
    y_train = history.history[metric]
    plt.figure(figsize=(20, 8))
    # TRAIN
    plt.plot(x, y_train, color='tab:blue', lw=3, label='train')
    plt.title(f'Training {metric}', fontsize=24, pad=10)
    plt.ylabel(metric, fontsize=20, labelpad=10)
    plt.xlabel('epoch', fontsize=20, labelpad=10)
    plt.xticks([1] + np.arange(5, N_EPOCHS + 1, 5).tolist(), fontsize=16) # set tick step to 1 and let x axis start at 1
    plt.yticks(fontsize=16)
    plt.yscale(yscale)
    
    # Train Best Marker
    x_best = f_best(y_train)
    y_best = y_train[x_best]
    plt.scatter(x_best + 1, y_best, color='purple', s=100, marker='o', label=f'train best: {y_best:.4f}')
 
    if f'val_{metric}' in history.history:
        y_val = history.history[f'val_{metric}']
       # Validation Best Marker
        plt.plot(x, y_val, color='tab:orange', lw=3, label='validation')
        # VALIDATION
        x_best = f_best(y_val)
        y_best = y_val[x_best]
        plt.scatter(x_best + 1, y_best, color='red', s=100, marker='o', label=f'validation best: {y_best:.4f}')
    
    plt.grid()
    plt.legend(prop={'size': 18})
    plt.show()

In [None]:
plot_history_metric('loss', f_best=np.argmin)

In [None]:
plot_history_metric('top5acc', f_best=np.argmax)

In [None]:
plot_history_metric('top1acc', f_best=np.argmax)

# Create Embeddings Model

Creates the embedding model where the output is not the classifier, but the descriptor.

In [None]:
# Show Model Names
for idx, l in enumerate(model.layers):
    print(f'{idx} | \t{l.name}')

In [None]:
with strategy.scope():
    # Input
    image = tf.keras.layers.Input(INPUT_SHAPE, name='image', dtype=tf.float32)

    # EfficientNet
    embedding, fm5, fm4, fm3, fm2, fm1 = model.layers[1](image, with_endpoints=True)
    
    # DOLG Branches
    descriptor = model.layers[2]([fm2, fm1])
    
    # Descriptor
    outputs = model.layers[4](descriptor)
    
    model_embedding = tf.keras.Model(inputs=image, outputs=outputs)
    model_embedding.trainable = False

In [None]:
model_embedding.summary()

In [None]:
# The embedding model outputs the descriptor of size 2048
tf.keras.utils.plot_model(model_embedding, show_shapes=True, show_dtype=True, show_layer_names=True, expand_nested=False)

In [None]:
model_embedding.save_weights('model_embedding.h5')