-
Notifications
You must be signed in to change notification settings - Fork 569
Closed
Closed
Copy link
Labels
Description
Bug description
With the latest PR for Float8 rowwise to support TP, I hit the following error when full checkpointing is active. It is not an issue using no checkpointing.
_fused_scaled_matmul_reduce_scatter_fallback
[rank5]:[E0216](https://www.internalfb.com/servicelab/experiment/0216) 16:37:13.835000 2683263 site-packages/torch/_subclasses/fake_tensor.py:2391] [0/0] raise ValueError(
[rank5]:[E0216](https://l.workplace.com/l.php?u=https%3A%2F%2Fwww.internalfb.com%2Fservicelab%2Fexperiment%2F0216&h=AT1TX4lglt8Gt6VweeYU3P0oFRLUvsJsvML5Y_yZJMPKCdVzFRdqqmTdvdtqmpXcIGhySXy-kct0PRVIYHpBY3OOFCLjLhFoaBNkUI3dt8eeFEvhGwmkZIm8ArpSAEFcFvqXXHIAHklO_ZuYQ9dt40L_8qfAErr7nWWJgg) 16:37:13.835000 2683263 site-packages/torch/_subclasses/fake_tensor.py:2391] [0/0] ValueError: For row-wise scaling, the leading dims of A_scale must match the leading dims of A (A shape: torch.Size([5, 8192, 4096]), A_scale shape: torch.Size([40960, 1]))
....
rank5]: ValueError: For row-wise scaling, the leading dims of A_scale must match the leading dims of A (A shape: torch.Size([5, 8192, 4096]), A_scale shape: torch.Size([40960, 1]))
[P1734043454]
Versions
Latest TorchAO nightly + TP PR + PyTorch nightly (Feb 10)
TP PR = https://github.com/pytorch/torchtitan/pull/808/files +
pytorch/ao#1718
This was at 256 scale, with full AC
TP=2 + FP8 rowwise