Skip to content

Commit

Permalink
[ONNX] Introduce Input/Ouptut adapter; Switch to 'DynamoExporter' (#9…
Browse files Browse the repository at this point in the history
…8421)

Summary
* Introduce input/output adapter. Due to design differences, input/output format
between PyTorch model and exported ONNX model are often not the same. E.g., `None`
inputs are allowed for PyTorch model, but are not supported by ONNX. Nested constructs
of tensors are allowed for PyTorch model, but only flattened tensors are supported by ONNX,
etc. The new input/output adapter is exported with the model. Providing an interface to
automatically convert and validate inputs/outputs format.
* As suggested by #98251,
provide extension for unwrapping user defined python classes for `dynamo.export` based
exporter. Unblock huggingface models.
* Re-wire tests to run through `DynamoExporter` w/ `dynamo_export` api. Kept
`DynamoOptimizeExporter` in the tests for now for coverage of this change.
Pull Request resolved: #98421
Approved by: https://github.com/justinchuby, https://github.com/titaiwangms, https://github.com/thiagocrepaldi
  • Loading branch information
BowenBao authored and ZainRizvi committed Apr 19, 2023
1 parent 6d10d95 commit 094ae84
Show file tree
Hide file tree
Showing 8 changed files with 1,193 additions and 249 deletions.
2 changes: 1 addition & 1 deletion docs/source/onnx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -736,4 +736,4 @@ Preview: torch.onnx TorchDynamo Exporter

torch.onnx.ExportOptions
torch.onnx.ExportOutput
torch.onnx.ExportOutputSerializer
torch.onnx.ExportOutputSerializer
5 changes: 4 additions & 1 deletion test/onnx/dynamo/test_exporter_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
from beartype import roar
from torch.onnx import dynamo_export, ExportOptions, ExportOutput
from torch.onnx._internal import exporter
from torch.onnx._internal.exporter import (
_DEFAULT_OPSET_VERSION,
ExportOutputSerializer,
Expand Down Expand Up @@ -135,7 +136,9 @@ def serialize(
def test_raise_on_invalid_save_argument_type(self):
with self.assertRaises(roar.BeartypeException):
ExportOutput(torch.nn.Linear(2, 3)) # type: ignore[arg-type]
export_output = ExportOutput(onnx.ModelProto())
export_output = ExportOutput(
onnx.ModelProto(), exporter.InputAdapter(), exporter.OutputAdapter()
)
with self.assertRaises(roar.BeartypeException):
export_output.save(None) # type: ignore[arg-type]
export_output.model_proto
Expand Down
37 changes: 35 additions & 2 deletions test/onnx/pytorch_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import random
import sys
import unittest
from typing import Optional
from typing import Mapping, Optional, Type

import numpy as np
import packaging.version

import torch
from torch.autograd import function
from torch.onnx._internal import diagnostics
from torch.onnx._internal import diagnostics, exporter
from torch.testing._internal import common_utils

pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
Expand Down Expand Up @@ -206,6 +206,39 @@ def wrapper(self, *args, **kwargs):
return skip_dec


def skip_fx_exporters(
exporter_cls_and_reason: Mapping[Optional[Type[exporter.Exporter]], str]
):
"""Skip exporting test for selected FX exporters.
Args:
exporter_cls_and_reason: Mapping from FX exporter class to skip the test to the
reason for skipping.
Returns:
A decorator for skipping exporting test for FX exporters.
"""

def skip_dec(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
for exporter_cls, reason in exporter_cls_and_reason.items():
if exporter_cls == self.exporter_cls:
exporter_name = (
exporter_cls.__name__
if exporter_cls is not None
else "dynamo_export"
)
raise unittest.SkipTest(
f"Skip verify test for '{exporter_name}'. {reason}"
)
return func(self, *args, **kwargs)

return wrapper

return skip_dec


# skips tests for opset_versions listed in unsupported_opset_versions.
# if the caffe2 test cannot be run for a specific version, add this wrapper
# (for example, an op was modified but the change is not supported in caffe2)
Expand Down

0 comments on commit 094ae84

Please sign in to comment.