New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] Restore readable names for parameters and buffers #104493
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/104493
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit c737bb9: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you adding unit test to check names are actually being preserved?
This PR introduces a new pass that restore parameter and buffer names from original module. It is useful for readability of the exported ONNX graph. It restores the parameter and buffer names from the original module. For example, if the original module has a parameter named `root.linear.0.weight`, and the parameter is renamed to `_param_constant9` by FX, this pass will rename it back. [ghstack-poisoned]
@wschin @thiagocrepaldi updated PR, PTAL. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
This PR introduces a new pass that restore parameter and buffer names from original module. It is useful for readability of the exported ONNX graph. It restores the parameter and buffer names from the original module. For example, if the original module has a parameter named `root.linear.0.weight`, and the parameter is renamed to `_param_constant9` by FX, this pass will rename it back. [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice
… buffers" This PR introduces a new pass that restore parameter and buffer names from original module. It is useful for readability of the exported ONNX graph. It restores the parameter and buffer names from the original module. For example, if the original module has a parameter named `root.linear.0.weight`, and the parameter is renamed to `_param_constant9` by FX, this pass will rename it back. [ghstack-poisoned]
This PR introduces a new pass that restore parameter and buffer names from original module. It is useful for readability of the exported ONNX graph. It restores the parameter and buffer names from the original module. For example, if the original module has a parameter named `root.linear.0.weight`, and the parameter is renamed to `_param_constant9` by FX, this pass will rename it back. [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Besides fx_module
and nn_module
, the implementation also created different references for them. original_module
as another name for nn_module
and module
as another name for fx_name
. The first two names are clearer
|
||
For each `get_attr` node, if the target is a str representing a parameter or buffer | ||
under `self.module`, we rename the parameter or buffer to its original name. | ||
The parameters and buffers between `self.module` and `self.original_module` refer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameters and buffers between `self.module` and `self.original_module` refer | |
The parameters and buffers between `self.fx_module` and `self.nn_module` refer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll fix these in a follow up to merge this first to unblock you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvm I think better to rebase to be safe.. Another PR with pass was merged.
"""Restore parameter and buffer names from original module. | ||
|
||
For each `get_attr` node, if the target is a str representing a parameter or buffer | ||
under `self.module`, we rename the parameter or buffer to its original name. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
under `self.module`, we rename the parameter or buffer to its original name. | |
under `self.fx_module`, we rename the parameter or buffer to its original name. |
there is a bunch of old self.module
going on here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can't do anything, this is in base class
nn_module: torch.nn.Module, | ||
): | ||
super().__init__(diagnostic_context, fx_module) | ||
self.original_module = nn_module |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.original_module = nn_module | |
self.nn_module = nn_module |
# TODO(bowbao): fix #104670 and replace "." with "/" to avoid collision. | ||
normalized_name = new_name.replace(".", "_") | ||
attr_value = getattr(self.module, old_name) | ||
setattr(self.module, normalized_name, attr_value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
setattr(self.module, normalized_name, attr_value) | |
setattr(self.fx_module, normalized_name, attr_value) |
This PR introduces a new pass that restore parameter and buffer names from original module. It is useful for readability of the exported ONNX graph. It restores the parameter and buffer names from the original module. For example, if the original module has a parameter named `root.linear.0.weight`, and the parameter is renamed to `_param_constant9` by FX, this pass will rename it back. [ghstack-poisoned]
## Context prior to this PR #100017 was merged onto PyTorch `main` branch with the goal of enabling `torch._dynamo.export` to perform symbolic tracing. In that context, symbolic tracing is defined as tracing of a model using fake inputs and weights. An input is Fake when `torch.nn.Tensor` is replaced by `torch._subclasses.FakeTensor`, whereas a weight is fake when a `torch.nn.Parameter` is replaced by `torch._subclasses.FakeTensor`. For additional context, several strategies were discussed with Meta to enable this feature, including 1) calling `torch._dynamo.export` within a `torch._subclass.FakeTensorMode` context and 2) **fake**fying input and model as separate step and then call `torch._dynamo.export` without an active `torch._subclass.FakeTensorMode` context. At the end, 2) was preferred and implemented by #100017 to minimize the number of side-effects the fake tensor mode has on the code base. As a consequence, `torch._dynamo.export` API introduced a new argument called `fake_mode`. When symbolic tracing is used, the user must pass in the `fake_mode` used to fakefy both the input and the model. Internally, `torch._dynamo.export` will adopt this `fake_mode` instead of creating its own instance. This is needed because each instance of `FakeTensorMode` has metadata on the tensor/parameter it fakefied. Thus, using real tensor/model and specify a `fake_mode` to `torch._dynamo.export` is an error. Also, specify a `fake_mode` instance to `torch._dynamo.export` different than the one used to fakefy the model and input is also an error. ## Changes introduced from this PR This PR is intended to integrate `torch._dynamo.export(fake_mode=...)` through `torch.onnx.dynamo_export`. In essence, it * Introduces a new public API `ONNXFakeContext` which wraps a `FakeTensorMode` under the hood. This removes complexity from the user side while still allow the exporter to leverage the fake mode. * Adds a new public API `enable_fake_mode` *context manager* that instantiates and return a `ONNXFakeContext`. * Adds a new `ExportOptions.fake_context` that will be used to persist the `ONNXFakeContext` created by `enable_fake_mode` and plumb through until it reaches the call to `torch._dynamo.export`. * Adds a `model_state_dict` argument to `ExportOutput.save` API. * When model is exported with fake tensors, no actual data exist in the FX module and, therefore, in the ONNX graph. * In fact, `torch.fx.make_fx` lifts initializers as model input when fake tensors are used * #104493 is needed to enforce name matching between Parameters and inputs * A model checkpoint file or state_dict is needed to populate the ONNX graph with real initializers through `export_output.save(model_state_dict=...)` API Symbolic tracing, or onnx fake mode, is only enabled when the user instantiates the input and model within the `enable_fake_mode` context. Otherwise, real tracing is done, which preserves the current behavior. ## Usability Because symbolic tracing depends a lot on having changes made on Dynamo side before it can be consumed on ONNX exporter, this feature may have its API and assumptions changed as symbolic tracing matures upstream. Nonetheless, it is still important to have this feature merged ASAP on the ONNX exporter side to "lock" changes on Dynamo that would otherwise break ONNX exporter without warning. Example: ```python class Model(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(2, 2) def forward(self, x): out = self.linear(x) return out with torch.onnx.enable_fake_mode() as fake_context: x = torch.rand(5, 2, 2) model = Model() # Export the model with fake inputs and parameters export_options = ExportOptions(fake_context=fake_context) export_output = torch.onnx.dynamo_export( model, x, export_options=export_options ) model_state_dict = Model().state_dict() # optional export_output.save("/path/to/model.onnx", model_state_dict=model_state_dict) ``` ## Next steps * Add unit tests running the exported model with ORT Today this is not possible yet because `make_fx` used by our Decomposition pass lifts initializers as model inputs. However, the initializer names are not preserved by FX tracing, causing a mismatch between the initializer and input name. #104493 and #104741 should fix the initializer mismatch, enabling model execution * Revisit `ONNXTorchPatcher` and how the ONNX initializers are saved in the graph as external data We can try to get rid of the PyTorch patcher. If we can't, we might prefer to create specific patchers, say `FXSymbolicTracePatcher` used specifically during an export using `torch.fx.symbolic_trace` and maybe a `ExportOutputSavePacther` used specifically for `ExportOutput.save` to prevent "patching too many pytorch API that we don't need ## References * [FakeTensor implementation](https://github.com/pytorch/pytorch/blob/main/torch/_subclasses/fake_tensor.py) * [PR that adds fake tensor support to torch._dynamo.export](#100017) * [Short fake tensor documentation](https://pytorch.org/torchdistx/latest/fake_tensor.html) Pull Request resolved: #103865 Approved by: https://github.com/BowenBao
Stack from ghstack (oldest at bottom):
This PR introduces a new pass that restore parameter and buffer names from original module.
It is useful for readability of the exported ONNX graph. It restores the parameter and buffer names from the original module. For example, if the original module has a parameter named
root.linear.0.weight
, and the parameter is renamed to_param_constant9
by FX, this pass will rename it back.