Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Restore readable names for parameters and buffers #104493

Closed
wants to merge 7 commits into from
10 changes: 9 additions & 1 deletion torch/onnx/_internal/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,9 @@ def export(self) -> ExportOutput:
)

# TODO: Design the passes API
graph_module = pre_export_passes(self.options, graph_module, updated_model_args)
graph_module = pre_export_passes(
self.options, self.model, graph_module, updated_model_args
)

# TODO: Defer `import onnxscript` out of `import torch` path
# https://github.com/pytorch/pytorch/issues/103764
Expand Down Expand Up @@ -610,6 +612,7 @@ def dynamo_export(
@_beartype.beartype
def pre_export_passes(
options: ResolvedExportOptions,
original_model: Union[torch.nn.Module, Callable],
fx_module: torch.fx.GraphModule,
fx_module_args: Sequence[Any],
):
Expand Down Expand Up @@ -647,6 +650,11 @@ def pre_export_passes(
diagnostic_context, module, options.onnxfunction_dispatcher
).analyze(infra.levels.ERROR)

if isinstance(original_model, torch.nn.Module):
module = passes.RestoreParameterAndBufferNames(
diagnostic_context, module, original_model
).run()

# ONNX does not support None inputs. During graph building, all None inputs
# are removed. Here we register this step to input adapter.
options.fx_tracer.input_adapter.append_step(io_adapter.RemoveNoneInputStep())
Expand Down
2 changes: 2 additions & 0 deletions torch/onnx/_internal/fx/passes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .decomp import Decompose
from .functionalization import Functionalize, RemoveInputMutation
from .readability import RestoreParameterAndBufferNames
from .shape_inference import ShapeInferenceWithFakeTensor
from .virtualization import MovePlaceholderToFront, ReplaceGetAttrWithPlaceholder

Expand All @@ -8,6 +9,7 @@
"Functionalize",
"MovePlaceholderToFront",
"RemoveInputMutation",
"RestoreParameterAndBufferNames",
"ReplaceGetAttrWithPlaceholder",
"ShapeInferenceWithFakeTensor",
]
111 changes: 111 additions & 0 deletions torch/onnx/_internal/fx/passes/readability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from __future__ import annotations

from typing import Dict, List, Sequence, Tuple, Union

import torch
from torch.onnx._internal import _beartype
from torch.onnx._internal.fx import _pass, diagnostics


class RestoreParameterAndBufferNames(_pass.Transform):
"""Restore parameter and buffer names from original module.

This pass is useful for readability of the exported ONNX graph. It restores the
parameter and buffer names from the original module. For example, if the original
module has a parameter named `root.linear.0.weight`, and the parameter is renamed to
BowenBao marked this conversation as resolved.
Show resolved Hide resolved
`_param_constant9` by FX, this pass will rename it back.
"""
def __init__(
self,
diagnostic_context: diagnostics.DiagnosticContext,
module: torch.fx.GraphModule,
original_module: torch.nn.Module,
BowenBao marked this conversation as resolved.
Show resolved Hide resolved
):
super().__init__(diagnostic_context, module)
self.original_module = original_module

@_beartype.beartype
def _rename_param_and_buffer(
self,
diagnostic: diagnostics.Diagnostic,
nodes: Sequence[torch.fx.Node],
new_name: str,
) -> None:
"""Rename the parameter/buffer and replace corresponding nodes with new nodes of updated target.
"""
assert len(nodes) > 0, "`nodes` cannot be empty"
assert (
len({node.target for node in nodes}) == 1
), "`nodes` must all have same `target`"
old_name = nodes[0].target
assert isinstance(old_name, str), f"Expected str, got type({old_name})"
# Parameter/buffer name cannot contain "."
BowenBao marked this conversation as resolved.
Show resolved Hide resolved
normalized_name = new_name.replace(".", "_")
attr_value = getattr(self.module, old_name)
setattr(self.module, normalized_name, attr_value)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
setattr(self.module, normalized_name, attr_value)
setattr(self.fx_module, normalized_name, attr_value)

delattr(self.module, old_name)
for node in nodes:
with self.module.graph.inserting_before(node):
new_node = self.module.graph.get_attr(normalized_name)
new_node.meta = node.meta
node.replace_all_uses_with(new_node)
self.module.graph.erase_node(node)
diagnostic.with_additional_message(
f"Renamed 'self.{old_name}' to 'self.{normalized_name}', "
f"normalized from original parameter name '{new_name}'."
)

def _run(self, *args, **kwargs) -> torch.fx.GraphModule:
"""Restore parameter and buffer names from original module.

For each `get_attr` node, if the target is a str representing a parameter or buffer
under `self.module`, we rename the parameter or buffer to its original name.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
under `self.module`, we rename the parameter or buffer to its original name.
under `self.fx_module`, we rename the parameter or buffer to its original name.

there is a bunch of old self.module going on here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't do anything, this is in base class

The parameters and buffers between `self.module` and `self.original_module` refer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The parameters and buffers between `self.module` and `self.original_module` refer
The parameters and buffers between `self.fx_module` and `self.nn_module` refer

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll fix these in a follow up to merge this first to unblock you.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm I think better to rebase to be safe.. Another PR with pass was merged.

to the same objects, allowing us to use it as key to retrieve the original name.
"""
assert len(args) == 0, "RestoreParameterAndBufferNames does not take any args"
assert (
len(kwargs) == 0
), "RestoreParameterAndBufferNames does not take any kwargs"
BowenBao marked this conversation as resolved.
Show resolved Hide resolved
state_to_readable_name: Dict[Union[torch.nn.Parameter, torch.Tensor], str] = {}
BowenBao marked this conversation as resolved.
Show resolved Hide resolved
state_to_readable_name.update(
{v: k for k, v in self.original_module.named_parameters()}
)
state_to_readable_name.update(
{v: k for k, v in self.original_module.named_buffers()}
)
thiagocrepaldi marked this conversation as resolved.
Show resolved Hide resolved
diagnostic = self.diagnostic_context.inflight_diagnostic()
BowenBao marked this conversation as resolved.
Show resolved Hide resolved

old_name_to_nodes: Dict[str, Tuple[List[torch.fx.Node], str]] = {}
BowenBao marked this conversation as resolved.
Show resolved Hide resolved

for node in self.module.graph.nodes:
if node.op == "get_attr":
assert isinstance(
node.target, str
), f"Expected str, got type({node.target})"
if node.target in old_name_to_nodes:
thiagocrepaldi marked this conversation as resolved.
Show resolved Hide resolved
# We have already processed this parameter/buffer.
old_name_to_nodes[node.target][0].append(node)
continue
attr_value = getattr(self.module, node.target)
thiagocrepaldi marked this conversation as resolved.
Show resolved Hide resolved
if (
isinstance(attr_value, (torch.nn.Parameter, torch.Tensor))
BowenBao marked this conversation as resolved.
Show resolved Hide resolved
and attr_value in state_to_readable_name
):
readable_name = state_to_readable_name[attr_value]
old_name_to_nodes[node.target] = ([node], readable_name)
continue

diagnostic.with_additional_message(
f"Cannot find readable name for self.{node.target}: {type(attr_value)}. The name is unchanged."
)
if isinstance(attr_value, torch.nn.Parameter):
# If it is a parameter we treat it more seriously.
diagnostic.level = diagnostics.levels.WARNING
BowenBao marked this conversation as resolved.
Show resolved Hide resolved
else:
diagnostic.level = diagnostics.levels.NONE

for nodes, new_name in old_name_to_nodes.values():
self._rename_param_and_buffer(diagnostic, nodes, new_name)

return self.module