In [None]:
import tensorflow as tf
import keras_toolkit as kt
import tensorflow_addons as tfa

from image import augmentation
from model import build_classification_model, build_siamese_model, fit

tf.random.set_seed(100)

In [None]:
# try to build a TPU strategy, GPU/CPU otherwise.
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection
    print("Running on TPU ", tpu.cluster_spec().as_dict()["worker"])
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except ValueError:
    print("Not connected to a TPU runtime. Using CPU/GPU strategy")
    strategy = kt.accelerator.auto_select(verbose=True)

## Classification based on EfficientNetB0 + a new top layer
Trained in 3 phases to adapt to the dataset:
1. Trained the new top layer
2. Trained the layers from the last block (block 7a)
3. Trained the whole model

No Batch Normalization originated from the EfficientNetB0 was trained throughout these phases.

### Dataset for Phase 1-3

In [None]:
BATCH_SIZE = 128

# Training set
################
# Glob train files.
# The images were converted into TFRecords and saved in multiple files
train_ds = tf.data.TFRecordDataset.list_files(TRAIN_TFR_PATH) 

# Load train TFRecords files in parallel.
train_ds = train_ds.interleave(
    lambda x: tf.data.TFRecordDataset(x).map(tfr_parser_image_label),
    deterministic=False,
    num_parallel_calls=tf.data.AUTOTUNE
)

# Vectorize data before applying augmentation.
train_ds = train_ds.batch(BATCH_SIZE, drop_remainder=False)\
    .map(augmentation, num_parallel_calls=tf.data.AUTOTUNE)

# Unbatch to shuffle. 
train_ds = train_ds.unbatch()\
    .shuffle(buffer_size = BATCH_SIZE*10, reshuffle_each_iteration = False)

# Batch to output.
train_ds = train_ds.batch(BATCH_SIZE, drop_remainder=False)\
    .prefetch(tf.data.AUTOTUNE)

# Validation set
################
valid_ds = tf.data.TFRecordDataset(VALID_TFR_PATH)\
    .map(tfr_parser_image_label, num_parallel_calls=tf.data.AUTOTUNE)\
    .batch(BATCH_SIZE, drop_remainder=False)\
    .map(augmentation, num_parallel_calls=tf.data.AUTOTUNE)\
    .prefetch(tf.data.AUTOTUNE)


### Phase 1

In [None]:
with strategy.scope():
    model = build_classification_model(top_dropout_rate = 0.2)
    model.compile(
        loss = tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics = [tf.keras.metrics.SparseCategoricalAccuracy()],
        optimizer = tf.keras.optimizers.Adam(learning_rate=1e-2), 
    )

fit(model, train_ds, valid_ds, epochs=1, initial_epoch=0, prefix='trans1')

### Phase 2

In [None]:
with strategy.scope():
    model = build_classification_model(top_dropout_rate = 0.2)
    model.load_weights('trans1_weights.00001.hdf5')
    
    # make the last block trainable except BN
    for layer in model.layers[-20:]:
        if not isinstance(layer, tf.keras.layers.BatchNormalization):
            layer.trainable = True

    model.compile(
        loss = tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics = [tf.keras.metrics.SparseCategoricalAccuracy()],
        optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4), 
    )

fit(model, train_ds, valid_ds, epochs=2, initial_epoch=0, prefix='trans2')

### Phase 3

In [None]:
with strategy.scope():
    model = build_classification_model(top_dropout_rate = 0.2)
    model.load_weights('trans2_weights.00002.hdf5')
    
    # make all blocks trainable except BN
    for layer in model.layers: 
        if not isinstance(layer, tf.keras.layers.BatchNormalization):
            layer.trainable = True

    model.compile(
        loss = tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics = [tf.keras.metrics.SparseCategoricalAccuracy()],
        optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5), 
    )

fit(model, train_ds, valid_ds, epochs=5, initial_epoch=0, prefix='trans3_6')

## Siamese (Dense + L2-Norm)


### Dataset for Phase 4

In [None]:
BATCH_SIZE_EMB = 512

train_ds_emb = tf.data.TFRecordDataset.list_files(TRAIN_EMB_PATH, shuffle=False)\
    .flat_map(tf.data.TFRecordDataset)\
    .map(tfr_parser_emb_idx, num_parallel_calls=tf.data.AUTOTUNE)\
    .batch(BATCH_SIZE_EMB)\
    .prefetch(tf.data.AUTOTUNE)

valid_ds_emb = tf.data.TFRecordDataset.list_files(VALID_EMB_PATH, shuffle=False)\
    .flat_map(tf.data.TFRecordDataset)\
    .map(tfr_parser_emb_idx, num_parallel_calls=tf.data.AUTOTUNE)\
    .batch(BATCH_SIZE_EMB)\
    .prefetch(tf.data.AUTOTUNE)

### Phase 4

Following configurations were used.
1. margin: 0.50, learning_rate: 1e-5, trained for 2 epochs
2. margin: 0.75, learning_rate: 1e-5, trained for 1 epoch
3. margin: 0.75, learning_rate: 1e-4, trained for 1 epoch

In [None]:
with strategy.scope():
    model = build_siamese_model()
    model.compile(
        loss=tfa.losses.TripletSemiHardLoss(margin=0.50, distance_metric='angular'), 
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
    )

fit(model, train_ds_emb, valid_ds_emb, epochs=2, initial_epoch=0, prefix='trans4_4')

In [None]:
with strategy.scope():
    model = build_siamese_model()
    model.load_weights('trans4_4_weights.00002.hdf5')
    model.compile(
        loss=tfa.losses.TripletSemiHardLoss(margin=0.75, distance_metric='angular'),
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5), 
    )

fit(model, train_ds_emb, valid_ds_emb, epochs=1, initial_epoch=0, prefix='trans4_4a')

In [None]:
with strategy.scope():
    model = build_siamese_model()
    model.load_weights('trans4_4a_weights.00001.hdf5')
    model.compile(
        loss=tfa.losses.TripletSemiHardLoss(margin=0.75, distance_metric='angular'),
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4), 
    )

fit(model, train_ds_emb, valid_ds_emb, epochs=1, initial_epoch=0, prefix='_trans4_4b')