Skip to content

[float8 moe training] make using triton kernels for per-group scaling configurable #2405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 18, 2025

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Jun 18, 2025

Summary

  • Make using triton kernels for per-group scaling configurable so we can measure speedup from avoiding d2h sync when benchmarking e2e training, and validate the expected speedup. We can remove this before graduating out of prototype.
  • Improve benchmarking script, make compile configurable via an argument (I discovered a bug in torch.compile tune_scaled_grouped_mm so want the benchmark script to be usable as a repro for debugging)

@danielvegamyhre danielvegamyhre added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Jun 18, 2025
Copy link

pytorch-bot bot commented Jun 18, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2405

Note: Links to docs will display an error until the docs builds have been completed.

⏳ 3 Pending, 1 Unrelated Failure

As of commit a285fc8 with merge base 8b12ddf (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 18, 2025
@danielvegamyhre
Copy link
Contributor Author

@drisspg @vkuzo for review


def __init__(self, data: torch.Tensor):
def __new__(cls, data: torch.Tensor, use_triton_for_per_group_scales: bool = True):
cls.use_triton_for_per_group_scales = use_triton_for_per_group_scales
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks weird, I think you should just let it fall through to init and then set it on the instance

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I did that at first but then in the torch_function i don't have access to the instance because it's a classmethod. I'd like to make it less weird though, am open to ideas.

Technically since this is just temporary for benchmarking comparison, I could just condition on an env var and avoid plumbing this through everywhere entirely, but that seemed weird as well. In retrospect I actually think it would be cleaner though. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One inputs to torch function has to be one of these subclasses, right? So can't you just grab it from whatever instance is the subclass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, that should work - updated.

@danielvegamyhre danielvegamyhre merged commit 101c039 into main Jun 18, 2025
18 of 19 checks passed
xiaowangintel pushed a commit to xiaowangintel/ao that referenced this pull request Jun 24, 2025
… configurable (pytorch#2405)

* improve moe training benchmarking

* lint

* readability improvements

* grab use_triton for args instead of class attribute

* add comment
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants