Skip to content

Commit

Permalink
Extend _TestONNXRuntime to reuses all tests for new model format (#11…
Browse files Browse the repository at this point in the history
…2289)

`_TestONNXRuntime` has infra to test models which are either Callable or a `torch.nn.Module`.

After #111497, we want to re-run all those tests for model of type `torch.export.ExportedProgram`.

This PR adds to `self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime` the capability of detect the model type to be tested and export the incoming `torch.nn.Module` model to `torch.export.ExportedProgram` before running ONNX export tests.
Pull Request resolved: #112289
Approved by: https://github.com/titaiwangms
  • Loading branch information
Thiago Crepaldi authored and pytorchmergebot committed Nov 18, 2023
1 parent 2efa89a commit d4189d8
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 5 deletions.
29 changes: 27 additions & 2 deletions test/onnx/onnx_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os
import unittest
import warnings
from enum import auto, Enum
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -63,6 +64,11 @@
pytorch_operator_dir = os.path.join(onnx_model_dir, "pytorch-operator")


class TorchModelType(Enum):
TORCH_NN_MODULE = auto()
TORCH_EXPORT_EXPORTEDPROGRAM = auto()


def run_model_test(test_suite: _TestONNXRuntime, *args, **kwargs):
options = verification.VerificationOptions()

Expand Down Expand Up @@ -255,7 +261,10 @@ def run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
if input_kwargs is None:
input_kwargs = {}

if has_mutation:
if (
has_mutation
and self.model_type != TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
):
ref_model = _try_clone_model(model)
ref_input_args, ref_input_kwargs = _try_clone_inputs(
input_args, input_kwargs
Expand All @@ -265,6 +274,19 @@ def run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
ref_input_args = input_args
ref_input_kwargs = input_kwargs

assert isinstance(ref_model, torch.nn.Module) or callable(
ref_model
), "Model must be a torch.nn.Module or callable"
if self.model_type == TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM:
ref_model = torch.export.export(ref_model, args=ref_input_args)
if (
self.dynamic_shapes
): # TODO: Support dynamic shapes for torch.export.ExportedProgram
# https://github.com/pytorch/pytorch/issues/113705
pytest.xfail(
reason="torch.export.ExportedProgram does not support dynamic shapes"
)

# Feed args and kwargs into exporter.
# Note that exporter should flatten kwargs into positional args the exported model;
# since ONNX doesn't represent kwargs.
Expand Down Expand Up @@ -373,7 +395,10 @@ def run_ort(
f"Expected {len(input_names)} inputs, got {len(pytorch_inputs)}"
)

ort_input = {k: v.cpu().numpy() for k, v in zip(input_names, pytorch_inputs)}
ort_input = {
k: torch.Tensor.numpy(v, force=True)
for k, v in zip(input_names, pytorch_inputs)
}
return session.run(None, ort_input)


Expand Down
59 changes: 59 additions & 0 deletions test/onnx/pytorch_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import numpy as np
import packaging.version
import pytest

import torch
from torch.autograd import function
Expand Down Expand Up @@ -329,6 +330,64 @@ def wrapper(self, *args, **kwargs):
return wrapper


def xfail_if_model_type_is_exportedprogram(reason: str):
"""xfail test with models using ExportedProgram as input.
Args:
reason: The reason for xfail the ONNX export test.
Returns:
A decorator for xfail tests.
"""

import onnx_test_common

def xfail_dec(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
if (
self.model_type
== onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
):
pytest.xfail(
reason=f"Xfail model_type==torch.export.ExportedProgram. {reason}"
)
return func(self, *args, **kwargs)

return wrapper

return xfail_dec


def xfail_if_model_type_is_not_exportedprogram(reason: str):
"""xfail test without models using ExportedProgram as input.
Args:
reason: The reason for xfail the ONNX export test.
Returns:
A decorator for xfail tests.
"""

import onnx_test_common

def xfail_dec(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
if (
self.model_type
!= onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
):
pytest.xfail(
reason=f"Xfail model_type!=torch.export.ExportedProgram. {reason}"
)
return func(self, *args, **kwargs)

return wrapper

return xfail_dec


def flatten(x):
return tuple(function._iter_filter(lambda o: isinstance(o, torch.Tensor))(x))

Expand Down
4 changes: 4 additions & 0 deletions test/onnx/test_fx_op_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,10 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
opset_version = -1
op_level_debug: bool = False
dynamic_shapes: bool = False
# TODO: Should onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM also be tested?
model_type: onnx_test_common.TorchModelType = (
onnx_test_common.TorchModelType.TORCH_NN_MODULE
)

fp16_low_precision_list = [
"nn.functional.batch_norm",
Expand Down
56 changes: 53 additions & 3 deletions test/onnx/test_fx_to_onnx_with_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,14 @@ def _parameterized_class_attrs_and_values():
itertools.product(
(True, False),
(True, False),
(
onnx_test_common.TorchModelType.TORCH_NN_MODULE,
onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
),
)
)
return {
"attrs": ["op_level_debug", "dynamic_shapes"],
"attrs": ["op_level_debug", "dynamic_shapes", "model_type"],
"input_values": input_values,
}

Expand All @@ -72,6 +76,7 @@ def _parameterize_class_name(cls: Type, idx: int, input_dicts: Mapping[Any, Any]
class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
op_level_debug: bool
dynamic_shapes: bool
model_type: onnx_test_common.TorchModelType

def setUp(self):
super().setUp()
Expand Down Expand Up @@ -131,6 +136,16 @@ def func(x, b=torch.tensor(1.0)):
func, (tensor_x,), input_kwargs={"b": torch.tensor(5.0)}
)

@pytorch_test_common.xfail_if_model_type_is_exportedprogram(
"torch._export.verifier.SpecViolationError: Operator '<built-in function pow>' is not an allowed operator type:"
"(<class 'torch._ops.OpOverload'>, <class 'torch._ops.HigherOrderOperator'>)"
"Valid builtin ops: [<built-in function getitem>, <built-in function add>, <built-in function mul>,"
"<built-in function sub>, <built-in function truediv>, <built-in function ge>, <built-in function le>,"
"<built-in function gt>, <built-in function lt>, <built-in function eq>, <built-in function ne>,"
"<built-in function floordiv>, <built-in function mod>, <built-in function and_>, <built-in function or_>,"
"<built-in function not_>]"
" Github issue: https://github.com/pytorch/pytorch/issues/113778"
)
@pytorch_test_common.skip_dynamic_fx_test(
"sympy operation tests don't need dynamic shape"
)
Expand Down Expand Up @@ -442,6 +457,12 @@ def forward(self, x):
additional_test_inputs=[((y,),)],
)

@pytorch_test_common.xfail_if_model_type_is_exportedprogram(
"RuntimeError:"
" Found following user inputs located at [0] are mutated. This is currently banned in the aot_export workflow."
" If you need this functionality, please file a github issue."
" Github issue: https://github.com/pytorch/pytorch/issues/112429"
)
def test_mutation(self):
class MutationModel(torch.nn.Module):
def forward(self, x):
Expand Down Expand Up @@ -469,6 +490,12 @@ def forward(self, input):
additional_test_inputs=[((y,),)],
)

@pytorch_test_common.xfail_if_model_type_is_exportedprogram(
"RuntimeError:"
" Found following user inputs located at [0] are mutated. This is currently banned in the aot_export workflow."
" If you need this functionality, please file a github issue."
" Github issue: https://github.com/pytorch/pytorch/issues/112429"
)
@pytorch_test_common.skip_dynamic_fx_test(
"[ONNXRuntimeError] : 1 : FAIL : Non-zero status code returned while running Slice node. "
"Name:'_inline_aten_slice_scattern13' Status Message: slice.cc:193 "
Expand Down Expand Up @@ -510,7 +537,7 @@ def forward(self, x):
"RuntimeError: at::functionalization::impl::isFunctionalTensor(self_) INTERNAL ASSERT FAILED "
"at '/path/to/pytorch/torch/csrc/autograd/python_torch_functions_manual.cpp':514, please report a bug to PyTorch."
)
def test_expand_as_fill_seperate_tensor(self):
def test_expand_as_fill_separate_tensor(self):
class Model(torch.nn.Module):
def forward(self, x):
aa = torch.tensor([[0], [1], [2]])
Expand Down Expand Up @@ -577,6 +604,10 @@ def func(x):
func, (torch.randn(3, 4),)
)

@pytorch_test_common.xfail_if_model_type_is_exportedprogram(
"Unsupported: {'call_function': ['<built-in function ge>', 'aten._assert_async.msg', '<built-in function le>']}."
" Github issue: https://github.com/pytorch/pytorch/issues/112443"
)
def test_operator_with_scalar_output(self):
def func(x, y):
return x.item() + y
Expand All @@ -585,6 +616,10 @@ def func(x, y):
func, (torch.tensor([1]), torch.randn(3, 4))
)

@pytorch_test_common.xfail_if_model_type_is_exportedprogram(
"Unsupported: Unsupported FX nodes: {'call_function': ['aten._assert_async.msg']}."
" Github issue: https://github.com/pytorch/pytorch/issues/112443"
)
def test_operator_with_dynamic_output_shape(self):
def func(x):
return x.nonzero()
Expand All @@ -593,6 +628,22 @@ def func(x):
func, (torch.randn(3, 4),)
)

@pytorch_test_common.xfail_if_model_type_is_exportedprogram(
"AssertionError: AssertionError: original output #1 is BaseModelOutputWithPastAndCrossAttentions("
" last_hidden_state=FakeTensor(..., size=(2, 128, 16), grad_fn=<ViewBackward0>),"
" past_key_values=((FakeTensor(..., size=(2, 2, 128, 8), grad_fn=<PermuteBackward0>),"
" FakeTensor(..., size=(2, 2, 128, 8), grad_fn=<PermuteBackward0>)),"
" (FakeTensor(..., size=(2, 2, 128, 8), grad_fn=<PermuteBackward0>),"
" FakeTensor(..., size=(2, 2, 128, 8), grad_fn=<PermuteBackward0>)),"
" (FakeTensor(..., size=(2, 2, 128, 8), grad_fn=<PermuteBackward0>),"
" FakeTensor(..., size=(2, 2, 128, 8), grad_fn=<PermuteBackward0>)),"
" (FakeTensor(..., size=(2, 2, 128, 8), grad_fn=<PermuteBackward0>),"
" FakeTensor(..., size=(2, 2, 128, 8), grad_fn=<PermuteBackward0>))),"
" hidden_states=None, attentions=None, cross_attentions=None),"
" but only the following types are supported:"
" (<class 'torch.Tensor'>, <class 'torch.SymInt'>, <class 'torch.SymFloat'>, <class 'torch.SymBool'>)"
" Github issue: https://github.com/pytorch/pytorch/issues/110100"
)
def test_gpt2_tiny_from_config(self):
# Model
config = transformers.GPT2Config(
Expand Down Expand Up @@ -853,7 +904,6 @@ def forward(self, x):
x = torch.randn(1, 1, 2, dtype=torch.float)
exported_program = torch.export.export(Model(), args=(x,))

# TODO: Support dynamic shape
self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
exported_program, (x,), skip_dynamic_shapes_check=True
)
Expand Down

0 comments on commit d4189d8

Please sign in to comment.