You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add the workaround to support rowwise scaled_gemm for fp32 outputs (#2431)
Summary:
Running rowwise scaling on fp32 tensors got the error, P1794222725
```
RuntimeError: Only bf16 high precision output types are supported for row-wise scaling.
```
This pr adds an option to explicitly use bfloat16 as the output of rowwise_scaled, and cast it back to the original precision.
It can be enabled by setting
```
config = dataclasses.replace(config, convert_dtypes_for_rowwise_scaled_mm=True)
```
Differential Revision: D73552660
0 commit comments