-
Notifications
You must be signed in to change notification settings - Fork 21.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[export] Handle serializing duplicate getitem nodes (#127633)
We ran into a graph that looks something like the following, where we have 2 getitem calls to the same index (%getitem, %getitem_2 both query topk[0]): ``` graph(): %x : [num_users=1] = placeholder[target=x] %topk : [num_users=3] = call_function[target=torch.ops.aten.topk.default](args = (%x, 2), kwargs = {}) %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%topk, 0), kwargs = {}) %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%topk, 1), kwargs = {}) %getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%topk, 0), kwargs = {}) %mul_tensor : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem, %getitem_2), kwargs = {}) %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_tensor, 2), kwargs = {}) return (mul, getitem_1) ``` The duplicate getitem call gets created during a pass.. so there are a couple of solutions: 1. Change serializer to support the case of duplicate getitem calls 2. Change the pass so that it doesn’t produce duplicate getitem calls 3. Add a pass which dedups the getitem calls As a framework, we should do 1 and 3 (through a CSE pass). This PR implements solution 1. However, the serializer currently does some special handling for getitem nodes -- instead of directly serializing the getitem nodes, we serialize the output of the node that outputting a list of tensors (the %topk node in this example) into a list nodes for each output ([%getitem, %getitem_1]). This fails when we have duplicate getitem nodes to the same index (%getitem_2), since we do not record that duplicate getitem node anywhere. So, the solution this PR takes is that the serializer will deduplicate the getitem nodes (%getitem_2 will be replaced with %getitem). This would result in a sematically correct graph, but not necessarily node-to-node identical as the original fx graph. Pull Request resolved: #127633 Approved by: https://github.com/ydwu4
- Loading branch information
1 parent
12c4a2c
commit 4d32de1
Showing
2 changed files
with
74 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters