## 1.2 Optimizing Attention with FlashAttention-2
### 1.2.1 Benchmarking PyTorch Attention

Report the timings (or out-of-memory errors) you get for these configurations. I do not get out-of-memory errors. Memory usage
batch_size * seq_len * seq_len * 2 params activations. x4 for float32. /1024^3 for GB. Checks out.

Memory scales linearly with sequence length.

<details>
<summary><strong>Results Table</strong></summary>

| d_model | seq_len | Forward (ms) | Backward (ms) | Memory (GB) | Status   |
|---------|---------|--------------|---------------|-------------|----------|
| 16      | 256     | 0.112        | 0.288         | 0.067       | success  |
| 16      | 1024    | 0.228        | 0.585         | 0.127       | success  |
| 16      | 4096    | 2.535        | 6.091         | 1.072       | success  |
| 16      | 8192    | 9.946        | 23.730        | 4.085       | success  |
| 16      | 16384   | 38.974       | 93.947        | 16.107      | success  |
| 32      | 256     | 0.107        | 0.275         | 0.091       | success  |
| 32      | 1024    | 0.241        | 0.609         | 0.130       | success  |
| 32      | 4096    | 2.703        | 6.166         | 1.081       | success  |
| 32      | 8192    | 10.255       | 23.993        | 4.106       | success  |
| 32      | 16384   | 40.150       | 95.034        | 16.150      | success  |
| 64      | 256     | 0.107        | 0.275         | 0.115       | success  |
| 64      | 1024    | 0.262        | 0.623         | 0.134       | success  |
| 64      | 4096    | 2.968        | 6.686         | 1.100       | success  |
| 64      | 8192    | 11.466       | 26.377        | 4.149       | success  |
| 64      | 16384   | 45.483       | 104.839       | 16.236      | success  |
| 128     | 256     | 0.107        | 0.276         | 0.164       | success  |
| 128     | 1024    | 0.300        | 0.704         | 0.144       | success  |
| 128     | 4096    | 3.562        | 7.916         | 1.137       | success  |
| 128     | 8192    | 13.772       | 30.826        | 4.235       | success  |
| 128     | 16384   | 54.428       | 122.644       | 16.408      | success  |

</details>


## 1.3 Benchmarking JIT-Compiled Attention


Using torch.compile, we get the following result table. The speedup is significant, almost double!
<details>
<summary><strong>Results Table</strong></summary>

| d_model | seq_len | Forward (ms) | Backward (ms) | Memory (GB) | Status   |
|---------|---------|--------------|---------------|-------------|----------|
| 16      | 256     | 0.100        | 0.210         | 0.067       | success  |
| 16      | 1024    | 0.173        | 0.338         | 0.127       | success  |
| 16      | 4096    | 1.065        | 2.515         | 1.072       | success  |
| 16      | 8192    | 4.263        | 9.728         | 4.085       | success  |
| 16      | 16384   | 15.773       | 40.295        | 16.107      | success  |
| 32      | 256     | 0.128        | 0.187         | 0.091       | success  |
| 32      | 1024    | 0.177        | 0.352         | 0.130       | success  |
| 32      | 4096    | 1.431        | 2.790         | 1.082       | success  |
| 32      | 8192    | 5.636        | 10.752        | 4.106       | success  |
| 32      | 16384   | 17.487       | 40.146        | 16.150      | success  |
| 64      | 256     | 0.130        | 0.186         | 0.115       | success  |
| 64      | 1024    | 0.296        | 0.385         | 0.134       | success  |
| 64      | 4096    | 1.680        | 3.144         | 1.100       | success  |
| 64      | 8192    | 5.895        | 12.360        | 4.149       | success  |
| 64      | 16384   | 22.951       | 49.938        | 16.236      | success  |
| 128     | 256     | 0.127        | 0.186         | 0.164       | success  |
| 128     | 1024    | 0.333        | 0.466         | 0.144       | success  |
| 128     | 4096    | 2.283        | 4.368         | 1.137       | success  |
| 128     | 8192    | 8.201        | 16.820        | 4.235       | success  |
| 128     | 16384   | 31.853       | 67.688        | 16.408      | success  |

</details>

Let us now use `torch.compile` on the entire Transformer model. It helps, but not as much as attention.

In [2]:
import numpy as np
import pandas as pd

results = {
    'small': {'forward_only': True, 'warmup_steps': 5, 'benchmark_steps': 5, 'avg': np.float64(0.009206475573591888), 'std': np.float64(7.099247922736725e-05)},
    'medium': {'forward_only': True, 'warmup_steps': 5, 'benchmark_steps': 5, 'avg': np.float64(0.02347335300873965), 'std': np.float64(2.0991335990701267e-05)},
    'large': {'forward_only': True, 'warmup_steps': 5, 'benchmark_steps': 5, 'avg': np.float64(0.055189921008422974), 'std': np.float64(1.4823861521927467e-05)},
    'xl': {'forward_only': True, 'warmup_steps': 5, 'benchmark_steps': 5, 'avg': np.float64(0.1119085440179333), 'std': np.float64(4.086507445538356e-05)},
    '2.7B': {'forward_only': True, 'warmup_steps': 5, 'benchmark_steps': 5, 'avg': np.float64(0.16058695279061794), 'std': np.float64(8.178852687122767e-06)}
}

df = pd.DataFrame(results).T
print("## Forward Only, Compile")
display(df)

results = {'small': {'forward_only': True, 'warmup_steps': 5, 'benchmark_steps': 5, 'avg': np.float64(0.0158657583873719), 'std': np.float64(6.784736402343876e-05)}, 'medium': {'forward_only': True, 'warmup_steps': 5, 'benchmark_steps': 5, 'avg': np.float64(0.030838705226778985), 'std': np.float64(0.0002046254887764915)}, 'large': {'forward_only': True, 'warmup_steps': 5, 'benchmark_steps': 5, 'avg': np.float64(0.06573201038409024), 'std': np.float64(0.00022784046083677796)}, 'xl': {'forward_only': True, 'warmup_steps': 5, 'benchmark_steps': 5, 'avg': np.float64(0.1278103150194511), 'std': np.float64(0.00018463351561487138)}, '2.7B': {'forward_only': True, 'warmup_steps': 5, 'benchmark_steps': 5, 'avg': np.float64(0.17442646059207617), 'std': np.float64(3.606307832233569e-05)}}
df = pd.DataFrame(results).T
print("## Forward Only, No Compile")
display(df)

results = {'small': {'forward_only': False, 'warmup_steps': 5, 'benchmark_steps': 5, 'avg': np.float64(0.036864083562977615), 'std': np.float64(0.00014968863947294376)}, 'medium': {'forward_only': False, 'warmup_steps': 5, 'benchmark_steps': 5, 'avg': np.float64(0.09583398420363665), 'std': np.float64(0.00012088628142813287)}, 'large': {'forward_only': False, 'warmup_steps': 5, 'benchmark_steps': 5, 'avg': np.float64(0.22588904437143356), 'std': np.float64(0.0002103569821551783)}, 'xl': {'forward_only': False, 'warmup_steps': 5, 'benchmark_steps': 5, 'avg': np.float64(0.4523045166162774), 'std': np.float64(0.0003389353421556619)}, '2.7B': {'forward_only': False, 'warmup_steps': 5, 'benchmark_steps': 5, 'avg': np.float64(0.6798262841999531), 'std': np.float64(0.00010838741021184491)}}
df = pd.DataFrame(results).T
print("## Full Step, Compile")
display(df)

results = {'small': {'forward_only': False, 'warmup_steps': 5, 'benchmark_steps': 5, 'avg': np.float64(0.0484661613823846), 'std': np.float64(0.0003632654241477262)}, 'medium': {'forward_only': False, 'warmup_steps': 5, 'benchmark_steps': 5, 'avg': np.float64(0.11584463163744659), 'std': np.float64(0.0005004669921443062)}, 'large': {'forward_only': False, 'warmup_steps': 5, 'benchmark_steps': 5, 'avg': np.float64(0.25648152017965914), 'std': np.float64(0.0002820108995609013)}, 'xl': {'forward_only': False, 'warmup_steps': 5, 'benchmark_steps': 5, 'avg': np.float64(0.49998851676937195), 'std': np.float64(0.00021257470922022222)}, '2.7B': {'forward_only': False, 'warmup_steps': 5, 'benchmark_steps': 5, 'avg': np.float64(0.7249885434051976), 'std': np.float64(0.0010648416588595853)}}
df = pd.DataFrame(results).T
print("## Full Step, No Compile")
display(df)


## Forward Only, Compile


Unnamed: 0,forward_only,warmup_steps,benchmark_steps,avg,std
small,True,5,5,0.009206,7.1e-05
medium,True,5,5,0.023473,2.1e-05
large,True,5,5,0.05519,1.5e-05
xl,True,5,5,0.111909,4.1e-05
2.7B,True,5,5,0.160587,8e-06


## Forward Only, No Compile


Unnamed: 0,forward_only,warmup_steps,benchmark_steps,avg,std
small,True,5,5,0.015866,6.8e-05
medium,True,5,5,0.030839,0.000205
large,True,5,5,0.065732,0.000228
xl,True,5,5,0.12781,0.000185
2.7B,True,5,5,0.174426,3.6e-05


## Full Step, Compile


Unnamed: 0,forward_only,warmup_steps,benchmark_steps,avg,std
small,False,5,5,0.036864,0.00015
medium,False,5,5,0.095834,0.000121
large,False,5,5,0.225889,0.00021
xl,False,5,5,0.452305,0.000339
2.7B,False,5,5,0.679826,0.000108


## Full Step, No Compile


Unnamed: 0,forward_only,warmup_steps,benchmark_steps,avg,std
small,False,5,5,0.048466,0.000363
medium,False,5,5,0.115845,0.0005
large,False,5,5,0.256482,0.000282
xl,False,5,5,0.499989,0.000213
2.7B,False,5,5,0.724989,0.001065
