-
Notifications
You must be signed in to change notification settings - Fork 36
[helion] backward support for swiglu #756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
8d8a27d
to
99587ae
Compare
|
||
|
||
@helion.kernel() | ||
def swiglu_bwd(gout: Tensor, x1: Tensor, x2: Tensor) -> tuple[Tensor, Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add this to run.py there are two lists you need to update there
also please run with triton bench and generate perf/accuracy numbers
cc: @yf225
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@oulgen do you have an example to do that for a backward kernel? I can find a few examples for fwd but not bwd
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh found 'rms_norm-bwd' in the run.py. will follow it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ran this command:
python benchmarks/run.py --metrics speedup,accuracy --kernel swiglu-bwd
but don't see the number for helion. Any ideas?
(B, T, H) liger_swiglu-speedup liger_swiglu-accuracy torch_compile_swiglu-speedup torch_compile_swiglu-accuracy
--------------- ---------------------- ----------------------- ------------------------------ -------------------------------
(4, 1024, 4096) 1.01139 1 1.03097 1
(4, 2048, 4096) 1.02854 1 1.00777 1
(4, 4096, 4096) 1.03631 1 1.03787 1
(4, 8192, 4096) 0.841614 1 1.04048 1
average 0.979463 1 1.02927 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 can you help?
99587ae
to
abbe582
Compare
abbe582
to
5d29b48
Compare
Add the backward formula of swiglu in examples/swiglu.py