In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl
from triton.runtime import driver

In [5]:
DEVICE = torch.device(f"cuda:{torch.cuda.current_device()}")

In [6]:
properties = driver.active.utils.get_device_properties(DEVICE.index)
properties

{'max_shared_mem': 101376,
 'max_num_regs': 65536,
 'multiprocessor_count': 64,
 'warpSize': 32,
 'sm_clock_rate': 1695000,
 'mem_clock_rate': 8001000,
 'mem_bus_width': 384}

In [None]:
@triton.jit
def attention_kernel(Q_ptr: torch.Tensor, 
                     K_ptr: torch.Tensor, 
                     V_ptr: torch.Tensor, 
                     output_ptr: torch.Tensor,
                     batch_stride: int, seq_stride: int,
                     head_stride: int,
                     B: int, L: int, heads: int, d_k: int,
                     MAX_D_K: tl.constexpr, 
                     BLOCK_SIZE: tl.constexpr, num_stages: tl.constexpr):
    
    pid_BH = tl.program_id(0)
    batch_idx = pid_BH // heads
    head_idx = pid_BH % heads
    
    query_block_idx = tl.program_id(1)
    query_start = query_block_idx * BLOCK_SIZE
    
    num_blocks = tl.cdiv(d_k, BLOCK_SIZE)
    offs_row = tl.arange(0, BLOCK_SIZE)
    offs_col = tl.arange(0, MAX_D_K)
    q_block_ptr = Q_ptr + batch_idx * batch_stride + head_idx * head_stride + query_start * seq_stride
    k_base_ptr = K_ptr + batch_idx * batch_stride + head_idx * head_stride
    v_base_ptr = V_ptr + batch_idx * batch_stride + head_idx * head_stride
    Q_start = q_block_ptr + offs_row[:, None] * MAX_D_K + offs_col[None, :]
    
    q = tl.load(Q_start, 
                mask=(query_start + tl.arange(0, BLOCK_SIZE)[:, None] < L) & (offs_col[None, :] < d_k), 
                other=0.0) # (BLOCK_SIZE, d_k)
    
    
    m_i = tl.full((BLOCK_SIZE,), -float('inf'), tl.float32)
    l_i = tl.zeros((BLOCK_SIZE,), tl.float32)
    acc = tl.zeros((BLOCK_SIZE, d_k), tl.float32)
    for key_block in tl.range(0, L, BLOCK_SIZE, num_stages=num_stages):
        mask = (offs_row  + k * BLOCK_SIZE < d_k)[None, :]
        K_tile_ptrs = k_base_ptr + key_block * seq_stride
        V_tile_ptrs = v_base_ptr + key_block * seq_stride
        mask = (key_block + offs_row[None, :]) < L & (offs_col[:, None] < d_k)
        k_t = tl.load(K_tile_ptrs + offs_row[None, :] * MAX_D_K + offs_col[:, None], 
                      mask=k_mask, other=0.0) # (d_k, BLOCK_SIZE)
        v = tl.load(V_tile_ptrs + offs_row[:, None] * d_k + offs_col[None, :], 
                   mask=mask, other=0.0) # (BLOCK_SIZE, d_k)
        scores = tl.dot(q, k_t)
    
        scores *= tl.rsqrt(d_k * 1.0) # (BLOCK_SIZE, BLOCK_SIZE)
    
        m_new = tl.maximum(m_i, tl.max(scores, axis=1)) # (BLOCK_SIZE)
        exp_scores = tl.exp(scores - m_new[:, None]) # (BLOCK_SIZE, BLOCK_SIZE)
        l_new = l_i * tl.exp(m_i - m_new) + tl.sum(exp_scores, axis=1) # (BLOCK_SIZE)
    
        acc = acc *  tl.exp(m_i - m_new) [:, None] + tl.dot(exp_scores, v) # (BLOCK_SIZE, d_k) 
        m_i = m_new
        l_i = l_new 
    
    
    out = acc / l_i[:, None]
    out_base_idx = batch_idx * batch_stride + head_idx * head_stride + query_start * seq_stride
    out_ptrs = out_ptr + out_base_idx + offs_row[:, None] * d_k + offs_col[None, :]
    out_mask = offs_row[:, None] < (L - query_start)
    tl.store(out_ptrs, out, mask=out_mask)
    
    
    

    
    

In [31]:
tl.transpose?

Object `tl.transpose` not found.


In [None]:
def attention_triton(q, k, v):
    """
    Compute attention using Triton kernel
    Args:
        q: (batch_size, num_heads, seq_len, d_k)
        k: (batch_size, num_heads, seq_len, d_k)
        v: (batch_size, num_heads, seq_len, d_k)
    Returns:
        output: (batch_size, num_heads, seq_len, d_k)
    """
    batch_size, num_heads, seq_len, d_k = q.shape
    
    # Output tensor
    output = torch.empty_like(q)
    
    # Calculate strides
    batch_stride = num_heads * seq_len * d_k
    head_stride = seq_len * d_k
    seq_stride = d_k
    
    # Block size configuration
    BLOCK_SIZE = 32  # Adjust based on your GPU architecture
    num_stages = 3
    
    # Grid configuration
    grid = (batch_size * num_heads, triton.cdiv(L, BLOCK_SIZE))
    
    # Launch kernel
    attention_kernel[grid](
        q, k, v, output,
        batch_stride, seq_stride, head_stride,
        batch_size, seq_len, num_heads, d_k,
        BLOCK_SIZE, num_stages
    )
    
    return output



In [33]:
# Test parameters
batch_size = 2
num_heads = 4
seq_len = 16
d_k = 64

# Generate random data
torch.manual_seed(0)  # For reproducibility
q = torch.randn(batch_size, num_heads, seq_len, d_k, device="cuda", dtype=torch.float32)
k = torch.randn(batch_size, num_heads, seq_len, d_k, device="cuda", dtype=torch.float32)
v = torch.randn(batch_size, num_heads, seq_len, d_k, device="cuda", dtype=torch.float32)

# Compute attention using Triton kernel
output_triton = attention_triton(q, k, v)

CompilationError: at 19:15:
                     BLOCK_SIZE: tl.constexpr, num_stages: tl.constexpr):

    pid_BH = tl.program_id(0)
    batch_idx = pid_BH // heads
    head_idx = pid_BH % heads

    query_block_idx = tl.program_id(1)
    query_start = query_block_idx * BLOCK_SIZE

    num_blocks = tl.cdiv(d_k, BLOCK_SIZE)
    offs_row = tl.arange(0, BLOCK_SIZE)
    offs_col = tl.arange(0, d_k)
               ^

In [17]:
tl.rsqrt??

[0;31mSignature:[0m [0mtl[0m[0;34m.[0m[0mrsqrt[0m[0;34m([0m[0mx[0m[0;34m,[0m [0m_builder[0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Computes the element-wise inverse square root of :code:`x`.

:param x: the input values
:type x: Block
[0;31mSource:[0m   
[0;34m@[0m[0mcore[0m[0;34m.[0m[0mbuiltin[0m[0;34m[0m
[0;34m[0m[0;34m@[0m[0m_check_dtype[0m[0;34m([0m[0mdtypes[0m[0;34m=[0m[0;34m[[0m[0;34m"fp32"[0m[0;34m,[0m [0;34m"fp64"[0m[0;34m][0m[0;34m)[0m[0;34m[0m
[0;34m[0m[0;34m@[0m[0m_add_math_1arg_docstr[0m[0;34m([0m[0;34m"inverse square root"[0m[0;34m)[0m[0;34m[0m
[0;34m[0m[0;34m@[0m[0mcore[0m[0;34m.[0m[0m_tensor_member_fn[0m[0;34m[0m
[0;34m[0m[0;32mdef[0m [0mrsqrt[0m[0;34m([0m[0mx[0m[0;34m,[0m [0m_builder[0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0mx[0m [0;34m=[0m [0mcore[0m[0;34m.[0m[0m_to_tensor[0m[0;3