-
Notifications
You must be signed in to change notification settings - Fork 30
/
README.md
67 lines (50 loc) · 3.61 KB
/
README.md
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
## Ring Flash Attention
This repo implements the [RingAttention](https://github.com/lhao499/RingAttention) with [FlashAttention](https://github.com/Dao-AILab/flash-attention). Currently, this repo implements:
- `ring_flash_attn_func`: ring attention version of `flash_attn_func`
- `ring_flash_attn_varlen_func`: ring attention version of `flash_attn_varlen_func`
- `zigzag_ring_flash_attn_func`: an optimized version of `ring_flash_attn_func`, see [issue#2](https://github.com/zhuzilin/ring-flash-attention/issues/2)
- `zigzag_ring_flash_attn_varlen_func`: an optimized version of `ring_flash_attn_varlen_func`
- `stripe_flash_attn_func`: stripe attention version of `ring_flash_attn_func`, the block size is set to 1 to use flash_attn api.
Note that
- all function has the `*_func`, `*_kvpacked_func`, `*_qkvpacked_func` variant implemented.
- the varlen versions only support passing one `cu_seqlens`.
The main idea is to use the `softmax_lse` output from the flash attention kernels.
The current performance on 8xH800 is ([benchmark/benchmark_qkvpacked_func.py](benchmark/benchmark_qkvpacked_func.py)):
| | GPU | theoretic flash_attn | ring_attn | zigzag_ring | stripe_attn |
| -------------------- | ------ | -------------------- | --------- | ----------- | ----------- |
| fwd only (iter/sec) | 8xH800 | 2418.4 / 8 = 302.3 | 208.0 | 283.0 | 259.6 |
| | | | 68.8% | **93.6%** | 85.9% |
| fwd + bwd (iter/sec) | 8xH800 | 705.2 / 8 = 88.2 | 54.3 | 75.7 | 76.9 |
| | | | 61.5% | 85.9% | **87.2%** |
| fwd only (iter/sec) | 8xA100 | 1545.9 / 8 = 193.2 | 124.4 | 179.0 | 163.9 |
| | | | 64.3% | **92.7%** | 84.8% |
| fwd + bwd (iter/sec) | 8xA100 | 470.6 / 8 = 58.8 | 33.3 | 49.5 | 45.9 |
| | | | 56.6% | **84.1%** | 78.1% |
Note that
- when running the benchmark with with 8 gpu, the flash attn code is running with 1/8 computation of ring attention.
- nvlink between GPUs are required for high performance.
- the varlen versions are slow at the moment, please use the non-varlen version if possible.
### Limits
There are some arithmetic errors with the current implementation. The reason for them is probably that flash attention will return bf16 value for each block, so we cannot accumluate the values with the original fp32 ones.
And also because we need to save extra fp32 buffer during computation, the memory usage would be higher than theoretic limit.
### TODOs
- [x] Implement `ring_flash_attn_varlen_qkvpacked_func`
- [x] Implement `zigzag_ring_flash_attn_qkvpacked_func` [issue#2](https://github.com/zhuzilin/ring-flash-attention/issues/2)
- [x] Implement `stripe_flash_attn_qkvpacked_func`
- [x] Implement `zigzag_ring_flash_attn_varlen_qkvpacked_func`
- [x] Implement `*_kvpacked_func` and `*_func` variant for all APIs
- [ ] Optimize `*_varlen_func`
- [ ] Try to upstream to flash attention.
### Test
```bash
torchrun --nproc_per_node 8 test/test_ring_flash_attn_func.py
torchrun --nproc_per_node 8 test/test_ring_flash_attn_varlen_func.py
torchrun --nproc_per_node 8 test/test_zigzag_ring_flash_attn_func.py
torchrun --nproc_per_node 8 test/test_zigzag_ring_flash_attn_varlen_func.py
torchrun --nproc_per_node 8 test/test_stripe_flash_attn_func.py
```
### Benchmark
```bash
torchrun --nproc_per_node 8 benchmark/benchmark_qkvpacked_func.py
torchrun --nproc_per_node 8 benchmark/benchmark_varlen_qkvpacked_func.py
```