# Shopee Product Matching - Image Similarity Model

This notebook outlines the training of a CNN to determine image similarities using Tensorflow2 with the following features:
* TPU Enhanced training. For construction of TFRecords required for TPU acceleration, please reference to [this link](https://www.kaggle.com/sandersli/shopee-product-matching-create-tfrecords)
* [EfficientNet Backbone](https://keras.io/examples/vision/image_classification_efficientnet_fine_tuning/)
* [Global Mean Average Pooling](https://arxiv.org/pdf/1711.02512.pdf)
* [ArcFace Loss](https://arxiv.org/pdf/1801.07698.pdf)
* Comprehensive image augmentation on dataset construction

Thanks to [ragnar's fantastic notebook for providing a jumping off point](https://www.kaggle.com/ragnar123/shopee-efficientnetb3-arcmarginproduct)

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random
import tensorflow as tf
import tensorflow_addons as tfa

from PIL import Image
from sklearn import metrics
from sklearn.model_selection import KFold, train_test_split
from tqdm.notebook import tqdm
from kaggle_datasets import KaggleDatasets

Enable TPU acceleration

In [None]:
AUTO = tf.data.experimental.AUTOTUNE
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver().connect()
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
    print(f"Running on TPU {tpu.master()} with {strategy.num_replicas_in_sync} replicas")
except ValueError:
    print("Not connected to a TPU runtime. Using CPU/GPU strategy")
    strategy = tf.distribute.MirroredStrategy()

In [None]:
GCS_PATH = KaggleDatasets().get_gcs_path('shopee-tfrec-224px')
train_filenames = tf.io.gfile.glob([GCS_PATH + '/train/*.tfrec'])
test_filenames = tf.io.gfile.glob([GCS_PATH + '/test/*tfrec'])

In [None]:
df = pd.read_csv('../input/shopee-product-matching/train.csv')

In [None]:
# Use correct image size with pretrained model
IMAGE_SIZE = (300, 300)
# Train-test-split size
TRAIN_SIZE = 0.8
# Initial learning rate
LR = 0.001
# ArcFace must assume a certain number of classes to optimize loss. May get better results on test set with higher N_CLASSES
N_CLASSES = df['label_group'].nunique()

SEED = 42

EPOCHS = 30
BATCH_SIZE = 32 * strategy.num_replicas_in_sync
STEPS_PER_EPOCH = len(df) * TRAIN_SIZE // BATCH_SIZE

In [None]:
# Function to seed everything
def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    tf.random.set_seed(seed)

## Model Construction

In [None]:
# ArcFace parameters
# s: norm of input feature
# m: margin, Original paper states that m=0.5 gives best results. I found this variable to have the strongest effect when calculating final f1 score
params = {
    'm': 0.3, 
    's': 30
}

In [None]:
# Function to augment data
# As data was serialized to TFRecords, I directly convert TFRecords to datasets and thus cannot use Keras ImageDataGenerator
def data_augment(posting_id, image, label_group, matches):   
    rotate = tf.random.uniform(shape=(), minval=-0.1*np.pi, maxval=0.1*np.pi)
    image = tfa.image.rotate(image, rotate, interpolation='bilinear', fill_mode='constant')
    shear_x = tf.random.uniform(shape=(), minval=-0.2, maxval=0.2)
    shear_y = tf.random.uniform(shape=(), minval=-0.2, maxval=0.2)
    image = tfa.image.transform(image, [1.0, shear_x, 0.0, shear_y, 1.0, 0.0, 0.0, 0.0], interpolation='bilinear', fill_mode='constant')
    translate_vec = tf.random.uniform(shape=(2,), minval=-int(0.05*IMAGE_SIZE[0]), maxval=int(0.05*IMAGE_SIZE[0]))
    image = tfa.image.translate(image, translate_vec, interpolation='bilinear', fill_mode='constant')
    
    crop_size = tf.random.uniform(shape=(), minval=int(0.8*IMAGE_SIZE[0]), maxval=int(1.2*IMAGE_SIZE[0]), dtype=tf.int32)
    image = tf.image.resize_with_crop_or_pad(image, crop_size, crop_size)  
    image = tf.image.resize(image, IMAGE_SIZE)
    
    image = tf.image.random_brightness(image, 0.10)
    image = tf.image.random_hue(image, 0.01)
    image = tf.image.random_saturation(image, 0.80, 1.20)
    image = tf.image.random_contrast(image, 0.80, 1.20)
    image = tf.image.random_flip_left_right(image)
    return posting_id, image, label_group, matches

# Function to decode images from serialized image data from TFRecords
def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels = 3)
    image = tf.image.resize(image, IMAGE_SIZE)
    image = tf.cast(image, tf.float32) / 255.0
    return image

# Function to read TFRecords
def read_tfrec(example):
    tfrec_format = {
        "posting_id": tf.io.FixedLenFeature([], tf.string),
        "image": tf.io.FixedLenFeature([], tf.string),
        "label_group": tf.io.FixedLenFeature([], tf.int64),
        "matches": tf.io.FixedLenFeature([], tf.string)
    }

    example = tf.io.parse_single_example(example, tfrec_format)
    posting_id = example['posting_id']
    image = decode_image(example['image'])
    label_group = tf.cast(example['label_group'], tf.int32)
    matches = example['matches']
    return posting_id, image, label_group, matches

# Function to create a dataset by reading TFRecords
def load_dataset(filenames):
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads = AUTO)
    dataset = dataset.map(read_tfrec, num_parallel_calls = AUTO) 
    return dataset

# Function to reformat dataset for model
def arcface_format(posting_id, image, label_group, matches):
    return posting_id, {'image': image, 'label': label_group}, label_group, matches

# Function to construct dataset
def get_dataset(filenames, training=False):
    dataset = load_dataset(filenames)
    if training:
        ignore_order = tf.data.Options()
        dataset = dataset.with_options(ignore_order)
        dataset = dataset.map(data_augment, num_parallel_calls = AUTO)
        dataset = dataset.repeat()
        dataset = dataset.shuffle(2048)
    dataset = dataset.map(arcface_format, num_parallel_calls = AUTO)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO)
    return dataset

In [None]:
# Split data into train and validation sets
train, valid = train_test_split(train_filenames, shuffle = True, random_state = SEED)
train_dataset = get_dataset(train, training=True)

In [None]:
# Visualize augmented dataset
fig = plt.figure(figsize=(24, 24))
rows, cols = 3, 3
ax = fig.subplots(rows, cols)
for example in train_dataset.take(1):
    for i in range(rows * cols):
        subplot = i//rows, i%cols
        ax[subplot].imshow(example[1]['image'][i])
plt.show()

## Construct Model

In [None]:
# Establish learning rate function
cosine_lr_fn = tf.keras.experimental.CosineDecay(LR, 3*STEPS_PER_EPOCH*10)

In [None]:
class ArcMarginProduct(tf.keras.layers.Layer):
    '''
    Implements large margin arc distance.

    Reference:
        https://arxiv.org/pdf/1801.07698.pdf
        https://github.com/lyakaap/Landmark2019-1st-and-3rd-Place-Solution/
            blob/master/src/modeling/metric_learning.py
    '''
    def __init__(self, n_classes, s=30, m=0.50, easy_margin=False,
                 ls_eps=0.0, **kwargs):

        super(ArcMarginProduct, self).__init__(**kwargs)

        self.n_classes = n_classes
        self.s = s
        self.m = m
        self.ls_eps = ls_eps
        self.easy_margin = easy_margin
        self.cos_m = tf.math.cos(m)
        self.sin_m = tf.math.sin(m)
        self.th = tf.math.cos(np.pi - m)
        self.mm = tf.math.sin(np.pi - m) * m

    def get_config(self):

        config = super().get_config().copy()
        config.update({
            'n_classes': self.n_classes,
            's': self.s,
            'm': self.m,
            'ls_eps': self.ls_eps,
            'easy_margin': self.easy_margin,
        })
        return config

    def build(self, input_shape):
        super(ArcMarginProduct, self).build(input_shape[0])

        self.W = self.add_weight(
            name='W',
            shape=(int(input_shape[0][-1]), self.n_classes),
            initializer='glorot_uniform',
            dtype='float32',
            trainable=True,
            regularizer=None)

    def call(self, inputs):
        X, y = inputs
        y = tf.cast(y, dtype=tf.int32)
        cosine = tf.matmul(
            tf.math.l2_normalize(X, axis=1),
            tf.math.l2_normalize(self.W, axis=0)
        )
        sine = tf.math.sqrt(1.0 - tf.math.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = tf.where(cosine > 0, phi, cosine)
        else:
            phi = tf.where(cosine > self.th, phi, cosine - self.mm)
        one_hot = tf.cast(
            tf.one_hot(y, depth=self.n_classes),
            dtype=cosine.dtype
        )
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.n_classes

        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s
        return output

In [None]:
class GeMPoolingLayer(tf.keras.layers.Layer):
    '''
    Implements Generalized-Mean Pooling layer
    Reference:
        https://arxiv.org/pdf/1711.02512.pdf
    '''
    def __init__(self, p=1., eps=1e-6):
        super().__init__()
        self.p = p
        self.eps = eps

    def call(self, inputs: tf.Tensor, **kwargs):
        inputs = tf.clip_by_value(inputs, clip_value_min=self.eps, clip_value_max=tf.reduce_max(inputs))
        inputs = tf.pow(inputs, self.p)
        inputs = tf.reduce_mean(inputs, axis=[1, 2], keepdims=False)
        inputs = tf.pow(inputs, 1. / self.p)
        return inputs
    
    def get_config(self):
        return {
            'p': self.p,
            'eps': self.eps
        }

In [None]:
#Construct the model
def get_model(params):
    backbone = tf.keras.applications.EfficientNetB3(weights = 'imagenet', include_top = False)
    margin = ArcMarginProduct(
        n_classes = N_CLASSES, 
        s = params['s'],
        m = params['m'],
        name='arc_margin_product', 
        dtype='float32'
        )

    inp = tf.keras.layers.Input(shape = IMAGE_SIZE + (3,), name = 'image')
    label = tf.keras.layers.Input(shape = (), name = 'label')
    x = tf.keras.applications.efficientnet.preprocess_input(inp)
    x = backbone(x)
    x = GeMPoolingLayer()(x)
    x = tf.keras.layers.Dense(512, kernel_regularizer=tf.keras.regularizers.l2(), activation=None)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = margin([x, label])

    output = tf.keras.layers.Softmax(dtype='float32')(x)

    model = tf.keras.models.Model(inputs = [inp, label], outputs = [output])

    model.compile(
        optimizer = tf.keras.optimizers.Adam(learning_rate = cosine_lr_fn),
        loss = [tf.keras.losses.SparseCategoricalCrossentropy()],
        metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]
        ) 

    return model

In [None]:
model = get_model(params)
model.summary()

## Model Training

In [None]:
def train_and_evaluate():
    seed_everything(SEED)
    train, valid = train_test_split(train_filenames, shuffle = True, random_state = SEED)
    train_dataset = get_dataset(train, training=True)
    train_dataset = train_dataset.map(lambda posting_id, image, label_group, matches: (image, label_group))
    val_dataset = get_dataset(valid)
    val_dataset = val_dataset.map(lambda posting_id, image, label_group, matches: (image, label_group))
    tf.keras.backend.clear_session()
    with strategy.scope():
        model = get_model(params)
    # Model checkpoint
    checkpoint = tf.keras.callbacks.ModelCheckpoint(f"EfficientNetB3_{IMAGE_SIZE[0]}_{SEED}_m{params['m']}_s{params['s']}.h5", 
                                                    monitor = 'val_loss',
                                                    save_best_only = True,
                                                    save_weights_only = True, 
                                                    mode = 'min')

    history = model.fit(train_dataset,
                        steps_per_epoch = STEPS_PER_EPOCH,
                        epochs = EPOCHS,
                        callbacks = [checkpoint], 
                        validation_data = val_dataset)
    return history, model
    
hist, model = train_and_evaluate()

In [None]:
import plotly.express as px
fig = px.line(hist.history)
display(fig)
fig.write_html(f"arcface_m{params['m']}.html")

The final acc/loss value isn't a direct measure of how well the model performs. Model loss is evaluated on a classification task (assigning embeddings to labels), while the competition is evaluated on a f1 score based on a clustering task (grouping embeddings together). Having a low loss may result in poor performance on unlabeled test data.