From 559de279b11bd1708a8aace546ae9c96f379a453 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 7 Jun 2021 13:45:19 +0100 Subject: [PATCH] Port test_backbone_utils.py to pytest --- test/test_backbone_utils.py | 26 ++++++-------------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/test/test_backbone_utils.py b/test/test_backbone_utils.py index 7ee1aed1459..712dccf11a8 100644 --- a/test/test_backbone_utils.py +++ b/test/test_backbone_utils.py @@ -1,25 +1,11 @@ -import unittest - - import torch from torchvision.models.detection.backbone_utils import resnet_fpn_backbone +import pytest -class ResnetFPNBackboneTester(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.dtype = torch.float32 - - def test_resnet18_fpn_backbone(self): - device = torch.device('cpu') - x = torch.rand(1, 3, 300, 300, dtype=self.dtype, device=device) - resnet18_fpn = resnet_fpn_backbone(backbone_name='resnet18', pretrained=False) - y = resnet18_fpn(x) - self.assertEqual(list(y.keys()), ['0', '1', '2', '3', 'pool']) - def test_resnet50_fpn_backbone(self): - device = torch.device('cpu') - x = torch.rand(1, 3, 300, 300, dtype=self.dtype, device=device) - resnet50_fpn = resnet_fpn_backbone(backbone_name='resnet50', pretrained=False) - y = resnet50_fpn(x) - self.assertEqual(list(y.keys()), ['0', '1', '2', '3', 'pool']) +@pytest.mark.parametrize('backbone_name', ('resnet18', 'resnet50')) +def test_resnet_fpn_backbone(backbone_name): + x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device='cpu') + y = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)(x) + assert list(y.keys()) == ['0', '1', '2', '3', 'pool']