Skip to content

Commit

Permalink
[ONNX] Introduce Input/Ouptut formatter; Switch to 'DynamoExporter'
Browse files Browse the repository at this point in the history
ghstack-source-id: 417f3a993252b55823ff6597dd6817a0c663b5f0
Pull Request resolved: #98421
  • Loading branch information
BowenBao committed Apr 6, 2023
1 parent f98c180 commit 1678e12
Show file tree
Hide file tree
Showing 9 changed files with 911 additions and 196 deletions.
4 changes: 3 additions & 1 deletion test/onnx/dynamo/test_exporter_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ def serialize(
def test_raise_on_invalid_save_argument_type(self):
with self.assertRaises(roar.BeartypeException):
ExportOutput(torch.nn.Linear(2, 3)) # type: ignore[arg-type]
export_output = ExportOutput(onnx.ModelProto())
export_output = ExportOutput(
onnx.ModelProto(), torch.onnx.InputFormatter(), torch.onnx.OutputFormatter()
)
with self.assertRaises(roar.BeartypeException):
export_output.save(None) # type: ignore[arg-type]
export_output.model_proto
Expand Down
6 changes: 3 additions & 3 deletions test/onnx/test_fx_dynamic_with_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def compare_pytorch_onnx_with_ort(
# Note that exporter should flatten kwargs into positional args the exported model;
# since ONNX doesn't represent kwargs.

onnx_model = torch.onnx.dynamo_export(
export_output = torch.onnx.dynamo_export(
model,
*input_args,
**input_kwargs,
Expand All @@ -121,13 +121,13 @@ def compare_pytorch_onnx_with_ort(
),
)

compare_pytorch_onnx_with_ort(onnx_model, input_args)
compare_pytorch_onnx_with_ort(export_output, input_args)

# This confirms the exported mode accepts different input shapes
# when dynamic shape is enabled.
if additional_test_inputs:
for additional_input_args in additional_test_inputs:
compare_pytorch_onnx_with_ort(onnx_model, additional_input_args)
compare_pytorch_onnx_with_ort(export_output, additional_input_args)


class TestFxDynamicWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
Expand Down
235 changes: 171 additions & 64 deletions test/onnx/test_fx_to_onnx_with_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
import tempfile
import unittest

from typing import Any, Callable, Generator, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
import onnx_test_common

import onnxruntime # type: ignore[import]
import torch
import torch.onnx
Expand All @@ -23,18 +22,17 @@
from torch.onnx._internal.fx.fx_symbolic_exporter import FXSymbolicTraceExporter
from torch.testing._internal import common_utils
from torch.types import Number
from torch.utils import _pytree as pytree

_NumericType = Union[Number, torch.Tensor, np.ndarray]
_ModelType = Union[torch.nn.Module, Callable]
_InputArgsType = Union[torch.Tensor, Tuple[Any, ...]]
_InputArgsType = Optional[Union[torch.Tensor, Sequence[Any], Mapping[str, Any]]]
_OutputsType = Sequence[_NumericType]


@_beartype.beartype
def _run_ort(
onnx_model: Union[str, torch.onnx.ExportOutput],
pytorch_inputs: Union[_InputArgsType, Generator],
pytorch_inputs: Sequence[_InputArgsType],
) -> _OutputsType:
if isinstance(onnx_model, torch.onnx.ExportOutput):
buffer = io.BytesIO()
Expand All @@ -46,48 +44,80 @@ def _run_ort(
ort_model, providers=["CPUExecutionProvider"]
)
input_names = [ort_input.name for ort_input in session.get_inputs()]
if len(input_names) != len(pytorch_inputs):
raise AssertionError(
f"Expected {len(input_names)} inputs, got {len(pytorch_inputs)}"
)
return session.run(
None, {k: v.cpu().numpy() for k, v in zip(input_names, pytorch_inputs)}
)


@_beartype.beartype
def _validate_export_output(
export_output: torch.onnx.ExportOutput,
model: _ModelType,
input_args: Sequence[_InputArgsType],
input_kwargs: Mapping[str, _InputArgsType],
atol: float,
rtol: float,
):
# Format original model inputs into the format expected by exported ONNX model.
onnx_format_args = export_output.input_formatter.to_onnx(
*input_args, **input_kwargs
)

ref_outputs = export_output.output_formatter.to_onnx(
model(*input_args, **input_kwargs)
)
ort_outputs = _run_ort(export_output, onnx_format_args)
if len(ref_outputs) != len(ort_outputs):
raise AssertionError(
f"Expected {len(ref_outputs)} outputs, got {len(ort_outputs)}"
)
for ref_output, ort_output in zip(ref_outputs, ort_outputs):
torch.testing.assert_close(
ref_output, torch.tensor(ort_output), rtol=rtol, atol=atol
)


@_beartype.beartype
def _run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
model: _ModelType,
input_args: Sequence[_InputArgsType],
rtol: float = 1e-3,
atol: float = 1e-7,
opset_version: int = 18,
dynamic_shapes: bool = True,
**input_kwargs,
):
# Feed args and kwargs into exporter.
# Note that exporter should flatten kwargs into positional args the exported model;
# since ONNX doesn't represent kwargs.
exporter = DynamoOptimizeExporter(
options=torch.onnx.ExportOptions(
opset_version=opset_version, dynamic_shapes=True
export_output = torch.onnx.dynamo_export(
model,
*input_args,
**input_kwargs,
export_options=torch.onnx.ExportOptions(
opset_version=opset_version, dynamic_shapes=dynamic_shapes
),
model=model,
model_args=input_args,
model_kwargs=input_kwargs,
)

export_output = exporter.export()
_validate_export_output(export_output, model, input_args, input_kwargs, atol, rtol)

# Bind args and kwargs to the model's signature to
# flatten kwargs into positional args since ONNX
# model cannot be called with kwargs.
bound = exporter.model_signature.bind(*input_args, **input_kwargs)
# Fill optional inputs.
bound.apply_defaults()
assert not bound.kwargs
# NOTE: Temporarily run `DynamoOptimizeExporter` here as well to ensure coverage.
# Remove after `DynamoOptimizeExporter` is removed. Or refactor with parameterization.

ref_outputs, _ = pytree.tree_flatten(model(*input_args, **input_kwargs))
ort_outputs = _run_ort(export_output, bound.args)
for ref_output, ort_output in zip(ref_outputs, ort_outputs):
torch.testing.assert_close(
ref_output, torch.tensor(ort_output), rtol=rtol, atol=atol
)
export_output = DynamoOptimizeExporter(
torch.onnx.ExportOptions(
opset_version=opset_version, dynamic_shapes=dynamic_shapes
),
model=model,
model_args=input_args,
model_kwargs=input_kwargs,
).export()

_validate_export_output(export_output, model, input_args, input_kwargs, atol, rtol)


class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
Expand Down Expand Up @@ -116,20 +146,31 @@ def func(x):

_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(func, (tensor_x,))

def test_func_with_args_and_kwargs(self):
# AssertionError: Dynamo input/output is not consistent with traced input/output
# https://github.com/pytorch/pytorch/issues/96379
# TODO: `DynamoOptimizeExporter` works for this test case. Re-enable for that
# after parameterization.
@unittest.expectedFailure
def test_func_with_args_and_tensor_kwargs(self):
# Non-tensor optional kwargs are always folded into constant and
# removed from input list in Dynamo-traced graph, so we can't
# define a function like
# removed from input list in Dynamo-traced graph, if its value is not provided
# to tracer. So for a function like
# def func(x, b=1.0)
# here. E.g., if you change the `b` to 1.0 below, it will complain
# here. E.g., if you first Dynamo-trace the model with arguments (x,),
# and then call the traced graph with arguments (x, b=2.0), it will complain
# somewhere that model is called with extra args because the modified
# function is traced into
# def forward(self, x : torch.Tensor):
# add = x + 1.0; x = None
# relu = add.relu()
# return (add, relu)
# To summarize, optional kwargs must be tensors; otherwise, they are
# treated as in-graph constants in Dynamo.
# To summarize, in order to be traced as graph input, the value of optional kwarg
# must be provided. Otherwise, they are treated as in-graph constants in Dynamo.
# Tensor optional kwargs are an exception. It is always traced as input.
# It is unclear if this behavior is intended or not. But in general it is bad
# practice to set mutable default values.
# `DynamoOptimizeExporter` applies a workaround by binding args and kwargs to
# model signature and fill in the default values of unprovided optional arguments.
def func(x, b=torch.tensor(1.0)):
y = x + b
z = y.relu()
Expand All @@ -148,6 +189,75 @@ def func(x, b=torch.tensor(1.0)):
func, (tensor_x,), b=torch.tensor(5.0)
)

# beartype.roar.BeartypeCallHintParamViolation:
# @beartyped onnxscript.function_libs.torch_aten.graph_building.TorchScriptGraph.add_input()
# parameter input_value=8.0 violates type hint typing.Union[torch.Tensor, NoneType],
# as float 8.0 not <class "builtins.NoneType"> or <protocol "torch.Tensor">.
@unittest.expectedFailure
def test_func_with_args_and_kwargs(self):
def func(x, b=1.0):
y = x + b
z = y.relu()
return (y, z)

tensor_x = torch.randn(1, 1, 2, dtype=torch.float32)

_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(func, (tensor_x,))
# Test with only positional args.
_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(func, (tensor_x, 8.0))
# Test while specifying optional kwarg.
_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(func, (tensor_x,), b=5.0)

def test_func_with_nested_input_structure(self):
def func(
x_dict: Dict[str, torch.Tensor],
y_tuple: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
z_list: List[List[torch.Tensor]],
):
if "a" in x_dict:
x = x_dict["a"]
elif "b" in x_dict:
x = x_dict["b"]
else:
x = torch.randn(3)

y1, (y2, y3) = y_tuple

z = x + y1 + y2 + y3
for z_sub_list in z_list:
z = z + torch.stack(z_sub_list).sum()

return z

# NOTE: `DynamoOptimizeExporter` fails if used argument 'c' is passed in.
x_dict = {"a": torch.randn(3)} # , "c": torch.randn(3)}
y_tuple = (torch.randn(3), (torch.randn(3), torch.randn(3)))
z_list = [
[torch.randn(3), torch.randn(3)],
[torch.randn(3), torch.randn(3), torch.randn(3)],
]
_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
func, (x_dict, y_tuple, z_list)
)

def test_func_with_nested_output_structure(self):
def func(x, y, z):
x = x + y
y = y + z
z = x + y
out1 = (x, (y, z))
out2 = [[x, y], [y, z]]
out3 = {"z": z, "x": x}
return out1 # , out2, out3

x = torch.randn(3)
y = torch.randn(3)
z = torch.randn(3)
# NOTE: `DynamoOptimizeExporter` fails if `, out2, out3` is uncommented and returned.
# It does not capture the output structure, which is the non computation part of
# the graph. It only sets `(x, y, z)` as output.
_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(func, (x, y, z))

@unittest.skip("ORT segfaults")
def test_mnist(self):
class MNISTModel(nn.Module):
Expand Down Expand Up @@ -202,6 +312,19 @@ def forward(self, x):

_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(SigmoidAddModel(), (x,))

def test_none_input(self):
class NoneInputModel(torch.nn.Module):
def forward(
self, x: torch.Tensor, y: Optional[torch.Tensor], z: torch.Tensor
):
if y is None:
return x + z
return x + y + z

_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
NoneInputModel(), (torch.randn(1, 2), None, torch.randn(1, 2))
)

def test_gpt2_tiny(self):
model_name = "sshleifer/tiny-gpt2"
# Download pytorch model
Expand All @@ -210,26 +333,12 @@ def test_gpt2_tiny(self):

# Transform input tokens
inputs = tokenizer("Hello world!", return_tensors="pt")
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]

# FIXME(titaiwang): SegFault when symbolic tracing is used
# https://github.com/microsoft/onnx-script/issues/523
onnx_model = DynamoOptimizeExporter(
options=torch.onnx.ExportOptions(
opset_version=self.opset_version, dynamic_shapes=False
),
model=model,
model_args=[],
model_kwargs=inputs,
).export()

ref_outputs, _ = pytree.tree_flatten(model(**inputs, return_dict=False))
ort_outputs = _run_ort(onnx_model, (input_ids, attention_mask))
assert len(ref_outputs) == len(ort_outputs)
assert len(ref_outputs) == 5
for ref_output, ort_output in zip(ref_outputs, ort_outputs):
torch.testing.assert_close(ref_output, torch.tensor(ort_output))
_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
model, [], **inputs, dynamic_shapes=False
)

@_beartype.beartype
def _test_large_scale_exporter(
Expand Down Expand Up @@ -294,19 +403,17 @@ def _test_large_scale_exporter(
# Export ONNX model without initializers while ctx.paths records
# all files that contains real initializers.

onnx_model = (
FXSymbolicTraceExporter(
options=torch.onnx.ExportOptions(
opset_version=self.opset_version,
dynamic_shapes=enable_dynamic_axes,
),
model=fake_model,
model_args=fake_args,
model_kwargs={},
)
.export()
.model_proto
)
export_output = FXSymbolicTraceExporter(
options=torch.onnx.ExportOptions(
opset_version=self.opset_version,
dynamic_shapes=enable_dynamic_axes,
),
model=fake_model,
model_args=fake_args,
model_kwargs={},
).export()

onnx_model = export_output.model_proto

# Tasks done by the following block.
# 1. Iterate through all tensors stored in ctx.paths (the file content is loaded torch.load)
Expand All @@ -331,9 +438,9 @@ def _test_large_scale_exporter(
args = create_args()
kwargs = create_pytorch_only_kwargs()
# Original outputs.
ref_outputs, _ = pytree.tree_flatten(model(*args, **kwargs))
ref_outputs = export_output.output_formatter.to_onnx(model(*args, **kwargs))
# ORT outputs.
args_not_none = (arg for arg in args if arg is not None)
args_not_none = export_output.input_formatter.to_onnx(*args)
ort_outputs = _run_ort(
os.path.join(tmp_folder, onnx_model_location),
args_not_none,
Expand Down Expand Up @@ -361,7 +468,7 @@ def forward(self, tensor_x: torch.Tensor):
tensor_x = self.fc2(tensor_x)
tensor_x = torch.sigmoid(tensor_x)
output = self.fc3(tensor_x)
return output
return (output, (output, output))

def create_model() -> nn.Module:
return MLPModel()
Expand Down

0 comments on commit 1678e12

Please sign in to comment.