In [1]:
import triton
import triton.language as tl
import torch

@triton.jit
def kernel(X, Y, Z, n: tl.constexpr):
    i = tl.program_id(0)
    off = i * 1024 + tl.arange(0, 1024)
    tl.store(Z + off, tl.load(X + off) + tl.load(Y + off))

# compile a kernel instance
import numpy as np
x = torch.randn(1024, device='cuda')
y = torch.randn(1024, device='cuda')
z = torch.empty_like(x)

# force compilation
kernel[(1,)](x, y, z, 1024)
z

tensor([ 1.1750, -0.6913, -1.1015,  ..., -0.5907,  2.2528, -0.2005],
       device='cuda:0')

In [2]:
def value(dict_):
    assert len(dict_)==1, 'dict has more than one value' # we're assuming a single env & a single input set
    return list(dict_.values())[0]

In [18]:
print(value(value(kernel.cache)).asm['ptx'])

//
// Generated by LLVM NVPTX Back-End
//

.version 8.4
.target sm_86
.address_size 64

	// .globl	kernel                  // -- Begin function kernel
                                        // @kernel
.visible .entry kernel(
	.param .u64 .ptr .global .align 1 kernel_param_0,
	.param .u64 .ptr .global .align 1 kernel_param_1,
	.param .u64 .ptr .global .align 1 kernel_param_2
)
.reqntid 128, 1, 1
{
	.reg .pred 	%p<7>;
	.reg .b32 	%r<31>;
	.reg .f32 	%f<25>;
	.reg .b64 	%rd<11>;
	.loc	1 6 0                           // 3810643283.py:6:0
$L__func_begin0:
	.loc	1 6 0                           // 3810643283.py:6:0

// %bb.0:
	ld.param.u64 	%rd7, [kernel_param_0];
	ld.param.u64 	%rd8, [kernel_param_1];
$L__tmp0:
	.loc	1 7 22                          // 3810643283.py:7:22
	// begin inline asm
	mov.u32 %r1, %ctaid.x;
	// end inline asm
	.loc	1 8 14                          // 3810643283.py:8:14
	shl.b32 	%r26, %r1, 10;
	ld.param.u64 	%rd9, [kernel_param_2];
	.loc	1 8 34                        

In [3]:
print(value(kernel.device_caches[0][0]).asm["ptx"])

AttributeError: 'JITFunction' object has no attribute 'device_caches'

In [20]:
code = value(value(kernel.cache)).asm['ptx']

def san(x:str):
    return x.replace("%", "%%")
    

print('\n'.join(["\"" + san(x) + "\\n\"" for x in code.split("\n")]))

"//\n"
"// Generated by LLVM NVPTX Back-End\n"
"//\n"
"\n"
".version 8.4\n"
".target sm_86\n"
".address_size 64\n"
"\n"
"	// .globl	kernel                  // -- Begin function kernel\n"
"                                        // @kernel\n"
".visible .entry kernel(\n"
"	.param .u64 .ptr .global .align 1 kernel_param_0,\n"
"	.param .u64 .ptr .global .align 1 kernel_param_1,\n"
"	.param .u64 .ptr .global .align 1 kernel_param_2\n"
")\n"
".reqntid 128, 1, 1\n"
"{\n"
"	.reg .pred 	%%p<7>;\n"
"	.reg .b32 	%%r<31>;\n"
"	.reg .f32 	%%f<25>;\n"
"	.reg .b64 	%%rd<11>;\n"
"	.loc	1 6 0                           // 3810643283.py:6:0\n"
"$L__func_begin0:\n"
"	.loc	1 6 0                           // 3810643283.py:6:0\n"
"\n"
"// %%bb.0:\n"
"	ld.param.u64 	%%rd7, [kernel_param_0];\n"
"	ld.param.u64 	%%rd8, [kernel_param_1];\n"
"$L__tmp0:\n"
"	.loc	1 7 22                          // 3810643283.py:7:22\n"
"	// begin inline asm\n"
"	mov.u32 %%r1, %%ctaid.x;\n"
"	// end inline asm\n"
"	.loc	1 8 14      

## Matmul

In [24]:
import triton
import triton.language as tl
import torch

# Kernel: computes C = A @ B
@triton.jit
def matmul_kernel(
    A_ptr, B_ptr, C_ptr,       # pointers to matrices
    M, N, K,                   # matrix dimensions
    stride_am, stride_ak,      # A strides
    stride_bk, stride_bn,      # B strides
    stride_cm, stride_cn,      # C strides
    BLOCK_M: tl.constexpr,     # block sizes
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    # Compute block offsets
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    for k in range(0, K, BLOCK_K):
        offs_k = k + tl.arange(0, BLOCK_K)
        a = tl.load(A_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak))
        b = tl.load(B_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn))
        acc += tl.dot(a, b)

    tl.store(C_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn), acc)

# Python wrapper
def matmul(A, B):
    M, K = A.shape
    K, N = B.shape
    C = torch.empty((M, N), device=A.device, dtype=torch.float32)

    BLOCK = 64  # you can tune this
    grid = (triton.cdiv(M, BLOCK), triton.cdiv(N, BLOCK))
    matmul_kernel[grid](
        A, B, C,
        M, N, K,
        A.stride(0), A.stride(1),
        B.stride(0), B.stride(1),
        C.stride(0), C.stride(1),
        BLOCK, BLOCK, BLOCK
    )
    return C

# Example usage
A = torch.randn(128, 128, device='cuda', dtype=torch.float32)
B = torch.randn(128, 128, device='cuda', dtype=torch.float32)
C = matmul(A, B)
print(torch.allclose(C, A @ B))


False


In [25]:
M, K = A.shape
K, N = B.shape
C = torch.empty((M, N), device=A.device, dtype=torch.float32)

BLOCK = 64  # you can tune this
grid = (triton.cdiv(M, BLOCK), triton.cdiv(N, BLOCK))
matmul_kernel[grid](
    A, B, C,
    M, N, K,
    A.stride(0), A.stride(1),
    B.stride(0), B.stride(1),
    C.stride(0), C.stride(1),
    BLOCK, BLOCK, BLOCK
)
# return C

<triton.compiler.compiler.CompiledKernel at 0x7ff53c276da0>

In [27]:
code = value(value(matmul_kernel.cache)).asm['ptx']


In [54]:
params = [x for x in code.split('\n') if 'ld.param' in x]
param_to_reg = {}

for line in params:
    # Extract register
    reg = line.split()[1].rstrip(',')  
    # Extract param number
    param_num = int(line.split('_')[-1].rstrip('];'))
    param_to_reg[param_num] = reg

print(param_to_reg)
import re
pattern = re.compile(r'%r(d?)[^<]*<(\d+)>')


# reg_c = {}
for line in code.split('\n'):
    match = pattern.search(line)
    if match:
        reg_suffix = match.group(1)   # 'd' if %rd, '' if %r
        reg_type = 'rd' if reg_suffix == 'd' else 'r'
        reg_count = int(match.group(2))
        reg_c[reg_type] = reg_count
        # print(f"Register type: {reg_type}, Count: {reg_count}")
# print(reg_c)


def san(x:str):
    return x.replace("%", "%%")

def gen_dcl(reg: str, idx, reg_num):
    # print(reg)
    if "rd" in reg:
        return ".reg .u64 %%rd"+str(reg_num) + ";\nmov.u64 %%rd"+str(reg_num) + ", %" + str(idx) + ";\n"
    elif "r" in reg:
        return ".reg .u32 %%r"+str(reg_num) + ";\nmov.u32 %%r"+str(reg_num) + ", %" + str(idx) + ";\n"


start = False
end = False
new_code = ""
new_map = {}
for line in code.split('\n'):
    if '.reqntid' in line:
        start = True
        continue

    if not start:
        continue

    if end:
        break

    if "}" in line:
        new_code += "}\n"
        end = True
        break
    
    if "{" in line:
        new_code += san(line) + "\n"
        for k, v in param_to_reg.items():
            if "rd" in v:
                new_map[v] = "%rd" + str(reg_c["rd"] + 1)
                new_code += gen_dcl(v, k, reg_c["rd"] + 1)
                reg_c["rd"] += 1
            else:
                new_map[v] = "%r" + str(reg_c["r"] + 1)
                new_code += gen_dcl(v, k, reg_c["r"] + 1)
                reg_c["r"] += 1
    elif "param" in line or ".loc" in line:
        continue
    else:
        pattern = re.compile(r'(%r[d]?[\d<>\w]*)')
        regs = pattern.findall(line)

        for k, v in new_map.items():
            if k in regs:
                line = line.replace(k, v)
        new_code += san(line) + "\n"
print(new_map)
print('\n'.join(["\"" + (x) + "\\n\"" for x in new_code.split("\n")]))

{8: '%r211', 7: '%r210', 6: '%r209', 5: '%r208', 2: '%rd3', 1: '%rd2', 0: '%rd1'}
{'%r211': '%r1167', '%r210': '%r1168', '%r209': '%r1169', '%r208': '%r1170', '%rd3': '%rd117', '%rd2': '%rd118', '%rd1': '%rd119'}
"{\n"
".reg .u32 %%r1167;\n"
"mov.u32 %%r1167, %8;\n"
".reg .u32 %%r1168;\n"
"mov.u32 %%r1168, %7;\n"
".reg .u32 %%r1169;\n"
"mov.u32 %%r1169, %6;\n"
".reg .u32 %%r1170;\n"
"mov.u32 %%r1170, %5;\n"
".reg .u64 %%rd117;\n"
"mov.u64 %%rd117, %2;\n"
".reg .u64 %%rd118;\n"
"mov.u64 %%rd118, %1;\n"
".reg .u64 %%rd119;\n"
"mov.u64 %%rd119, %0;\n"
"	.reg .pred 	%%p<80>;\n"
"	.reg .b32 	%%r<1166>;\n"
"	.reg .f32 	%%f<610>;\n"
"	.reg .b64 	%%rd<116>;\n"
"$L__func_begin0:\n"
"\n"
"// %%bb.0:\n"
"$L__tmp0:\n"
"	// begin inline asm\n"
"	mov.u32 %%r212, %%ctaid.x;\n"
"	// end inline asm\n"
"	// begin inline asm\n"
"	mov.u32 %%r213, %%ctaid.y;\n"
"	// end inline asm\n"
"	shl.b32 	%%r1, %%r212, 6;\n"
"	mov.u32 	%%r2, %%tid.x;\n"
"	bfe.u32 	%%r3, %%r2, 4, 3;\n"
"	shl.b32 	%%r4, %%r2, 2;\n"
"	a

In [26]:
code = value(value(matmul_kernel.cache)).asm['ptx']

def san(x:str):
    return x.replace("%", "%%")
    

print('\n'.join(["\"" + san(x) + "\\n\"" for x in code.split("\n")]))

"//\n"
"// Generated by LLVM NVPTX Back-End\n"
"//\n"
"\n"
".version 8.4\n"
".target sm_86\n"
".address_size 64\n"
"\n"
"	// .globl	matmul_kernel           // -- Begin function matmul_kernel\n"
".extern .shared .align 16 .b8 global_smem[];\n"
"                                        // @matmul_kernel\n"
".visible .entry matmul_kernel(\n"
"	.param .u64 .ptr .global .align 1 matmul_kernel_param_0,\n"
"	.param .u64 .ptr .global .align 1 matmul_kernel_param_1,\n"
"	.param .u64 .ptr .global .align 1 matmul_kernel_param_2,\n"
"	.param .u32 matmul_kernel_param_3,\n"
"	.param .u32 matmul_kernel_param_4,\n"
"	.param .u32 matmul_kernel_param_5,\n"
"	.param .u32 matmul_kernel_param_6,\n"
"	.param .u32 matmul_kernel_param_7,\n"
"	.param .u32 matmul_kernel_param_8\n"
")\n"
".reqntid 128, 1, 1\n"
"{\n"
"	.reg .pred 	%%p<80>;\n"
"	.reg .b32 	%%r<1166>;\n"
"	.reg .f32 	%%f<610>;\n"
"	.reg .b64 	%%rd<116>;\n"
"	.loc	1 7 0                           // 3479778974.py:7:0\n"
"$L__func_begin0:\n"
"	.loc	1 7