-
Notifications
You must be signed in to change notification settings - Fork 294
test rowwise fp32 #2431
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
test rowwise fp32 #2431
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2431
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit b0240a2 with merge base 2025b75 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This pull request was exported from Phabricator. Differential Revision: D73552660 |
|
||
if convert_dtypes_for_rowwise_scaled_mm and is_rowwise_scaling: | ||
output_dtype = torch.bfloat16 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of adding a flag, TBH I think we can just enable this on-by-default, like this:
file issue in PyTorch core to add float32 output to scaled_mm
output_dtype_to_use = output_dtype
if is_rowwise_scaling:
# work around torch._scaled_mm not having float32 output type
# TODO(issue number): remove this once torch._scaled_mm supports float32 output
output_dtype_to_use = torch.bfloat16
output = torch._scaled_mm(..., output_dtype_to_use, ...)
...
if is_rowwise_scaling and output_dtype == torch.float32:
# work around torch._scaled_mm not having float32 output type
# TODO(issue number): remove this once torch._scaled_mm supports float32 output
output = output.to(orig_dtype)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense, I'll change to enable by default and file an issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we file an issue in core to add this to torch._scaled_mm
, and enable the workaround without a config for now? also add a test?
Updated to enable the workaround by default. Included fp16 and fp32 dtypes in the existing test cases. The additional changes are formatting things generated by linter. The pytorch issue: pytorch/pytorch#156771 |
@y-sq , maybe export again? |
…ytorch#2431) Summary: Running rowwise scaling on fp32 tensors got the error, P1794222725 ``` RuntimeError: Only bf16 high precision output types are supported for row-wise scaling. ``` This pr adds an option to explicitly use bfloat16 as the output of rowwise_scaled, and cast it back to the original precision. It can be enabled by setting ``` config = dataclasses.replace(config, convert_dtypes_for_rowwise_scaled_mm=True) ``` Differential Revision: D73552660
This pull request was exported from Phabricator. Differential Revision: D73552660 |
@vkuzo sorry there were some un-synced files between github and fbcode so the previous export all failed. The pr should be updated now. |
…ytorch#2431) Summary: Running rowwise scaling on fp32 tensors got the error, P1794222725 ``` RuntimeError: Only bf16 high precision output types are supported for row-wise scaling. ``` This pr adds an option to explicitly use bfloat16 as the output of rowwise_scaled, and cast it back to the original precision. It can be enabled by setting ``` config = dataclasses.replace(config, convert_dtypes_for_rowwise_scaled_mm=True) ``` Differential Revision: D73552660
This pull request was exported from Phabricator. Differential Revision: D73552660 |
…ytorch#2431) Summary: Running rowwise scaling on fp32 tensors got the error, P1794222725 ``` RuntimeError: Only bf16 high precision output types are supported for row-wise scaling. ``` This pr adds an option to explicitly use bfloat16 as the output of rowwise_scaled, and cast it back to the original precision. It can be enabled by setting ``` config = dataclasses.replace(config, convert_dtypes_for_rowwise_scaled_mm=True) ``` Differential Revision: D73552660
This pull request was exported from Phabricator. Differential Revision: D73552660 |
…ytorch#2431) Summary: Running rowwise scaling on fp32 tensors got the error, P1794222725 ``` RuntimeError: Only bf16 high precision output types are supported for row-wise scaling. ``` This pr adds an option to explicitly use bfloat16 as the output of rowwise_scaled, and cast it back to the original precision. It can be enabled by setting ``` config = dataclasses.replace(config, convert_dtypes_for_rowwise_scaled_mm=True) ``` Differential Revision: D73552660
abbed3a
to
303d6a6
Compare
…ytorch#2431) Summary: Running rowwise scaling on fp32 tensors got the error, P1794222725 ``` RuntimeError: Only bf16 high precision output types are supported for row-wise scaling. ``` This pr adds an option to explicitly use bfloat16 as the output of rowwise_scaled, and cast it back to the original precision. It can be enabled by setting ``` config = dataclasses.replace(config, convert_dtypes_for_rowwise_scaled_mm=True) ``` Differential Revision: D73552660
This pull request was exported from Phabricator. Differential Revision: D73552660 |
…ytorch#2431) Summary: Pull Request resolved: pytorch#2431 Running rowwise scaling on fp32 tensors got the error, P1794222725 ``` RuntimeError: Only bf16 high precision output types are supported for row-wise scaling. ``` This pr adds an option to explicitly use bfloat16 as the output of rowwise_scaled, and cast it back to the original precision. It can be enabled by setting ``` config = dataclasses.replace(config, convert_dtypes_for_rowwise_scaled_mm=True) ``` Differential Revision: D73552660
This pull request was exported from Phabricator. Differential Revision: D73552660 |
Summary:
Running rowwise scaling on fp32 tensors got the error, P1794222725
This pr adds an option to explicitly use bfloat16 as the output of rowwise_scaled, and cast it back to the original precision.
It can be enabled by setting
Differential Revision: D73552660