Skip to content

Commit

Permalink
[ONNX] Introduce Input/Ouptut formatter; Switch to 'DynamoExporter'
Browse files Browse the repository at this point in the history
ghstack-source-id: c89f1cda4c217bb1cd593c13993eedf0c1e72b6a
Pull Request resolved: #98421
  • Loading branch information
BowenBao committed Apr 14, 2023
1 parent 0962114 commit cb3824d
Show file tree
Hide file tree
Showing 8 changed files with 1,188 additions and 257 deletions.
2 changes: 1 addition & 1 deletion docs/source/onnx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -736,4 +736,4 @@ Preview: torch.onnx TorchDynamo Exporter

torch.onnx.ExportOptions
torch.onnx.ExportOutput
torch.onnx.ExportOutputSerializer
torch.onnx.ExportOutputSerializer
5 changes: 4 additions & 1 deletion test/onnx/dynamo/test_exporter_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
from beartype import roar
from torch.onnx import dynamo_export, ExportOptions, ExportOutput
from torch.onnx._internal import exporter
from torch.onnx._internal.exporter import (
_DEFAULT_OPSET_VERSION,
ExportOutputSerializer,
Expand Down Expand Up @@ -135,7 +136,9 @@ def serialize(
def test_raise_on_invalid_save_argument_type(self):
with self.assertRaises(roar.BeartypeException):
ExportOutput(torch.nn.Linear(2, 3)) # type: ignore[arg-type]
export_output = ExportOutput(onnx.ModelProto())
export_output = ExportOutput(
onnx.ModelProto(), exporter.InputFormatter(), exporter.OutputFormatter()
)
with self.assertRaises(roar.BeartypeException):
export_output.save(None) # type: ignore[arg-type]
export_output.model_proto
Expand Down
36 changes: 34 additions & 2 deletions test/onnx/pytorch_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import random
import sys
import unittest
from typing import Optional
from typing import Mapping, Optional, Type

import numpy as np
import packaging.version

import torch
from torch.autograd import function
from torch.onnx._internal import diagnostics
from torch.onnx._internal import diagnostics, exporter
from torch.testing._internal import common_utils

pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
Expand Down Expand Up @@ -206,6 +206,38 @@ def wrapper(self, *args, **kwargs):
return skip_dec


def skip_fx_exporters(
exporter_cls_and_reason: Mapping[Optional[Type[exporter.Exporter]], str]
):
"""Skip exporting test for selected FX exporters.
Args:
exporter_cls_and_reason: Mapping from FX exporter class to skip the test to the
reason for skipping.
Returns:
A decorator for skipping exporting test for FX exporters.
"""

def skip_dec(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
for exporter_cls, reason in exporter_cls_and_reason.items():
if exporter_cls == self.exporter_cls:
exporter_name = (
exporter_cls.__name__
if exporter_cls is not None
else "dynamo_export"
)
raise unittest.SkipTest(
f"Skip verify test for '{exporter_name}'. {reason}"
)

return wrapper

return skip_dec


# skips tests for opset_versions listed in unsupported_opset_versions.
# if the caffe2 test cannot be run for a specific version, add this wrapper
# (for example, an op was modified but the change is not supported in caffe2)
Expand Down

0 comments on commit cb3824d

Please sign in to comment.