From 1877086f313e4849495c7ae68b64ea24a384d0ce Mon Sep 17 00:00:00 2001 From: vivekkumar7089 Date: Fri, 11 Jun 2021 07:59:51 +0530 Subject: [PATCH] Port test_models_detection_anchor_utils.py to pytest --- test/test_models_detection_anchor_utils.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/test/test_models_detection_anchor_utils.py b/test/test_models_detection_anchor_utils.py index 13c399a0c32..c918d3fc8df 100644 --- a/test/test_models_detection_anchor_utils.py +++ b/test/test_models_detection_anchor_utils.py @@ -3,6 +3,7 @@ from _assert_utils import assert_equal from torchvision.models.detection.anchor_utils import AnchorGenerator, DefaultBoxGenerator from torchvision.models.detection.image_list import ImageList +import pytest class Tester(TestCase): @@ -13,7 +14,7 @@ def test_incorrect_anchors(self): image1 = torch.randn(3, 800, 800) image_list = ImageList(image1, [(800, 800)]) feature_maps = [torch.randn(1, 50)] - self.assertRaises(ValueError, anc, image_list, feature_maps) + pytest.raises(ValueError, anc, image_list, feature_maps) def _init_test_anchor_generator(self): anchor_sizes = ((10,),) @@ -59,10 +60,10 @@ def test_anchor_generator(self): [0., 5., 10., 15.], [5., 5., 15., 15.]]) - self.assertEqual(num_anchors_estimated, 9) - self.assertEqual(len(anchors), 2) - self.assertEqual(tuple(anchors[0].shape), (9, 4)) - self.assertEqual(tuple(anchors[1].shape), (9, 4)) + assert num_anchors_estimated == 9 + assert len(anchors) == 2 + assert tuple(anchors[0].shape) == (9, 4) + assert tuple(anchors[1].shape) == (9, 4) assert_equal(anchors[0], anchors_output) assert_equal(anchors[1], anchors_output) @@ -83,8 +84,8 @@ def test_defaultbox_generator(self): [6.7045, 5.9090, 8.2955, 9.0910] ]) - self.assertEqual(len(dboxes), 2) - self.assertEqual(tuple(dboxes[0].shape), (4, 4)) - self.assertEqual(tuple(dboxes[1].shape), (4, 4)) + assert len(dboxes) == 2 + assert tuple(dboxes[0].shape) == (4, 4) + assert tuple(dboxes[1].shape) == (4, 4) torch.testing.assert_close(dboxes[0], dboxes_output, rtol=1e-5, atol=1e-8) torch.testing.assert_close(dboxes[1], dboxes_output, rtol=1e-5, atol=1e-8)