diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 9728569ab1..ecf8332475 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -217,8 +217,8 @@ def apply_tp( torch.ops.aten._scaled_dot_product_flash_attention.default, torch.ops._c10d_functional.reduce_scatter_tensor.default, # for low precision training, it's useful to always save - # the result of max(abs(tensor)) - torch.ops.aten.abs.default, + # the result of max, since the absolute maximum is + # used to compute the scaling factor for quantization. torch.ops.aten.max.default, }