Skip to content

Commit

Permalink
Add support for edge dialect ops in exir/serde (#106371)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #106371

Adding support for edge dialect ops in `exir/serde`. This diff does the following:
- Moves the global `serialize_operator/deserialize_operator` implementations in`export/serde/serialize.py` into `GraphModuleSerializer` and `GraphModuleDeserializer`
- Adds implementations of `serialize_operator/deserialize_operator` inside `GraphModuleSerializer` and `GraphModuleDeserializer` in `exir/serde/serialize.py`

Test Plan: CI + Enabled edge dialect ops in `executorch/exir/tests/test_serde.py`

Reviewed By: chakriu, angelayi

Differential Revision: D47938280

fbshipit-source-id: a20f3e7441142bf84a0d0f44cba250148b5e5885
  • Loading branch information
tarun292 authored and facebook-github-bot committed Aug 1, 2023
1 parent 5c3aae8 commit 8fea473
Showing 1 changed file with 37 additions and 41 deletions.
78 changes: 37 additions & 41 deletions torch/_export/serde/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,38 +181,6 @@ def serialize_tensor_meta(t: torch.Tensor) -> TensorMeta:
layout=_TORCH_TO_SERIALIZE_LAYOUT[t.layout],
)


def serialize_operator(target) -> str:
if isinstance(target, str):
return target
elif target.__module__.startswith("torch._ops"):
# TODO(zhxchen17) Maybe provide a function name helper in FX.
# From torch.fx.node._get_qualified_name
module = target.__module__.replace("torch._ops", "torch.ops")
return f"{module}.{target.__name__}"
else: # TODO(zhxchen17) Don't catch all here.
return f"{target.__module__}.{target.__name__}"


def deserialize_operator(serialized_target: str):
if serialized_target.startswith("_operator"): # TODO(zhxchen17) Follow up on this.
module = operator
serialized_target_names = serialized_target.split(".")[1:]
elif serialized_target.startswith("torch.ops"):
module = torch.ops
serialized_target_names = serialized_target.split(".")[2:]
else: # TODO(zhxchen17) Don't catch all here.
return serialized_target

target = module
for name in serialized_target_names:
if not hasattr(target, name):
return serialized_target
else:
target = getattr(target, name)
return target


def serialize_call_spec(call_spec: ep.CallSpec) -> CallSpec:
return CallSpec(
in_spec=pytree_to_str(call_spec.in_spec) if call_spec.in_spec else "",
Expand Down Expand Up @@ -391,6 +359,17 @@ def handle_output(self, node: torch.fx.Node):
assert isinstance(node_args, (tuple, list))
self.graph_state.outputs = [self.serialize_input(arg) for arg in node_args]

def serialize_operator(self, target) -> str:
if isinstance(target, str):
return target
elif target.__module__.startswith("torch._ops"):
# TODO(zhxchen17) Maybe provide a function name helper in FX.
# From torch.fx.node._get_qualified_name
module = target.__module__.replace("torch._ops", "torch.ops")
return f"{module}.{target.__name__}"
else: # TODO(zhxchen17) Don't catch all here.
return f"{target.__module__}.{target.__name__}"

def handle_call_function(self, node: torch.fx.Node):
assert node.op == "call_function"

Expand All @@ -402,7 +381,7 @@ def handle_call_function(self, node: torch.fx.Node):
assert len(node.kwargs) == 0
meta_val = node.meta["val"]
ex_node = Node(
target=serialize_operator(node.target),
target=self.serialize_operator(node.target),
inputs=self.serialize_sym_op_inputs(node.args),
outputs=[Argument.create(as_sym_int=self.serialize_sym_int_output(node.name, meta_val))],
metadata=self.serialize_metadata(node),
Expand All @@ -411,14 +390,14 @@ def handle_call_function(self, node: torch.fx.Node):
assert len(node.kwargs) == 0
meta_val = node.meta["val"]
ex_node = Node(
target=serialize_operator(node.target),
target=self.serialize_operator(node.target),
inputs=self.serialize_sym_op_inputs(node.args),
outputs=[Argument.create(as_sym_bool=self.serialize_sym_bool_output(node.name, meta_val))],
metadata=self.serialize_metadata(node),
)
elif isinstance(node.target, torch._ops.OpOverload):
ex_node = Node(
target=serialize_operator(node.target),
target=self.serialize_operator(node.target),
inputs=self.serialize_inputs(node.target, node.args, node.kwargs),
outputs=self.serialize_outputs(node),
# TODO: create a new tensor_values here, meta might have faketensor info
Expand All @@ -434,7 +413,7 @@ def handle_call_function(self, node: torch.fx.Node):
arg=self.serialize_input(a),
) for a in node.args]
ex_node = Node(
target=serialize_operator(node.target),
target=self.serialize_operator(node.target),
inputs=inputs,
outputs=[Argument.create(as_tensor=self.serialize_tensor_output(node.name, node.meta['val']))],
metadata=self.serialize_metadata(node),
Expand All @@ -455,14 +434,14 @@ def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]:
if nn_module_stack := node.meta.get("nn_module_stack"):
# Serialize to "fx_node_name:(orig_ref,type_str)"
nn_module_list = [
f"{k}:({v[0]},{serialize_operator(v[1])})"
f"{k}:({v[0]},{self.serialize_operator(v[1])})"
for k, v in nn_module_stack.items()
]
ret["nn_module_stack"] = ";".join(nn_module_list)

if source_fn := node.meta.get("source_fn"):
# Serialize to "fx_node_name,op_str"
op = serialize_operator(source_fn[1])
op = self.serialize_operator(source_fn[1])
ret["source_fn"] = f"{source_fn[0]},{op}"

return ret
Expand Down Expand Up @@ -760,14 +739,31 @@ def save_graph_module(self) -> Iterator[None]:
finally:
self.graph, self.module, self.serialized_name_to_node, self.serialized_name_to_meta = saved

def deserialize_operator(self, serialized_target: str):
if serialized_target.startswith("_operator"): # TODO(zhxchen17) Follow up on this.
module = operator
serialized_target_names = serialized_target.split(".")[1:]
elif serialized_target.startswith("torch.ops"):
module = torch.ops
serialized_target_names = serialized_target.split(".")[2:]
else: # TODO(zhxchen17) Don't catch all here.
return serialized_target

target = module
for name in serialized_target_names:
if not hasattr(target, name):
return serialized_target
else:
target = getattr(target, name)
return target

def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]:
val = s.value
if s.type == "as_expr":
if val.expr_str in self.symbol_name_to_symbol:
sym = self.symbol_name_to_symbol[val.expr_str]
else:
sym = sympy.sympify(val.expr_str, locals=self.symbol_name_to_symbol)

if isinstance(sym, sympy.Symbol):
self.symbol_name_to_symbol[val.expr_str] = sym

Expand Down Expand Up @@ -834,7 +830,7 @@ def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph:
# Nodes: convert to call_function nodes.
for serialized_node in serialized_graph.nodes:
try:
target = deserialize_operator(serialized_node.target)
target = self.deserialize_operator(serialized_node.target)
self.deserialize_node(serialized_node, target)

except Exception as e:
Expand Down Expand Up @@ -1064,7 +1060,7 @@ def deserialize_meta_func(serialized_target: str):
module = torch
serialized_target_names = serialized_target.split(".")[1:]
else:
return deserialize_operator(serialized_target)
return self.deserialize_operator(serialized_target)

target = module
for name in serialized_target_names:
Expand Down

0 comments on commit 8fea473

Please sign in to comment.