Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:GPU] In Multi-output fusion, use BFS to check reachability #63707

Merged
merged 1 commit into from Mar 14, 2024

Conversation

copybara-service[bot]
Copy link

[XLA:GPU] In Multi-output fusion, use BFS to check reachability

Currently, MultiOutputFusion uses HloReachabilityMap, which builds reachability matrix on creation and element lookup on request. With multi-output fusion rebuilding the reachability after every step, that may take time.

This change makes MultiOutputFusion to use a simple breadth-first search reachability check, bounded by the post-order index. Namely,

  • To build the "map", the instructions are sorted in the post order, and indices are stored.
  • On the request, the BFS from the destination node is started, and only goes through nodes which have post-order index >= source node.

With this approach, the MultiOutputFusion pass of one example HLO module goes from 34 minutes to 4 minutes. There are more advanced data structures that help even more (0.5 minutes), this will be done in the subsequent changes.

Currently, MultiOutputFusion uses HloReachabilityMap, which builds reachability matrix on creation and element lookup on request. With multi-output fusion rebuilding the reachability after every step, that may take time.

This change makes MultiOutputFusion to use a simple depth-first search reachability check, bounded by the post-order index. Namely,
* To build the "map", the instructions are sorted in the post order, and indices are stored.
* On the request, the DFS from the destination node is started, and only goes through nodes which have post-order index >= source node.

With this approach, the MultiOutputFusion pass of one example HLO module goes from 34 minutes to 4 minutes. There are more advanced data structures that help even more (0.5 minutes), this will be done in the subsequent changes.

PiperOrigin-RevId: 615730140
@copybara-service copybara-service bot merged commit 8cdcf82 into master Mar 14, 2024
@copybara-service copybara-service bot deleted the test_615049679 branch March 14, 2024 11:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant