-
Notifications
You must be signed in to change notification settings - Fork 615
Remove abs from save list for per op AC to fix bug causing unexpected increase in peak memory usage for float8 training with row-wise scaling #820
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
cc @vkuzo |
vkuzo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
awesome, thank you for finding this!
| torch.ops.aten._scaled_dot_product_flash_attention.default, | ||
| torch.ops._c10d_functional.reduce_scatter_tensor.default, | ||
| # for low precision training, it's useful to always save | ||
| # the result of max(abs(tensor)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pls modify comments accordingly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| # for low precision training, it's useful to always save | ||
| # the result of max(abs(tensor)) | ||
| torch.ops.aten.abs.default, | ||
| torch.ops.aten.max.default, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since abs needs to be recomputed, do we still want to keep max? Asking in another way, why do we need to keep abs in the first place, since we already keep the result of max? @vkuzo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't need to keep abs, I made a mistake adding it in my original PR. We just didn't see the cost of the mistake until the rowwise scaled float8 recipe.
2266bad to
d607298
Compare
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, thank you!
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you please fix the CPU CI error before merging?
Yeah it seems unrelated to my changes ( I will try rerunning |
|
seems related, feel free to merge |
I don't have merging powers in this repo :( could you merge when you can please? |
This fixes a bug in how per op AC interacts with float8 training. The bug surfaced through unexpected increase in peak memory usage when integrating row-wise scaling into float8 training in #808. The full RCA can be in found in this thread starting here.
TL;DR
abs(W3), where W3 is an unsharded weight tensor.abs()from the save list for per op AC, and confirmed it resolved the memory issue without impacting TPS.Benchmarks
float8 row-wise WITH storing abs() op output:
float8 row-wise WITHOUT storing abs() op output:
I also confirmed this also does not cause a regression in memory usage or TPS with tensorwise scaling:
float8 tensorwise WITH storing abs() output:
float8 tensorwise WITHOUT storing abs() output: