Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for edge dialect ops in exir/serde #106371

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
78 changes: 37 additions & 41 deletions torch/_export/serde/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,38 +181,6 @@ def serialize_tensor_meta(t: torch.Tensor) -> TensorMeta:
layout=_TORCH_TO_SERIALIZE_LAYOUT[t.layout],
)


def serialize_operator(target) -> str:
if isinstance(target, str):
return target
elif target.__module__.startswith("torch._ops"):
# TODO(zhxchen17) Maybe provide a function name helper in FX.
# From torch.fx.node._get_qualified_name
module = target.__module__.replace("torch._ops", "torch.ops")
return f"{module}.{target.__name__}"
else: # TODO(zhxchen17) Don't catch all here.
return f"{target.__module__}.{target.__name__}"


def deserialize_operator(serialized_target: str):
if serialized_target.startswith("_operator"): # TODO(zhxchen17) Follow up on this.
module = operator
serialized_target_names = serialized_target.split(".")[1:]
elif serialized_target.startswith("torch.ops"):
module = torch.ops
serialized_target_names = serialized_target.split(".")[2:]
else: # TODO(zhxchen17) Don't catch all here.
return serialized_target

target = module
for name in serialized_target_names:
if not hasattr(target, name):
return serialized_target
else:
target = getattr(target, name)
return target


def serialize_call_spec(call_spec: ep.CallSpec) -> CallSpec:
return CallSpec(
in_spec=pytree_to_str(call_spec.in_spec) if call_spec.in_spec else "",
Expand Down Expand Up @@ -391,6 +359,17 @@ def handle_output(self, node: torch.fx.Node):
assert isinstance(node_args, (tuple, list))
self.graph_state.outputs = [self.serialize_input(arg) for arg in node_args]

def serialize_operator(self, target) -> str:
if isinstance(target, str):
return target
elif target.__module__.startswith("torch._ops"):
# TODO(zhxchen17) Maybe provide a function name helper in FX.
# From torch.fx.node._get_qualified_name
module = target.__module__.replace("torch._ops", "torch.ops")
return f"{module}.{target.__name__}"
else: # TODO(zhxchen17) Don't catch all here.
return f"{target.__module__}.{target.__name__}"

def handle_call_function(self, node: torch.fx.Node):
assert node.op == "call_function"

Expand All @@ -402,7 +381,7 @@ def handle_call_function(self, node: torch.fx.Node):
assert len(node.kwargs) == 0
meta_val = node.meta["val"]
ex_node = Node(
target=serialize_operator(node.target),
target=self.serialize_operator(node.target),
inputs=self.serialize_sym_op_inputs(node.args),
outputs=[Argument.create(as_sym_int=self.serialize_sym_int_output(node.name, meta_val))],
metadata=self.serialize_metadata(node),
Expand All @@ -411,14 +390,14 @@ def handle_call_function(self, node: torch.fx.Node):
assert len(node.kwargs) == 0
meta_val = node.meta["val"]
ex_node = Node(
target=serialize_operator(node.target),
target=self.serialize_operator(node.target),
inputs=self.serialize_sym_op_inputs(node.args),
outputs=[Argument.create(as_sym_bool=self.serialize_sym_bool_output(node.name, meta_val))],
metadata=self.serialize_metadata(node),
)
elif isinstance(node.target, torch._ops.OpOverload):
ex_node = Node(
target=serialize_operator(node.target),
target=self.serialize_operator(node.target),
inputs=self.serialize_inputs(node.target, node.args, node.kwargs),
outputs=self.serialize_outputs(node),
# TODO: create a new tensor_values here, meta might have faketensor info
Expand All @@ -434,7 +413,7 @@ def handle_call_function(self, node: torch.fx.Node):
arg=self.serialize_input(a),
) for a in node.args]
ex_node = Node(
target=serialize_operator(node.target),
target=self.serialize_operator(node.target),
inputs=inputs,
outputs=[Argument.create(as_tensor=self.serialize_tensor_output(node.name, node.meta['val']))],
metadata=self.serialize_metadata(node),
Expand All @@ -455,14 +434,14 @@ def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]:
if nn_module_stack := node.meta.get("nn_module_stack"):
# Serialize to "fx_node_name:(orig_ref,type_str)"
nn_module_list = [
f"{k}:({v[0]},{serialize_operator(v[1])})"
f"{k}:({v[0]},{self.serialize_operator(v[1])})"
for k, v in nn_module_stack.items()
]
ret["nn_module_stack"] = ";".join(nn_module_list)

if source_fn := node.meta.get("source_fn"):
# Serialize to "fx_node_name,op_str"
op = serialize_operator(source_fn[1])
op = self.serialize_operator(source_fn[1])
ret["source_fn"] = f"{source_fn[0]},{op}"

return ret
Expand Down Expand Up @@ -760,14 +739,31 @@ def save_graph_module(self) -> Iterator[None]:
finally:
self.graph, self.module, self.serialized_name_to_node, self.serialized_name_to_meta = saved

def deserialize_operator(self, serialized_target: str):
if serialized_target.startswith("_operator"): # TODO(zhxchen17) Follow up on this.
module = operator
serialized_target_names = serialized_target.split(".")[1:]
elif serialized_target.startswith("torch.ops"):
module = torch.ops
serialized_target_names = serialized_target.split(".")[2:]
else: # TODO(zhxchen17) Don't catch all here.
return serialized_target

target = module
for name in serialized_target_names:
if not hasattr(target, name):
return serialized_target
else:
target = getattr(target, name)
return target

def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]:
val = s.value
if s.type == "as_expr":
if val.expr_str in self.symbol_name_to_symbol:
sym = self.symbol_name_to_symbol[val.expr_str]
else:
sym = sympy.sympify(val.expr_str, locals=self.symbol_name_to_symbol)

if isinstance(sym, sympy.Symbol):
self.symbol_name_to_symbol[val.expr_str] = sym

Expand Down Expand Up @@ -834,7 +830,7 @@ def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph:
# Nodes: convert to call_function nodes.
for serialized_node in serialized_graph.nodes:
try:
target = deserialize_operator(serialized_node.target)
target = self.deserialize_operator(serialized_node.target)
self.deserialize_node(serialized_node, target)

except Exception as e:
Expand Down Expand Up @@ -1064,7 +1060,7 @@ def deserialize_meta_func(serialized_target: str):
module = torch
serialized_target_names = serialized_target.split(".")[1:]
else:
return deserialize_operator(serialized_target)
return self.deserialize_operator(serialized_target)

target = module
for name in serialized_target_names:
Expand Down