From 066f88974a5fbd7bdcc8e2a25e3d31926f327a16 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 27 Jan 2025 13:49:22 -0800 Subject: [PATCH] float8 rowwise training: add FSDP workaround Summary: Adds the workaround from https://github.com/pytorch/pytorch/issues/141881 to the torchao float8 rowwise recipe, to reduce memory usage when FSDP is on. Test Plan: tested in torchtitan, LLaMa 3 8B 8H100 training with rowwise peak memory decreased from 67GiB to 59GiB Reviewers: Subscribers: Tasks: Tags: --- torchao/float8/float8_linear.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index 18aebaeada..6b3c0f06df 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -159,6 +159,15 @@ def backward(ctx, grad_output): elif c.cast_config_weight_for_grad_input.scaling_type is ScalingType.DISABLED: weight_t_maybe_fp8_dim0 = weight_hp_t else: + if ( + c.cast_config_weight_for_grad_input.scaling_granularity + is ScalingGranularity.AXISWISE + ): + # workaround from https://github.com/pytorch/pytorch/issues/141881 + # to avoid saving float8 weight from forward to backward when + # FSDP is on + weight_hp_t = weight_hp_t + (grad_output_reshaped[0, 0] * 0) + # Note: we need https://github.com/pytorch/pytorch/issues/136267 # to be solved to have a chance to reuse max(abs(weight, dim=...)) # from the forward to get max(abs(weight)) here without reading