Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ def apply_tp(
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops._c10d_functional.reduce_scatter_tensor.default,
# for low precision training, it's useful to always save
# the result of max(abs(tensor))
torch.ops.aten.abs.default,
# the result of max, since the absolute maximum is
# used to compute the scaling factor for quantization.
torch.ops.aten.max.default,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since abs needs to be recomputed, do we still want to keep max? Asking in another way, why do we need to keep abs in the first place, since we already keep the result of max? @vkuzo

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need to keep abs, I made a mistake adding it in my original PR. We just didn't see the cost of the mistake until the rowwise scaled float8 recipe.

}

Expand Down
Loading