diff --git a/backends/cadence/aot/program_builder.py b/backends/cadence/aot/program_builder.py index d73cc9fcfbf..153d52ca6f1 100644 --- a/backends/cadence/aot/program_builder.py +++ b/backends/cadence/aot/program_builder.py @@ -12,6 +12,7 @@ from torch import Tensor from torch._export.verifier import Verifier from torch.export import ExportedProgram +from torch.export.exported_program import ModuleCallEntry, ModuleCallSignature from torch.export.graph_signature import ( ExportGraphSignature, InputKind, @@ -20,6 +21,7 @@ OutputSpec, TensorArgument, ) +from torch.utils import _pytree as pytree class IrMode(Enum): @@ -87,17 +89,25 @@ def get_verifiers(self) -> Optional[list[Verifier]]: def get_program(self) -> ExportedProgram: gm = self.get_graph_module() + graph_signature = ExportGraphSignature(self.input_specs, self.output_specs) + in_spec = pytree.tree_flatten((tuple(graph_signature.user_inputs), {}))[1] + out_spec = pytree.tree_flatten(graph_signature.user_outputs)[1] return ExportedProgram( root=gm, graph=gm.graph, - graph_signature=ExportGraphSignature( - input_specs=self.input_specs, output_specs=self.output_specs - ), + graph_signature=graph_signature, # pyre-ignore[6]: Incompatible parameter type. constants=self.constants, state_dict=self.state_dict, range_constraints={}, - module_call_graph=[], + module_call_graph=[ + ModuleCallEntry( + "", + ModuleCallSignature( + inputs=[], outputs=[], in_spec=in_spec, out_spec=out_spec + ), + ) + ], # pyre-ignore[6]: Incompatible parameter type. verifiers=self.get_verifiers(), )