From 49e8bf754c31fc6bb5e1bd002cc26d216ef6f902 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 7 Oct 2019 18:31:00 -0700 Subject: [PATCH 01/17] Add tests for results in script vs eager mode This copies some logic from `test_jit.py` to check that a TorchScript'ed model's outputs are the same as outputs from the model in eager mode. To support differences in TorchScript / eager mode outputs, an `unwrapper` function can be provided per-model. --- test/common_utils.py | 186 ++++++++++++++++++++++ test/test_models.py | 66 ++++---- torchvision/models/segmentation/_utils.py | 1 + 3 files changed, 224 insertions(+), 29 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index 9c0c3175ef1..a4a2e62dd47 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -5,10 +5,15 @@ import unittest import argparse import sys +import io import torch import errno import __main__ +from numbers import Number +from torch._six import string_classes, inf +from collections import OrderedDict + @contextlib.contextmanager def get_tmp_dir(src=None, **kwargs): @@ -64,9 +69,19 @@ def map_nested_tensor_object(object, tensor_map_fn): return impl(object) +def is_iterable(obj): + try: + iter(obj) + return True + except TypeError: + return False + + # adapted from TestCase in torch/test/common_utils to accept non-string # inputs and set maximum binary size class TestCase(unittest.TestCase): + precision = 1e-5 + def assertExpected(self, output, subname=None, rtol=None, atol=None): r""" Test that a python value matches the recorded contents of a file @@ -131,6 +146,121 @@ def accept_output(update_type): else: self.assertNestedTensorObjectsEqual(output, expected, rtol=rtol, atol=atol) + def assertEqual(self, x, y, prec=None, message='', allow_inf=False): + """ + This is copied from pytorch/test/common_utils.py's TestCase.assertEqual + """ + if isinstance(prec, str) and message == '': + message = prec + prec = None + if prec is None: + prec = self.precision + + if isinstance(x, torch.Tensor) and isinstance(y, Number): + self.assertEqual(x.item(), y, prec=prec, message=message, + allow_inf=allow_inf) + elif isinstance(y, torch.Tensor) and isinstance(x, Number): + self.assertEqual(x, y.item(), prec=prec, message=message, + allow_inf=allow_inf) + elif isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor): + def assertTensorsEqual(a, b): + super(TestCase, self).assertEqual(a.size(), b.size(), message) + if a.numel() > 0: + if (a.device.type == 'cpu' and (a.dtype == torch.float16 or a.dtype == torch.bfloat16)): + # CPU half and bfloat16 tensors don't have the methods we need below + a = a.to(torch.float32) + b = b.to(a) + + if (a.dtype == torch.bool) != (b.dtype == torch.bool): + raise TypeError("Was expecting both tensors to be bool type.") + else: + if a.dtype == torch.bool and b.dtype == torch.bool: + # we want to respect precision but as bool doesn't support substraction, + # boolean tensor has to be converted to int + a = a.to(torch.int) + b = b.to(torch.int) + + diff = a - b + if a.is_floating_point(): + # check that NaNs are in the same locations + nan_mask = torch.isnan(a) + self.assertTrue(torch.equal(nan_mask, torch.isnan(b)), message) + diff[nan_mask] = 0 + # inf check if allow_inf=True + if allow_inf: + inf_mask = torch.isinf(a) + inf_sign = inf_mask.sign() + self.assertTrue(torch.equal(inf_sign, torch.isinf(b).sign()), message) + diff[inf_mask] = 0 + # TODO: implement abs on CharTensor (int8) + if diff.is_signed() and diff.dtype != torch.int8: + diff = diff.abs() + max_err = diff.max() + self.assertLessEqual(max_err, prec, message) + super(TestCase, self).assertEqual(x.is_sparse, y.is_sparse, message) + super(TestCase, self).assertEqual(x.is_quantized, y.is_quantized, message) + if x.is_sparse: + x = self.safeCoalesce(x) + y = self.safeCoalesce(y) + assertTensorsEqual(x._indices(), y._indices()) + assertTensorsEqual(x._values(), y._values()) + elif x.is_quantized and y.is_quantized: + self.assertEqual(x.qscheme(), y.qscheme(), prec=prec, + message=message, allow_inf=allow_inf) + if x.qscheme() == torch.per_tensor_affine: + self.assertEqual(x.q_scale(), y.q_scale(), prec=prec, + message=message, allow_inf=allow_inf) + self.assertEqual(x.q_zero_point(), y.q_zero_point(), + prec=prec, message=message, + allow_inf=allow_inf) + elif x.qscheme() == torch.per_channel_affine: + self.assertEqual(x.q_per_channel_scales(), y.q_per_channel_scales(), prec=prec, + message=message, allow_inf=allow_inf) + self.assertEqual(x.q_per_channel_zero_points(), y.q_per_channel_zero_points(), + prec=prec, message=message, + allow_inf=allow_inf) + self.assertEqual(x.q_per_channel_axis(), y.q_per_channel_axis(), + prec=prec, message=message) + self.assertEqual(x.dtype, y.dtype) + self.assertEqual(x.int_repr().to(torch.int32), + y.int_repr().to(torch.int32), prec=prec, + message=message, allow_inf=allow_inf) + else: + assertTensorsEqual(x, y) + elif isinstance(x, string_classes) and isinstance(y, string_classes): + super(TestCase, self).assertEqual(x, y, message) + elif type(x) == set and type(y) == set: + super(TestCase, self).assertEqual(x, y, message) + elif isinstance(x, dict) and isinstance(y, dict): + if isinstance(x, OrderedDict) and isinstance(y, OrderedDict): + self.assertEqual(x.items(), y.items(), prec=prec, + message=message, allow_inf=allow_inf) + else: + self.assertEqual(set(x.keys()), set(y.keys()), prec=prec, + message=message, allow_inf=allow_inf) + key_list = list(x.keys()) + self.assertEqual([x[k] for k in key_list], + [y[k] for k in key_list], + prec=prec, message=message, + allow_inf=allow_inf) + elif is_iterable(x) and is_iterable(y): + super(TestCase, self).assertEqual(len(x), len(y), message) + for x_, y_ in zip(x, y): + self.assertEqual(x_, y_, prec=prec, message=message, + allow_inf=allow_inf) + elif isinstance(x, bool) and isinstance(y, bool): + super(TestCase, self).assertEqual(x, y, message) + elif isinstance(x, Number) and isinstance(y, Number): + if abs(x) == inf or abs(y) == inf: + if allow_inf: + super(TestCase, self).assertEqual(x, y, message) + else: + self.fail("Expected finite numeric values - x={}, y={}".format(x, y)) + return + super(TestCase, self).assertLessEqual(abs(x - y), prec, message) + else: + super(TestCase, self).assertEqual(x, y, message) + def assertNestedTensorObjectsEqual(self, a, b, rtol=None, atol=None): self.assertEqual(type(a), type(b)) @@ -151,3 +281,59 @@ def assertNestedTensorObjectsEqual(self, a, b, rtol=None, atol=None): else: self.assertEqual(a, b) + + def checkModule(self, nn_module, args, unwrapper=None, skip=False): + """ + Check that a nn.Module's results in TorchScript match eager and that it + can be exported + """ + if not torch.jit._enabled or skip: + # TorchScript is not enabled, skip these tests + return + + sm = torch.jit.script(nn_module) + + with freeze_rng_state(): + eager_out = nn_module(*args) + + with freeze_rng_state(): + script_out = sm(*args) + if unwrapper: + script_out = unwrapper(script_out) + + self.assertEqual(eager_out, script_out) + self.assertExportImportModule(sm, args) + + return sm + + def getExportImportCopy(self, m): + """ + Save and load a TorchScript model + """ + buffer = io.BytesIO() + torch.jit.save(m, buffer) + buffer.seek(0) + imported = torch.jit.load(buffer) + return imported + + def assertExportImportModule(self, m, args): + """ + Check that the results of a model are the same after saving and loading + """ + m_import = self.getExportImportCopy(m) + with freeze_rng_state(): + results = m(*args) + with freeze_rng_state(): + results_from_imported = m_import(*args) + self.assertEqual(results, results_from_imported) + + +@contextlib.contextmanager +def freeze_rng_state(): + rng_state = torch.get_rng_state() + if torch.cuda.is_available(): + cuda_rng_state = torch.cuda.get_rng_state() + yield + if torch.cuda.is_available(): + torch.cuda.set_rng_state(cuda_rng_state) + torch.set_rng_state(rng_state) diff --git a/test/test_models.py b/test/test_models.py index 1864d233772..3c48b269c4f 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -1,4 +1,4 @@ -from common_utils import TestCase, map_nested_tensor_object +from common_utils import TestCase, map_nested_tensor_object, freeze_rng_state from collections import OrderedDict from itertools import product import torch @@ -38,44 +38,51 @@ def get_available_video_models(): # models that are in torch hub, as well as r3d_18. we tried testing all models # but the test was too slow. not included are detection models, because # they are not yet supported in JIT. -script_test_models = [ - "deeplabv3_resnet101", - "mobilenet_v2", - "resnext50_32x4d", - "fcn_resnet101", - "googlenet", - "densenet121", - "resnet18", - "alexnet", - "shufflenet_v2_x1_0", - "squeezenet1_0", - "vgg11", - "inception_v3", - 'r3d_18', +# If 'unwrapper' is provided it will be called with the script model outputs +# before they are compared to the eager model outputs. This is useful if the +# model outputs are different between TorchScript / Eager mode +script_test_models = { + "deeplabv3_resnet101": {}, + "mobilenet_v2": {}, + "resnext50_32x4d": {}, + "fcn_resnet101": {}, + "googlenet": { + "unwrapper": lambda x: x.logits + }, + "densenet121": {}, + "resnet18": {}, + "alexnet": {}, + "shufflenet_v2_x1_0": {}, + "squeezenet1_0": {}, + "vgg11": {}, + "inception_v3": {}, + 'r3d_18': {}, +} + + +# These models don't work with checkModule, this list should be deleted as soon +# as possible +SCRIPT_MODELS_TO_FIX = [ + 'test_inception_v3', + 'test_fcn_resnet101', + 'test_deeplabv3_resnet101', ] class ModelTester(TestCase): - def check_script(self, model, name): + def checkModule(self, model, name, args): if name not in script_test_models: return - scriptable = True - msg = "" - try: - torch.jit.script(model) - except Exception as e: - tb = traceback.format_exc() - scriptable = False - msg = str(e) + str(tb) - self.assertTrue(scriptable, msg) + unwrapper = script_test_models[name].get('unwrapper', None) + return super(ModelTester, self).checkModule(model, args, unwrapper=unwrapper, skip=name in SCRIPT_MODELS_TO_FIX) def _test_classification_model(self, name, input_shape): # passing num_class equal to a number other than 1000 helps in making the test # more enforcing in nature model = models.__dict__[name](num_classes=50) - self.check_script(model, name) model.eval() x = torch.rand(input_shape) + self.checkModule(model, name, (x,)) out = model(x) self.assertEqual(out.shape[-1], 50) @@ -83,20 +90,20 @@ def _test_segmentation_model(self, name): # passing num_class equal to a number other than 1000 helps in making the test # more enforcing in nature model = models.segmentation.__dict__[name](num_classes=50, pretrained_backbone=False) - self.check_script(model, name) model.eval() input_shape = (1, 3, 300, 300) x = torch.rand(input_shape) + self.checkModule(model, name, (x,)) out = model(x) self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300)) def _test_detection_model(self, name): set_rng_seed(0) model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False) - self.check_script(model, name) model.eval() input_shape = (3, 300, 300) x = torch.rand(input_shape) + self.checkModule(model, name, (x,)) model_input = [x] out = model(model_input) self.assertIs(model_input[0], x) @@ -138,9 +145,10 @@ def _test_video_model(self, name): input_shape = (1, 3, 4, 112, 112) # test both basicblock and Bottleneck model = models.video.__dict__[name](num_classes=50) - self.check_script(model, name) + model.eval() x = torch.rand(input_shape) out = model(x) + self.checkModule(model, name, (x,)) self.assertEqual(out.shape[-1], 50) def _make_sliced_model(self, model, stop_layer): diff --git a/torchvision/models/segmentation/_utils.py b/torchvision/models/segmentation/_utils.py index c5a7ae99e43..ffd655b016e 100644 --- a/torchvision/models/segmentation/_utils.py +++ b/torchvision/models/segmentation/_utils.py @@ -10,6 +10,7 @@ class _SimpleSegmentationModel(nn.Module): def __init__(self, backbone, classifier, aux_classifier=None): super(_SimpleSegmentationModel, self).__init__() + print('ssm', type(backbone), type(classifier), type(aux_classifier)) self.backbone = backbone self.classifier = classifier self.aux_classifier = aux_classifier From ccfd00b7e8081affdcf2369de9cc34c1386c2587 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 8 Oct 2019 10:22:55 -0700 Subject: [PATCH 02/17] Fix inception, use PYTORCH_TEST_WITH_SLOW --- test/common_utils.py | 4 +++- test/test_models.py | 5 +++-- torchvision/models/segmentation/_utils.py | 1 - 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index a4a2e62dd47..2110c38dd4d 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -28,6 +28,8 @@ def get_tmp_dir(src=None, **kwargs): ACCEPT = os.getenv('EXPECTTEST_ACCEPT') +TEST_WITH_SLOW = os.getenv('PYTORCH_TEST_WITH_SLOW', '0') == '1' + parser = argparse.ArgumentParser(add_help=False) parser.add_argument('--accept', action='store_true') @@ -287,7 +289,7 @@ def checkModule(self, nn_module, args, unwrapper=None, skip=False): Check that a nn.Module's results in TorchScript match eager and that it can be exported """ - if not torch.jit._enabled or skip: + if not TEST_WITH_SLOW or skip: # TorchScript is not enabled, skip these tests return diff --git a/test/test_models.py b/test/test_models.py index 3c48b269c4f..cf39f9ab890 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -55,7 +55,9 @@ def get_available_video_models(): "shufflenet_v2_x1_0": {}, "squeezenet1_0": {}, "vgg11": {}, - "inception_v3": {}, + "inception_v3": { + "unwrapper": lambda x: x.logits + }, 'r3d_18': {}, } @@ -63,7 +65,6 @@ def get_available_video_models(): # These models don't work with checkModule, this list should be deleted as soon # as possible SCRIPT_MODELS_TO_FIX = [ - 'test_inception_v3', 'test_fcn_resnet101', 'test_deeplabv3_resnet101', ] diff --git a/torchvision/models/segmentation/_utils.py b/torchvision/models/segmentation/_utils.py index ffd655b016e..c5a7ae99e43 100644 --- a/torchvision/models/segmentation/_utils.py +++ b/torchvision/models/segmentation/_utils.py @@ -10,7 +10,6 @@ class _SimpleSegmentationModel(nn.Module): def __init__(self, backbone, classifier, aux_classifier=None): super(_SimpleSegmentationModel, self).__init__() - print('ssm', type(backbone), type(classifier), type(aux_classifier)) self.backbone = backbone self.classifier = classifier self.aux_classifier = aux_classifier From 9299fe211cbad8ce0907eca307eb05d1e826bd4d Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 8 Oct 2019 10:44:39 -0700 Subject: [PATCH 03/17] Update --- test/test_models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index cf39f9ab890..63852026496 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -62,10 +62,10 @@ def get_available_video_models(): } -# These models don't work with checkModule, this list should be deleted as soon -# as possible SCRIPT_MODELS_TO_FIX = [ - 'test_fcn_resnet101', + # This model fails in the TorchScript interpreter, see + # https://github.com/pytorch/pytorch/issues/27549. Delete this list when + # that issue is closed. 'test_deeplabv3_resnet101', ] From fed5ce59865378085d54d811bb110a29274ac677 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 8 Oct 2019 10:48:03 -0700 Subject: [PATCH 04/17] Remove assertNestedTensorObjectsEqual --- test/common_utils.py | 25 ++----------------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index 2110c38dd4d..a00c3691947 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -140,13 +140,13 @@ def accept_output(update_type): if ACCEPT: equal = False try: - equal = self.assertNestedTensorObjectsEqual(output, expected, rtol=rtol, atol=atol) + equal = self.assertEqual(output, expected, rtol=rtol, atol=atol) except Exception: equal = False if not equal: return accept_output("updated output") else: - self.assertNestedTensorObjectsEqual(output, expected, rtol=rtol, atol=atol) + self.assertEqual(output, expected, rtol=rtol, atol=atol) def assertEqual(self, x, y, prec=None, message='', allow_inf=False): """ @@ -263,27 +263,6 @@ def assertTensorsEqual(a, b): else: super(TestCase, self).assertEqual(x, y, message) - def assertNestedTensorObjectsEqual(self, a, b, rtol=None, atol=None): - self.assertEqual(type(a), type(b)) - - if isinstance(a, torch.Tensor): - torch.testing.assert_allclose(a, b, rtol=rtol, atol=atol) - - elif isinstance(a, dict): - self.assertEqual(len(a), len(b)) - for key, value in a.items(): - self.assertTrue(key in b, "key: " + str(key)) - - self.assertNestedTensorObjectsEqual(value, b[key], rtol=rtol, atol=atol) - elif isinstance(a, (list, tuple)): - self.assertEqual(len(a), len(b)) - - for val1, val2 in zip(a, b): - self.assertNestedTensorObjectsEqual(val1, val2, rtol=rtol, atol=atol) - - else: - self.assertEqual(a, b) - def checkModule(self, nn_module, args, unwrapper=None, skip=False): """ Check that a nn.Module's results in TorchScript match eager and that it From c8425579e5d94a9d2a3f9d8eb3d2da1c3b4e6ce2 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 8 Oct 2019 11:09:54 -0700 Subject: [PATCH 05/17] Add PYTORCH_TEST_WITH_SLOW to CircleCI config --- .circleci/config.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 8a09c45e507..77d98f6bb63 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -164,7 +164,8 @@ jobs: cd ${HOME}/project/ export DOCKER_IMAGE=soumith/conda-cuda - export VARS_TO_PASS="-e PYTHON_VERSION -e BUILD_VERSION -e PYTORCH_VERSION -e UNICODE_ABI -e CU_VERSION" + export PYTORCH_TEST_WITH_SLOW=1 + export VARS_TO_PASS="-e PYTHON_VERSION -e BUILD_VERSION -e PYTORCH_VERSION -e UNICODE_ABI -e CU_VERSION -e PYTORCH_TEST_WITH_SLOW" docker run --gpus all --ipc=host -v $(pwd):/remote -w /remote ${VARS_TO_PASS} ${DOCKER_IMAGE} ./packaging/build_conda.sh From d4ca330732ca56c74f0a9b68434a7150d3eb571f Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 8 Oct 2019 13:10:01 -0700 Subject: [PATCH 06/17] Add MaskRCNN unwrapper --- test/test_models.py | 37 +++++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index 63852026496..422ebc6162b 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -42,23 +42,32 @@ def get_available_video_models(): # before they are compared to the eager model outputs. This is useful if the # model outputs are different between TorchScript / Eager mode script_test_models = { - "deeplabv3_resnet101": {}, - "mobilenet_v2": {}, - "resnext50_32x4d": {}, - "fcn_resnet101": {}, - "googlenet": { - "unwrapper": lambda x: x.logits + 'deeplabv3_resnet101': {}, + 'mobilenet_v2': {}, + 'resnext50_32x4d': {}, + 'fcn_resnet101': {}, + 'googlenet': { + 'unwrapper': lambda x: x.logits }, - "densenet121": {}, - "resnet18": {}, - "alexnet": {}, - "shufflenet_v2_x1_0": {}, - "squeezenet1_0": {}, - "vgg11": {}, - "inception_v3": { - "unwrapper": lambda x: x.logits + 'densenet121': {}, + 'resnet18': {}, + 'alexnet': {}, + 'shufflenet_v2_x1_0': {}, + 'squeezenet1_0': {}, + 'vgg11': {}, + 'inception_v3': { + 'unwrapper': lambda x: x.logits }, 'r3d_18': {}, + 'fasterrcnn_resnet50_fpn': { + 'unwrapper': lambda x: x[1] + }, + 'maskrcnn_resnet50_fpn': { + 'unwrapper': lambda x: x[1] + }, + 'keypointrcnn_resnet50_fpn': { + 'unwrapper': lambda x: x[1] + }, } From 557a15c2eb043c0f5e6a9eff7ebb7c3b9183351e Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 8 Oct 2019 15:47:09 -0700 Subject: [PATCH 07/17] fix prec args --- test/common_utils.py | 6 +++--- test/test_models.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index a00c3691947..d378aa2d285 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -84,7 +84,7 @@ def is_iterable(obj): class TestCase(unittest.TestCase): precision = 1e-5 - def assertExpected(self, output, subname=None, rtol=None, atol=None): + def assertExpected(self, output, subname=None, prec=None): r""" Test that a python value matches the recorded contents of a file derived from the name of this test and subname. The value must be @@ -140,13 +140,13 @@ def accept_output(update_type): if ACCEPT: equal = False try: - equal = self.assertEqual(output, expected, rtol=rtol, atol=atol) + equal = self.assertEqual(output, expected, prec=prec) except Exception: equal = False if not equal: return accept_output("updated output") else: - self.assertEqual(output, expected, rtol=rtol, atol=atol) + self.assertEqual(output, expected, prec=prec) def assertEqual(self, x, y, prec=None, message='', allow_inf=False): """ diff --git a/test/test_models.py b/test/test_models.py index 422ebc6162b..55bcfe50501 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -140,8 +140,8 @@ def compute_mean_std(tensor): # compare results with mean and std if name == "maskrcnn_resnet50_fpn": test_value = map_nested_tensor_object(out, tensor_map_fn=compute_mean_std) - # mean values are small, use large rtol - self.assertExpected(test_value, rtol=.01, atol=.01) + # mean values are small, use large prec + self.assertExpected(test_value, prec=.01) else: self.assertExpected(map_nested_tensor_object(out, tensor_map_fn=subsample_tensor)) From b3471f9c6bb62bb679af30ac90bc03fc517efe0a Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 15 Oct 2019 13:13:26 -0700 Subject: [PATCH 08/17] Remove CI changes --- .circleci/config.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 77d98f6bb63..8a09c45e507 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -164,8 +164,7 @@ jobs: cd ${HOME}/project/ export DOCKER_IMAGE=soumith/conda-cuda - export PYTORCH_TEST_WITH_SLOW=1 - export VARS_TO_PASS="-e PYTHON_VERSION -e BUILD_VERSION -e PYTORCH_VERSION -e UNICODE_ABI -e CU_VERSION -e PYTORCH_TEST_WITH_SLOW" + export VARS_TO_PASS="-e PYTHON_VERSION -e BUILD_VERSION -e PYTORCH_VERSION -e UNICODE_ABI -e CU_VERSION" docker run --gpus all --ipc=host -v $(pwd):/remote -w /remote ${VARS_TO_PASS} ${DOCKER_IMAGE} ./packaging/build_conda.sh From 7cc28974cffa6b8b5cad2b87aeff15202c2ade76 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 23 Oct 2019 13:16:18 -0700 Subject: [PATCH 09/17] update --- test/common_utils.py | 8 +++++++- test/test_models.py | 22 +++++++--------------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index d378aa2d285..ab96be02061 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -29,6 +29,7 @@ def get_tmp_dir(src=None, **kwargs): ACCEPT = os.getenv('EXPECTTEST_ACCEPT') TEST_WITH_SLOW = os.getenv('PYTORCH_TEST_WITH_SLOW', '0') == '1' +TEST_WITH_SLOW = True # TODO: Delete this line once there is a PYTORCH_TEST_WITH_SLOW aware CI job parser = argparse.ArgumentParser(add_help=False) @@ -310,7 +311,12 @@ def assertExportImportModule(self, m, args): @contextlib.contextmanager -def freeze_rng_state(): +def freeze_rng_state(seed=None): + if seed: + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + rng_state = torch.get_rng_state() if torch.cuda.is_available(): cuda_rng_state = torch.cuda.get_rng_state() diff --git a/test/test_models.py b/test/test_models.py index a3ecc4fdaa7..99d3f91e771 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -59,22 +59,13 @@ def get_available_video_models(): 'unwrapper': lambda x: x.logits }, 'r3d_18': {}, - 'fasterrcnn_resnet50_fpn': { - 'unwrapper': lambda x: x[1] - }, - 'maskrcnn_resnet50_fpn': { - 'unwrapper': lambda x: x[1] - }, - 'keypointrcnn_resnet50_fpn': { - 'unwrapper': lambda x: x[1] - }, } SCRIPT_MODELS_TO_FIX = [ # This model fails in the TorchScript interpreter, see - # https://github.com/pytorch/pytorch/issues/27549. Delete this list when - # that issue is closed. + # https://github.com/pytorch/vision/pull/1436. Delete this list when + # that PR is closed. 'test_deeplabv3_resnet101', ] @@ -93,10 +84,11 @@ def _test_classification_model(self, name, input_shape): model = models.__dict__[name](num_classes=50) model.eval() x = torch.rand(input_shape) - self.checkModule(model, name, (x,)) out = model(x) - self.assertExpected(out, rtol=1e-2, atol=0.) + self.assertExpected(out) + # self.assertExpected(out, prec=1e-2) self.assertEqual(out.shape[-1], 50) + self.checkModule(model, name, (x,)) def _test_segmentation_model(self, name): # passing num_class equal to a number other than 1000 helps in making the test @@ -105,9 +97,9 @@ def _test_segmentation_model(self, name): model.eval() input_shape = (1, 3, 300, 300) x = torch.rand(input_shape) - self.checkModule(model, name, (x,)) out = model(x) self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300)) + self.checkModule(model, name, (x,)) def _test_detection_model(self, name): set_rng_seed(0) @@ -115,11 +107,11 @@ def _test_detection_model(self, name): model.eval() input_shape = (3, 300, 300) x = torch.rand(input_shape) - self.checkModule(model, name, (x,)) model_input = [x] out = model(model_input) self.assertIs(model_input[0], x) self.assertEqual(len(out), 1) + self.checkModule(model, name, (x,)) def subsample_tensor(tensor): num_elems = tensor.numel() From 8922bef6b129c722a20ab887fe0ceee19424ad4b Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 23 Oct 2019 13:46:30 -0700 Subject: [PATCH 10/17] Update --- test/common_utils.py | 7 +------ .../ModelTester.test_inception_v3_expect.pkl | Bin 543 -> 541 bytes .../ModelTester.test_resnet152_expect.pkl | Bin 543 -> 541 bytes test/test_models.py | 9 ++++----- 4 files changed, 5 insertions(+), 11 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index ab96be02061..1a7ba925385 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -311,12 +311,7 @@ def assertExportImportModule(self, m, args): @contextlib.contextmanager -def freeze_rng_state(seed=None): - if seed: - torch.manual_seed(seed) - random.seed(seed) - np.random.seed(seed) - +def freeze_rng_state(): rng_state = torch.get_rng_state() if torch.cuda.is_available(): cuda_rng_state = torch.cuda.get_rng_state() diff --git a/test/expect/ModelTester.test_inception_v3_expect.pkl b/test/expect/ModelTester.test_inception_v3_expect.pkl index 38574f498a5db5800efe91f4d77d4e2a78179f14..f325937f64d5f97e6bf098a89e27add990b47775 100644 GIT binary patch delta 260 zcmbQwGM8n-RUtkG1_n!G3nNnlV&=4R*4 zBr~5AUb^X=X!>+t->2^9_Gq`Cec*V``|v?FzhK^j=Pc^K_{7^d`o%tLK5Mn*_PM4B ztmiDlC;Boc#QClEzvs(2`;GVfS08+f47U3E@lHL<>#E}ScJZGxEa@rN&Ik!UI_q#g z@SKX^!Lw%1W}cHy={Q#qGxMCsc8)Wlp90S{%z1g%%(~aNqW-O~@4cI64@5mWdpu&l z?>8rxbAObT{1WW=&b_pK>&w49$?tKlfuF>N#B;$62Ygh{UOxL}i`>~jX+b{#r2UTBV!A5qlwQ0>ahqHGA8O70ab&*f{pg4ea+3z zok(UrC$MzWIkEKVzCKUg&+XD~KfA*5ocF}X>VMDIV9p!w&Q~9Niww5^>`SzGH~-`4uKz9IK+o_!Sc=g>D@8$a^Wc-8MUW&W#^+0XdBI=)t&_PuRz%5~+JQ&RsHdD+zeKJ`$^-h0CG zji-(WdwZx%nCaQLD(dvT9ji}Iaol(MuI#;&t!t-vwLj$cUd6=X&69pc?R4fx_S1|U z%TH&VTzKl5t<&55OR_v$0G%K8ryYk$NFVMtWvw5O-%*#A)lW*rwRW~x7mQs;8 zvwz!^Q(v#LduK+Ko<8!m>9maGW^cWzpHEJ{&FQW1gzvPGWsUdw`7WpPS`T{x0IDr< AdH?_b delta 261 zcmbQsGM{C_RUv)`1_nbD15-;&6LT|5LnA}WiO&M-F-46E858x4fQmstDowmAdaLQ@rms4v07I|6J|339h$=-YV z@{OkgL%cnbCd~A#TNQQs){fPur#S9Ay+`)m$3mA+H_h Date: Wed, 23 Oct 2019 14:04:33 -0700 Subject: [PATCH 11/17] remove expect changes --- .../ModelTester.test_inception_v3_expect.pkl | Bin 541 -> 543 bytes .../ModelTester.test_resnet152_expect.pkl | Bin 541 -> 543 bytes 2 files changed, 0 insertions(+), 0 deletions(-) diff --git a/test/expect/ModelTester.test_inception_v3_expect.pkl b/test/expect/ModelTester.test_inception_v3_expect.pkl index f325937f64d5f97e6bf098a89e27add990b47775..38574f498a5db5800efe91f4d77d4e2a78179f14 100644 GIT binary patch delta 262 zcmbQsGM{C_RbhSx1_nbD12Y37V>2UTBV!A5qlwQ0>ahqHGA8O70ab&*f{pg4ea+3z zok(UrC$MzWIkEKVzCKUg&+XD~KfA*5ocF}X>VMDIV9p!w&Q~9Niww5^>`SzGH~-`4uKz9IK+o_!Sc=&=4R*4 zBr~5AUb^X=X!>+t->2^9_Gq`Cec*V``|v?FzhK^j=Pc^K_{7^d`o%tLK5Mn*_PM4B ztmiDlC;Boc#QClEzvs(2`;GVfS08+f47U3E@lHL<>#E}ScJZGxEa@rN&Ik!UI_q#g z@SKX^!Lw%1W}cHy={Q#qGxMCsc8)Wlp90S{%z1g%%(~aNqW-O~@4cI64@5mWdpu&l z?>8rxbAObT{1WW=&b_pK>&w49$?tKlfuF>N#B;$62Ygh{UOxL}i`>~jX+b{#rDowmAdaLQ@rms4v07I|6J|339h$=-YV z@{OkgL%cnbCd~A#TNQQs){fPur#S9Ay+`)m$3mA+H_hg>D@8$a^Wc-8MUW&W#^+0XdBI=)t&_PuRz%5~+JQ&RsHdD+zeKJ`$^-h0CG zji-(WdwZx%nCaQLD(dvT9ji}Iaol(MuI#;&t!t-vwLj$cUd6=X&69pc?R4fx_S1|U z%TH&VTzKl5t<&55OR_v$0G%K8ryYk$NFVMtWvw5O-%*#A)lW*rwRW~x7mQs;8 zvwz!^Q(v#LduK+Ko<8!m>9maGW^cWzpHEJ{&FQW1gzvPGWsUdw`7WpPS`T{x0IDr< AdH?_b From fda9a924070b8c96039bae1e22197e51f5525b0b Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 24 Oct 2019 18:07:15 -0700 Subject: [PATCH 12/17] Fix tolerance bug --- test/common_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/common_utils.py b/test/common_utils.py index 1a7ba925385..e9acd0450e2 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -199,7 +199,8 @@ def assertTensorsEqual(a, b): if diff.is_signed() and diff.dtype != torch.int8: diff = diff.abs() max_err = diff.max() - self.assertLessEqual(max_err, prec, message) + tolerance = prec + prec * abs(a.max()) + self.assertLessEqual(max_err, tolerance, message) super(TestCase, self).assertEqual(x.is_sparse, y.is_sparse, message) super(TestCase, self).assertEqual(x.is_quantized, y.is_quantized, message) if x.is_sparse: From 2da2d19fadad52f0305be2bab2390a50c5be5ec1 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 5 Nov 2019 16:36:36 -0800 Subject: [PATCH 13/17] Fix breakages --- test/common_utils.py | 2 +- torchvision/models/mobilenet.py | 8 +++++--- torchvision/models/quantization/mobilenet.py | 2 +- torchvision/models/quantization/shufflenetv2.py | 2 +- torchvision/models/resnet.py | 7 ++++--- torchvision/models/shufflenetv2.py | 6 ++++-- 6 files changed, 16 insertions(+), 11 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index e9acd0450e2..7d3b18684de 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -29,7 +29,7 @@ def get_tmp_dir(src=None, **kwargs): ACCEPT = os.getenv('EXPECTTEST_ACCEPT') TEST_WITH_SLOW = os.getenv('PYTORCH_TEST_WITH_SLOW', '0') == '1' -TEST_WITH_SLOW = True # TODO: Delete this line once there is a PYTORCH_TEST_WITH_SLOW aware CI job +TEST_WITH_SLOW = True # TODO: Delete this line once there is a PYTORCH_TEST_WITH_SLOW aware CI job parser = argparse.ArgumentParser(add_help=False) diff --git a/torchvision/models/mobilenet.py b/torchvision/models/mobilenet.py index b3ba049a4c6..6d10610b633 100644 --- a/torchvision/models/mobilenet.py +++ b/torchvision/models/mobilenet.py @@ -147,14 +147,16 @@ def __init__(self, nn.init.normal_(m.weight, 0, 0.01) nn.init.zeros_(m.bias) - def _forward(self, x): + def _forward_impl(self, x): + # This exists since TorchScript doesn't support inheritance, so the superclass method + # (this one) needs to have a name other than `forward` that can be accessed in a subclass x = self.features(x) x = x.mean([2, 3]) x = self.classifier(x) return x - # Allow for accessing forward method in a inherited class - forward = _forward + def forward(self, x): + return self._forward_impl(x) def mobilenet_v2(pretrained=False, progress=True, **kwargs): diff --git a/torchvision/models/quantization/mobilenet.py b/torchvision/models/quantization/mobilenet.py index e665f121234..1d14410f376 100644 --- a/torchvision/models/quantization/mobilenet.py +++ b/torchvision/models/quantization/mobilenet.py @@ -44,7 +44,7 @@ def __init__(self, *args, **kwargs): def forward(self, x): x = self.quant(x) - x = self._forward(x) + x = self._forward_impl(x) x = self.dequant(x) return x diff --git a/torchvision/models/quantization/shufflenetv2.py b/torchvision/models/quantization/shufflenetv2.py index 9826e61f679..c9aeb2368ad 100644 --- a/torchvision/models/quantization/shufflenetv2.py +++ b/torchvision/models/quantization/shufflenetv2.py @@ -46,7 +46,7 @@ def __init__(self, *args, **kwargs): def forward(self, x): x = self.quant(x) - x = self._forward(x) + x = self._forward_impl(x) x = self.dequant(x) return x diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index 223159ccae6..527eab8ff05 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -194,7 +194,8 @@ def _make_layer(self, block, planes, blocks, stride=1, dilate=False): return nn.Sequential(*layers) - def _forward(self, x): + def _forward_impl(self, x): + # See note [TorchScript super()] x = self.conv1(x) x = self.bn1(x) x = self.relu(x) @@ -211,8 +212,8 @@ def _forward(self, x): return x - # Allow for accessing forward method in a inherited class - forward = _forward + def forward(self, x): + return self._forward_impl(x) def _resnet(arch, block, layers, pretrained, progress, **kwargs): diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index 7817e8aa1c1..14f9521886c 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -122,7 +122,8 @@ def __init__(self, stages_repeats, stages_out_channels, num_classes=1000, invert self.fc = nn.Linear(output_channels, num_classes) - def _forward(self, x): + def _forward_impl(self, x): + # See note [TorchScript super()] x = self.conv1(x) x = self.maxpool(x) x = self.stage2(x) @@ -133,7 +134,8 @@ def _forward(self, x): x = self.fc(x) return x - forward = _forward + def forward(self, x): + return self._forward_impl(x) def _shufflenetv2(arch, pretrained, progress, *args, **kwargs): From f00dcd1036d5740cfef714c0aa4ced88fd7e7944 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 6 Nov 2019 11:02:10 +0100 Subject: [PATCH 14/17] Fix quantized resnet --- torchvision/models/quantization/resnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/quantization/resnet.py b/torchvision/models/quantization/resnet.py index f00b7ed46d3..5fd3c039299 100644 --- a/torchvision/models/quantization/resnet.py +++ b/torchvision/models/quantization/resnet.py @@ -95,7 +95,7 @@ def forward(self, x): # Ensure scriptability # super(QuantizableResNet,self).forward(x) # is not scriptable - x = self._forward(x) + x = self._forward_impl(x) x = self.dequant(x) return x From 6e7bc18ead7af3eeaed732e32982eef95537d191 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 26 Nov 2019 17:17:39 +0100 Subject: [PATCH 15/17] Fix merge errors and simplify code --- test/test_models.py | 8 ++++---- torchvision/models/detection/rpn.py | 5 +---- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index 6873606a165..611349e432e 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -119,7 +119,6 @@ def _test_detection_model(self, name): out = model(model_input) self.assertIs(model_input[0], x) self.assertEqual(len(out), 1) - self.checkModule(model, name, (x,)) def subsample_tensor(tensor): num_elems = tensor.numel() @@ -150,16 +149,17 @@ def compute_mean_std(tensor): scripted_model = torch.jit.script(model) scripted_model.eval() scripted_out = scripted_model(model_input)[1] - self.assertNestedTensorObjectsEqual(scripted_out[0]["boxes"], out[0]["boxes"]) - self.assertNestedTensorObjectsEqual(scripted_out[0]["scores"], out[0]["scores"]) + self.assertEqual(scripted_out[0]["boxes"], out[0]["boxes"]) + self.assertEqual(scripted_out[0]["scores"], out[0]["scores"]) # labels currently float in script: need to investigate (though same result) - self.assertNestedTensorObjectsEqual(scripted_out[0]["labels"].to(dtype=torch.long), out[0]["labels"]) + self.assertEqual(scripted_out[0]["labels"].to(dtype=torch.long), out[0]["labels"]) self.assertTrue("boxes" in out[0]) self.assertTrue("scores" in out[0]) self.assertTrue("labels" in out[0]) # don't check script because we are compiling it here: # TODO: refactor tests # self.check_script(model, name) + self.checkModule(model, name, ([x],)) def _test_video_model(self, name): # the default input shape is diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index 6dfcf788d0b..4b767f8209b 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -216,7 +216,6 @@ def concat_box_prediction_layers(box_cls, box_regression): # same format as the labels. Note that the labels are computed for # all feature levels concatenated, so we keep the same representation # for the objectness and the box_regression - last_C = torch.jit.annotate(Optional[int], None) for box_cls_per_level, box_regression_per_level in zip( box_cls, box_regression ): @@ -229,16 +228,14 @@ def concat_box_prediction_layers(box_cls, box_regression): ) box_cls_flattened.append(box_cls_per_level) - last_C = C box_regression_per_level = permute_and_flatten( box_regression_per_level, N, A, 4, H, W ) box_regression_flattened.append(box_regression_per_level) - assert last_C is not None # concatenate on the first dimension (representing the feature levels), to # take into account the way the labels were generated (with all feature maps # being concatenated as well) - box_cls = torch.cat(box_cls_flattened, dim=1).reshape(-1, last_C) + box_cls = torch.cat(box_cls_flattened, dim=1).flatten(0, -2) box_regression = torch.cat(box_regression_flattened, dim=1).reshape(-1, 4) return box_cls, box_regression From 305b8d7d9c5707f8a8de425e781cd280e4a9a4ef Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 26 Nov 2019 17:20:43 +0100 Subject: [PATCH 16/17] DeepLabV3 has been fixed --- test/test_models.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index 611349e432e..14c70175dc0 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -71,20 +71,12 @@ def get_available_video_models(): } -SCRIPT_MODELS_TO_FIX = [ - # This model fails in the TorchScript interpreter, see - # https://github.com/pytorch/vision/pull/1436. Delete this list when - # that PR is closed. - 'deeplabv3_resnet101', -] - - class ModelTester(TestCase): def checkModule(self, model, name, args): if name not in script_test_models: return unwrapper = script_test_models[name].get('unwrapper', None) - return super(ModelTester, self).checkModule(model, args, unwrapper=unwrapper, skip=name in SCRIPT_MODELS_TO_FIX) + return super(ModelTester, self).checkModule(model, args, unwrapper=unwrapper, skip=False) def _test_classification_model(self, name, input_shape): set_rng_seed(0) From 6e077723588887d256ca7bf972cf3a7d669acb86 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 26 Nov 2019 18:17:26 +0100 Subject: [PATCH 17/17] Temporarily disable jit compilation --- test/common_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/common_utils.py b/test/common_utils.py index 7d3b18684de..b0a8fbe1c97 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -29,7 +29,7 @@ def get_tmp_dir(src=None, **kwargs): ACCEPT = os.getenv('EXPECTTEST_ACCEPT') TEST_WITH_SLOW = os.getenv('PYTORCH_TEST_WITH_SLOW', '0') == '1' -TEST_WITH_SLOW = True # TODO: Delete this line once there is a PYTORCH_TEST_WITH_SLOW aware CI job +# TEST_WITH_SLOW = True # TODO: Delete this line once there is a PYTORCH_TEST_WITH_SLOW aware CI job parser = argparse.ArgumentParser(add_help=False)