Skip to content

Commit

Permalink
Prototype for export_for_training
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
tugsbayasgalan committed Jun 19, 2024
1 parent 5ddab68 commit 14f8fc4
Show file tree
Hide file tree
Showing 2 changed files with 484 additions and 1 deletion.
16 changes: 15 additions & 1 deletion test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

HAS_TORCHREC = True
except ImportError:
except (ImportError, AttributeError):
HAS_TORCHREC = False

try:
Expand Down Expand Up @@ -1148,6 +1148,20 @@ def forward(self, x, y, z):
ep.module()(torch.randn(6), torch.randn(7), torch.randn(8)).size()[0], 6
)

def test_simple_export_for_training(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2)

def forward(self, x):
return self.linear(x)

gm = torch.export._trace._export_for_training(
Foo(), (torch.ones(2, 2),)
).graph_module
print(gm.graph)

def test_derived_dim_out_of_order_simplified_repeat_non_derived(self):
class Foo(torch.nn.Module):
def forward(self, x, y, y1, z):
Expand Down
Loading

0 comments on commit 14f8fc4

Please sign in to comment.