Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve original scope names in exported ONNX graph #75100

Closed
marcocaccin opened this issue Apr 1, 2022 · 14 comments
Closed

Preserve original scope names in exported ONNX graph #75100

marcocaccin opened this issue Apr 1, 2022 · 14 comments
Assignees
Labels
feature A request for a proper, new feature. module: onnx Related to torch.onnx onnx-triaged triaged by ONNX team triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@marcocaccin
Copy link

marcocaccin commented Apr 1, 2022

馃殌 The feature, motivation and pitch

When a PyTorch model is exported, the graph nodes take uninformative names like f"{op_type}_{op_cnt}", and so do tensors (besides the input and output tensors of the overall graph).
In some contexts, it would be extremely helpful if each node in the exported ONNX graph would be named after the scope name of the PyTorch module from which it came from. Additionally, intermediate tensors should also ideally be named after the node that created them.

Here's a very subjective list of why it matters:

Visualisation, Debugging

netron_view

Let's say one goal of exporting to ONNX is to visualise a model (without loss of generality, let's think Netron) to reason about its architecture. For anything but the most trivial NN, this is simply not possible if nodes are named Conv_{N} and tensors are {M}.

As soon as the node name becomes unet.encoder.stage1.conv0 and its output tensor unet.encoder.stage1.conv0.output0, I become able to understand what node am I looking at, where is a skip connection going, and this sort of things.

ONNX graph manipulation

For some people, exporting a model to ONNX is not the end of the story. There may be some graph- or weight- level manipulations that must be applied on sub-graphs defined by the name scopes of their nodes or their intermediate tensors: just to give an idea, a silly example here may be "apply batchnorm fusion only on the decoder of the UNet".

TensorFlow does it 馃槢

Let's be clear: I am by no means a TF fan, but credit is due here. On that side, the feature I'm pitching is the default in tf2onnx and is a design decision that makes a lot of sense.

Notes

I fully understand that name scopes may lose some of their meaning when running graph-level optimisations during the export if sets of nodes are replaced by new ones and there is no 1:1 relationship. Even in this case, though, it may be possible to assign a "common ancestor name"... or fall back to the current implementation if nothing can be done.

Alternatives

1. Fork PyTorch and diverge

Since at least some of the code of the ONNX exporting is in C++, there is no easy monkey patching solution, the library must be recompiled.

2. Hack around with a heavy hand

Someone suggested that it is possible to monkey patch the _slow_forward method of torch.nn.Module to give meaningful names to the tensors flying around during graph tracing (note: not the nodes, but from output tensor names to node names the step is trivial).
I tried it and it seems to be a bit of a hit and miss, I would need to have a way deeper understanding of the guts of JIT to ensure that 100% of the output tensors actually have a useful name.

Alternatively, or in conjunction, I am also trying to monkey patch torch.onnx.utils._export() to use the inlined_graph, dump it, and rename the ONNX nodes as a post-processing.

Additional context

The scope name would be readily available from the scopeName() of each torch._C.Node of the TorchScript's inlined_graph, but the export instead uses the graph where such information has been stripped off (for example, here is the first place where the choice of using the graph is made and becomes consequential when exporting a ScriptModule).

To make use of this metadata down the line, also the serialisation logic of torch._C.Graph._export_onnx() must be adapted. Here is where the nodes of the exported ONNX proto take their names on the fly.

[Edit: updated markdown header level for better rendering]

@alcabilone
Copy link

Hello, is anyone else assigned on this or could i have a try?

@mruberry mruberry added module: onnx Related to torch.onnx feature A request for a proper, new feature. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Apr 4, 2022
@BowenBao
Copy link
Collaborator

BowenBao commented Apr 5, 2022

Have you tried exporting with verbose=True? It adds stacktrace information to doc_string field of each node. https://pytorch.org/docs/stable/onnx.html#torch.onnx.export

@marcocaccin
Copy link
Author

Have you tried exporting with verbose=True? It adds stacktrace information to doc_string field of each node. https://pytorch.org/docs/stable/onnx.html#torch.onnx.export

I tried that first, but TL;DR it simply does not work.
There are a couple of problems with this approach:

  1. The docs aren't saying the truth (deserves a separate github issue, but let's not get derailed). To be more specific, the documentation shows that the docstring of each node should contain a scope: section (which is indeed what I'm looking for, even if in a slightly different format), but in reality this section does not exist.
    You can easily reproduce this by
    class MyModel(torch.nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
            self.chunk0 = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1), torch.nn.BatchNorm2d(1), torch.nn.ReLU())
            self.chunk1 = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1), torch.nn.ReLU())

        def forward(self, x):
            x = self.chunk0(x)
            x = self.chunk1(x)
            return x

    model_pt = MyModel()
    input_tensor = torch.randn(1, 1, 32, 32)
    torch.onnx.export(
        model_pt,
        input_tensor,
        "model_test.onnx",
        verbose=True,
        input_names=["inputs"],
        output_names=["outputs"],
    )
    model_onnx = onnx.load("model_test.onnx")
    print(onnx.helper.printable_graph(model_onnx.graph))

the output will look like

graph torch-jit-export (
  %inputs[FLOAT, 1x1x32x32]
) initializers (
  %chunk1.0.weight[FLOAT, 1x1x1x1]
  %chunk1.0.bias[FLOAT, 1]
  %onnx::Conv_16[FLOAT, 1x1x1x1]
  %onnx::Conv_17[FLOAT, 1]
) {
  %input.4 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%inputs, %onnx::Conv_16, %onnx::Conv_17)
  %input.8 = Relu(%input.4)
  %input.12 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%input.8, %chunk1.0.weight, %chunk1.0.bias)
  %outputs = Relu(%input.12)
  return %outputs
}
  1. Related to (1), one cannot reliably infer the name scopes of every node. Two conditions must be true for that to be possible:
    • The node must have parameters, which is true only for stateful layers such as Conv but not for stateless such as ReLU or Add
    • The node must map 1:1 to a torch.nn.Module, otherwise even the names of layer parameters are mangled into something meaningless. As an example, look at chunk0 from the example above where the Conv+BN layers have been fused.

@BowenBao
Copy link
Collaborator

BowenBao commented Apr 6, 2022

@marcocaccin thanks for detailed response. The verbose argument mentioned the final ONNX graph will include the field doc_string. In your example above, it appears that onnx.helper.printable_graph() did not render the doc_string field. If you simply call print(model_onnx), you will get

graph(
...
  node {
    input: "input.12"
    output: "outputs"
    name: "Relu_3"
    op_type: "Relu"
    doc_string: "/home/bowbao/pytorch/torch/nn/functional.py(1406): relu\n/home/bowbao/pytorch/torch/nn/modules/activation.py(98): forward\n/home/bowbao/pytorch/torch/nn/modules/module.py(1098): _slow_forward\n/home/bowbao/pytorch/torch/nn/modules/module.py(1110): _call_impl\n/home/bowbao/pytorch/torch/nn/modules/container.py(139): forward\n/home/bowbao/pytorch/torch/nn/modules/module.py(1098): _slow_forward\n/home/bowbao/pytorch/torch/nn/modules/module.py(1110): _call_impl\nrepro_scope_docstring.py(12): forward\n/home/bowbao/pytorch/torch/nn/modules/module.py(1098): _slow_forward\n/home/bowbao/pytorch/torch/nn/modules/module.py(1110): _call_impl\n/home/bowbao/pytorch/torch/jit/_trace.py(118): wrapper\n/home/bowbao/pytorch/torch/jit/_trace.py(132): forward\n/home/bowbao/pytorch/torch/nn/modules/module.py(1110): _call_impl\n/home/bowbao/pytorch/torch/jit/_trace.py(1166): _get_trace_graph\n/home/bowbao/pytorch/torch/onnx/utils.py(401): _trace_and_get_graph_from_model\n/home/bowbao/pytorch/torch/onnx/utils.py(450): _create_jit_graph\n/home/bowbao/pytorch/torch/onnx/utils.py(545): _model_to_graph\n/home/bowbao/pytorch/torch/onnx/utils.py(809): _export\n/home/bowbao/pytorch/torch/onnx/utils.py(129): export\n/home/bowbao/pytorch/torch/onnx/__init__.py(336): export\nrepro_scope_docstring.py(23): <module>\n"
...

The doc_string field should also be visible through Netron.

I agree the scoped names provide much better clarity. Will discuss internally at microsoft, and tracked via https://msdata.visualstudio.com/Vienna/_workitems/edit/1737678

@marcocaccin
Copy link
Author

@BowenBao thank you!
A small clarification/nit to be made about your last comment though: the doc_string field contains sort of a traceback of where in the code the node comes from, but does not seem to have any reference to the scope name.
I don't have access to the MS link but if there's anything I can help you (e.g., providing a gist of how I solved the problem at the moment) please ping me at any moment 馃檪 !

@garymm
Copy link
Collaborator

garymm commented Apr 11, 2022

Same request from @lutzroeder to make visualizing structure in Netron possible.
Basically, what TensorFlow is doing. Nodes are named "module1/submodule2/nodename_x" and the "/" are interpreted by Netron as hierarchy.

@fatcat-z fatcat-z added the onnx-triaged triaged by ONNX team label Apr 18, 2022
@BowenBao BowenBao self-assigned this Jul 7, 2022
@BowenBao
Copy link
Collaborator

Same request from @lutzroeder to make visualizing structure in Netron possible. Basically, what TensorFlow is doing. Nodes are named "module1/submodule2/nodename_x" and the "/" are interpreted by Netron as hierarchy.

@garymm, @lutzroeder Is there any example to check how interpreted by Netron as hierarchy looks and works? I'm trying to see if similar effect is achieved with updating pytorch node naming.

@RalphMao
Copy link

This feature will be super helpful!

@cchan-lm
Copy link

cchan-lm commented Sep 9, 2022

Tried the nightly to test out #82038 and #82040 and it's already very useful! Do you know when these will be in an official PyTorch release?

@BowenBao
Copy link
Collaborator

BowenBao commented Sep 27, 2022

@cchan-lm Happy to hear that you find it useful, please feel free to give us feedback! This should be ready in the next PyTorch release 1.13 which is coming soon.

@marcocaccin
Copy link
Author

This is fantastic @BowenBao, thank you for your hard work and making this feature happen!

@cchan-lm
Copy link

@BowenBao That is awesome news, and we're looking forward to 1.13! For feedback, should we open another issue or comment here?

@BowenBao
Copy link
Collaborator

@BowenBao That is awesome news, and we're looking forward to 1.13! For feedback, should we open another issue or comment here?

Please open new issue, but feel free to mention this one for reference if needed.

@kevalmorabia97
Copy link
Contributor

kevalmorabia97 commented Oct 17, 2022

Currently the node names do not correspond to submodule names. For example, if I export the onnx graph for torchvision.models.resnet18, I see nodes with names like /layer1/layer1.0/conv2/Conv while in the actual model, the corresponding submodule name is layer1.0.conv2. Is it possible to have the submodule name in the onnx node name i.e. /layer1.0.conv2/Conv or /layer1/0/conv2/Conv?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. module: onnx Related to torch.onnx onnx-triaged triaged by ONNX team triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
9 participants