Skip to content

Commit

Permalink
[ONNX] Do not run 'deduplicate_initializers' when 'keep_initializers_…
Browse files Browse the repository at this point in the history
…as_inputs' is True

ghstack-source-id: c0d7d0fb8aee0bc363be8cef28e4698f51ff1623
Pull Request resolved: #96320
  • Loading branch information
BowenBao committed Aug 1, 2023
1 parent f23d755 commit c7d9359
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions torch/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,11 @@ def forward(self, x):
This may allow for better optimizations (e.g. constant folding) by
backends/runtimes.
If True, `deduplicate_initializers` pass will not be executed. This means
initializers with duplicated values will not be deduplicated and
will be treated as distinct inputs to the graph. This allows different
input initializers to be supplied at the runtime following export.
If ``opset_version < 9``, initializers MUST be part of graph
inputs and this argument will be ignored and the behavior will be
equivalent to setting this argument to True.
Expand Down Expand Up @@ -1603,9 +1608,11 @@ def _export(
module_typenames_to_export_as_functions,
list(params_dict.keys()),
)
params_dict = _C._jit_pass_onnx_deduplicate_initializers( # type: ignore[assignment]
graph, params_dict, getattr(model, "training", False) # type: ignore[arg-type]
)

if keep_initializers_as_inputs is not True:
params_dict = _C._jit_pass_onnx_deduplicate_initializers( # type: ignore[assignment]
graph, params_dict, getattr(model, "training", False) # type: ignore[arg-type]
)
_C._jit_pass_onnx_assign_scoped_names_for_node_and_value(graph)
if export_params:
(
Expand Down

0 comments on commit c7d9359

Please sign in to comment.