In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch

import triton
import triton.language as tl

In [19]:
def cdiv(a,b): return (a + b - 1) // b

@triton.jit
def matmul(x_ptr, y_ptr, out_ptr, dx0, dx1, dy0, dy1, bs0: tl.constexpr, bs1: tl.constexpr):
    pid_0 = tl.program_id(0)
    pid_1 = tl.program_id(1)
    
    offs_0 = pid_0 * bs0 + tl.arange(0,bs0)  # 1d vector
    offs_1 = pid_1 * bs1 + tl.arange(0,bs1)  # 1d vector

    offs_x = dx1 * offs_0[:,None] + offs_1[None, :]  # 2d matrix! - we multiply first offset by width, see image above
    mask_x_0 = offs_0 < dx0
    mask_x_1 = offs_1 < dx1
    mask_x = mask_x_0[:,None] & mask_x_1[None, :]
    x = tl.load(x_ptr + offs_x, mask=mask_x)
    
    offs_y = dy1 * offs_1[:,None] + offs_0[None, :]
    mask_y_0 = offs_0 < dy0
    mask_y_1 = offs_1 < dy1
    mask_y = mask_y_0[:,None] & mask_y_1[None, :]
    y = tl.load(y_ptr + offs_y, mask=mask_y)
    
    out = x * y
    
    tl.store(out_ptr + offs_x, out, mask=mask_x)

d0 = 5
d1 = 7
bs0 = 2
bs1 = 2
x = torch.arange(1, d0*d1+1).reshape(d0, d1).to("cuda")
y = torch.arange(d0*d1+1, 2*d0*d1+1).reshape(d1, d0).to("cuda")
print(x)
print(y)

grid = lambda meta: (cdiv(d0, meta['bs0']), cdiv(d1,  meta['bs1']))
out = torch.zeros(d0, d0, device="cuda")
matmul[grid](x, y, out, d0, d1, d1, d0, bs0=bs0, bs1=bs1)

print(out)

tensor([[ 1,  2,  3,  4,  5,  6,  7],
        [ 8,  9, 10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19, 20, 21],
        [22, 23, 24, 25, 26, 27, 28],
        [29, 30, 31, 32, 33, 34, 35]], device='cuda:0')
tensor([[36, 37, 38, 39, 40],
        [41, 42, 43, 44, 45],
        [46, 47, 48, 49, 50],
        [51, 52, 53, 54, 55],
        [56, 57, 58, 59, 60],
        [61, 62, 63, 64, 65],
        [66, 67, 68, 69, 70]], device='cuda:0')
tensor([[  36.,   74.,  138.,  188.,  280.],
        [   0.,    0.,  328.,  378.,  510.],
        [ 572.,  732.,    0.,    0.,  570.],
        [ 624.,  816.,  882., 1102.,    0.],
        [   0.,  946., 1012., 1272., 1350.]], device='cuda:0')
