-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Add tests for results in script vs eager mode #1430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
49e8bf7
Add tests for results in script vs eager mode
ccfd00b
Fix inception, use PYTORCH_TEST_WITH_SLOW
9299fe2
Update
fed5ce5
Remove assertNestedTensorObjectsEqual
c842557
Add PYTORCH_TEST_WITH_SLOW to CircleCI config
d4ca330
Add MaskRCNN unwrapper
557a15c
fix prec args
b3471f9
Remove CI changes
44f02f1
Merge branch 'master' of github.com:pytorch/vision into driazati/chec…
7cc2897
update
8922bef
Update
0ce7102
remove expect changes
fda9a92
Fix tolerance bug
9e60684
Merge branch 'master' of github.com:pytorch/vision into driazati/chec…
fmassa 4e84e0f
Merge branch 'master' of github.com:pytorch/vision into driazati/chec…
fmassa 7578d46
Merge branch 'master' of github.com:pytorch/vision into driazati/chec…
2da2d19
Fix breakages
f00dcd1
Fix quantized resnet
fmassa a440686
Merge branch 'master' of github.com:pytorch/vision into driazati/chec…
fmassa c87eac4
Merge branch 'master' of github.com:pytorch/vision into driazati/chec…
fmassa 6e7bc18
Fix merge errors and simplify code
fmassa 305b8d7
DeepLabV3 has been fixed
fmassa 6e07772
Temporarily disable jit compilation
fmassa File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,10 +5,15 @@ | |
import unittest | ||
import argparse | ||
import sys | ||
import io | ||
import torch | ||
import errno | ||
import __main__ | ||
|
||
from numbers import Number | ||
from torch._six import string_classes, inf | ||
from collections import OrderedDict | ||
|
||
|
||
@contextlib.contextmanager | ||
def get_tmp_dir(src=None, **kwargs): | ||
|
@@ -23,6 +28,9 @@ def get_tmp_dir(src=None, **kwargs): | |
|
||
|
||
ACCEPT = os.getenv('EXPECTTEST_ACCEPT') | ||
TEST_WITH_SLOW = os.getenv('PYTORCH_TEST_WITH_SLOW', '0') == '1' | ||
# TEST_WITH_SLOW = True # TODO: Delete this line once there is a PYTORCH_TEST_WITH_SLOW aware CI job | ||
|
||
|
||
parser = argparse.ArgumentParser(add_help=False) | ||
parser.add_argument('--accept', action='store_true') | ||
|
@@ -64,10 +72,20 @@ def map_nested_tensor_object(object, tensor_map_fn): | |
return impl(object) | ||
|
||
|
||
def is_iterable(obj): | ||
try: | ||
iter(obj) | ||
return True | ||
except TypeError: | ||
return False | ||
|
||
|
||
# adapted from TestCase in torch/test/common_utils to accept non-string | ||
# inputs and set maximum binary size | ||
class TestCase(unittest.TestCase): | ||
def assertExpected(self, output, subname=None, rtol=None, atol=None): | ||
precision = 1e-5 | ||
|
||
def assertExpected(self, output, subname=None, prec=None): | ||
r""" | ||
Test that a python value matches the recorded contents of a file | ||
derived from the name of this test and subname. The value must be | ||
|
@@ -123,31 +141,182 @@ def accept_output(update_type): | |
if ACCEPT: | ||
equal = False | ||
try: | ||
equal = self.assertNestedTensorObjectsEqual(output, expected, rtol=rtol, atol=atol) | ||
equal = self.assertEqual(output, expected, prec=prec) | ||
except Exception: | ||
equal = False | ||
if not equal: | ||
return accept_output("updated output") | ||
else: | ||
self.assertNestedTensorObjectsEqual(output, expected, rtol=rtol, atol=atol) | ||
self.assertEqual(output, expected, prec=prec) | ||
|
||
def assertNestedTensorObjectsEqual(self, a, b, rtol=None, atol=None): | ||
self.assertEqual(type(a), type(b)) | ||
def assertEqual(self, x, y, prec=None, message='', allow_inf=False): | ||
""" | ||
This is copied from pytorch/test/common_utils.py's TestCase.assertEqual | ||
""" | ||
if isinstance(prec, str) and message == '': | ||
message = prec | ||
prec = None | ||
if prec is None: | ||
prec = self.precision | ||
|
||
if isinstance(a, torch.Tensor): | ||
torch.testing.assert_allclose(a, b, rtol=rtol, atol=atol) | ||
if isinstance(x, torch.Tensor) and isinstance(y, Number): | ||
self.assertEqual(x.item(), y, prec=prec, message=message, | ||
allow_inf=allow_inf) | ||
elif isinstance(y, torch.Tensor) and isinstance(x, Number): | ||
self.assertEqual(x, y.item(), prec=prec, message=message, | ||
allow_inf=allow_inf) | ||
elif isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor): | ||
def assertTensorsEqual(a, b): | ||
super(TestCase, self).assertEqual(a.size(), b.size(), message) | ||
if a.numel() > 0: | ||
if (a.device.type == 'cpu' and (a.dtype == torch.float16 or a.dtype == torch.bfloat16)): | ||
# CPU half and bfloat16 tensors don't have the methods we need below | ||
a = a.to(torch.float32) | ||
b = b.to(a) | ||
|
||
elif isinstance(a, dict): | ||
self.assertEqual(len(a), len(b)) | ||
for key, value in a.items(): | ||
self.assertTrue(key in b, "key: " + str(key)) | ||
if (a.dtype == torch.bool) != (b.dtype == torch.bool): | ||
raise TypeError("Was expecting both tensors to be bool type.") | ||
else: | ||
if a.dtype == torch.bool and b.dtype == torch.bool: | ||
# we want to respect precision but as bool doesn't support substraction, | ||
# boolean tensor has to be converted to int | ||
a = a.to(torch.int) | ||
b = b.to(torch.int) | ||
|
||
self.assertNestedTensorObjectsEqual(value, b[key], rtol=rtol, atol=atol) | ||
elif isinstance(a, (list, tuple)): | ||
self.assertEqual(len(a), len(b)) | ||
diff = a - b | ||
if a.is_floating_point(): | ||
# check that NaNs are in the same locations | ||
nan_mask = torch.isnan(a) | ||
self.assertTrue(torch.equal(nan_mask, torch.isnan(b)), message) | ||
diff[nan_mask] = 0 | ||
# inf check if allow_inf=True | ||
if allow_inf: | ||
inf_mask = torch.isinf(a) | ||
inf_sign = inf_mask.sign() | ||
self.assertTrue(torch.equal(inf_sign, torch.isinf(b).sign()), message) | ||
diff[inf_mask] = 0 | ||
# TODO: implement abs on CharTensor (int8) | ||
if diff.is_signed() and diff.dtype != torch.int8: | ||
diff = diff.abs() | ||
max_err = diff.max() | ||
tolerance = prec + prec * abs(a.max()) | ||
self.assertLessEqual(max_err, tolerance, message) | ||
super(TestCase, self).assertEqual(x.is_sparse, y.is_sparse, message) | ||
super(TestCase, self).assertEqual(x.is_quantized, y.is_quantized, message) | ||
if x.is_sparse: | ||
x = self.safeCoalesce(x) | ||
y = self.safeCoalesce(y) | ||
assertTensorsEqual(x._indices(), y._indices()) | ||
assertTensorsEqual(x._values(), y._values()) | ||
elif x.is_quantized and y.is_quantized: | ||
self.assertEqual(x.qscheme(), y.qscheme(), prec=prec, | ||
message=message, allow_inf=allow_inf) | ||
if x.qscheme() == torch.per_tensor_affine: | ||
self.assertEqual(x.q_scale(), y.q_scale(), prec=prec, | ||
message=message, allow_inf=allow_inf) | ||
self.assertEqual(x.q_zero_point(), y.q_zero_point(), | ||
prec=prec, message=message, | ||
allow_inf=allow_inf) | ||
elif x.qscheme() == torch.per_channel_affine: | ||
self.assertEqual(x.q_per_channel_scales(), y.q_per_channel_scales(), prec=prec, | ||
message=message, allow_inf=allow_inf) | ||
self.assertEqual(x.q_per_channel_zero_points(), y.q_per_channel_zero_points(), | ||
prec=prec, message=message, | ||
allow_inf=allow_inf) | ||
self.assertEqual(x.q_per_channel_axis(), y.q_per_channel_axis(), | ||
prec=prec, message=message) | ||
self.assertEqual(x.dtype, y.dtype) | ||
self.assertEqual(x.int_repr().to(torch.int32), | ||
y.int_repr().to(torch.int32), prec=prec, | ||
message=message, allow_inf=allow_inf) | ||
else: | ||
assertTensorsEqual(x, y) | ||
elif isinstance(x, string_classes) and isinstance(y, string_classes): | ||
super(TestCase, self).assertEqual(x, y, message) | ||
elif type(x) == set and type(y) == set: | ||
super(TestCase, self).assertEqual(x, y, message) | ||
elif isinstance(x, dict) and isinstance(y, dict): | ||
if isinstance(x, OrderedDict) and isinstance(y, OrderedDict): | ||
self.assertEqual(x.items(), y.items(), prec=prec, | ||
message=message, allow_inf=allow_inf) | ||
else: | ||
self.assertEqual(set(x.keys()), set(y.keys()), prec=prec, | ||
message=message, allow_inf=allow_inf) | ||
key_list = list(x.keys()) | ||
self.assertEqual([x[k] for k in key_list], | ||
[y[k] for k in key_list], | ||
prec=prec, message=message, | ||
allow_inf=allow_inf) | ||
elif is_iterable(x) and is_iterable(y): | ||
super(TestCase, self).assertEqual(len(x), len(y), message) | ||
for x_, y_ in zip(x, y): | ||
self.assertEqual(x_, y_, prec=prec, message=message, | ||
allow_inf=allow_inf) | ||
elif isinstance(x, bool) and isinstance(y, bool): | ||
super(TestCase, self).assertEqual(x, y, message) | ||
elif isinstance(x, Number) and isinstance(y, Number): | ||
if abs(x) == inf or abs(y) == inf: | ||
if allow_inf: | ||
super(TestCase, self).assertEqual(x, y, message) | ||
else: | ||
self.fail("Expected finite numeric values - x={}, y={}".format(x, y)) | ||
return | ||
super(TestCase, self).assertLessEqual(abs(x - y), prec, message) | ||
else: | ||
super(TestCase, self).assertEqual(x, y, message) | ||
|
||
for val1, val2 in zip(a, b): | ||
self.assertNestedTensorObjectsEqual(val1, val2, rtol=rtol, atol=atol) | ||
def checkModule(self, nn_module, args, unwrapper=None, skip=False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is copied from |
||
""" | ||
Check that a nn.Module's results in TorchScript match eager and that it | ||
can be exported | ||
""" | ||
if not TEST_WITH_SLOW or skip: | ||
# TorchScript is not enabled, skip these tests | ||
return | ||
|
||
else: | ||
self.assertEqual(a, b) | ||
sm = torch.jit.script(nn_module) | ||
|
||
with freeze_rng_state(): | ||
eager_out = nn_module(*args) | ||
|
||
with freeze_rng_state(): | ||
script_out = sm(*args) | ||
if unwrapper: | ||
script_out = unwrapper(script_out) | ||
|
||
self.assertEqual(eager_out, script_out) | ||
self.assertExportImportModule(sm, args) | ||
|
||
return sm | ||
|
||
def getExportImportCopy(self, m): | ||
""" | ||
Save and load a TorchScript model | ||
""" | ||
buffer = io.BytesIO() | ||
torch.jit.save(m, buffer) | ||
buffer.seek(0) | ||
imported = torch.jit.load(buffer) | ||
return imported | ||
|
||
def assertExportImportModule(self, m, args): | ||
""" | ||
Check that the results of a model are the same after saving and loading | ||
""" | ||
m_import = self.getExportImportCopy(m) | ||
with freeze_rng_state(): | ||
results = m(*args) | ||
with freeze_rng_state(): | ||
results_from_imported = m_import(*args) | ||
self.assertEqual(results, results_from_imported) | ||
|
||
|
||
@contextlib.contextmanager | ||
def freeze_rng_state(): | ||
rng_state = torch.get_rng_state() | ||
if torch.cuda.is_available(): | ||
cuda_rng_state = torch.cuda.get_rng_state() | ||
yield | ||
if torch.cuda.is_available(): | ||
torch.cuda.set_rng_state(cuda_rng_state) | ||
torch.set_rng_state(rng_state) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This all copied from PyTorch's
common_utils
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is maybe more general than it needs to be - would
assertNestedTensorObjectsEqual
suffice ? this function e.g. does coercion between numbers and tensor results that you wouldn't want to allow for testing model equality. it's also a good amount of code.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assertEqual
has been pretty stable and widely used in PyTorch, I think it's better to just copy the entire thing and avoid having to move it over piece by piece later on if some functionality is missing fromassertNestedTensorObjectsEqual
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay, but we shouldn't have both in this file, we should only have one.