# CurricularFace

In [None]:
class CurricularFace(tf.keras.layers.Layer):
    def __init__(self, n_classes, s=30, m=0.50, easy_margin=False,
                 ls_eps=0.0, **kwargs):
        super(CurricularFace, self).__init__(**kwargs)

        self.n_classes = n_classes
        self.s = s
        self.m = m
        self.ls_eps = ls_eps
        self.easy_margin = easy_margin
        self.cos_m = tf.math.cos(m)
        self.sin_m = tf.math.sin(m)
        self.th = tf.math.cos(math.pi - m)
        self.mm = tf.math.sin(math.pi - m) * m
        self._USE_V2_BEHAVIOR = True

    def _assign_new_value(self, variable, value):
        with backend.name_scope('AssignNewValue') as scope:
          if tf.compat.v1.executing_eagerly_outside_functions():
            return variable.assign(value, name=scope)
          else:
            with tf.compat.v1.colocate_with(variable):  # pylint: disable=protected-access
              return tf.compat.v1.assign(variable, value, name=scope)


    def _get_training_value(self, training=None):
        if training is None:
          training = backend.learning_phase()
        if self._USE_V2_BEHAVIOR:
          if isinstance(training, int):
            training = bool(training)
          if not self.trainable:
            # When the layer is not trainable, it overrides the value passed from
            # model.
            training = False
        return training


    def get_config(self):

        config = super().get_config().copy()
        config.update({
            'n_classes': self.n_classes,
            's': self.s,
            'm': self.m,
            'ls_eps': self.ls_eps,
            'easy_margin': self.easy_margin,
        })
        return config

    def build(self, input_shape):
        super(CurricularFace, self).build(input_shape[0])

        self.W = self.add_weight(
            name='W',
            shape=(int(input_shape[0][-1]), self.n_classes),
            initializer='glorot_uniform',
            dtype='float32',
            trainable=True,
            regularizer=None)
        
        self.t = self.add_weight(
            name='t',
            shape=(1),
            initializer=tf.zeros_initializer(),
            dtype='float32',
            trainable=False,
            regularizer=None,
            aggregation=tf.VariableAggregation.MEAN,
            experimental_autocast=False,
            synchronization=tf.VariableSynchronization.ON_READ)
        
    def call(self, inputs, training=None):
        X, y = inputs
        y = tf.cast(y, dtype=tf.int32)

        do_training = self._get_training_value(training)

        if do_training:
            cosine = tf.matmul(
                tf.math.l2_normalize(X, axis=1),
                tf.math.l2_normalize(self.W, axis=0)
            )
            sine = tf.math.sqrt(1.0 - tf.math.pow(cosine, 2))
            phi = cosine * self.cos_m - sine * self.sin_m

            target_logit = tf.reduce_sum(cosine * tf.cast(tf.one_hot(y, depth=self.n_classes),dtype=cosine.dtype), axis=-1)
            sin_theta = tf.math.sqrt(1.0 - tf.math.pow(target_logit, 2))
            cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m

            phi = tf.where(cosine > self.th, phi, cosine - self.mm)
            one_hot = tf.cast(
                tf.one_hot(y, depth=self.n_classes),
                dtype=cosine.dtype
            )
        
            t = tf.reduce_mean(target_logit) * 0.01 + (1 - 0.01) * self.t
            self._assign_new_value(self.t, t)
            cosine = tf.where(cosine > tf.expand_dims(cos_theta_m, axis=-1), cosine*(self.t+cosine), cosine)

            if self.ls_eps > 0:
                one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.n_classes

            output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
            output *= self.s

        else:
            output = tf.matmul(
                tf.math.l2_normalize(X, axis=1),
                tf.math.l2_normalize(self.W, axis=0)
            )

        return output

# Gradient Accumulation

In [None]:
EPOCHS = 8
N_TRAIN = 1_500_000
STEPS_PER_TPU_CALL = 1
AVG_N_BATCHES = 16
BATCH_SIZE = 8*REPLICAS
STEPS_PER_EPOCH = N_TRAIN//STEPS_PER_TPU_CALL//AVG_N_BATCHES//BATCH_SIZE

###########################################
###### Create objects and Initialize ######
###########################################

with strategy.scope():
    model = build_model()
    loss_fn = tf.keras.losses.sparse_categorical_crossentropy
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001)
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()

###########################################
##### Function to calculate gradients #####
###########################################

@tf.function
def get_gradients(I,O):
    with tf.GradientTape() as tape:
        probabilities = model(I, training=True)
        loss = loss_fn(O, probabilities)
    grads = tape.gradient(loss, model.trainable_variables)
    train_accuracy.update_state(O, probabilities)
    train_acc = train_accuracy.result()
    return grads

###########################################
####### Function to apply gradients #######
###########################################

@tf.function
def apply_gradients(grads):
    optimizer.apply_gradients(zip(grads, model.trainable_variables))


def _minimum_control_deps(outputs):
  """Returns the minimum control dependencies to ensure step succeeded."""
  if tf.executing_eagerly():
    return []  # Control dependencies not needed.
  outputs = tf.nest.flatten(outputs, expand_composites=True)
  for out in outputs:
    # Variables can't be control dependencies.
    if not isinstance(out, tf.Variable):
      return [out]  # Return first Tensor or Op from outputs.
  return []

############################################
########### Train Step Function ############
############################################

@tf.function
def train_step(train_data_iter):
    grads = strategy.run(get_gradients, next(train_data_iter))
    train_acc = train_accuracy.result()

    with tf.device('/TPU:0'):
        grads = strategy.reduce('mean', grads, axis=None)
        grads = [g/(AVG_N_BATCHES*1.0) for g in grads]

    grads_0 = strategy.run(get_gradients, next(train_data_iter))
    train_acc = train_accuracy.result()

    with tf.device('/TPU:0'):
        grads_0 = strategy.reduce('mean', grads_0, axis=None)
        grads = [g0+g1/(AVG_N_BATCHES*1.0) for g0,g1 in zip(grads,grads_0)]

    for _ in tf.range(AVG_N_BATCHES-2):

        with tf.control_dependencies(_minimum_control_deps(grads)):
            grads_0 = strategy.run(get_gradients, next(train_data_iter))
            with tf.device('/TPU:0'):
                grads_0 = strategy.reduce('mean', grads_0, axis=None)
                grads = [g0+g1/(AVG_N_BATCHES*1.0) for g0,g1 in zip(grads,grads_0)]
            
        train_acc = train_accuracy.result()
    with tf.control_dependencies(_minimum_control_deps(grads)):
        strategy.run(apply_gradients, args = (grads,))

# distributed dataset
train_dist_ds = strategy.experimental_distribute_dataset(get_dataset(files_train,batch_size = BATCH_SIZE,mode='train'))

# dataset iterator
train_data_iter = iter(train_dist_ds)

# custom training loop
start = time.time()
for epoch in range(EPOCHS):
    pbar = tqdm(range(STEPS_PER_EPOCH+1))
    for steps in pbar:
        train_step(train_data_iter)
        train_acc = train_accuracy.result().numpy()
        pbar.set_description('Train Accuracy: '+str(train_acc))
    
    print(f'\n\nEVALUATING EPOCH: {epoch}')
    model.evaluate(get_dataset(files_valid,mode='val'))
    print(' \n\n')
    train_accuracy.reset_states()
    model.save_weights(BASE_SAVE_DIR+f'/l-512-custom-epoch-{epoch:02d}.h5')