From 9cbd37e0d12920960d2c3c77a386f91439336f8e Mon Sep 17 00:00:00 2001 From: Anirudh Dagar Date: Thu, 10 Jun 2021 17:35:55 +0530 Subject: [PATCH] port test_models_detection_utils.py to pytest --- test/test_models_detection_utils.py | 44 +++++++++++++++-------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/test/test_models_detection_utils.py b/test/test_models_detection_utils.py index a20e0abc965..bb50e237544 100644 --- a/test/test_models_detection_utils.py +++ b/test/test_models_detection_utils.py @@ -2,12 +2,13 @@ import torch from torchvision.models.detection import _utils from torchvision.models.detection.transform import GeneralizedRCNNTransform -import unittest +import pytest from torchvision.models.detection import backbone_utils from _assert_utils import assert_equal -class Tester(unittest.TestCase): +class TestModelsDetectionUtils: + def test_balanced_positive_negative_sampler(self): sampler = _utils.BalancedPositiveNegativeSampler(4, 0.25) # keep all 6 negatives first, then add 3 positives, last two are ignore @@ -16,39 +17,40 @@ def test_balanced_positive_negative_sampler(self): # we know the number of elements that should be sampled for the positive (1) # and the negative (3), and their location. Let's make sure that they are # there - self.assertEqual(pos[0].sum(), 1) - self.assertEqual(pos[0][6:9].sum(), 1) - self.assertEqual(neg[0].sum(), 3) - self.assertEqual(neg[0][0:6].sum(), 3) + assert pos[0].sum() == 1 + assert pos[0][6:9].sum() == 1 + assert neg[0].sum() == 3 + assert neg[0][0:6].sum() == 3 - def test_resnet_fpn_backbone_frozen_layers(self): + @pytest.mark.parametrize('train_layers, exp_froz_params', [ + (0, 53), (1, 43), (2, 24), (3, 11), (4, 1), (5, 0) + ]) + def test_resnet_fpn_backbone_frozen_layers(self, train_layers, exp_froz_params): # we know how many initial layers and parameters of the network should # be frozen for each trainable_backbone_layers parameter value # i.e all 53 params are frozen if trainable_backbone_layers=0 # ad first 24 params are frozen if trainable_backbone_layers=2 - expected_frozen_params = {0: 53, 1: 43, 2: 24, 3: 11, 4: 1, 5: 0} - for train_layers, exp_froz_params in expected_frozen_params.items(): - model = backbone_utils.resnet_fpn_backbone( - 'resnet50', pretrained=False, trainable_layers=train_layers) - # boolean list that is true if the param at that index is frozen - is_frozen = [not parameter.requires_grad for _, parameter in model.named_parameters()] - # check that expected initial number of layers are frozen - self.assertTrue(all(is_frozen[:exp_froz_params])) + model = backbone_utils.resnet_fpn_backbone( + 'resnet50', pretrained=False, trainable_layers=train_layers) + # boolean list that is true if the param at that index is frozen + is_frozen = [not parameter.requires_grad for _, parameter in model.named_parameters()] + # check that expected initial number of layers are frozen + assert all(is_frozen[:exp_froz_params]) def test_validate_resnet_inputs_detection(self): # default number of backbone layers to train ret = backbone_utils._validate_trainable_layers( pretrained=True, trainable_backbone_layers=None, max_value=5, default_value=3) - self.assertEqual(ret, 3) + assert ret == 3 # can't go beyond 5 - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): ret = backbone_utils._validate_trainable_layers( pretrained=True, trainable_backbone_layers=6, max_value=5, default_value=3) # if not pretrained, should use all trainable layers and warn - with self.assertWarns(UserWarning): + with pytest.warns(UserWarning): ret = backbone_utils._validate_trainable_layers( pretrained=False, trainable_backbone_layers=0, max_value=5, default_value=3) - self.assertEqual(ret, 5) + assert ret == 5 def test_transform_copy_targets(self): transform = GeneralizedRCNNTransform(300, 500, torch.zeros(3), torch.ones(3)) @@ -63,9 +65,9 @@ def test_not_float_normalize(self): transform = GeneralizedRCNNTransform(300, 500, torch.zeros(3), torch.ones(3)) image = [torch.randint(0, 255, (3, 200, 300), dtype=torch.uint8)] targets = [{'boxes': torch.rand(3, 4)}] - with self.assertRaises(TypeError): + with pytest.raises(TypeError): out = transform(image, targets) # noqa: F841 if __name__ == '__main__': - unittest.main() + pytest.main([__file__])