Skip to content

Commit

Permalink
[export] Handle serializing duplicate getitem nodes (#127633)
Browse files Browse the repository at this point in the history
We ran into a graph that looks something like the following, where we have 2 getitem calls to the same index (%getitem, %getitem_2 both query topk[0]):
```
graph():
    %x : [num_users=1] = placeholder[target=x]
    %topk : [num_users=3] = call_function[target=torch.ops.aten.topk.default](args = (%x, 2), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%topk, 0), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%topk, 1), kwargs = {})
    %getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%topk, 0), kwargs = {})
    %mul_tensor : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem, %getitem_2), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_tensor, 2), kwargs = {})
    return (mul, getitem_1)
```

The duplicate getitem call gets created during a pass.. so there are a couple of solutions:

1. Change serializer to support the case of duplicate getitem calls
2. Change the pass so that it doesn’t produce duplicate getitem calls
3. Add a pass which dedups the getitem calls

As a framework, we should do 1 and 3 (through a CSE pass).

This PR implements solution 1. However, the serializer currently does some special handling for getitem nodes -- instead of directly serializing the getitem nodes, we serialize the output of the node that outputting a list of tensors (the %topk node in this example) into a list nodes for each output ([%getitem, %getitem_1]). This fails when we have duplicate getitem nodes to the same index (%getitem_2), since we do not record that duplicate getitem node anywhere. So, the solution this PR takes is that the serializer will deduplicate the getitem nodes (%getitem_2 will be replaced with %getitem). This would result in a sematically correct graph, but not necessarily node-to-node identical as the original fx graph.
Pull Request resolved: #127633
Approved by: https://github.com/ydwu4
  • Loading branch information
angelayi authored and pytorchmergebot committed Jun 3, 2024
1 parent 12c4a2c commit 4d32de1
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 22 deletions.
43 changes: 43 additions & 0 deletions test/export/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,49 @@ def forward(self, x):
dynamic_shapes = {"x": {0: Dim("dim0"), 1: Dim("dim1")}}
self.check_graph(Foo(), (torch.ones(4, 5),), dynamic_shapes=dynamic_shapes)

def test_multiple_getitem(self):
class M(torch.nn.Module):
def forward(self, x):
a, b = torch.topk(x, 2)
a = a * 2
return a, b

ep = torch.export.export(M(), (torch.ones(3),))

# insert another getitem node
for node in ep.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.mul.Tensor:
getitem_0 = node.args[0]
with ep.graph.inserting_before(getitem_0):
getitem_copy = ep.graph.node_copy(getitem_0)
mul_node = ep.graph.call_function(
torch.ops.aten.mul.Tensor, (getitem_copy, 2)
)
mul_node.meta = copy.copy(getitem_copy.meta)
node.args = (getitem_0, mul_node)

deserialized_ep = deserialize(serialize(ep))

inp = (torch.randn(3),)
orig_res = ep.module()(*inp)
res = deserialized_ep.module()(*inp)
self.assertTrue(torch.allclose(orig_res[0], res[0]))
self.assertTrue(torch.allclose(orig_res[1], res[1]))

# The deserialized graph should have deduped getitem calls
self.assertExpectedInline(
deserialized_ep.graph_module.code.strip("\n"),
"""\
def forward(self, x):
topk_default = torch.ops.aten.topk.default(x, 2); x = None
getitem = topk_default[0]
getitem_1 = topk_default[1]; topk_default = None
mul_tensor = torch.ops.aten.mul.Tensor(getitem, 2)
mul = torch.ops.aten.mul.Tensor(getitem, mul_tensor); getitem = mul_tensor = None
return (mul, getitem_1)
""",
)

@parametrize(
"name,case",
get_filtered_export_db_tests(),
Expand Down
53 changes: 31 additions & 22 deletions torch/_export/serde/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,14 +387,6 @@ def _is_single_tensor_list_return(target: Any) -> bool:
return_type.getElementType(), torch.TensorType
)

def _output_node_at_index(node, index):
for user in node.users:
assert user.target is operator.getitem, f"{user} is not a getitem node"
if index == user.args[1]:
return user
return None



@dataclass
class GraphState:
Expand Down Expand Up @@ -427,6 +419,7 @@ def __init__(
self.graph_signature = graph_signature
self.module_call_graph = module_call_graph
self.custom_objs: Dict[str, torch._C.ScriptObject] = {}
self.duplicate_getitem_nodes: Dict[str, str] = {}

@contextmanager
def save_graph_state(self):
Expand Down Expand Up @@ -552,6 +545,19 @@ def handle_call_function(self, node: torch.fx.Node):
def handle_get_attr(self, node):
pass

def _output_node_at_index(self, node, index):
user_node = None
for user in node.users:
assert user.target is operator.getitem, f"{user} is not a getitem node"
if index == user.args[1]:
if user_node is None:
user_node = user
else:
# We want to deduplicate getitem nodes that are trying to
# index to the same index
self.duplicate_getitem_nodes[user.name] = user_node.name
return user_node

def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]:
ret = {}
if stack_trace := node.meta.get("stack_trace"):
Expand Down Expand Up @@ -705,13 +711,16 @@ def serialize_input(
return Argument.create(
as_sym_bool=SymBoolArgument.create(as_name=arg.name)
)
else:
if isinstance(arg.meta["val"], ep.CustomObjArgument):
return Argument.create(
as_custom_obj=CustomObjArgument(
name=arg.name, class_fqn=arg.meta["val"].class_fqn
)
elif isinstance(arg.meta["val"], ep.CustomObjArgument):
return Argument.create(
as_custom_obj=CustomObjArgument(
name=arg.name, class_fqn=arg.meta["val"].class_fqn
)
)
elif arg.name in self.duplicate_getitem_nodes:
dedup_name = self.duplicate_getitem_nodes[arg.name]
return Argument.create(as_tensor=TensorArgument(name=dedup_name))
else:
return Argument.create(as_tensor=TensorArgument(name=arg.name))
elif isinstance(arg, inductor_tensor_buffers):
# Other branches are for arguments in fx node.
Expand Down Expand Up @@ -1121,7 +1130,7 @@ def serialize_outputs(self, node: torch.fx.Node) -> List[Argument]:
# e.g "-> Tensor[]"
tensor_args = []
for idx, meta in enumerate(meta_val):
user_node = _output_node_at_index(node, idx)
user_node = self._output_node_at_index(node, idx)
name = (
user_node.name
if user_node is not None
Expand Down Expand Up @@ -1151,7 +1160,7 @@ def serialize_outputs(self, node: torch.fx.Node) -> List[Argument]:
output_arguments.append(Argument.create(as_none=()))
elif isinstance(meta, FakeTensor):
assert isinstance(return_schema.real_type, (torch.OptionalType, torch.TensorType))
user_node = _output_node_at_index(node, idx)
user_node = self._output_node_at_index(node, idx)
name = (
user_node.name
if user_node is not None
Expand All @@ -1165,20 +1174,20 @@ def serialize_outputs(self, node: torch.fx.Node) -> List[Argument]:
) and isinstance(
return_schema.real_type.getElementType(), torch.TensorType
)
user_node = _output_node_at_index(node, idx)
user_node = self._output_node_at_index(node, idx)
assert user_node is not None

args = []
for i, m in enumerate(meta):
if m is None:
continue
sub_user_node = _output_node_at_index(user_node, i)
sub_user_node = self._output_node_at_index(user_node, i)
assert sub_user_node is not None, f"No user found at index {i}"

args.append(self.serialize_tensor_output(sub_user_node.name, m))
output_arguments.append(Argument.create(as_tensors=args))
elif isinstance(meta, (int, SymInt)):
user_node = _output_node_at_index(node, idx)
user_node = self._output_node_at_index(node, idx)
name = (
user_node.name
if user_node is not None
Expand Down Expand Up @@ -1208,7 +1217,7 @@ def serialize_hoo_outputs(self, node: torch.fx.Node) -> List[Argument]:

if len(meta_val) == 1:
assert isinstance(meta_val[0], torch.Tensor)
user_node = _output_node_at_index(node, 0)
user_node = self._output_node_at_index(node, 0)
name = (
user_node.name
if user_node is not None
Expand All @@ -1218,7 +1227,7 @@ def serialize_hoo_outputs(self, node: torch.fx.Node) -> List[Argument]:

outputs = []
for i, element_meta_val in enumerate(meta_val):
user_node = _output_node_at_index(node, i)
user_node = self._output_node_at_index(node, i)
if isinstance(element_meta_val, list):
# e.g "-> Tensor[]"
assert user_node is not None
Expand All @@ -1228,7 +1237,7 @@ def serialize_hoo_outputs(self, node: torch.fx.Node) -> List[Argument]:
if not isinstance(m, torch.Tensor):
raise SerializeError(f"Serialize list output with type {type(m)} nyi")

sub_user_node = _output_node_at_index(user_node, j)
sub_user_node = self._output_node_at_index(user_node, j)
name = (
sub_user_node.name
if sub_user_node is not None
Expand Down

0 comments on commit 4d32de1

Please sign in to comment.