Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions object_detection/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ py_library(
deps = [
"//tensorflow",
"//tensorflow_models/object_detection/utils:ops",
"//tensorflow_models/object_detection/utils:shape_utils",
"//tensorflow_models/object_detection/utils:static_shape",
],
)
Expand Down
35 changes: 18 additions & 17 deletions object_detection/core/box_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from abc import abstractmethod
import tensorflow as tf
from object_detection.utils import ops
from object_detection.utils import shape_utils
from object_detection.utils import static_shape

slim = tf.contrib.slim
Expand Down Expand Up @@ -316,6 +317,8 @@ def __init__(self,
self._predict_instance_masks = predict_instance_masks
self._mask_prediction_conv_depth = mask_prediction_conv_depth
self._predict_keypoints = predict_keypoints
if self._predict_instance_masks:
raise ValueError('Mask prediction is unimplemented.')
if self._predict_keypoints:
raise ValueError('Keypoint prediction is unimplemented.')
if ((self._predict_instance_masks or self._predict_keypoints) and
Expand Down Expand Up @@ -524,23 +527,21 @@ def _predict(self, image_features, num_predictions_per_location):
class_predictions_with_background = tf.sigmoid(
class_predictions_with_background)

batch_size = static_shape.get_batch_size(image_features.get_shape())
if batch_size is None:
features_height = static_shape.get_height(image_features.get_shape())
features_width = static_shape.get_width(image_features.get_shape())
flattened_predictions_size = (features_height * features_width *
num_predictions_per_location)
box_encodings = tf.reshape(
box_encodings,
[-1, flattened_predictions_size, 1, self._box_code_size])
class_predictions_with_background = tf.reshape(
class_predictions_with_background,
[-1, flattened_predictions_size, num_class_slots])
else:
box_encodings = tf.reshape(
box_encodings, [batch_size, -1, 1, self._box_code_size])
class_predictions_with_background = tf.reshape(
class_predictions_with_background, [batch_size, -1, num_class_slots])
combined_feature_map_shape = shape_utils.combined_static_and_dynamic_shape(
image_features)
box_encodings = tf.reshape(
box_encodings, tf.stack([combined_feature_map_shape[0],
combined_feature_map_shape[1] *
combined_feature_map_shape[2] *
num_predictions_per_location,
1, self._box_code_size]))
class_predictions_with_background = tf.reshape(
class_predictions_with_background,
tf.stack([combined_feature_map_shape[0],
combined_feature_map_shape[1] *
combined_feature_map_shape[2] *
num_predictions_per_location,
num_class_slots]))
return {BOX_ENCODINGS: box_encodings,
CLASS_PREDICTIONS_WITH_BACKGROUND:
class_predictions_with_background}
15 changes: 7 additions & 8 deletions object_detection/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,25 +228,24 @@ def provide_groundtruth(self,
fields.BoxListFields.keypoints] = groundtruth_keypoints_list

@abstractmethod
def restore_fn(self, checkpoint_path, from_detection_checkpoint=True):
"""Return callable for loading a foreign checkpoint into tensorflow graph.
def restore_map(self, from_detection_checkpoint=True):
"""Returns a map of variables to load from a foreign checkpoint.

Loads variables from a different tensorflow graph (typically feature
extractor variables). This enables the model to initialize based on weights
from another task. For example, the feature extractor variables from a
Returns a map of variable names to load from a checkpoint to variables in
the model graph. This enables the model to initialize based on weights from
another task. For example, the feature extractor variables from a
classification model can be used to bootstrap training of an object
detector. When loading from an object detection model, the checkpoint model
should have the same parameters as this detection model with exception of
the num_classes parameter.

Args:
checkpoint_path: path to checkpoint to restore.
from_detection_checkpoint: whether to restore from a full detection
checkpoint (with compatible variable names) or to restore from a
classification checkpoint for initialization prior to training.

Returns:
a callable which takes a tf.Session as input and loads a checkpoint when
run.
A dict mapping variable names (to load from a checkpoint) to variables in
the model graph.
"""
pass
128 changes: 71 additions & 57 deletions object_detection/core/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ def batch_multiclass_non_max_suppression(boxes,
change_coordinate_frame=False,
num_valid_boxes=None,
masks=None,
scope=None):
scope=None,
parallel_iterations=32):
"""Multi-class version of non maximum suppression that operates on a batch.

This op is similar to `multiclass_non_max_suppression` but operates on a batch
Expand Down Expand Up @@ -208,63 +209,74 @@ def batch_multiclass_non_max_suppression(boxes,
float32 tensor containing box masks. `q` can be either number of classes
or 1 depending on whether a separate mask is predicted per class.
scope: tf scope name.
parallel_iterations: (optional) number of batch items to process in
parallel.

Returns:
A dictionary containing the following entries:
'detection_boxes': A [batch_size, max_detections, 4] float32 tensor
'nmsed_boxes': A [batch_size, max_detections, 4] float32 tensor
containing the non-max suppressed boxes.
'detection_scores': A [bath_size, max_detections] float32 tensor containing
'nmsed_scores': A [batch_size, max_detections] float32 tensor containing
the scores for the boxes.
'detection_classes': A [batch_size, max_detections] float32 tensor
'nmsed_classes': A [batch_size, max_detections] float32 tensor
containing the class for boxes.
'num_detections': A [batchsize] float32 tensor indicating the number of
'nmsed_masks': (optional) a
[batch_size, max_detections, mask_height, mask_width] float32 tensor
containing masks for each selected box. This is set to None if input
`masks` is None.
'num_detections': A [batch_size] int32 tensor indicating the number of
valid detections per batch item. Only the top num_detections[i] entries in
nms_boxes[i], nms_scores[i] and nms_class[i] are valid. the rest of the
entries are zero paddings.
'detection_masks': (optional) a
[batch_size, max_detections, mask_height, mask_width] float32 tensor
containing masks for each selected box.

Raises:
ValueError: if iou_thresh is not in [0, 1] or if input boxlist does not have
a valid scores field.
ValueError: if `q` in boxes.shape is not 1 or not equal to number of
classes as inferred from scores.shape.
"""
q = boxes.shape[2].value
num_classes = scores.shape[2].value
if q != 1 and q != num_classes:
raise ValueError('third dimension of boxes must be either 1 or equal '
'to the third dimension of scores')

original_masks = masks
with tf.name_scope(scope, 'BatchMultiClassNonMaxSuppression'):
per_image_boxes_list = tf.unstack(boxes)
per_image_scores_list = tf.unstack(scores)
num_valid_boxes_list = len(per_image_boxes_list) * [None]
per_image_masks_list = len(per_image_boxes_list) * [None]
if num_valid_boxes is not None:
num_valid_boxes_list = tf.unstack(num_valid_boxes)
if masks is not None:
per_image_masks_list = tf.unstack(masks)
boxes_shape = boxes.shape
batch_size = boxes_shape[0].value
num_anchors = boxes_shape[1].value

if batch_size is None:
batch_size = tf.shape(boxes)[0]
if num_anchors is None:
num_anchors = tf.shape(boxes)[1]

# If num valid boxes aren't provided, create one and mark all boxes as
# valid.
if num_valid_boxes is None:
num_valid_boxes = tf.ones([batch_size], dtype=tf.int32) * num_anchors

detection_boxes_list = []
detection_scores_list = []
detection_classes_list = []
num_detections_list = []
detection_masks_list = []
for (per_image_boxes, per_image_scores, per_image_masks, num_valid_boxes
) in zip(per_image_boxes_list, per_image_scores_list,
per_image_masks_list, num_valid_boxes_list):
if num_valid_boxes is not None:
per_image_boxes = tf.reshape(
tf.slice(per_image_boxes, 3*[0],
tf.stack([num_valid_boxes, -1, -1])), [-1, q, 4])
per_image_scores = tf.reshape(
tf.slice(per_image_scores, [0, 0],
tf.stack([num_valid_boxes, -1])), [-1, num_classes])
if masks is not None:
per_image_masks = tf.reshape(
tf.slice(per_image_masks, 4*[0],
tf.stack([num_valid_boxes, -1, -1, -1])),
[-1, q, masks.shape[3].value, masks.shape[4].value])
# If masks aren't provided, create dummy masks so we can only have one copy
# of single_image_nms_fn and discard the dummy masks after map_fn.
if masks is None:
masks_shape = tf.stack([batch_size, num_anchors, 1, 0, 0])
masks = tf.zeros(masks_shape)

def single_image_nms_fn(args):
"""Runs NMS on a single image and returns padded output."""
(per_image_boxes, per_image_scores, per_image_masks,
per_image_num_valid_boxes) = args
per_image_boxes = tf.reshape(
tf.slice(per_image_boxes, 3 * [0],
tf.stack([per_image_num_valid_boxes, -1, -1])), [-1, q, 4])
per_image_scores = tf.reshape(
tf.slice(per_image_scores, [0, 0],
tf.stack([per_image_num_valid_boxes, -1])),
[-1, num_classes])

per_image_masks = tf.reshape(
tf.slice(per_image_masks, 4 * [0],
tf.stack([per_image_num_valid_boxes, -1, -1, -1])),
[-1, q, per_image_masks.shape[2].value,
per_image_masks.shape[3].value])
nmsed_boxlist = multiclass_non_max_suppression(
per_image_boxes,
per_image_scores,
Expand All @@ -275,24 +287,26 @@ def batch_multiclass_non_max_suppression(boxes,
masks=per_image_masks,
clip_window=clip_window,
change_coordinate_frame=change_coordinate_frame)
num_detections_list.append(tf.to_float(nmsed_boxlist.num_boxes()))
padded_boxlist = box_list_ops.pad_or_clip_box_list(nmsed_boxlist,
max_total_size)
detection_boxes_list.append(padded_boxlist.get())
detection_scores_list.append(
padded_boxlist.get_field(fields.BoxListFields.scores))
detection_classes_list.append(
padded_boxlist.get_field(fields.BoxListFields.classes))
if masks is not None:
detection_masks_list.append(
padded_boxlist.get_field(fields.BoxListFields.masks))
num_detections = nmsed_boxlist.num_boxes()
nmsed_boxes = padded_boxlist.get()
nmsed_scores = padded_boxlist.get_field(fields.BoxListFields.scores)
nmsed_classes = padded_boxlist.get_field(fields.BoxListFields.classes)
nmsed_masks = padded_boxlist.get_field(fields.BoxListFields.masks)
return [nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
num_detections]

nms_dict = {
'detection_boxes': tf.stack(detection_boxes_list),
'detection_scores': tf.stack(detection_scores_list),
'detection_classes': tf.stack(detection_classes_list),
'num_detections': tf.stack(num_detections_list)
}
if masks is not None:
nms_dict['detection_masks'] = tf.stack(detection_masks_list)
return nms_dict
(batch_nmsed_boxes, batch_nmsed_scores,
batch_nmsed_classes, batch_nmsed_masks,
batch_num_detections) = tf.map_fn(
single_image_nms_fn,
elems=[boxes, scores, masks, num_valid_boxes],
dtype=[tf.float32, tf.float32, tf.float32, tf.float32, tf.int32],
parallel_iterations=parallel_iterations)

if original_masks is None:
batch_nmsed_masks = None

return (batch_nmsed_boxes, batch_nmsed_scores, batch_nmsed_classes,
batch_nmsed_masks, batch_num_detections)
Loading