Skip to content

run_decompositions in ExportedProgram doesn't keep the same model state_dict #114628

@titaiwangms

Description

@titaiwangms

🐛 Describe the bug

After running run_decompositions, the exported program is not the same.

def run_decompositions(
self, decomp_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None
) -> "ExportedProgram":
"""
Run a set of decompositions on the exported program and returns a new
exported program. By default we will run the Core ATen decompositions to
get operators in the
`Core ATen Operator Set <https://pytorch.org/docs/stable/torch.compiler_ir.html>`_.
For now, we do not decompose joint graphs.
"""

However, the state_dict is also updated.

state_dict = self.state_dict.copy()
lift_constant_tensor_pass(gm, new_graph_signature, state_dict)
_replace_sym_size_ops_pass(gm)
exported_program = ExportedProgram(
gm,
gm.graph,
new_graph_signature,
state_dict,
new_range_constraints,
new_equality_constraints,
copy.deepcopy(self.module_call_graph),
self.example_inputs,
self.verifier,
self.tensor_constants,
)

Should decomposition happen in-place, instead of creating a whole new one?

Versions

main branch

cc @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4 @thiagocrepaldi

Metadata

Metadata

Assignees

Labels

export-triagedThis tag is used to tag issues that have been looked by PT2 Export team and determined the next steponcall: export

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions