Skip to content

Commit

Permalink
[ONNX] Enclose package info for modules exported as local functions (#…
Browse files Browse the repository at this point in the history
…107409)

Enclose source package of modules that are exported as onnx local function in exported onnx model. GPT2 model example:

<img width="350" alt="image" src="https://github.com/pytorch/pytorch/assets/9376104/5e361bdd-ca24-45e7-a9ba-191c35acf3bb">

Pull Request resolved: #107409
Approved by: https://github.com/justinchuby
ghstack dependencies: #107408
  • Loading branch information
BowenBao authored and pytorchmergebot committed Aug 23, 2023
1 parent 7a8db57 commit c3c1b68
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 14 deletions.
4 changes: 3 additions & 1 deletion .ci/docker/common/install_onnx.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ pip_install \
pip_install onnx-weekly==1.15.0.dev20230717

# TODO: change this when onnx-script is on testPypi
pip_install onnxscript-preview==0.1.0.dev20230809 --no-deps
# pip_install onnxscript-preview==0.1.0.dev20230809 --no-deps
# NOTE: temp change for CI to run on unpublished onnxscript PR.
pip_install "onnxscript@git+https://github.com/microsoft/onnxscript@7cf838dcdfef9d06494f21d5a55e0ba32ba548b6" --no-deps

# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/
Expand Down
10 changes: 0 additions & 10 deletions test/onnx/test_fx_op_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,10 +355,6 @@
dtypes=(torch.uint8, torch.int8, torch.int16,),
reason=onnx_test_common.reason_onnx_script_does_not_support("Add", "int8, int16"),
),
xfail(
"nn.functional.adaptive_avg_pool1d",
reason=onnx_test_common.reason_onnx_script_does_not_support("aten::div.Tensor_mode needs type promotion"),
),
xfail(
"nn.functional.adaptive_avg_pool2d",
reason=onnx_test_common.reason_onnx_script_does_not_support("RecursionError: \
Expand Down Expand Up @@ -493,12 +489,6 @@
matcher=lambda sample: sample.input[0].equal(torch.tensor([])),
reason="core dump - cat does not support zero-dim tensors yet",
),
xfail(
"div",
matcher=lambda sample: sample.kwargs.get("rounding_mode") is not None
and sample.input.dtype in onnx_test_common.INT_TYPES,
reason="rounding_mode is not yet supported",
),
xfail(
"index_put",
matcher=lambda sample: (sample.args[0][0].dtype == torch.bool)
Expand Down
2 changes: 1 addition & 1 deletion test/onnx/test_fx_to_onnx_with_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def forward(self, input):
additional_test_inputs=[((y,),)],
)

@pytorch_test_common.xfail(
@pytorch_test_common.skip_dynamic_fx_test(
"[ONNXRuntimeError] : 1 : FAIL : Non-zero status code returned while running Slice node. Name:'n13__5' Status Message:"
"slice.cc:193 FillVectorsFromInput Starts must be a 1-D array"
)
Expand Down
27 changes: 27 additions & 0 deletions torch/onnx/_internal/fx/_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import abc

import contextlib
import dataclasses
import difflib

import io
Expand All @@ -19,6 +20,32 @@
from torch.onnx._internal.fx import diagnostics, onnxfunction_dispatcher


@dataclasses.dataclass
class PackageInfo:
package_name: str
version: Optional[str]
commit_hash: Optional[str]

def to_onnx_domain_string(self) -> str:
return ".".join(
filter(None, ("pkg", self.package_name, self.version, self.commit_hash))
)

@classmethod
def from_python_class(cls, python_class: type) -> PackageInfo:
package_name = python_class.__module__.split(".")[0]
package = __import__(package_name)
version = getattr(package, "__version__", None)
# TODO: Figure out how to retrieve commit hash.
commit_hash = None
return cls(package_name, version, commit_hash)


@dataclasses.dataclass
class GraphModuleOnnxMeta:
package_info: PackageInfo


@contextlib.contextmanager
def _patch_difflib_sequence_matcher_init():
"""Context patching `difflib.SequenceMatcher` for fx readable graph.
Expand Down
23 changes: 22 additions & 1 deletion torch/onnx/_internal/fx/fx_onnx_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torch.onnx import _type_utils as jit_type_utils
from torch.onnx._internal import _beartype
from torch.onnx._internal.fx import (
_pass,
diagnostics,
onnxfunction_dispatcher,
op_validation,
Expand Down Expand Up @@ -459,8 +460,28 @@ def run(
`fx_graph_module` is a submodule. If not provided,
`fx_graph_module` is assumed to be the root module.
"""
if parent_onnxscript_graph is not None:
# If parent_onnxscript_graph is provided, we assume fx_graph_module is a
# submodule representing a forward call of an nn.Module.
# Compose package and version where the nn.Module is defined as domain name
# for the local function.

onnx_meta: Optional[_pass.GraphModuleOnnxMeta] = fx_graph_module.meta.get(
"onnx"
)
if onnx_meta is None:
raise RuntimeError(
f"ONNX meta is not found in submodule {fx_graph_module._get_name()}. "
f"Only submodules produced by `Modularize` pass is supported in ONNX export."
)

onnx_domain = onnx_meta.package_info.to_onnx_domain_string()
else:
# Leave as default domain name for the root module.
onnx_domain = None

onnxscript_graph = onnxscript_graph_building.TorchScriptGraph(
parent_onnxscript_graph
parent_onnxscript_graph, domain_name=onnx_domain
)
onnxscript_tracer = onnxscript_graph_building.TorchScriptTracingEvaluator(
onnxscript_graph
Expand Down
14 changes: 13 additions & 1 deletion torch/onnx/_internal/fx/passes/modularization.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,11 @@ def qualified_module_class_name(self) -> str:
"""Returns the qualified module class name of the top module."""
return self.top().qualified_module_class_name

@property
def module_class(self) -> Optional[type]:
"""Returns the module class of the top module."""
return self.top()._module_class


def _module_stack_meta_from_node(node: torch.fx.Node) -> _ModuleStackMeta:
return _ModuleStackMeta(node.meta.get("nn_module_stack"))
Expand Down Expand Up @@ -678,7 +683,14 @@ def _arg_transform(node: torch.fx.Node) -> torch.fx.Node:
new_outputs[0] if len(new_outputs) == 1 else new_outputs
)

return torch.fx.GraphModule(self._reference_module, fx_graph, module_class_name)
graph_module = torch.fx.GraphModule(
self._reference_module, fx_graph, module_class_name
)
if (module_class := self._stack_meta.module_class) is not None:
graph_module.meta["onnx"] = _pass.GraphModuleOnnxMeta(
_pass.PackageInfo.from_python_class(module_class)
)
return graph_module


class _LeafNode(_IRNode):
Expand Down

0 comments on commit c3c1b68

Please sign in to comment.