In [1]:
import os
import sys
import tempfile
from ast import literal_eval

import numpy as np
import pandas as pd
import tensorflow as tf

import fastestimator as fe
from fastestimator.architecture.retinanet import RetinaNet, get_fpn_anchor_box, get_target
from fastestimator.dataset.mscoco import load_data
from fastestimator.op import NumpyOp, TensorOp
from fastestimator.op.numpyop import ImageReader, ResizeImageAndBbox, TypeConverter
from fastestimator.op.tensorop import Loss, ModelOp, Pad, Rescale
from fastestimator.trace import ModelSaver

In [2]:
train_csv, val_csv, path = load_data(path='/data/hsiming/dataset/')

In [3]:
class String2List(NumpyOp):
    # this thing converts '[1, 2, 3]' into np.array([1, 2, 3])
    def forward(self, data, state):
        data = map(literal_eval, data)
        return data
    

class GenerateTarget(NumpyOp):
    def __init__(self, inputs=None, outputs=None, mode=None):
        super().__init__(inputs=inputs, outputs=outputs, mode=mode)
        self.anchorbox, _ = get_fpn_anchor_box(input_shape=(512, 512, 3))

    def forward(self, data, state):
        obj_label, x1, y1, width, height = data
        cls_gt, x1_gt, y1_gt, w_gt, h_gt = get_target(self.anchorbox, obj_label, x1, y1, width, height)
        return cls_gt, x1_gt, y1_gt, w_gt, h_gt


class RetinaLoss(Loss):
    def focal_loss(self, cls_gt_example, cls_pred_example, alpha=0.25, gamma=2.0):
        # cls_gt_example shape: [A], cls_pred_example shape: [A, K]
        num_classes = cls_pred_example.shape[-1]
        # gather the objects and background, discard the rest
        anchor_obj_idx = tf.where(tf.greater_equal(cls_gt_example, 0))
        anchor_obj_bg_idx = tf.where(tf.greater_equal(cls_gt_example, -1))
        anchor_obj_count = tf.cast(tf.shape(anchor_obj_idx)[0], tf.float32)
        cls_gt_example = tf.one_hot(cls_gt_example, num_classes)
        cls_gt_example = tf.gather_nd(cls_gt_example, anchor_obj_bg_idx)
        cls_pred_example = tf.gather_nd(cls_pred_example, anchor_obj_bg_idx)
        cls_gt_example = tf.reshape(cls_gt_example, (-1, 1))
        cls_pred_example = tf.reshape(cls_pred_example, (-1, 1))
        # compute the focal weight on each selected anchor box
        alpha_factor = tf.ones_like(cls_gt_example) * alpha
        alpha_factor = tf.where(tf.equal(cls_gt_example, 1), alpha_factor, 1 - alpha_factor)
        focal_weight = tf.where(tf.equal(cls_gt_example, 1), 1 - cls_pred_example, cls_pred_example)
        focal_weight = alpha_factor * focal_weight**gamma / anchor_obj_count
        cls_loss = tf.losses.BinaryCrossentropy(reduction='sum')(cls_gt_example,
                                                                 cls_pred_example,
                                                                 sample_weight=focal_weight)
        return cls_loss, anchor_obj_idx

    def smooth_l1(self, loc_gt_example, loc_pred_example, anchor_obj_idx, beta=0.1):
        """Return smooth l1 loss for box regesssion.

        Args:
            loc_gt_example (Tensor): Tensor of shape (padded=252, 4).
            loc_pred_example (Tensor): Tensor of shape (num_anchors, 4).
            anchor_obj_idx (Tensor): Indices of selected anchor box.

        Returns:
            float: Smooth l1 loss.
        """
        loc_pred = tf.gather_nd(loc_pred_example, anchor_obj_idx)  #anchor_obj_count x 4
        anchor_obj_count = tf.shape(loc_pred)[0]
        loc_gt = loc_gt_example[:anchor_obj_count]  #anchor_obj_count x 4
        loc_gt = tf.reshape(loc_gt, (-1, 1))
        loc_pred = tf.reshape(loc_pred, (-1, 1))
        loc_diff = tf.abs(loc_gt - loc_pred)
        cond = tf.less(loc_diff, beta)
        smooth_l1_loss = tf.where(cond, 0.5 * loc_diff**2 / beta, loc_diff - 0.5 * beta)
        smooth_l1_loss = tf.reduce_sum(smooth_l1_loss) / tf.cast(anchor_obj_count, tf.float32)
        return smooth_l1_loss

    def forward(self, data, state):
        cls_gt, x1_gt, y1_gt, w_gt, h_gt, cls_pred, loc_pred = data
        local_batch_size = state["local_batch_size"]
        focal_loss = []
        l1_loss = []
        total_loss = []
        for idx in range(local_batch_size):
            cls_gt_example = cls_gt[idx]
            x1_gt_example = x1_gt[idx]
            y1_gt_example = y1_gt[idx]
            w_gt_example = w_gt[idx]
            h_gt_example = h_gt[idx]
            loc_gt_example = tf.transpose(tf.stack([x1_gt_example, y1_gt_example, w_gt_example, h_gt_example]))
            cls_pred_example = cls_pred[idx]
            loc_pred_example = loc_pred[idx]
            focal_loss_example, anchor_obj_idx = self.focal_loss(cls_gt_example, cls_pred_example)
            smooth_l1_loss_example = self.smooth_l1(loc_gt_example, loc_pred_example, anchor_obj_idx)
            focal_loss.append(focal_loss_example)
            l1_loss.append(smooth_l1_loss_example)
        focal_loss = tf.stack(focal_loss)
        l1_loss = tf.stack(l1_loss)
        total_loss = focal_loss + l1_loss

        return total_loss, focal_loss, l1_loss

In [4]:
class PredictBox(TensorOp):
    def __init__(self,
                 inputs=None,
                 outputs=None,
                 mode=None,
                 input_shape=(512, 512, 3),
                 select_top_k=1000,
                 nms_max_outputs=100):
        super().__init__(inputs=inputs, outputs=outputs, mode=mode)
        self.input_shape = input_shape
        self.select_top_k = tf.cast(select_top_k, dtype=tf.int32)
        self.nms_max_outputs = nms_max_outputs

        all_anchors, num_anchors_per_level = get_fpn_anchor_box(input_shape=input_shape)
        self.all_anchors = tf.convert_to_tensor(all_anchors)
        self.num_anchors_per_level = tf.convert_to_tensor(num_anchors_per_level, dtype=tf.int32)

    def index_to_bool(self, indices, length):
        updates = tf.ones_like(indices, dtype=tf.bool)
        shape = tf.expand_dims(length, 0)
        is_selected = tf.scatter_nd(tf.cast(tf.expand_dims(indices, axis=-1), dtype=tf.int32), updates, shape)
        return is_selected

    def forward(self, data, state):
        """
        """

        pred = []
        gt = []

        # extract max score and its class label
        cls_pred, deltas, label_gt, x1_gt, y1_gt, w_gt, h_gt = data
        labels = tf.cast(tf.argmax(cls_pred, axis=2), dtype=tf.int32)
        scores = tf.reduce_max(cls_pred, axis=2)

        # iterate over image
        for i in range(state['local_batch_size']):
            labels_per_image = labels[i]
            scores_per_image = scores[i]
            deltas_per_image = deltas[i]
            
            keep_gt = label_gt[i] > 0
            label_gt_per_image = label_gt[i][keep_gt]
            x1_gt_per_image = x1_gt[i][keep_gt]
            y1_gt_per_image = y1_gt[i][keep_gt]
            w_gt_per_image = w_gt[i][keep_gt]
            h_gt_per_image = h_gt[i][keep_gt]

            selected_deltas_per_image = tf.constant([], shape=(0, 4))
            selected_labels_per_image = tf.constant([], dtype=tf.int32)
            selected_scores_per_image = tf.constant([])
            selected_anchor_indices_per_image = tf.constant([], dtype=tf.int32)

            end_index = 0
            # iterate over each pyramid level
            for j in range(self.num_anchors_per_level.shape[0]):
                start_index = end_index
                end_index += self.num_anchors_per_level[j]
                anchor_indices = tf.range(start_index, end_index, dtype=tf.int32)

                level_scores = scores_per_image[start_index:end_index]
                level_deltas = deltas_per_image[start_index:end_index]
                level_labels = labels_per_image[start_index:end_index]

                # select top k
                if self.num_anchors_per_level[j] >= self.select_top_k:
                    # won't work without the tf.minimum
                    top_k = tf.math.top_k(level_scores, tf.minimum(self.num_anchors_per_level[j], self.select_top_k))
                    top_k_scores = top_k.values
                    top_k_indices = tf.add(top_k.indices, [start_index])
                else:
                    top_k_scores = level_scores
                    top_k_indices = anchor_indices

                # filter out low score
                is_high_score = tf.greater(top_k_scores, 0.05)
                selected_indices = tf.boolean_mask(top_k_indices, is_high_score)
                is_selected = self.index_to_bool(tf.subtract(selected_indices, [start_index]),
                                                 self.num_anchors_per_level[j])

                # combine all pyramid levels
                selected_deltas_per_image = tf.concat(
                    [selected_deltas_per_image, tf.boolean_mask(level_deltas, is_selected)], axis=0)
                selected_scores_per_image = tf.concat(
                    [selected_scores_per_image, tf.boolean_mask(level_scores, is_selected)], axis=0)
                selected_labels_per_image = tf.concat(
                    [selected_labels_per_image, tf.boolean_mask(level_labels, is_selected)], axis=0)
                selected_anchor_indices_per_image = tf.concat(
                    [selected_anchor_indices_per_image, tf.boolean_mask(anchor_indices, is_selected)], axis=0)

            # delta -> (x1, y1, w, h)
            anchor_mask = self.index_to_bool(selected_anchor_indices_per_image, self.all_anchors.shape[0])
            x1 = (selected_deltas_per_image[:, 0] * tf.boolean_mask(
                self.all_anchors, anchor_mask)[:, 2]) + tf.boolean_mask(self.all_anchors, anchor_mask)[:, 0]
            y1 = (selected_deltas_per_image[:, 1] * tf.boolean_mask(
                self.all_anchors, anchor_mask)[:, 3]) + tf.boolean_mask(self.all_anchors, anchor_mask)[:, 1]
            w = tf.math.exp(selected_deltas_per_image[:, 2]) * tf.boolean_mask(self.all_anchors, anchor_mask)[:, 2]
            h = tf.math.exp(selected_deltas_per_image[:, 3]) * tf.boolean_mask(self.all_anchors, anchor_mask)[:, 3]
            x2 = x1 + w
            y2 = y1 + h

            # nms
            boxes_per_image = tf.stack([y1, x1, y2, x2], axis=1)
            nms_indices = tf.image.non_max_suppression(boxes_per_image, selected_scores_per_image, self.nms_max_outputs)

            nms_boxes = tf.gather(boxes_per_image, nms_indices)
            final_scores = tf.gather(selected_scores_per_image, nms_indices)
            final_labels = tf.gather(selected_labels_per_image, nms_indices)

            x1 = tf.clip_by_value(nms_boxes[:, 1], clip_value_min=0, clip_value_max=self.input_shape[1])
            y1 = tf.clip_by_value(nms_boxes[:, 0], clip_value_min=0, clip_value_max=self.input_shape[0])
            w = tf.clip_by_value(nms_boxes[:, 3], clip_value_min=0, clip_value_max=self.input_shape[1]) - x1
            h = tf.clip_by_value(nms_boxes[:, 2], clip_value_min=0, clip_value_max=self.input_shape[0]) - y1

            final_boxes = tf.stack([x1, y1, w, h], axis=1)

            # combine image results into batch
            image_results = tf.concat([
                tf.pad(final_boxes, [[0, 0], [1, 0]], constant_values=i),
                tf.cast(tf.expand_dims(final_labels, axis=1), dtype=tf.float32),
                tf.expand_dims(final_scores, axis=1)
            ],
                                      axis=1)

            image_gt = tf.transpose(
               tf.concat([
                   tf.stack([i * tf.ones_like(x1_gt_per_image), x1_gt_per_image]),
                   tf.expand_dims(y1_gt_per_image, axis=0),
                   tf.expand_dims(w_gt_per_image, axis=0),
                   tf.expand_dims(h_gt_per_image, axis=0),
                   tf.expand_dims(label_gt_per_image, axis=0)
               ],
                         axis=0))
            pred.append(image_results)
            gt.append(image_gt)
            
#             tf.print('image_gt', image_gt)
#             tf.print('final_boxes', final_boxes)
            
        return tf.concat(pred, axis=0), tf.concat(gt, axis=0)

In [5]:
model_dir = '/data/hsiming/mscoco_model/'
writer = fe.RecordWriter(
    save_dir=os.path.join(path, "retinanet_coco_all"),
    train_data='/data/hsiming/dataset/MSCOCO2017/train_object.csv',
    validation_data='/data/hsiming/dataset/MSCOCO2017/val_object.csv',
    ops=[
        ImageReader(inputs="image", parent_path=path, outputs="image"),
        String2List(inputs=["x1", "y1", "width", "height", "obj_label"],
                    outputs=["x1", "y1", "width", "height", "obj_label"]),
        ResizeImageAndBbox(target_size=(512, 512),
                           keep_ratio=True,
                           inputs=["image", "x1", "y1", "width", "height"],
                           outputs=["image", "x1", "y1", "width", "height"]),
        GenerateTarget(inputs=("obj_label", "x1", "y1", "width", "height"),
                       outputs=("cls_gt", "x1_gt", "y1_gt", "w_gt", "h_gt")),
        TypeConverter(target_type='int32', inputs=["id", "cls_gt"], outputs=["id", "cls_gt"]),
        TypeConverter(target_type='float32',
                      inputs=["x1_gt", "y1_gt", "w_gt", "h_gt"],
                      outputs=["x1_gt", "y1_gt", "w_gt", "h_gt"])
    ],
    compression="GZIP",
    write_feature=[
        "image", "id", "cls_gt", "x1_gt", "y1_gt", "w_gt", "h_gt", "obj_label", "x1", "y1", "width", "height"
    ])

# prepare pipeline
pipeline = fe.Pipeline(
    batch_size=8,
    data=writer,
    ops=[
        Rescale(inputs="image", outputs="image"),
        Pad(padded_shape=[190],
            inputs=["x1_gt", "y1_gt", "w_gt", "h_gt", "obj_label", "x1", "y1", "width", "height"],
            outputs=["x1_gt", "y1_gt", "w_gt", "h_gt", "obj_label", "x1", "y1", "width", "height"])
    ])

# prepare network
model = fe.build(model_def=lambda: RetinaNet(input_shape=(512, 512, 3), num_classes=90),
                 model_name="retinanet",
                 optimizer=tf.optimizers.Adam(learning_rate=0.0002),
                 loss_name="total_loss")
network = fe.Network(ops=[
    ModelOp(inputs="image", model=model, outputs=["cls_pred", "loc_pred"]),
    PredictBox(inputs=["cls_pred", "loc_pred", "obj_label", "x1", "y1", "width", "height"],
               outputs=("pred", "gt"),
               mode="eval"),
    RetinaLoss(inputs=("cls_gt", "x1_gt", "y1_gt", "w_gt", "h_gt", "cls_pred", "loc_pred"),
               outputs=("total_loss", "focal_loss", "l1_loss"))
])

# prepare estimator
estimator = fe.Estimator(
    network=network,
    pipeline=pipeline,
    epochs=80,
    #steps_per_epoch=2,
    #log_steps=1,
    #validation_steps=2,
    traces=ModelSaver(model_name="retinanet", save_dir=model_dir, save_best=True))

In [None]:
estimator.fit()

    ______           __  ______     __  _                 __            
   / ____/___ ______/ /_/ ____/____/ /_(_)___ ___  ____ _/ /_____  _____
  / /_  / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/
 / __/ / /_/ (__  ) /_/ /___(__  ) /_/ / / / / / / /_/ / /_/ /_/ / /    
/_/    \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/     
                                                                        

FastEstimator: Saving tfrecord to /data/hsiming/dataset/MSCOCO2017/retinanet_coco_all
FastEstimator: Converting Train TFRecords 0.0%, Speed: 0.00 record/sec
FastEstimator: Converting Train TFRecords 5.0%, Speed: 72.80 record/sec
FastEstimator: Converting Train TFRecords 10.0%, Speed: 75.88 record/sec
FastEstimator: Converting Train TFRecords 15.0%, Speed: 73.63 record/sec
FastEstimator: Converting Train TFRecords 20.0%, Speed: 74.70 record/sec
FastEstimator: Converting Train TFRecords 25.0%, Speed: 73.34 record/sec
FastEstimator: Converting Train TFRecord

FastEstimator-Train: step: 3200; focal_loss: 0.8798137; l1_loss: 0.7627349; total_loss: 1.6425486; examples/sec: 33.3; progress: 0.3%; 
FastEstimator-Train: step: 3300; focal_loss: 0.8194082; l1_loss: 0.4866818; total_loss: 1.3060899; examples/sec: 33.3; progress: 0.3%; 
FastEstimator-Train: step: 3400; focal_loss: 0.795442; l1_loss: 0.5267133; total_loss: 1.3221552; examples/sec: 33.3; progress: 0.3%; 
FastEstimator-Train: step: 3500; focal_loss: 0.8672877; l1_loss: 0.5424381; total_loss: 1.4097258; examples/sec: 33.3; progress: 0.3%; 
FastEstimator-Train: step: 3600; focal_loss: 0.8193113; l1_loss: 0.3956013; total_loss: 1.2149127; examples/sec: 33.3; progress: 0.3%; 
FastEstimator-Train: step: 3700; focal_loss: 0.833902; l1_loss: 0.4563558; total_loss: 1.2902578; examples/sec: 33.3; progress: 0.3%; 
FastEstimator-Train: step: 3800; focal_loss: 0.8186311; l1_loss: 0.5018157; total_loss: 1.3204467; examples/sec: 33.3; progress: 0.3%; 
FastEstimator-Train: step: 3900; focal_loss: 0.769