Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 38 additions & 34 deletions test/test_models_detection_negative_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
from torchvision.models.detection.roi_heads import RoIHeads
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead

import unittest
import pytest
from _assert_utils import assert_equal


class Tester(unittest.TestCase):
class TestModelsDetectionNegativeSamples:

def _make_empty_sample(self, add_masks=False, add_keypoints=False):
images = [torch.rand((3, 100, 100), dtype=torch.float32)]
Expand Down Expand Up @@ -48,13 +49,13 @@ def test_targets_to_anchors(self):

labels, matched_gt_boxes = head.assign_targets_to_anchors(anchors, targets)

self.assertEqual(labels[0].sum(), 0)
self.assertEqual(labels[0].shape, torch.Size([anchors[0].shape[0]]))
self.assertEqual(labels[0].dtype, torch.float32)
assert labels[0].sum() == 0
assert labels[0].shape == torch.Size([anchors[0].shape[0]])
assert labels[0].dtype == torch.float32

self.assertEqual(matched_gt_boxes[0].sum(), 0)
self.assertEqual(matched_gt_boxes[0].shape, anchors[0].shape)
self.assertEqual(matched_gt_boxes[0].dtype, torch.float32)
assert matched_gt_boxes[0].sum() == 0
assert matched_gt_boxes[0].shape == anchors[0].shape
assert matched_gt_boxes[0].dtype == torch.float32

def test_assign_targets_to_proposals(self):

Expand Down Expand Up @@ -88,25 +89,28 @@ def test_assign_targets_to_proposals(self):

matched_idxs, labels = roi_heads.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)

self.assertEqual(matched_idxs[0].sum(), 0)
self.assertEqual(matched_idxs[0].shape, torch.Size([proposals[0].shape[0]]))
self.assertEqual(matched_idxs[0].dtype, torch.int64)

self.assertEqual(labels[0].sum(), 0)
self.assertEqual(labels[0].shape, torch.Size([proposals[0].shape[0]]))
self.assertEqual(labels[0].dtype, torch.int64)

def test_forward_negative_sample_frcnn(self):
for name in ["fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large_fpn",
"fasterrcnn_mobilenet_v3_large_320_fpn"]:
model = torchvision.models.detection.__dict__[name](
num_classes=2, min_size=100, max_size=100)
assert matched_idxs[0].sum() == 0
assert matched_idxs[0].shape == torch.Size([proposals[0].shape[0]])
assert matched_idxs[0].dtype == torch.int64

assert labels[0].sum() == 0
assert labels[0].shape == torch.Size([proposals[0].shape[0]])
assert labels[0].dtype == torch.int64

@pytest.mark.parametrize('name', [
"fasterrcnn_resnet50_fpn",
"fasterrcnn_mobilenet_v3_large_fpn",
"fasterrcnn_mobilenet_v3_large_320_fpn",
])
def test_forward_negative_sample_frcnn(self, name):
model = torchvision.models.detection.__dict__[name](
num_classes=2, min_size=100, max_size=100)

images, targets = self._make_empty_sample()
loss_dict = model(images, targets)
images, targets = self._make_empty_sample()
loss_dict = model(images, targets)

self.assertEqual(loss_dict["loss_box_reg"], torch.tensor(0.))
self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))
assert_equal(loss_dict["loss_box_reg"], torch.tensor(0.))
assert_equal(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))

def test_forward_negative_sample_mrcnn(self):
model = torchvision.models.detection.maskrcnn_resnet50_fpn(
Expand All @@ -115,9 +119,9 @@ def test_forward_negative_sample_mrcnn(self):
images, targets = self._make_empty_sample(add_masks=True)
loss_dict = model(images, targets)

self.assertEqual(loss_dict["loss_box_reg"], torch.tensor(0.))
self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))
self.assertEqual(loss_dict["loss_mask"], torch.tensor(0.))
assert_equal(loss_dict["loss_box_reg"], torch.tensor(0.))
assert_equal(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))
assert_equal(loss_dict["loss_mask"], torch.tensor(0.))

def test_forward_negative_sample_krcnn(self):
model = torchvision.models.detection.keypointrcnn_resnet50_fpn(
Expand All @@ -126,9 +130,9 @@ def test_forward_negative_sample_krcnn(self):
images, targets = self._make_empty_sample(add_keypoints=True)
loss_dict = model(images, targets)

self.assertEqual(loss_dict["loss_box_reg"], torch.tensor(0.))
self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))
self.assertEqual(loss_dict["loss_keypoint"], torch.tensor(0.))
assert_equal(loss_dict["loss_box_reg"], torch.tensor(0.))
assert_equal(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))
assert_equal(loss_dict["loss_keypoint"], torch.tensor(0.))

def test_forward_negative_sample_retinanet(self):
model = torchvision.models.detection.retinanet_resnet50_fpn(
Expand All @@ -137,7 +141,7 @@ def test_forward_negative_sample_retinanet(self):
images, targets = self._make_empty_sample()
loss_dict = model(images, targets)

self.assertEqual(loss_dict["bbox_regression"], torch.tensor(0.))
assert_equal(loss_dict["bbox_regression"], torch.tensor(0.))

def test_forward_negative_sample_ssd(self):
model = torchvision.models.detection.ssd300_vgg16(
Expand All @@ -146,8 +150,8 @@ def test_forward_negative_sample_ssd(self):
images, targets = self._make_empty_sample()
loss_dict = model(images, targets)

self.assertEqual(loss_dict["bbox_regression"], torch.tensor(0.))
assert_equal(loss_dict["bbox_regression"], torch.tensor(0.))


if __name__ == '__main__':
unittest.main()
pytest.main([__file__])