Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Add additional_test_kwargs into test_fx_to_onnx_with_onnxruntime.py #99434

Closed
wants to merge 5 commits into from
Closed
77 changes: 32 additions & 45 deletions test/onnx/test_fx_to_onnx_with_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@

_NumericType = Union[Number, torch.Tensor, np.ndarray]
_ModelType = Union[torch.nn.Module, Callable]
_InputArgsType = Optional[Union[torch.Tensor, Sequence[Any], Mapping[str, Any]]]
_InputArgsType = Optional[
Union[torch.Tensor, int, float, bool, Sequence[Any], Mapping[str, Any]]
]
_OutputsType = Sequence[_NumericType]


Expand Down Expand Up @@ -332,7 +334,6 @@ def func(x, b=torch.tensor(1.0)):
return (y, z)

tensor_x = torch.randn(1, 2, 3, dtype=torch.float32)
another_x = torch.randn(2, 4, 3, dtype=torch.float32)

# Test without providing optional kwarg.
_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(self, func, (tensor_x,))
Expand All @@ -345,17 +346,11 @@ def func(x, b=torch.tensor(1.0)):
self, func, (tensor_x,), {"b": torch.tensor(5.0)}
)

# beartype.roar.BeartypeCallHintParamViolation:
# @beartyped onnxscript.function_libs.torch_aten.graph_building.TorchScriptGraph.add_input()
# parameter input_value=8.0 violates type hint typing.Union[torch.Tensor, NoneType],
# as float 8.0 not <class "builtins.NoneType"> or <protocol "torch.Tensor">.
# @unittest.expectedFailure
@pytorch_test_common.xfail(
"beartype.roar.BeartypeCallHintReturnViolation: @beartyped "
"torch.onnx._internal.exporter.ExportOutput.adapt_torch_inputs_to_onnx() "
"return (tensor([[[ 1.5410, -0.2934]]]), 8.0) violates type hint "
"typing.Sequence[torch.Tensor], as tuple index 1 item float 8.0 not "
"instance of <protocol 'torch.Tensor'>."
"AssertionError: Expected 1 inputs, got 2"
"Captured fx graph does not have any information of the constant input (arg1). "
"The constant input is inserted into op.target args directly."
"This might be a bug in fx tracer regarding potential break in dynamic shapes."
)
@pytorch_test_common.skip_min_ort_version(
reason="ORT doesn't support dynamic fx exporter yet making SegFault flaky test",
Expand All @@ -369,13 +364,20 @@ def func(x, b=1.0):
return (y, z)

tensor_x = torch.randn(1, 1, 2, dtype=torch.float32)
another_x = torch.randn(2, 2, 4, dtype=torch.float32)

_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(self, func, (tensor_x,))
# Test with only positional args.
_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(self, func, (tensor_x, 8.0))
_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
self, func, (tensor_x, 8.0), additional_test_inputs=[[(another_x, 9.0)]]
)
# Test while specifying optional kwarg.
_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
self, func, (tensor_x,), b=5.0
self,
func,
(tensor_x,),
{"b": 5.0},
additional_test_inputs=[[(another_x,), {"b": 6.0}]],
)

@pytorch_test_common.skip_min_ort_version(
Expand Down Expand Up @@ -464,25 +466,6 @@ def forward(self, tensor_x: torch.Tensor):
self, MNISTModel(), (tensor_x,)
)

@pytorch_test_common.skip_min_ort_version(
reason="ORT doesn't support dynamic fx exporter yet making SegFault flaky test",
version="1.15",
dynamic_only=True,
)
# test single op with no kwargs
def test_sigmoid(self):
BowenBao marked this conversation as resolved.
Show resolved Hide resolved
x = torch.randn(1, 4, 2, 3)

class SigmoidModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.sigmoid = torch.nn.Sigmoid()

def forward(self, x):
return self.sigmoid(x)

_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(self, SigmoidModel(), (x,))

@pytorch_test_common.xfail(
"RuntimeError: false INTERNAL ASSERT FAILED at "
"'/home/titaiwang/pytorch/build/aten/src/ATen/RegisterFunctionalization_0.cpp':3725,"
Expand Down Expand Up @@ -557,6 +540,7 @@ def forward(self, x, y):
"[ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Non-zero status code returned "
"while running Expand node. Name:'_0x55b501ebaf10_n2' "
"Status Message: invalid expand shape"
"https://github.com/pytorch/pytorch/issues/99360"
)
@pytorch_test_common.skip_min_ort_version(
reason="ORT doesn't support dynamic fx exporter yet making SegFault flaky test",
Expand All @@ -577,8 +561,8 @@ def forward(self, x, y):
self, DynamicMatMul(), (x, y), additional_test_inputs=[[(input_x, input_y)]]
)

@pytorch_test_common.xfail(
"RuntimeError: Unknown call_function target: aten.scalar_tensor.default"
@pytorch_test_common.skip_dynamic_fx_test(
"fx graph does not capture symbolic value for aten::scalar_tensor."
)
def test_scalar_tensor(self):
class test(torch.nn.Module):
Expand Down Expand Up @@ -633,9 +617,7 @@ def forward(self, d1, d2):
)

@pytorch_test_common.skip_dynamic_fx_test(
"AssertionError: The values for attribute 'shape' do not match:"
" torch.Size([5, 6, 2]) != torch.Size([4, 4, 2]). Even symbolic "
"fx.graph can't get dynamic arguments from this Module."
"https://github.com/pytorch/pytorch/issues/99360"
)
def test_slice(self):
class DynamicSliceExportMod(torch.nn.Module):
Expand Down Expand Up @@ -671,7 +653,10 @@ def forward(self, x):
self, MutationModel(), (torch.randn(12),), has_mutation=True
)

@pytorch_test_common.xfail("TypeError: missing a required argument: 'end'")
# TODO(justinchuby): A known limitation in aten::arange support.
justinchuby marked this conversation as resolved.
Show resolved Hide resolved
@pytorch_test_common.xfail(
"arange overload does not support positional 'end' argument"
)
def test_arange(self):
class ArangeModel(torch.nn.Module):
def forward(self, input):
Expand Down Expand Up @@ -710,7 +695,7 @@ def forward(self, x):
)

@pytorch_test_common.xfail(
"RuntimeError: Unknown call_function target: aten.copy.default"
"RuntimeError: Unknown call_function target: aten.lift_fresh_copy.default"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👀 is this a new op in the family

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it's just not implemented yet.

)
def test_expand_as_fill_tensor(self):
class Model(torch.nn.Module):
Expand All @@ -727,11 +712,6 @@ def forward(self, x):
additional_test_inputs=[[(x2,)]],
)

@pytorch_test_common.skip_min_ort_version(
reason="ORT doesn't support dynamic fx exporter yet making SegFault flaky test",
version="1.15",
dynamic_only=True,
)
@pytorch_test_common.xfail(
"Unknown call_function target: aten.lift_fresh_copy.default"
)
Expand Down Expand Up @@ -814,6 +794,13 @@ def forward(
version="1.15",
dynamic_only=True,
)
@pytorch_test_common.skip_dynamic_fx_test(
"onnxruntime::ReshapeHelper::ReshapeHelper(const onnxruntime::TensorShape&, "
"onnxruntime::TensorShapeVector&, bool) size != 0 && "
"(input_shape.Size() % size) == 0 was false. The input tensor cannot be "
"reshaped to the requested shape. Input shape:{1,4}, requested shape:{-1,3}\n"
"fx graph captures static graph."
)
def test_gpt2_tiny(self):
model_name = "sshleifer/tiny-gpt2"
# Download pytorch model
Expand Down
6 changes: 4 additions & 2 deletions torch/onnx/_internal/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,9 @@ def append_step(self, step: InputAdaptStep) -> None:
self._input_adapt_steps.append(step)

@_beartype.beartype
def apply(self, *model_args, **model_kwargs) -> Sequence[torch.Tensor]:
def apply(
self, *model_args, **model_kwargs
) -> Sequence[Union[torch.Tensor, int, float, bool]]:
"""Converts the PyTorch model inputs to exported ONNX model inputs format.

Args:
Expand Down Expand Up @@ -288,7 +290,7 @@ def model_proto(self) -> onnx.ModelProto:
@_beartype.beartype
def adapt_torch_inputs_to_onnx(
self, *model_args, **model_kwargs
) -> Sequence[torch.Tensor]:
) -> Sequence[Union[torch.Tensor, int, float, bool]]:
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
"""Converts the PyTorch model inputs to exported ONNX model inputs format.

Due to design differences, input/output format between PyTorch model and exported
Expand Down
2 changes: 2 additions & 0 deletions torch/onnx/_internal/fx/function_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def aten_getitem(self, i):
"aten::clamp": ops.core.aten_clamp,
"aten::clone": ops.core.aten_clone,
"aten::convolution": ops.core.aten_convolution,
"aten::copy": ops.core.aten_copy,
"aten::cos": ops.core.aten_cos,
"aten::cosh": ops.core.aten_cosh,
"aten::cumsum": ops.core.aten_cumsum,
Expand Down Expand Up @@ -122,6 +123,7 @@ def aten_getitem(self, i):
"aten::round": ops.core.aten_round,
"aten::rsqrt": ops.core.aten_rsqrt,
"aten::rsub": ops.core.aten_rsub,
"aten::scalar_tensor": ops.core.aten_scalar_tensor,
"aten::select": ops.core.aten_select,
"aten::selu": ops.core.aten_selu,
"aten::sigmoid": ops.core.aten_sigmoid,
Expand Down
4 changes: 3 additions & 1 deletion torch/onnx/_internal/fx/passes/fx_to_onnxscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,9 @@ def filter_incompatible_and_dtype_convert_kwargs(kwargs):
continue
if key == "dtype":
if value is None:
filtered["dtype"] = -1
# We omit if dtype is not provided, because onnxscript handles the
# default case.
continue
else:
filtered["dtype"] = int(
_type_utils.JitScalarType.from_dtype(value).onnx_type()
Expand Down