Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make fused RMSNorm a registered op #199

Closed
lessw2020 opened this issue Apr 5, 2024 · 2 comments
Closed

Make fused RMSNorm a registered op #199

lessw2020 opened this issue Apr 5, 2024 · 2 comments
Assignees
Labels
bug Something isn't working enhancement New feature or request

Comments

@lessw2020
Copy link
Contributor

Adding this as tracking issue to unblock #181 from landing:
per @wanchaol :
IMO we should also register the fwd/bwd rmsnorm kernel as a PyTorch op, this is so that:

making it a custom op makes it compatible with PT2, which I believe it's currently graph breaking on the FusedRMSNorm path if we turn on torch.compile
it allows other components (i.e. DTensor) to provide sharding rule to this custom op so that it would compatible with the tensor parallelism

@tianyu-l
Copy link
Contributor

tianyu-l commented May 8, 2024

update: Hit IMA issues for both my implementation #296 and @wconstab's #303. Working on debugging with @lessw2020 .

@tianyu-l
Copy link
Contributor

closing this as we have supported this fused RMSNorm in Tensor Parallelism (#404).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants