Skip to content

Commit

Permalink
Comments and simple code structure changes
Browse files Browse the repository at this point in the history
  • Loading branch information
vierja committed Oct 10, 2017
1 parent e552688 commit c5e0ddb
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 94 deletions.
17 changes: 7 additions & 10 deletions luminoth/models/fasterrcnn/roi_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,24 @@


class ROIPoolingLayer(snt.AbstractModule):
"""ROIPoolingLayer which applies ROI pooling (or tf.crop_and_resize).
"""ROIPoolingLayer applies ROI Pooling (or tf.crop_and_resize).
RoI pooling or RoI extraction is used to extract fixed size features from a
feature map using variabled sized values for extraction. Since we have
plently of proposals of different shapes and sizes, we need a way to use
all the information available in the pretrained feature map.
variable sized feature map using variabled sized bounding boxes. Since we
have proposals of different shapes and sizes, we need a way to transform
them into a fixed size Tensor for using FC layers.
There are two basic ways to do this, the original one in the FasterRCNN's
paper is RoI Pooling, which as the name suggests, it maxpools directly from
the region of interest, or proposal, into a fixed sized Tensor.
the region of interest, or proposal, into a fixed size Tensor.
The alternative way uses TensorFlow's image utility operation called,
`crop_and_resize` which first crops an Tensor using a normalized proposal,
and then applies extrapolationt to resize it to the desired size,
generating a fixed sized Tensor.
and then applies extrapolation to resize it to the desired size,
generating a fixed size Tensor.
Since there isn't a std support implemenation of RoIPooling, we apply the
easier but still proven alternatve way.
TODO: Should not be called ROIPoolingLayer, since it doesn't always apply
RoI pooling.
"""
def __init__(self, config, debug=False, name='roi_pooling'):
super(ROIPoolingLayer, self).__init__(name=name)
Expand Down
126 changes: 53 additions & 73 deletions luminoth/models/fasterrcnn/rpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,17 @@ def _instantiate_layers(self):
padding='VALID', name='bbox_conv'
)

def _build(self, conv_feature_map, image_shape, all_anchors,
def _build(self, conv_feature_map, im_shape, all_anchors,
gt_boxes=None):
"""Builds the RPN model subgraph.
Args:
conv_feature_map: A Tensor with the output of some pretrained
network. Its dimensions should be
`[feature_map_height, feature_map_width, depth]` where depth is
512 for the default layer in VGG and 1024 for the default layer
in ResNet.
image_shape: A Tensor with the shape of the original image.
`[1, feature_map_height, feature_map_width, depth]` where depth
is 512 for the default layer in VGG and 1024 for the default
layer in ResNet.
im_shape: A Tensor with the shape of the original image.
all_anchors: A Tensor with all the anchor bounding boxes. Its shape
should be
[feature_map_height * feature_map_width * total_anchors, 4]
Expand Down Expand Up @@ -136,30 +136,39 @@ def _build(self, conv_feature_map, image_shape, all_anchors,
# rpn_bbox_pred_original has shape (1, H, W, num_anchors * 4)
# where H, W are height and width of the pretrained feature map.

# Convert `rpn_cls_score` which has two scalars per anchor per location
# to be able to apply softmax.
# Convert (flatten) `rpn_cls_score_original` which has two scalars per
# anchor per location to be able to apply softmax.
rpn_cls_score = tf.reshape(rpn_cls_score_original, [-1, 2])
# Now that `rpn_cls_score` has shape (H * W * num_anchors, 2), we apply
# softmax to the last dim.
rpn_cls_prob = tf.nn.softmax(rpn_cls_score)

prediction_dict['rpn_cls_prob'] = rpn_cls_prob
prediction_dict['rpn_cls_score'] = rpn_cls_score

# Flatten bounding box delta prediction for easy manipulation.
# We end up with `rpn_bbox_pred` having shape (H * W * num_anchors, 4).
rpn_bbox_pred = tf.reshape(rpn_bbox_pred_original, [-1, 4])

prediction_dict['rpn_bbox_pred'] = rpn_bbox_pred

# We have to convert bbox deltas to usable bounding boxes and remove
# redudant bbox using non maximum suppression.
# redundant ones using Non Maximum Suppression (NMS).
proposal_prediction = self._proposal(
rpn_cls_prob, rpn_bbox_pred, all_anchors, image_shape)
rpn_cls_prob, rpn_bbox_pred, all_anchors, im_shape)

prediction_dict['proposals'] = proposal_prediction['nms_proposals']
prediction_dict['scores'] = proposal_prediction['nms_proposals_scores']

if self._debug:
prediction_dict['proposal_prediction'] = proposal_prediction

if gt_boxes is not None:
# When training we use a separate module to calculate the target
# values we want to output.
(rpn_cls_target, rpn_bbox_target,
rpn_max_overlap) = self._anchor_target(
all_anchors, gt_boxes, image_shape
all_anchors, gt_boxes, im_shape
)

prediction_dict['rpn_cls_target'] = rpn_cls_target
Expand All @@ -169,8 +178,8 @@ def _build(self, conv_feature_map, image_shape, all_anchors,
prediction_dict['rpn_max_overlap'] = rpn_max_overlap

variable_summaries(rpn_bbox_target, 'rpn_bbox_target', ['rpn'])
variable_summaries(rpn_bbox_target, 'rpn_bbox_target', ['rpn'])

# Variables summaries.
variable_summaries(
proposal_prediction['nms_proposals_scores'], 'rpn_scores', ['rpn'])
variable_summaries(rpn_cls_prob, 'rpn_cls_prob', ['rpn'])
Expand All @@ -181,108 +190,65 @@ def _build(self, conv_feature_map, image_shape, all_anchors,
variable_summaries(
rpn_bbox_pred_original, 'rpn_bbox_pred_original', ['rpn'])

# Layer summaries.
layer_summaries(self._rpn, ['rpn'])
layer_summaries(self._rpn_cls, ['rpn'])
layer_summaries(self._rpn_bbox, ['rpn'])

prediction_dict['proposals'] = proposal_prediction['nms_proposals']
prediction_dict['scores'] = proposal_prediction['nms_proposals_scores']

if self._debug:
prediction_dict['proposal_prediction'] = proposal_prediction

return prediction_dict

def loss(self, prediction_dict):
"""
Returns cost for Region Proposal Network based on:
Args:
rpn_cls_prob: Probability of for being an object for each anchor
in the image. Shape -> (num_anchors, 2)
rpn_cls_score: Score for being an object or not for each anchor
in the image. Shape: (num_anchors, 2)
rpn_cls_target: Ground truth labeling for each anchor. Should be
1: for positive labels
0: for negative labels
-1: for labels we should ignore.
Shape -> (num_anchors, 4)
* 1: for positive labels
* 0: for negative labels
* -1: for labels we should ignore.
Shape: (num_anchors, 4)
rpn_bbox_target: Bounding box output delta target for rpn.
Shape -> (num_anchors, 4)
Shape: (num_anchors, 4)
rpn_bbox_pred: Bounding box output delta prediction for rpn.
Shape -> (num_anchors, 4)
Shape: (num_anchors, 4)
Returns:
Multiloss between cls probability and bbox target.
"""

# rpn_cls_prob = prediction_dict['rpn_cls_prob']
rpn_cls_score = prediction_dict['rpn_cls_score']
rpn_cls_target = prediction_dict['rpn_cls_target']

rpn_bbox_target = prediction_dict['rpn_bbox_target']
rpn_bbox_pred = prediction_dict['rpn_bbox_pred']

# First, we need to calculate classification loss over `rpn_cls_prob`
# and `rpn_cls_target`. Ignoring all anchors where `rpn_cls_target =
# -1`.

# For classification loss we use log loss of 2 classes. So we need to:
# - filter `rpn_cls_prob` that are ignored. We need to reshape both
# labels and prob
# - transform positive and negative `rpn_cls_target` to same shape as
# `rpn_cls_prob`.
# - then we can use `tf.losses.log_loss` which returns a tensor.

with tf.variable_scope('RPNLoss'):
# Flatten already flat Tensor for usage as boolean mask filter.
rpn_cls_target = tf.cast(tf.reshape(
rpn_cls_target, [-1]), tf.int32, name='rpn_cls_target')
# Transform to boolean tensor with True only when != -1 (else
# == -1 -> False)
# Transform to boolean tensor mask for not ignored.
labels_not_ignored = tf.not_equal(
rpn_cls_target, -1, name='labels_not_ignored')

# Now we only have the labels we are going to compare with the
# cls probability.
labels = tf.boolean_mask(rpn_cls_target, labels_not_ignored)
# cls_prob = tf.boolean_mask(rpn_cls_prob, labels_not_ignored)
cls_score = tf.boolean_mask(rpn_cls_score, labels_not_ignored)

tf.summary.scalar(
'batch_size',
tf.shape(labels)[0], ['rpn']
)

# We need to transform `labels` to `cls_prob` shape.
# convert [1, 0] to [[0, 1], [1, 0]]
# We need to transform `labels` to `cls_score` shape.
# convert [1, 0] to [[0, 1], [1, 0]] for ce with logits.
cls_target = tf.one_hot(labels, depth=2)

# Equivalent to log loss
ce_per_anchor = tf.nn.softmax_cross_entropy_with_logits(
labels=cls_target, logits=cls_score
)

foreground_cls_loss = tf.boolean_mask(
ce_per_anchor, tf.equal(labels, 1)
)
background_cls_loss = tf.boolean_mask(
ce_per_anchor, tf.equal(labels, 0)
)

tf.summary.scalar(
'foreground_cls_loss',
tf.reduce_mean(foreground_cls_loss), ['rpn'])
tf.summary.histogram(
'foreground_cls_loss', foreground_cls_loss, ['rpn'])
tf.summary.scalar(
'background_cls_loss',
tf.reduce_mean(background_cls_loss), ['rpn'])
tf.summary.histogram(
'background_cls_loss', background_cls_loss, ['rpn'])

prediction_dict['cross_entropy_per_anchor'] = ce_per_anchor

# Finally, we need to calculate the regression loss over
# `rpn_bbox_target` and `rpn_bbox_pred`.
# Since `rpn_bbox_target` is obtained from RPNTarget then we
# just need to apply SmoothL1Loss.
# We use SmoothL1Loss.
rpn_bbox_target = tf.reshape(rpn_bbox_target, [-1, 4])
rpn_bbox_pred = tf.reshape(rpn_bbox_pred, [-1, 4])

Expand All @@ -292,18 +258,32 @@ def loss(self, prediction_dict):
rpn_bbox_target = tf.boolean_mask(rpn_bbox_target, positive_labels)
rpn_bbox_pred = tf.boolean_mask(rpn_bbox_pred, positive_labels)

tf.summary.scalar(
'foreground_samples',
tf.shape(rpn_bbox_target)[0], ['rpn']
)

# We apply smooth l1 loss as described by the Fast R-CNN paper.
reg_loss_per_anchor = smooth_l1_loss(
rpn_bbox_pred, rpn_bbox_target
)

prediction_dict['reg_loss_per_anchor'] = reg_loss_per_anchor

# Loss summaries.
tf.summary.scalar('batch_size', tf.shape(labels)[0], ['rpn'])
foreground_cls_loss = tf.boolean_mask(
ce_per_anchor, tf.equal(labels, 1))
background_cls_loss = tf.boolean_mask(
ce_per_anchor, tf.equal(labels, 0))
tf.summary.scalar(
'foreground_cls_loss',
tf.reduce_mean(foreground_cls_loss), ['rpn'])
tf.summary.histogram(
'foreground_cls_loss', foreground_cls_loss, ['rpn'])
tf.summary.scalar(
'background_cls_loss',
tf.reduce_mean(background_cls_loss), ['rpn'])
tf.summary.histogram(
'background_cls_loss', background_cls_loss, ['rpn'])
tf.summary.scalar(
'foreground_samples', tf.shape(rpn_bbox_target)[0], ['rpn'])

return {
'rpn_cls_loss': tf.reduce_mean(ce_per_anchor),
'rpn_reg_loss': tf.reduce_mean(reg_loss_per_anchor),
Expand Down
8 changes: 3 additions & 5 deletions luminoth/models/fasterrcnn/rpn_proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ class RPNProposal(snt.AbstractModule):
sorted by relevance.
Besides applying the transformations (or adjustments) from the prediction,
it also tries to get rid of duplicate proposals by using non maximum
supression (NMS).
it tries to get rid of duplicate proposals by using non maximum supression
(NMS).
"""
def __init__(self, num_anchors, config, name='proposal_layer'):
super(RPNProposal, self).__init__(name=name)
Expand All @@ -29,7 +29,7 @@ def __init__(self, num_anchors, config, name='proposal_layer'):
self._post_nms_top_n = config.post_nms_top_n
# Threshold to use for NMS.
self._nms_threshold = config.nms_threshold
# TODO: Currently we do not filter out proposals by size.
# Currently we do not filter out proposals by size.
self._min_size = config.min_size

def _build(self, rpn_cls_prob, rpn_bbox_pred, all_anchors, im_shape):
Expand Down Expand Up @@ -102,8 +102,6 @@ def _build(self, rpn_cls_prob, rpn_bbox_pred, all_anchors, im_shape):
proposals = clip_boxes(proposals, im_shape)

# Filter proposals with negative area.
# TODO: Optional, is not done in paper, maybe we should make it
# configurable.
(x_min, y_min, x_max, y_max) = tf.unstack(proposals, axis=1)
proposal_filter = tf.greater_equal(
(x_max - x_min) * (y_max - y_min), 0)
Expand Down
13 changes: 7 additions & 6 deletions luminoth/models/fasterrcnn/rpn_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self, num_anchors, config, seed=None, name='anchor_target'):
# When choosing random targets use `seed` to replicate behaviour.
self._seed = seed

def _build(self, all_anchors, gt_boxes, im_size):
def _build(self, all_anchors, gt_boxes, im_shape):
"""
We compare anchors to GT and using the minibatch size and the different
config settings (clobber, foreground fraction, etc), we end up with
Expand All @@ -85,11 +85,12 @@ def _build(self, all_anchors, gt_boxes, im_size):
Args:
all_anchors:
A Tensor with all the bounding boxes coords of the anchors.
Its shape should be (num_anchors, 4).
gt_boxes:
A Tensor with the ground truth bounding boxes of the image of
the batch being processed. Its dimensions should be
(num_gt, 5). The last dimension is used for the label.
im_size:
the batch being processed. Its shape should be (num_gt, 5).
The last dimension is used for the label.
im_shape:
Shape of original image (height, width) in order to define
anchor targers in respect with gt_boxes.
Expand All @@ -116,8 +117,8 @@ def _build(self, all_anchors, gt_boxes, im_size):
tf.greater_equal(y_min_anchor, -self._allowed_border)
),
tf.logical_and(
tf.less(x_max_anchor, im_size[1] + self._allowed_border),
tf.less(y_max_anchor, im_size[0] + self._allowed_border)
tf.less(x_max_anchor, im_shape[1] + self._allowed_border),
tf.less(y_max_anchor, im_shape[0] + self._allowed_border)
)
)

Expand Down

0 comments on commit c5e0ddb

Please sign in to comment.