-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Closed
Labels
export-triagedThis tag is used to tag issues that have been looked by PT2 Export team and determined the next stepThis tag is used to tag issues that have been looked by PT2 Export team and determined the next steponcall: export
Description
🐛 Describe the bug
After running run_decompositions
, the exported program is not the same.
pytorch/torch/export/exported_program.py
Lines 351 to 361 in e0d2a24
def run_decompositions( | |
self, decomp_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None | |
) -> "ExportedProgram": | |
""" | |
Run a set of decompositions on the exported program and returns a new | |
exported program. By default we will run the Core ATen decompositions to | |
get operators in the | |
`Core ATen Operator Set <https://pytorch.org/docs/stable/torch.compiler_ir.html>`_. | |
For now, we do not decompose joint graphs. | |
""" |
However, the state_dict
is also updated.
pytorch/torch/export/exported_program.py
Lines 458 to 472 in e0d2a24
state_dict = self.state_dict.copy() | |
lift_constant_tensor_pass(gm, new_graph_signature, state_dict) | |
_replace_sym_size_ops_pass(gm) | |
exported_program = ExportedProgram( | |
gm, | |
gm.graph, | |
new_graph_signature, | |
state_dict, | |
new_range_constraints, | |
new_equality_constraints, | |
copy.deepcopy(self.module_call_graph), | |
self.example_inputs, | |
self.verifier, | |
self.tensor_constants, | |
) |
Should decomposition happen in-place, instead of creating a whole new one?
Versions
main branch
cc @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4 @thiagocrepaldi
Metadata
Metadata
Assignees
Labels
export-triagedThis tag is used to tag issues that have been looked by PT2 Export team and determined the next stepThis tag is used to tag issues that have been looked by PT2 Export team and determined the next steponcall: export