In [None]:
import tensorflow as tf
import pandas as pd
import numpy as np
import pathlib
import os
import io
import matplotlib.pyplot as plt
import functools
import time

from object_detection import inputs

from object_detection.model_lib_v2 import eager_train_step
from object_detection.model_lib_v2 import eager_eval_loop
from object_detection.model_lib_v2 import load_fine_tune_checkpoint
from object_detection.model_lib_v2 import get_filepath
from object_detection.model_lib_v2 import clean_temporary_directories

from object_detection.protos import pipeline_pb2

from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as viz_utils
from object_detection.utils import config_util

from object_detection.builders import dataset_builder
from object_detection.builders import image_resizer_builder
from object_detection.builders import model_builder
from object_detection.builders import preprocessor_builder

from object_detection.core import standard_fields as fields

from object_detection.exporter_lib_v2 import DetectionInferenceModule

from zoobot.tensorflow.estimators import custom_layers, define_model

tf.get_logger().setLevel('ERROR')

In [None]:
%matplotlib inline

In [None]:
strategy = tf.distribute.MirroredStrategy()
# strategy = tf.compat.v2.distribute.get_strategy()

In [None]:
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

In [None]:
DATA_PATH = '../RPN_Backbone_GZ2/Data/'
IMAGE_PATH = DATA_PATH + 'real_pngs/'
TFRECORDS_PATH = './Data/tf_records/'
PIPELINE_CONFIG_PATH = './ssd_efficientdet_d0_512x512_coco17_tpu-8.config'

# Full pathes for config
LABELS_PATH = './Data/clump_label_map_reduced.pbtxt'
FINE_TUNE_CHECKPOINT_PATH = './pre_trained_models/EfficientDet/ckpt-0'
TRAIN_DATASET_PATH = './Data/tf_records/GZ2_ClumpScout_train.records-?????-of-00006'
EVAL_DATASET_PATH = './Data/tf_records/GZ2_ClumpScout_val.records-?????-of-00001'

NUM_CLASSES = 2

PER_REPLICA_BATCH_SIZE = 4
try:
    REPLICAS = strategy.num_replicas_in_sync
except:
    REPLICAS = 1

BATCH_SIZE = PER_REPLICA_BATCH_SIZE * REPLICAS

IMAGE_SIZE = 256
SCORE_THRESHOLD = 0.8

# Training and eval directories
MODEL_DIR = './models/Zoobot_EfficientDetD0/'
OUTPUT_MODEL_DIR = MODEL_DIR + 'saved_model'

In [None]:
# Create a dictionary describing the features.
feature_description = {
    'image/height': tf.io.FixedLenFeature([], tf.int64),
    'image/width': tf.io.FixedLenFeature([], tf.int64),
    'image/local_id': tf.io.FixedLenFeature([], tf.int64),
    'image/source_id': tf.io.FixedLenFeature([], tf.string),
    'image/encoded': tf.io.FixedLenFeature([], tf.string),
    'image/format': tf.io.FixedLenFeature([], tf.string),
    'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32),
    'image/object/bbox/xmax': tf.io.VarLenFeature(tf.float32),
    'image/object/bbox/ymin': tf.io.VarLenFeature(tf.float32),
    'image/object/bbox/ymax': tf.io.VarLenFeature(tf.float32),
    'image/object/class/text': tf.io.VarLenFeature(tf.string),
    'image/object/class/label': tf.io.VarLenFeature(tf.int64),
}


def _parse_image_function(example_proto):
  # Parse the input tf.train.Example proto using the dictionary above.
  return tf.io.parse_single_example(example_proto, feature_description)


def plot_img_with_boxes(image, classes, boxes, scores=None, axis=None, plot=True):
    if scores is None:
        scores = np.ones(len(classes))
        
    image_with_detections = image.copy()
    
    viz_utils.visualize_boxes_and_labels_on_image_array(
        image_with_detections,
        boxes,
        classes,
        scores,
        category_index,
        use_normalized_coordinates=True,
        max_boxes_to_draw=100,
        min_score_thresh=SCORE_THRESHOLD,
        agnostic_mode=False,
        line_thickness=1,
    )
    
    if plot:
        if axis is None:
            plt.figure(figsize=(12,12))
            plt.imshow(image_with_detections)
            plt.show()
        else:
            axis.imshow(image_with_detections)
    else:
        return image_with_detections

In [None]:
train_set =  sorted(str(p) for p in pathlib.Path(TFRECORDS_PATH).glob('GZ2_ClumpScout_train.records*'))
valid_set =  sorted(str(p) for p in pathlib.Path(TFRECORDS_PATH).glob('GZ2_ClumpScout_val.records*'))
test_set  =  sorted(str(p) for p in pathlib.Path(TFRECORDS_PATH).glob('GZ2_ClumpScout_test.records*'))

train_image_dataset = tf.data.TFRecordDataset(train_set)
validation_image_dataset = tf.data.TFRecordDataset(valid_set)
test_image_dataset = tf.data.TFRecordDataset(test_set)

# Load in the labels
category_index = label_map_util.create_category_index_from_labelmap(
    LABELS_PATH,
    use_display_name=True
)

In [None]:
# plot some sample images
iterator = iter(test_image_dataset)

N = 16
fig, ax = plt.subplots(int(np.sqrt(N)), int(np.sqrt(N)), figsize=(20,20))
ax = ax.flatten()

for idx in range(N):
    image_features_raw = next(iterator)
    image_features = tf.train.Example.FromString(image_features_raw.numpy())
    image = tf.image.decode_png(
        image_features.features.feature['image/encoded'].bytes_list.value[0], channels=3
    ).numpy()
    classes = image_features.features.feature['image/object/class/label'].int64_list.value[:]
    # [ymin, xmin, ymax, xmax]
    boxes = np.stack([
        image_features.features.feature['image/object/bbox/ymin'].float_list.value[:],
        image_features.features.feature['image/object/bbox/xmin'].float_list.value[:],
        image_features.features.feature['image/object/bbox/ymax'].float_list.value[:],
        image_features.features.feature['image/object/bbox/xmax'].float_list.value[:],
    ], -1)

    plot_img_with_boxes(image, classes, boxes, axis=ax[idx])

In [None]:
# Model configuration
configs = config_util.get_configs_from_pipeline_file(PIPELINE_CONFIG_PATH)

configs['model'].ssd.num_classes = NUM_CLASSES

configs['model'].ssd.image_resizer.keep_aspect_ratio_resizer.min_dimension = IMAGE_SIZE
configs['model'].ssd.image_resizer.keep_aspect_ratio_resizer.max_dimension = IMAGE_SIZE

configs['train_config'].sync_replicas = True if REPLICAS > 1 else False
configs['train_config'].replicas_to_aggregate = REPLICAS
configs['train_config'].batch_size = BATCH_SIZE

configs['train_config'].fine_tune_checkpoint = FINE_TUNE_CHECKPOINT_PATH
configs['train_config'].fine_tune_checkpoint_type = "detection"

configs['train_input_config'].label_map_path = LABELS_PATH
configs['train_input_config'].tf_record_input_reader.input_path[:] = TRAIN_DATASET_PATH

configs['eval_config'].batch_size = 1
configs['eval_config'].metrics_set[:] = ''
configs['eval_config'].metrics_set.append('coco_detection_metrics')

configs['eval_input_config'].label_map_path = LABELS_PATH
configs['eval_input_config'].tf_record_input_reader.input_path[:] = EVAL_DATASET_PATH

In [None]:
model_config = configs['model']
train_config = configs['train_config']
train_input_config = configs['train_input_config']
eval_config = configs['eval_config']
eval_input_config = configs['eval_input_configs'][0]

In [None]:
# Build model
def preprocess_fn(inputs):
    return inputs / 255.0

def build_model():
    detection_model = model_builder._build_ssd_model(
        ssd_config=model_config.ssd,
        is_training=True,
        add_summaries=False
    )

    detection_model._feature_extractor.preprocess = preprocess_fn
    
    return detection_model

try:
    with strategy.scope():
        detection_model = build_model()
except:
    detection_model = build_model()

In [None]:
# Training settings
SHAPE = (256,256)
learning_rate = 1e-4

EPOCHS = 1
STEPS_PER_EPOCH = 15834
NUM_TRAIN_STEPS = STEPS_PER_EPOCH * EPOCHS
train_steps = NUM_TRAIN_STEPS

RUN_EVAL = True
MONITOR_METRIC = 'PascalBoxes_Precision/mAP@0.5IOU'
ES_PATIENCE = 5

best_metric_value = 0.0
not_improved = 0
steps_per_sec_list = []

unpad_groundtruth_tensors = train_config.unpad_groundtruth_tensors
add_regularization_loss = train_config.add_regularization_loss

clip_gradients_value = None
if train_config.gradient_clipping_by_norm > 0:
    clip_gradients_value = train_config.gradient_clipping_by_norm

config_util.update_fine_tune_checkpoint_type(train_config)
fine_tune_checkpoint_type = train_config.fine_tune_checkpoint_type
fine_tune_checkpoint_version = train_config.fine_tune_checkpoint_version

In [None]:
with strategy.scope():
    def train_dataset_fn(input_context):
    
        def transform_input_data_fn(tensor_dict):
            data_augmentation_options = [
                preprocessor_builder.build(step)
                for step in train_config.data_augmentation_options
            ]
            data_augmentation_fn = functools.partial(
                inputs.augment_input_data,
                data_augmentation_options=data_augmentation_options
            )
    
            image_resizer_config = model_config.ssd.image_resizer
            image_resizer_fn = image_resizer_builder.build(image_resizer_config)
            
            transform_data_fn = functools.partial(
                inputs.transform_input_data, 
                model_preprocess_fn=detection_model.preprocess,
                image_resizer_fn=image_resizer_fn,
                num_classes=NUM_CLASSES,
                data_augmentation_fn=data_augmentation_fn,
                merge_multiple_boxes=False,
                use_multiclass_scores=False
            )
    
            tensor_dict = inputs.pad_input_data_to_static_shapes(
                tensor_dict=transform_data_fn(tensor_dict),
                max_num_boxes=train_input_config.max_number_of_boxes,
                num_classes=NUM_CLASSES,
                spatial_image_shape=SHAPE
            )
    
            return (inputs._get_features_dict(tensor_dict, False),
                    inputs._get_labels_dict(tensor_dict))
    
        train_input = dataset_builder.build(
            input_reader_config=train_input_config,
            transform_input_data_fn=transform_input_data_fn,
            batch_size=train_config.batch_size,
            input_context=input_context,
        )
        train_input = train_input.repeat()    
    
        return train_input

In [None]:
with strategy.scope():
    def eval_dataset_fn(input_context):
    
        def transform_input_data_fn(tensor_dict):
            image_resizer_config = model_config.ssd.image_resizer
            image_resizer_fn = image_resizer_builder.build(image_resizer_config)
    
            transform_data_fn = functools.partial(
                inputs.transform_input_data, 
                model_preprocess_fn=detection_model.preprocess,
                image_resizer_fn=image_resizer_fn,
                num_classes=NUM_CLASSES,
                merge_multiple_boxes=False,
                use_multiclass_scores=False,
                retain_original_image=eval_config.retain_original_images,
                retain_original_image_additional_channels=eval_config.retain_original_image_additional_channels
            )
    
            tensor_dict = inputs.pad_input_data_to_static_shapes(
                tensor_dict=transform_data_fn(tensor_dict),
                max_num_boxes=eval_input_config.max_number_of_boxes,
                num_classes=NUM_CLASSES,
                spatial_image_shape=SHAPE
            )
    
            return (inputs._get_features_dict(tensor_dict, False),
                    inputs._get_labels_dict(tensor_dict))
    
        eval_input = dataset_builder.build(
            eval_input_config,
            transform_input_data_fn=transform_input_data_fn,
            batch_size=eval_config.batch_size,
            input_context=input_context,
        )
    
        return eval_input

In [None]:
with strategy.scope():
    train_input = strategy.experimental_distribute_datasets_from_function(train_dataset_fn)
    eval_input = strategy.experimental_distribute_datasets_from_function(eval_dataset_fn)

In [None]:
with strategy.scope():
    train_input_iter = iter(train_input)

In [None]:
# load pre-trained COCO checkpoint
with strategy.scope():
    global_step = tf.Variable(
        0,
        trainable=False,
        dtype=tf.compat.v2.dtypes.int64,name='global_step',
        aggregation=tf.compat.v2.VariableAggregation.ONLY_FIRST_REPLICA
    )
    
    checkpointed_step = int(global_step.value())
    logged_step = int(global_step.value())
    total_loss = 0

    if train_config.fine_tune_checkpoint:
        load_fine_tune_checkpoint(
            model=detection_model,
            checkpoint_path=train_config.fine_tune_checkpoint,
            checkpoint_type=fine_tune_checkpoint_type,
            checkpoint_version=fine_tune_checkpoint_version,
            run_model_on_dummy_input=False,
            input_dataset=train_input,
            unpad_groundtruth_tensors=unpad_groundtruth_tensors
        )

In [None]:
# Load Zoobot
INITIAL_SIZE = 300
CROP_SIZE = int(INITIAL_SIZE * 0.75)
RESIZE_SIZE = 224   # Zoobot, as pretrained, expects 224x224 images

checkpoint_dir = './pre_trained_models/Zoobot_EfficientnetB0_colour/'
checkpoint_loc = os.path.join(checkpoint_dir, 'checkpoint')

conv_base = define_model.load_model(
    checkpoint_loc,  # loading pretrained model as above
    expect_partial=True,  # ignores some optimizer warnings
    include_top=False,  # do not include the head used for GZ DECaLS, this time - we will add our own head
    input_size=INITIAL_SIZE,  # the preprocessing above did not change size
    crop_size=CROP_SIZE,  # model augmentation layers apply a crop...
    resize_size=RESIZE_SIZE,  # ...and then apply a resize
    output_dim=None,
    channels=3
)

inputs = tf.keras.Input(shape=(INITIAL_SIZE, INITIAL_SIZE, 3))
x = conv_base(inputs)
outputs = tf.keras.layers.Dense(1, activation='sigmoid')(x)

zoobot = tf.keras.Model(inputs, outputs)

zoobot.compile(
    loss='binary_crossentropy',
    optimizer='adam',
    metrics=['accuracy']
)

In [None]:
efficientnet_zoobot = zoobot.get_layer('sequential').get_layer('sequential_1').get_layer('efficientnet-b0')

In [None]:
# assign the Zoobot weights to the EfficientDet detection model
for i, weight in enumerate(efficientnet_zoobot.weights):
    detection_model.feature_extractor.classification_backbone.weights[i].assign(weight)

In [None]:
with strategy.scope():
    if callable(learning_rate):
        learning_rate_fn = learning_rate
    else:
        learning_rate_fn = lambda: learning_rate

    optimizer = tf.keras.optimizers.legacy.Adam(learning_rate=learning_rate)
    ckpt = tf.compat.v2.train.Checkpoint(step=global_step, model=detection_model, optimizer=optimizer)
    manager_dir = get_filepath(strategy, MODEL_DIR)
    manager = tf.compat.v2.train.CheckpointManager(ckpt, manager_dir, max_to_keep=None)

In [None]:
def reduce_dict(strategy, reduction_dict, reduction_op):
  # scaling of the loss and switch this to a ReduceOp.Mean
  return {
      name: strategy.reduce(reduction_op, loss, axis=None)
      for name, loss in reduction_dict.items()
  }


# Training
with strategy.scope():
    def train_step_fn(features, labels):
        losses_dict = eager_train_step(
            detection_model,
            features,
            labels,
            unpad_groundtruth_tensors,
            optimizer,
            training_step=global_step,
            #learning_rate=learning_rate_fn(),
            add_regularization_loss=add_regularization_loss,
            clip_gradients_value=clip_gradients_value,
            num_replicas=REPLICAS
        )
        global_step.assign_add(1)
        
        return losses_dict


    def _sample_and_train(strategy, train_step_fn, data_iterator):
        features, labels = data_iterator.next()
        per_replica_losses_dict = strategy.run(train_step_fn, args=(features, labels))
        
        return reduce_dict(strategy, per_replica_losses_dict, tf.distribute.ReduceOp.SUM)


    @tf.function
    def _dist_train_step(data_iterator):
        return _sample_and_train(strategy, train_step_fn, data_iterator)

In [None]:
# Training Loop
last_step_time = time.time()

for _ in range(global_step.value(), train_steps):
    with strategy.scope():
        loss = _dist_train_step(train_input_iter)
        time_taken = time.time() - last_step_time
        last_step_time = time.time()
        steps_per_sec = 1.0 / time_taken
        steps_per_sec_list.append(steps_per_sec)
        total_loss += loss
        
    if int(global_step.value()) % STEPS_PER_EPOCH == 0:
        if not RUN_EVAL:
            print('Epoch {} [ETA {:.2f}s] loss={:.3f}'.format(
                  int(global_step.value()) // STEPS_PER_EPOCH,
                  time_taken * STEPS_PER_EPOCH,
                  total_loss / STEPS_PER_EPOCH))
        else:
            eval_global_step = tf.compat.v2.Variable(0, trainable=False, dtype=tf.compat.v2.dtypes.int64)
            eval_metrics = eager_eval_loop(
                detection_model,
                configs,
                eval_input,
                global_step=eval_global_step
            )

            print('Epoch {} [ETA {:.2f}s] loss={:.3f} mAP@.5={:.3f}'.format(
                  int(global_step.value()) // STEPS_PER_EPOCH,
                  time_taken * STEPS_PER_EPOCH,
                  total_loss / STEPS_PER_EPOCH,
                  eval_metrics[MONITOR_METRIC]))

            if eval_metrics[MONITOR_METRIC] > best_metric_value:
                best_metric_value = eval_metrics[MONITOR_METRIC]
                manager.save()
                not_improved = 0
            else:
                not_improved += 1

            if not_improved >= ES_PATIENCE:
                print(f"Early stopping at epoch {int(global_step.value()) // STEPS_PER_EPOCH}")
                break
            
        total_loss = 0

clean_temporary_directories(strategy, manager_dir)