Skip to content

Commit

Permalink
Plumb various exporter strategies through new API
Browse files Browse the repository at this point in the history
We have a few experiments for exporting Dynamo and FX to ONNX
that this PR rationalizes through the new Exporter API and
adjusts tests to use the new API.

- A base FXGraphModuleExporter exporter from which all derive:
- DynamoExportExporter: uses dynamo.export to acquire FX graph
- DynamoOptimizeExporter: uses dynamo.optimize to acquire FX graph
- FXSymbolicTraceExporter: uses FX symbolic tracing

The dynamo_export API currently uses DynamoOptimizeExporter.
  • Loading branch information
abock committed Apr 4, 2023
1 parent 969fa79 commit 521f11c
Show file tree
Hide file tree
Showing 18 changed files with 356 additions and 418 deletions.
24 changes: 12 additions & 12 deletions test/onnx/dynamo/test_exporter_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from torch.onnx import dynamo_export, ExportOptions, ExportOutput
from torch.onnx._internal.exporter import (
_DEFAULT_OPSET_VERSION,
_resolve_export_options,
ExportOutputSerializer,
ProtobufExportOutputSerializer,
ResolvedExportOptions,
)

from torch.testing._internal.common_utils import TemporaryFileName
Expand All @@ -27,11 +27,11 @@ def forward(self, x):

class TestExportOptionsAPI(unittest.TestCase):
def test_opset_version_default(self):
options = _resolve_export_options(None)
options = ResolvedExportOptions(None)
self.assertEquals(options.opset_version, _DEFAULT_OPSET_VERSION)

def test_opset_version_explicit(self):
options = _resolve_export_options(ExportOptions(opset_version=3000))
options = ResolvedExportOptions(ExportOptions(opset_version=3000))
self.assertEquals(options.opset_version, 3000)

def test_raise_on_invalid_argument_type(self):
Expand All @@ -43,26 +43,26 @@ def test_raise_on_invalid_argument_type(self):
with self.assertRaises(expected_exception_type):
ExportOptions(logger="DEBUG") # type: ignore[arg-type]
with self.assertRaises(expected_exception_type):
_resolve_export_options(options=12) # type: ignore[arg-type]
ResolvedExportOptions(options=12) # type: ignore[arg-type]

def test_dynamic_shapes_default(self):
options = _resolve_export_options(None)
self.assertIsNone(options.dynamic_shapes)
options = ResolvedExportOptions(None)
self.assertFalse(options.dynamic_shapes)

def test_dynamic_shapes_explicit(self):
options = _resolve_export_options(ExportOptions(dynamic_shapes=None))
self.assertIsNone(options.dynamic_shapes)
options = _resolve_export_options(ExportOptions(dynamic_shapes=True))
options = ResolvedExportOptions(ExportOptions(dynamic_shapes=None))
self.assertFalse(options.dynamic_shapes)
options = ResolvedExportOptions(ExportOptions(dynamic_shapes=True))
self.assertTrue(options.dynamic_shapes)
options = _resolve_export_options(ExportOptions(dynamic_shapes=False))
options = ResolvedExportOptions(ExportOptions(dynamic_shapes=False))
self.assertFalse(options.dynamic_shapes)

def test_logger_default(self):
options = _resolve_export_options(None)
options = ResolvedExportOptions(None)
self.assertEquals(options.logger, logging.getLogger().getChild("torch.onnx"))

def test_logger_explicit(self):
options = _resolve_export_options(ExportOptions(logger=logging.getLogger()))
options = ResolvedExportOptions(ExportOptions(logger=logging.getLogger()))
self.assertEquals(options.logger, logging.getLogger())
self.assertNotEquals(options.logger, logging.getLogger().getChild("torch.onnx"))

Expand Down
32 changes: 20 additions & 12 deletions test/onnx/test_fx_dynamic_with_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,31 @@

import numpy as np

import onnx.reference
import onnx_test_common

import onnxruntime # type: ignore[import]

import torch
import torchvision
from torch.onnx._internal import _beartype, diagnostics, fx as fx_onnx
from torch.onnx import ExportOptions, ExportOutput
from torch.onnx._internal import _beartype, diagnostics
from torch.testing._internal import common_utils
from torch.types import Number
from torch.utils import _pytree as pytree

_NumericType = Union[Number, torch.Tensor, np.ndarray]
_ModelType = Union[torch.nn.Module, Callable]
_ONNXModelType = Union["onnx.ModelProto", bytes, str, io.BytesIO]
_InputArgsType = Union[torch.Tensor, Tuple[Any, ...]]
_OutputsType = Sequence[_NumericType]


@_beartype.beartype
def _run_ort(
onnx_model: _ONNXModelType, pytorch_inputs: _InputArgsType
export_output: ExportOutput, pytorch_inputs: _InputArgsType
) -> _OutputsType:
buffer = io.BytesIO()
export_output.save(buffer)
session = onnxruntime.InferenceSession(
onnx_model, providers=["CPUExecutionProvider"]
buffer.getvalue(), providers=["CPUExecutionProvider"]
)
input_names = [ort_input.name for ort_input in session.get_inputs()]
return session.run(
Expand Down Expand Up @@ -81,7 +81,7 @@ def _try_clone_model(model: _ModelType) -> _ModelType:

@_beartype.beartype
def compare_pytorch_onnx_with_ort(
onnx_model: Union["onnx.ModelProto", bytes],
export_output: ExportOutput,
model_input_args: _InputArgsType,
):
# Inspect the model's signature. It will be used
Expand All @@ -101,7 +101,7 @@ def compare_pytorch_onnx_with_ort(

pt_cloned_model = _try_clone_model(model)
ref_outputs, _ = pytree.tree_flatten(pt_cloned_model(*model_input_args))
ort_outputs = _run_ort(onnx_model, bound.args)
ort_outputs = _run_ort(export_output, bound.args)
for ref_output, ort_output in zip(ref_outputs, ort_outputs):
torch.testing.assert_close(
ref_output, torch.tensor(ort_output), rtol=rtol, atol=atol
Expand All @@ -110,13 +110,15 @@ def compare_pytorch_onnx_with_ort(
# Feed args and kwargs into exporter.
# Note that exporter should flatten kwargs into positional args the exported model;
# since ONNX doesn't represent kwargs.
onnx_model = fx_onnx.export_after_normalizing_args_and_kwargs(

onnx_model = torch.onnx.dynamo_export(
model,
*input_args,
opset_version=opset_version,
use_binary_format=True,
enable_dynamic_axes=True, # export models with dynamic shapes
**input_kwargs,
export_options=ExportOptions(
opset_version=opset_version,
dynamic_shapes=True,
),
)

compare_pytorch_onnx_with_ort(onnx_model, input_args)
Expand Down Expand Up @@ -148,7 +150,11 @@ def tearDown(self):
"typing.Union[float, int, str, bytes, typing.Sequence[float],"
" typing.Sequence[int], torch.Tensor], as [None, None]:"
)
# When the skip reason above is addressed, annotate this test with
# @skipIfNoTorchVision
def test_shufflenet_v2_dynamic_axes(self):
import torchvision

model = torchvision.models.shufflenet_v2_x0_5(pretrained=False)
dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True)
test_inputs = torch.randn(3, 3, 224, 224, requires_grad=True)
Expand Down Expand Up @@ -349,6 +355,7 @@ def forward(self, x):
additional_test_inputs=[(x2,)],
)

@unittest.skip("ORT segfault")
def test_expand_as_fill_seperate_tensor(self):
class Model(torch.nn.Module):
def forward(self, x):
Expand Down Expand Up @@ -377,6 +384,7 @@ def forward(self, input):
additional_test_inputs=[(another_x,)],
)

@unittest.skip("ORT segfault")
def test_flatten_dynamic_axes(self):
class MyModule(torch.nn.Module):
def forward(self, x):
Expand Down
18 changes: 10 additions & 8 deletions test/onnx/test_fx_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,27 @@
import torch
from torch import nn
from torch.nn import functional as F
from torch.onnx._internal import fx as fx_onnx
from torch.onnx import dynamo_export, ExportOptions
from torch.testing._internal import common_utils


class TestFxToOnnx(pytorch_test_common.ExportTestCase):
def setUp(self):
super().setUp()
self.opset_version = torch.onnx._constants.ONNX_DEFAULT_OPSET
self.export_options = ExportOptions()

def test_simple_function(self):
def func(x):
y = x + 1
z = y.relu()
return (y, z)

_ = fx_onnx.export(func, torch.randn(1, 1, 2), opset_version=self.opset_version)
_ = dynamo_export(
func, torch.randn(1, 1, 2), export_options=self.export_options
)

@unittest.skip(
"Conv Op is not supported at the time. https://github.com/microsoft/onnx-script/issues/397"
"max_pool2d is not supported in ATen Lib: https://github.com/microsoft/onnx-script/issues/585"
)
def test_mnist(self):
class MNISTModel(nn.Module):
Expand All @@ -48,7 +50,7 @@ def forward(self, tensor_x: torch.Tensor):
return output

tensor_x = torch.rand((64, 1, 28, 28), dtype=torch.float32)
_ = fx_onnx.export(MNISTModel(), tensor_x, opset_version=self.opset_version)
_ = dynamo_export(MNISTModel(), tensor_x, export_options=self.export_options)

def test_trace_only_op_with_evaluator(self):
model_input = torch.tensor([[1.0, 2.0, 3.0], [1.0, 1.0, 2.0]])
Expand All @@ -64,8 +66,8 @@ def forward(self, input):
torch.argmax(input, dim=1, keepdim=True),
)

_ = fx_onnx.export(
ArgminArgmaxModel(), model_input, opset_version=self.opset_version
_ = dynamo_export(
ArgminArgmaxModel(), model_input, export_options=self.export_options
)

def test_multiple_outputs_op_with_evaluator(self):
Expand All @@ -75,7 +77,7 @@ def forward(self, x):
return torch.sum(values)

x = torch.arange(1.0, 6.0, requires_grad=True)
_ = fx_onnx.export(TopKModel(), x, opset_version=self.opset_version)
_ = dynamo_export(TopKModel(), x, export_options=self.export_options)


if __name__ == "__main__":
Expand Down
80 changes: 44 additions & 36 deletions test/onnx/test_fx_to_onnx_with_onnxruntime.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
# Owner(s): ["module: onnx"]
from __future__ import annotations

import inspect

import io
import os
import tempfile
import unittest

from typing import Any, Callable, Generator, Sequence, Tuple, Union

import numpy as np

import onnx.reference
import onnx_test_common

import onnxruntime # type: ignore[import]
Expand All @@ -21,23 +18,32 @@

from torch._subclasses import fake_tensor
from torch.onnx._internal import _beartype, diagnostics, fx as fx_onnx
from torch.onnx._internal.exporter import ExportOptions, ExportOutput
from torch.onnx._internal.exporters.dynamo_optimize import DynamoOptimizeExporter
from torch.onnx._internal.exporters.fx_symbolic import FXSymbolicTraceExporter
from torch.testing._internal import common_utils
from torch.types import Number
from torch.utils import _pytree as pytree

_NumericType = Union[Number, torch.Tensor, np.ndarray]
_ModelType = Union[torch.nn.Module, Callable]
_ONNXModelType = Union["onnx.ModelProto", bytes, str, io.BytesIO]
_InputArgsType = Union[torch.Tensor, Tuple[Any, ...]]
_OutputsType = Sequence[_NumericType]


@_beartype.beartype
def _run_ort(
onnx_model: _ONNXModelType, pytorch_inputs: Union[_InputArgsType, Generator]
onnx_model: Union[str, ExportOutput],
pytorch_inputs: Union[_InputArgsType, Generator],
) -> _OutputsType:
if isinstance(onnx_model, ExportOutput):
buffer = io.BytesIO()
onnx_model.save(buffer)
ort_model = buffer.getvalue()
else:
ort_model = onnx_model
session = onnxruntime.InferenceSession(
onnx_model, providers=["CPUExecutionProvider"]
ort_model, providers=["CPUExecutionProvider"]
)
input_names = [ort_input.name for ort_input in session.get_inputs()]
return session.run(
Expand All @@ -48,7 +54,7 @@ def _run_ort(
@_beartype.beartype
def _run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
model: _ModelType,
input_args: _InputArgsType,
input_args: Sequence[_InputArgsType],
rtol: float = 1e-3,
atol: float = 1e-7,
opset_version: int = 18,
Expand All @@ -57,32 +63,25 @@ def _run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
# Feed args and kwargs into exporter.
# Note that exporter should flatten kwargs into positional args the exported model;
# since ONNX doesn't represent kwargs.
onnx_model = fx_onnx.export_after_normalizing_args_and_kwargs(
model,
*input_args,
opset_version=opset_version,
use_binary_format=True,
enable_dynamic_axes=True,
**input_kwargs,
exporter = DynamoOptimizeExporter(
options=ExportOptions(opset_version=opset_version, dynamic_shapes=True),
model=model,
model_args=input_args,
model_kwargs=input_kwargs,
)

# Inspect the model's signature. It will be used
# to flatten kwargs.
if isinstance(model, torch.nn.Module):
signature = inspect.signature(model.forward)
else:
signature = inspect.signature(model)
export_output = exporter.export()

# Bind args and kwargs to the model's signature to
# flatten kwargs into positional args since ONNX
# model cannot be called with kwargs.
bound = signature.bind(*input_args, **input_kwargs)
bound = exporter.model_signature.bind(*input_args, **input_kwargs)
# Fill optional inputs.
bound.apply_defaults()
assert not bound.kwargs

ref_outputs, _ = pytree.tree_flatten(model(*input_args, **input_kwargs))
ort_outputs = _run_ort(onnx_model, bound.args)
ort_outputs = _run_ort(export_output, bound.args)
for ref_output, ort_output in zip(ref_outputs, ort_outputs):
torch.testing.assert_close(
ref_output, torch.tensor(ort_output), rtol=rtol, atol=atol
Expand Down Expand Up @@ -147,6 +146,7 @@ def func(x, b=torch.tensor(1.0)):
func, (tensor_x,), b=torch.tensor(5.0)
)

@unittest.skip("ORT segfaults")
def test_mnist(self):
class MNISTModel(nn.Module):
def __init__(self):
Expand Down Expand Up @@ -213,13 +213,14 @@ def test_gpt2_tiny(self):

# FIXME(titaiwang): SegFault when symbolic tracing is used
# https://github.com/microsoft/onnx-script/issues/523
onnx_model = fx_onnx.export_after_normalizing_args_and_kwargs(
model,
use_binary_format=True,
opset_version=self.opset_version,
enable_dynamic_axes=False,
**inputs,
)
onnx_model = DynamoOptimizeExporter(
options=ExportOptions(
opset_version=self.opset_version, dynamic_shapes=False
),
model=model,
model_args=[],
model_kwargs=inputs,
).export()

ref_outputs, _ = pytree.tree_flatten(model(**inputs, return_dict=False))
ort_outputs = _run_ort(onnx_model, (input_ids, attention_mask))
Expand Down Expand Up @@ -290,12 +291,19 @@ def _test_large_scale_exporter(
fake_args = create_args()
# Export ONNX model without initializers while ctx.paths records
# all files that contains real initializers.
(onnx_model, _, _, _) = fx_onnx.export_without_parameters_and_buffers(
fake_model,
*fake_args,
use_binary_format=False,
opset_version=self.opset_version,
enable_dynamic_axes=enable_dynamic_axes,

onnx_model = (
FXSymbolicTraceExporter(
options=ExportOptions(
opset_version=self.opset_version,
dynamic_shapes=enable_dynamic_axes,
),
model=fake_model,
model_args=fake_args,
model_kwargs={},
)
.export()
.model_proto
)

# Tasks done by the following block.
Expand Down

0 comments on commit 521f11c

Please sign in to comment.