Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions test/export/test_unflatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,27 @@ def forward(self, x, y):
inp = (torch.randn(3), None)
self.assertTrue(torch.allclose(unf(*inp), M1()(*inp)))

def test_unflatten_root_module_type(self) -> None:
class M(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + x

class M1(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.m = M()

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.m(x)

inp = (torch.randn(3),)
ep = torch.export.export(M1(), inp)
unf = torch.export.unflatten(ep)
self.assertIsNotNone(unf.type_name())
self.assertEqual(unf.type_name().split(".")[-1], "M1")
self.assertEqual(unf.m.type_name().split(".")[-1], "M")
self.assertTrue(torch.allclose(unf(*inp), M1()(*inp)))


if __name__ == "__main__":
run_tests()
14 changes: 13 additions & 1 deletion torch/export/unflatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def get_flat_arg_paths(self) -> list[str]:
return []


class UnflattenedModule(torch.nn.Module):
class UnflattenedModule(_SubmoduleBase, torch.nn.Module):
def __init__(
self,
export_module: ExportedProgram,
Expand Down Expand Up @@ -340,6 +340,7 @@ def _id(obj):

_inplace_buffer_and_input_mutations(export_graph, self.graph_signature)
_fix_nn_module_stacks(export_graph)
self._ty = _root_module_type(export_graph)

self.ivals = _IVals()
# for any intermediate value of a mutation that is read, track the mutation
Expand Down Expand Up @@ -858,6 +859,17 @@ def forward(self, buffer, x):
output_node.args = ((user_outputs),)


def _root_module_type(graph: torch.fx.Graph) -> Optional[str]:
for node in graph.nodes:
if "nn_module_stack" not in node.meta:
continue

for path, ty in node.meta["nn_module_stack"].values():
if not path:
return ty
return None


def _fix_nn_module_stacks(graph):
# For each nn module stack in the graph, check if the fqns in it represent a stack:
# 1. Each fqn must be a prefix of the next fqn.
Expand Down
Loading