Skip to content

Commit

Permalink
[ONNX] Add dtype check in onnx verification (#79263) (#79263)
Browse files Browse the repository at this point in the history
Summary:
Currently we don't have a dtype check in verifying the consistency between PyTorch and ONNX outputs. As a result, some of dtype inconsistencies were found and reported: #77842 #77845

This is a POC.

Failed workflows:
- [linux-xenial-py3.7-clang7-onnx / test (default, 2, 2, linux.2xlarge)]
  - inconsistent shape
    - TestONNXRuntime_opset10.test_all (#79371)
    - TestONNXRuntime_opset10.test_any (#79371)
    - TestONNXRuntime_opset10.test_argmin_argmax (#79503)
    - TestONNXRuntime_opset10.test_hardshrink (#79695)
    - TestONNXRuntime_opset10.test_linalg_norm (#79506)
    - TestONNXRuntime_opset10.test_linalg_vector_norm (#79506)
    - TestONNXRuntime_opset10.test_prelu_scalar (#79846)
    - TestONNXRuntime_opset10.test_softshrink (#79695)
    - TestONNXRuntime_opset10.test_sum_empty_tensor (skipped)
    - TestONNXRuntime_opset10.test_tolist (skipped)
  - inconsistent dtype
    - test_arithmetic_prim_bool (skipped)
    - test_arithmeticOps_with_low_precision (skipped)
    - test_arithmetic_prim_float (skipped)
    - test_logical_and (#79339)
    - test_logical_or (#79339)
    - test_logical_xor (#79339)
    - test_pow (skipped)
    - test_primitive_input_floating (skipped)
    - test_quantize_per_tensor (#79690)
    - test_quantized_adaptive_avg_pool2d (#79690)
    - test_quantized_arithmetic (#79690)
    - test_quantized_arithmetic_qfunctional (#79690)
    - test_quantized_conv2d (#79690)
    - test_quantized_conv2d_relu (#79690)
    - test_quantized_flatten (#79690)
    - test_quantized_hardsigmoid (#79690)
    - test_quantized_hardswish (#79690)
    - test_quantized_linear (#79690)
    - test_quantized_sigmoid (#79690)
    - test_item (skipped)
    - test_full_like_value (skipped)
    - TestONNXRuntime_opset7.test_div_rounding_mode (skipped)
    - TestONNXRuntime_opset8.test_div_rounding_mode (skipped)
    - TestONNXRuntime_opset9.test_div_rounding_mode (skipped)
    - TestONNXRuntime_opset9_IRv4.test_div_rounding_mode (skipped)
    - test_outer (skipped)
    - test_symbolic_shape_inference_arange_2 (skipped)

Pull Request resolved: #79263
Approved by: https://github.com/justinchuby, https://github.com/BowenBao

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/d9a7e93aaf3166e639ea413123bd6c38b9144adc

Reviewed By: seemethere

Differential Revision: D38585848

fbshipit-source-id: 9da98581ceec51142ae31d3f8a06f9f296a16b23
  • Loading branch information
qqaatw authored and facebook-github-bot committed Aug 10, 2022
1 parent a4153aa commit dbbf9be
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 5 deletions.
6 changes: 6 additions & 0 deletions test/onnx/onnx_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ def run_model_test(test_suite: _TestONNXRuntime, *args, **kwargs):
kwargs["ort_providers"] = _ORT_PROVIDERS
kwargs["opset_version"] = test_suite.opset_version
kwargs["keep_initializers_as_inputs"] = test_suite.keep_initializers_as_inputs
if hasattr(test_suite, "check_shape"):
kwargs["check_shape"] = test_suite.check_shape
if hasattr(test_suite, "check_dtype"):
kwargs["check_dtype"] = test_suite.check_dtype
return verification.verify(*args, **kwargs)


Expand All @@ -60,6 +64,8 @@ class _TestONNXRuntime(common_utils.TestCase):
opset_version = _constants.onnx_default_opset
keep_initializers_as_inputs = True # For IR version 3 type export.
is_script = False
check_shape = True
check_dtype = True

def setUp(self):
set_rng_seed(0)
Expand Down
18 changes: 18 additions & 0 deletions test/onnx/pytorch_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,5 +139,23 @@ def wrapper(self, *args, **kwargs):
return skip_dec


def skipShapeChecking(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
self.check_shape = False
return func(self, *args, **kwargs)

return wrapper


def skipDtypeChecking(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
self.check_dtype = False
return func(self, *args, **kwargs)

return wrapper


def flatten(x):
return tuple(function._iter_filter(lambda o: isinstance(o, torch.Tensor))(x))
9 changes: 8 additions & 1 deletion test/onnx/test_pytorch_jit_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class _TestJITIRToONNX:

opset_version = -1 # Sub-classes must override
ort_providers = ["CPUExecutionProvider"]
check_shape = True
check_dtype = True

def run_test(self, graph_ir, example_inputs):
graph = torch._C.parse_ir(graph_ir)
Expand All @@ -64,7 +66,12 @@ def run_test(self, graph_ir, example_inputs):
ort_outs = verification._run_ort(ort_sess, example_inputs)

verification._compare_ort_pytorch_outputs(
ort_outs, jit_outs, rtol=1e-3, atol=1e-7
ort_outs,
jit_outs,
rtol=1e-3,
atol=1e-7,
check_shape=self.check_shape,
check_dtype=self.check_dtype,
)

def test_example_ir(self):
Expand Down
14 changes: 14 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@
RNN_HIDDEN_SIZE,
RNN_INPUT_SIZE,
RNN_SEQUENCE_LENGTH,
skipDtypeChecking,
skipForAllOpsetVersions,
skipIfUnsupportedMaxOpsetVersion,
skipIfUnsupportedMinOpsetVersion,
skipIfUnsupportedOpsetVersion,
skipScriptTest,
skipShapeChecking,
skipTraceTest,
)
from torch import Tensor
Expand Down Expand Up @@ -827,6 +829,7 @@ def forward(self, x: int, y):
y = torch.randint(10, (2, 3, 4))
self.run_test(Model(), (x, y))

@skipDtypeChecking
def test_primitive_input_floating(self):
class Model(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -1531,6 +1534,7 @@ def forward(self, x):
x = torch.randn(2, 3, 4)
self.run_test(ArithmeticModule(), x, remained_onnx_input_idx=[])

@skipDtypeChecking
def test_arithmetic_prim_float(self):
class ArithmeticModule(torch.nn.Module):
def forward(self, x, y: float):
Expand All @@ -1553,6 +1557,7 @@ def forward(self, x):
x = torch.randn(2, 3, 4)
self.run_test(ArithmeticModule(), x, remained_onnx_input_idx=[])

@skipDtypeChecking
def test_arithmetic_prim_bool(self):
class ArithmeticModule(torch.nn.Module):
def forward(self, x, y: int, z: bool, t: float):
Expand Down Expand Up @@ -1720,6 +1725,7 @@ def forward(self, x, y):
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.double)
self.run_test(torch.jit.script(DivModule()), (x, y))

@skipDtypeChecking
def test_div_rounding_mode(self):
class TrueDivModule(torch.nn.Module):
def forward(self, x, y):
Expand Down Expand Up @@ -2940,6 +2946,7 @@ def forward(self, x):
torch.jit.script(ListUnpackSlice()), x, remained_onnx_input_idx=[]
)

@skipDtypeChecking
def test_pow(self):
class PowModule(torch.nn.Module):
def forward(self, x, y):
Expand Down Expand Up @@ -2986,6 +2993,7 @@ def forward(self, x, y):
# add to(dtype=torch.long) to avoid ORT output type does not match expected type.
# will be fixed in ONNX version 14.
@skipIfUnsupportedMaxOpsetVersion(13)
@skipDtypeChecking
def test_arithmeticOps_with_low_precision(self):
class AddModule(torch.nn.Module):
def forward(self, x, y):
Expand Down Expand Up @@ -5279,6 +5287,7 @@ def forward(self, x, y, z, ind):
ind = torch.tensor(-2, dtype=torch.long)
self.run_test(GetItemModel(), (x, y, z, ind))

@skipDtypeChecking
def test_item(self):
class M(torch.nn.Module):
def forward(self, x, y, i: int):
Expand Down Expand Up @@ -6085,6 +6094,7 @@ def forward(self, x):
self.run_test(ZeroAndOnes(), (x,))

@skipIfUnsupportedMinOpsetVersion(9)
@skipShapeChecking
def test_tolist(self):
class List(torch.jit.ScriptModule):
@torch.jit.script_method
Expand Down Expand Up @@ -6626,6 +6636,7 @@ def forward(self, x):
self.run_test(FullLikeModel(), x)

@skipIfUnsupportedMinOpsetVersion(9)
@skipDtypeChecking
def test_full_like_value(self):
class FullLikeModel(torch.nn.Module):
def forward(self, x, y):
Expand Down Expand Up @@ -7892,6 +7903,7 @@ def forward(self, poses):
self.run_test(M(), (dummy_inputs,), input_names=["x"], dynamic_axes={"x": [0]})

@skipIfUnsupportedMinOpsetVersion(12)
@skipDtypeChecking
def test_outer(self):
class Outer(torch.nn.Module):
def forward(self, x, y):
Expand Down Expand Up @@ -11060,6 +11072,7 @@ def forward(self, boxes, scores):
self.run_test(model, (boxes, scores))

@skipIfUnsupportedMinOpsetVersion(11)
@skipDtypeChecking
def test_symbolic_shape_inference_arange_2(self):
# test Range
class ArangeModel(torch.nn.Module):
Expand Down Expand Up @@ -11516,6 +11529,7 @@ def forward(self, x):
x = torch.ones(12, 3)
self.run_test(M(), (x,), input_names=["x"], dynamic_axes={"x": [0]})

@skipShapeChecking
def test_sum_empty_tensor(self):
class M(torch.nn.Module):
def forward(self, x):
Expand Down
41 changes: 37 additions & 4 deletions torch/onnx/verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,34 @@ def _ort_session(
return ort_session


def _compare_ort_pytorch_outputs(ort_outs, pt_outs, rtol, atol):
def _compare_ort_pytorch_outputs(
ort_outs: Sequence[np.ndarray],
pt_outs: Sequence[torch.Tensor],
rtol: float,
atol: float,
check_shape: bool,
check_dtype: bool,
):
pt_outs, _ = torch.jit._flatten(pt_outs)
pt_outs = _unpack_to_numpy(pt_outs, cast_onnx_accepted=False)

assert len(pt_outs) == len(ort_outs), "number of outputs differ"
assert len(ort_outs) == len(
pt_outs
), f"Number of outputs differ ONNX runtime: ({len(ort_outs)}) PyTorch: ({len(pt_outs)})"

for ort_out, pt_out in zip(ort_outs, pt_outs):
np.testing.assert_allclose(ort_out, pt_out, rtol=rtol, atol=atol)
# TODO: Remove `check_shape` option once every shape inconsistent issue is addressed.
if not check_shape:
# Allow different but broadcastable output shapes.
ort_out, pt_out = np.broadcast_arrays(ort_out, pt_out)
torch.testing.assert_close(
ort_out,
pt_out,
rtol=rtol,
atol=atol,
check_dtype=check_dtype,
equal_nan=True,
)


def _prepare_input_for_pytorch(args, kwargs):
Expand Down Expand Up @@ -221,6 +241,8 @@ def _compare_ort_pytorch_model(
flatten,
rtol,
atol,
check_shape,
check_dtype,
):
"""Compare outputs from ONNX model runs with outputs from PyTorch model runs.
Expand All @@ -242,7 +264,9 @@ def compare_ort_pytorch_model_with_input(input_args, input_kwargs):
)
ort_outs = _run_ort(ort_session, ort_inputs)

_compare_ort_pytorch_outputs(ort_outs, pt_outs, rtol, atol)
_compare_ort_pytorch_outputs(
ort_outs, pt_outs, rtol, atol, check_shape, check_dtype
)

compare_ort_pytorch_model_with_input(input_args, input_kwargs)

Expand Down Expand Up @@ -519,6 +543,8 @@ def verify(
additional_test_inputs: Optional[Sequence[Tuple[Any, ...]]] = None,
remained_onnx_input_idx: Optional[Sequence[int]] = None,
flatten: bool = True,
check_shape: bool = True,
check_dtype: bool = True,
ort_providers: Sequence[str] = _ORT_PROVIDERS,
rtol: float = 0.001,
atol: float = 1e-7,
Expand Down Expand Up @@ -552,6 +578,11 @@ def verify(
inputs into a flattened list of Tensors for ONNX. Set this to False if nested
structures are to be preserved for ONNX, which is usually the case with
exporting ScriptModules.
check_shape (bool, optional): Default True. If True, check the shapes between
PyTorch and ONNX Runtime outputs are exactly the same. Set this to False to allow
output shape broadcasting.
check_dtype (bool, optional): Default True. If True, check the dtypes between
PyTorch and ONNX Runtime outputs are consistent.
ort_providers (sequence, optional): ONNX Runtime providers to use.
rtol (float, optional): relative tolerance in comparison between ONNX and PyTorch outputs.
atol (float, optional): absolute tolerance in comparison between ONNX and PyTorch outputs.
Expand Down Expand Up @@ -601,4 +632,6 @@ def verify(
flatten,
rtol,
atol,
check_shape,
check_dtype,
)

0 comments on commit dbbf9be

Please sign in to comment.