diff --git a/test/common_utils.py b/test/common_utils.py index 9c0c3175ef1..b0a8fbe1c97 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -5,10 +5,15 @@ import unittest import argparse import sys +import io import torch import errno import __main__ +from numbers import Number +from torch._six import string_classes, inf +from collections import OrderedDict + @contextlib.contextmanager def get_tmp_dir(src=None, **kwargs): @@ -23,6 +28,9 @@ def get_tmp_dir(src=None, **kwargs): ACCEPT = os.getenv('EXPECTTEST_ACCEPT') +TEST_WITH_SLOW = os.getenv('PYTORCH_TEST_WITH_SLOW', '0') == '1' +# TEST_WITH_SLOW = True # TODO: Delete this line once there is a PYTORCH_TEST_WITH_SLOW aware CI job + parser = argparse.ArgumentParser(add_help=False) parser.add_argument('--accept', action='store_true') @@ -64,10 +72,20 @@ def map_nested_tensor_object(object, tensor_map_fn): return impl(object) +def is_iterable(obj): + try: + iter(obj) + return True + except TypeError: + return False + + # adapted from TestCase in torch/test/common_utils to accept non-string # inputs and set maximum binary size class TestCase(unittest.TestCase): - def assertExpected(self, output, subname=None, rtol=None, atol=None): + precision = 1e-5 + + def assertExpected(self, output, subname=None, prec=None): r""" Test that a python value matches the recorded contents of a file derived from the name of this test and subname. The value must be @@ -123,31 +141,182 @@ def accept_output(update_type): if ACCEPT: equal = False try: - equal = self.assertNestedTensorObjectsEqual(output, expected, rtol=rtol, atol=atol) + equal = self.assertEqual(output, expected, prec=prec) except Exception: equal = False if not equal: return accept_output("updated output") else: - self.assertNestedTensorObjectsEqual(output, expected, rtol=rtol, atol=atol) + self.assertEqual(output, expected, prec=prec) - def assertNestedTensorObjectsEqual(self, a, b, rtol=None, atol=None): - self.assertEqual(type(a), type(b)) + def assertEqual(self, x, y, prec=None, message='', allow_inf=False): + """ + This is copied from pytorch/test/common_utils.py's TestCase.assertEqual + """ + if isinstance(prec, str) and message == '': + message = prec + prec = None + if prec is None: + prec = self.precision - if isinstance(a, torch.Tensor): - torch.testing.assert_allclose(a, b, rtol=rtol, atol=atol) + if isinstance(x, torch.Tensor) and isinstance(y, Number): + self.assertEqual(x.item(), y, prec=prec, message=message, + allow_inf=allow_inf) + elif isinstance(y, torch.Tensor) and isinstance(x, Number): + self.assertEqual(x, y.item(), prec=prec, message=message, + allow_inf=allow_inf) + elif isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor): + def assertTensorsEqual(a, b): + super(TestCase, self).assertEqual(a.size(), b.size(), message) + if a.numel() > 0: + if (a.device.type == 'cpu' and (a.dtype == torch.float16 or a.dtype == torch.bfloat16)): + # CPU half and bfloat16 tensors don't have the methods we need below + a = a.to(torch.float32) + b = b.to(a) - elif isinstance(a, dict): - self.assertEqual(len(a), len(b)) - for key, value in a.items(): - self.assertTrue(key in b, "key: " + str(key)) + if (a.dtype == torch.bool) != (b.dtype == torch.bool): + raise TypeError("Was expecting both tensors to be bool type.") + else: + if a.dtype == torch.bool and b.dtype == torch.bool: + # we want to respect precision but as bool doesn't support substraction, + # boolean tensor has to be converted to int + a = a.to(torch.int) + b = b.to(torch.int) - self.assertNestedTensorObjectsEqual(value, b[key], rtol=rtol, atol=atol) - elif isinstance(a, (list, tuple)): - self.assertEqual(len(a), len(b)) + diff = a - b + if a.is_floating_point(): + # check that NaNs are in the same locations + nan_mask = torch.isnan(a) + self.assertTrue(torch.equal(nan_mask, torch.isnan(b)), message) + diff[nan_mask] = 0 + # inf check if allow_inf=True + if allow_inf: + inf_mask = torch.isinf(a) + inf_sign = inf_mask.sign() + self.assertTrue(torch.equal(inf_sign, torch.isinf(b).sign()), message) + diff[inf_mask] = 0 + # TODO: implement abs on CharTensor (int8) + if diff.is_signed() and diff.dtype != torch.int8: + diff = diff.abs() + max_err = diff.max() + tolerance = prec + prec * abs(a.max()) + self.assertLessEqual(max_err, tolerance, message) + super(TestCase, self).assertEqual(x.is_sparse, y.is_sparse, message) + super(TestCase, self).assertEqual(x.is_quantized, y.is_quantized, message) + if x.is_sparse: + x = self.safeCoalesce(x) + y = self.safeCoalesce(y) + assertTensorsEqual(x._indices(), y._indices()) + assertTensorsEqual(x._values(), y._values()) + elif x.is_quantized and y.is_quantized: + self.assertEqual(x.qscheme(), y.qscheme(), prec=prec, + message=message, allow_inf=allow_inf) + if x.qscheme() == torch.per_tensor_affine: + self.assertEqual(x.q_scale(), y.q_scale(), prec=prec, + message=message, allow_inf=allow_inf) + self.assertEqual(x.q_zero_point(), y.q_zero_point(), + prec=prec, message=message, + allow_inf=allow_inf) + elif x.qscheme() == torch.per_channel_affine: + self.assertEqual(x.q_per_channel_scales(), y.q_per_channel_scales(), prec=prec, + message=message, allow_inf=allow_inf) + self.assertEqual(x.q_per_channel_zero_points(), y.q_per_channel_zero_points(), + prec=prec, message=message, + allow_inf=allow_inf) + self.assertEqual(x.q_per_channel_axis(), y.q_per_channel_axis(), + prec=prec, message=message) + self.assertEqual(x.dtype, y.dtype) + self.assertEqual(x.int_repr().to(torch.int32), + y.int_repr().to(torch.int32), prec=prec, + message=message, allow_inf=allow_inf) + else: + assertTensorsEqual(x, y) + elif isinstance(x, string_classes) and isinstance(y, string_classes): + super(TestCase, self).assertEqual(x, y, message) + elif type(x) == set and type(y) == set: + super(TestCase, self).assertEqual(x, y, message) + elif isinstance(x, dict) and isinstance(y, dict): + if isinstance(x, OrderedDict) and isinstance(y, OrderedDict): + self.assertEqual(x.items(), y.items(), prec=prec, + message=message, allow_inf=allow_inf) + else: + self.assertEqual(set(x.keys()), set(y.keys()), prec=prec, + message=message, allow_inf=allow_inf) + key_list = list(x.keys()) + self.assertEqual([x[k] for k in key_list], + [y[k] for k in key_list], + prec=prec, message=message, + allow_inf=allow_inf) + elif is_iterable(x) and is_iterable(y): + super(TestCase, self).assertEqual(len(x), len(y), message) + for x_, y_ in zip(x, y): + self.assertEqual(x_, y_, prec=prec, message=message, + allow_inf=allow_inf) + elif isinstance(x, bool) and isinstance(y, bool): + super(TestCase, self).assertEqual(x, y, message) + elif isinstance(x, Number) and isinstance(y, Number): + if abs(x) == inf or abs(y) == inf: + if allow_inf: + super(TestCase, self).assertEqual(x, y, message) + else: + self.fail("Expected finite numeric values - x={}, y={}".format(x, y)) + return + super(TestCase, self).assertLessEqual(abs(x - y), prec, message) + else: + super(TestCase, self).assertEqual(x, y, message) - for val1, val2 in zip(a, b): - self.assertNestedTensorObjectsEqual(val1, val2, rtol=rtol, atol=atol) + def checkModule(self, nn_module, args, unwrapper=None, skip=False): + """ + Check that a nn.Module's results in TorchScript match eager and that it + can be exported + """ + if not TEST_WITH_SLOW or skip: + # TorchScript is not enabled, skip these tests + return - else: - self.assertEqual(a, b) + sm = torch.jit.script(nn_module) + + with freeze_rng_state(): + eager_out = nn_module(*args) + + with freeze_rng_state(): + script_out = sm(*args) + if unwrapper: + script_out = unwrapper(script_out) + + self.assertEqual(eager_out, script_out) + self.assertExportImportModule(sm, args) + + return sm + + def getExportImportCopy(self, m): + """ + Save and load a TorchScript model + """ + buffer = io.BytesIO() + torch.jit.save(m, buffer) + buffer.seek(0) + imported = torch.jit.load(buffer) + return imported + + def assertExportImportModule(self, m, args): + """ + Check that the results of a model are the same after saving and loading + """ + m_import = self.getExportImportCopy(m) + with freeze_rng_state(): + results = m(*args) + with freeze_rng_state(): + results_from_imported = m_import(*args) + self.assertEqual(results, results_from_imported) + + +@contextlib.contextmanager +def freeze_rng_state(): + rng_state = torch.get_rng_state() + if torch.cuda.is_available(): + cuda_rng_state = torch.cuda.get_rng_state() + yield + if torch.cuda.is_available(): + torch.cuda.set_rng_state(cuda_rng_state) + torch.set_rng_state(rng_state) diff --git a/test/test_models.py b/test/test_models.py index 0102824ac70..14c70175dc0 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -1,4 +1,4 @@ -from common_utils import TestCase, map_nested_tensor_object +from common_utils import TestCase, map_nested_tensor_object, freeze_rng_state from collections import OrderedDict from itertools import product import torch @@ -38,62 +38,68 @@ def get_available_video_models(): # models that are in torch hub, as well as r3d_18. we tried testing all models # but the test was too slow. not included are detection models, because # they are not yet supported in JIT. -script_test_models = [ - "deeplabv3_resnet101", - "mobilenet_v2", - "resnext50_32x4d", - "fcn_resnet101", - "googlenet", - "densenet121", - "resnet18", - "alexnet", - "shufflenet_v2_x1_0", - "squeezenet1_0", - "vgg11", - "inception_v3", - "r3d_18", - "fasterrcnn_resnet50_fpn", - "maskrcnn_resnet50_fpn", - "keypointrcnn_resnet50_fpn", -] +# If 'unwrapper' is provided it will be called with the script model outputs +# before they are compared to the eager model outputs. This is useful if the +# model outputs are different between TorchScript / Eager mode +script_test_models = { + 'deeplabv3_resnet101': {}, + 'mobilenet_v2': {}, + 'resnext50_32x4d': {}, + 'fcn_resnet101': {}, + 'googlenet': { + 'unwrapper': lambda x: x.logits + }, + 'densenet121': {}, + 'resnet18': {}, + 'alexnet': {}, + 'shufflenet_v2_x1_0': {}, + 'squeezenet1_0': {}, + 'vgg11': {}, + 'inception_v3': { + 'unwrapper': lambda x: x.logits + }, + 'r3d_18': {}, + "fasterrcnn_resnet50_fpn": { + 'unwrapper': lambda x: x[1] + }, + "maskrcnn_resnet50_fpn": { + 'unwrapper': lambda x: x[1] + }, + "keypointrcnn_resnet50_fpn": { + 'unwrapper': lambda x: x[1] + }, +} class ModelTester(TestCase): - def check_script(self, model, name): + def checkModule(self, model, name, args): if name not in script_test_models: return - scriptable = True - msg = "" - try: - torch.jit.script(model) - except Exception as e: - tb = traceback.format_exc() - scriptable = False - msg = str(e) + str(tb) - self.assertTrue(scriptable, msg) + unwrapper = script_test_models[name].get('unwrapper', None) + return super(ModelTester, self).checkModule(model, args, unwrapper=unwrapper, skip=False) def _test_classification_model(self, name, input_shape): + set_rng_seed(0) # passing num_class equal to a number other than 1000 helps in making the test # more enforcing in nature - set_rng_seed(0) model = models.__dict__[name](num_classes=50) - self.check_script(model, name) model.eval() x = torch.rand(input_shape) out = model(x) - self.assertExpected(out, rtol=1e-2, atol=0.) + self.assertExpected(out, prec=0.1) self.assertEqual(out.shape[-1], 50) + self.checkModule(model, name, (x,)) def _test_segmentation_model(self, name): # passing num_class equal to a number other than 1000 helps in making the test # more enforcing in nature model = models.segmentation.__dict__[name](num_classes=50, pretrained_backbone=False) - self.check_script(model, name) model.eval() input_shape = (1, 3, 300, 300) x = torch.rand(input_shape) out = model(x) self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300)) + self.checkModule(model, name, (x,)) def _test_detection_model(self, name): set_rng_seed(0) @@ -127,24 +133,25 @@ def compute_mean_std(tensor): # compare results with mean and std if name == "maskrcnn_resnet50_fpn": test_value = map_nested_tensor_object(out, tensor_map_fn=compute_mean_std) - # mean values are small, use large rtol - self.assertExpected(test_value, rtol=.01, atol=.01) + # mean values are small, use large prec + self.assertExpected(test_value, prec=.01) else: - self.assertExpected(map_nested_tensor_object(out, tensor_map_fn=subsample_tensor)) + self.assertExpected(map_nested_tensor_object(out, tensor_map_fn=subsample_tensor), prec=0.01) scripted_model = torch.jit.script(model) scripted_model.eval() scripted_out = scripted_model(model_input)[1] - self.assertNestedTensorObjectsEqual(scripted_out[0]["boxes"], out[0]["boxes"]) - self.assertNestedTensorObjectsEqual(scripted_out[0]["scores"], out[0]["scores"]) + self.assertEqual(scripted_out[0]["boxes"], out[0]["boxes"]) + self.assertEqual(scripted_out[0]["scores"], out[0]["scores"]) # labels currently float in script: need to investigate (though same result) - self.assertNestedTensorObjectsEqual(scripted_out[0]["labels"].to(dtype=torch.long), out[0]["labels"]) + self.assertEqual(scripted_out[0]["labels"].to(dtype=torch.long), out[0]["labels"]) self.assertTrue("boxes" in out[0]) self.assertTrue("scores" in out[0]) self.assertTrue("labels" in out[0]) # don't check script because we are compiling it here: # TODO: refactor tests # self.check_script(model, name) + self.checkModule(model, name, ([x],)) def _test_video_model(self, name): # the default input shape is @@ -152,9 +159,10 @@ def _test_video_model(self, name): input_shape = (1, 3, 4, 112, 112) # test both basicblock and Bottleneck model = models.video.__dict__[name](num_classes=50) - self.check_script(model, name) + model.eval() x = torch.rand(input_shape) out = model(x) + self.checkModule(model, name, (x,)) self.assertEqual(out.shape[-1], 50) def _make_sliced_model(self, model, stop_layer): diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index 6dfcf788d0b..4b767f8209b 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -216,7 +216,6 @@ def concat_box_prediction_layers(box_cls, box_regression): # same format as the labels. Note that the labels are computed for # all feature levels concatenated, so we keep the same representation # for the objectness and the box_regression - last_C = torch.jit.annotate(Optional[int], None) for box_cls_per_level, box_regression_per_level in zip( box_cls, box_regression ): @@ -229,16 +228,14 @@ def concat_box_prediction_layers(box_cls, box_regression): ) box_cls_flattened.append(box_cls_per_level) - last_C = C box_regression_per_level = permute_and_flatten( box_regression_per_level, N, A, 4, H, W ) box_regression_flattened.append(box_regression_per_level) - assert last_C is not None # concatenate on the first dimension (representing the feature levels), to # take into account the way the labels were generated (with all feature maps # being concatenated as well) - box_cls = torch.cat(box_cls_flattened, dim=1).reshape(-1, last_C) + box_cls = torch.cat(box_cls_flattened, dim=1).flatten(0, -2) box_regression = torch.cat(box_regression_flattened, dim=1).reshape(-1, 4) return box_cls, box_regression diff --git a/torchvision/models/mobilenet.py b/torchvision/models/mobilenet.py index b3ba049a4c6..6d10610b633 100644 --- a/torchvision/models/mobilenet.py +++ b/torchvision/models/mobilenet.py @@ -147,14 +147,16 @@ def __init__(self, nn.init.normal_(m.weight, 0, 0.01) nn.init.zeros_(m.bias) - def _forward(self, x): + def _forward_impl(self, x): + # This exists since TorchScript doesn't support inheritance, so the superclass method + # (this one) needs to have a name other than `forward` that can be accessed in a subclass x = self.features(x) x = x.mean([2, 3]) x = self.classifier(x) return x - # Allow for accessing forward method in a inherited class - forward = _forward + def forward(self, x): + return self._forward_impl(x) def mobilenet_v2(pretrained=False, progress=True, **kwargs): diff --git a/torchvision/models/quantization/mobilenet.py b/torchvision/models/quantization/mobilenet.py index e665f121234..1d14410f376 100644 --- a/torchvision/models/quantization/mobilenet.py +++ b/torchvision/models/quantization/mobilenet.py @@ -44,7 +44,7 @@ def __init__(self, *args, **kwargs): def forward(self, x): x = self.quant(x) - x = self._forward(x) + x = self._forward_impl(x) x = self.dequant(x) return x diff --git a/torchvision/models/quantization/resnet.py b/torchvision/models/quantization/resnet.py index f00b7ed46d3..5fd3c039299 100644 --- a/torchvision/models/quantization/resnet.py +++ b/torchvision/models/quantization/resnet.py @@ -95,7 +95,7 @@ def forward(self, x): # Ensure scriptability # super(QuantizableResNet,self).forward(x) # is not scriptable - x = self._forward(x) + x = self._forward_impl(x) x = self.dequant(x) return x diff --git a/torchvision/models/quantization/shufflenetv2.py b/torchvision/models/quantization/shufflenetv2.py index 9826e61f679..c9aeb2368ad 100644 --- a/torchvision/models/quantization/shufflenetv2.py +++ b/torchvision/models/quantization/shufflenetv2.py @@ -46,7 +46,7 @@ def __init__(self, *args, **kwargs): def forward(self, x): x = self.quant(x) - x = self._forward(x) + x = self._forward_impl(x) x = self.dequant(x) return x diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index 223159ccae6..527eab8ff05 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -194,7 +194,8 @@ def _make_layer(self, block, planes, blocks, stride=1, dilate=False): return nn.Sequential(*layers) - def _forward(self, x): + def _forward_impl(self, x): + # See note [TorchScript super()] x = self.conv1(x) x = self.bn1(x) x = self.relu(x) @@ -211,8 +212,8 @@ def _forward(self, x): return x - # Allow for accessing forward method in a inherited class - forward = _forward + def forward(self, x): + return self._forward_impl(x) def _resnet(arch, block, layers, pretrained, progress, **kwargs): diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index 7817e8aa1c1..14f9521886c 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -122,7 +122,8 @@ def __init__(self, stages_repeats, stages_out_channels, num_classes=1000, invert self.fc = nn.Linear(output_channels, num_classes) - def _forward(self, x): + def _forward_impl(self, x): + # See note [TorchScript super()] x = self.conv1(x) x = self.maxpool(x) x = self.stage2(x) @@ -133,7 +134,8 @@ def _forward(self, x): x = self.fc(x) return x - forward = _forward + def forward(self, x): + return self._forward_impl(x) def _shufflenetv2(arch, pretrained, progress, *args, **kwargs):