Skip to content

Commit 00529fa

Browse files
authored
[float8] fix all-gather in 2D with DTensor(WeightWithDynamicFloat8CastTensor) (#590)
* [float8][2D] fix bug in precomputing scales Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * [float8] fix all-gather in 2D with DTensor(WeightWithDynamicFloat8CastTensor) Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * remove record_function after debugging Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * add asci diagraph Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 28f48fd commit 00529fa

File tree

2 files changed

+85
-2
lines changed

2 files changed

+85
-2
lines changed

test/float8/test_dtensor.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from torchao.float8 import Float8LinearConfig
2828
from torchao.float8.float8_linear_utils import convert_to_float8_training
2929

30+
from torchao.float8.config import CastConfig, ScalingType
3031
from torchao.float8.float8_scaling_utils import NoopFwToFloat8E5M2BwDynamic
3132
from torchao.float8.float8_tensor import (
3233
Float8Tensor,
@@ -43,6 +44,11 @@
4344
from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard
4445
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
4546
from torch.distributed.tensor.parallel import parallelize_module
47+
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
48+
from torch.testing._internal.distributed._tensor.common_dtensor import (
49+
ModelArgs,
50+
Transformer,
51+
)
4652
from tqdm import tqdm
4753

4854

@@ -303,6 +309,38 @@ def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
303309
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True)
304310

305311

312+
def _test_distribute_fsdp_tensor_subclass(tp_mesh: DeviceMesh):
313+
torch.manual_seed(42)
314+
model = Transformer(ModelArgs(dropout_p=0.0, weight_tying=False)).cuda()
315+
convert_to_float8_training(
316+
model,
317+
config=Float8LinearConfig(
318+
enable_fsdp_float8_all_gather=True,
319+
cast_config_weight=CastConfig(scaling_type=ScalingType.DYNAMIC),
320+
),
321+
)
322+
# test Float8ColwiseParallel
323+
colwise_param = distribute_tensor(
324+
model.layers[0].attention.wq.weight, tp_mesh, [Shard(0)]
325+
)
326+
assert (
327+
isinstance(colwise_param, DTensor)
328+
and isinstance(
329+
colwise_param._local_tensor, WeightWithDynamicFloat8CastTensor
330+
)
331+
), f"expect DTensor(local_tensor={WeightWithDynamicFloat8CastTensor}) but got {colwise_param}"
332+
# test Float8RowwiseParallel
333+
rowwise_param = distribute_tensor(
334+
model.layers[0].attention.wo.weight, tp_mesh, [Shard(1)]
335+
)
336+
assert (
337+
isinstance(rowwise_param, DTensor)
338+
and isinstance(
339+
rowwise_param._local_tensor, WeightWithDynamicFloat8CastTensor
340+
)
341+
), f"expect DTensor(local_tensor={WeightWithDynamicFloat8CastTensor}) but got {colwise_param}"
342+
343+
306344
if __name__ == "__main__":
307345
# float8 only works on CUDA H100 so we only test cuda and we follow
308346
# other test files to not use TestCase but instead just add the test
@@ -315,6 +353,7 @@ def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
315353
_test_dtensor_fp8_autograd,
316354
_test_fp8_mlp_tensor_parallelism_eager,
317355
_test_fp8_mlp_tensor_parallelism_compile,
356+
_test_distribute_fsdp_tensor_subclass,
318357
]
319358

320359
for test in tqdm(tests, desc="Running tests"):

torchao/float8/fsdp_utils.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,43 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
8484
torch.ops.aten.as_strided.default,
8585
torch.ops.aten._to_copy.default,
8686
torch.ops.aten._pin_memory.default,
87+
torch.ops.aten.split.Tensor,
88+
torch.ops.aten.clone.default,
8789
}
8890

91+
# How Tensor Parallel (TP) and FSDP2 work
92+
93+
# Initialization: apply TP first then FSDP2
94+
# nn.Linear(weight=torch.Tensor)
95+
# |
96+
# | apply float8 linear, `convert_to_float8_training`
97+
# |
98+
# Float8Linear(weight=WeightWithDynamicFloat8CastTensor)
99+
# |
100+
# | apply tensor parallel, `parallelize_module` shards rowwise/colwise
101+
# |
102+
# Float8Linear(weight=DTensor(local_tensor=WeightWithDynamicFloat8CastTensor,
103+
# device_mesh=DeviceMesh([0, 1], mesh_dim_names=('tp',)),
104+
# placements=(Shard(dim=0),)))
105+
# |
106+
# | apply FSDP2, `fully_shard` shards rowwise (dim=0)
107+
# |
108+
# Float8Linear(weight=DTensor(local_tensor=WeightWithDynamicFloat8CastTensor,
109+
# device_mesh=DeviceMesh([[0, 1], [2, 3]], mesh_dim_names=('dp', 'tp')),
110+
# placements=(Shard(dim=0), Shard(dim=0))))
111+
112+
# Forward and backward: FSDP runs first then TP
113+
# Float8Linear(weight=DTensor(local_tensor=WeightWithDynamicFloat8CastTensor,
114+
# device_mesh=DeviceMesh([[0, 1], [2, 3]], mesh_dim_names=('dp', 'tp')),
115+
# placements=(Shard(dim=0), Shard(dim=0))))
116+
# |
117+
# | FSDP unshards parameters within dp mesh
118+
# |
119+
# Float8Linear(weight=DTensor(local_tensor=WeightWithDynamicFloat8CastTensor,
120+
# device_mesh=DeviceMesh([0, 1], mesh_dim_names=('tp',)),
121+
# placements=(Shard(dim=0),)))
122+
# |
123+
# | TP compute with torch.mm(input, weight)
89124

90125
class WeightWithDynamicFloat8CastTensor(torch.Tensor):
91126
@staticmethod
@@ -195,8 +230,17 @@ def fsdp_post_all_gather(
195230
(data,) = all_gather_outputs
196231
(scale,) = metadata
197232
if out is not None:
198-
assert isinstance(out, Float8Tensor), f"{type(out)}"
199-
out._scale = scale
233+
from torch.distributed._tensor import DTensor
234+
if isinstance(out, Float8Tensor):
235+
out._scale = scale
236+
elif isinstance(out, DTensor) and isinstance(
237+
out._local_tensor, Float8Tensor
238+
):
239+
out._local_tensor._scale = scale
240+
else:
241+
raise RuntimeError(
242+
f"out must be a Float8Tensor or DTensor(_local_tensor=Float8Tensor), but got {out}"
243+
)
200244
return
201245
return Float8Tensor(
202246
data,

0 commit comments

Comments
 (0)