diff --git a/test/test_backbone_utils.py b/test/test_backbone_utils.py index 7ee1aed1459..712dccf11a8 100644 --- a/test/test_backbone_utils.py +++ b/test/test_backbone_utils.py @@ -1,25 +1,11 @@ -import unittest - - import torch from torchvision.models.detection.backbone_utils import resnet_fpn_backbone +import pytest -class ResnetFPNBackboneTester(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.dtype = torch.float32 - - def test_resnet18_fpn_backbone(self): - device = torch.device('cpu') - x = torch.rand(1, 3, 300, 300, dtype=self.dtype, device=device) - resnet18_fpn = resnet_fpn_backbone(backbone_name='resnet18', pretrained=False) - y = resnet18_fpn(x) - self.assertEqual(list(y.keys()), ['0', '1', '2', '3', 'pool']) - def test_resnet50_fpn_backbone(self): - device = torch.device('cpu') - x = torch.rand(1, 3, 300, 300, dtype=self.dtype, device=device) - resnet50_fpn = resnet_fpn_backbone(backbone_name='resnet50', pretrained=False) - y = resnet50_fpn(x) - self.assertEqual(list(y.keys()), ['0', '1', '2', '3', 'pool']) +@pytest.mark.parametrize('backbone_name', ('resnet18', 'resnet50')) +def test_resnet_fpn_backbone(backbone_name): + x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device='cpu') + y = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)(x) + assert list(y.keys()) == ['0', '1', '2', '3', 'pool']