-
Notifications
You must be signed in to change notification settings - Fork 296
[float8] add _auto_filter_for_recipe to float8 #2410
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2410
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit ded0931 with merge base 101c039 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc @vkuzo for review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good, can we make sure the name has an underscore and also add a test before landing?
ae04451
to
ded0931
Compare
Fixes #1207 ## Problem - float8 rowwise + vanilla TP in torchtitan had flat perf with respect to bfloat16 (see #1207). - RCA In #1207 found attention.wk and attention.wv layers were so small that float8 rowwise conversion resulted in approx ~40% slowdown for those layers, which nullified the perf benefits from fp8 rowwise conversion on larger linears. - This is because the default `filter_fqns` for float8 model conversion are fine for the fp8 tensorwise recipe, but bad for the float8 rowwise recipe. ### Solution This has been a footgun for various users as well (including Poolside), so I created an "auto filter" (pytorch/ao#2410) which automatically filters Linears for a given float8 recipe, by checking for the following criteria: 1. dims not divisible by 16 (hardware requirement for float8) 2. dim sizes below thresholds that may result in worse perf **for that given recipe**, using simple heuristics based on the linked recipe perf tables above. 3. fqn matches one of the user defined `filter_fqns` It prevents users from hitting this common footgun, while also preserving the flexibility to define their model-specific fqns. ## Results Benchmarks show a ~10% TPS improvement for TP and ~15% TPS improvement for async TP (over bf16 TP baseline). Llama3 70b on 256 H100s with FSDP=32, TP=8, torch.compile, full AC, local batch size 16: - [bfloat16 baseline](https://fburl.com/mlhub/ji9smr5u) = ~597TPS - [fp8 rowwise WITH attention.wk, attention.wv converted](https://fburl.com/mlhub/cu4o6w5m) = ~600 TPS - [fp8 rowwise WITHOUT attention.wk, attention.wv converted](https://fburl.com/mlhub/mgzz309o) = ~660 TPS - [fp8 rowwise + async TP WITH attention.wk, attention.wv converted](https://fburl.com/mlhub/76q4mel9 ) = ~625 TPS - [fp8 rowwise + async TP WITHOUT attention.wk, attention.wv converted](https://fburl.com/mlhub/6b07aa4d) = ~695 TPS
Fixes #1207 ## Problem - float8 rowwise + vanilla TP in torchtitan had flat perf with respect to bfloat16 (see #1207). - RCA In #1207 found attention.wk and attention.wv layers were so small that float8 rowwise conversion resulted in approx ~40% slowdown for those layers, which nullified the perf benefits from fp8 rowwise conversion on larger linears. - This is because the default `filter_fqns` for float8 model conversion are fine for the fp8 tensorwise recipe, but bad for the float8 rowwise recipe. ### Solution This has been a footgun for various users as well (including Poolside), so I created an "auto filter" (pytorch/ao#2410) which automatically filters Linears for a given float8 recipe, by checking for the following criteria: 1. dims not divisible by 16 (hardware requirement for float8) 2. dim sizes below thresholds that may result in worse perf **for that given recipe**, using simple heuristics based on the linked recipe perf tables above. 3. fqn matches one of the user defined `filter_fqns` It prevents users from hitting this common footgun, while also preserving the flexibility to define their model-specific fqns. ## Results Benchmarks show a ~10% TPS improvement for TP and ~15% TPS improvement for async TP (over bf16 TP baseline). Llama3 70b on 256 H100s with FSDP=32, TP=8, torch.compile, full AC, local batch size 16: - [bfloat16 baseline](https://fburl.com/mlhub/ji9smr5u) = ~597TPS - [fp8 rowwise WITH attention.wk, attention.wv converted](https://fburl.com/mlhub/cu4o6w5m) = ~600 TPS - [fp8 rowwise WITHOUT attention.wk, attention.wv converted](https://fburl.com/mlhub/mgzz309o) = ~660 TPS - [fp8 rowwise + async TP WITH attention.wk, attention.wv converted](https://fburl.com/mlhub/76q4mel9 ) = ~625 TPS - [fp8 rowwise + async TP WITHOUT attention.wk, attention.wv converted](https://fburl.com/mlhub/6b07aa4d) = ~695 TPS
Fixes pytorch#1207 ## Problem - float8 rowwise + vanilla TP in torchtitan had flat perf with respect to bfloat16 (see pytorch#1207). - RCA In pytorch#1207 found attention.wk and attention.wv layers were so small that float8 rowwise conversion resulted in approx ~40% slowdown for those layers, which nullified the perf benefits from fp8 rowwise conversion on larger linears. - This is because the default `filter_fqns` for float8 model conversion are fine for the fp8 tensorwise recipe, but bad for the float8 rowwise recipe. ### Solution This has been a footgun for various users as well (including Poolside), so I created an "auto filter" (pytorch/ao#2410) which automatically filters Linears for a given float8 recipe, by checking for the following criteria: 1. dims not divisible by 16 (hardware requirement for float8) 2. dim sizes below thresholds that may result in worse perf **for that given recipe**, using simple heuristics based on the linked recipe perf tables above. 3. fqn matches one of the user defined `filter_fqns` It prevents users from hitting this common footgun, while also preserving the flexibility to define their model-specific fqns. ## Results Benchmarks show a ~10% TPS improvement for TP and ~15% TPS improvement for async TP (over bf16 TP baseline). Llama3 70b on 256 H100s with FSDP=32, TP=8, torch.compile, full AC, local batch size 16: - [bfloat16 baseline](https://fburl.com/mlhub/ji9smr5u) = ~597TPS - [fp8 rowwise WITH attention.wk, attention.wv converted](https://fburl.com/mlhub/cu4o6w5m) = ~600 TPS - [fp8 rowwise WITHOUT attention.wk, attention.wv converted](https://fburl.com/mlhub/mgzz309o) = ~660 TPS - [fp8 rowwise + async TP WITH attention.wk, attention.wv converted](https://fburl.com/mlhub/76q4mel9 ) = ~625 TPS - [fp8 rowwise + async TP WITHOUT attention.wk, attention.wv converted](https://fburl.com/mlhub/6b07aa4d) = ~695 TPS
Part of pytorch/torchtitan#1207
Problem
filter_fqns
for float8 model conversion are fine for the fp8 tensorwise recipe, but bad for the float8 rowwise recipe.Solution
This has been a footgun for various users as well (including Poolside), so I created an "auto filter" (#2410) which automatically filters Linears for a given float8 recipe, by checking for the following criteria:
filter_fqns
I integrated a PoC into torchtitan and the auto filter improved fp8 rowwise perf both local Llama3 8b run and Llama3 70b MAST run, compared to the default filter_fn we have now.
It prevents users from hitting this common footgun, while also preserving the flexibility to define their model-specific fqns.
Results
See pytorch/torchtitan#1207 for Llama3 70b results, TL;DR is filtering wk and wv improves TPS ~10% for vanilla TP and ~15% for async TP.