Skip to content

Conversation

@danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Feb 4, 2025

This fixes a bug in how per op AC interacts with float8 training. The bug surfaced through unexpected increase in peak memory usage when integrating row-wise scaling into float8 training in #808. The full RCA can be in found in this thread starting here.

TL;DR

  • I tested different model layers and permutations of configurations and found the minimal repro was a FFN with per op AC + FSDP + compile + float8 row-wise scaling.
  • I analyzed the triton kernels generated for the compiled FFN forward pass in bf16 vs fp8, tracking peak cumulative memory allocated and what was saved for backward. I found the fp8 forward kernel was saving a huge buffer/tensor for backward which the bf16 kernel was not saving.
  • I tracked how the various buffers were used and determined this huge buffer saved for backwards was holding abs(W3), where W3 is an unsharded weight tensor.
  • I tested removing abs() from the save list for per op AC, and confirmed it resolved the memory issue without impacting TPS.

Benchmarks

float8 row-wise WITH storing abs() op output:

[rank0]:2025-02-04 14:13:52,324 - root - INFO - step: 10  loss:  9.9765  memory: 58.55GiB(61.63%)  tps: 6,379  mfu: 37.35%
[rank0]:2025-02-04 14:14:05,145 - root - INFO - step: 20  loss:  8.3610  memory: 58.55GiB(61.63%)  tps: 6,390  mfu: 37.42%
[rank0]:2025-02-04 14:14:17,975 - root - INFO - step: 30  loss:  7.6860  memory: 58.55GiB(61.63%)  tps: 6,386  mfu: 37.39%

float8 row-wise WITHOUT storing abs() op output:

[rank0]:2025-02-04 14:19:22,446 - root - INFO - step: 10  loss: 10.1946  memory: 47.77GiB(50.28%)  tps: 6,293  mfu: 36.85%
[rank0]:2025-02-04 14:19:35,190 - root - INFO - step: 20  loss:  8.4865  memory: 47.77GiB(50.28%)  tps: 6,429  mfu: 37.64%
[rank0]:2025-02-04 14:19:47,951 - root - INFO - step: 30  loss:  7.6982  memory: 47.77GiB(50.28%)  tps: 6,420  mfu: 37.60%

I also confirmed this also does not cause a regression in memory usage or TPS with tensorwise scaling:

float8 tensorwise WITH storing abs() output:

[rank0]:2025-02-04 15:06:53,589 - root - INFO - step: 10  loss:  9.8912  memory: 47.77GiB(50.28%)  tps: 6,881  mfu: 40.30%
[rank0]:2025-02-04 15:07:05,499 - root - INFO - step: 20  loss:  8.4239  memory: 47.77GiB(50.28%)  tps: 6,879  mfu: 40.28%
[rank0]:2025-02-04 15:07:17,412 - root - INFO - step: 30  loss:  7.6615  memory: 47.77GiB(50.28%)  tps: 6,877  mfu: 40.27%

float8 tensorwise WITHOUT storing abs() output:

[rank0]:2025-02-04 15:08:32,527 - root - INFO - step: 10  loss:  9.9628  memory: 47.77GiB(50.28%)  tps: 6,865  mfu: 40.20%
[rank0]:2025-02-04 15:08:44,450 - root - INFO - step: 20  loss:  8.5451  memory: 47.77GiB(50.28%)  tps: 6,871  mfu: 40.24%
[rank0]:2025-02-04 15:08:56,383 - root - INFO - step: 30  loss:  7.8286  memory: 47.77GiB(50.28%)  tps: 6,865  mfu: 40.20%

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 4, 2025
@danielvegamyhre
Copy link
Contributor Author

cc @vkuzo

Copy link
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome, thank you for finding this!

torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops._c10d_functional.reduce_scatter_tensor.default,
# for low precision training, it's useful to always save
# the result of max(abs(tensor))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls modify comments accordingly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

# for low precision training, it's useful to always save
# the result of max(abs(tensor))
torch.ops.aten.abs.default,
torch.ops.aten.max.default,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since abs needs to be recomputed, do we still want to keep max? Asking in another way, why do we need to keep abs in the first place, since we already keep the result of max? @vkuzo

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need to keep abs, I made a mistake adding it in my original PR. We just didn't see the cost of the mistake until the rowwise scaled float8 recipe.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, thank you!

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please fix the CPU CI error before merging?

@danielvegamyhre
Copy link
Contributor Author

could you please fix the CPU CI error before merging?

Yeah it seems unrelated to my changes (RuntimeError: 0 active drivers ([]). There should only be one.)

I will try rerunning

@tianyu-l
Copy link
Contributor

tianyu-l commented Feb 5, 2025

seems related, feel free to merge

@danielvegamyhre
Copy link
Contributor Author

seems related, feel free to merge

I don't have merging powers in this repo :( could you merge when you can please?

@tianyu-l tianyu-l merged commit 8824727 into pytorch:main Feb 6, 2025
5 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants