From ffa59769be09bd22b6761669a37d06848256e04a Mon Sep 17 00:00:00 2001 From: Anirudh Dagar Date: Sun, 6 Jun 2021 15:54:18 +0530 Subject: [PATCH 1/3] Refactor test_models to pytest --- test/test_models.py | 716 ++++++++++++++++++++++---------------------- 1 file changed, 352 insertions(+), 364 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index 3a335f6c3a6..2a4e6279acc 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -1,7 +1,7 @@ import os import io import sys -from common_utils import TestCase, map_nested_tensor_object, freeze_rng_state, set_rng_seed +from common_utils import TestCase, map_nested_tensor_object, freeze_rng_state, set_rng_seed, cpu_and_gpu, needs_cuda from _utils_internal import get_relative_path from collections import OrderedDict from itertools import product @@ -10,10 +10,8 @@ import torch import torch.nn as nn from torchvision import models -import unittest -import warnings - import pytest +import warnings ACCEPT = os.getenv('EXPECTTEST_ACCEPT', '0') == '1' @@ -220,399 +218,389 @@ def get_export_import_copy(m): } -class ModelTester(TestCase): - def _test_classification_model(self, name, dev): - set_rng_seed(0) - defaults = { - 'num_classes': 50, - 'input_shape': (1, 3, 224, 224), - } - kwargs = {**defaults, **_model_params.get(name, {})} - input_shape = kwargs.pop('input_shape') - - model = models.__dict__[name](**kwargs) - model.eval().to(device=dev) - # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests - x = torch.rand(input_shape).to(device=dev) - out = model(x) - _assert_expected(out.cpu(), name, prec=0.1) - self.assertEqual(out.shape[-1], 50) - _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None)) - - if dev == torch.device("cuda"): - with torch.cuda.amp.autocast(): - out = model(x) - # See autocast_flaky_numerics comment at top of file. - if name not in autocast_flaky_numerics: - _assert_expected(out.cpu(), name, prec=0.1) - self.assertEqual(out.shape[-1], 50) - - def _test_segmentation_model(self, name, dev): - set_rng_seed(0) - defaults = { - 'num_classes': 10, - 'pretrained_backbone': False, - 'input_shape': (1, 3, 32, 32), - } - kwargs = {**defaults, **_model_params.get(name, {})} - input_shape = kwargs.pop('input_shape') - - model = models.segmentation.__dict__[name](**kwargs) - model.eval().to(device=dev) - # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests - x = torch.rand(input_shape).to(device=dev) - out = model(x)["out"] - - def check_out(out): - prec = 0.01 - try: - # We first try to assert the entire output if possible. This is not - # only the best way to assert results but also handles the cases - # where we need to create a new expected result. - _assert_expected(out.cpu(), name, prec=prec) - except AssertionError: - # Unfortunately some segmentation models are flaky with autocast - # so instead of validating the probability scores, check that the class - # predictions match. - expected_file = _get_expected_file(name) - expected = torch.load(expected_file) - torch.testing.assert_close(out.argmax(dim=1), expected.argmax(dim=1), rtol=prec, atol=prec) - return False # Partial validation performed - - return True # Full validation performed - - full_validation = check_out(out) - - _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None)) - - if dev == torch.device("cuda"): - with torch.cuda.amp.autocast(): - out = model(x)["out"] - # See autocast_flaky_numerics comment at top of file. - if name not in autocast_flaky_numerics: - full_validation &= check_out(out) - - if not full_validation: - msg = "The output of {} could only be partially validated. " \ - "This is likely due to unit-test flakiness, but you may " \ - "want to do additional manual checks if you made " \ - "significant changes to the codebase.".format(self._testMethodName) - warnings.warn(msg, RuntimeWarning) - raise unittest.SkipTest(msg) - - def _test_detection_model(self, name, dev): - set_rng_seed(0) - defaults = { - 'num_classes': 50, - 'pretrained_backbone': False, - 'input_shape': (3, 300, 300), - } - kwargs = {**defaults, **_model_params.get(name, {})} - input_shape = kwargs.pop('input_shape') - - model = models.detection.__dict__[name](**kwargs) - model.eval().to(device=dev) - # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests - x = torch.rand(input_shape).to(device=dev) - model_input = [x] - out = model(model_input) - self.assertIs(model_input[0], x) - - def check_out(out): - self.assertEqual(len(out), 1) - - def compact(tensor): - size = tensor.size() - elements_per_sample = functools.reduce(operator.mul, size[1:], 1) - if elements_per_sample > 30: - return compute_mean_std(tensor) - else: - return subsample_tensor(tensor) - - def subsample_tensor(tensor): - num_elems = tensor.size(0) - num_samples = 20 - if num_elems <= num_samples: - return tensor - - ith_index = num_elems // num_samples - return tensor[ith_index - 1::ith_index] - - def compute_mean_std(tensor): - # can't compute mean of integral tensor - tensor = tensor.to(torch.double) - mean = torch.mean(tensor) - std = torch.std(tensor) - return {"mean": mean, "std": std} - - output = map_nested_tensor_object(out, tensor_map_fn=compact) - prec = 0.01 - try: - # We first try to assert the entire output if possible. This is not - # only the best way to assert results but also handles the cases - # where we need to create a new expected result. - _assert_expected(output, name, prec=prec) - except AssertionError: - # Unfortunately detection models are flaky due to the unstable sort - # in NMS. If matching across all outputs fails, use the same approach - # as in NMSTester.test_nms_cuda to see if this is caused by duplicate - # scores. - expected_file = _get_expected_file(name) - expected = torch.load(expected_file) - torch.testing.assert_close(output[0]["scores"], expected[0]["scores"], rtol=prec, atol=prec, - check_device=False, check_dtype=False) - - # Note: Fmassa proposed turning off NMS by adapting the threshold - # and then using the Hungarian algorithm as in DETR to find the - # best match between output and expected boxes and eliminate some - # of the flakiness. Worth exploring. - return False # Partial validation performed - - return True # Full validation performed - - full_validation = check_out(out) - _check_jit_scriptable(model, ([x],), unwrapper=script_model_unwrapper.get(name, None)) - - if dev == torch.device("cuda"): - with torch.cuda.amp.autocast(): - out = model(model_input) - # See autocast_flaky_numerics comment at top of file. - if name not in autocast_flaky_numerics: - full_validation &= check_out(out) - - if not full_validation: - msg = "The output of {} could only be partially validated. " \ - "This is likely due to unit-test flakiness, but you may " \ - "want to do additional manual checks if you made " \ - "significant changes to the codebase.".format(self._testMethodName) - warnings.warn(msg, RuntimeWarning) - raise unittest.SkipTest(msg) - - def _test_detection_model_validation(self, name): - set_rng_seed(0) - model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False) - input_shape = (3, 300, 300) - x = [torch.rand(input_shape)] - - # validate that targets are present in training - self.assertRaises(ValueError, model, x) - - # validate type - targets = [{'boxes': 0.}] - self.assertRaises(ValueError, model, x, targets=targets) - - # validate boxes shape - for boxes in (torch.rand((4,)), torch.rand((1, 5))): - targets = [{'boxes': boxes}] - self.assertRaises(ValueError, model, x, targets=targets) - - # validate that no degenerate boxes are present - boxes = torch.tensor([[1, 3, 1, 4], [2, 4, 3, 4]]) - targets = [{'boxes': boxes}] - self.assertRaises(ValueError, model, x, targets=targets) - - def _test_video_model(self, name, dev): - # the default input shape is - # bs * num_channels * clip_len * h *w - input_shape = (1, 3, 4, 112, 112) - # test both basicblock and Bottleneck - model = models.video.__dict__[name](num_classes=50) - model.eval().to(device=dev) - # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests - x = torch.rand(input_shape).to(device=dev) - out = model(x) - _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None)) - self.assertEqual(out.shape[-1], 50) - - if dev == torch.device("cuda"): - with torch.cuda.amp.autocast(): - out = model(x) - self.assertEqual(out.shape[-1], 50) - - def _make_sliced_model(self, model, stop_layer): - layers = OrderedDict() - for name, layer in model.named_children(): - layers[name] = layer - if name == stop_layer: - break - new_model = torch.nn.Sequential(layers) - return new_model - - def test_memory_efficient_densenet(self): - input_shape = (1, 3, 300, 300) - x = torch.rand(input_shape) - - for name in ['densenet121', 'densenet169', 'densenet201', 'densenet161']: - model1 = models.__dict__[name](num_classes=50, memory_efficient=True) - params = model1.state_dict() - num_params = sum([x.numel() for x in model1.parameters()]) - model1.eval() - out1 = model1(x) - out1.sum().backward() - num_grad = sum([x.grad.numel() for x in model1.parameters() if x.grad is not None]) - - model2 = models.__dict__[name](num_classes=50, memory_efficient=False) - model2.load_state_dict(params) - model2.eval() - out2 = model2(x) - - self.assertTrue(num_params == num_grad) - torch.testing.assert_close(out1, out2, rtol=0.0, atol=1e-5) - - def test_resnet_dilation(self): - # TODO improve tests to also check that each layer has the right dimensionality - for i in product([False, True], [False, True], [False, True]): - model = models.__dict__["resnet50"](replace_stride_with_dilation=i) - model = self._make_sliced_model(model, stop_layer="layer4") - model.eval() - x = torch.rand(1, 3, 224, 224) - out = model(x) - f = 2 ** sum(i) - self.assertEqual(out.shape, (1, 2048, 7 * f, 7 * f)) +def _make_sliced_model(model, stop_layer): + layers = OrderedDict() + for name, layer in model.named_children(): + layers[name] = layer + if name == stop_layer: + break + new_model = torch.nn.Sequential(layers) + return new_model + + +def test_memory_efficient_densenet(): + input_shape = (1, 3, 300, 300) + x = torch.rand(input_shape) + + for name in ['densenet121', 'densenet169', 'densenet201', 'densenet161']: + model1 = models.__dict__[name](num_classes=50, memory_efficient=True) + params = model1.state_dict() + num_params = sum([x.numel() for x in model1.parameters()]) + model1.eval() + out1 = model1(x) + out1.sum().backward() + num_grad = sum([x.grad.numel() for x in model1.parameters() if x.grad is not None]) + + model2 = models.__dict__[name](num_classes=50, memory_efficient=False) + model2.load_state_dict(params) + model2.eval() + out2 = model2(x) - def test_mobilenet_v2_residual_setting(self): - model = models.__dict__["mobilenet_v2"](inverted_residual_setting=[[1, 16, 1, 1], [6, 24, 2, 2]]) + assert num_params == num_grad + torch.testing.assert_close(out1, out2, rtol=0.0, atol=1e-5) + + +def test_resnet_dilation(): + # TODO improve tests to also check that each layer has the right dimensionality + for i in product([False, True], [False, True], [False, True]): + model = models.__dict__["resnet50"](replace_stride_with_dilation=i) + model = _make_sliced_model(model, stop_layer="layer4") model.eval() x = torch.rand(1, 3, 224, 224) out = model(x) - self.assertEqual(out.shape[-1], 1000) - - def test_mobilenet_norm_layer(self): - for name in ["mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"]: - model = models.__dict__[name]() - self.assertTrue(any(isinstance(x, nn.BatchNorm2d) for x in model.modules())) - - def get_gn(num_channels): - return nn.GroupNorm(32, num_channels) - - model = models.__dict__[name](norm_layer=get_gn) - self.assertFalse(any(isinstance(x, nn.BatchNorm2d) for x in model.modules())) - self.assertTrue(any(isinstance(x, nn.GroupNorm) for x in model.modules())) - - def test_inception_v3_eval(self): - # replacement for models.inception_v3(pretrained=True) that does not download weights - kwargs = {} - kwargs['transform_input'] = True - kwargs['aux_logits'] = True - kwargs['init_weights'] = False - name = "inception_v3" - model = models.Inception3(**kwargs) - model.aux_logits = False - model.AuxLogits = None - model = model.eval() - x = torch.rand(1, 3, 299, 299) - _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None)) - - def test_fasterrcnn_double(self): - model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False) - model.double() - model.eval() - input_shape = (3, 300, 300) - x = torch.rand(input_shape, dtype=torch.float64) - model_input = [x] - out = model(model_input) - self.assertIs(model_input[0], x) - self.assertEqual(len(out), 1) - self.assertTrue("boxes" in out[0]) - self.assertTrue("scores" in out[0]) - self.assertTrue("labels" in out[0]) - - def test_googlenet_eval(self): - # replacement for models.googlenet(pretrained=True) that does not download weights - kwargs = {} - kwargs['transform_input'] = True - kwargs['aux_logits'] = True - kwargs['init_weights'] = False - name = "googlenet" - model = models.GoogLeNet(**kwargs) - model.aux_logits = False - model.aux1 = None - model.aux2 = None - model = model.eval() - x = torch.rand(1, 3, 224, 224) - _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None)) - - @unittest.skipIf(not torch.cuda.is_available(), 'needs GPU') - def test_fasterrcnn_switch_devices(self): - def checkOut(out): - self.assertEqual(len(out), 1) - self.assertTrue("boxes" in out[0]) - self.assertTrue("scores" in out[0]) - self.assertTrue("labels" in out[0]) - - model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False) - model.cuda() - model.eval() - input_shape = (3, 300, 300) - x = torch.rand(input_shape, device='cuda') - model_input = [x] + f = 2 ** sum(i) + assert out.shape == (1, 2048, 7 * f, 7 * f) + + +def test_mobilenet_v2_residual_setting(): + model = models.__dict__["mobilenet_v2"](inverted_residual_setting=[[1, 16, 1, 1], [6, 24, 2, 2]]) + model.eval() + x = torch.rand(1, 3, 224, 224) + out = model(x) + assert out.shape[-1] == 1000 + + +def test_mobilenet_norm_layer(): + for name in ["mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"]: + model = models.__dict__[name]() + assert any(isinstance(x, nn.BatchNorm2d) for x in model.modules()) + + def get_gn(num_channels): + return nn.GroupNorm(32, num_channels) + + model = models.__dict__[name](norm_layer=get_gn) + assert not(any(isinstance(x, nn.BatchNorm2d) for x in model.modules())) + assert any(isinstance(x, nn.GroupNorm) for x in model.modules()) + + +def test_inception_v3_eval(): + # replacement for models.inception_v3(pretrained=True) that does not download weights + kwargs = {} + kwargs['transform_input'] = True + kwargs['aux_logits'] = True + kwargs['init_weights'] = False + name = "inception_v3" + model = models.Inception3(**kwargs) + model.aux_logits = False + model.AuxLogits = None + model = model.eval() + x = torch.rand(1, 3, 299, 299) + _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None)) + + +def test_fasterrcnn_double(): + model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False) + model.double() + model.eval() + input_shape = (3, 300, 300) + x = torch.rand(input_shape, dtype=torch.float64) + model_input = [x] + out = model(model_input) + assert model_input[0] is x + assert len(out) == 1 + assert "boxes" in out[0] + assert "scores" in out[0] + assert "labels" in out[0] + + +def test_googlenet_eval(): + # replacement for models.googlenet(pretrained=True) that does not download weights + kwargs = {} + kwargs['transform_input'] = True + kwargs['aux_logits'] = True + kwargs['init_weights'] = False + name = "googlenet" + model = models.GoogLeNet(**kwargs) + model.aux_logits = False + model.aux1 = None + model.aux2 = None + model = model.eval() + x = torch.rand(1, 3, 224, 224) + _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None)) + + +@needs_cuda +def test_fasterrcnn_switch_devices(): + def checkOut(out): + assert len(out) == 1 + assert "boxes" in out[0] + assert "scores" in out[0] + assert "labels" in out[0] + + model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False) + model.cuda() + model.eval() + input_shape = (3, 300, 300) + x = torch.rand(input_shape, device='cuda') + model_input = [x] + out = model(model_input) + assert model_input[0] is x + + checkOut(out) + + with torch.cuda.amp.autocast(): out = model(model_input) - self.assertIs(model_input[0], x) - - checkOut(out) - - with torch.cuda.amp.autocast(): - out = model(model_input) - - checkOut(out) - # now switch to cpu and make sure it works - model.cpu() - x = x.cpu() - out_cpu = model([x]) + checkOut(out) - checkOut(out_cpu) + # now switch to cpu and make sure it works + model.cpu() + x = x.cpu() + out_cpu = model([x]) - def test_generalizedrcnn_transform_repr(self): + checkOut(out_cpu) - min_size, max_size = 224, 299 - image_mean = [0.485, 0.456, 0.406] - image_std = [0.229, 0.224, 0.225] - t = models.detection.transform.GeneralizedRCNNTransform(min_size=min_size, - max_size=max_size, - image_mean=image_mean, - image_std=image_std) +def test_generalizedrcnn_transform_repr(): - # Check integrity of object __repr__ attribute - expected_string = 'GeneralizedRCNNTransform(' - _indent = '\n ' - expected_string += '{0}Normalize(mean={1}, std={2})'.format(_indent, image_mean, image_std) - expected_string += '{0}Resize(min_size=({1},), max_size={2}, '.format(_indent, min_size, max_size) - expected_string += "mode='bilinear')\n)" - self.assertEqual(t.__repr__(), expected_string) + min_size, max_size = 224, 299 + image_mean = [0.485, 0.456, 0.406] + image_std = [0.229, 0.224, 0.225] + t = models.detection.transform.GeneralizedRCNNTransform(min_size=min_size, + max_size=max_size, + image_mean=image_mean, + image_std=image_std) -_devs = [torch.device("cpu"), torch.device("cuda")] if torch.cuda.is_available() else [torch.device("cpu")] + # Check integrity of object __repr__ attribute + expected_string = 'GeneralizedRCNNTransform(' + _indent = '\n ' + expected_string += '{0}Normalize(mean={1}, std={2})'.format(_indent, image_mean, image_std) + expected_string += '{0}Resize(min_size=({1},), max_size={2}, '.format(_indent, min_size, max_size) + expected_string += "mode='bilinear')\n)" + assert t.__repr__() == expected_string @pytest.mark.parametrize('model_name', get_available_classification_models()) -@pytest.mark.parametrize('dev', _devs) +@pytest.mark.parametrize('dev', cpu_and_gpu()) def test_classification_model(model_name, dev): - ModelTester()._test_classification_model(model_name, dev) + set_rng_seed(0) + defaults = { + 'num_classes': 50, + 'input_shape': (1, 3, 224, 224), + } + kwargs = {**defaults, **_model_params.get(model_name, {})} + input_shape = kwargs.pop('input_shape') + + model = models.__dict__[model_name](**kwargs) + model.eval().to(device=dev) + # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests + x = torch.rand(input_shape).to(device=dev) + out = model(x) + _assert_expected(out.cpu(), model_name, prec=0.1) + assert out.shape[-1] == 50 + _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None)) + + if dev == torch.device("cuda"): + with torch.cuda.amp.autocast(): + out = model(x) + # See autocast_flaky_numerics comment at top of file. + if model_name not in autocast_flaky_numerics: + _assert_expected(out.cpu(), model_name, prec=0.1) + assert out.shape[-1] == 50 @pytest.mark.parametrize('model_name', get_available_segmentation_models()) -@pytest.mark.parametrize('dev', _devs) +@pytest.mark.parametrize('dev', cpu_and_gpu()) def test_segmentation_model(model_name, dev): - ModelTester()._test_segmentation_model(model_name, dev) + set_rng_seed(0) + defaults = { + 'num_classes': 10, + 'pretrained_backbone': False, + 'input_shape': (1, 3, 32, 32), + } + kwargs = {**defaults, **_model_params.get(model_name, {})} + input_shape = kwargs.pop('input_shape') + + model = models.segmentation.__dict__[model_name](**kwargs) + model.eval().to(device=dev) + # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests + x = torch.rand(input_shape).to(device=dev) + out = model(x)["out"] + + def check_out(out): + prec = 0.01 + try: + # We first try to assert the entire output if possible. This is not + # only the best way to assert results but also handles the cases + # where we need to create a new expected result. + _assert_expected(out.cpu(), model_name, prec=prec) + except AssertionError: + # Unfortunately some segmentation models are flaky with autocast + # so instead of validating the probability scores, check that the class + # predictions match. + expected_file = _get_expected_file(model_name) + expected = torch.load(expected_file) + torch.testing.assert_close(out.argmax(dim=1), expected.argmax(dim=1), rtol=prec, atol=prec) + return False # Partial validation performed + + return True # Full validation performed + + full_validation = check_out(out) + + _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None)) + + if dev == torch.device("cuda"): + with torch.cuda.amp.autocast(): + out = model(x)["out"] + # See autocast_flaky_numerics comment at top of file. + if model_name not in autocast_flaky_numerics: + full_validation &= check_out(out) + + if not full_validation: + msg = "The output of {} could only be partially validated. " \ + "This is likely due to unit-test flakiness, but you may " \ + "want to do additional manual checks if you made " \ + "significant changes to the codebase.".format(test_segmentation_model.__name__) + warnings.warn(msg, RuntimeWarning) + pytest.skip(msg) @pytest.mark.parametrize('model_name', get_available_detection_models()) -@pytest.mark.parametrize('dev', _devs) +@pytest.mark.parametrize('dev', cpu_and_gpu()) def test_detection_model(model_name, dev): - ModelTester()._test_detection_model(model_name, dev) + set_rng_seed(0) + defaults = { + 'num_classes': 50, + 'pretrained_backbone': False, + 'input_shape': (3, 300, 300), + } + kwargs = {**defaults, **_model_params.get(model_name, {})} + input_shape = kwargs.pop('input_shape') + + model = models.detection.__dict__[model_name](**kwargs) + model.eval().to(device=dev) + # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests + x = torch.rand(input_shape).to(device=dev) + model_input = [x] + out = model(model_input) + assert model_input[0] is x + + def check_out(out): + assert len(out) == 1 + + def compact(tensor): + size = tensor.size() + elements_per_sample = functools.reduce(operator.mul, size[1:], 1) + if elements_per_sample > 30: + return compute_mean_std(tensor) + else: + return subsample_tensor(tensor) + + def subsample_tensor(tensor): + num_elems = tensor.size(0) + num_samples = 20 + if num_elems <= num_samples: + return tensor + + ith_index = num_elems // num_samples + return tensor[ith_index - 1::ith_index] + + def compute_mean_std(tensor): + # can't compute mean of integral tensor + tensor = tensor.to(torch.double) + mean = torch.mean(tensor) + std = torch.std(tensor) + return {"mean": mean, "std": std} + + output = map_nested_tensor_object(out, tensor_map_fn=compact) + prec = 0.01 + try: + # We first try to assert the entire output if possible. This is not + # only the best way to assert results but also handles the cases + # where we need to create a new expected result. + _assert_expected(output, model_name, prec=prec) + except AssertionError: + # Unfortunately detection models are flaky due to the unstable sort + # in NMS. If matching across all outputs fails, use the same approach + # as in NMSTester.test_nms_cuda to see if this is caused by duplicate + # scores. + expected_file = _get_expected_file(model_name) + expected = torch.load(expected_file) + torch.testing.assert_close(output[0]["scores"], expected[0]["scores"], rtol=prec, atol=prec, + check_device=False, check_dtype=False) + + # Note: Fmassa proposed turning off NMS by adapting the threshold + # and then using the Hungarian algorithm as in DETR to find the + # best match between output and expected boxes and eliminate some + # of the flakiness. Worth exploring. + return False # Partial validation performed + + return True # Full validation performed + + full_validation = check_out(out) + _check_jit_scriptable(model, ([x],), unwrapper=script_model_unwrapper.get(model_name, None)) + + if dev == torch.device("cuda"): + with torch.cuda.amp.autocast(): + out = model(model_input) + # See autocast_flaky_numerics comment at top of file. + if model_name not in autocast_flaky_numerics: + full_validation &= check_out(out) + + if not full_validation: + msg = "The output of {} could only be partially validated. " \ + "This is likely due to unit-test flakiness, but you may " \ + "want to do additional manual checks if you made " \ + "significant changes to the codebase.".format(test_detection_model.__name__) + warnings.warn(msg, RuntimeWarning) + pytest.skip(msg) @pytest.mark.parametrize('model_name', get_available_detection_models()) def test_detection_model_validation(model_name): - ModelTester()._test_detection_model_validation(model_name) + set_rng_seed(0) + model = models.detection.__dict__[model_name](num_classes=50, pretrained_backbone=False) + input_shape = (3, 300, 300) + x = [torch.rand(input_shape)] + + # validate that targets are present in training + pytest.raises(ValueError, model, x) + + # validate type + targets = [{'boxes': 0.}] + pytest.raises(ValueError, model, x, targets=targets) + + # validate boxes shape + for boxes in (torch.rand((4,)), torch.rand((1, 5))): + targets = [{'boxes': boxes}] + pytest.raises(ValueError, model, x, targets=targets) + + # validate that no degenerate boxes are present + boxes = torch.tensor([[1, 3, 1, 4], [2, 4, 3, 4]]) + targets = [{'boxes': boxes}] + pytest.raises(ValueError, model, x, targets=targets) @pytest.mark.parametrize('model_name', get_available_video_models()) -@pytest.mark.parametrize('dev', _devs) +@pytest.mark.parametrize('dev', cpu_and_gpu()) def test_video_model(model_name, dev): - ModelTester()._test_video_model(model_name, dev) + # the default input shape is + # bs * num_channels * clip_len * h *w + input_shape = (1, 3, 4, 112, 112) + # test both basicblock and Bottleneck + model = models.video.__dict__[model_name](num_classes=50) + model.eval().to(device=dev) + # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests + x = torch.rand(input_shape).to(device=dev) + out = model(x) + _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None)) + assert out.shape[-1] == 50 + + if dev == torch.device("cuda"): + with torch.cuda.amp.autocast(): + out = model(x) + assert out.shape[-1] == 50 if __name__ == '__main__': From 5acb52655555ead8cbf66f98a4f741d75605f571 Mon Sep 17 00:00:00 2001 From: Anirudh Dagar Date: Sun, 6 Jun 2021 23:51:56 +0530 Subject: [PATCH 2/3] parametrize test_resnet_dilation, test_mobilenet_norm_layer & add @cpu_only --- test/test_models.py | 92 ++++++++++++++++++++++++++------------------- 1 file changed, 54 insertions(+), 38 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index 2a4e6279acc..c9c6e1f2e5d 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -1,7 +1,8 @@ import os import io import sys -from common_utils import TestCase, map_nested_tensor_object, freeze_rng_state, set_rng_seed, cpu_and_gpu, needs_cuda +from common_utils import TestCase, map_nested_tensor_object, freeze_rng_state, set_rng_seed +from common_utils import cpu_and_gpu, needs_cuda, cpu_only from _utils_internal import get_relative_path from collections import OrderedDict from itertools import product @@ -228,40 +229,45 @@ def _make_sliced_model(model, stop_layer): return new_model -def test_memory_efficient_densenet(): +@cpu_only +@pytest.mark.parametrize('model_name', ['densenet121', 'densenet169', 'densenet201', 'densenet161']) +def test_memory_efficient_densenet(model_name): input_shape = (1, 3, 300, 300) x = torch.rand(input_shape) - for name in ['densenet121', 'densenet169', 'densenet201', 'densenet161']: - model1 = models.__dict__[name](num_classes=50, memory_efficient=True) - params = model1.state_dict() - num_params = sum([x.numel() for x in model1.parameters()]) - model1.eval() - out1 = model1(x) - out1.sum().backward() - num_grad = sum([x.grad.numel() for x in model1.parameters() if x.grad is not None]) + model1 = models.__dict__[model_name](num_classes=50, memory_efficient=True) + params = model1.state_dict() + num_params = sum([x.numel() for x in model1.parameters()]) + model1.eval() + out1 = model1(x) + out1.sum().backward() + num_grad = sum([x.grad.numel() for x in model1.parameters() if x.grad is not None]) - model2 = models.__dict__[name](num_classes=50, memory_efficient=False) - model2.load_state_dict(params) - model2.eval() - out2 = model2(x) + model2 = models.__dict__[model_name](num_classes=50, memory_efficient=False) + model2.load_state_dict(params) + model2.eval() + out2 = model2(x) - assert num_params == num_grad - torch.testing.assert_close(out1, out2, rtol=0.0, atol=1e-5) + assert num_params == num_grad + torch.testing.assert_close(out1, out2, rtol=0.0, atol=1e-5) -def test_resnet_dilation(): +@cpu_only +@pytest.mark.parametrize('dilate_layer_2', (True, False)) +@pytest.mark.parametrize('dilate_layer_3', (True, False)) +@pytest.mark.parametrize('dilate_layer_4', (True, False)) +def test_resnet_dilation(dilate_layer_2, dilate_layer_3, dilate_layer_4): # TODO improve tests to also check that each layer has the right dimensionality - for i in product([False, True], [False, True], [False, True]): - model = models.__dict__["resnet50"](replace_stride_with_dilation=i) - model = _make_sliced_model(model, stop_layer="layer4") - model.eval() - x = torch.rand(1, 3, 224, 224) - out = model(x) - f = 2 ** sum(i) - assert out.shape == (1, 2048, 7 * f, 7 * f) + model = models.__dict__["resnet50"](replace_stride_with_dilation=(dilate_layer_2, dilate_layer_3, dilate_layer_4)) + model = _make_sliced_model(model, stop_layer="layer4") + model.eval() + x = torch.rand(1, 3, 224, 224) + out = model(x) + f = 2 ** sum((dilate_layer_2, dilate_layer_3, dilate_layer_4)) + assert out.shape == (1, 2048, 7 * f, 7 * f) +@cpu_only def test_mobilenet_v2_residual_setting(): model = models.__dict__["mobilenet_v2"](inverted_residual_setting=[[1, 16, 1, 1], [6, 24, 2, 2]]) model.eval() @@ -270,19 +276,21 @@ def test_mobilenet_v2_residual_setting(): assert out.shape[-1] == 1000 -def test_mobilenet_norm_layer(): - for name in ["mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"]: - model = models.__dict__[name]() - assert any(isinstance(x, nn.BatchNorm2d) for x in model.modules()) +@cpu_only +@pytest.mark.parametrize('model_name', ["mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"]) +def test_mobilenet_norm_layer(model_name): + model = models.__dict__[model_name]() + assert any(isinstance(x, nn.BatchNorm2d) for x in model.modules()) - def get_gn(num_channels): - return nn.GroupNorm(32, num_channels) + def get_gn(num_channels): + return nn.GroupNorm(32, num_channels) - model = models.__dict__[name](norm_layer=get_gn) - assert not(any(isinstance(x, nn.BatchNorm2d) for x in model.modules())) - assert any(isinstance(x, nn.GroupNorm) for x in model.modules()) + model = models.__dict__[model_name](norm_layer=get_gn) + assert not(any(isinstance(x, nn.BatchNorm2d) for x in model.modules())) + assert any(isinstance(x, nn.GroupNorm) for x in model.modules()) +@cpu_only def test_inception_v3_eval(): # replacement for models.inception_v3(pretrained=True) that does not download weights kwargs = {} @@ -298,6 +306,7 @@ def test_inception_v3_eval(): _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None)) +@cpu_only def test_fasterrcnn_double(): model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False) model.double() @@ -313,6 +322,7 @@ def test_fasterrcnn_double(): assert "labels" in out[0] +@cpu_only def test_googlenet_eval(): # replacement for models.googlenet(pretrained=True) that does not download weights kwargs = {} @@ -361,6 +371,7 @@ def checkOut(out): checkOut(out_cpu) +@cpu_only def test_generalizedrcnn_transform_repr(): min_size, max_size = 224, 299 @@ -557,6 +568,7 @@ def compute_mean_std(tensor): pytest.skip(msg) +@cpu_only @pytest.mark.parametrize('model_name', get_available_detection_models()) def test_detection_model_validation(model_name): set_rng_seed(0) @@ -565,21 +577,25 @@ def test_detection_model_validation(model_name): x = [torch.rand(input_shape)] # validate that targets are present in training - pytest.raises(ValueError, model, x) + with pytest.raises(ValueError): + model(x) # validate type targets = [{'boxes': 0.}] - pytest.raises(ValueError, model, x, targets=targets) + with pytest.raises(ValueError): + model(x, targets=targets) # validate boxes shape for boxes in (torch.rand((4,)), torch.rand((1, 5))): targets = [{'boxes': boxes}] - pytest.raises(ValueError, model, x, targets=targets) + with pytest.raises(ValueError): + model(x, targets=targets) # validate that no degenerate boxes are present boxes = torch.tensor([[1, 3, 1, 4], [2, 4, 3, 4]]) targets = [{'boxes': boxes}] - pytest.raises(ValueError, model, x, targets=targets) + with pytest.raises(ValueError): + model(x, targets=targets) @pytest.mark.parametrize('model_name', get_available_video_models()) From 36250c278814642d5f4db6b413a0eda47510b1e2 Mon Sep 17 00:00:00 2001 From: Anirudh Dagar Date: Sun, 6 Jun 2021 23:58:10 +0530 Subject: [PATCH 3/3] remove unused imports --- test/test_models.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index c9c6e1f2e5d..4f021d323b2 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -1,11 +1,9 @@ import os import io import sys -from common_utils import TestCase, map_nested_tensor_object, freeze_rng_state, set_rng_seed -from common_utils import cpu_and_gpu, needs_cuda, cpu_only +from common_utils import map_nested_tensor_object, freeze_rng_state, set_rng_seed, cpu_and_gpu, needs_cuda, cpu_only from _utils_internal import get_relative_path from collections import OrderedDict -from itertools import product import functools import operator import torch