Skip to content

Commit

Permalink
Add support for edge dialect ops in exir/serde
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/pytorch#106371

Adding support for edge dialect ops in `exir/serde`. This diff does the following:
- Moves the global `serialize_operator/deserialize_operator` implementations in`export/serde/serialize.py` into `GraphModuleSerializer` and `GraphModuleDeserializer`
- Adds implementations of `serialize_operator/deserialize_operator` inside `GraphModuleSerializer` and `GraphModuleDeserializer` in `exir/serde/serialize.py`

Reviewed By: chakriu, angelayi

Differential Revision: D47938280

fbshipit-source-id: 8a8aed65c51d9cc57b4b37499421be971c19d344
  • Loading branch information
tarun292 authored and facebook-github-bot committed Aug 2, 2023
1 parent f8a19a7 commit 13ca8c9
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 17 deletions.
2 changes: 2 additions & 0 deletions exir/serde/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ python_library(
"//executorch/exir:lib",
"//executorch/exir:lowered_backend_module",
"//executorch/exir:memory",
"//executorch/exir/dialects:lib",
"//executorch/exir/dialects/edge:lib",
],
)

Expand Down
109 changes: 101 additions & 8 deletions exir/serde/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import json
import logging
import operator
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import executorch.exir as exir
import executorch.exir.memory as memory
Expand All @@ -21,7 +21,9 @@
import torch._export.serde.schema as schema
import torch._export.serde.serialize as export_serialize
from executorch.backends.compile_spec_schema import CompileSpec as delegate_CompileSpec
from executorch.exir.delegate import executorch_call_delegate
from executorch.exir import delegate
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.lowered_backend_module import (
LoweredBackendModule as ExirLoweredBackendModule,
)
Expand All @@ -31,7 +33,6 @@
)
from torch.fx.experimental import symbolic_shapes


log: logging.Logger = logging.getLogger(__name__)


Expand All @@ -42,6 +43,28 @@ def __init__(
super().__init__(graph_signature, call_spec)
self.state_dict: Dict[str, torch.Tensor] = {} # TODO(T157676982)

def serialize_operator(
self,
target: Union[
str,
EdgeOpOverload,
torch._ops.OpOverload,
torch._ops.HigherOrderOperator,
],
) -> str:
if isinstance(target, str):
return target
elif target.__module__.startswith("executorch.exir.dialects"):
# TODO(zhxchen17) Maybe provide a function name helper in FX.
# From torch.fx.node._get_qualified_name
module = target.__module__.replace(
"executorch.exir.dialects.edge._ops",
"executorch.exir.dialects.edge.ops",
)
return f"{module}.{target.__name__}"

return super().serialize_operator(target)

def handle_call_function(self, node: torch.fx.Node) -> None:
assert node.op == "call_function"

Expand All @@ -54,10 +77,22 @@ def handle_call_function(self, node: torch.fx.Node) -> None:
)
self.graph_state.nodes.append(ex_node)
return

elif node.target is executorch_call_delegate:
elif isinstance(node.target, EdgeOpOverload):
assert node.target._op is not None
ex_node = schema.Node(
target=export_serialize.serialize_operator(node.target),
target=self.serialize_operator(node.target),
# pyre-ignore Undefined attribute [16]: Item `typing.Callable` of
# `typing.Union[typing.Callable[..., typing.Any], str]` has no attribute `_op`.
inputs=self.serialize_inputs(node.target._op, node.args, node.kwargs),
outputs=self.serialize_outputs(node),
# TODO: create a new tensor_values here, meta might have faketensor info
metadata=self.serialize_metadata(node),
)
self.graph_state.nodes.append(ex_node)
return
elif node.target is delegate.executorch_call_delegate:
ex_node = schema.Node(
target=self.serialize_operator(node.target),
inputs=self.serialize_call_delegate_inputs(node.args),
outputs=self.serialize_arbitrary_outputs(node),
metadata=self.serialize_metadata(node),
Expand All @@ -67,6 +102,20 @@ def handle_call_function(self, node: torch.fx.Node) -> None:

super().handle_call_function(node)

def serialize_outputs(self, node: torch.fx.Node) -> List[schema.Argument]:
if isinstance(node.target, EdgeOpOverload):
# Store the original edge op
edge_op = node.target
# Replace the edge op with the original ATen op so that we can just call into
# the serialize_outputs implementation present in the parent class.
node.target = edge_op._op
ret = super().serialize_outputs(node)
# Replace the edge op back.
node.target = edge_op
else:
ret = super().serialize_outputs(node)
return ret

def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]:
meta = super().serialize_metadata(node)

Expand Down Expand Up @@ -265,6 +314,21 @@ def __init__(self, state_dict: Dict[str, torch.Tensor]) -> None:
super().__init__()
self.state_dict: Dict[str, Any] = state_dict # TODO(T157676982)

def deserialize_operator(self, serialized_target: str) -> str:
if serialized_target.startswith("executorch.exir.dialects.edge.ops"):
module = exir_ops.edge
serialized_target_names = serialized_target.split(".")[5:]

target = module
for name in serialized_target_names:
if not hasattr(target, name):
return serialized_target
else:
target = getattr(target, name)
return target

return super().deserialize_operator(serialized_target)

# pyre-ignore
def deserialize_node(self, serialized_node: schema.Node, target: Callable) -> None:
if target == "memory.alloc":
Expand All @@ -278,7 +342,7 @@ def deserialize_node(self, serialized_node: schema.Node, target: Callable) -> No
fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata))
return

elif target is executorch_call_delegate:
elif target is delegate.executorch_call_delegate:
if (
len(serialized_node.outputs) == 1
and serialized_node.outputs[0].type == "as_tensor"
Expand All @@ -298,7 +362,21 @@ def deserialize_node(self, serialized_node: schema.Node, target: Callable) -> No

fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata))
return

elif isinstance(target, EdgeOpOverload):
# For convenience: if this node returns a single tensor, name the
# newly-created node after it. This ensures that these tensor values
# have names that are consistent with serialized.
name = (
serialized_node.outputs[0].value.name
if export_serialize._is_single_tensor_return(target._op)
else None # FX will generate a name for us.
)
args, kwargs = self.deserialize_inputs(target._op, serialized_node)
fx_node = self.graph.create_node(
"call_function", target, args, kwargs, name
)
self.deserialize_outputs(serialized_node, fx_node)
return
elif isinstance(target, str):
# Create a dummy fake op if the target does not exist
# because we cannot create a call_function node w/o a
Expand All @@ -317,6 +395,21 @@ def fake_op(x):

super().deserialize_node(serialized_node, target)

def deserialize_outputs(
self, serialized_node: schema.Node, fx_node: torch.fx.Node
) -> None:
if isinstance(fx_node.target, EdgeOpOverload):
# Store the original edge op
edge_op = fx_node.target
# Replace the edge op with the original ATen op so that we can just call into
# node deserialize_outputs implementation present in the parent class.
fx_node.target = edge_op._op
super().deserialize_outputs(serialized_node, fx_node)
# Replace the edge op back.
fx_node.target = edge_op
else:
super().deserialize_outputs(serialized_node, fx_node)

def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]:
res = super().deserialize_metadata(metadata)

Expand Down
12 changes: 3 additions & 9 deletions exir/tests/test_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@

# Tests for serializing to json and back
class TestSerde(unittest.TestCase):
def setUp(self) -> None:
# TODO(gasoon): Remove this once serde is fully migrated to Edge ops
self.edge_complie_config = EdgeCompileConfig(_use_edge_ops=False)

def check_ep(
self,
ep1: TorchExportedProgram,
Expand All @@ -53,7 +49,7 @@ def check_serde(self, m, inputs) -> None:
aten_new = deserialize(*serialize(aten.exported_program))
self.check_ep(aten.exported_program, aten_new, inputs)

edge = aten.to_edge(self.edge_complie_config)
edge = aten.to_edge()
edge_new = deserialize(*serialize(edge.exported_program))
self.check_ep(edge.exported_program, edge_new, inputs)

Expand Down Expand Up @@ -118,7 +114,7 @@ def forward(self, x):
model_inputs = (torch.ones(1),)
edgeir_m = exir.capture(
sin_module, model_inputs, exir.CaptureConfig(pt2_mode=True)
).to_edge(self.edge_complie_config)
).to_edge()
max_value = model_inputs[0].shape[0]
compile_specs = [CompileSpec("max_value", bytes([max_value]))]
lowered_sin_module = to_backend(
Expand Down Expand Up @@ -160,9 +156,7 @@ def forward(self, a, x, b):
m = Model()
inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))

ep = exir.capture(m, inputs, exir.CaptureConfig(pt2_mode=True)).to_edge(
self.edge_complie_config
)
ep = exir.capture(m, inputs, exir.CaptureConfig(pt2_mode=True)).to_edge()
edge = to_backend(ep.exported_program, AddMulPartitionerDemo)
edge_new = deserialize(*serialize(edge))
self.check_ep(edge, edge_new, inputs)

0 comments on commit 13ca8c9

Please sign in to comment.