# Flash Attention Deepdive

## Overview
Flash attention has evolved over the last 3 years with v1, v2 and v3. A table with the key differences is provided at the end of this notebook. 
- Analysis and Limitations of Vanilla attention
- Softmax and Online Softamx
- More on 
- Vanilla attention deepdive
- Flash Attention

Discuss Advance topics
- flex attention 
- deepseek attention methods
- Ohter customization from character ai etc. 

### Why Flash Attention?
* Vanilla attention is $O(N^2)$ in runtime and memory. As the need for larger and larger seq lenghts grows, this quadratic relationship is a big bottleneck
* Prior attempts have mostly targeted at ideas like FLOPS reduction (sparse attention) using modified versions of attention without taking into account SW-HW co-design.
    * HBM is an order of magnitude slower than SMEM. 


### How
Flash Attention V1 is inspired by the `DataMovement is all you need` paper. This paper analyzes the data movement and storage patterns of attention ops. The key takeaway is that vanilla attention is memory-bound (figure 1). 

* **IO Aware**: Takes into account the GPU memory hierarchy and compute architecture. Basically move beyond the mathematical / PyTorch formulation of attention and think critically about the target devices (Mi300x, H100, H200 etc) memory and compute properties. 
    * **Tiling**: To take full advantage of the various memory capacities and bandwidth (HBM, SMEM, L2 Cache, Registers, etc), the attention algorithm is re-organized in a way that doesn't change the end result (ie exact). It uses "math tricks" to make the algo tiling compatible. **NOTE:** By default the attention function is not "tilable" due to the non-associate and non-distributive nature of the softmax function. More on this later. 

* **Activation Recomputation**
How do you get here ? Using the data movement paper + facts/research observations on Elementiwise vs matmul parts in terms of memory bound. 


In simpler terms, the flash attention algo provides a **fused kernel implememtation** of the original attention function $$O = \text{dropout(mask(softmax}(Q \dot K^T))V))$$. 

Formally, Dao et al. introduced FlashAttention, a novel tiling strategy for parallelizing attention that eliminates intermediate
reads/writes to slow global memory through fusing all of the attention operations into a single GPU kernel. Dao
[15] restructured the algorithm as FlashAttention-2 to also parallelize over the sequence length dimension and
perform the inner loop of the forward pass over blocks of the key and value matrices, thus improving the occupancy
and distribution of work on the GPU

## Outcome 
As a result of Flash attention 1
* Wall clock time speedup of 2- 4x over optimized baselines 
* O(N) in memory 
* About 25-50% FLOPS utilization (compared to theoretical max) over a mere 10% FLOPS utilization with vanilla attention. Note that this flops util is still low if you compare it to GeMMs. We'll address some of this in **Flash Attention v2**. 

## Revisit GPU Fundamentals

### GPU Memory Hierarchy
| Memory Level         | H100                   | MI300X                 | Description                                                                                  |
|-----------------------|-------------------------|-------------------------|----------------------------------------------------------------------------------------------|
| Registers            | 256KB per SM, ~90 TB/s  | 256KB per CU, ~100 TB/s| Fastest memory, private to each thread. Stores variables and intermediate values for execution.|
| Shared Memory (SMEM) | upto 258KB per SM, ~50 TB/s  | 256KB per CU, ~60 TB/s | Programmable memory local to thread blocks. Used for data sharing and reuse within a block. Nvidia combines the SMEM and L1 cache into a unified memory region that the programmer can control allocation ratio for. For applications with high thread-level data sharing (e.g., matrix multiplication), the GPU can allocate more of this 256 KB as shared memory. |
| L1 Data Cache             | 128KB per SM, ~50 TB/s  | 128KB per CU, ~60 TB/s | Automatically caches frequently accessed local and global memory. Shared within an SM/CU.    |
| L2 Cache             | 50MB, 37.5 TB/s         | 256MB for 8 XCDs, 50 TB/s          | Last-level cache before accessing HBM. Shared across all SMs/CUs on the GPU. Also called as AMD Infinity Cache.               |
| HBM3                 | 80GB, 3.35 TB/s         | 192GB, 5.3 TB/s        | High Bandwidth Memory used for storing large datasets and models. Accessible by all SMs/CUs.|

* We'll take into account the realization that HBM r/w (2TB / sec) are an order of magnitude slower than SMEM r/w (20TB / sec)

**GPU Compute Hierchy**
* Threads: Each thread has access to private set of registers. Add numbers
* Warps: A group of 32 threads. On the AMD side a warp also called an exectuion unit (EU) is a group of 64 threads
* Warpgroups: A group of 4 contingous warps. This is the one that as WGMMA instruction in H100s
    * [https://pytorch.org/blog/warp-specialization/#:~:text=Warp%20specialization%20(WS)%20is%20a,task%20differentiation%20or%20cooperative%20processing][warp specialization] 
* Threadblocks (a.k.a CTA - co-operative thread arrays):
   * WHen a block is assigned to an SM, it is divided into warps. Say we have a 2D blocks sizse 32*128 then total threads = 4096. Given a warp size of 64 on AMD we get a total of 4096/64 = 64 warps on this SM.
   * The H100 supports a maximum of 1024 threads per block and up to 32 blocks per SM.
* ThreadBlock Clusters 
* WARPS PER SM =

When a warp encounters an instruction where an operand isn't ready (e.g., waiting for memory access), the scheduler can perform "context switching" to execute another warp instead of waiting idle, which helps hide memory latency.

The compute Hierarchy and Memory Heiarchies are closely inter-twinded

## Vanilla Attention 
### Formulation
Remember that the IO unaware formulation of attention is as follows:

$\text{Let } Q, K, V \in \mathbb{R}^{N \times d} \text{ be the query, key and value input sequences}$
$ \text{associated to a single head, where } N \text{ is the sequence length and } d \text{ is the} $
$ \text{head dimension. Then the attention output } O \text{ is computed as:} $ 

\begin{align*}
& S = \alpha QK^\top \in \mathbb{R}^{N \times N}, && \\ 
& P = \text{softmax}(S) \in \mathbb{R}^{N \times N}, && \\ 
& O = PV \in \mathbb{R}^{N \times d}, &&
\end{align*}

Then give B Batches and H heads (MHA) for each of Q, K, V we get the following shapes: 
* Q: `[BatchSize (B), NumHeads (H), Seq Len (N), Embedding Dim (d)]`
* K: `[BatchSize (B), NumHeads (H), Seq Len (N), Embedding Dim (d)]`
* V: `[BatchSize (B), NumHeads (H), Seq Len (N), Embedding Dim (d)]`

You can parallelize the GeMM across BATCH and HEAD dimensions. 


### Limitation

#### I - Attention is $O (N^2)$ in Memory Complexity 
**Materiallization**: if N is too large then it would **NOT be possible** at all to fit the entire $Q \dot K^T$ in the `SMEM`. Since $Q \dot K^T \in N, N$, this means that Attention is $O(N^2)$ in memory complexity, where N is the number of tokens/sequence len.

* N = 8192.
* B = 1
* d = 128 [does not matter in this calclulation]
* NUM HEADS = 16
* Precision = bfp16 (2 bytes)
* Then NxN matrix = 8192 X 8192 X 1 X 16 elements = 1073741824 where each element takes 2 bytes.
* <mark>Total memory required if NxN is materialized ~2148 MB. </mark>. This is TOO large to fit in SMEM which is usually in KBs.

<u>Understanding the Space & Time Complexity of Attention<u>
* Lower Bounds on IO access
* Parameterization

#### II -  Attetion is IO Bound  
**IO Bound**:  Standard computation would require 6 trips (3 loads and 3 stores) from the HBM and is memory bound. 
* Load $Q$ & $K$   --> compute $X = Q \dot K^T$   --> Store $X$     
* Load $X$         --> compute $A = softmax(X)$   --> Store $A$
* Load $A$, $V$    --> compute $O = X  V$          --> Store $O$

We already know that HBM access is 10x slower than SMEM access. 

Add the distribution of time vs flops chart hereâ€”link to glossary.

### Addressing the Limiations
1. Tiling
Matrix Mul Op have the nice property of **Associativity** & **Distributivity**. Therefore the $Q \dot K^T$ can be tiled achieving massive paralllism on GPUs. 
2. However, the subsequent operation `SOFTMAX` CANNOT be tiled. Additionally, using the 'Safe Softmax' operations requires us to scan the entire sequence length for the max value. We'll use online softmax to make softmax tilable and use activation recomputation. Doing so we'll be able to compute the entire Attention OP without materialization the intermediate S and P matrices on the HBM.  

1. Left: The GPU is a hierarchical computing and memory system. 
    * The HBM BW (gbps / seconds) is an order of magnitude (10x) slower that the the SMEM
    * However, the the SMEM memory is 100x smaller than the HBM in storage (HBM: 80GB H100, 192 GB MI300x vs SMEM: )
2. Middle: This is the core logic / kernel `_attention_fwd` of the flash attention algorithms in terms. This helps avoid the materialization of the N x N matrix 
    * The **outer for loop** loads the 
    
<img src="diagrams/inner_loop_attn_paper.png" alt="Description" width="1000" height="400">

#### Online Softmax
**Safe Softmax** Additionally in the real world we use  `Safe Softmax` operation. Safe Softmax **subtracts the MAX values to avoid overflow  and underflow** given a certain bit-width. Ex: IN fp16 (2bytes) the max value can only be 65556. Therefore if $x_{i}=12$ then  $e^{12}$ is too large to fit into the dynamic range of FP16. 

* fp16 is with: `1 signed bit |  5 exponent bits |  10 mantissa` therefore very high precision but lower dynamic range 
* bf16 is with: `1 signed bit |  8 exponent bits | 7 mantissa bit`  therefore very low precision but higher dynamic range

$$\text{Safe Softmax}(x_i) = \frac{e^{x_i - M}}{\sum_{j=1}^{n} e^{x_j - M}} \quad \text{where } M = \text{max} (x_i)$$


The online softmax (by Milakov et al) title `Online normalizer calculation for softmax` was introduced by Nvidia in 2018. 

1. Take care of overflows / underflow by subtracing the max value. IF the bit width is 16 (bf16, fp16).
2. 

The **implication / advantage of doing Online Softmax** is that equation (7) and equation (8) can be fused since you don't require $m_{N}$ any more and only require $m_{i}$ and $m_{i-1}$. All of this is hinged on the fact that $d_{N} == d^{`}_{N}$

### Understanding the space-time complexity of Attention 
* Lower Bounds on IO access
* Parameterization

In [10]:
import math 

m = - math.inf
l = 0

a = [1, 6, 11, 4, 5]
m = max(a)                                    # O(N)

# caculation the normalization (denomintor)
for x in a:                                   # O(N)
    l += math.exp(x-max_a)
print(f"normalization: {normalization}")

safe_softmax = []
for x in a:                                   # O(N)  
    prob = math.exp(x-max_a) / l
    safe_softmax.append(prob)
    
print(f"safe_softmax: {safe_softmax}")

normalization: 1.0101739810710686
safe_softmax: [4.494268374874213e-05, 0.006670085673698849, 0.9899284863184842, 0.0009026979338625064, 0.0024537873902059763]


1.0000000000000002

## Tile Programming
As shown above using the **online softmax** provides us with **eventual consistency**. Therefore we can now parallelize across 

* X: seq length
* Y: batch*head
* Z: 1 [no parallelism]

<mark>Each **CUDA THREAD BLOCK** gets a **UNIQUE PROGRAM ID**<mark>

<img src="diagrams/fa_triton_program_id_mapping.png" alt="Description" width="700" height="500">


### Making Attention "Tilable" - FWD PASS MATH


BLOCK_M: Seq Tile size
BLOCK_N


Block level partial softmax is computed. For the partial softmax to converge to actual softmax we need to maintain 2 statistics - `m(x)`. We do this recursively over all the N/BLOCK_SIZE number of blocks,  accumulatating the results in the accumulator O.

**Notatation**
1. Given a shared memory size of $M$ we determined the block size to be  

    BLOCK_SIZE = ($B_c$,   $B_r$) where $B_c = \left\lceil\frac{M}{4d}\right\rceil$ and $B_r = \min\left(\left\lceil\frac{M}{4d}\right\rceil, d\right) $

   We'll tune the block_size and find the best M based on our GPU target arch.

3. Init Tensors
    * **Output tensor**: Allocate 0s for $O$ where $O = PV \in \mathbb{R}^{N \times d}$
    * m, l
4. Next let's chunk the Q, K, and V matrices into block
    * **row wise split**: divide the Q matrix into chunks ($Q_1, Q_2, .... Q_{Tr}$) where each chunk $Q_i$ is $B_r$ X $d$ size. We have a total of $T_r$ such chunks where $T_r = N / B_r$
    * **col wise split**: divide the K and V matrix into chunks ($K_1, K_2, .... K_{Tc}$) where each chunk $K_i$ is $B_c$ X $d$ size. We have a total of $T_c$ such chunks where $T_c = N / B_c$

5. Block wise output $O$ 

6. Implement **Fused Kernel** using Tiling and Online Softmax 
    * **Outer Loop**: Load $K_j$ and $V_j$ for each j in [1 .... $Tc$]
    * **Inner Loop**: Load $Q_i$ for each i in [1 ... $Tr$]
     

<img src="diagrams/flash_attention_algo_v1.png" alt="Description" width="900" height="500">

$\text{Let } Q, K, V \in \mathbb{R}^{N \times d} \text{ be the query, key and value input sequences}$
$ \text{associated to a single head, where } N \text{ is the sequence length and } d \text{ is the} $
$ \text{head dimension. Then the attention output } O \text{ is computed as:} $ 

\begin{align*}
& S = \alpha QK^\top \in \mathbb{R}^{N \times N}, && \\ 
& P = \text{softmax}(S) \in \mathbb{R}^{N \times N}, && \\ 
& O = PV \in \mathbb{R}^{N \times d}, &&
\end{align*}

Then give B Batches and H heads (MHA) for each of Q, K, V we get the following shapes: 
* Q: `[BatchSize (B), NumHeads (H), Seq Len (N), Embedding Dim (d)]`
* K: `[BatchSize (B), NumHeads (H), Seq Len (N), Embedding Dim (d)]`
* V: `[BatchSize (B), NumHeads (H), Seq Len (N), Embedding Dim (d)]`

1. Batches and the heads are to be treated independently [heads don't pass around info]. Therefore, the effective matrix becomes `[B*H, N, d]`where `B*H` acts as the batch dimension. In addition, **Flash Attention 2** parallelizes along the sequence dimension as well.

* **Flash attention v1** paralellizes over the batch size and the number of heads since this each head is indepdendent. 1 threadblock is used to process 1 attention head.  Therefore given a `batch size=B` and `number of heads=H`, we have **total number of threadblocks being launched = H * B**
* `A100 has 108SM SM`,  `H100 has 114SMs` and `Mi300x has 304SMs`. Therefore, the above scheduling of a) parallelization across heads b) 1 head per SM is effciiency when Total heads >= 100 since we end up using most compute resources of the GPU.
* However, in the case of long sequences (Llama2 8192) we usually have to go with both smaller batch sizes and less number of heads. Therefore this scehduling becomes inefficient. We'll fix this in **FlashAttention 2**


<center><img src="diagrams/inner_loop_attn.png" alt="Description" width="500" height="300"><center>

## Memory Store and Load
* Transpose
* Materialization
* Load Store Units, Vectorized Load Stores
* Strides, Contingous Memory , Non contiguous MEmory
* Moving data from HBM to SMEM. Inner workings of SMEM
* L1 cache, L2 cache, REgisters

Explain how is Q actually stored in memory ????

## Coding Flash Attention FWD
* Flash Attention 2 FWD Math

### Feature Support List
* Causal Masking
* Dropout
    * Requires a random number generator per thread. Philox Seed 
    * Modern LLMs like LLama2, 3 do not use attention dropout. Earlier Encoder only models like BERT do use attention dropout
* Attention Bias
* Attention Variants
    * Multi-Head Attention
    * Grouped Query Attention: Group query attention helps reduce the KV cache size by sharing K, V across a set of Queries. This require index manipulation to make the core attention $O= \text{softmax}(Q K^T)V$ work 
    * Multi-Query Attention
    * Sliding Window Attention (SWA): A default value of `(-1, -1)` means attend to everything on the left and right.
    * Ring Attention and Context parallelism
* Sequence Length Support
    * Variable Seq Lenghts
    * Fixed Sequence Lengths.  
* FP8 Support
    * Scaling Strategies
* More Modern Attention Variants
  * MLA - Mult-head latent Attention
 

### Python Autograd API
This section creates the python wrapper that integrates into Pytorch by inheriting the `torch.autograd.Function` class and implementing the `foward()` and `backward()` member functions.

**FWD PASS SETUP & LAUNCH**

The broad recipe we follow: 
1. Asserts: Check if the shapes of Q, K, V tensors are good. Check if values are the type and shape they should be.
2. Allocate Input and Output buffers.
3. Decide how to parallelize the workload - reason about grids, warps, compute units, numblocks, block size w.r.t the workload under consideration
4. Launch FWD Triton Kernel using the `kernel_name[grid](<kernel params>)`
5. Save CTX needed for BWD. This context would not be necessary if were doing inference. The context is essentially data that the BWD pass will require. 
    - LSE


Its key to remember that while our $X \in mathbb{R}^{B, H, Seq, d]$

the Transformer is operating at a token level. Therefore we can generalize the input to [Num Tokens, d] where Num `Tokens = B * H * Seq`

In [None]:
import torch 

device = 'cuda' if torch.cuda.is_available() else "cpu"

BATCH_SIZE = 4
SEQ_LEN = 4096
EMB_DIM = 1024

# HEADS_Q == HEADS_KV is for Multi-head attention
# HEADS_Q != HEADS_KV and HEADS_Q > HEADS_KV is for MQA
NUM_HEADS_Q = 16
NUM_HEAD_KV = 16

device = torch


# Add support for QKV, QK and V, Q and K and V
flash_attention_v2(
    batch_size=BATCH_SIZE,
    num_heads_q=NUM_HEADS_Q, 
    num_heads_kv=NUM_HEAD_KV,
    seq_len=SEQ_LEN,
    emb_dim=EMB_DIM,
    causal=True,
    mode='fwd',
    dropout=False
)

# python api
def flash_attention_v2(
    batch_size: int,
    num_heads_q: int,
    num_heads_kv: int, 
    seq_len: int,
    emb_dim: int,
    causal:bool,
    mode: str,
    dropout: bool
):
    dtype = torch.float16

    # Flash Attention 1: Parallelizes over Batch and Heads
    # Flash Attention 2: Parallelises over Batch, Heads and Sequence Len
    q = torch.randn((batch_size, num_heads_q, seq_len, emb_dim), dtype=dtype, device=device) 
    k = torch.randn((batch_size, num_heads_kv, seq_len, emb_dim), dtype=dtype, device=device) 
    v = torch.randn((batch_size, num_heads_kv, seq_len, emb_dim), dtype=dtype, device=device) 
    
    if mode == 'fwd':
        q = q.to(torch.float8_e5m2)  # why e5m2 for FWD pass ?
        k = k.to(torch.float8_e5m2)  # what is _nuz ?

    softmax_scale = 1 / math.sqrt(EMB_DIM)  
    fn = lambda: FusedAttention.apply(q, k, v, causal, softmax_scale, dropout)
    ms = triton.testing.do_bench(fn) # benchmark the op and report time in milli seconds. 

In [11]:
import triton
import torch 

def is_hip():
    return True if triton.runtime.driver.active.get_current_target().backend == 'hip' else False


class FusedAttention(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v, causal:bool, softmax_scale: float):
        """
        the ctx object is used to build ctx that is required during the bwd pass
        we'll use ctx.save_for_backward(...) to do this. 
        """
        ### ------ 1. Check Inputs --------------
        HEAD_EMB_DIM_Q = HEAD_EMB_DIM_K = q.shape[-1] = k.shape[-1]  # q: [B, H, S, d]
        HEAD_EMB_DIM_V = v.shape[-1]

        assert HEAD_EMB_DIM_Q == HEAD_EMB_DIM_K == HEAD_EMB_DIM_V, "All heads must have the same embedding dimension"
        assert HEAD_EMB_DIM_K in {16, 32, 64, 128, 256}, "Only head size of 16, 32, 65, 128, 256 are supported"  
        # why are only these supported ??
        # - usually we prefer powers of 2 because of GPU SMs (ALUs)being a power of 2 for distirbuting computation. 
        # -  

        ### ------- 2. Allocate Buffers ----------
        o = torch.empty_like(q)      # [B, H, Seq, d]
        stage = 3 if causal else 1   # causal requires special considerations such as masking. If masking is required then additional memory needs to be allocated for the boolean mask. 
        
        ### ------ 3. Grid, Warps, EUs, Blocks etc ---------
        # SM == EU; 
        # WARPS (32 threads) == WAVEFRONTS (64 threads) ; 
        # OCCUPANCY == WAVES_PER_EU == ACTIVE WARPS PER SM
        # How does this affect perf
            # - register pressure 
        if is_hip(): 
            # Increase the active warps count if ...
            # Decrease the warp count if ...
            waves_per_eu = 3 if HEAD_EMB_DIM_K <= 64 else 2 
            extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True}

        num_heads_q = q.shape[1]
        batch_size = q.shape[0]

        # BLOCK_M is B_r
        # If seq_len=4096 and BLOCK_M=128 then NUM_BLOCKS=32. Each Threadblock will process BLOCK_M elements of the seq.
        x_grid = triton.cdiv(q.shape[2], BLOCK_M)  
        y_grid = num_heads_q * batch_size    # HEADS and BATCH_SIZE are truly independent. No TB communication required for parallelization.  
        z_grid = 1 
        grid = (x_grid, y_grid, z_grid) # parallelization was done from inwards to outwards. 
        
        mask = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)

        ### ------ 4. Launch ---------
        _attn_fwd[grid](
            q,
            k,
            v,
            softmax_scale,  # softmax scale. 
            mask,           # causal mask tensor
            o,              # Final Output  
            dropout,        # dropout
            q.stride(0), q.stride(1), q.stride(2), q.stride(3),  #  Batch, Head, Seq, Emb strides required for correct CUDA memory access. 
            k.stride(0), k.stride(1), k.stride(2), k.stride(3),  #  Batch, Head, Seq, Emb strides required for correct CUDA memory access.
            v.stride(0), v.stride(1), v.stride(2), v.stride(3),  #  Batch, Head, Seq, Emb strides required for correct CUDA memory access.
            o.stride(0), o.stride(1), o.stride(2), o.stride(3),  #  Batch, Head, Seq, Emb strides required for correct CUDA memory access.
            batch_size,              # Batch size
            num_heads_q,             # Number of query heads
            N_CTX=q.shape[2],        # SeqLen
            HEAD_DIM=HEAD_DIM_K,     #
            STAGE=stage,  
            **extra_kern_args
        )

        ### Save context for BWD
        ctx.save_for_backward(q, k, v, o, M)
        ctx.grid = grid
        ctx.sm_scale = sm_scale       # normalization factro diag(l)^-1
        ctx.HEAD_DIM = HEAD_DIM_K
        ctx.causal = causal

    
    @staticmethod
    def backward():
        """
        You can make mistakes while implemetning the bwd pass. To help with this we can use the .gradcheck() to use 
        infinite small internal [f(a+h) - f(a-h)] / 2h. In addition to that people use tools like PySolver.
        """
        pass



ModuleNotFoundError: No module named 'triton'

### FWD KERNEL in TRITON

Given:
* S = 8192
* BLOCK_M = 64
* d = 128
* NUM_HEADS (H) = 32
* BATCH_SIZE (B = 8

grid_x = NUM_BLOCKS_IN_SEQ_DIM = ceil_div(S/BLOCK_M) = 8192 / 64 = 128. # We'll process 64 tokens at a time
grid_y = B*H # easy parallilism along Batch and Head Dim as done in flash attention v1. 

We can think about the parallelism as 
* Pure Parallisim: Where the task is easily parallilization no ifs and buts.
* Tiling: GPU programming required, warps etc. 

A column is a full sequence for a given batch and head. Each column is chunked into block sizes of size 64, s.t that there are a total of 128 chunks.  
**Example 1 (3, 12)**
- `start_m = program_id(3)`       # processing **the 4th chunk** of the seq [0-63, 64-127, 128-195, **196-255(4th chunk)**, .... 8128-8191]
- `offset_bh = program_id(12)`
- `offset_bh = 12`
- `offset_bh // H` = 12 // 32 =  0th Batch
- `offset_bh  % H` = 12 % 32  = 12th Head


**Example 2 (*, 50)**
- For a given sequence chunk we current process 
    - offset_bh = 50
    - offset_bh // H = 50 // 32 =  1st Batch
    - offset_bh  % H = 50 % 32  = 18th Head


**Example 3 (127, 120)**
- start_m = 127      # Last Sequence chunk [8128-8191]
- offset_bh = 120
- offset_bh // H = 120 // 32 = 7
- offset_bh % H = % 32 = 


In [24]:
@triton.jit
def _attn_fwd_inner(
    acc, 
    li, 
    mi, 
    q,  # [BLOCK_M, HEAD_DIM]
    K_block_ptr, V_block_ptr, 
    start_m, 
    qk_scale,  # softmax scale, default to sqrt(HEAD_DIM)
    BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr,  #
    STAGE: tl.constexpr, 
    offs_m: tl.constexpr, offs_n: tl.constexpr,  #
    N_CTX: tl.constexpr,  # The operation is parallelized by dividing the sequence length (N_CTX) into blocks of size BLOCK_N.
    fp8_v: tl.constexpr
        
): 
    """
    Load blocks of Q, K and V from HBM.
    """
    if STAGE == 1:
        lo, hi = 0, start_m * BLOCK_M
    elif STAGE == 2:
        lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
        lo = tl.multiple_of(lo, BLOCK_M)
    # causal = False, so keep the entire range of 0, N_CTX[no masking]
    else:
        lo, hi = 0, N_CTX
    
    # low, high marks the low/high index position over seq len. 


@triton.jit
def _attention_fwd(
    Q,                       # Q:  M, K matrix
    K,                       # K:  M, K matrix
    V,                       # V:  M, K matrix
    softmax_scale, 
    OUT,                     # O:  M, K matrix
    stride_qb, stride_qh, stride_qm, stride_qk,  # jump next batch, jump next head, jump next seq, jump next d  
    stride_kb, stride_kh, stride_kn, stride_kk,  # jump next batch, jump next head, jump next seq, jump next d  
    stride_vb, stride_vh, stride_vk, stride_vn,  # jump next batch, jump next head, jump next seq, jump next d  
    stride_ob, stride_oh, stride_om, stride_on,  # jump next batch, jump next head, jump next seq, jump next d  
    B,                       # Batch Size 
    H,                       # Number of Heads
    N_CTX,                   # Sequence Length: 8192
    lse,                     # Log Sum Exponential. Required for BWD. 
    dropout_p,               # Attention Dropout 
    is_causal,               # Causal Attention Flag
    HEAD_DIM: tl.constexpr,  # d
    BLOCK_M: tl.constexpr,   # Tile size in the QUERY SEQ DIMENSION
    BLOCK_N: tl.constexpr,   # Tile size for Key/Value dimension
    STAGE: tl.constexpr      #git


):
    """
    3. Variable sequence Lengths Layout (thd) --> (token, head,  dim_per_head)
    4. Context Parallelism
        - Due to intern token communication during attention, we need to all gather during fwd and reduce scatter during bwd

    """
    """
    Here we try to create a highly parallelizable "instance of the program". Therefore we must
    1. Define the program with the correct way for a program instance to know what data to work [which block, which tile, which batch which head & so on]
    2. 
    """
    start_m = tl.program_id(0)    # get the current program instance along x axis. Achieve parallelism by tiling along seq dim [X-axis]
    offset_bh = tl.program_id(1)  # get the current program instance along y axis.  Parallelism across the Batch * Head

    off_b = offset_bh // H  # Current batch.  
    off_h = offset_bh % H   # Current head in that batch 

    # represents the starting point of the block. 
    qkv_offset =  off_b.to(tl.int64) * stride_qb  + off_h.to(tl.int64) * stride_qh

    # With very-large tensor sizes, the offsets can be huge. Use int64 to avoid overflow. 
    # Get Q given a head and batch iie qkv_offset = batch_id * stride_batch_Q + head_id * stride_head_Q. 
    # Programs gets dispatched onto the BLOCKs(The abstraction for Triton)
    # make_block_ptr: Returns a pointer to a block in a parent tensor
    Q_block_ptr = tl.make_block_ptr(
        base = Q + qkv_offset,             # offset Q by qkv_offset to get to the current batch * head for the current threadblock. 
        shape = (N_CTX, HEAD_DIM),         # (N, d)
        strides = (stride_qm, stride_qk),  # The strides of the parent tensor [SeqLen X EmbedDim] ==> how to jump rows and cols of Q [mk] matrix. Helps with info contigous memory storage & access. 
        offsets = (start_m * BLOCK_M, 0)   # offsets to the block ==>  how to access the next block
        block_shape = (BLOCK_M, HEAD_DIM)  # Size of the block to load from HBM to SMEM / LDS [tile size of seqlen, d]
        order = (1, 0)                     # The order of the original data format [TN format]
    )

    K_block_ptr = tl.make_block_ptr(
        base = K + qkv_offset,             # The base pointer to the parent tensor
        shape = (N_CTX, HEAD_DIM),         # Shape of the parent tensor (seqlen tile size , head_dim)
        strides = (stride_kk, stride_kn),  # The strides of the parent tensor [EmbedDim X SeqLen] ==> How to jump rows and cols of K [kn] matrix 
        offsets = (0, 0)                   # Since for a given Q_i we get "ALL" K_j and V_j for j = 1 ....
        block_shape = (HEAD_DIM, N_CTX)    # 
        order = (0, 1)                     # The order of the original data format
    )

    v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0)
    V_block_ptr = tl.make_block_ptr(
        base = V + qkv_offset,             # The base pointer to the parent tensor
        shape = (N_CTX, HEAD_DIM),         # Shape of the parent tensor (seqlen tile size , head_dim)
        strides = (stride_vk, stride_vn),  # The strides of the parent tensor [EmbedDim X SeqLen] ==> How to jump rows and cols of V [kn] matrix 
        offsets=(0, 0),                    # Since for a given Q_i we get "ALL" K_j and V_j for j = 1 ....
        block_shape = (BLOCK_N, HEAD_DIM)  # SEQ TILE SIZE , HEAD DIM
        order = v_order                    # The order of the original data format
    )
    O_block_ptr = tl.make_block_ptr(
        base=OUT + qkv_offset,             # start of current block
        shape=(N_CTX, HEAD_DIM),           # N X d
        strides=(stride_om, stride_on),    # (Jump Seq, Jump d) 
        offsets=(start_m * BLOCK_M, 0),    # Output BLOCKS
        block_shape=(BLOCK_M, HEAD_DIM),   # Each block is M x d
        order=(1, 0)
    )

    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)

    # epilogue

    # Since we are dpoing everything block wise we need to have the right offset 
    # for storing m_ptrs
    # BLOCK_SIZE + JUMP ACROSS head * num batches * partial seqlen + offset_m 
    m_ptrs = M + offset_bh * N_CTX + offs_m
    tl.store(pointer=m_ptrs, value=m_i) # store the m_i block. 


SyntaxError: invalid syntax. Perhaps you forgot a comma? (132727422.py, line 84)

### FWD Pass Correctness Check

### BWD KERNEL
### BWD Pass Correcntess Check

## Kernel Tuning
**Tile Size**
1. use a smaller tile size (16x16) for latency sensitive workloads. Use bigger for throughput sensitive. 
2. Smaller SMEM => Smaller tile size to fit in the fused op
3. Higher accuracy due to accumulation over smaller ranges. 

## Benchmarking
Compare across input types and sizes. 

1. GFLOPS per second
2. Data movement: GBPS per second
3. wall clock time (milliseconds)
4. GPU Level Metrics across compute and memory hierarchy
    * Compute
    * How many warps are active out of all the max warms - Occupancy
    * Thread Divergence
    * Cache hit rate
    * Register spills
    * Benchmarking: How many simulatenous blocks can run. What is the right balance for blocks vs tile size ? Does that/should that way across input sizes and types. large matrix small matrix, fp8, fp16, 

## Understanding the TriDao Repo in detail

## Glossary of Terms
1. Compute Bound vs Memory Bound
2. Arithmetic Intensity 
3. 