# Flash Attention V3 Deepdive

Dao et al. introduced FlashAttention, a novel tiling strategy for parallelizing attention that eliminates intermediate
reads/writes to slow global memory through fusing all of the attention operations into a single GPU kernel. Dao
[15] restructured the algorithm as FlashAttention-2 to also parallelize over the sequence length dimension and
perform the inner loop of the forward pass over blocks of the key and value matrices, thus improving the occupancy
and distribution of work on the GPU

$Q [k, :]  $  --> The Kth row of  

## Python Autograd API
**FWD PASS SETUP & LAUNCH**
1. Asserts: Check if the shapes of Q, K, V tensors are good. Check if values are the type and shape they should be
2. Allocate Input and output buffers.
3. Decide how to parallelize the workload: grids, warps, compute units, numblocks, block size w.r.t the workload under consideration
4. Launch FWD Triton Kernel using the `grid[]()`
5. Save CTX needed for BWD

In [None]:
import triton
import torch 

def is_hip():
    return True if triton.runtime.driver.active.get_current_target().backend == 'hip' else False


class _attention(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v, causal:bool, softmax_scale: float):
        """
        the ctx object is used to build ctx that is required during the bwd pass
        we'll use ctx.save_for_backward(...) to do this. 
        """
        # 1. asserts
        # 2. Allocate Input and output buffers 
        # 3. grids, warps, compute units, numblocks, block size
        # 4. Launch


        ### ------ Check Inputs --------------
        HEAD_EMB_DIM_Q == HEAD_EMB_DIM_K == q.shape[-1] == k.shape[-1]  # q: [B, S, H, E]
        HEAD_EMB_DIM_V == v.shape[-1]

        assert HEAD_EMB_DIM_Q == HEAD_EMB_DIM_K == HEAD_EMB_DIM_V
        assert HEAD_EMB_DIM_K in {16, 32, 65, 128, 256}, "Only head size of 16, 32, 65, 128, 256 are supported"  # why are only these supported ??

        ### ------- Allocate Buffers ----------
        o = torch.empty_like(q)      # whis the output the shape of q tensor ?
        stage = 3 if causal else 1   # causal requires special considerations such as masking
        
        ### ------ Grid, Warps, EUs, Blocks etc ---------
        # SM == EU; WARPS (32 threads) == WAVEFRONTS (64 threads) ; OCCUPANCY == WAVES_PER_EU == ACTIVE WARPS PER SM
        # How does this affect perf
            # - register pressure 
        if is_hip(): # 
            waves_per_eu = 3 if HEAD_EMB_DIM_K <= 64 else 2 # why is this set this way ?
            extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True}

        # The grid can be 3D. 
        # Since Q.shape = [BATCH, NUM_HEADS, SEQ_LEN, EMB_DIM] 
        # we know we can parallelize across the SEQ_LEN dim. since BATCH * NUM_HEADS essentially is our BATCH DIM
        x_grid = triton.cdiv(q.shape[2], BLOCK_M)  # BLOCK_M is the tile size. If seq len = 1024 and BLOCK_M=128 then NUM_BLOCKS=8. Therefore each block will process BLOCK_M elements of the seq.
        y_grid = q.shape[0] * q.shape[1]           # BATCH * NUM_HEADS, since these are 100 % indepenedent.
        z_grid = 1 
        grid = (x_grid, y_grid, z_grid)
        M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)

        _attn_fwd[grid]( # what is this call ?
            q,
            k,
            v,
            softmax_scale,
            M,
            o,  #
            q.stride(0), q.stride(1), q.stride(2), q.stride(3),  #  Batch, Head, Seq, Emb strides required for correct CUDA memory access. 
            k.stride(0), k.stride(1), k.stride(2), k.stride(3),  #  Batch, Head, Seq, Emb strides required for correct CUDA memory access.
            v.stride(0), v.stride(1), v.stride(2), v.stride(3),  #  Batch, Head, Seq, Emb strides required for correct CUDA memory access.
            o.stride(0), o.stride(1), o.stride(2), o.stride(3),  #  Batch, Head, Seq, Emb strides required for correct CUDA memory access.
            q.shape[0], q.shape[1],  #
            N_CTX=q.shape[2],  #
            HEAD_DIM=HEAD_DIM_K,  #
            STAGE=stage,  #
            **extra_kern_args
        )

        ### Save context for BWD
        ctx.save_for_backward(q, k, v, o, M)
        ctx.grid = grid
        ctx.sm_scale = sm_scale
        ctx.HEAD_DIM = HEAD_DIM_K
        ctx.causal = causal

    
    @staticmethod
    def backward():
        """
        You can make mistakes while implemetning the bwd pass. To help with this we can use the .gradcheck() to use 
        infinite small internal [f(a+h) - f(a-h)] / 2h 
        """
        pass

attention = _attention.apply

# python api
def flash_attention_v2(batch_size: int, num_heads: int, seq_len: int, emb_dim: int, causal:bool, mode):
    dtype = torch.float16
    device='cuda'

    q = torch.randn((batch_size, num_heads, seq_len, emb_dim), dtype=dtype, device=device) 
    k = torch.randn((batch_size, num_heads, seq_len, emb_dim), dtype=dtype, device=device) 
    v = torch.randn((batch_size, num_heads, seq_len, emb_dim), dtype=dtype, device=device) 
    
    if mode == 'fwd':
        q = q.to(torch.float8_e5m2) # why e5m2 for FWD pass ?
        k = k.to(torch.float8_e5m2)  # what is _nuz ?

    softmax_scale = 1.3  
    fn = lambda: attention(q, k, v, causal, softmax_scale) # why use a lambda function here ?
    ms = triton.testing.do_bench(fn) # benchmark the op and report time in milli seconds. 

    

def main():
    BATCH_SIZE = 4
    NUM_HEADS = 16
    SEQ_LEN = 4096
    EMB_DIM = 1024

    # We have many variations to support
    # 1. QKV - Stanards MHA. Most efficient since the tensors are packed and a giant matmul can be launched.
    # 2. QK and V: useful for GQA
    # 3. Q and K and V
    flash_attention_v2()

## Triton Kernel

In [None]:
@triton.jit
def _attention_fwd(
    q, k, v, 
    softmax_scale, 
    output, 
    stride_qz, stride_qh, stride_qm, stride_qk,  #
    stride_kz, stride_kh, stride_kn, stride_kk,  #
    stride_vz, stride_vh, stride_vk, stride_vn,  #
    stride_oz, stride_oh, stride_om, stride_on,  #
    Z, H, N_CTX,  #
    HEAD_DIM: tl.constexpr,  #
    BLOCK_M: tl.constexpr,   #
    BLOCK_N: tl.constexpr,   #
    STAGE: tl.constexpr      #
):
    

**BWD PASS SETUP & LAUNCH**
1. Asserts: Check if the shapes of Q, K, V tensors are good. Check if values are the type and shape they should be
2. Allocate Input and output buffers.
3. Decide how to parallelize the workload: grids, warps, compute units, numblocks, block size w.r.t the workload under consideration
4. Launch FWD Triton Kernel using the `grid[]()`
5. Save CTX needed for BWD