In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl
from triton.runtime import driver

In [4]:
DEVICE = torch.device(f"cuda:{torch.cuda.current_device()}")

In [6]:
properties = driver.active.utils.get_device_properties(DEVICE.index)
properties

{'max_shared_mem': 101376,
 'max_num_regs': 65536,
 'multiprocessor_count': 64,
 'warpSize': 32,
 'sm_clock_rate': 1695000,
 'mem_clock_rate': 8001000,
 'mem_bus_width': 384}

In [None]:
@triton.jit
def attention_kernel(Q_ptr: torch.Tensor, 
                     K_ptr: torch.Tensor, 
                     V_ptr: torch.Tensor, 
                     output_ptr: torch.Tensor,
                     batch_stride: int, seq_stride: int,
                     head_stride: int,
                     B: int, L: int, heads: int, d_k: int, 
                     BLOCK_SIZE: tl.constexpr, num_stages: tl.constexpr):
    
    pid_BH = tl.program_id(0)
    batch_idx = pid_BH // heads
    head_idx = pid_BH % heads
    
    seq_idx = tl.program_id(1)
    pid_block = tl.program_id(2)
    
    num_blocks = tl.cdiv(d_k, BLOCK_SIZE)
    offs_row = pid_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    offs_col = tl.arange(0, BLOCK_SIZE)
    q_base_idx = batch_idx * batch_stride + head_idx * head_stride + seq_idx * seq_stride
    k_base_idx = batch_idx * batch_stride + head_idx * head_stride + seq_idx * seq_stride
    v_base_idx = batch_idx * batch_stride + head_idx * head_stride + seq_idx * seq_stride
    Q_start = Q_ptr + q_base_idx + offs_row[:, None] * d_k + offs_col[None, :]
    K_start = K_ptr + k_base_idx + offs_row[:, None] + offs_col[None, :] * L
    V_start = V_ptr + v_base_idx + offs_row[:, None] * d_k + offs_col[None, :]
    out = tl.zeros((BLOCK_SIZE, BLOCK_SIZE), dtype=tl.float32)
    for k in tl.range(0, num_blocks, num_stages=num_stages):
        mask = (offs_row  + k * BLOCK_SIZE < d_k)[None, :]
        Q_tile_ptrs = Q_start + k * BLOCK_SIZE
        K_tile_ptrs = K_start + k * BLOCK_SIZE
        q = tl.load(Q_tile_ptrs, mask=mask, other=0.0)
        k_t = tl.load(K_tile_ptrs, mask=mask, other=0.0)
        out += tl.dot(q, k_t)
    
    out *= tl.rsqrt(d_k)

    
    