Skip to content

[Float8] Unable to run asyncTP + Float8 row with 'full' AC active, leading dims mismatch #864

@lessw2020

Description

@lessw2020

Bug description

With the latest PR for Float8 rowwise to support TP, I hit the following error when full checkpointing is active. It is not an issue using no checkpointing.

_fused_scaled_matmul_reduce_scatter_fallback
[rank5]:[E0216](https://www.internalfb.com/servicelab/experiment/0216) 16:37:13.835000 2683263 site-packages/torch/_subclasses/fake_tensor.py:2391] [0/0]     raise ValueError(
[rank5]:[E0216](https://l.workplace.com/l.php?u=https%3A%2F%2Fwww.internalfb.com%2Fservicelab%2Fexperiment%2F0216&h=AT1TX4lglt8Gt6VweeYU3P0oFRLUvsJsvML5Y_yZJMPKCdVzFRdqqmTdvdtqmpXcIGhySXy-kct0PRVIYHpBY3OOFCLjLhFoaBNkUI3dt8eeFEvhGwmkZIm8ArpSAEFcFvqXXHIAHklO_ZuYQ9dt40L_8qfAErr7nWWJgg) 16:37:13.835000 2683263 site-packages/torch/_subclasses/fake_tensor.py:2391] [0/0] ValueError: For row-wise scaling, the leading dims of A_scale must match the leading dims of A (A shape: torch.Size([5, 8192, 4096]), A_scale shape: torch.Size([40960, 1]))
....
rank5]: ValueError: For row-wise scaling, the leading dims of A_scale must match the leading dims of A (A shape: torch.Size([5, 8192, 4096]), A_scale shape: torch.Size([40960, 1]))
[P1734043454]

Versions

Latest TorchAO nightly + TP PR + PyTorch nightly (Feb 10)
TP PR = https://github.com/pytorch/torchtitan/pull/808/files +
pytorch/ao#1718
This was at 256 scale, with full AC
TP=2 + FP8 rowwise

Metadata

Metadata

Labels

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions