Skip to content

Commit

Permalink
Modified RPN Proposal tests to check correct scores
Browse files Browse the repository at this point in the history
  • Loading branch information
gastonrod07 authored and vierja committed Aug 14, 2017
1 parent c287a6a commit c02b2a2
Showing 1 changed file with 66 additions and 2 deletions.
68 changes: 66 additions & 2 deletions luminoth/models/fasterrcnn/rpn_proposal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,51 +75,75 @@ def testNMSThreshold(self):
all_anchors, gt_boxes, rpn_cls_prob, self.config)

# Check we get exactly 2 'nms proposals' because 2 IoU equals to 0.
# Also check that we get the corrects scores.
self.assertEqual(
results['nms_proposals'].shape,
(2, 5)
)

self.assertAllClose(
results['nms_proposals_scores'],
[0.9, 0.8]
)

config['nms_threshold'] = 0.3

results = self._run_rpn_proposal(
all_anchors, gt_boxes, rpn_cls_prob, self.config)

# Check we get exactly 3 'nms proposals' because 3 IoU lowers than 0.3.
# Also check that we get the corrects scores.
self.assertEqual(
results['nms_proposals'].shape,
(3, 5)
)

self.assertAllClose(
results['nms_proposals_scores'],
[0.9, 0.8, 0.6]
)

config['nms_threshold'] = 0.6

results = self._run_rpn_proposal(
all_anchors, gt_boxes, rpn_cls_prob, self.config)

# Check we get exactly 3 'nms proposals' because 3 IoU lowers than 0.3.
# Also check that we get the corrects scores.
self.assertEqual(
results['nms_proposals'].shape,
(3, 5)
)

self.assertAllClose(
results['nms_proposals_scores'],
[0.9, 0.8, 0.6]
)

config['nms_threshold'] = 0.8

results = self._run_rpn_proposal(
all_anchors, gt_boxes, rpn_cls_prob, self.config)

# Check we get exactly 3 'nms proposals' because 3 IoU lowers than 0.8.
# Also check that we get the corrects scores.
self.assertEqual(
results['nms_proposals'].shape,
(3, 5)
)

self.assertAllClose(
results['nms_proposals_scores'],
[0.9, 0.8, 0.6]
)

config['nms_threshold'] = 1.0

results = self._run_rpn_proposal(
all_anchors, gt_boxes, rpn_cls_prob, self.config)

# Check we get 'post_nms_top_n' nms proposals because
# 'nms_threshold' = 1.
# 'nms_threshold' = 1 and this only removes duplicates.
self.assertEqual(
results['nms_proposals'].shape,
(4, 5)
Expand Down Expand Up @@ -163,6 +187,12 @@ def testOutsidersAndTopN(self):
(3, 4)
)

# Also check that we get the corrects scores.
self.assertAllClose(
results['nms_proposals_scores'],
[0.7, 0.6, 0.2]
)

config = self.config
config['post_nms_top_n'] = 2

Expand All @@ -181,6 +211,17 @@ def testOutsidersAndTopN(self):
(3, 4)
)

# Also check that we get the corrects scores.
self.assertAllClose(
results['nms_proposals_scores'],
[0.7, 0.6]
)

self.assertAllClose(
results['scores'],
[0.7, 0.6, 0.2]
)

config['post_nms_top_n'] = 3
config['pre_nms_top_n'] = 2

Expand All @@ -189,6 +230,7 @@ def testOutsidersAndTopN(self):

# Check that with a post_nms_top_n = 3 and pre_nms_top = 2
# we have only 2 'nms proposals' but 3 'proposals'.

self.assertAllEqual(
results['nms_proposals'].shape,
(2, 5)
Expand All @@ -199,6 +241,17 @@ def testOutsidersAndTopN(self):
(3, 4)
)

# Also check that we get the corrects scores.
self.assertAllClose(
results['nms_proposals_scores'],
[0.7, 0.6]
)

self.assertAllClose(
results['scores'],
[0.7, 0.6, 0.2]
)

config['post_nms_top_n'] = 1

results = self._run_rpn_proposal(
Expand All @@ -216,6 +269,17 @@ def testOutsidersAndTopN(self):
(3, 4)
)

# Also check that we get the corrects scores.
self.assertAllClose(
results['nms_proposals_scores'],
[0.7]
)

self.assertAllClose(
results['scores'],
[0.7, 0.6, 0.2]
)

def testNegativeArea(self):
"""
Test negative area filters
Expand Down Expand Up @@ -254,7 +318,7 @@ def testNegativeArea(self):
(2, 4)
)

def testClippingOfProporsals(self):
def testClippingOfProposals(self):
"""
Test clipping of proposals
"""
Expand Down

0 comments on commit c02b2a2

Please sign in to comment.