Skip to content

Commit

Permalink
[ONNX] Do not run 'deduplicate_initializers' when 'keep_initializers_…
Browse files Browse the repository at this point in the history
…as_inputs' is True (#96320)

### Proposal
When arg of 'keep_initializers_as_inputs' is True, it's quite possible that parameters are set by initializer of input.
Hence we should disable de-duplicate initializer optimization when 'keep_initializers_as_inputs==True'.

- [x] Update doc related to `keep_initializers_as_inputs`.
Pull Request resolved: #96320
Approved by: https://github.com/abock, https://github.com/thiagocrepaldi
  • Loading branch information
BowenBao authored and pytorchmergebot committed Aug 1, 2023
1 parent cfa4edc commit 05b2a6c
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions torch/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,11 @@ def forward(self, x):
This may allow for better optimizations (e.g. constant folding) by
backends/runtimes.
If True, `deduplicate_initializers` pass will not be executed. This means
initializers with duplicated values will not be deduplicated and
will be treated as distinct inputs to the graph. This allows different
input initializers to be supplied at the runtime following export.
If ``opset_version < 9``, initializers MUST be part of graph
inputs and this argument will be ignored and the behavior will be
equivalent to setting this argument to True.
Expand Down Expand Up @@ -1603,9 +1608,11 @@ def _export(
module_typenames_to_export_as_functions,
list(params_dict.keys()),
)
params_dict = _C._jit_pass_onnx_deduplicate_initializers( # type: ignore[assignment]
graph, params_dict, getattr(model, "training", False) # type: ignore[arg-type]
)

if keep_initializers_as_inputs is not True:
params_dict = _C._jit_pass_onnx_deduplicate_initializers( # type: ignore[assignment]
graph, params_dict, getattr(model, "training", False) # type: ignore[arg-type]
)
_C._jit_pass_onnx_assign_scoped_names_for_node_and_value(graph)
if export_params:
(
Expand Down

0 comments on commit 05b2a6c

Please sign in to comment.