In [7]:
import torch
import triton
import triton.language as tl
from triton.runtime import driver
from pprint import pprint, pformat
import torch.nn.functional as F

In [3]:
DEVICE = torch.device(f'cuda:{torch.cuda.current_device()}')
properties = driver.active.utils.get_device_properties(DEVICE.index)
pprint(f"{properties=}", underscore_numbers=True)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}

("properties={'max_shared_mem': 101376, 'max_num_regs': 65536, "
 "'multiprocessor_count': 64, 'warpSize': 32, 'sm_clock_rate': 1695000, "
 "'mem_clock_rate': 8001000, 'mem_bus_width': 384}")


In [None]:
@triton.jit
def layernorm_kernel(input_ptr: torch.Tensor, output_ptr: torch.Tensor, 
                     gamma: torch.Tensor, beta: torch.Tensor, 
                     input_row_stride: int,
                     num_rows: int, num_cols: int, 
                     eps: float,
                     BLOCK_SIZE: tl.constexpr, num_stages: tl.constexpr):
    row_idx = tl.program_id(0)
    
    sum_x = 0.0
    sum_squared_x = 0.0
    row_start_ptr = input_ptr + row_idx * input_row_stride
    output_row_start_ptr = output_ptr + row_idx * input_row_stride
    col_offsets = tl.arange(0, BLOCK_SIZE)
    for col in tl.range(0, num_cols, BLOCK_SIZE, num_stages=num_stages):
        row_ptrs = row_start_ptr + col_offsets + col
        mask = (col_offsets + col) < num_cols
        x = tl.load(row_ptrs, mask=mask, other=0.0)
        sum_x += tl.sum(x, where=mask)
        sum_squared_x += t.sum(x * x, where=mask)
    
    mean = sum_x / num_cols
    var = (sum_squared_x / mean) - (mean * mean)
    inv_std = tl.rsqrt(var + eps)
    
    for col in tl.arange(0, num_cols, BLOCK_SIZE, num_stages=num_stages):
        row_ptrs = row_start_ptr + col_offsets + col
        mask = (col_offsets + col) < num_cols
        input_block = tl.load(row_ptrs, mask=mask, other=0.0)
        gammas = tl.load(gamma + col, mask=mask, other=0.0)
        betas = tl.load(beta + col, mask=mask, other=0.0)
        norm = (input_block - mean) * inv_std
        output = norm * gammas + betas
        tl.store(output_row_start_ptr + col_offsets + col, output, mask=mask)

        
        

In [16]:
def layernorm(a: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    b = torch.zeros_like(a)
    M, N = a.shape
    num_stages = 3
    # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
    BLOCK_SIZE = 128 #triton.next_power_of_2(n_cols)
    grid= (M, (N + BLOCK_SIZE - 1) // BLOCK_SIZE)
    layernorm_kernel[grid](a, b, gamma, beta,a.stride(0), M, N, eps, BLOCK_SIZE=BLOCK_SIZE, num_stages=num_stages)
    return b

In [17]:
torch.manual_seed(0)
a = torch.randn((64, 512), device=DEVICE, dtype=torch.float32)
gamma = torch.randn(512, device=DEVICE, dtype=torch.float32)
beta = torch.randn(512, device=DEVICE, dtype=torch.float32)
triton_output = layernorm(a, gamma, beta, eps=1e-5)
torch_output = F.layer_norm(A, normalized_shape=(D,), weight=gamma, bias=beta, eps=1e-5)
if torch.allclose(triton_output, torch_output, atol=0.125, rtol=0):
    print("✅ Triton and Torch match")
else:
    print("❌ Triton and Torch differ")

CompilationError: at 14:4:
                     input_row_stride: int,
                     num_rows: int, num_cols: int, 
                     eps: float,
                     BLOCK_SIZE: tl.constexpr, num_stages: tl.constexpr):
    row_idx = tl.program_id(0)

    sum_x = 0.0
    sum_squared_x = 0.0
    row_start_ptr = input_ptr + row_idx * input_row_stride
    output_row_start_ptr = output_ptr + row_idx * input_row_stride
    col_offsets = tl.arange(0, BLOCK_SIZE)
    for col in tl.range(0, num_cols, BLOCK_SIZE, num_stages=num_stages):
    ^
AssertionError('Loop-carried variable sum_x has initial type fp32 but is re-assigned to <[128], fp32> in loop! Please make sure that the type stays consistent.')