In [1]:
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from triton.runtime import driver

In [11]:
DEVICE = torch.device(f"cuda:{torch.cuda.current_device()}")

In [13]:
properties = driver.active.utils.get_device_properties(DEVICE.index)
properties

{'max_shared_mem': 101376,
 'max_num_regs': 65536,
 'multiprocessor_count': 64,
 'warpSize': 32,
 'sm_clock_rate': 1695000,
 'mem_clock_rate': 8001000,
 'mem_bus_width': 384}

In [None]:
@triton.jit
def relu_kernel(input_ptr: torch.Tensor, output_ptr: torch.Tensor, batch_stride: int, seq_stride: int, 
                B: int, L: int, H: int, 
                BLOCK_SIZE: tl.constexpr, num_stages: tl.constexpr):
    batch_idx = tl.program_id(0)
    seq_idx = tl.program_id(1)
    
    base_idx = batch_idx * batch_stride + seq_idx * seq_stride
    input_start_ptr = input_ptr + base_idx
    output_start_ptr = output_ptr + base_idx
    
    offs = tl.arange(0, BLOCK_SIZE)
    for k in tl.range(0, tl.cdiv(H, BLOCK_SIZE)):
        offs_k = k * BLOCK_SIZE + offs
        input_ptrs = input_start_ptr + offs_k
        mask = offs_k < H
        x = tl.load(input_ptrs, mask=mask, other=0.0)
        out = tl.max(x, 0)
        tl.store(output_start+ ptr + offs_k, out, mask=mask)

    