Skip to content

Commit

Permalink
Rename AnchorTarget to RPNAnchorTarget
Browse files Browse the repository at this point in the history
  • Loading branch information
vierja committed Jul 5, 2017
1 parent 29dfae3 commit 26e5e5f
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 12 deletions.
2 changes: 0 additions & 2 deletions frcnn/rcnn_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
class RCNNTarget(snt.AbstractModule):
"""
Generate RCNN target tensors for both probabilities and bounding boxes.
TODO: We should unify this module with AnchorTarget.
"""
def __init__(self, num_classes, name='rcnn_proposal'):
super(RCNNTarget, self).__init__(name=name)
Expand Down
6 changes: 3 additions & 3 deletions frcnn/rpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from sonnet.python.modules.conv import Conv2D

from .anchor_target import AnchorTarget
from .rpn_anchor_target import RPNAnchorTarget
from .rpn_proposal import RPNProposal
from .utils.generate_anchors import generate_anchors
from .utils.losses import smooth_l1_loss
Expand Down Expand Up @@ -63,7 +63,7 @@ def _instantiate_layers(self):


self._proposal = RPNProposal(self._num_anchors)
self._anchor_target = AnchorTarget(self._num_anchors)
self._anchor_target = RPNAnchorTarget(self._num_anchors)

def _build(self, pretrained_feature_map, gt_boxes, image_shape, all_anchors,
is_training=True):
Expand Down Expand Up @@ -188,7 +188,7 @@ def loss(self, prediction_dict):

# Finally, we need to calculate the regression loss over
# `rpn_bbox_target` and `rpn_bbox_pred`.
# Since `rpn_bbox_target` is obtained from AnchorTargetLayer then we
# Since `rpn_bbox_target` is obtained from RPNAnchorTarget then we
# just need to apply SmoothL1Loss.
rpn_bbox_target = tf.reshape(rpn_bbox_target, [-1, 4])
rpn_bbox_pred = tf.reshape(rpn_bbox_pred, [-1, 4])
Expand Down
6 changes: 3 additions & 3 deletions frcnn/anchor_target.py → frcnn/rpn_anchor_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from .utils.bbox_transform import bbox_transform, unmap


class AnchorTarget(snt.AbstractModule):
class RPNAnchorTarget(snt.AbstractModule):
"""
AnchorTarget
RPNAnchorTarget
TODO: (copied) Assign anchors to ground-truth targets. Produces anchor
classification labels and bounding-box regression targets.
Expand Down Expand Up @@ -40,7 +40,7 @@ class AnchorTarget(snt.AbstractModule):
"""
def __init__(self, num_anchors, feat_stride=[16], name='anchor_target'):
super(AnchorTarget, self).__init__(name=name)
super(RPNAnchorTarget, self).__init__(name=name)
self._num_anchors = num_anchors
self._feat_stride = feat_stride

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,22 @@

from sonnet.testing.parameterized import parameterized

from .anchor_target import AnchorTarget
from .rpn_anchor_target import RPNAnchorTarget
from .utils.generate_anchors import generate_anchors


class AnchorTargetTest(parameterized.ParameterizedTestCase, tf.test.TestCase):
class RPNnchorTargetTest(parameterized.ParameterizedTestCase, tf.test.TestCase):

def setUp(self):
super(AnchorTargetTest, self).setUp()
super(RPNAnchorTargetTest, self).setUp()
# Setup anchors
self.anchor_scales = np.array([8, 16, 32])
self.anchor_ratios = np.array([0.5, 1, 2])
self.anchors = generate_anchors(
ratios=self.anchor_ratios, scales=self.anchor_scales)

def testBasic(self):
model = AnchorTarget(self.anchors)
model = RPNAnchorTarget(self.anchors)
rpn_cls_score_shape = (1, 32, 32, model._num_anchors * 2)
gt_boxes_shape = (1, 4) # 1 ground truth boxes.
im_info_shape = (2,)
Expand Down

0 comments on commit 26e5e5f

Please sign in to comment.