Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove CallSpec #1618

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 1 addition & 6 deletions exir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,7 @@
to_edge,
)
from executorch.exir.tracer import ExirDynamoConfig
from torch._export import ( # lots of people are doing from exir import CallSpec, ExportGraphSignature, ExportedProgram which seems wrong
CallSpec,
ExportedProgram,
ExportGraphSignature,
)
from torch.export import ExportedProgram, ExportGraphSignature

Value = Any

Expand All @@ -42,7 +38,6 @@
"capture",
"capture_multiple",
"_capture_legacy_do_not_use",
"CallSpec",
"ExportedProgram",
"ExirExportedProgram",
"ExecutorchProgram",
Expand Down
13 changes: 11 additions & 2 deletions exir/capture/_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import copy
import dataclasses
import warnings
from collections import namedtuple
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
Expand All @@ -26,10 +27,12 @@
from torch import _guards
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.eval_frame import Constraint
from torch._export import CallSpec, export, ExportedProgram, ExportGraphSignature
from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.export import export
from torch.export.exported_program import (
ExportedProgram,
ExportGraphSignature,
InputKind,
InputSpec,
ModuleCallEntry,
Expand All @@ -53,6 +56,12 @@
)


@dataclasses.dataclass
class CallSpec:
in_spec: Optional[pytree.TreeSpec]
out_spec: Optional[pytree.TreeSpec]


@compatibility(is_backward_compatible=False)
def _capture_legacy_do_not_use(f, args) -> ExirExportedProgram:
"""
Expand Down Expand Up @@ -128,7 +137,7 @@ def _capture_legacy_do_not_use(f, args) -> ExirExportedProgram:

@compatibility(is_backward_compatible=False)
def capture( # noqa: C901
f: Callable[..., Any],
f: torch.nn.Module,
args: Tuple[Value, ...],
config: Optional[CaptureConfig] = None,
constraints: Optional[List[Constraint]] = None,
Expand Down
4 changes: 0 additions & 4 deletions exir/lowered_backend_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,10 +271,6 @@ def program(self, emit_stacktrace: bool = False) -> Program:
),
lowered_exported_program.graph_module,
),
# TODO: May need to set lowered_exported_program.call_spec = CallSpec(None, None)
# somewhere as we should pass it a list of tensors to the lowered module and output a
# list of tensors. Putting call_spec=lowered_exported_program.call_spec is correct here as the
# inputs/outputs to the toplevel program will be in the format of the eager module.
state_dict={}, # None because all data are consumed by delegate
range_constraints=lowered_exported_program.range_constraints,
module_call_graph=lowered_exported_program.module_call_graph,
Expand Down
4 changes: 2 additions & 2 deletions exir/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def dead_code_elimination_pass(graph_module: torch.fx.GraphModule) -> PassResult

# Passes to convert a graph module from ATen to Edge IR

pre_op_replace_passes = PassManager(
pre_op_replace_passes: List[Callable[[torch.nn.Module], PassResult]] = PassManager(
passes=[
# ReplaceSymSizeOpPass need to be run before other passes which inherits
# from ExportPass. ExportPass can not handle OpOverloadPacket in its
Expand All @@ -479,7 +479,7 @@ def dead_code_elimination_pass(graph_module: torch.fx.GraphModule) -> PassResult
]
).passes

post_op_replace_passes = PassManager(
post_op_replace_passes: List[Callable[[torch.nn.Module], PassResult]] = PassManager(
passes=[
dead_code_elimination_pass,
DebugHandleGeneratorPass(),
Expand Down
9 changes: 6 additions & 3 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,12 @@
EXIREdgeDialectVerifier,
get_aten_verifier,
)
from torch._export import ExportedProgram
from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass
from torch.export.exported_program import (
_get_updated_range_constraints,
ConstantArgument,
CustomObjArgument,
ExportedProgram,
ExportGraphSignature,
InputKind,
InputSpec,
Expand Down Expand Up @@ -74,7 +75,8 @@ def _get_updated_graph_signature(
old_input_spec = old_signature.input_specs[i]
arg = (
old_input_spec.arg
if isinstance(old_input_spec.arg, ConstantArgument)
if isinstance(old_input_spec.arg, (ConstantArgument, CustomObjArgument))
# pyre-ignore
else type(old_input_spec.arg)(node.name)
)
new_input_specs.append(
Expand All @@ -93,7 +95,8 @@ def _get_updated_graph_signature(
old_output_spec = old_signature.output_specs[i]
arg = (
old_output_spec.arg
if isinstance(old_output_spec.arg, ConstantArgument)
if isinstance(old_output_spec.arg, (ConstantArgument, CustomObjArgument))
# pyre-ignore
else type(old_output_spec.arg)(node.name)
)
new_output_specs.append(
Expand Down
2 changes: 1 addition & 1 deletion exir/serde/export_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ def serialize_optional_tensor_args(a):
# serialize/deserialize function.
custom_obj_name = f"_custom_obj_{len(self.custom_objs)}"
self.custom_objs[custom_obj_name] = arg
return Argument.create(as_custom_obj=CustomObjArgument(custom_obj_name))
return Argument.create(as_custom_obj=CustomObjArgument(custom_obj_name, ""))
else:
raise SerializeError(f"Unsupported argument type: {type(arg)}")

Expand Down
2 changes: 1 addition & 1 deletion exir/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ class TensorSplit(nn.Module):
def __init__(self) -> None:
super().__init__()

def forward(self, input: Tensor, sections: int, dim: int = 0) -> List[Tensor]:
def forward(self, input: Tensor, sections: int, dim: int = 0) -> Tuple[Tensor]:
return torch.tensor_split(input, sections, dim)


Expand Down