Commit 7a04053
committed
[not for land yet] example of float8 with rowwise scaling
Summary:
This is an example of how to call float8 training with rowwise scaling
from torchao.
TODO: finalize API in torchao, and finalize how we want to expose it in
torchtitan, and optimize performance.
```
// baseline (bf16 + compile)
> with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.compile
...
step: 20 loss: 8.4931 memory: 47.65GiB(50.16%) tps: 5,760 mfu: 33.73%
// experiment (rowwise float8 + compile)
> with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile
...
// torchao main branch
step: 40 loss: 7.3818 memory: 66.81GiB(70.33%) tps: 6,412 mfu: 37.55%
// torchao with pytorch/ao#1629
step: 20 loss: 8.3823 memory: 58.55GiB(61.63%) tps: 6,424 mfu: 37.62%
// for comparison, tensorwise float8 with float8 all-gather (on main branch)
with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp
...
step: 20 loss: 8.4258 memory: 47.32GiB(49.81%) tps: 7,186 mfu: 42.08%
```
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:1 parent 6cb13c7 commit 7a04053
File tree
3 files changed
+66
-26
lines changed- torchtitan
- components
- models/llama
3 files changed
+66
-26
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
49 | 49 | | |
50 | 50 | | |
51 | 51 | | |
52 | | - | |
53 | | - | |
54 | | - | |
55 | | - | |
56 | | - | |
57 | | - | |
58 | | - | |
59 | | - | |
60 | | - | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
61 | 60 | | |
62 | 61 | | |
63 | 62 | | |
64 | | - | |
65 | | - | |
66 | | - | |
67 | | - | |
68 | | - | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
69 | 75 | | |
70 | | - | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
71 | 92 | | |
72 | 93 | | |
73 | 94 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
613 | 613 | | |
614 | 614 | | |
615 | 615 | | |
616 | | - | |
| 616 | + | |
617 | 617 | | |
618 | 618 | | |
619 | 619 | | |
620 | 620 | | |
621 | | - | |
| 621 | + | |
622 | 622 | | |
623 | 623 | | |
624 | 624 | | |
625 | 625 | | |
626 | 626 | | |
627 | 627 | | |
628 | | - | |
629 | | - | |
| 628 | + | |
| 629 | + | |
| 630 | + | |
| 631 | + | |
| 632 | + | |
| 633 | + | |
| 634 | + | |
| 635 | + | |
| 636 | + | |
| 637 | + | |
| 638 | + | |
| 639 | + | |
630 | 640 | | |
631 | 641 | | |
632 | 642 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
56 | 56 | | |
57 | 57 | | |
58 | 58 | | |
| 59 | + | |
59 | 60 | | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
60 | 71 | | |
61 | 72 | | |
62 | 73 | | |
63 | 74 | | |
64 | | - | |
| 75 | + | |
65 | 76 | | |
66 | 77 | | |
67 | 78 | | |
| |||
115 | 126 | | |
116 | 127 | | |
117 | 128 | | |
118 | | - | |
| 129 | + | |
119 | 130 | | |
120 | 131 | | |
121 | 132 | | |
| |||
141 | 152 | | |
142 | 153 | | |
143 | 154 | | |
144 | | - | |
145 | | - | |
146 | | - | |
147 | | - | |
| 155 | + | |
| 156 | + | |
148 | 157 | | |
149 | 158 | | |
150 | 159 | | |
| |||
202 | 211 | | |
203 | 212 | | |
204 | 213 | | |
205 | | - | |
| 214 | + | |
206 | 215 | | |
207 | 216 | | |
208 | 217 | | |
| |||
0 commit comments