I implemented the model with reference to the following notebook.  
[Shopee EfficientNetB3 ArcMarginProduct | Kaggle](https://www.kaggle.com/ragnar123/shopee-efficientnetb3-arcmarginproduct)  
Thank you very much. @ ragnar123

The dataset I'm using is created using the following notebook.  
[Shopee data to TFRecord | Kaggle](https://www.kaggle.com/yukiohkawa/shopee-data-to-tfrecord)

# Setting

In [None]:
import math
import sys
import pathlib

import numpy as np
import pandas as pd
import cv2
import tensorflow as tf
from sklearn.model_selection import train_test_split

from tensorflow.keras import Model, layers
from tensorflow.keras.applications import EfficientNetB3, efficientnet
from tensorflow.keras.layers import Dense, Dropout, Flatten, GlobalAveragePooling2D, Input

AUTO = tf.data.experimental.AUTOTUNE

In [None]:
try:
    # TPU detection. No parameters necessary if TPU_NAME environment variable is
    # set: this is always the case on Kaggle.
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
else:
    # Default distribution strategy in Tensorflow. Works on CPU and single GPU.
    strategy = tf.distribute.get_strategy()

print("REPLICAS: ", strategy.num_replicas_in_sync)

In [None]:
IMG_SIIZE = 512
IMG_CHANNEL = 3
BATCH_SIZE = 8 * strategy.num_replicas_in_sync
EPOCHS = 1

In [None]:
def get_layer_index(model, layer_name, not_found=None):
    """get model's layer index by layer's name"""
    for i, layer in enumerate(model.layers):
        if layer.name == layer_name:
            return i
    return not_found

# Data Loading

In [None]:
tfrecords_path = [ str(i) for i in list(pathlib.Path('../input/shoppee-tfrecord/tfrecords').glob('**/*.tfrecord'))]
df = pd.read_csv('../input/shoppee-tfrecord/tfrecords/train_fold.csv')

In [None]:
df.head()

In [None]:
N_CLASSES = len(df.label_group.unique())
DATA_SIZE = len(df)

# Build Dataset

In [None]:
def data_augment(image):

    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    image = tf.image.random_hue(image, 0.01)
    image = tf.image.random_saturation(image, 0.70, 1.30)
    image = tf.image.random_contrast(image, 0.80, 1.20)
    image = tf.image.random_brightness(image, 0.10)

    return image

In [None]:
def load_tfrecords(paths, is_augment=True):
    """load tfrecords"""

    raw_dataset = tf.data.TFRecordDataset(paths, num_parallel_reads=AUTO)

    feature_description = {
        'label_group': tf.io.FixedLenFeature([], tf.int64),
        'image': tf.io.FixedLenFeature([], tf.string),
        'posting_id': tf.io.FixedLenFeature([], tf.string),
        'title': tf.io.FixedLenFeature([], tf.string),
    }

    def _parse_function(example):

        feature = tf.io.parse_single_example(example, feature_description)

        image = tf.image.decode_jpeg(feature['image'], channels=3)
        image = tf.image.resize(image, [IMG_SIIZE, IMG_SIIZE])
        image = tf.cast(image, tf.float32) / 255.0
        if is_augment:
            image = data_augment(image)
        label = feature['label_group']

        return (image, label), label

    parsed_dataset = raw_dataset.map(_parse_function)

    return parsed_dataset

In [None]:
train_paths, val_paths = train_test_split(tfrecords_path, test_size=3, random_state=1)

In [None]:
train = load_tfrecords(train_paths)
val = load_tfrecords(val_paths, is_augment=False)

In [None]:
train_data_index = [ pathlib.Path(i).stem.replace('train_', '') for i in train_paths]
val_data_index = [ pathlib.Path(i).stem.replace('train_', '') for i in val_paths]

In [None]:
df_train = df.query(f'fold == {train_data_index}')
df_val = df.query(f'fold == {val_data_index}')

In [None]:
train = (train.shuffle(1012)
         .repeat()
         .batch(BATCH_SIZE)
         .prefetch(buffer_size=AUTO))

val = (val.repeat()
       .batch(BATCH_SIZE)
       .prefetch(buffer_size=AUTO))

# Build Model

In [None]:
class BatchNormalization(tf.keras.layers.BatchNormalization):
    """Make trainable=False freeze BN for real (the og version is sad).
       ref: https://github.com/zzh8829/yolov3-tf2
    """
    def call(self, x, training=False):
        if training is None:
            training = tf.constant(False)
        training = tf.logical_and(training, self.trainable)
        return super().call(x, training)

In [None]:
def ArcHead(num_classes, margin=0.5, logist_scale=64, name='ArcHead'):
    """Arc Head"""
    def arc_head(x_in, y_in):
        x = inputs1 = Input(x_in.shape[1:])
        y = Input(y_in.shape[1:])
        x = ArcMarginPenaltyLogists(num_classes=num_classes,
                                    margin=margin,
                                    logist_scale=logist_scale)(x, y)
        return Model((inputs1, y), x, name=name)((x_in, y_in))
    return arc_head

In [None]:
class ArcMarginPenaltyLogists(tf.keras.layers.Layer):
    """ArcMarginPenaltyLogists"""
    def __init__(self, num_classes, margin=0.5, logist_scale=64, **kwargs):
        super(ArcMarginPenaltyLogists, self).__init__(**kwargs)
        self.num_classes = num_classes
        self.margin = margin
        self.logist_scale = logist_scale

    def build(self, input_shape):
        self.w = self.add_variable(
            "weights", shape=[int(input_shape[-1]), self.num_classes])
        self.cos_m = tf.identity(math.cos(self.margin), name='cos_m')
        self.sin_m = tf.identity(math.sin(self.margin), name='sin_m')
        self.th = tf.identity(math.cos(math.pi - self.margin), name='th')
        self.mm = tf.multiply(self.sin_m, self.margin, name='mm')

    def call(self, embds, labels):
        normed_embds = tf.nn.l2_normalize(embds, axis=1, name='normed_embd')
        normed_w = tf.nn.l2_normalize(self.w, axis=0, name='normed_weights')

        cos_t = tf.matmul(normed_embds, normed_w, name='cos_t')
        sin_t = tf.sqrt(1. - cos_t ** 2, name='sin_t')

        cos_mt = tf.subtract(
            cos_t * self.cos_m, sin_t * self.sin_m, name='cos_mt')

        cos_mt = tf.where(cos_t > self.th, cos_mt, cos_t - self.mm)

        mask = tf.one_hot(tf.cast(labels, tf.int32), depth=self.num_classes,
                          name='one_hot_mask')

        logists = tf.where(mask == 1., cos_mt, cos_t)
        logists = tf.multiply(logists, self.logist_scale, 'arcface_logist')

        return logists

In [None]:
def _regularizer(weights_decay=5e-4):
    return tf.keras.regularizers.l2(weights_decay)

In [None]:
preprocess_input = efficientnet.preprocess_input

base_model = EfficientNetB3(input_shape=[IMG_SIIZE, IMG_SIIZE, 3],
                              include_top=False,
                              weights='imagenet')

base_model.trainable = True

def build_model(w_decay=5e-4, embd_shape=256, is_training=True, num_classes=None, margin=0.5, logist_scale=32):
    with strategy.scope():

        inputs = Input([IMG_SIIZE, IMG_SIIZE, IMG_CHANNEL], name='input_image')

        x = preprocess_input(inputs)

        x = base_model(x)
        
        x = BatchNormalization()(x)
        
        x = Dropout(rate=0.3)(x)
        
        x = GlobalAveragePooling2D()(x)
        
        x = Dense(embd_shape, kernel_regularizer=_regularizer(w_decay))(x)
        
        embds = BatchNormalization(name='embs')(x)

        if is_training:
            assert num_classes is not None
            labels = Input([], name='label')
            logist = ArcHead(num_classes=num_classes, 
                            margin=margin,
                            logist_scale=logist_scale)(embds, labels)
            return Model((inputs, labels), logist)
        else:
            return Model(inputs, embds)

In [None]:
model = build_model(num_classes=N_CLASSES, is_training=True)

In [None]:
model.summary()

In [None]:
learning_rate = tf.constant(1e-4)

model.compile(
            optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate),
            loss = [tf.keras.losses.SparseCategoricalCrossentropy()],
            metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]
            ) 

In [None]:
checkpoint = tf.keras.callbacks.ModelCheckpoint(f'EfficientNetB3.h5', 
                                                monitor = 'val_loss', 
                                                verbose = 2, 
                                                save_best_only = True,
                                                save_weights_only = True, 
                                                mode = 'min')

In [None]:
steps_per_epoch = len(df_train) // BATCH_SIZE
val_steps_per_epoch = len(df_val) // BATCH_SIZE

model.fit(train,
          epochs=EPOCHS,
          steps_per_epoch=steps_per_epoch,
          validation_data=val,
          callbacks = [checkpoint],
          validation_steps=val_steps_per_epoch)

In [None]:
model = Model(inputs=model.layers[0].input, outputs=model.layers[get_layer_index(model, 'embs')].output)

In [None]:
model.save('model')