Skip to content

fix unflatten with HOPs #138978

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -6562,6 +6562,51 @@ def test(ep):

test(export(M(), inp))

def test_set_grad_unflatten(self):
class M1(torch.nn.Module):
def forward(self, a, b):
with torch.no_grad():
return a + b

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.m1 = M1()

def forward(self, a, b):
return self.m1(a, b)

inp = (torch.ones(3, 3), torch.ones(3, 3))
ep = export(M(), inp)
epm = ep.module()
ufm = torch.export.unflatten(ep)
self.assertTrue(torch.allclose(ufm(*inp), epm(*inp)))

def test_cond_unflatten(self):
class M1(torch.nn.Module):
def forward(self, p, a, b):
def true_fn(x, y):
return x + y

def false_fn(x, y):
return x - y

return torch.cond(p, true_fn, false_fn, [a, b])

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.m1 = M1()

def forward(self, p, a, b):
return self.m1(p, a, b)

inp = (torch.tensor(False), torch.ones(3, 3), torch.ones(3, 3))
ep = export(M(), inp)
epm = ep.module()
ufm = torch.export.unflatten(ep)
self.assertTrue(torch.allclose(ufm(*inp), epm(*inp)))

def test_unflatten_multiple_graphs_shared_submodule(self):
class N(torch.nn.Module):
def forward(self, x, b):
Expand Down
9 changes: 8 additions & 1 deletion torch/_export/serde/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1699,6 +1699,13 @@ def deserialize_node(self, serialized_node: Node, target: Callable) -> None:

elif isinstance(target, torch._ops.HigherOrderOperator):
args, kwargs = self.deserialize_hoo_inputs(serialized_node.inputs)
metadata = self.deserialize_metadata(serialized_node.metadata)
for x in (*args, *kwargs.values()):
if isinstance(x, torch.fx.Node) and x.op == "get_attr":
# this means that we have deserialized a graph argument, but
# unfortunately the schema for it does not include metadata;
# so we reuse the metadata of the HOP call for such arguments
x.meta.update(metadata)
# If HOP returns a single tensor, name the
# newly-created node after it. This ensures that these tensor values
# have names that are consistent with serialized.
Expand All @@ -1714,7 +1721,7 @@ def deserialize_node(self, serialized_node: Node, target: Callable) -> None:
"call_function", target, args, kwargs, name
)
self.deserialize_outputs(serialized_node, fx_node)
fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata))
fx_node.meta.update(metadata)

elif isinstance(target, (torch._ops.OpOverload, *_registered_extension_types())):
# For convenience: if this node returns a single tensor, name the
Expand Down
4 changes: 3 additions & 1 deletion torch/distributed/pipelining/_unflatten.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from collections import defaultdict
from typing import Dict, List
from typing import Dict, List, Set

import torch
from torch.export.unflatten import _ModuleFrame, _SubmoduleEntry
Expand All @@ -12,11 +12,13 @@ def _outline_submodules(orig_graph: torch.fx.Graph):
new_module = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
seen_nodes: Dict[str, torch.fx.Node] = {}
seen_modules: Dict[int, List[_SubmoduleEntry]] = defaultdict(list)
seen_attrs: Dict[str, Set[str]] = defaultdict(set)
_ModuleFrame(
orig_graph,
tuple(orig_graph.nodes),
seen_nodes,
seen_modules,
seen_attrs,
None,
[("", 0)],
"",
Expand Down
29 changes: 27 additions & 2 deletions torch/export/unflatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,14 @@ def __init__(
self.ivals = _IVals()
# record any intermediate value x that is used, with the modules that used it,
# and generate instructions to read the corresponding attribute
seen_modules = _outline_submodules(export_graph, self)
seen_modules, seen_attrs = _outline_submodules(export_graph, self)
# for each read intermediate value x, find the module that created it,
# and generate instructions to update the corresponding attribute;
# finally, initialize all these attributes
self.ivals.create(seen_modules.values())
# move attributes that correspond to graph arguments for HOPs
# from exported program to unflattened submodules
_copy_graph_attrs(export_module._graph_module, self, seen_attrs)

self.range_constraints = export_module.range_constraints
self.equality_constraints: List = []
Expand Down Expand Up @@ -760,6 +763,7 @@ def __init__(
nodes: Tuple[torch.fx.Node, ...],
seen_nodes,
seen_modules,
seen_attrs,
parent,
module_stack: List[Tuple[str, int]],
module_id,
Expand All @@ -770,6 +774,7 @@ def __init__(
self.nodes = nodes
self.seen_nodes = seen_nodes
self.seen_modules = seen_modules
self.seen_attrs = seen_attrs
self.parent = parent
self.module_stack = module_stack
self.module_id = module_id
Expand Down Expand Up @@ -1135,6 +1140,7 @@ def run_from(self, node_idx):
self.nodes,
self.seen_nodes,
self.seen_modules,
self.seen_attrs,
self,
self.module_stack + [next_module],
next_module_key.split("@")[0],
Expand All @@ -1146,6 +1152,11 @@ def run_from(self, node_idx):
# The only remaining possibility is that we are in the right stack
# frame. Copy the node into this frame's graph and increment the node counter.
assert node_module_stack == self.module_stack

if node.op == "get_attr":
# this must be a graph argument for a HOP
self.seen_attrs[self.child_fqn].add(node.target)

self.copy_node(node)
node_idx += 1

Expand All @@ -1163,11 +1174,13 @@ class _SubmoduleEntry:
def _outline_submodules(orig_graph: torch.fx.Graph, root_module: UnflattenedModule):
seen_nodes: Dict[str, torch.fx.Node] = {}
seen_modules: Dict[int, List[_SubmoduleEntry]] = defaultdict(list)
seen_attrs: Dict[str, Set[str]] = defaultdict(set)
_ModuleFrame(
orig_graph,
tuple(orig_graph.nodes),
seen_nodes,
seen_modules,
seen_attrs,
None,
[("", 0)],
"",
Expand All @@ -1178,7 +1191,7 @@ def _outline_submodules(orig_graph: torch.fx.Graph, root_module: UnflattenedModu
},
module=root_module,
).run_outer()
return seen_modules
return seen_modules, seen_attrs


def _reorder_submodules(
Expand Down Expand Up @@ -1303,6 +1316,18 @@ def create(self, partitions):
)


def _copy_graph_attrs(
gm: torch.fx.GraphModule,
root_module: UnflattenedModule,
seen_attrs: Dict[str, Set[str]],
):
for child_fqn, names in seen_attrs.items():
module = _get_attr(root_module, child_fqn) if child_fqn else root_module
for name in names:
val = getattr(gm, name)
setattr(module, name, val)


def _deduplicate_modules(partitions):
for shared_submodules in partitions:
for i, entry in enumerate(shared_submodules):
Expand Down
Loading