Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Restore readable names for parameters and buffers #104493

Closed
wants to merge 7 commits into from

Conversation

BowenBao
Copy link
Collaborator

@BowenBao BowenBao commented Jul 1, 2023

Stack from ghstack (oldest at bottom):

This PR introduces a new pass that restore parameter and buffer names from original module.

It is useful for readability of the exported ONNX graph. It restores the parameter and buffer names from the original module. For example, if the original module has a parameter named root.linear.0.weight, and the parameter is renamed to
_param_constant9 by FX, this pass will rename it back.

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 1, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/104493

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c737bb9:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: onnx torch.onnx related changes that should show up in the release notes label Jul 1, 2023
@BowenBao BowenBao added the topic: improvements topic category label Jul 1, 2023
@BowenBao BowenBao marked this pull request as ready for review July 5, 2023 16:43
@BowenBao BowenBao marked this pull request as draft July 5, 2023 16:45
@BowenBao BowenBao marked this pull request as ready for review July 5, 2023 17:00
@BowenBao BowenBao requested a review from wschin July 5, 2023 17:00
@thiagocrepaldi thiagocrepaldi added the module: onnx Related to torch.onnx label Jul 5, 2023
Copy link
Collaborator

@thiagocrepaldi thiagocrepaldi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you adding unit test to check names are actually being preserved?


This PR introduces a new pass that restore parameter and buffer names from original module.

It is useful for readability of the exported ONNX graph. It restores the parameter and buffer names from the original module. For example, if the original module has a parameter named `root.linear.0.weight`, and the parameter is renamed to
  `_param_constant9` by FX, this pass will rename it back.

[ghstack-poisoned]
@BowenBao
Copy link
Collaborator Author

BowenBao commented Jul 5, 2023

@wschin @thiagocrepaldi updated PR, PTAL.

test/onnx/test_fx_to_onnx.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@wschin wschin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.


This PR introduces a new pass that restore parameter and buffer names from original module.

It is useful for readability of the exported ONNX graph. It restores the parameter and buffer names from the original module. For example, if the original module has a parameter named `root.linear.0.weight`, and the parameter is renamed to
  `_param_constant9` by FX, this pass will rename it back.

[ghstack-poisoned]
Copy link
Collaborator

@thiagocrepaldi thiagocrepaldi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

torch/onnx/_internal/fx/passes/readability.py Outdated Show resolved Hide resolved
torch/onnx/_internal/fx/passes/readability.py Outdated Show resolved Hide resolved
torch/onnx/_internal/fx/passes/readability.py Outdated Show resolved Hide resolved
… buffers"



This PR introduces a new pass that restore parameter and buffer names from original module.

It is useful for readability of the exported ONNX graph. It restores the parameter and buffer names from the original module. For example, if the original module has a parameter named `root.linear.0.weight`, and the parameter is renamed to
  `_param_constant9` by FX, this pass will rename it back.

[ghstack-poisoned]

This PR introduces a new pass that restore parameter and buffer names from original module.

It is useful for readability of the exported ONNX graph. It restores the parameter and buffer names from the original module. For example, if the original module has a parameter named `root.linear.0.weight`, and the parameter is renamed to
  `_param_constant9` by FX, this pass will rename it back.

[ghstack-poisoned]
@BowenBao BowenBao added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 7, 2023
Copy link
Collaborator

@thiagocrepaldi thiagocrepaldi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides fx_module and nn_module, the implementation also created different references for them. original_module as another name for nn_module and module as another name for fx_name. The first two names are clearer


For each `get_attr` node, if the target is a str representing a parameter or buffer
under `self.module`, we rename the parameter or buffer to its original name.
The parameters and buffers between `self.module` and `self.original_module` refer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The parameters and buffers between `self.module` and `self.original_module` refer
The parameters and buffers between `self.fx_module` and `self.nn_module` refer

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll fix these in a follow up to merge this first to unblock you.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm I think better to rebase to be safe.. Another PR with pass was merged.

"""Restore parameter and buffer names from original module.

For each `get_attr` node, if the target is a str representing a parameter or buffer
under `self.module`, we rename the parameter or buffer to its original name.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
under `self.module`, we rename the parameter or buffer to its original name.
under `self.fx_module`, we rename the parameter or buffer to its original name.

there is a bunch of old self.module going on here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't do anything, this is in base class

torch/onnx/_internal/fx/passes/readability.py Show resolved Hide resolved
nn_module: torch.nn.Module,
):
super().__init__(diagnostic_context, fx_module)
self.original_module = nn_module
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.original_module = nn_module
self.nn_module = nn_module

# TODO(bowbao): fix #104670 and replace "." with "/" to avoid collision.
normalized_name = new_name.replace(".", "_")
attr_value = getattr(self.module, old_name)
setattr(self.module, normalized_name, attr_value)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
setattr(self.module, normalized_name, attr_value)
setattr(self.fx_module, normalized_name, attr_value)


This PR introduces a new pass that restore parameter and buffer names from original module.

It is useful for readability of the exported ONNX graph. It restores the parameter and buffer names from the original module. For example, if the original module has a parameter named `root.linear.0.weight`, and the parameter is renamed to
  `_param_constant9` by FX, this pass will rename it back.

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Jul 11, 2023
## Context prior to this PR

#100017 was merged onto PyTorch `main` branch with the goal of enabling `torch._dynamo.export` to perform symbolic tracing.
In that context, symbolic tracing is defined as tracing of a model using fake inputs and weights. An input is Fake when `torch.nn.Tensor` is replaced by `torch._subclasses.FakeTensor`, whereas a weight is fake when a `torch.nn.Parameter` is replaced by `torch._subclasses.FakeTensor`.

For additional context, several strategies were discussed with Meta to enable this feature, including 1) calling `torch._dynamo.export` within a `torch._subclass.FakeTensorMode` context and 2) **fake**fying input and model as separate step and then call `torch._dynamo.export` without an active `torch._subclass.FakeTensorMode` context. At the end, 2) was preferred and implemented by #100017 to minimize the number of side-effects the fake tensor mode has on the code base.

As a consequence, `torch._dynamo.export` API introduced a new argument called `fake_mode`. When symbolic tracing is used, the user must pass in the `fake_mode` used to fakefy both the input and the model. Internally, `torch._dynamo.export` will adopt this `fake_mode` instead of creating its own instance. This is needed because each instance of `FakeTensorMode` has metadata on the tensor/parameter it fakefied. Thus, using real tensor/model and specify a `fake_mode` to `torch._dynamo.export` is an error. Also, specify a `fake_mode` instance to `torch._dynamo.export` different than the one used to fakefy the model and input is also an error.

## Changes introduced from this PR

This PR is intended to integrate `torch._dynamo.export(fake_mode=...)` through `torch.onnx.dynamo_export`. In essence, it
* Introduces a new public API `ONNXFakeContext` which wraps a `FakeTensorMode` under the hood. This removes complexity from the user side while still allow the exporter to leverage the fake mode.
* Adds a new public API `enable_fake_mode` *context manager* that instantiates and return a `ONNXFakeContext`.
* Adds a new `ExportOptions.fake_context` that will be used to persist the `ONNXFakeContext` created by `enable_fake_mode` and plumb through until it reaches the call to `torch._dynamo.export`.
* Adds a `model_state_dict` argument to `ExportOutput.save` API.
  * When model is exported with fake tensors, no actual data exist in the FX module and, therefore, in the ONNX graph.
    * In fact, `torch.fx.make_fx` lifts initializers as model input when fake tensors are used
      * #104493 is needed to enforce name matching between Parameters and inputs
    *  A model checkpoint file or state_dict is needed to populate the ONNX graph with real initializers through `export_output.save(model_state_dict=...)` API

Symbolic tracing, or onnx fake mode, is only enabled when the user instantiates the input and model within the `enable_fake_mode` context. Otherwise, real tracing is done, which preserves the current behavior.

## Usability

Because symbolic tracing depends a lot on having changes made on Dynamo side before it can be consumed on ONNX exporter, this feature may have its API and assumptions changed as symbolic tracing matures upstream. Nonetheless, it is still important to have this feature merged ASAP on the ONNX exporter side to "lock" changes on Dynamo that would otherwise break ONNX exporter without warning.

Example:

```python
class Model(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = torch.nn.Linear(2, 2)

    def forward(self, x):
        out = self.linear(x)
        return out

with torch.onnx.enable_fake_mode() as fake_context:
    x = torch.rand(5, 2, 2)
    model = Model()

# Export the model with fake inputs and parameters
export_options = ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
    model, x, export_options=export_options
)

model_state_dict = Model().state_dict()  # optional
export_output.save("/path/to/model.onnx", model_state_dict=model_state_dict)
```

## Next steps

* Add unit tests running the exported model with ORT
Today this is not possible yet because `make_fx` used by our Decomposition pass lifts initializers as model inputs. However, the initializer names are not preserved by FX tracing, causing a mismatch between the initializer and input name.
#104493 and #104741 should fix the initializer mismatch, enabling model execution

* Revisit `ONNXTorchPatcher` and how the ONNX initializers are saved in the graph as external data
We can try to get rid of the PyTorch patcher. If we can't, we might prefer to create specific patchers, say `FXSymbolicTracePatcher` used specifically during an export using `torch.fx.symbolic_trace` and maybe a `ExportOutputSavePacther` used specifically for `ExportOutput.save` to prevent "patching too many pytorch API that we don't need

## References

* [FakeTensor implementation](https://github.com/pytorch/pytorch/blob/main/torch/_subclasses/fake_tensor.py)
* [PR that adds fake tensor support to torch._dynamo.export](#100017)
* [Short fake tensor documentation](https://pytorch.org/torchdistx/latest/fake_tensor.html)
Pull Request resolved: #103865
Approved by: https://github.com/BowenBao
@facebook-github-bot facebook-github-bot deleted the gh/BowenBao/257/head branch July 11, 2023 14:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: onnx Related to torch.onnx open source release notes: onnx torch.onnx related changes that should show up in the release notes topic: improvements topic category
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

None yet

5 participants