Skip to content

Conversation

shunting314
Copy link

Add the backward formula of swiglu in examples/swiglu.py

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 1, 2025


@helion.kernel()
def swiglu_bwd(gout: Tensor, x1: Tensor, x2: Tensor) -> tuple[Tensor, Tensor]:
Copy link
Contributor

@oulgen oulgen Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add this to run.py there are two lists you need to update there

also please run with triton bench and generate perf/accuracy numbers
cc: @yf225

Copy link
Author

@shunting314 shunting314 Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@oulgen do you have an example to do that for a backward kernel? I can find a few examples for fwd but not bwd

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh found 'rms_norm-bwd' in the run.py. will follow it

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran this command:

python benchmarks/run.py --metrics speedup,accuracy --kernel swiglu-bwd

but don't see the number for helion. Any ideas?

      (B, T, H)    liger_swiglu-speedup    liger_swiglu-accuracy    torch_compile_swiglu-speedup    torch_compile_swiglu-accuracy
---------------  ----------------------  -----------------------  ------------------------------  -------------------------------
(4, 1024, 4096)                1.01139                         1                         1.03097                                1
(4, 2048, 4096)                1.02854                         1                         1.00777                                1
(4, 4096, 4096)                1.03631                         1                         1.03787                                1
(4, 8192, 4096)                0.841614                        1                         1.04048                                1
        average                0.979463                        1                         1.02927                                1

@oulgen

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 can you help?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants