-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce FXGraphExtractor into torch.onnx.dynamo_export #98893
Introduce FXGraphExtractor into torch.onnx.dynamo_export #98893
Conversation
3d56ab1
to
e433956
Compare
I prefer landing #98421 first, since it's in a much more ready state. Please let me know if you have more concerns. We'd need a discussion session for many points in this one. Aaron and I were talking about how we should expose |
I will look into #98421 one more time. The aspect that I was concerned is that we were using this adapter for internal and external purposes through the same mechanism/API. The external use causes some clutters on Regarding this one, while we are in the experimental stage, having more than one FX extractor shouldn't be a problem; in fact, some things only the symbolic trace can do and we don't know how much time we will take to make dynamo export to do the same. When we can get rid of all and just stick to |
36343b7
to
976fb4c
Compare
5c0e01e
to
1be2507
Compare
1be2507
to
65a10c8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for finishing what #95651 could not!
@@ -654,6 +721,15 @@ def forward(self, x): | |||
) | |||
|
|||
@pytorch_test_common.xfail("TypeError: missing a required argument: 'end'") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Many more xfail
s needs to be removed as being replaced by skip_fx_tracer
.
One of the downside with parameterizing over tracer is skip
over xfail
which does not catch unexpected success.
I wonder if it is necessary to maintain the test coverage over DynamoOptimize
. We are not doing it for FxSymbolicTracer
. How to decide?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When it becomes a heavy burden, we decide if these tests are bringing benefit to compensate the overhead. Right now it is too early to determine benefit or overhead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are already losing functionality of xfail
. Is it worth it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate the issue and why it cannot be fixed in any other way?
import torch._dynamo # NOTE: Do not import at top level. | ||
|
||
# Even torch/__init__.py does it internally, only | ||
# Causes circular when torch._dynamo.* surfaces public facing API during `import torch` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True but not necessary for files not intended to be imported during import torch
right?
Why does this file need to be imported so early?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no sure.
before this PR, it was being globally imported in the beginning of the module. This change delays to only when it is going to be used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How can we mark which Extractor
we are actively maintaining, and which Extractor
is "almost immutable code"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is what ResolvedExportOptions
do, right? it sets the default (most used and preferred) values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm but is it ok that such information is not revealed if just looking at these classes alone?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That was a design decision made when we decided that ExportOptions
would default everything to None
and only the private ResolvedExportOptions
would actually assign it.
I am open to changes, but seems unrelated to this task
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That was not what I meant. Simply put, should we add into docstring of DynamoOptimize
and DynamoExport
saying if it is maintained?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will split this PR in 2. One for the layering fx_tracer in the exporter and another re-adding optimize.
The latter is delaying the former which adds more value :)
Heads up PRs that changes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest using ghstack to sperate this lind of big PR to multiple subPRs that we could discuss different topic in different ones, and it's also easier for you that subPRs can be merged one by one.
return graph_module | ||
|
||
|
||
class DynamoOptimize(exporter.FXGraphExtractor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you put the reason in PR description why we need this back after #99202
# NOTE: Do not import at top level. | ||
# Even torch/__init__.py does it internally, only | ||
# Causes circular when torch._dynamo.* surfaces public facing API during `import torch` | ||
import torch._dynamo |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
duplicate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from? give more details :)
@BowenBao @titaiwangms please review #99940 instead |
The current API architecture can be seen as 3 independent exporters as shown below. The public API
dynamo_export()
defaults to one of the 3 variants and the other 2 must be used by instantiating private classes:This PR refactors the API to a single public API that can use different implementations of a
FXGraphExtractor
interface, as shown below:Summary of changes:
dynamo_export
APIExportOptions
was expanded to allowfx_tracer: FXGraphExtractor
to be specified, selecting which FX graph extractor to use, according to the design proposaltorch.onnx._internal.exporter.Exporter
does not have to internally specialize for each type of FX API that the exporter might be used. This leads to a singleExporter
with manyFX graph extractors
Exporter
subclasses toFXGraphExtractor
subclasses, where they are actually consumedExporter
is a [data]class that holds export options, model and input data in a single cohesive object. Specializing it means create different exporters instead of having one exporter capable of exporting models through different options.Exporter
doesn't consume themodel_args
that caused it to specializeimport torch.onnx
to after all dynamo subcomponents, preventingtorch.onnx
to have circular depemndencies whentorch.XXXX
is imported during initializationdecomposition_table
as aExportOptions
.Exporter
to initialize such list, not allowing customization from usersExporter.model_signature
to a simple standalone helperinpect.signature
usage without any statePossible next steps are:
passes
anddispatching
from the clutteredexport_fx_to_onnx
FXGraphExtractor
public API + helper for unit testing** COPILOT SUMMARY**
🤖 Generated by Copilot at 1be2507
Summary
📝🛠️🚀
This pull request refactors the ONNX exporter code to use the FX graph extractor engine and the new
io_adapter
module, which improve the modularity, readability, and performance of the export process. It also updates the test files and theutils.py
file to reflect the changes in the exporter API and the export options. Additionally, it fixes some import issues, typos, and style inconsistencies.Walkthrough
exporter
module and its subclasses to use theFXGraphExtractor
abstract class and theio_adapter
module for input and output adaptation (link,link,link,link,link,link,link,link,link,link,link,link,link,link,link,link,link,link,link)decomposition_table
attribute and parameter to theExportOptions
and_ResolvedExportOptions
classes, which allow customizing the decomposition of ATen operators into ONNX-friendly subgraphs (link,link,link,link,link)skip_fx_tracer
decorator to thepytorch_test_common
module, which allows skipping exporting tests for selected FX tracers based on a mapping from FX tracer class to skip reason (link,link)fx_tracer
attribute and parameter to theTestFxToOnnxWithOnnxRuntime
class and its subclasses, which allow parameterizing the FX tracer class and running tests with different FX graph extractors (link,link,link,link,link,link,link,link)concrete_args
attribute and parameter to theFXSymbolicTracer
class, which allow partially specializing the model inputs and removing control flow or data structures during tracing (link,link)model_signature
function to theonnx_utils
module and theutils
module, which returns the signature of a PyTorch model or function (link,link)FXSymbolicTraceExporter
class toFXSymbolicTracer
, which is a more descriptive name for the subclass ofFXGraphExtractor
that uses thetorch.fx.symbolic_trace
API to generate FX graphs (link)FXGraphModuleExporter
class toIOAdapter
, which is a more descriptive name for the class that adapts the PyTorch model inputs and outputs to the exported ONNX model inputs and outputs format (link)_ONNX_FRIENDLY_DECOMPOSITION_TABLE
variable to_DEFAULT_ONNX_EXPORTER_DECOMPOSITION_TABLE
, which is a more descriptive name for the subset of PyTorch's built-in aten-to-aten decomposition that is compatible with ONNX export (link)torch/onnx/_internal/fx/fx_exporter.py
file totorch/onnx/_internal/fx/io_adapter.py
, which is a more descriptive name for the file that contains the classes for input and output adaptation (link)InputAdaptStep
,InputAdapter
,OutputAdaptStep
, andOutputAdapter
classes from theexporter
module to theio_adapter
module, which is a more logical place for them (link,link)DynamoOptimize
andDynamoExport
classes from thedynamo_exporter
module to thedynamo_graph_extractor
module, which is a more logical place for them (link)FXGraphModuleExporter
class from thefx_exporter
module to theio_adapter
module, which is a more logical place for it (link)export_fx_to_onnx
method from theFXSymbolicTraceExporter
class to theFXGraphExtractor
abstract class, which is a common logic that can be shared by different FX graph extractors (link)model_signature
function from thefx_serialization
module to theonnx_utils
module, which is a more logical place for it (link)torch/onnx/_internal/fx/__init__.py
file, which is not needed (link)torch/onnx/_internal/fx/dynamo_exporter.py
file, which is not needed (link)io_adapter
module instead of theexporter
module to access theInputAdapter
andOutputAdapter
classes in various places, which reflects the refactoring of the input and output adapter logic (link,link,link)fx_context
module instead of thefx
module to access theFxToOnnxContext
class in the_test_large_scale_exporter
function, which is a more specific import that avoids importing the wholefx
module (link)fx_serialization
module instead of thefx
module to access thesave_model_with_external_data
function in the_test_large_scale_exporter
function, which is a more specific import that avoids importing the wholefx
module (link)torch_ops
module instead of thetorch
module to access theOpOverload
andOpOverloadPacket
classes in various places, which are more specific imports that avoid importing the wholetorch
module (link,link,link,link,link)export_options
attribute instead of creating a newExportOptions
object in various places, which avoids duplicating the export options and uses the user-provided values (link,link)_ResolvedExportOptions
class instead of theResolvedExportOptions
class in various places, which reflects the renaming and subclassing of the export options class (link,link,link,link,link,link)adapt_input
andadapt_output
methods instead of the_apply_input_adapt_step
and_apply_output_adapt_step
methods in various places, which reflect the renaming and refactoring of the input and output adaptation logic (link,link)FXSymbolicTracer
class instead of theDynamoExporter
class in the_run_test_with_fx_to_onnx_exporter_and_onnx_runtime
function, which reflects the parameterization of the FX tracer class (link)FXGraphExtractor
class instead of theExporter
class in thedynamo_export
function, which reflects the renaming and refactoring of the exporter class (link)FXGraphExtractor
class instead of theExporter
class in theexport_options
parameter of the__init__
method of the_ResolvedExportOptions
class, which reflects the renaming and refactoring of the exporter class (link)FXGraphExtractor
class instead of theExporter
class in thefx_tracer
attribute and parameter of theTestFxToOnnxWithOnnxRuntime
class and its subclasses, which reflects the renaming and refactoring of the exporter class (link,link,link,link,link,link)generate_fx
method instead of theexport
method in various places, which reflects the renaming and refactoring of the FX graph extractor interface (link,link)skip_fx_tracer
decorator instead of thexfail
decorator in various places, which allows skipping tests only for selected FX tracers and providing reasons for skipping (link,link)skip_fx_tracer
decorator to skip tests for theDynamoExport
andDynamoOptimize
FX tracers in various places, and provide the reasons for skipping (link,link,link,link,link,link,link)# type: ignore[arg-type]
comment to suppress the type-checking error for theopset
argument of theonnxscript.script
decorator in various places, which is a false positive due to the dynamic nature of theonnxscript
module (link,link)onnx_model
parameter of thesave_model_with_external_data
function, which avoids importingonnx
at the top level and causing conflicts with the user-installedonnx
package (link)torch/onnx/__init__.py
, to follow the PEP 8 style guide for imports (link)inspect
module from the import statement intorch/onnx/_internal/exporter.py
(link)List
annotation and add theDict
annotation from the import statement intorch/onnx/_internal/exporter.py
(link)torch
module and add theType
annotation from the import statement intorch/onnx/_internal/exporter.py
(link)torch
module from the import statement intorch/onnx/_internal/exporter.py
(link)ops
module from the import statement intest/onnx/test_fx_to_onnx_with_onnxruntime.py
(link)model_signature
property from theExporter
class, which is not used by the newFXGraphExtractor
class (link)logger
property from theExporter
class, which is not used by the newdynamo_export
function (link)UnsatisfiedDependencyError
class, which is not raised by the newdynamo_export
function (link)export
abstract method from theExporter
class, which is replaced by thegenerate_fx
abstract method in theFXGraphExtractor
class (link)export_fx_to_onnx
method from theFXSymbolicTraceExporter
class, which is moved to theFXGraphExtractor
abstract class (link)import torch._dynamo
statement from the top level oftorch/onnx/_internal/fx/passes/shape_inference.py
, which could cause circular imports and unnecessary initialization of the dynamo module (link)_module_expansion_symbolic_trace
function, which uses thetorch.fx.symbolic_trace
API to trace a model with module expansion (link)fx_tracer
attribute to the file name of the ONNX model in thetearDown
function of theTestFxToOnnxWithOnnxRuntime
class, to distinguish the models exported by different FX tracers (link)opset
argument to theonnxscript.script
decorator in various places, which specifies the ONNX opset version to use for exporting the decorated functions (link,link)import torch._dynamo
statement to the_run
function of theShapeInferenceWithFakeTensor
class, which ensures thattorch._dynamo
is only imported when needed (link)fx
module to the import statement intest/onnx/test_fx_to_onnx_with_onnxruntime.py
, which contains the submodules for FX graph extraction, context, and serialization (link)exporter
module to the import statement intest/onnx/pytorch_test_common.py
, which contains theFXGraphExtractor
abstract class that is used as a base class for thefx_tracer
attribute of theTestFxToOnnxWithOnnxRuntime
class (link)functools
module and theUnion
annotation to the import statement intorch/onnx/_internal/fx/fx_symbolic_graph_extractor.py
, which are used for implementing theFXSymbolicTracer
class (link)onnx_utils
module to the import statement intorch/onnx/_internal/fx/fx_symbolic_graph_extractor.py
, which contains themodel_signature
function that is used for binding the model inputs with default values (link)Protocol
andruntime_checkable
annotations to the import statement intorch/onnx/_internal/fx/io_adapter.py
, which are used for defining structural subtyping protocols that can be checked at runtime (link)Type
annotation to the import statement intest/onnx/pytorch_test_common.py
, which is used for type-checking thefx_tracer
attribute of theTestFxToOnnxWithOnnxRuntime
class (link)