Skip to content

Commit

Permalink
[ONNX] Add additional_test_kwargs into test_fx_to_onnx_with_onnxrunti…
Browse files Browse the repository at this point in the history
…me.py

ghstack-source-id: 01716dab3990b5845ff82d2a4e6bd25d88084436
Pull Request resolved: #99434
  • Loading branch information
titaiwangms committed Apr 18, 2023
1 parent 24d20ea commit 81a5cd9
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 74 deletions.
150 changes: 79 additions & 71 deletions test/onnx/test_fx_to_onnx_with_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@

_NumericType = Union[Number, torch.Tensor, np.ndarray]
_ModelType = Union[torch.nn.Module, Callable]
_InputArgsType = Optional[Union[torch.Tensor, Sequence[Any], Mapping[str, Any]]]
_InputArgsType = Optional[
Union[torch.Tensor, int, float, bool, Sequence[Any], Mapping[str, Any]]
]
_OutputsType = Sequence[_NumericType]


Expand Down Expand Up @@ -141,27 +143,33 @@ def _run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
test_suite: TestFxToOnnxWithOnnxRuntime,
model: _ModelType,
input_args: Sequence[_InputArgsType],
input_kwargs: Mapping[str, _InputArgsType] = {},
rtol: float = 1e-3,
atol: float = 1e-7,
opset_version: int = 18,
has_mutation: bool = False,
additional_test_inputs: Optional[Sequence[Sequence[_InputArgsType]]] = None,
**input_kwargs,
additional_test_inputs: Optional[
List[
List[
Union[Sequence[_InputArgsType], Optional[Mapping[str, _InputArgsType]]]
]
]
] = None,
):
"""Compare the results of PyTorch model with exported ONNX model
Args:
model (_ModelType): PyTorch model
input_args (_InputArgsType): torch input arguments
input_args (Sequence[_InputArgsType]): torch input arguments
input_kwargs (Mapping[str, _InputArgsType]): torch input kwargs
rtol (float, optional): relative tolerance. Defaults to 1e-3.
atol (float, optional): absolute tolerance. Defaults to 1e-7.
opset_version (int, optional): ONNX opset version. Defaults to 18.
has_mutation (bool, optional): Whether the model mutates its input or state.
`mutation` as `True` incurs extra overhead of cloning the inputs and model.
Defaults to False.
additional_test_inputs (Optional[Sequence[_InputArgsType]], optional):
Test the models with another dataset, which is designed for dynamic axes
testing. Defaults to None.
additional_test_inputs: Test the models with another dataset input, which
is designed for dynamic axes testing. Defaults to None.
"""

Expand Down Expand Up @@ -198,14 +206,22 @@ def _run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
)
# This confirms the exported mode accepts different input shapes
# when dynamic shape is enabled.
# TODO(titaiwangms): additional input kwargs.
if additional_test_inputs and test_suite.dynamic_shapes:
for additional_input_args in additional_test_inputs:
for another_input in additional_test_inputs:
if len(another_input) == 2:
additional_input_args, additional_input_kwargs = another_input
elif len(another_input) == 1:
additional_input_args = another_input[0]
additional_input_kwargs = {}
else:
raise ValueError(
f"test_inputs should only have tuple args and dictionary kwargs. But receives: {len(another_input)}"
)
_compare_pytorch_onnx_with_ort(
export_output,
model,
additional_input_args,
{},
additional_input_kwargs,
atol,
rtol,
has_mutation=has_mutation,
Expand Down Expand Up @@ -317,7 +333,7 @@ def func(x, b=torch.tensor(1.0)):
z = y.relu()
return (y, z)

tensor_x = torch.randn(1, 1, 2, dtype=torch.float32)
tensor_x = torch.randn(1, 2, 3, dtype=torch.float32)

# Test without providing optional kwarg.
_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(self, func, (tensor_x,))
Expand All @@ -327,20 +343,14 @@ def func(x, b=torch.tensor(1.0)):
)
# Test while specifying optional kwarg.
_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
self, func, (tensor_x,), b=torch.tensor(5.0)
self, func, (tensor_x,), {"b": torch.tensor(5.0)}
)

# beartype.roar.BeartypeCallHintParamViolation:
# @beartyped onnxscript.function_libs.torch_aten.graph_building.TorchScriptGraph.add_input()
# parameter input_value=8.0 violates type hint typing.Union[torch.Tensor, NoneType],
# as float 8.0 not <class "builtins.NoneType"> or <protocol "torch.Tensor">.
# @unittest.expectedFailure
@pytorch_test_common.xfail(
"beartype.roar.BeartypeCallHintReturnViolation: @beartyped "
"torch.onnx._internal.exporter.ExportOutput.adapt_torch_inputs_to_onnx() "
"return (tensor([[[ 1.5410, -0.2934]]]), 8.0) violates type hint "
"typing.Sequence[torch.Tensor], as tuple index 1 item float 8.0 not "
"instance of <protocol 'torch.Tensor'>."
"AssertionError: Expected 1 inputs, got 2"
"Captured fx graph does not have any information of the constant input (arg1). "
"The constant input is inserted into op.target args directly."
"This might be a bug in fx tracer regarding potential break in dynamic shapes."
)
@pytorch_test_common.skip_min_ort_version(
reason="ORT doesn't support dynamic fx exporter yet making SegFault flaky test",
Expand All @@ -354,13 +364,20 @@ def func(x, b=1.0):
return (y, z)

tensor_x = torch.randn(1, 1, 2, dtype=torch.float32)
another_x = torch.randn(2, 2, 4, dtype=torch.float32)

_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(self, func, (tensor_x,))
# Test with only positional args.
_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(self, func, (tensor_x, 8.0))
_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
self, func, (tensor_x, 8.0), additional_test_inputs=[[(another_x, 9.0)]]
)
# Test while specifying optional kwarg.
_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
self, func, (tensor_x,), b=5.0
self,
func,
(tensor_x,),
{"b": 5.0},
additional_test_inputs=[[(another_x,), {"b": 6.0}]],
)

@pytorch_test_common.skip_min_ort_version(
Expand Down Expand Up @@ -449,25 +466,6 @@ def forward(self, tensor_x: torch.Tensor):
self, MNISTModel(), (tensor_x,)
)

@pytorch_test_common.skip_min_ort_version(
reason="ORT doesn't support dynamic fx exporter yet making SegFault flaky test",
version="1.15",
dynamic_only=True,
)
# test single op with no kwargs
def test_sigmoid(self):
x = torch.randn(1, 4, 2, 3)

class SigmoidModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.sigmoid = torch.nn.Sigmoid()

def forward(self, x):
return self.sigmoid(x)

_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(self, SigmoidModel(), (x,))

@pytorch_test_common.xfail(
"RuntimeError: false INTERNAL ASSERT FAILED at "
"'/home/titaiwang/pytorch/build/aten/src/ATen/RegisterFunctionalization_0.cpp':3725,"
Expand All @@ -485,7 +483,7 @@ def test_shufflenet_v2(self):
self,
model,
(dummy_input,),
additional_test_inputs=[(dummy_input,), (test_inputs,)],
additional_test_inputs=[[(test_inputs,)]],
rtol=1e-3,
atol=1e-5,
)
Expand All @@ -506,7 +504,10 @@ def forward(self, x, y):
another_y = torch.randn(3, 4)

_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
self, DynamicAdd(), (x, y), additional_test_inputs=[(another_x, another_y)]
self,
DynamicAdd(),
(x, y),
additional_test_inputs=[[(another_x, another_y)]],
)

@pytorch_test_common.skip_min_ort_version(
Expand All @@ -532,13 +533,14 @@ def forward(self, x, y):
input_y = torch.randn(1, 4)

_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
self, DynamicAdd(), (x, y), additional_test_inputs=[(input_x, input_y)]
self, DynamicAdd(), (x, y), additional_test_inputs=[[(input_x, input_y)]]
)

@pytorch_test_common.skip_dynamic_fx_test(
"[ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Non-zero status code returned "
"while running Expand node. Name:'_0x55b501ebaf10_n2' "
"Status Message: invalid expand shape"
"https://github.com/pytorch/pytorch/issues/99360"
)
@pytorch_test_common.skip_min_ort_version(
reason="ORT doesn't support dynamic fx exporter yet making SegFault flaky test",
Expand All @@ -556,11 +558,11 @@ def forward(self, x, y):
input_y = torch.randn(2, 4, 4)

_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
self, DynamicMatMul(), (x, y), additional_test_inputs=[(input_x, input_y)]
self, DynamicMatMul(), (x, y), additional_test_inputs=[[(input_x, input_y)]]
)

@pytorch_test_common.xfail(
"RuntimeError: Unknown call_function target: aten.scalar_tensor.default"
@pytorch_test_common.skip_dynamic_fx_test(
"fx graph does not capture symbolic value for aten::scalar_tensor."
)
def test_scalar_tensor(self):
class test(torch.nn.Module):
Expand All @@ -575,7 +577,7 @@ def forward(self, x):
self,
test(),
(x,),
additional_test_inputs=[(y,)],
additional_test_inputs=[[(y,)]],
)

def test_transpose_infer_shape(self):
Expand All @@ -594,7 +596,7 @@ def forward(self, x):
self,
TransposeModule(),
(x,),
additional_test_inputs=[(y,)],
additional_test_inputs=[[(y,)]],
)

@pytorch_test_common.xfail("torch._dynamo.exc.TorchRuntimeError")
Expand All @@ -608,16 +610,14 @@ def forward(self, d1, d2):
d3 = torch.tensor([3])
d4 = torch.tensor([4])
_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
self, Squeeze(), (d1, d4), additional_test_inputs=[(d3, d4)]
self, Squeeze(), (d1, d4), additional_test_inputs=[[(d3, d4)]]
)
_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
self, Squeeze(), (d3, d4), additional_test_inputs=[(d1, d3)]
self, Squeeze(), (d3, d4), additional_test_inputs=[[(d1, d3)]]
)

@pytorch_test_common.skip_dynamic_fx_test(
"AssertionError: The values for attribute 'shape' do not match:"
" torch.Size([5, 6, 2]) != torch.Size([4, 4, 2]). Even symbolic "
"fx.graph can't get dynamic arguments from this Module."
"https://github.com/pytorch/pytorch/issues/99360"
)
def test_slice(self):
class DynamicSliceExportMod(torch.nn.Module):
Expand All @@ -633,7 +633,7 @@ def forward(self, x):
self,
DynamicSliceExportMod(),
(x,),
additional_test_inputs=[(y,)],
additional_test_inputs=[[(y,)]],
)

# TODO(titaiwang): This is also detected flaky in static shape:
Expand All @@ -653,7 +653,10 @@ def forward(self, x):
self, MutationModel(), (torch.randn(12),), has_mutation=True
)

@pytorch_test_common.xfail("TypeError: missing a required argument: 'end'")
# TODO(justinchuby): A known limitation in aten::arange support.
@pytorch_test_common.xfail(
"arange overload does not support positional 'end' argument"
)
def test_arange(self):
class ArangeModel(torch.nn.Module):
def forward(self, input):
Expand All @@ -669,7 +672,7 @@ def forward(self, input):
self,
ArangeModel(),
(x,),
additional_test_inputs=[(y,)],
additional_test_inputs=[[(y,)]],
)

@pytorch_test_common.xfail(
Expand All @@ -688,11 +691,11 @@ def forward(self, x):
self,
Model(),
(x,),
additional_test_inputs=[(x2,)],
additional_test_inputs=[[(x2,)]],
)

@pytorch_test_common.xfail(
"RuntimeError: Unknown call_function target: aten.copy.default"
"RuntimeError: Unknown call_function target: aten.lift_fresh_copy.default"
)
def test_expand_as_fill_tensor(self):
class Model(torch.nn.Module):
Expand All @@ -706,14 +709,9 @@ def forward(self, x):
self,
Model(),
(x,),
additional_test_inputs=[(x2,)],
additional_test_inputs=[[(x2,)]],
)

@pytorch_test_common.skip_min_ort_version(
reason="ORT doesn't support dynamic fx exporter yet making SegFault flaky test",
version="1.15",
dynamic_only=True,
)
@pytorch_test_common.xfail(
"Unknown call_function target: aten.lift_fresh_copy.default"
)
Expand All @@ -729,7 +727,7 @@ def forward(self, x):
self,
Model(),
(x,),
additional_test_inputs=[(x2,)],
additional_test_inputs=[[(x2,)]],
)

@pytorch_test_common.skip_min_ort_version(
Expand All @@ -749,7 +747,7 @@ def forward(self, input):
self,
ViewModel(),
(x,),
additional_test_inputs=[(another_x,)],
additional_test_inputs=[[(another_x,)]],
)

@pytorch_test_common.skip_min_ort_version(
Expand All @@ -770,7 +768,7 @@ def forward(self, x):
y = torch.randn(5, 5, 4, 5)
model = MyModule()
_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
self, model, (x,), additional_test_inputs=[(y,)]
self, model, (x,), additional_test_inputs=[[(y,)]]
)

@pytorch_test_common.skip_min_ort_version(
Expand All @@ -796,6 +794,13 @@ def forward(
version="1.15",
dynamic_only=True,
)
@pytorch_test_common.skip_dynamic_fx_test(
"onnxruntime::ReshapeHelper::ReshapeHelper(const onnxruntime::TensorShape&, "
"onnxruntime::TensorShapeVector&, bool) size != 0 && "
"(input_shape.Size() % size) == 0 was false. The input tensor cannot be "
"reshaped to the requested shape. Input shape:{1,4}, requested shape:{-1,3}\n"
"fx graph captures static graph."
)
def test_gpt2_tiny(self):
model_name = "sshleifer/tiny-gpt2"
# Download pytorch model
Expand All @@ -804,8 +809,11 @@ def test_gpt2_tiny(self):

# Transform input tokens
inputs = tokenizer("Hello world!", return_tensors="pt")
another_inputs = tokenizer("Another Hello world!", return_tensors="pt")

_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(self, model, [], **inputs)
_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
self, model, [], inputs, additional_test_inputs=[[(), another_inputs]]
)

@_beartype.beartype
def _test_large_scale_exporter(
Expand Down
6 changes: 4 additions & 2 deletions torch/onnx/_internal/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,9 @@ def append_step(self, step: InputAdaptStep) -> None:
self._input_adapt_steps.append(step)

@_beartype.beartype
def apply(self, *model_args, **model_kwargs) -> Sequence[torch.Tensor]:
def apply(
self, *model_args, **model_kwargs
) -> Sequence[Union[torch.Tensor, int, float, bool]]:
"""Converts the PyTorch model inputs to exported ONNX model inputs format.
Args:
Expand Down Expand Up @@ -288,7 +290,7 @@ def model_proto(self) -> onnx.ModelProto:
@_beartype.beartype
def adapt_torch_inputs_to_onnx(
self, *model_args, **model_kwargs
) -> Sequence[torch.Tensor]:
) -> Sequence[Union[torch.Tensor, int, float, bool]]:
"""Converts the PyTorch model inputs to exported ONNX model inputs format.
Due to design differences, input/output format between PyTorch model and exported
Expand Down

0 comments on commit 81a5cd9

Please sign in to comment.