Skip to content

Commit 7a04053

Browse files
committed
[not for land yet] example of float8 with rowwise scaling
Summary: This is an example of how to call float8 training with rowwise scaling from torchao. TODO: finalize API in torchao, and finalize how we want to expose it in torchtitan, and optimize performance. ``` // baseline (bf16 + compile) > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.compile ... step: 20 loss: 8.4931 memory: 47.65GiB(50.16%) tps: 5,760 mfu: 33.73% // experiment (rowwise float8 + compile) > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile ... // torchao main branch step: 40 loss: 7.3818 memory: 66.81GiB(70.33%) tps: 6,412 mfu: 37.55% // torchao with pytorch/ao#1629 step: 20 loss: 8.3823 memory: 58.55GiB(61.63%) tps: 6,424 mfu: 37.62% // for comparison, tensorwise float8 with float8 all-gather (on main branch) with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp ... step: 20 loss: 8.4258 memory: 47.32GiB(49.81%) tps: 7,186 mfu: 42.08% ``` Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 6cb13c7 commit 7a04053

File tree

3 files changed

+66
-26
lines changed

3 files changed

+66
-26
lines changed

torchtitan/components/float8.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -49,25 +49,46 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
4949
"torchao is not installed. Please install it to use float8 linear layers."
5050
) from e
5151

52-
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
53-
enable_fsdp_float8_all_gather = (
54-
parallel_dims.dp_shard_enabled
55-
and float8_config.enable_fsdp_float8_all_gather
56-
)
57-
self.config = Float8LinearConfig(
58-
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
59-
force_recompute_fp8_weight_in_bwd=float8_config.force_recompute_fp8_weight_in_bwd,
60-
)
52+
if float8_config.recipe_name is not None and not hasattr(
53+
Float8LinearConfig, "from_recipe_name"
54+
):
55+
logger.warning(
56+
"Failed to swap to Float8Linear with recipe lookup because the torchao version "
57+
+ "is too old, please install torchao v0.9.0 or later and try again",
58+
)
59+
return
6160

6261
self.enabled = True
6362

64-
# for precompute_float8_dynamic_scale_for_fsdp
65-
self.precompute_scale = (
66-
enable_fsdp_float8_all_gather
67-
and float8_config.precompute_float8_dynamic_scale_for_fsdp
68-
)
63+
if float8_config.recipe_name is not None:
64+
assert (
65+
not float8_config.enable_fsdp_float8_all_gather
66+
), "using `float8_config.enable_fsdp_float8_all_gather` together with `float8_config.recipe_name` is not supported"
67+
assert (
68+
not float8_config.force_recompute_fp8_weight_in_bwd
69+
), "using `float8_config.force_recompute_fp8_weight_in_bwd` together with `float8_config.recipe_name` is not supported"
70+
self.config = Float8LinearConfig.from_recipe_name(float8_config.recipe_name)
71+
self.precompute_scale = False
72+
logger.info(
73+
f"Float8 training active with recipe {float8_config.recipe_name}"
74+
)
6975

70-
logger.info("Float8 training active")
76+
else:
77+
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
78+
enable_fsdp_float8_all_gather = (
79+
parallel_dims.dp_shard_enabled
80+
and float8_config.enable_fsdp_float8_all_gather
81+
)
82+
self.config = Float8LinearConfig(
83+
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
84+
force_recompute_fp8_weight_in_bwd=float8_config.force_recompute_fp8_weight_in_bwd,
85+
)
86+
# for precompute_float8_dynamic_scale_for_fsdp
87+
self.precompute_scale = (
88+
enable_fsdp_float8_all_gather
89+
and float8_config.precompute_float8_dynamic_scale_for_fsdp
90+
)
91+
logger.info("Float8 tensorwise scaled training active")
7192

7293
def convert(self, model: nn.Module):
7394
return self.convert_to_float8_training(model)

torchtitan/config_manager.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -613,20 +613,30 @@ def __init__(self):
613613
self.parser.add_argument(
614614
"--float8.enable_fsdp_float8_all_gather",
615615
action="store_true",
616-
help="Whether enable float8 all-gather in FSDP",
616+
help="Whether enable float8 all-gather in FSDP, recommended for tensorwise scaling",
617617
)
618618
self.parser.add_argument(
619619
"--float8.precompute_float8_dynamic_scale_for_fsdp",
620620
action="store_true",
621-
help="Whether precompute float8 scales dynamically for FSDP",
621+
help="Whether precompute float8 scales dynamically for FSDP, recommended for tensorwise scaling",
622622
)
623623
self.parser.add_argument(
624624
"--float8.force_recompute_fp8_weight_in_bwd",
625625
action="store_true",
626626
help="""
627627
Whether to force the recomputation of FP8 weights during backward pass.
628-
When using FSDP, it is recommended to enable `force_recompute_fp8_weight_in_bwd`
629-
to prevent saving unsharded FP8 weights for backward computation.
628+
When using FSDP with tensorwise scaling, it is recommended to enable
629+
`force_recompute_fp8_weight_in_bwd` to prevent saving unsharded FP8 weights
630+
for backward computation.
631+
""",
632+
)
633+
self.parser.add_argument(
634+
"--float8.recipe_name",
635+
type=str,
636+
default=None,
637+
help="""
638+
If specified, creates float8 config from recipe name, valid choices are
639+
`rowwise` and `rowwise_with_gw_hp`.
630640
""",
631641
)
632642

torchtitan/models/llama/parallelize_llama.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,23 @@ def parallelize_llama(
5656
and not job_config.training.compile
5757
):
5858
raise RuntimeError("Async TP requires --training.compile")
59+
5960
enable_float8_linear = "float8" in job_config.model.converters
61+
float8_is_rowwise = job_config.float8.recipe_name in (
62+
"rowwise",
63+
"rowwise_with_gw_hp",
64+
)
65+
66+
# For now, float8 all-gather with TP is only supported for tensorwise
67+
# float8 scaling recipes. For rowwise recipes, we use regular TP and
68+
# all-gather happens in high precision.
69+
enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise
70+
6071
apply_tp(
6172
model,
6273
world_mesh["tp"],
6374
loss_parallel=parallel_dims.loss_parallel_enabled,
64-
enable_float8=enable_float8_linear,
75+
enable_float8_tensorwise_tp=enable_float8_tensorwise_tp,
6576
enable_async_tp=job_config.experimental.enable_async_tensor_parallel,
6677
)
6778

@@ -115,7 +126,7 @@ def apply_tp(
115126
model: nn.Module,
116127
tp_mesh: DeviceMesh,
117128
loss_parallel: bool,
118-
enable_float8: bool,
129+
enable_float8_tensorwise_tp: bool,
119130
enable_async_tp: bool,
120131
):
121132
"""Apply tensor parallelism."""
@@ -141,10 +152,8 @@ def apply_tp(
141152
)
142153

143154
# Parallel styles used for transformer block linear weights and their
144-
# inputs may be different for float8 linears
145-
if enable_float8:
146-
# TODO(vkuzo): once float8 configuration supports delayed scaling,
147-
# add a check here to enforce supported float8 all-gather configurations
155+
# inputs may be different for float8 linears with tensorwise scaling.
156+
if enable_float8_tensorwise_tp:
148157
# TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there
149158
from torchao.float8.float8_tensor_parallel import (
150159
Float8ColwiseParallel,
@@ -202,7 +211,7 @@ def apply_tp(
202211
enable_symm_mem_for_group(tp_mesh.get_group().group_name)
203212

204213
logger.info(
205-
f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}"
214+
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}"
206215
"Tensor Parallelism to the model"
207216
)
208217

0 commit comments

Comments
 (0)