-
Notifications
You must be signed in to change notification settings - Fork 36
Add backward kernel for exp #736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add this to run.py, you can grep for rms_norm again for backward implementation, and share perf/accuracy results with tritonbench.
@Sibylau can give pointers on how to do it
gist is here: https://docs.google.com/document/d/1BiQaFJRBufzcLNPSMhVKuC9gv-KEFpN4B49CqOjZC74/edit?usp=sharing
From what I see, tritonbench does not have a bwd operator for this op (the same is true for some other ops too) |
Yeah let's probably also add backward support for this operator to tritonbench (we can look at get_bwd_fn in tritonbench layer_norm and rms_norm for how to do this; also need to make sure the inputs have proper requires_grad=True too). |
tritonbench PR: meta-pytorch/tritonbench#501 |
5bc32b8
to
61446b8
Compare
61446b8
to
67819f3
Compare
Test Plan:
Adding new unit tests in test_examples.py
Testing with run.py and exp.py