diff --git a/test/export/test_unflatten.py b/test/export/test_unflatten.py index 5e1872c249ed..13b234c173e5 100644 --- a/test/export/test_unflatten.py +++ b/test/export/test_unflatten.py @@ -1060,6 +1060,27 @@ def forward(self, x, y): inp = (torch.randn(3), None) self.assertTrue(torch.allclose(unf(*inp), M1()(*inp))) + def test_unflatten_root_module_type(self) -> None: + class M(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + x + + class M1(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.m = M() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.m(x) + + inp = (torch.randn(3),) + ep = torch.export.export(M1(), inp) + unf = torch.export.unflatten(ep) + self.assertIsNotNone(unf.type_name()) + self.assertEqual(unf.type_name().split(".")[-1], "M1") + self.assertEqual(unf.m.type_name().split(".")[-1], "M") + self.assertTrue(torch.allclose(unf(*inp), M1()(*inp))) + if __name__ == "__main__": run_tests() diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index bdea191190e3..de72c8f505d9 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -301,7 +301,7 @@ def get_flat_arg_paths(self) -> list[str]: return [] -class UnflattenedModule(torch.nn.Module): +class UnflattenedModule(_SubmoduleBase, torch.nn.Module): def __init__( self, export_module: ExportedProgram, @@ -340,6 +340,7 @@ def _id(obj): _inplace_buffer_and_input_mutations(export_graph, self.graph_signature) _fix_nn_module_stacks(export_graph) + self._ty = _root_module_type(export_graph) self.ivals = _IVals() # for any intermediate value of a mutation that is read, track the mutation @@ -858,6 +859,17 @@ def forward(self, buffer, x): output_node.args = ((user_outputs),) +def _root_module_type(graph: torch.fx.Graph) -> Optional[str]: + for node in graph.nodes: + if "nn_module_stack" not in node.meta: + continue + + for path, ty in node.meta["nn_module_stack"].values(): + if not path: + return ty + return None + + def _fix_nn_module_stacks(graph): # For each nn module stack in the graph, check if the fqns in it represent a stack: # 1. Each fqn must be a prefix of the next fqn.