-
Notifications
You must be signed in to change notification settings - Fork 298
[float8 moe training] make using triton kernels for per-group scaling configurable #2405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2405
Note: Links to docs will display an error until the docs builds have been completed. ⏳ 3 Pending, 1 Unrelated FailureAs of commit a285fc8 with merge base 8b12ddf ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torchao/prototype/moe_training/benchmarks/benchmark_scaled_grouped_mm.py
Outdated
Show resolved
Hide resolved
|
||
def __init__(self, data: torch.Tensor): | ||
def __new__(cls, data: torch.Tensor, use_triton_for_per_group_scales: bool = True): | ||
cls.use_triton_for_per_group_scales = use_triton_for_per_group_scales |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this looks weird, I think you should just let it fall through to init and then set it on the instance
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, I did that at first but then in the torch_function i don't have access to the instance because it's a classmethod. I'd like to make it less weird though, am open to ideas.
Technically since this is just temporary for benchmarking comparison, I could just condition on an env var and avoid plumbing this through everywhere entirely, but that seemed weird as well. In retrospect I actually think it would be cleaner though. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One inputs to torch function has to be one of these subclasses, right? So can't you just grab it from whatever instance is the subclass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, that should work - updated.
… configurable (pytorch#2405) * improve moe training benchmarking * lint * readability improvements * grab use_triton for args instead of class attribute * add comment
Summary