Skip to content

Commit

Permalink
[Release/2.2] [export] Do not copy state_dict in run_decomp (#115753)
Browse files Browse the repository at this point in the history
Fixes #114628

Cherry-pick of  #115269 into release/2.2 branch
Approved by: https://github.com/thiagocrepaldi, https://github.com/ydwu4

Co-authored-by: angelayi <yiangela7@gmail.com>
  • Loading branch information
Thiago Crepaldi and angelayi committed Dec 14, 2023
1 parent 1b70285 commit 8be2611
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
3 changes: 2 additions & 1 deletion test/export/test_export.py
Expand Up @@ -1454,7 +1454,7 @@ def forward(self, x):
m = M()
with unittest.mock.patch("torch._export.DECOMP_TABLE", None):
ep = export(m, inp)

state_dict = ep.state_dict

FileCheck().check_count(
"torch.ops.aten.t.default", 1, exactly=True
Expand All @@ -1469,6 +1469,7 @@ def forward(self, x):
"torch.ops.aten.t.default", 0, exactly=True
).run(core_aten_ep.graph_module.code)
self.assertTrue(torch.allclose(core_aten_ep(*inp), m(*inp)))
self.assertEqual(id(state_dict), id(ep.state_dict))

def test_export_decomps_dynamic(self):
class M(torch.nn.Module):
Expand Down
3 changes: 1 addition & 2 deletions torch/export/exported_program.py
Expand Up @@ -466,14 +466,13 @@ def update_arg(old_arg, new_ph):
for inp_dim1, inp_dim2 in self.equality_constraints
]

state_dict = self.state_dict.copy()
lift_constant_tensor_pass(gm, new_graph_signature)
_replace_sym_size_ops_pass(gm)
exported_program = ExportedProgram(
gm,
gm.graph,
new_graph_signature,
state_dict,
self.state_dict,
new_range_constraints,
new_equality_constraints,
copy.deepcopy(self.module_call_graph),
Expand Down

0 comments on commit 8be2611

Please sign in to comment.