diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index 18aebaeada..6b3c0f06df 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -159,6 +159,15 @@ def backward(ctx, grad_output): elif c.cast_config_weight_for_grad_input.scaling_type is ScalingType.DISABLED: weight_t_maybe_fp8_dim0 = weight_hp_t else: + if ( + c.cast_config_weight_for_grad_input.scaling_granularity + is ScalingGranularity.AXISWISE + ): + # workaround from https://github.com/pytorch/pytorch/issues/141881 + # to avoid saving float8 weight from forward to backward when + # FSDP is on + weight_hp_t = weight_hp_t + (grad_output_reshaped[0, 0] * 0) + # Note: we need https://github.com/pytorch/pytorch/issues/136267 # to be solved to have a chance to reuse max(abs(weight, dim=...)) # from the forward to get max(abs(weight)) here without reading