# ArcFace Training

This notebook builds and trains an ArcFace based deep neural network.

References: 
1. https://www.kaggle.com/ragnar123/shopee-efficientnetb3-arcmarginproduct/notebook
2. https://www.kaggle.com/ragnar123/shopee-tf-records-512

In [1]:
import re
import os
import numpy as np
import pandas as pd
import random
import math
import tensorflow as tf
from tensorflow.keras.applications import EfficientNetB3
from sklearn.model_selection import train_test_split
from tensorflow.keras import backend as K
import util

In [2]:
# tfrecords directory
TFRECORDS = tf.io.gfile.glob('../../data/tfrecords-new/*.tfrec')

EPOCHS = 20
BATCH_SIZE = 8
IMAGE_SIZE = [512, 512]
SEED = 42
LR = 0.001
N_CLASSES = 11014

AUTO = tf.data.experimental.AUTOTUNE

In [3]:
def set_seed(seed):
    '''
    Seed for reproducibility
    '''
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    tf.random.set_seed(seed)

In [4]:
def arcface_format(posting_id, image, label_group, matches):
    '''
    Transforms our dataset to the ArcFace input format
    '''
    return posting_id, {'inp1': image, 'inp2': label_group}, label_group, matches

def data_augment(posting_id, image, label_group, matches):
    '''
    image augmentation
    '''
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    image = tf.image.random_hue(image, 0.01)
    image = tf.image.random_saturation(image, 0.70, 1.30)
    image = tf.image.random_contrast(image, 0.80, 1.20)
    image = tf.image.random_brightness(image, 0.10)
    return posting_id, image, label_group, matches


def read_tfrecord(example, labeled=True):
    '''
    Parses a single image from tfrecords
    '''
    tfrecord_format = {
        "posting_id": tf.io.FixedLenFeature([], tf.string),
        "image": tf.io.FixedLenFeature([], tf.string),
        "label_group": tf.io.FixedLenFeature([], tf.int64),
        "matches": tf.io.FixedLenFeature([], tf.string)
    }

    example = tf.io.parse_single_example(example, tfrecord_format)
    posting_id = example['posting_id']
    image = util.decode_image(example['image'],IMAGE_SIZE)
    if labeled:
        label_group = tf.cast(example['label_group'], tf.int32)
    matches = example['matches']
    return posting_id, image, label_group, matches

def load_dataset(filenames, ordered = False):
    '''
    Load tfrecords and parse into a tf.data.TFRecordDataset
    '''
    ignore_order = tf.data.Options()
    # disable order, increase speed
    if not ordered: ignore_order.experimental_deterministic = False 
        
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads = AUTO)
    # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order)
    dataset = dataset.map(read_tfrecord, num_parallel_calls = AUTO) 
    return dataset

In [6]:
def get_training_dataset(filenames, ordered = False):
    dataset = load_dataset(filenames, ordered = ordered)
    dataset = dataset.map(data_augment, num_parallel_calls = AUTO)
    dataset = dataset.map(arcface_format, num_parallel_calls = AUTO)
    dataset = dataset.repeat()
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO)
    return dataset

def get_validation_dataset(filenames, ordered = True):
    dataset = load_dataset(filenames, ordered = ordered)
    dataset = dataset.map(arcface_format, num_parallel_calls = AUTO)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) 
    return dataset

In [7]:
def count_data_items(filenames):
    '''
    Parses filenames of tfrecords to get number of images
    '''
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

In [8]:
def get_lr_callback():
    '''
    Custom learning rate scheduler
    '''
    
    lr_start   = 0.000001
    lr_max     = 0.000005 * BATCH_SIZE
    lr_min     = 0.000001
    lr_ramp_ep = 5
    lr_sus_ep  = 0
    lr_decay   = 0.8
   
    def lrfn(epoch):
        if epoch < lr_ramp_ep:
            lr = (lr_max - lr_start) / lr_ramp_ep * epoch + lr_start   
        elif epoch < lr_ramp_ep + lr_sus_ep:
            lr = lr_max    
        else:
            lr = (lr_max - lr_min) * lr_decay**(epoch - lr_ramp_ep - lr_sus_ep) + lr_min    
        return lr

    lr_callback = tf.keras.callbacks.LearningRateScheduler(lrfn, verbose = True)
    return lr_callback

In [9]:
def get_model():

    inp = tf.keras.layers.Input(shape = (*IMAGE_SIZE, 3), name = 'inp1')
    label = tf.keras.layers.Input(shape = (), name = 'inp2')
    x = EfficientNetB3(weights = 'imagenet', include_top = False)(inp)
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    
    margin = util.ArcMarginProduct(
        n_classes = N_CLASSES, 
        s = 30, 
        m = 0.5, 
        name='head/arc_margin', 
        dtype='float32'
    )
    
    x = margin([x, label])

    output = tf.keras.layers.Softmax(dtype='float32')(x)
    model = tf.keras.models.Model(inputs = [inp, label], outputs = [output])

    opt = tf.keras.optimizers.Adam(learning_rate = LR)
    model.compile(
        optimizer = opt,
        loss = ['sparse_categorical_crossentropy'],
        metrics = ['sparse_categorical_accuracy']
        ) 

    return model

In [10]:
set_seed(SEED)

In [11]:
train, valid = train_test_split(TFRECORDS, shuffle = True, random_state = SEED, test_size=0.1)

In [7]:
train_dataset = get_training_dataset(train, ordered = False)
train_dataset = train_dataset.map(lambda posting_id, image, label_group, matches: (image, label_group))
val_dataset = get_validation_dataset(valid, ordered = True)
val_dataset = val_dataset.map(lambda posting_id, image, label_group, matches: (image, label_group))

In [14]:
K.clear_session()

In [8]:
model = get_model()

In [9]:
checkpoint = tf.keras.callbacks.ModelCheckpoint(f'EfficientNetB3_{IMAGE_SIZE[0]}_{SEED}_new_lr.h5', 
                                                monitor = 'val_loss', 
                                                verbose = 1, 
                                                save_best_only = True,
                                                save_weights_only = True, 
                                                mode = 'min')

In [10]:
model.load_weights('./trained/arcface_best_epoch_512_42.h5')

In [13]:
STEPS_PER_EPOCH = count_data_items(train) // BATCH_SIZE

In [11]:
history = model.fit(train_dataset,
                    steps_per_epoch = STEPS_PER_EPOCH,
                    epochs = EPOCHS,
#                    callbacks = [checkpoint],
                    callbacks = [checkpoint, get_lr_callback()], 
                    validation_data = val_dataset,
                    verbose = 1)

Epoch 1/20

Epoch 00001: LearningRateScheduler reducing learning rate to 1e-06.

Epoch 00001: val_loss improved from inf to 9.84911, saving model to EfficientNetB5_512_42_new_lr.h5
Epoch 2/20

Epoch 00002: LearningRateScheduler reducing learning rate to 8.800000000000002e-06.

Epoch 00002: val_loss did not improve from 9.84911
Epoch 3/20

Epoch 00003: LearningRateScheduler reducing learning rate to 1.6600000000000004e-05.

Epoch 00003: val_loss did not improve from 9.84911
Epoch 4/20

Epoch 00004: LearningRateScheduler reducing learning rate to 2.4400000000000007e-05.

Epoch 00004: val_loss did not improve from 9.84911
Epoch 5/20

Epoch 00005: LearningRateScheduler reducing learning rate to 3.2200000000000003e-05.

Epoch 00005: val_loss did not improve from 9.84911
Epoch 6/20

Epoch 00006: LearningRateScheduler reducing learning rate to 4e-05.

Epoch 00006: val_loss did not improve from 9.84911
Epoch 7/20

Epoch 00007: LearningRateScheduler reducing learning rate to 3.2200000000000003e

In [15]:
model.save('./trained/arcface.h5')

