diff --git a/exir/__init__.py b/exir/__init__.py index f95f9c57b2..341148ca7e 100644 --- a/exir/__init__.py +++ b/exir/__init__.py @@ -28,11 +28,7 @@ to_edge, ) from executorch.exir.tracer import ExirDynamoConfig -from torch._export import ( # lots of people are doing from exir import CallSpec, ExportGraphSignature, ExportedProgram which seems wrong - CallSpec, - ExportedProgram, - ExportGraphSignature, -) +from torch.export import ExportedProgram, ExportGraphSignature Value = Any @@ -42,7 +38,6 @@ "capture", "capture_multiple", "_capture_legacy_do_not_use", - "CallSpec", "ExportedProgram", "ExirExportedProgram", "ExecutorchProgram", diff --git a/exir/capture/_capture.py b/exir/capture/_capture.py index 575685ca86..a803a24537 100644 --- a/exir/capture/_capture.py +++ b/exir/capture/_capture.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import copy +import dataclasses import warnings from collections import namedtuple from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union @@ -26,10 +27,12 @@ from torch import _guards from torch._dispatch.python import enable_python_dispatcher from torch._dynamo.eval_frame import Constraint -from torch._export import CallSpec, export, ExportedProgram, ExportGraphSignature from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch.export import export from torch.export.exported_program import ( + ExportedProgram, + ExportGraphSignature, InputKind, InputSpec, ModuleCallEntry, @@ -53,6 +56,12 @@ ) +@dataclasses.dataclass +class CallSpec: + in_spec: Optional[pytree.TreeSpec] + out_spec: Optional[pytree.TreeSpec] + + @compatibility(is_backward_compatible=False) def _capture_legacy_do_not_use(f, args) -> ExirExportedProgram: """ @@ -128,7 +137,7 @@ def _capture_legacy_do_not_use(f, args) -> ExirExportedProgram: @compatibility(is_backward_compatible=False) def capture( # noqa: C901 - f: Callable[..., Any], + f: torch.nn.Module, args: Tuple[Value, ...], config: Optional[CaptureConfig] = None, constraints: Optional[List[Constraint]] = None, diff --git a/exir/lowered_backend_module.py b/exir/lowered_backend_module.py index 58e4356a19..f8e38db990 100644 --- a/exir/lowered_backend_module.py +++ b/exir/lowered_backend_module.py @@ -271,10 +271,6 @@ def program(self, emit_stacktrace: bool = False) -> Program: ), lowered_exported_program.graph_module, ), - # TODO: May need to set lowered_exported_program.call_spec = CallSpec(None, None) - # somewhere as we should pass it a list of tensors to the lowered module and output a - # list of tensors. Putting call_spec=lowered_exported_program.call_spec is correct here as the - # inputs/outputs to the toplevel program will be in the format of the eager module. state_dict={}, # None because all data are consumed by delegate range_constraints=lowered_exported_program.range_constraints, module_call_graph=lowered_exported_program.module_call_graph, diff --git a/exir/passes/__init__.py b/exir/passes/__init__.py index 19806d7ea1..23bda63bd9 100644 --- a/exir/passes/__init__.py +++ b/exir/passes/__init__.py @@ -464,7 +464,7 @@ def dead_code_elimination_pass(graph_module: torch.fx.GraphModule) -> PassResult # Passes to convert a graph module from ATen to Edge IR -pre_op_replace_passes = PassManager( +pre_op_replace_passes: List[Callable[[torch.nn.Module], PassResult]] = PassManager( passes=[ # ReplaceSymSizeOpPass need to be run before other passes which inherits # from ExportPass. ExportPass can not handle OpOverloadPacket in its @@ -479,7 +479,7 @@ def dead_code_elimination_pass(graph_module: torch.fx.GraphModule) -> PassResult ] ).passes -post_op_replace_passes = PassManager( +post_op_replace_passes: List[Callable[[torch.nn.Module], PassResult]] = PassManager( passes=[ dead_code_elimination_pass, DebugHandleGeneratorPass(), diff --git a/exir/program/_program.py b/exir/program/_program.py index 343fcbf761..f33870221a 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -36,11 +36,12 @@ EXIREdgeDialectVerifier, get_aten_verifier, ) -from torch._export import ExportedProgram from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass from torch.export.exported_program import ( _get_updated_range_constraints, ConstantArgument, + CustomObjArgument, + ExportedProgram, ExportGraphSignature, InputKind, InputSpec, @@ -74,7 +75,8 @@ def _get_updated_graph_signature( old_input_spec = old_signature.input_specs[i] arg = ( old_input_spec.arg - if isinstance(old_input_spec.arg, ConstantArgument) + if isinstance(old_input_spec.arg, (ConstantArgument, CustomObjArgument)) + # pyre-ignore else type(old_input_spec.arg)(node.name) ) new_input_specs.append( @@ -93,7 +95,8 @@ def _get_updated_graph_signature( old_output_spec = old_signature.output_specs[i] arg = ( old_output_spec.arg - if isinstance(old_output_spec.arg, ConstantArgument) + if isinstance(old_output_spec.arg, (ConstantArgument, CustomObjArgument)) + # pyre-ignore else type(old_output_spec.arg)(node.name) ) new_output_specs.append( diff --git a/exir/serde/export_serialize.py b/exir/serde/export_serialize.py index e01e86db2e..6c573464d0 100644 --- a/exir/serde/export_serialize.py +++ b/exir/serde/export_serialize.py @@ -699,7 +699,7 @@ def serialize_optional_tensor_args(a): # serialize/deserialize function. custom_obj_name = f"_custom_obj_{len(self.custom_objs)}" self.custom_objs[custom_obj_name] = arg - return Argument.create(as_custom_obj=CustomObjArgument(custom_obj_name)) + return Argument.create(as_custom_obj=CustomObjArgument(custom_obj_name, "")) else: raise SerializeError(f"Unsupported argument type: {type(arg)}") diff --git a/exir/tests/models.py b/exir/tests/models.py index ffcae2716c..963fa3ad44 100644 --- a/exir/tests/models.py +++ b/exir/tests/models.py @@ -192,7 +192,7 @@ class TensorSplit(nn.Module): def __init__(self) -> None: super().__init__() - def forward(self, input: Tensor, sections: int, dim: int = 0) -> List[Tensor]: + def forward(self, input: Tensor, sections: int, dim: int = 0) -> Tuple[Tensor]: return torch.tensor_split(input, sections, dim)