### This notebook consists of two parts, the 1st demonstrates how can we create startified tfrecords and the other explains how to set up training pipeline with those tfrecords!! 

In Development (For 2021 Data)

# 1. Creating Stratified TFRecords For Google Landmark Recognition 2021
This notebook discusses how you can create Stratified TFRecords to train tensorflow models for Google Landmark Recognition 2021. Some parts of the code is taken from this excellent notebook by Chris Doette: https://www.kaggle.com/cdeotte/how-to-create-tfrecords/output. The folds are created by stratifying on landmark_id. Further groups are created on the basis of density of landmark id to which image belongs. This grouping makes it easy to train models on landmark_id whose sample size is greater than certain threshold. You are encouraged to try your own groups based on your training strategy. 

For each (fold,group), there is a single TFRecord. 
For example **train-0-1.tfrec** corresponds to Fold 0 & Group 1

Further the notebook discusses how you can read these TFRecords & prepare dataset which can directly be passed to keras.fit method.

**Suggestion:** If you train only single model, choose fold 0 or 1 as validation set to ensure all landmark ids have at least one sample in validation set

Last but not the least, I have prepared all the 60 tfrecords and the link for Kaggle Dataset is mentioned at the end.

## Load Libraries

In [None]:
import os, json, random, cv2
import numpy as np, pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf, re, math
from tqdm import tqdm

## Login to Kaggle API

Kaggle API allows you to perform basic dataset operations (creation, versioning, etc). We are going to create tfrecords in /tmp/ directory which can store upto 100 GBs of data and then create a single kaggle dataset using kaggle api. For more details regarding kaggle api, please refer https://www.kaggle.com/docs/api or https://github.com/Kaggle/kaggle-api

In [None]:
%%time

### Create Kaggle Dataset if not exists 

DATASET_NAME = 'landmark-recognition-2021-tfrecords-subset'

!rm -r /tmp/{DATASET_NAME}

os.makedirs(f'/tmp/{DATASET_NAME}', exist_ok=True)

with open('../input/kaggle-api-creds/kaggle.json') as f:
    kaggle_creds = json.load(f)
    
os.environ['KAGGLE_USERNAME'] = kaggle_creds['username']
os.environ['KAGGLE_KEY'] = kaggle_creds['key']

!kaggle datasets init -p /tmp/{DATASET_NAME}

with open(f'/tmp/{DATASET_NAME}/dataset-metadata.json') as f:
    dataset_meta = json.load(f)
dataset_meta['id'] = f'ks2019/{DATASET_NAME}'
dataset_meta['title'] = DATASET_NAME
with open(f'/tmp/{DATASET_NAME}/dataset-metadata.json', "w") as outfile:
    json.dump(dataset_meta, outfile)
print(dataset_meta)

!cp /tmp/{DATASET_NAME}/dataset-metadata.json /tmp/{DATASET_NAME}/meta.json
!ls /tmp/{DATASET_NAME}

!kaggle datasets create -u -p /tmp/{DATASET_NAME} 

## Configurations

In [None]:
RESIZE = True
N_LABELS = 81313
N_LABELS_SUBSET = 100
IMAGE_SIZE = 224
N_GROUPS = 12
N_FOLDS = 5
N_TFRs = N_GROUPS*N_FOLDS
SUBSET = True  # Keep SUBSET=True while debugging (Faster Execution)
BATCH_SIZE = 32
FOLDS = [0]
GROUPS = [11]
assert max(FOLDS)<N_FOLDS, "ELEMENTS OF FOLDS can't be greater than N_FOLDS"
assert max(GROUPS)<N_GROUPS, "ELEMENTS OF FOLDS can't be greater than N_FOLDS"

## Preparing Folds and Groups
### You may change heuristics for Groups as per your requirements

In [None]:
train_df = pd.read_csv('../input/landmark-recognition-2021/train.csv')
if SUBSET:
    landmarks = random.sample(list(train_df.landmark_id.unique()),100)
    train_df = train_df[train_df.landmark_id.isin(landmarks)].reset_index(drop=True)
    N_LABELS = N_LABELS_SUBSET
train_df['original_landmark_id'] = train_df.landmark_id
print(train_df.shape)
train_df['order'] = np.arange(train_df.shape[0])
train_df['order'] = train_df.groupby('landmark_id').order.rank()-1
landmark_counts = train_df.landmark_id.value_counts()
train_df['landmark_counts'] = landmark_counts.loc[train_df.landmark_id.values].values
train_df['fold'] = (train_df['order']%N_FOLDS).astype(int)
all_groups = [(1/N_GROUPS)*x for x in range(N_GROUPS)]
for i,partition_val in enumerate(train_df.landmark_counts.quantile(all_groups).values):
                     train_df.loc[train_df.landmark_counts>=partition_val,'group'] = i 
        
landmark_map = train_df.sort_values(by='landmark_counts').landmark_id.drop_duplicates().reset_index(drop=True)
landmark_dict = {landmark_map.loc[x]:N_LABELS-x-1 for x in range(N_LABELS)}
train_df['landmark_id'] = train_df.original_landmark_id.apply(lambda x: landmark_dict[x])
train_df = train_df.sample(frac=1).reset_index(drop=True)
train_df.to_csv(f'/tmp/{DATASET_NAME}/train_meta_data.csv',index=False)
train_df.sample(5)

In [None]:
#Checking Null values
train_df.isna().sum().sum()

## Group Partitions

In [None]:
train_df.groupby('group').landmark_counts.agg(['min','max'])

## Some Statistics

In [None]:
#Landmark Counts
train_df.landmark_id.value_counts()

In [None]:
#No of images GroupBy landmark counts
train_df.landmark_counts.value_counts()

In [None]:
#No of Images in Each Folds
train_df.fold.value_counts()

In [None]:
#No of Images in Each Group
train_df.group.value_counts()

In [None]:
# No of Landmark in each Fold
train_df.groupby('fold').landmark_id.nunique()

## No of images in each (fold,group)
### Each row corresponds to single tf-records 

In [None]:
pd.pivot_table(train_df.groupby(['fold','group']).id.count().reset_index(),index='fold',columns='group')

## Creating TF-Records

In [None]:
def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def serialize_example(image,image_name,label):
    feature = {
        'image': _bytes_feature(image),
        'image_name': _bytes_feature(image_name),
        'target': _int64_feature(label),
      }
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

In [None]:
def create_tf_records(fold  = 0, group = 0):
    df = train_df[(train_df.fold==fold) & (train_df.group==group)]
    tfr_filename = f'/tmp/{DATASET_NAME}/landmark-2021-train-{fold}-{group}-{df.shape[0]}.tfrec'
    with tf.io.TFRecordWriter(tfr_filename) as writer:
        for i,row in df.iterrows():
            image_id = row.id
            target = row.landmark_id
            image_path = "../input/landmark-recognition-2020/train/{}/{}/{}/{}.jpg".format(image_id[0],image_id[1],image_id[2],image_id) 
            image_encoded = tf.io.read_file(image_path)
            image_name = str.encode(image_id)
            example = serialize_example(image_encoded,image_name,target)
            writer.write(example)

In [None]:
import joblib
for fold in range(N_FOLDS):
    _ = joblib.Parallel(n_jobs=8)(
        joblib.delayed(create_tf_records)(fold,group) for group in tqdm(range(N_GROUPS))
    )

## Upload Dataset

In [None]:
from datetime import datetime
version_name = datetime.now().strftime("%Y%m%d-%H%M%S")
print(version_name)

In [None]:
!ls /tmp/{DATASET_NAME}

In [None]:
!kaggle datasets version -m {version_name} -p /tmp/{DATASET_NAME} -r zip -q

## Reading TFRecords & preparing Tensorflow Dataset for model training

In [None]:
def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32) / 255.0  # convert image to floats in [0, 1] range
    image = tf.image.resize(image,IMAGE_SIZE_)
    return image

def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "image_name": tf.io.FixedLenFeature([], tf.string),  # shape [] means single element
        'target': tf.io.FixedLenFeature([], tf.int64),
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    label = example['target']
    return image, label # returns a dataset of (image, label) pairs

def load_dataset(filenames, labeled=True, ordered=False):
    # Read from TFRecords. For optimal performance, reading from multiple files at once and
    # disregarding data order. Order does not matter since we will be shuffling the data anyway.

    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False # disable order, increase speed

    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(read_labeled_tfrecord)
    # returns a dataset of (image, label) pairs if labeled=True or (image, id) pairs if labeled=False
    return dataset

def get_training_dataset():
    dataset = load_dataset(TRAINING_FILENAMES, labeled=True)
    dataset = dataset.repeat() # the training dataset must repeat for several epochs
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def count_data_items(filenames):
    # the number of data items is written in the name of the .tfrec files, i.e. flowers00-230.tfrec = 230 data items
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

In [None]:
IMAGE_SIZE_ = [IMAGE_SIZE,IMAGE_SIZE]
AUTO = tf.data.experimental.AUTOTUNE
TRAINING_FILENAMES = tf.io.gfile.glob(f'/tmp/{DATASET_NAME}/landmark-2021-train*.tfrec')
print(len(TRAINING_FILENAMES))
dataset = load_dataset(TRAINING_FILENAMES, labeled=True)
dataset = dataset.repeat()
dataset = dataset.shuffle(2048)
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.prefetch(AUTO) #This dataset can directly be passed to keras.fit method
count_data_items(TRAINING_FILENAMES)

In [None]:
for x,y in dataset:
    print(x.shape,y.shape)
    break

## Verifying Dataset is written in correct format

In [None]:
# numpy and matplotlib defaults
np.set_printoptions(threshold=15, linewidth=80)
CLASSES = [0,1]

def batch_to_numpy_images_and_labels(data):
    images, labels = data
    numpy_images = images.numpy()
    numpy_labels = labels.numpy()
    #if numpy_labels.dtype == object: # binary string in this case, these are image ID strings
    #    numpy_labels = [None for _ in enumerate(numpy_images)]
    # If no labels, only image IDs, return None for labels (this is the case for test data)
    return numpy_images, numpy_labels

def display_single_sample(image, label, subplot, red=False, titlesize=16):
    plt.subplot(*subplot)
    plt.axis('off')
    plt.imshow(image)
    title = str(label)
    if len(title) > 0:
        plt.title(title, fontsize=int(titlesize) if not red else int(titlesize/1.2), color='red' if red else 'black', fontdict={'verticalalignment':'center'}, pad=int(titlesize/1.5))
    return (subplot[0], subplot[1], subplot[2]+1)
    
def display_batch_of_images(databatch):
    """
    Display single batch Of images 
    """
    # data
    images, labels = batch_to_numpy_images_and_labels(databatch)
    if labels is None:
        labels = [None for _ in enumerate(images)]
        
    # auto-squaring: this will drop data that does not fit into square or square-ish rectangle
    rows = int(math.sqrt(len(images)))
    cols = len(images)//rows
        
    # size and spacing
    FIGSIZE = 13.0
    SPACING = 0.1
    subplot=(rows,cols,1)
    if rows < cols:
        plt.figure(figsize=(FIGSIZE,FIGSIZE/cols*rows))
    else:
        plt.figure(figsize=(FIGSIZE/rows*cols,FIGSIZE))
    
    # display
    for i, (image, label) in enumerate(zip(images[:rows*cols], labels[:rows*cols])):
        correct = True
        dynamic_titlesize = FIGSIZE*SPACING/max(rows,cols)*40+3 # magic formula tested to work from 1x1 to 10x10 images
        subplot = display_single_sample(image, label, subplot, not correct, titlesize=dynamic_titlesize)
    
    #layout
    plt.tight_layout()
    if label is None and predictions is None:
        plt.subplots_adjust(wspace=0, hspace=0)
    else:
        plt.subplots_adjust(wspace=SPACING, hspace=SPACING)
    plt.show()

In [None]:
# Displaying single batch of TFRecord
train_batch = iter(dataset)
display_batch_of_images(next(train_batch))

# Great, Now you have to run this notebook for different folds & prepare all the tf records!!The good part is that I have already done this & you can find all the tf records & metadata here:
1. Fold 0: https://www.kaggle.com/ks2019/landmark-recognition-2021-tfrecords-fold0
2. Fold 1: https://www.kaggle.com/ks2019/landmark-recognition-2021-tfrecords-fold1
3. Fold 2: https://www.kaggle.com/ks2019/landmark-recognition-2021-tfrecords-fold2
4. Fold 3: https://www.kaggle.com/ks2019/landmark-recognition-2021-tfrecords-fold3
5. Fold 4: https://www.kaggle.com/ks2019/landmark-recognition-2021-tfrecords-fold4
6. Metadata: https://www.kaggle.com/ks2019/landmark-recognition-2021-tfrecords-fold1?select=train_meta_data.csv

Please note for training models, you don't need to add any dataset, you can directly use the gcs paths. The GCS path for a dataset can be found by using the following commands: 

```from kaggle_datasets import KaggleDatasets
KaggleDatasets().get_gcs_path(f'landmark-recognition-2021-tfrecords-fold{fold}')```


# 2. Setting up Training Pipeline
This section discusses how to train efficient net model using TFRecords prepared above!! Most of the part of this section is taken from this excellent chris notebook: https://www.kaggle.com/cdeotte/triple-stratified-kfold-with-tfrecords !! 

This only demonstrates training pipeline for a subset of training data!! You may want to play with "GROUPS" argument to add more data!! 
### The arcface implementation is taken from:
1. https://www.kaggle.com/akensert/glrec-resnet50-arcface-tf2-2
2. https://www.kaggle.com/ragnar123/shopee-efficientnetb3-arcmarginproduct

In [None]:
!pip install -q efficientnet
!pip install tensorflow_addons
import re
import os
import numpy as np
import pandas as pd
import random
import math
import tensorflow as tf
import efficientnet.tfkeras as efn
from sklearn import metrics
from sklearn.model_selection import KFold, train_test_split
from tensorflow.keras import backend as K
import tensorflow_addons as tfa
from tqdm.notebook import tqdm
from kaggle_datasets import KaggleDatasets
import matplotlib.pyplot as plt

In [None]:
try:
    # TPU detection. No parameters necessary if TPU_NAME environment variable is
    # set: this is always the case on Kaggle.
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
else:
    # Default distribution strategy in Tensorflow. Works on CPU and single GPU.
    strategy = tf.distribute.get_strategy()

print("REPLICAS: ", strategy.num_replicas_in_sync)

In [None]:
# For tf.dataset
AUTO = tf.data.experimental.AUTOTUNE

### Note that: These GCS paths are updated every week. To get updated GCS path, follow the instructions mentioned above.

# Data access
GCS_PATHS = {
    0: 'gs://kds-c430f4dc931cb05decf924854e81afc21a367188a0feb7718872ebf1',
    1: 'gs://kds-47fd736d084594eec1e33a8bba3c79c0539e21279a3207233c12dcb6',
    2: 'gs://kds-de7798fc5a5e4670e6421cd846bf3dd1dcc9c6340acef9fc5a85f247',
    3: 'gs://kds-89b8e7f8f9fe4836d0092e6b753505a82af130e9c270d91c1ee092c7',
    4: 'gs://kds-697a86388d216aba3e3119019447f5053586208639b827f5bb0fb53a'
}

# Configuration
EPOCHS = 4
BATCH_SIZE = 16 * strategy.num_replicas_in_sync
IMAGE_SIZE = [256, 256]
# Seed
SEED = 42
FOLD_TO_RUN = [0,1,2,3,4]
# Learning rate
LR = 0.001
# Verbosity
VERBOSE = 2
# Number of classes
N_CLASSES = 81313
# Number of folds
FOLDS = 5
# EfficientNet
EFF_NET = 0
# Freeze Batch Norm
FREEZE_BATCH_NORM = False
SNAPSHOT_THRESHOLD = 0

# Training filenames directory
TRAINING_FILENAMES = []
for fold in GCS_PATHS:
    TRAINING_FILENAMES += tf.io.gfile.glob(GCS_PATHS[fold] + '/*.tfrec')
    

In [None]:
len(TRAINING_FILENAMES)

In [None]:
# Function to get our f1 score
def f1_score(y_true, y_pred):
    y_true = y_true.apply(lambda x: set(x.split()))
    y_pred = y_pred.apply(lambda x: set(x.split()))
    intersection = np.array([len(x[0] & x[1]) for x in zip(y_true, y_pred)])
    len_y_pred = y_pred.apply(lambda x: len(x)).values
    len_y_true = y_true.apply(lambda x: len(x)).values
    f1 = 2 * intersection / (len_y_pred + len_y_true)
    return f1

# Function to seed everything
def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    tf.random.set_seed(seed)
    
def arcface_format(posting_id, image, label_group, matches):
    return posting_id, {'inp1': image, 'inp2': label_group}, label_group, matches

# Data augmentation function
def data_augment(posting_id, image, label_group, matches):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_hue(image, 0.01)
    image = tf.image.random_saturation(image, 0.70, 1.30)
    image = tf.image.random_contrast(image, 0.80, 1.20)
    image = tf.image.random_brightness(image, 0.10)
    return posting_id, image, label_group, matches

# Function to decode our images
def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels = 3)
    image = tf.image.resize(image, IMAGE_SIZE)
    image = tf.cast(image, tf.float32) / 255.0
    return image

# This function parse our images and also get the target variable
def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image_name": tf.io.FixedLenFeature([], tf.string),
        "image": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.int64),
#         "matches": tf.io.FixedLenFeature([], tf.string)
    }

    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    posting_id = example['image_name']
    image = decode_image(example['image'])
#     label_group = tf.one_hot(tf.cast(example['label_group'], tf.int32), depth = N_CLASSES)
    label_group = tf.cast(example['target'], tf.int32)
#     matches = example['matches']
    matches = 1
    return posting_id, image, label_group, matches

# This function loads TF Records and parse them into tensors
def load_dataset(filenames, ordered = False, cache=False):
    
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False 
        
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads = AUTO)
    if cache:
        dataset = dataset.cache()
    dataset = dataset.with_options(ignore_order)
    dataset = dataset.map(read_labeled_tfrecord, num_parallel_calls = AUTO) 
    return dataset

# This function is to get our training tensors
def get_training_dataset(filenames, ordered = False):
    dataset = load_dataset(filenames, ordered = ordered)
    dataset = dataset.map(data_augment, num_parallel_calls = AUTO)
    dataset = dataset.map(arcface_format, num_parallel_calls = AUTO)
    dataset = dataset.map(lambda posting_id, image, label_group, matches: (image, label_group))
    dataset = dataset.repeat()
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO)
    return dataset

# This function is to get our validation tensors
def get_validation_dataset(filenames, ordered = True):
    dataset = load_dataset(filenames, ordered = ordered)
    dataset = dataset.map(arcface_format, num_parallel_calls = AUTO)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) 
    return dataset

# Function to count how many photos we have in
def count_data_items(filenames):
    # The number of data items is written in the name of the .tfrec files, i.e. flowers00-230.tfrec = 230 data items
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

NUM_TRAINING_IMAGES = count_data_items(TRAINING_FILENAMES)
print(f'Dataset: {NUM_TRAINING_IMAGES} training images')

In [None]:
def get_lr_callback(plot=False):
    lr_start   = 0.000001
    lr_max     = 0.000005 * BATCH_SIZE  
    lr_min     = 0.000001
    lr_ramp_ep = 5
    lr_sus_ep  = 0
    lr_decay   = 0.8
   
    def lrfn(epoch):
        if epoch < lr_ramp_ep:
            lr = (lr_max - lr_start) / lr_ramp_ep * epoch + lr_start
            
        elif epoch < lr_ramp_ep + lr_sus_ep:
            lr = lr_max
            
        else:
            lr = (lr_max - lr_min) * lr_decay**(epoch - lr_ramp_ep - lr_sus_ep) + lr_min
            
        return lr
        
    if plot:
        epochs = list(range(EPOCHS))
        learning_rates = [lrfn(x) for x in epochs]
        plt.scatter(epochs,learning_rates)
        plt.show()

    lr_callback = tf.keras.callbacks.LearningRateScheduler(lrfn, verbose=False)
    return lr_callback

get_lr_callback(plot=True)

In [None]:
# Arcmarginproduct class keras layer
class ArcMarginProduct(tf.keras.layers.Layer):
    '''
    Implements large margin arc distance.

    Reference:
        https://arxiv.org/pdf/1801.07698.pdf
        https://github.com/lyakaap/Landmark2019-1st-and-3rd-Place-Solution/
            blob/master/src/modeling/metric_learning.py
    '''
    def __init__(self, n_classes, s=30, m=0.50, easy_margin=False,
                 ls_eps=0.0, **kwargs):

        super(ArcMarginProduct, self).__init__(**kwargs)

        self.n_classes = n_classes
        self.s = s
        self.m = m
        self.ls_eps = ls_eps
        self.easy_margin = easy_margin
        self.cos_m = tf.math.cos(m)
        self.sin_m = tf.math.sin(m)
        self.th = tf.math.cos(math.pi - m)
        self.mm = tf.math.sin(math.pi - m) * m

    def get_config(self):

        config = super().get_config().copy()
        config.update({
            'n_classes': self.n_classes,
            's': self.s,
            'm': self.m,
            'ls_eps': self.ls_eps,
            'easy_margin': self.easy_margin,
        })
        return config

    def build(self, input_shape):
        super(ArcMarginProduct, self).build(input_shape[0])

        self.W = self.add_weight(
            name='W',
            shape=(int(input_shape[0][-1]), self.n_classes),
            initializer='glorot_uniform',
            dtype='float32',
            trainable=True,
            regularizer=None)

    def call(self, inputs):
        X, y = inputs
        y = tf.cast(y, dtype=tf.int32)
        cosine = tf.matmul(
            tf.math.l2_normalize(X, axis=1),
            tf.math.l2_normalize(self.W, axis=0)
        )
        sine = tf.math.sqrt(1.0 - tf.math.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = tf.where(cosine > 0, phi, cosine)
        else:
            phi = tf.where(cosine > self.th, phi, cosine - self.mm)
        one_hot = tf.cast(
            tf.one_hot(y, depth=self.n_classes),
            dtype=cosine.dtype
        )
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.n_classes

        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s
        return output

EFNS = [efn.EfficientNetB0, efn.EfficientNetB1, efn.EfficientNetB2, efn.EfficientNetB3, 
        efn.EfficientNetB4, efn.EfficientNetB5, efn.EfficientNetB6, efn.EfficientNetB7]

def freeze_BN(model):
    # Unfreeze layers while leaving BatchNorm layers frozen
    for layer in model.layers:
        if not isinstance(layer, tf.keras.layers.BatchNormalization):
            layer.trainable = True
        else:
            layer.trainable = False

# Function to create our EfficientNetB3 model
def get_model():

    with strategy.scope():

        margin = ArcMarginProduct(
            n_classes = N_CLASSES, 
            s = 30, 
            m = 0.5, 
            name='head/arc_margin', 
            dtype='float32'
            )

        inp = tf.keras.layers.Input(shape = (*IMAGE_SIZE, 3), name = 'inp1')
        label = tf.keras.layers.Input(shape = (), name = 'inp2')
        x = EFNS[EFF_NET](weights = 'imagenet', include_top = False)(inp)
        x = tf.keras.layers.GlobalAveragePooling2D()(x)
        x = margin([x, label])
        
        output = tf.keras.layers.Softmax(dtype='float32')(x)

        model = tf.keras.models.Model(inputs = [inp, label], outputs = [output])

        opt = tf.keras.optimizers.Adam(learning_rate = LR)
        if FREEZE_BATCH_NORM:
            freeze_BN(model)

        model.compile(
            optimizer = opt,
            loss = [tf.keras.losses.SparseCategoricalCrossentropy()],
            metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]
            ) 
        
        return model
get_model().summary()

In [None]:
row = 10; col = 2;
row = min(row,BATCH_SIZE//col)
N_TRAIN = count_data_items(TRAINING_FILENAMES)
print(N_TRAIN)
ds = get_training_dataset(TRAINING_FILENAMES, ordered = False)

for (sample,label) in ds:
    img = sample['inp1']
    plt.figure(figsize=(25,int(25*row/col)))
    for j in range(row*col):
        plt.subplot(row,col,j+1)
        plt.title(label[j].numpy())
        plt.axis('off')
        plt.imshow(img[j,])
    plt.show()
    break
print(img.shape)

In [None]:
sample.keys()

In [None]:
def is_interactive():
    return 'runtime'    in get_ipython().config.IPKernelApp.connection_file
IS_INTERACTIVE = is_interactive()
print(IS_INTERACTIVE)

In [None]:
class Snapshot(tf.keras.callbacks.Callback):
    
    def __init__(self,snapshot_min_epoch,fold):
        super(Snapshot, self).__init__()
        self.snapshot_min_epoch = snapshot_min_epoch
        self.fold = fold
        
        
    def on_epoch_end(self, epoch, logs=None):
        # logs is a dictionary
#         print(f"epoch: {epoch}, train_acc: {logs['acc']}, valid_acc: {logs['val_acc']}")
        if epoch >=self.snapshot_min_epoch: # your custom condition         
            self.model.save_weights(f"EF{EFF_NET}_fold{self.fold}_epoch{epoch}.h5")

In [None]:
VERBOSE = 1
seed_everything(SEED)
train = np.array(TRAINING_FILENAMES)
STEPS_PER_EPOCH = count_data_items(train) // BATCH_SIZE
train_dataset = get_training_dataset(train, ordered = False)
model = get_model()
snap = Snapshot(snapshot_min_epoch=SNAPSHOT_THRESHOLD,fold=0)
history = model.fit(train_dataset,
                        steps_per_epoch = STEPS_PER_EPOCH,
                        epochs = EPOCHS,
                        callbacks = [snap,get_lr_callback()], 
                        verbose = VERBOSE)