# Introduction

[Object detection API](https://tensorflow-object-detection-api-tutorial.readthedocs.io/en/latest/) is a tensorflow-based library for object detectio tasks. Following [the docs](https://tensorflow-object-detection-api-tutorial.readthedocs.io/en/latest/training.html) you should be able to train your own detector without problems. Furthermore, there is a well written [public notebook](https://www.kaggle.com/sreevishnudamodaran/vbd-efficientdet-tf2-object-detection-api) that uses OD API for this challenge.

The API itself provides a python script for both training and evaluation of a detection model, but this approach is not very suitable for IPython notebooks and is not easy to customize. For this reason I've choosen a slightly different approach to use this API by looking at their codebase and rewriting the training loop by myself. Most of the code is indeed copied from OD API with some minor modifications to make it work in kaggle notebooks.

I'm sorry if the code isn't very clear, this is my first attempt with object detection API too.

# Install Object Detection API

**NOTE:** I decided to use commit `3f6fe2aa410d901aae8829597a65d084bffc20d3` as it does not require tensorflow version `2.4.0` (that causes CUDA version mismatch due to some recent commit).

In [None]:
%%capture
!git clone https://github.com/tensorflow/models.git

%cd models/research/
!git reset --hard 3f6fe2aa410d901aae8829597a65d084bffc20d3

!protoc object_detection/protos/*.proto --python_out=.

!cp object_detection/packages/tf2/setup.py .
!python -m pip install . 
%cd /kaggle/working

# Setup directories

* Organise workspace/training files: in standard OD API approach you need to setup your working directories following a specific tree. For convenience, I did the same here:

```
./
├─ annotations/
    └─ label_map.pbtxt
├─ models/
    └─ pipeline.config
└─ pre-trained-models/*/
    ├─ checkpoint/
    ├─ saved_model/
    └─ pipeline.config
```

* Download pre-trained EfficientDet-D0 from [TF Model Zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md)

In [None]:
BASE_DIR = 'chest-x-ray-detection'
MODEL_PATH = 'efficientdet_d0_coco17_tpu-32'

In [None]:
%%capture
!rm -r {BASE_DIR}
!mkdir {BASE_DIR}
!mkdir {BASE_DIR}/pre-trained-models/
!mkdir {BASE_DIR}/annotations
!mkdir {BASE_DIR}/models
!mkdir {BASE_DIR}/models/efficientdet/

!wget http://download.tensorflow.org/models/object_detection/tf2/20200711/{MODEL_PATH}.tar.gz

!tar -xvzf {MODEL_PATH}.tar.gz
!rm {MODEL_PATH}.tar.gz
!mv {MODEL_PATH} {BASE_DIR}/pre-trained-models/
!mv {BASE_DIR}/pre-trained-models/{MODEL_PATH}/pipeline.config {BASE_DIR}/models/efficientdet/pipeline.config

# Import libraries

In [None]:
import pathlib, cv2, os, time, functools
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
%matplotlib inline

import tensorflow as tf

from google.protobuf import text_format

## Object Detection API internal libraries

In [None]:
from object_detection import inputs

from object_detection.model_lib_v2 import eager_train_step
from object_detection.model_lib_v2 import eager_eval_loop
from object_detection.model_lib_v2 import load_fine_tune_checkpoint
from object_detection.model_lib_v2 import get_filepath
from object_detection.model_lib_v2 import clean_temporary_directories

from object_detection.protos import pipeline_pb2

from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as viz_utils
from object_detection.utils import config_util

from object_detection.builders import dataset_builder
from object_detection.builders import image_resizer_builder
from object_detection.builders import model_builder
from object_detection.builders import preprocessor_builder

from object_detection.core import standard_fields as fields

from object_detection.exporter_lib_v2 import DetectionInferenceModule

# Setup notebook

In [None]:
MODEL_DIR = BASE_DIR + '/models/efficientdet/'
PIPELINE_PATH = MODEL_DIR + 'pipeline.config'
LABEL_MAP_PATH = BASE_DIR + '/annotations/label_map.pbtxt'
OUTPUT_MODEL_DIR = '/kaggle/working/saved_model'

In [None]:
input_path = pathlib.Path('/kaggle/input/chest-xray-detection-512x512-groupkfold-tfrec')

!cp {input_path}/label_map.pbtxt {LABEL_MAP_PATH}

DS_PATH = str(input_path)
os.makedirs(OUTPUT_MODEL_DIR, exist_ok=True)

In [None]:
plt.rcParams['axes.grid'] = False
plt.rcParams['xtick.labelsize'] = False
plt.rcParams['ytick.labelsize'] = False
plt.rcParams['xtick.top'] = False
plt.rcParams['xtick.bottom'] = False
plt.rcParams['ytick.left'] = False
plt.rcParams['ytick.right'] = False
plt.rcParams['figure.figsize'] = [12, 12]

In [None]:
def seed_everything(seed=0):
    np.random.seed(seed)
    tf.random.set_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['TF_DETERMINISTIC_OPS'] = '1'

seed = 2020
seed_everything(seed)

In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPU")
        strategy = tf.distribute.MirroredStrategy(devices=["GPU:0"])
    except RuntimeError as e:
        gpu = None

In [None]:
NUM_CLASSES = 14
PER_REPLICA_BATCH_SIZE = 2
try:
    REPLICAS = strategy.num_replicas_in_sync
except:
    REPLICAS = 1
    
BATCH_SIZE = PER_REPLICA_BATCH_SIZE * REPLICAS

fold = 0
N_FOLDS = 5

SCORE_THRESHOLD = 0.5

# Read data

I created a custom version of the dataset stored in tfrecord files (private at the moment). I'm not planning to make the dataset public, but I've done basic pre-processing steps:

* voi-lut
* monochrome correction
* resizing to `512x512`
* histogram equalization

I've also splitted data by patient id in `5 folds` using (Multi-Class Stratified) GroupKFold.

In [None]:
train_df = pd.read_csv(input_path / 'train.csv')
train_df.head()

In [None]:
category_index = label_map_util.create_category_index_from_labelmap(
    LABEL_MAP_PATH,
    use_display_name=True
)

In [None]:
TRAIN_DATASET = tf.io.gfile.glob(DS_PATH + f'/fold_[^{fold + 1}].tfrecord')
TEST_DATASET = tf.io.gfile.glob(DS_PATH + f'/fold_{fold + 1}.tfrecord')    

ct_train = len(train_df['image_id'][train_df['fold'] != fold + 1].unique())  / BATCH_SIZE
ct_test = len(train_df['image_id'][train_df['fold'] == fold + 1].unique()) / BATCH_SIZE

## Show samples

In [None]:
def plot_img_with_boxes(image, classes, boxes, scores=None, axis=None, plot=True):
    if scores is None:
        scores = np.ones(len(classes))
        
    image_with_detections = image.copy()
    
    viz_utils.visualize_boxes_and_labels_on_image_array(
          image_with_detections,
          boxes,
          classes,
          scores,
          category_index,
          use_normalized_coordinates=True,
          max_boxes_to_draw=100,
          min_score_thresh=SCORE_THRESHOLD,
          agnostic_mode=False)
    
    if plot:
        if axis is None:
            plt.figure(figsize=(12,12))
            plt.imshow(image_with_detections)
            plt.show()
        else:
            axis.imshow(image_with_detections)
    else:
        return image_with_detections

In [None]:
feature_description = {
    'image/height': tf.io.FixedLenFeature([], tf.int64),
    'image/width': tf.io.FixedLenFeature([], tf.int64),
    'image/filename': tf.io.FixedLenFeature([], tf.string),
    'image/source_id': tf.io.FixedLenFeature([], tf.string),
    'image/encoded': tf.io.FixedLenFeature([], tf.string),
    'image/format': tf.io.FixedLenFeature([], tf.string),
    'image/object/bbox/xmin': tf.io.FixedLenSequenceFeature([], tf.float32, True),
    'image/object/bbox/xmax': tf.io.FixedLenSequenceFeature([], tf.float32, True),
    'image/object/bbox/ymin': tf.io.FixedLenSequenceFeature([], tf.float32, True),
    'image/object/bbox/ymax': tf.io.FixedLenSequenceFeature([], tf.float32, True),
    'image/object/class/text': tf.io.FixedLenSequenceFeature([], tf.string, True),
    'image/object/class/label': tf.io.FixedLenSequenceFeature([], tf.int64, True)
}

def parse_image_sample(example_proto):
    return tf.io.parse_single_example(example_proto,
                                      feature_description)

raw_image_dataset = tf.data.TFRecordDataset(TEST_DATASET)
parsed_image_dataset = raw_image_dataset.map(parse_image_sample)
iterator = iter(parsed_image_dataset)

In [None]:
%matplotlib inline

N = 16
fig, ax = plt.subplots(int(np.sqrt(N)), int(np.sqrt(N)), figsize=(12,12))
ax = ax.flatten()

for idx in range(N):
    image_features = next(iterator)
    image_raw = image_features['image/encoded']
    image = tf.image.decode_jpeg(image_raw).numpy()
    classes = image_features['image/object/class/label'].numpy()
    boxes = np.stack([
        image_features['image/object/bbox/xmin'],
        image_features['image/object/bbox/ymin'],
        image_features['image/object/bbox/xmax'],
        image_features['image/object/bbox/ymax'],
    ], -1)

    plot_img_with_boxes(image, 
                        classes, 
                        boxes,
                        axis=ax[idx])
    
fig.show()

# Model configuration

Here I override default configurations of the pre-trained model that are specific for COCO dataset. This step is equivalent to writing the `pipeline.config` file by hand or by looking for strings through regular expressions, but I find this way more clean and straightforward.

Notice that differently from the standard way of using OD API, in this case I'm not using the `pipeline.config` file, but I'm loading configurations into a dictionary that I'm going to use later.

In [None]:
configs = config_util.get_configs_from_pipeline_file(PIPELINE_PATH)

configs['model'].ssd.num_classes = NUM_CLASSES

configs['train_config'].sync_replicas = True if REPLICAS > 1 else False
configs['train_config'].replicas_to_aggregate = REPLICAS
configs['train_config'].batch_size = BATCH_SIZE
configs['train_config'].data_augmentation_options.pop(1)

configs['train_config'].fine_tune_checkpoint = (
    BASE_DIR + f'/pre-trained-models/{MODEL_PATH}/checkpoint/ckpt-0'
)
configs['train_config'].fine_tune_checkpoint_type = "detection"

configs['train_input_config'].label_map_path = LABEL_MAP_PATH
configs['train_input_config'].tf_record_input_reader.input_path[:] = TRAIN_DATASET
configs['train_input_config'].load_multiclass_scores = True

configs['eval_config'].batch_size = 1
configs['eval_config'].metrics_set[:] = ''
configs['eval_config'].metrics_set.append('pascal_voc_detection_metrics')

configs['eval_input_config'].label_map_path = LABEL_MAP_PATH
configs['eval_input_config'].tf_record_input_reader.input_path[:] = TEST_DATASET
configs['eval_input_config'].load_multiclass_scores = True

config_util.save_pipeline_config(config_util.create_pipeline_proto_from_configs(configs),
                                 PIPELINE_PATH.replace('pipeline.config', ''))

In [None]:
model_config = configs['model']
train_config = configs['train_config']
train_input_config = configs['train_input_config']
eval_config = configs['eval_config']
eval_input_config = configs['eval_input_configs'][0]

## Build model

Here I build my detector using `model_builder` utils. Notice that `build_model` fuction also overrides the preprocessing function used by the feature extractor backbone (`feature_extractor.preprocess`). This function is used both during training and inference to pre-process the data before feeding them to the detector. The [default pre-processing](https://github.com/tensorflow/models/blob/master/research/object_detection/models/ssd_efficientnet_bifpn_feature_extractor.py#L185)  provided by OD API is channel-wise normalization by ImageNet mean/std.

In [None]:
def preprocess_fn(inputs):
    return inputs / 255.0

def build_model():
    detection_model = model_builder._build_ssd_model(ssd_config=model_config.ssd,
                                                     is_training=True,
                                                     add_summaries=False)

    detection_model._feature_extractor.preprocess = preprocess_fn
    
    return detection_model

try:
    with strategy.scope():
        detection_model = build_model()
except:
    detection_model = build_model()

# Training utils

In [None]:
SHAPE = (512,512)
learning_rate = 1e-4

EPOCHS = 1
STEPS_PER_EPOCH = int(ct_train)
NUM_TRAIN_STEPS = int(STEPS_PER_EPOCH * EPOCHS)
train_steps = NUM_TRAIN_STEPS

RUN_EVAL = True
MONITOR_METRIC = 'PascalBoxes_Precision/mAP@0.5IOU'
ES_PATIENCE = 5

best_metric_value = 0.0
not_improved = 0
steps_per_sec_list = []

unpad_groundtruth_tensors = train_config.unpad_groundtruth_tensors
add_regularization_loss = train_config.add_regularization_loss

clip_gradients_value = None
if train_config.gradient_clipping_by_norm > 0:
    clip_gradients_value = train_config.gradient_clipping_by_norm

config_util.update_fine_tune_checkpoint_type(train_config)
fine_tune_checkpoint_type = train_config.fine_tune_checkpoint_type
fine_tune_checkpoint_version = train_config.fine_tune_checkpoint_version

## Dataset functions

I used `dataset.build` ([code](https://github.com/tensorflow/models/blob/master/research/object_detection/builders/dataset_builder.py#L166)) as it is done in OD API codebase. From the docs:
```
Builds a tf.data.Dataset by applying the `transform_input_data_fn` on all
  records. Applies a padded batch to the resulting dataset.
  Args:
    input_reader_config: A input_reader_pb2.InputReader object.
    batch_size: Batch size. If batch size is None, no batching is performed.
    transform_input_data_fn: Function to apply transformation to all records,
      or None if no extra decoding is required.
    input_context: optional, A tf.distribute.InputContext object used to
      shard filenames and compute per-replica batch_size when this function
      is being called per-replica.
    reduce_to_frame_fn: Function that extracts frames from tf.SequenceExample
      type input data.
  Returns:
    A tf.data.Dataset based on the input_reader_config.
```

### Training dataset

Here you can change `train_config.data_augmentation_options` in `pipeline.config` file to change data augmentation applied during training. Possible options are listed [here](https://github.com/tensorflow/models/blob/master/research/object_detection/protos/preprocessor.proto#L8).

In [None]:
def train_dataset_fn(input_context):
    def transform_input_data_fn(tensor_dict):
        data_augmentation_options = [
            preprocessor_builder.build(step)
            for step in train_config.data_augmentation_options
        ]
        data_augmentation_fn = functools.partial(
            inputs.augment_input_data,
            data_augmentation_options=data_augmentation_options
        )

        image_resizer_config = model_config.ssd.image_resizer
        image_resizer_fn = image_resizer_builder.build(image_resizer_config)
        transform_data_fn = functools.partial(
            inputs.transform_input_data, 
            model_preprocess_fn=detection_model.preprocess,
            image_resizer_fn=image_resizer_fn,
            num_classes=NUM_CLASSES,
            data_augmentation_fn=data_augmentation_fn,
            merge_multiple_boxes=False,
            use_multiclass_scores=False
        )

        tensor_dict = inputs.pad_input_data_to_static_shapes(
            tensor_dict=transform_data_fn(tensor_dict),
            max_num_boxes=train_input_config.max_number_of_boxes,
            num_classes=NUM_CLASSES,
            spatial_image_shape=SHAPE
        )

        return (inputs._get_features_dict(tensor_dict, False),
                inputs._get_labels_dict(tensor_dict))

    train_input = dataset_builder.build(
        train_input_config,
        transform_input_data_fn=transform_input_data_fn,
        batch_size=train_config.batch_size,
        input_context=input_context,
    )
    train_input = train_input.repeat()    

    return train_input

### Validation dataset

In [None]:
def eval_dataset_fn(input_context):
    def transform_input_data_fn(tensor_dict):
        image_resizer_config = model_config.ssd.image_resizer
        image_resizer_fn = image_resizer_builder.build(image_resizer_config)

        transform_data_fn = functools.partial(
            inputs.transform_input_data, 
            model_preprocess_fn=detection_model.preprocess,
            image_resizer_fn=image_resizer_fn,
            num_classes=NUM_CLASSES,
            merge_multiple_boxes=False,
            use_multiclass_scores=False,
            retain_original_image=eval_config.retain_original_images,
            retain_original_image_additional_channels=eval_config.retain_original_image_additional_channels
        )

        tensor_dict = inputs.pad_input_data_to_static_shapes(
            tensor_dict=transform_data_fn(tensor_dict),
            max_num_boxes=eval_input_config.max_number_of_boxes,
            num_classes=NUM_CLASSES,
            spatial_image_shape=SHAPE
        )

        return (inputs._get_features_dict(tensor_dict, False),
                inputs._get_labels_dict(tensor_dict))

    eval_input = dataset_builder.build(
        eval_input_config,
        transform_input_data_fn=transform_input_data_fn,
        batch_size=eval_config.batch_size,
        input_context=input_context,
    )

    return eval_input

In [None]:
train_input = strategy.experimental_distribute_datasets_from_function(
    train_dataset_fn
)

train_input_iter = iter(train_input)

eval_input = strategy.experimental_distribute_datasets_from_function(
    eval_dataset_fn
)

### Show training samples

In [None]:
%matplotlib inline

N = 16
fig, ax = plt.subplots(int(np.sqrt(N)), int(np.sqrt(N)), figsize=(12,12))
ax = ax.flatten()

train_dataset = train_dataset_fn(None).unbatch().batch(1)
train_iter = iter(train_dataset)

for idx in range(N):
    features, labels = next(train_iter)
    image = features['image'][0].numpy()
    n_boxes = labels['num_groundtruth_boxes'][0].numpy()
    boxes = labels['groundtruth_boxes'][0, :n_boxes, :].numpy()
    classes = labels['groundtruth_classes'][0, :n_boxes, :].numpy()
    classes = np.argmax(classes, axis=-1) + 1
    
    plot_img_with_boxes((image*255).astype('uint8'), 
                        classes, 
                        boxes,
                        axis=ax[idx])
    
fig.show()

# Training

## Load weights

First I load the model checkpoint from `pipeline.config` file (i.e. the saved model pre-trained on COCO dataset).

In [None]:
with strategy.scope():
    global_step = tf.Variable(0,
                              trainable=False,
                              dtype=tf.compat.v2.dtypes.int64,
                              name='global_step',
                              aggregation=tf.compat.v2.VariableAggregation.ONLY_FIRST_REPLICA)
    
    checkpointed_step = int(global_step.value())
    logged_step = int(global_step.value())
    total_loss = 0

    if train_config.fine_tune_checkpoint:
        load_fine_tune_checkpoint(detection_model,
                                  train_config.fine_tune_checkpoint,
                                  fine_tune_checkpoint_type,
                                  fine_tune_checkpoint_version,
                                  train_input,
                                  unpad_groundtruth_tensors)

Then I load the last saved checkpoint from `MODEL_DIR` (if any), in order to resume the training process if any checkpoint has been saved in `MODEL_DIR`.

In [None]:
with strategy.scope():
    if callable(learning_rate):
        learning_rate_fn = learning_rate
    else:
        learning_rate_fn = lambda: learning_rate

    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
    
    ckpt = tf.compat.v2.train.Checkpoint(step=global_step,
                                         model=detection_model,
                                         optimizer=optimizer)

    manager_dir = get_filepath(strategy, MODEL_DIR)

    manager = tf.compat.v2.train.CheckpointManager(ckpt,
                                                   manager_dir,
                                                   max_to_keep=1)

    latest_checkpoint = tf.train.latest_checkpoint(MODEL_DIR)
    ckpt.restore(latest_checkpoint).expect_partial()

## Custom training loop

To implement the training loop i used `eager_train_step`  from `model_lib_v2` ([code](https://github.com/tensorflow/models/blob/31e86e86c1e7f4154819e1c52ea0c51b287c2c70/research/object_detection/model_lib_v2.py#L146)), but you can rewrite if on your own as it is essentially a standard [TF-2 custom training loop](https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch), except for the fact that I use COCO evaluator to compute mean average precision after each epoch.

In [None]:
with strategy.scope():
    def train_step_fn(features, labels):
        loss = eager_train_step(detection_model,
                                features,
                                labels,
                                unpad_groundtruth_tensors,
                                optimizer,
                                learning_rate=learning_rate_fn(),
                                add_regularization_loss=add_regularization_loss,
                                clip_gradients_value=clip_gradients_value,
                                global_step=global_step,
                                num_replicas=REPLICAS)
        global_step.assign_add(1)
        return loss

    def _sample_and_train(strategy, train_step_fn, data_iterator):
        features, labels = data_iterator.next()
        per_replica_losses = strategy.run(train_step_fn, 
                                          args=(features, labels))
        return strategy.reduce(tf.distribute.ReduceOp.SUM,
                               per_replica_losses, axis=None)

    @tf.function
    def _dist_train_step(data_iterator):
        return _sample_and_train(strategy, 
                                 train_step_fn, 
                                 data_iterator)

Notice that in the following cell the training steps are performed within `strategy.scope()`, while the evaluation loop is performed outside. This is intended as differently from the training step, for evaluation i used directly `eager_eval_loop`. If you look at the [code](https://github.com/tensorflow/models/blob/31e86e86c1e7f4154819e1c52ea0c51b287c2c70/research/object_detection/model_lib_v2.py#L775) you will notice that this function use a similar approach to the one used in previous cell. Anyway the GPU usage when running evaluation is very low with respect to the training loop. For this reason, a more efficient and time-saving approach would be to write functions similar to those used for training also for running evaluation, at the cost of not using the mAP score but for instance the value of the loss (or other tensorflow metrics).

In [None]:
last_step_time = time.time()

for _ in range(global_step.value(), train_steps):
    with strategy.scope():
        loss = _dist_train_step(train_input_iter)
        time_taken = time.time() - last_step_time
        last_step_time = time.time()
        steps_per_sec = 1.0 / time_taken
        steps_per_sec_list.append(steps_per_sec)
        total_loss += loss
        
    if int(global_step.value()) % STEPS_PER_EPOCH == 0:
        if not RUN_EVAL:
            print('Epoch {} [ETA {:.2f}s] loss={:.3f}'.format(
                  int(global_step.value()) // STEPS_PER_EPOCH,
                  time_taken * STEPS_PER_EPOCH,
                  total_loss / STEPS_PER_EPOCH))
        else:
            eval_global_step = tf.compat.v2.Variable(0, 
                                                     trainable=False,
                                                     dtype=tf.compat.v2.dtypes.int64)

            eval_metrics = eager_eval_loop(detection_model,
                                           configs,
                                           eval_input,
                                           global_step=eval_global_step)

            print('Epoch {} [ETA {:.2f}s] loss={:.3f} mAP@.5={:.3f}'.format(
                  int(global_step.value()) // STEPS_PER_EPOCH,
                  time_taken * STEPS_PER_EPOCH,
                  total_loss / STEPS_PER_EPOCH,
                  eval_metrics[MONITOR_METRIC]))

            if eval_metrics[MONITOR_METRIC] > best_metric_value:
                best_metric_value = eval_metrics[MONITOR_METRIC]
                manager.save()
                not_improved = 0
            else:
                not_improved += 1

            if not_improved >= ES_PATIENCE:
                print(f"Early stopping at epoch {int(global_step.value()) // STEPS_PER_EPOCH}")
                break
            
        total_loss = 0

clean_temporary_directories(strategy, manager_dir)

# Export model for inference

We first load the latest checkpoint saved during training, then we export the whole detection model (pre-processing and post-processing included) in TF2 OD-API style.

In [None]:
class DetectionFromImageModule(DetectionInferenceModule):
    def __init__(self, detection_model):
        
        sig = [tf.TensorSpec(shape=[1, None, None, 3],
                             dtype=tf.uint8,
                             name='input_tensor')]

        def call_func(input_tensor):
            return self._run_inference_on_images(input_tensor)

        self.__call__ = tf.function(call_func, input_signature=sig)

        super(DetectionFromImageModule, self).__init__(detection_model)
        
    def _run_inference_on_images(self, image, **kwargs):
        label_id_offset = 1
        image = tf.cast(image, tf.float32)
        image, shapes = self._model.preprocess(image)
        prediction_dict = self._model.predict(image, shapes, **kwargs)
        detections = self._model.postprocess(prediction_dict, shapes)
        classes_field = fields.DetectionResultFields.detection_classes
        classes = tf.cast(detections[classes_field], tf.float32)
        detections[classes_field] = (classes + label_id_offset)

        for key, val in detections.items():
            detections[key] = tf.cast(val, tf.float32)

        return detections

In [None]:
ckpt = tf.train.Checkpoint(model=detection_model)
manager = tf.train.CheckpointManager(ckpt, 
                                     MODEL_DIR,
                                     max_to_keep=1)

status = ckpt.restore(manager.latest_checkpoint).expect_partial()

detection_module = DetectionFromImageModule(detection_model)
concrete_function = detection_module.__call__.get_concrete_function()
status.assert_existing_objects_matched()

exported_checkpoint_manager = tf.train.CheckpointManager(ckpt, 
                                                         OUTPUT_MODEL_DIR, 
                                                         max_to_keep=1)

exported_checkpoint_manager.save(checkpoint_number=0)
tf.saved_model.save(detection_module,
                    OUTPUT_MODEL_DIR + '/saved_model',
                    signatures=concrete_function)

# Inference from saved model

Here we test the saved model by running it on few examples and compare detected boxes (coloured) with respect to ground-truth boxes (in black).

In [None]:
detector = tf.saved_model.load('/kaggle/working/saved_model/saved_model/')

## Plot detected boxes

In [None]:
%matplotlib inline

N = 16
fig, ax = plt.subplots(int(np.sqrt(N)), int(np.sqrt(N)), figsize=(12,12))
ax = ax.flatten()

raw_image_dataset = tf.data.TFRecordDataset(TEST_DATASET)
parsed_image_dataset = raw_image_dataset.map(parse_image_sample)
iterator = iter(parsed_image_dataset)

for idx in range(N):
    image_features = next(iterator)
    image_raw = image_features['image/encoded']
    image = tf.image.decode_jpeg(image_raw)
    gt_boxes = np.stack([
        image_features['image/object/bbox/xmin'],
        image_features['image/object/bbox/ymin'],
        image_features['image/object/bbox/xmax'],
        image_features['image/object/bbox/ymax'],
    ], -1)
    
    out = detector(tf.expand_dims(image, 0))
    
    classes = out['detection_classes'].numpy()[0].astype('int')
    scores = out['detection_scores'].numpy()[0]
    boxes = out['detection_boxes'].numpy()[0]
    boxes = np.stack([
        boxes[:,1],
        boxes[:,0],
        boxes[:,3],
        boxes[:,2]        
    ], -1)
    
    image_with_gt = image.numpy()
    
    viz_utils.draw_bounding_boxes_on_image_array(image_with_gt, 
                                                 gt_boxes,
                                                 color='black')
    
    plot_img_with_boxes(image_with_gt, 
                        classes, 
                        boxes,
                        scores,
                        axis=ax[idx])
    
fig.show()

# Clean Environment

In [None]:
!rm -r /kaggle/working/models
!rm -r /kaggle/working/chest-x-ray-detection