Skip to content

Commit

Permalink
Fix bug in graph partitioner
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/pytorch#125133

Title

Reviewed By: PaulZhang12

Differential Revision: D56688411
  • Loading branch information
tugsbayasgalan authored and facebook-github-bot committed May 7, 2024
1 parent 5e93499 commit 3808a94
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions torchrec/distributed/tests/test_pt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,15 +302,18 @@ def test_sharded_quant_ebc_non_strict_export(self) -> None:
kjt = kjt.to("meta")
sharded_model(kjt.values(), kjt.lengths())

ep = torch.export.export(
from torch.export import _trace
ep_pre = _trace._export(
sharded_model,
(
kjt.values(),
kjt.lengths(),
),
{},
pre_dispatch=True,
strict=False,
)
ep = ep_pre.run_decompositions()

ep.module()(kjt.values(), kjt.lengths())

Expand Down Expand Up @@ -338,15 +341,18 @@ def test_sharded_quant_fpebc_non_strict_export(self) -> None:

sharded_model(kjt.values(), kjt.lengths())

ep = torch.export.export(
from torch.export import _trace

ep = _trace._export(
sharded_model,
(
kjt.values(),
kjt.lengths(),
),
{},
strict=False,
)
pre_dispatch=True,
).run_decompositions()
ep.module()(kjt.values(), kjt.lengths())

# PT2 IR autofunctionalizes mutation funcs (bounds_check_indices)
Expand Down

0 comments on commit 3808a94

Please sign in to comment.