Skip to content

Commit

Permalink
Add tests for negative samples for Mask R-CNN and Keypoint R-CNN (#2069)
Browse files Browse the repository at this point in the history
* Add tests for negative samples for Mask R-CNN and Keypoint R-CNN

* Fix lint
  • Loading branch information
fmassa committed Apr 8, 2020
1 parent e61538c commit 1ae7f5c
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 16 deletions.
53 changes: 38 additions & 15 deletions test/test_models_detection_negative_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,27 @@

class Tester(unittest.TestCase):

def test_targets_to_anchors(self):
def _make_empty_sample(self, add_masks=False, add_keypoints=False):
images = [torch.rand((3, 100, 100), dtype=torch.float32)]
boxes = torch.zeros((0, 4), dtype=torch.float32)
negative_target = {"boxes": boxes,
"labels": torch.zeros((1, 1), dtype=torch.int64),
"labels": torch.zeros(0, dtype=torch.int64),
"image_id": 4,
"area": (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]),
"iscrowd": torch.zeros((0,), dtype=torch.int64)}

anchors = [torch.randint(-50, 50, (3, 4), dtype=torch.float32)]
if add_masks:
negative_target["masks"] = torch.zeros(0, 100, 100, dtype=torch.uint8)

if add_keypoints:
negative_target["keypoints"] = torch.zeros(17, 0, 3, dtype=torch.float32)

targets = [negative_target]
return images, targets

def test_targets_to_anchors(self):
_, targets = self._make_empty_sample()
anchors = [torch.randint(-50, 50, (3, 4), dtype=torch.float32)]

anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
Expand Down Expand Up @@ -85,25 +96,37 @@ def test_assign_targets_to_proposals(self):
self.assertEqual(labels[0].shape, torch.Size([proposals[0].shape[0]]))
self.assertEqual(labels[0].dtype, torch.int64)

def test_forward_negative_sample(self):
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2)
def test_forward_negative_sample_frcnn(self):
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
num_classes=2, min_size=100, max_size=100)

images = [torch.rand((3, 100, 100), dtype=torch.float32)]
boxes = torch.zeros((0, 4), dtype=torch.float32)
negative_target = {"boxes": boxes,
"labels": torch.zeros((1, 1), dtype=torch.int64),
"image_id": 4,
"area": (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]),
"iscrowd": torch.zeros((0,), dtype=torch.int64)}
images, targets = self._make_empty_sample()
loss_dict = model(images, targets)

targets = [negative_target]
self.assertEqual(loss_dict["loss_box_reg"], torch.tensor(0.))
self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))

def test_forward_negative_sample_mrcnn(self):
model = torchvision.models.detection.maskrcnn_resnet50_fpn(
num_classes=2, min_size=100, max_size=100)

images, targets = self._make_empty_sample(add_masks=True)
loss_dict = model(images, targets)

self.assertEqual(loss_dict["loss_box_reg"], torch.tensor(0.))
self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))
self.assertEqual(loss_dict["loss_mask"], torch.tensor(0.))

def test_forward_negative_sample_krcnn(self):
model = torchvision.models.detection.keypointrcnn_resnet50_fpn(
num_classes=2, min_size=100, max_size=100)

images, targets = self._make_empty_sample(add_keypoints=True)
loss_dict = model(images, targets)

self.assertEqual(loss_dict["loss_box_reg"], torch.tensor(0.))
self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))
self.assertEqual(loss_dict["loss_keypoint"], torch.tensor(0.))


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion torchvision/ops/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class can go away.
)

output_shape = _output_size(2, input, size, scale_factor)
output_shape = input.shape[:-2] + output_shape
output_shape = list(input.shape[:-2]) + output_shape
return _new_empty_tensor(input, output_shape)


Expand Down

0 comments on commit 1ae7f5c

Please sign in to comment.