Skip to content

Commit

Permalink
Minor fixes found when running
Browse files Browse the repository at this point in the history
  • Loading branch information
vierja committed Jun 16, 2017
1 parent 78ff2fe commit 6666e39
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 31 deletions.
39 changes: 21 additions & 18 deletions frcnn/anchor_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,18 @@ def __init__(self, anchors, feat_stride=[16], name='anchor_target'):
self._batch_size = 256
self._bbox_inside_weights = (1.0, 1.0, 1.0, 1.0)

def _build(self, rpn_cls_score, gt_boxes, im_info):
def _build(self, rpn_cls_prob, gt_boxes, im_info):
"""
Args:
rpn_cls_score: A Tensor with the class score for every anchor
rpn_cls_prob: A Tensor with the class probability for every anchor
generated by the Region Proposal Network. It's dimensions
should be TODO: Chequear dimensiones.
should be (1, H, W, num_anchors * 2). We only use it to get the
pretrained shape (H, W).
gt_boxes: A Tensor with the groundtruth bounding boxes of the
image of the batch being processed.
image of the batch being processed. It's dimensions should
be (num_gt, 5).
im_info: Shape of original image (height, width) in order to define
anchor targers in respect with gt_boxes.
We currently use the `anchor_target_layer` code provided in the
original Caffe implementation by Ross Girshick. Ideally we should
Expand All @@ -80,31 +84,30 @@ def _build(self, rpn_cls_score, gt_boxes, im_info):

(
labels, bbox_targets,
bbox_inside_weights, bbox_outside_weights
# bbox_inside_weights, bbox_outside_weights
) = tf.py_func(
self._anchor_target_layer_np,
[rpn_cls_score, gt_boxes, im_info],
[tf.float32, tf.float32, tf.float32, tf.float32]
[rpn_cls_prob, gt_boxes, im_info],
[tf.float32, tf.float32]

)

return labels, bbox_targets # missing bbox_inside_weights, bbox_outside_weights


def _anchor_target_layer(self, rpn_cls_score, gt_boxes, im_info):
def _anchor_target_layer(self, rpn_cls_prob, gt_boxes, im_info):
"""
Function working with Tensors instead of instances for proper
computing in the Tensorflow graph.
"""
raise NotImplemented()


def _anchor_target_layer_np(self, rpn_cls_score, gt_boxes, im_info):
def _anchor_target_layer_np(self, rpn_cls_prob, gt_boxes, im_info):
"""
Function to be executed with tf.py_func
"""

height, width = rpn_cls_score.shape[1:3]
height, width = rpn_cls_prob.shape[1:3] # TODO(debug): rpn_cls_prob.shape = (1, 23, 279, 2)

# 1. Generate proposals from bbox deltas and shifted anchors
shift_x = np.arange(0, width) * self._feat_stride
Expand Down Expand Up @@ -190,8 +193,8 @@ def _anchor_target_layer_np(self, rpn_cls_score, gt_boxes, im_info):
bg_inds, size=(len(bg_inds) - num_bg), replace=False)
labels[disable_inds] = -1

bbox_targets = np.zeros((len(inds_inside), 4), dtype=np.float32)
bbox_targets = self._compute_targets(anchors, gt_boxes[argmax_overlaps, :])
# TODO: Not necessary to define first bbox_targets = np.zeros((len(inds_inside), 4), dtype=np.float32)
bbox_targets = self._compute_targets(anchors, gt_boxes[argmax_overlaps, :]).astype(np.float32)

bbox_inside_weights = np.zeros((len(inds_inside), 4), dtype=np.float32)
bbox_inside_weights[labels == 1, :] = np.array(self._bbox_inside_weights)
Expand All @@ -211,13 +214,13 @@ def _anchor_target_layer_np(self, rpn_cls_score, gt_boxes, im_info):
bbox_outside_weights[labels == 1, :] = positive_weights
bbox_outside_weights[labels == 0, :] = negative_weights

labels = _unmap(
labels = self._unmap(
labels, total_anchors, inds_inside, fill=-1)
bbox_targets = _unmap(
bbox_targets = self._unmap(
bbox_targets, total_anchors, inds_inside, fill=0)
bbox_inside_weights = _unmap(
bbox_inside_weights = self._unmap(
bbox_inside_weights, total_anchors, inds_inside, fill=0)
bbox_outside_weights = _unmap(
bbox_outside_weights = self._unmap(
bbox_outside_weights, total_anchors, inds_inside, fill=0)

# labels
Expand All @@ -239,7 +242,7 @@ def _anchor_target_layer_np(self, rpn_cls_score, gt_boxes, im_info):
(1, height, width, A * 4)
).transpose(0, 3, 1, 2)

return labels, bbox_targets, bbox_inside_weights, bbox_outside_weights
return labels, bbox_targets

def _bbox_overlaps(self, boxes, gt_boxes):
"""
Expand Down
6 changes: 3 additions & 3 deletions frcnn/anchor_target_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def setUp(self):
def testBasic(self):
model = AnchorTarget(self.anchors)
rpn_cls_score_shape = (1, 32, 32, model._num_anchors * 2)
gt_boxes_shape = (3, 4) # 3 ground truth boxes.
gt_boxes_shape = (1, 4) # 1 ground truth boxes.
im_info_shape = (2,)

rpn_cls_score_ph = tf.placeholder(
Expand All @@ -36,8 +36,8 @@ def testBasic(self):
sess.run(tf.global_variables_initializer())
out_inst = sess.run(out, feed_dict={
rpn_cls_score_ph: np.random.rand(*rpn_cls_score_shape),
gt_boxes_ph: np.random.rand(*gt_boxes_shape),
im_info_ph: np.random.rand(*im_info_shape),
gt_boxes_ph: [[ 45, 42, 455, 342]],
im_info_ph: [375, 500],
})


Expand Down
8 changes: 4 additions & 4 deletions frcnn/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import tensorflow as tf

from .anchor_target import AnchorTarget
from .config import Config
from .dataset import TFRecordDataset
from .pretrained import VGG
from .proposal import Proposal
Expand Down Expand Up @@ -52,18 +51,19 @@ def _build(self, image, gt_boxes, is_training=True):
pretrained_output = self._pretrained(image)
rpn_layers = self._rpn(pretrained_output, is_training=is_training)
# TODO: We should rename the
rpn_cls_prob = rpn_layers['rpn_cls_prob_reshape']
rpn_bbox_pred = rpn_layers['rpn_bbox_pred']
rpn_cls_prob = rpn_layers['rpn_cls_prob_reshape'] # TODO(debug): shape (1, 21, 31, 18)
rpn_bbox_pred = rpn_layers['rpn_bbox_pred'] # TODO(debug): shape (1, 21, 31, 36)

rpn_labels, rpn_bbox = self._anchor_target(
rpn_layers['rpn_cls_score_reshape'], gt_boxes, image_shape)
rpn_cls_prob, gt_boxes, image_shape)

blob, scores = self._proposal(
rpn_layers['rpn_cls_prob'], rpn_layers['rpn_bbox_pred'])
roi_pool = self._roi_pool(blob, pretrained_output)

# TODO: Missing mapping classification_bbox to real coordinates.
# (and trimming, and NMS?)
# TODO: Missing gt_boxes labels!
classification_prob, classification_bbox = self._rcnn(roi_pool)

# TODO: We are returning only rpn tensors for training RPN.
Expand Down
12 changes: 6 additions & 6 deletions frcnn/rpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@ def _build(self, pretrained, is_training=True):
"""
TODO: We don't have BatchNorm yet.
"""
rpn = self._rpn_activation(self._rpn(pretrained))
rpn_cls_score = self._rpn_cls(rpn)
rpn_cls_score_reshape = self.spatial_reshape_layer(rpn_cls_score, 2)
rpn_cls_prob = self.spatial_softmax(rpn_cls_score_reshape)
rpn_cls_prob_reshape = self.spatial_reshape_layer(rpn_cls_prob, self._num_anchors * 2)
rpn_bbox_pred = self._rpn_bbox(rpn)
rpn = self._rpn_activation(self._rpn(pretrained)) # TODO(debug): shape (1, 21, 31, 512)
rpn_cls_score = self._rpn_cls(rpn) # TODO(debug): shape (1, 21, 31, 18)
rpn_cls_score_reshape = self.spatial_reshape_layer(rpn_cls_score, 2) # TODO(debug): shape (1, 21, 279, 2)
rpn_cls_prob = self.spatial_softmax(rpn_cls_score_reshape) # TODO(debug): shape (1, 21, 279, 2)
rpn_cls_prob_reshape = self.spatial_reshape_layer(rpn_cls_prob, self._num_anchors * 2) # TODO(debug): shape (1, 21, 31, 18)
rpn_bbox_pred = self._rpn_bbox(rpn) # TODO(debug): shape (1, 21, 31, 36)

return {
'rpn': rpn,
Expand Down

0 comments on commit 6666e39

Please sign in to comment.