-
Notifications
You must be signed in to change notification settings - Fork 21.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add backwards support to FlexAttention (#123902)
# Summary This is part one of adding backwards support to FlexAttention. This PR focuses on the eager implementation and wiring up enough of the templated_attention_backward(name change soon 馃槈) to get through aot_eager. Notably this does not actually wire up the triton template just yet in order to make this PR easier to review. That will be the next follow up PR. #### Structure We pass both the forward and backward graph to the backwardsHOP since these are both needed to be inlined into the calculation for backwards: - the forward graph is needed in order to re-compute the scores - the joint graph is needed in order to construct the correct gradients post softmax_grad calc ### Attatched AOT Graph https://gist.github.com/drisspg/ce4c041f8df8a5a7983c5174705cf2b5 Pull Request resolved: #123902 Approved by: https://github.com/Chillee
- Loading branch information
1 parent
720e5f3
commit 8c21925
Showing
7 changed files
with
707 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from .cond import cond | ||
from .while_loop import while_loop | ||
from .templated_attention import templated_attention | ||
from .templated_attention import templated_attention, templated_attention_backward |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.