Skip to content

Commit

Permalink
Speed up _extract_graph_with_inputs_outputs (#125937)
Browse files Browse the repository at this point in the history
_extract_graph_with_inputs_outputs() does membership testing on the input nodes but often that collection is a list so the test is O(n).  Ensure it's a set before looping over all the nodes.

This change speeds up the internal repro (D57090987) by about 18%:
before:
```
708.88user 15.86system 12:16.19elapsed 98%CPU (0avgtext+0avgdata 12898628maxresident)k
0inputs+91968outputs (3major+3532970minor)pagefaults 0swaps
```
after:
```
583.39user 15.98system 10:10.11elapsed 98%CPU (0avgtext+0avgdata 12895108maxresident)k
0inputs+87488outputs (4major+3374582minor)pagefaults 0swaps
```

Pull Request resolved: #125937
Approved by: https://github.com/oulgen, https://github.com/anijain2305
  • Loading branch information
aorenste authored and pytorchmergebot committed May 11, 2024
1 parent 4457cd9 commit a5c93a6
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion torch/_functorch/partitioners.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ def _extract_graph_with_inputs_outputs(joint_graph, inputs, outputs):
env[node] = new_node

for node in joint_graph.nodes:
if node in inputs:
if node in env:
# Node must be one of our inputs. (Any member of env which wasn't an
# input to start must have been created by this loop and won't be in
# joint_graph.nodes).
continue
elif node.op == "placeholder":
env[node] = InvalidNode
Expand Down

0 comments on commit a5c93a6

Please sign in to comment.