In [4]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [5]:
import triton 
import torch 
import triton.language as tl

In [6]:
DEVICE = torch.device("cuda")

In [47]:
@triton.jit 
def __attn_fwd_inner(Q,O,L,M,K_ptr,V_ptr,K_T_offsets,V_offsets,block_index_QO,softmax_scale,stride_K_SEQ_LEN,
                    stride_V_SEQ_LEN,BLOCK_SIZE_QO:tl.constexpr,BLOCK_SIZE_KV:tl.constexpr,
                    DIAGONAL:tl.constexpr,offset_QO_SEQ_LEN,offset_KV_SEQ_LEN,
                    SEQ_LEN:tl.constexpr,HEAD_DIM:tl.constexpr):
    if DIAGONAL:
        lo = block_index_QO* BLOCK_SIZE_QO
        hi = (block_index_QO+1)* BLOCK_SIZE_QO

        lo = tl.multiple_of(lo,BLOCK_SIZE_QO)
    else:
        lo = 0 
        hi = block_index_QO* BLOCK_SIZE_QO
    #move offsets 
    K_T_offsets+= lo*stride_K_SEQ_LEN
    V_offsets+= lo* stride_V_SEQ_LEN
    offset_KV_SEQ_LEN+= lo 

    for start_kv in range(lo,hi,BLOCK_SIZE_KV):
        start_kv = tl.multiple_of(start_kv,BLOCK_SIZE_KV)

        mask_KV_SEQ_LEN = offset_KV_SEQ_LEN < SEQ_LEN

        K_DATA_T  = tl.load(K_ptr+K_T_offsets,mask = mask_KV_SEQ_LEN[:,None],other=0.0)
        K_DATA_T = tl.trans(K_DATA_T)
        S = tl.dot(Q,K_DATA_T) * softmax_scale

        if DIAGONAL:
            causal_mask  = offset_QO_SEQ_LEN[:,None] >= offset_KV_SEQ_LEN[None,:]
            S = tl.where(causal_mask,S,float('-inf'))
        M_new = tl.maximum(M,tl.max(S,axis=1))
        S -= M_new[:,None]
        P = tl.exp2(S) #BLOCK_SIZE_OQ * BLOCK_SIZE_KV
        L_new = tl.sum(P,axis=1)

        alpha = tl.exp2(M-M_new) # BLOCK_SIZE_OQ
        L = L *alpha + L_new
        V = tl.load(V_ptr+V_offsets,mask =  mask_KV_SEQ_LEN[:,None],other=0.0)
        O = O* alpha[:,None] # ALPHA[:,None] = shape of alpha is [BLOCK_SIZE_OQ,1]
        #o shape is BLOCK_SIZE_OQ * HEAD_DIM
        O = tl.dot(P,V,acc=O)
        M = M_new
        K_T_offsets += BLOCK_SIZE_KV * stride_K_SEQ_LEN
        V_offsets += BLOCK_SIZE_KV * stride_V_SEQ_LEN
        offset_KV_SEQ_LEN += BLOCK_SIZE_KV
    return O,L,M
        
        

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE_QO': 16, 'BLOCK_SIZE_KV': 16}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_QO': 16, 'BLOCK_SIZE_KV': 32}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_QO': 16, 'BLOCK_SIZE_KV': 64}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_QO': 32, 'BLOCK_SIZE_KV': 16}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_QO': 32, 'BLOCK_SIZE_KV': 32}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_QO': 32, 'BLOCK_SIZE_KV': 64}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_QO': 64, 'BLOCK_SIZE_KV': 16}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_QO': 64, 'BLOCK_SIZE_KV': 32}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_QO': 64, 'BLOCK_SIZE_KV': 64}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_QO': 128, 'BLOCK_SIZE_KV': 32}, num_stages=4, num_warps=8),
        triton.Config({'BLOCK_SIZE_QO': 128, 'BLOCK_SIZE_KV': 64}, num_stages=4, num_warps=8),
    ],
    key=['SEQ_LEN', 'HEAD_DIM'],
)
@triton.jit
def attn_fwd(Q_PTR,K_PTR,V_PTR,O_PTR,LSE_PTR,softmax_scale,
            stride_Q_BATCH,stride_Q_N_HEADS,stride_Q_SEQ_LEN,stride_Q_HEAD_DIM,
            stride_K_BATCH,stride_K_N_HEADS,stride_K_SEQ_LEN,stride_K_HEAD_DIM,
            stride_V_BATCH,stride_V_N_HEADS,stride_V_SEQ_LEN,stride_V_HEAD_DIM,
            stride_O_BATCH,stride_O_N_HEADS,stride_O_SEQ_LEN,stride_O_HEAD_DIM,
            stride_LSE_BATCH,stride_LSE_N_HEADS,stride_LSE_SEQ_LEN,
            BATCH_SIZE,N_HEADS:tl.constexpr,SEQ_LEN:tl.constexpr,HEAD_DIM:tl.constexpr,
            BLOCK_SIZE_QO:tl.constexpr,BLOCK_SIZE_KV:tl.constexpr 
            ):
    rln2: tl.constexpr = 1.4426950408889634
    softmax_scale *= rln2
    tl.static_assert(BLOCK_SIZE_KV<=HEAD_DIM)
    block_index_QO = tl.program_id(0)
    block_head_id = tl.program_id(1)
    batch_index  = block_head_id // N_HEADS
    head_index = block_head_id % N_HEADS

    ##ptr moving to correct batch and head 
    Q_PTR+= batch_index * stride_Q_BATCH + head_index*stride_Q_N_HEADS
    K_PTR+= batch_index * stride_K_BATCH + head_index*stride_K_N_HEADS
    V_PTR+= batch_index * stride_V_BATCH + head_index*stride_V_N_HEADS
    O_PTR+= batch_index * stride_O_BATCH + head_index*stride_O_N_HEADS

    offset_QO_SEQ_LEN = block_index_QO* BLOCK_SIZE_QO + tl.arange(0,BLOCK_SIZE_QO)
    offset_KV_SEQ_LEN  = tl.arange(0,BLOCK_SIZE_KV)
    offset_HEAD_DIM = tl.arange(0,HEAD_DIM)

    Q_offsets  = offset_QO_SEQ_LEN[:,None] * stride_Q_SEQ_LEN + offset_HEAD_DIM[None,:]* stride_Q_HEAD_DIM
    K_T_offsets  =  offset_KV_SEQ_LEN[:,None]* stride_K_SEQ_LEN + offset_HEAD_DIM[None,:]* stride_K_HEAD_DIM #not yet transposed we tranpose while loading 
    V_offsets  = offset_KV_SEQ_LEN[:,None]* stride_V_SEQ_LEN + offset_HEAD_DIM[None,:]* stride_V_HEAD_DIM
    mask_OQ_SEQ_LEN = offset_QO_SEQ_LEN < SEQ_LEN

    Q_data = tl.load(Q_PTR+Q_offsets,mask = mask_OQ_SEQ_LEN[:,None],other= 0.0)

    M = tl.full(shape=[BLOCK_SIZE_QO],value = float("-inf"),dtype=tl.float32)
    L = tl.full(shape=[BLOCK_SIZE_QO],value = float(1),dtype=tl.float32)
    O = tl.zeros([BLOCK_SIZE_QO,HEAD_DIM],dtype=tl.float32)

    O,L,M = __attn_fwd_inner(
        Q_data,O,L,M,K_PTR,V_PTR,K_T_offsets,V_offsets,block_index_QO,softmax_scale,stride_K_SEQ_LEN,stride_V_SEQ_LEN,
        BLOCK_SIZE_QO,BLOCK_SIZE_KV,False,offset_QO_SEQ_LEN,offset_KV_SEQ_LEN,SEQ_LEN,HEAD_DIM
        
    )

    O,L,M = __attn_fwd_inner(
        Q_data,O,L,M,K_PTR,V_PTR,K_T_offsets,V_offsets,block_index_QO,softmax_scale,stride_K_SEQ_LEN,stride_V_SEQ_LEN,
        BLOCK_SIZE_QO,BLOCK_SIZE_KV,True,offset_QO_SEQ_LEN,offset_KV_SEQ_LEN,SEQ_LEN,HEAD_DIM
        
    )

    O = O / L[:,None] #BLOCK_SIZE_OQ,HEAD_DIM / BLOCK_SIZE_OQ,1 = BLOCK_SIZE_OQ,HEAD_DIM
    LSE = M+ tl.math.log2(L)
    LSE_offsets = batch_index*stride_LSE_BATCH + offset_QO_SEQ_LEN
    LSE_mask = block_index_QO * BLOCK_SIZE_QO + tl.arange(0, BLOCK_SIZE_QO) < SEQ_LEN
    tl.store(LSE_PTR + LSE_offsets, LSE, mask=LSE_mask) # shape (BLOCK_SIZE_QO)

    O_offsets = offset_QO_SEQ_LEN[:,None] * stride_O_SEQ_LEN + offset_HEAD_DIM[None,:]* stride_O_HEAD_DIM
    tl.store(O_PTR+O_offsets,O,mask=mask_OQ_SEQ_LEN[:,None])




class FlashAttn(torch.autograd.Function):
    @staticmethod
    def forward(ctx,q,k,v,scale):
        assert q.shape[-1]==k.shape[-1]==v.shape[-1]
        assert q.shape[-1]<=128 #only support max_head dim of size 128
        assert q.device==k.device==v.device
        assert q.dtype == k.dtype == v.dtype == torch.float32
        BS,N_HEADS,SEQ_LEN,HEAD_DIM = q.shape
        O = torch.empty_like(q) 
        LSE =  torch.empty((BS,N_HEADS,SEQ_LEN),device=q.device,dtype=torch.float32)

        grid = lambda args : (
            triton.cdiv(SEQ_LEN,args['BLOCK_SIZE_QO']),
            BS*N_HEADS
        )  #we dont do it in other way because sm's work on cache and its better to have them stack same kv blocks together 

        #0,0,1,0,  
        attn_fwd[grid](
            q,k,v,O,LSE,scale,
            q.stride(0),q.stride(1),q.stride(2),q.stride(3),
            k.stride(0),k.stride(1),k.stride(2),k.stride(3),
            v.stride(0),v.stride(1),v.stride(2),v.stride(3),
            O.stride(0),O.stride(1),O.stride(2),O.stride(3),
            LSE.stride(0),LSE.stride(1),LSE.stride(2),
            BS,N_HEADS,SEQ_LEN,HEAD_DIM,
            
        )
        ctx.save_for_backward(q,k,v,O,LSE)
        ctx.grid = grid
        ctx.BS = BS
        ctx.N_HEADS = N_HEADS
        ctx.SEQ_LEN = SEQ_LEN 
        ctx.HEAD_DIM = HEAD_DIM
        ctx.scale = scale
        return O
        


In [48]:
triton_attention = FlashAttn.apply
def test_flash_attn(BS,N_HEADS,SEQ_LEN,HEAD_DIM,device=DEVICE,atol=5e-3):
    q = torch.randn(BS,N_HEADS,SEQ_LEN,HEAD_DIM,device=device,dtype=torch.float32)
    k = torch.randn(BS,N_HEADS,SEQ_LEN,HEAD_DIM,device=device,dtype=torch.float32)
    v = torch.randn(BS,N_HEADS,SEQ_LEN,HEAD_DIM,device=device,dtype=torch.float32)
    scale = 1/(HEAD_DIM**0.5)
    tri_out = triton_attention(q,k,v,scale)
    ref_out = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True)
    triton.testing.assert_close(tri_out,ref_out,atol=atol,rtol=0)
    return tri_out,ref_out
    
    

In [49]:
test_flash_attn(1,1,128,32)

(tensor([[[[ 1.7829e-01, -1.5816e-01,  1.2091e-01,  ...,  2.4630e+00,
             7.9585e-01, -2.1501e-01],
           [-1.9836e-01, -1.0919e-03,  2.2382e-01,  ...,  2.2296e+00,
             1.3938e-01, -2.9992e-01],
           [-2.1720e-01, -1.1830e-01, -1.6409e-02,  ...,  1.7270e+00,
            -3.0924e-01,  2.0510e-01],
           ...,
           [-7.1058e-02,  2.1507e-01,  3.2266e-01,  ...,  5.2523e-02,
             9.8962e-03,  1.8218e-01],
           [ 6.1470e-02,  9.7897e-02,  2.0129e-01,  ...,  9.8208e-02,
             1.2815e-02,  5.8018e-02],
           [ 2.0506e-01,  7.9163e-02,  1.3703e-01,  ...,  7.2430e-02,
             3.2650e-02, -1.7212e-01]]]], device='cuda:0'),
 tensor([[[[ 1.7829e-01, -1.5816e-01,  1.2091e-01,  ...,  2.4630e+00,
             7.9585e-01, -2.1501e-01],
           [-1.9836e-01, -1.0919e-03,  2.2382e-01,  ...,  2.2296e+00,
             1.3938e-01, -2.9992e-01],
           [-2.1720e-01, -1.1830e-01, -1.6409e-02,  ...,  1.7270e+00,
            -3.0924e-