In [4]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [5]:
import torch
import triton
import triton.language as tl 

In [6]:
@triton.jit 
def attn_impl(Q,K,V,softmax_scale,M,O,stride_Q_batch,
            stride_Q_head ,
            stride_Q_seq ,
            stride_Q_dim ,
            stride_K_batch ,
            stride_K_head,
            stride_K_seq ,
            stride_K_dim,
            stride_V_batch,
            stride_V_head,
            stride_V_seq,
            stride_V_dim,
            stride_O_batch,
            stride_O_head,
            stride_O_seq,
            stride_O_dim,
            BATCH_SIZE,
            NUM_HEADS:tl.constexpr,
            SEQ_LEN:tl.constexpr,
            HEAD_DIM:tl.constexpr,
            BLOCK_SIZE_Q: tl.constexpr,
            BLOCK_SIZE_KV: tl.constexpr,
            STAGE :tl.constexpr, 
               ):
    """
    Q - query matrix of shape (bs,num_heads,seq_len,head_dim)
    K AND V OF same shape 
    M - for storing maximums 
    O - for storing output of PV where p is the attn-meatrix
    BATCH_SIZE: int for the batch of sequences 
    NUM_HEADS : int for number of heafs 
    SEQ_LEN : no of tokens
    HEAD_DIM : per head dim of token genrally ndim/num_heads 
    BLOCK_SIZE_Q : the number of tokens processed per block so we take block size of tokens 
    BLOCK_SIZE_KV : the blocks of kv to process per query per loop iteration (note we need all KV for quries anyway but here we say for per iter in a sm we load KV blocks instead of loading them fully)
    STAGE: stage telling us if we are doing causal or non causal attention 
    """
    #here we assert to make sure we dont load more than 
    tl.static_assert(BLOCK_SIZE_KV <= SEQ_LEN)
    block_index_q = tl.program_id(0) # if u remember we do grid launch with (cdiv(seq-len,block_size,q),bs*n_heads)
    index_batch_head = tl.program_id(1) #tells us what head we are on in what batch so 0 saying batch 0 head 0 , 1 saying batch 0 head 1 ... 
    index_batch = index_batch_head // NUM_HEADS  #what batch are we processing
    index_head  = index_batch_head % NUM_HEADS #what index we are in for a head in a batch 

    qvk_offset = (
        index_batch.to(tl.int64)*stride_Q_batch+ index_head.to(tl.int64)*stride_Q_head
    ) #to land the pointer exactly where we want to be so moving in batches and heads to make q point to correct location 

    Q_block_ptr  = tl.make_block_ptr(
        base = Q + qvk_offset , #this points to a correct start of a head and batch we wanna be in
        shape = (SEQ_LEN,HEAD_DIM),
        strides = (stride_Q_seq,stride_Q_dim),
        offsets = (block_index_q*BLOCK_SIZE_Q,0),
        block_shape = (BLOCK_SIZE_Q,HEAD_DIM),
        order = (1,0),)
    
    V_block_ptr  = tl.make_block_ptr(
        base  = V+ qvk_offset, 
        shape = (SEQ_LEN,HEAD_DIM),
        strides = (stride_V_seq,stride_V_dim),
        offsets = (0,0),
        block_shape = (BLOCK_SIZE_KV,HEAD_DIM),
        order = (1,0),
        
    )
        
    K_block_ptr  = tl.make_block_ptr(
        base  = K + qvk_offset, 
        shape = (HEAD_DIM,SEQ_LEN),
        strides = (stride_K_dim,stride_K_seq),
        offsets = (0,0),
        block_shape = (HEAD_DIM,BLOCK_SIZE_KV),
        order = (0,1),
        
    )
    O_block_ptr  = tl.make_block_ptr(
        base =  O + qvk_offset , #this points to a correct start of a head and batch we wanna be in
        shape = (SEQ_LEN,HEAD_DIM),
        strides = (stride_O_seq,stride_O_dim),
        offsets = (block_index_q*BLOCK_SIZE_Q,0),
        block_shape = (BLOCK_SIZE_Q,HEAD_DIM),
        order = (1,0))
    offs_q = block_index_q * BLOCK_SIZE_Q + tl.arange(0,BLOCK_SIZE_Q) #these are the queries we load in above block_ptr 
    offs_kv = tl.arange(0,BLOCK_SIZE_KV) #we load all keys so index represente key tokens but we use all for both kv
    m_i = tl.zeros([BLOCK_SIZE_Q],dtype=tl.float16)-float('inf') 
    l_i = tl.zeros([BLOCK_SIZE_Q],dtype=tl.float16)  #vector of shape block_sizeq to store sum and max
    O_block = tl.zeros([BLOCK_SIZE_Q,HEAD_DIM],dtype=tl.float16)

    Q_block = tl.load(Q_block_ptr)

    #stage 1 of causal attending to all keys before query block 
    past_hi = (block_index_q + 1) * BLOCK_SIZE_Q
    start_kv = 0
    while start_kv < past_hi and start_kv < SEQ_LEN:
        start_kv = tl.multiple_of(start_kv,BLOCK_SIZE_KV)
        end_kv = start_kv + BLOCK_SIZE_KV #for start_kv 0 its BLOCK_SIZE_KV for ex(0+64) = 64
        need_causal_mask = end_kv > block_index_q * BLOCK_SIZE_Q 
        
        K_block = tl.load(K_block_ptr)
        QK_block = tl.dot(Q_block , K_block)
        QK_block = QK_block * softmax_scale 

        if need_causal_mask:
            offs_kv_current = start_kv + tl.arange(0,BLOCK_SIZE_KV)
            mask = offs_q[:,None] >= offs_kv_current[None,:]
            QK_block = tl.where(mask,QK_block,float('-inf'))

        
        m_ij = tl.maximum(m_i,tl.max(QK_block,1))
        QK_block = QK_block - m_ij[:,None] #bq,bk - [bq,1] #(bq,bk) - m_ij # block_q 
        P_block = tl.exp(QK_block)
        l_ij = tl.sum(P_block,1)

        alpha = tl.exp(m_i-m_ij) #0 if both same 
        l_i = l_i * alpha + l_ij
        m_i = m_ij
        V_block = tl.load(V_block_ptr)
        O_block  = O_block * alpha[:,None]
        O_block = tl.dot(P_block,V_block,O_block)

        V_block_ptr = tl.advance(V_block_ptr,(BLOCK_SIZE_KV,0))
        K_block_ptr = tl.advance(K_block_ptr,(0,BLOCK_SIZE_KV))

        
        start_kv += BLOCK_SIZE_KV

    l_i = l_i + 1e-6
    O_block = O_block / l_i[:,None]

    tl.store(O_block_ptr, O_block)

    
    

In [11]:
class TritonAttention(torch.autograd.Function):
    def forward(ctx,Q,K,V,causal,softmax_scale):
        assert Q.shape[-1]==V.shape[-1]==K.shape[-1]
        BATCH_SIZE,SEQ_LEN,NUM_HEADS,HEAD_DIM = Q.shape
        O = torch.empty_like(Q) #has to be same shape as Q matrix 
        stage = 3 if causal else 1
        grid = lambda args: (
            triton.cdiv(SEQ_LEN,args['BLOCK_SIZE_Q']), # which group of query we are working on
            BATCH_SIZE*NUM_HEADS,  # which head in a batch are we on
            1
        )
        M = torch.empty(BATCH_SIZE,NUM_HEADS,SEQ_LEN,device=Q.device,dtype=torch.float32) #this is for storing all maximums for later use in backward pass?

        BLOCK_SIZE_Q = max(16,min(64, SEQ_LEN))
        BLOCK_SIZE_KV = max(16,min(64, SEQ_LEN))
        
        attn_impl[grid](
            Q=Q,
            K=K,
            V=V,
            softmax_scale=softmax_scale,
            M=M , # not mask 
            O=O,
            stride_Q_batch = Q.stride(0),
            stride_Q_head = Q.stride(1),
            stride_Q_seq = Q.stride(2),
            stride_Q_dim = Q.stride(3),
            stride_K_batch = K.stride(0),
            stride_K_head = K.stride(1),
            stride_K_seq = K.stride(2),
            stride_K_dim = K.stride(3),
            stride_V_batch = V.stride(0),
            stride_V_head = V.stride(1),
            stride_V_seq = V.stride(2),
            stride_V_dim = V.stride(3),
            stride_O_batch = O.stride(0),
            stride_O_head = O.stride(1),
            stride_O_seq = O.stride(2),
            stride_O_dim = O.stride(3),
            BATCH_SIZE = Q.shape[0],
            NUM_HEADS = Q.shape[1],
            SEQ_LEN = Q.shape[2],
            HEAD_DIM = Q.shape[3],
            STAGE=stage,
            BLOCK_SIZE_Q = BLOCK_SIZE_Q,
            BLOCK_SIZE_KV = BLOCK_SIZE_KV
        )
        ctx.save_for_backward(Q,K,V,O,M)
        ctx.grid = grid
        ctx.softmax_scale = softmax_scale
        ctx.HEAD_DIM = HEAD_DIM_K
        
        ctx.causal = causal

        return O


        
        pass
    def backward():
        pass

In [12]:
def test_op(BATCH_SIZE,SEQ_LEN,NUM_HEADS,HEAD_DIM,causal,dtype=torch.float32):
    Q = (
        torch.empty(
            (BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM), dtype=dtype, device="cuda"
        )
        .normal_(mean=0.0, std=0.5)
        .requires_grad_()
    ).to(torch.float16)
    K = (
        torch.empty(
            (BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM), dtype=dtype, device="cuda"
        )
        .normal_(mean=0.0, std=0.5)
        .requires_grad_()
    ).to(torch.float16)
    V = (
        torch.empty(
            (BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM), dtype=dtype, device="cuda"
        )
        .normal_(mean=0.0, std=0.5)
        .requires_grad_()
    ).to(torch.float16)

    softmax_scale = 1 / (HEAD_DIM**0.5)
    P = torch.matmul(Q,K.transpose(2,3))* softmax_scale 
    MASK = torch.tril(torch.ones(SEQ_LEN,SEQ_LEN,device='cuda'))
    if causal:
        P[:,:,MASK==0] = float('-inf')
    P = torch.softmax(P,dim=-1)
    naive_out = torch.matmul(P,V).half()
    triton_out  = TritonAttention.apply(Q,K,V,causal,softmax_scale).to(torch.float16)
    assert triton.testing.assert_close(naive_out,triton_out,atol=1e-2,rtol=0)


In [None]:
test_op(1,32,2,32,1)