In [1]:
import torch
import triton
import triton.language as tl

@triton.jit
def get_1d_offset(size, n_prev_chunks):
    return n_prev_chunks * size + tl.arange(0, size)

@triton.jit
def get_2d_offset(offs_0, offs_1, stride_0, stride_1=1):
    return tl.expand_dims(offs_0, 1) * stride_0 + tl.expand_dims(offs_1, 0) * stride_1

@triton.jit
def get_1d_mask(offs, max):
    return offs < max

@triton.jit
def get_2d_mask(offs_0, offs_1, max_0, max_1):
    return (tl.expand_dims(offs_0, 1) < max_0 ) & (tl.expand_dims(offs_1, 0) < max_1)

def get_cuda_autotune_config():
    return [
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
                      num_warps=8),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
                      num_warps=2),
        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
                      num_warps=2),
        # Good config for fp8 inputs.
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3,
                      num_warps=8),
        triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3,
                      num_warps=8),
        triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4)
    ]

In [2]:
@triton.jit
def matmul_naive_k(
    a_ptr,
    b_ptr,
    c_ptr,
    m,
    n,
    k,
    stride_am,
    stride_ak,
    stride_bk,
    stride_bn,
    stride_cm,
    stride_cn,
    bm: tl.constexpr,
    bn: tl.constexpr,
    bk: tl.constexpr,
):
    pid_m, pid_n = tl.program_id(0), tl.program_id(1)
    # get number of chunks for each dimension
    rm = get_1d_offset(bm, pid_m)
    rn = get_1d_offset(bn, pid_n)
    rk = get_1d_offset(bk, 0) # K will always start from 0

    offs_a = a_ptr + get_2d_offset(rm, rk, stride_am, stride_ak)
    offs_b = b_ptr + get_2d_offset(rk, rn, stride_bk, stride_bn)

    acc = tl.zeros((bm, bn), dtype=tl.float32)
    for _ in range(0, k, bk):
        a = tl.load(offs_a, mask=get_2d_mask(rm, rk, m, k), other=0.0)
        b = tl.load(offs_b, mask=get_2d_mask(rk, rn, k, n), other=0.0)
        acc += tl.dot(a, b, allow_tf32=False)
        offs_a += bk * stride_ak
        offs_b += bk * stride_bk

    c = c_ptr + get_2d_offset(rm, rn, stride_cm, stride_cn)
    mask = get_2d_mask(rm, rn, m, n)
    tl.store(c, acc, mask=mask)


In [3]:
from functools import partial

def matmul(a, b, matmul_k_fn, bs=16, group_sz=None):
    assert a.shape[1] == b.shape[0], "matrix dims not compatible for matmul"
    (m, k), (_, n) = a.shape, b.shape
    c = torch.empty((m, n), device=a.device, dtype=torch.float16)
    grid = lambda meta: (triton.cdiv(m, meta['bm']),  triton.cdiv(n, meta['bn']))
    group_sz = {} if group_sz is None else {"group_sz":group_sz} # not used in naive_matmul, but will be in grouped_matmul further below 
    matmul_k_fn[grid](
        a, b, c,
        m, n, k,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
        bm=bs, bn=bs, bk=bs, # Weirdness: allow_tf32 must be set to False for older GPUs, otherwise won't compile
        **group_sz
    )
    return c

naive_matmul = partial(matmul, matmul_k_fn=matmul_naive_k)

In [4]:
a = torch.ones((3, 4), dtype=torch.float32, device="cuda")
b = torch.ones((4, 5), dtype=torch.float32, device="cuda")
naive_matmul(a, b)

tensor([[4., 4., 4., 4., 4.],
        [4., 4., 4., 4., 4.],
        [4., 4., 4., 4., 4.]], device='cuda:0', dtype=torch.float16)

In [5]:
torch.manual_seed(0)
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
triton_output = naive_matmul(a, b)
torch_output = torch.matmul(a, b)
if torch.allclose(triton_output, torch_output, atol=5e-2, rtol=0):
    print("✅ Triton and Torch match")
else:
    print("❌ Triton and Torch differ")

✅ Triton and Torch match


In [6]:
@triton.jit
def swizzle_k(x_ptr, z_ptr, group_sz: tl.constexpr):
    pid_m, pid_n = tl.program_id(0), tl.program_id(1)
    num_pid_m, num_pid_n = tl.num_programs(0), tl.num_programs(1)

    # get swizzled program ids
    pid_m_, pid_n_ = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, group_sz)  # Weirdness: tl.swizzle2d doesn't work when simulating on CPU
    
    # get offsets aligned to original program ids
    offs_m = get_1d_offset(1, n_prev_chunks=pid_m)
    offs_n = get_1d_offset(1, n_prev_chunks=pid_n)
    
    # get mask for original program ids
    offs = get_2d_offset(offs_m, offs_n, stride_0=num_pid_n)
    mask = get_2d_mask(offs_m, offs_n, max_0=num_pid_m, max_1=num_pid_n )

    # get offsets aligned to swizzled program ids
    offs_sw_m = get_1d_offset(1, n_prev_chunks=pid_m_)
    offs_sw_n = get_1d_offset(1, n_prev_chunks=pid_n_)
    
    # get mask for swizzled program ids
    offs_sw = get_2d_offset(offs_sw_m, offs_sw_n, stride_0=num_pid_n)
    mask_sw = get_2d_mask(offs_sw_m, offs_sw_n, max_0=num_pid_m, max_1=num_pid_n)
    
    # load using original program ids and store using swizzled program ids
    x = tl.load(x_ptr + offs, mask=mask)
    tl.store(z_ptr + offs_sw, x, mask=mask_sw)

In [7]:
blocks_m, blocks_n = 5, 4
x = torch.arange(blocks_m * blocks_n, device='cuda').view(blocks_m, blocks_n)
x

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15],
        [16, 17, 18, 19]], device='cuda:0')

In [8]:
z = -torch.torch.ones_like(x)
z

tensor([[-1, -1, -1, -1],
        [-1, -1, -1, -1],
        [-1, -1, -1, -1],
        [-1, -1, -1, -1],
        [-1, -1, -1, -1]], device='cuda:0')

In [9]:
swizzle_k[(blocks_m, blocks_n)](x, z, group_sz=3)
z

tensor([[ 0,  3,  6,  9],
        [ 1,  4,  7, 10],
        [ 2,  5,  8, 11],
        [12, 14, 16, 18],
        [13, 15, 17, 19]], device='cuda:0')

In [10]:
@triton.jit
def grouped_matmul_k(
    a_ptr,
    b_ptr,
    c_ptr,
    m,
    n,
    k,
    stride_am,
    stride_ak,
    stride_bk,
    stride_bn,
    stride_cm,
    stride_cn,
    bm: tl.constexpr,
    bn: tl.constexpr,
    bk: tl.constexpr,
    group_sz: tl.constexpr
):
    pid_m, pid_n = tl.program_id(0), tl.program_id(1)
    num_pid_m, num_pid_n = tl.num_programs(0), tl.num_programs(1)

    # get swizzled program ids
    pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, group_sz)

    # get offsets for each axis
    rm = get_1d_offset(bm, pid_m)
    rn = get_1d_offset(bn, pid_n)
    rk = get_1d_offset(bk, 0) # K will always start from 0

    # relevant offsets for a and b
    offs_a = a_ptr + get_2d_offset(rm, rk, stride_am, stride_ak)
    offs_b = b_ptr + get_2d_offset(rk, rn, stride_bk, stride_bn)

    # initialize accumulator
    acc = tl.zeros((bm, bn), dtype=tl.float32)
    for _ in range(0, k, bk):
        a = tl.load(offs_a, mask=get_2d_mask(rm, rk, m, k), other=0.0)
        b = tl.load(offs_b, mask=get_2d_mask(rk, rn, k, n), other=0.0)
        acc += tl.dot(a, b, allow_tf32=False)
        # update offsets for next iteration 
        offs_a += bk * stride_ak
        offs_b += bk * stride_bk

    c = c_ptr + get_2d_offset(rm, rn, stride_cm, stride_cn)
    mask = get_2d_mask(rm, rn, m, n)
    tl.store(c, acc, mask=mask)

In [11]:
grouped_matmul = partial(matmul, matmul_k_fn=grouped_matmul_k)

In [12]:
a = torch.ones((3, 4), dtype=torch.float32, device="cuda")
b = torch.ones((4, 5), dtype=torch.float32, device="cuda")
grouped_matmul(a, b, group_sz=4)

tensor([[4., 4., 4., 4., 4.],
        [4., 4., 4., 4., 4.],
        [4., 4., 4., 4., 4.]], device='cuda:0', dtype=torch.float16)

In [13]:
@triton.jit
def masked_matmul_k(
    a_ptr,
    b_ptr,
    c_ptr,
    mask_ptr,
    m,
    n,
    k,
    e,
    stride_am,
    stride_ak,
    stride_bk,
    stride_bn,
    stride_cm,
    stride_cn,
    stride_mask,
    bm: tl.constexpr,
    bn: tl.constexpr,
    bk: tl.constexpr,
    group_sz: tl.constexpr
):
    pid_m, pid_n = tl.program_id(0), tl.program_id(1)
    num_pid_m, num_pid_n = tl.num_programs(0), tl.num_programs(1)

    # get swizzled program ids
    pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, group_sz)

    # get offsets for each axis
    rm = get_1d_offset(bm, pid_m)
    rn = get_1d_offset(bn, pid_n)
    rk = get_1d_offset(bk, 0) # K will always start from 0

    # token mask only operates on the m dimension
    tmask = tl.load(mask_ptr + rm * stride_mask)
    tmask = tl.expand_dims(tmask, 1)

    # initialize accumulator
    acc = tl.zeros((bm, bn), dtype=tl.float32)
    for e_i in range(e): # e is the number of experts
        # Get in_dim // 2^(e - e_i)
        k_i = k >> (e - e_i - 1)

        # relevant offsets for a and b
        # needs to be updated for each expert
        offs_a = a_ptr + get_2d_offset(rm, rk, stride_am, stride_ak)
        offs_b = b_ptr + get_2d_offset(rk, rn, stride_bk, stride_bn)
        
        # get masks for the current expert
        t_mask_i = (tmask == (e_i))
        a_mask_i = get_2d_mask(rm, rk, m, k_i)
        b_mask_i = get_2d_mask(rk, rn, k_i, n)

        # perform dot product for the current expert dimension
        for _ in range(0, k_i, bk):
            a = tl.load(offs_a, mask=t_mask_i & a_mask_i, other=0.0)
            b = tl.load(offs_b, mask=b_mask_i, other=0.0)
            acc += tl.dot(a, b, allow_tf32=True)

            # update offsets for next iteration
            offs_a += bk * stride_ak
            offs_b += bk * stride_bk

    c = c_ptr + get_2d_offset(rm, rn, stride_cm, stride_cn)
    mask = get_2d_mask(rm, rn, m, n)
    tl.store(c, acc, mask=mask)


In [14]:
def masked_matmul(a, b, mask, experts=4, bs=16,group_sz=None):
    assert a.shape[1] == b.shape[0], "matrix dims not compatible for matmul"
    assert mask.shape[0] == a.shape[0], "mask dims not compatible for matmul"
    (m, k), (_, n) = a.shape, b.shape
    e = experts
    c = torch.empty((m, n), device=a.device, dtype=torch.float32)
    grid = lambda meta: (triton.cdiv(m, meta['bm']),  triton.cdiv(n, meta['bn']))
    group_sz = {} if group_sz is None else {"group_sz":group_sz} # not used in naive_matmul, but will be in grouped_matmul further below 
    masked_matmul_k[grid](
        a, b, c, mask,
        m, n, k, e,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
        mask.stride(0),
        bm=bs, bn=bs, bk=bs, # Weirdness: allow_tf32 must be set to False for older GPUs, otherwise won't compile
        **group_sz
    )
    return c

In [15]:
a = torch.ones((3, 4), dtype=torch.float32, device="cuda")
b = torch.ones((4, 5), dtype=torch.float32, device="cuda")
experts = 1
mask = torch.tensor([0, 0, 0], dtype=torch.int32, device="cuda")
masked_matmul(a, b, mask, experts, bs=16, group_sz=4)

tensor([[4., 4., 4., 4., 4.],
        [4., 4., 4., 4., 4.],
        [4., 4., 4., 4., 4.]], device='cuda:0')

In [16]:
a = torch.ones((5, 128), dtype=torch.float32, device="cuda", requires_grad=True)
b = torch.ones((128, 256), dtype=torch.float32, device="cuda", requires_grad=True)
experts = 4
mask = torch.tensor([2, 1, 3, 0, 2], dtype=torch.int32, device="cuda")
output = masked_matmul(a, b, mask, experts, bs=16, group_sz=4)
print(f"contains 128: {torch.any(output == 128)}")
print(f"contains 64: {torch.any(output == 64)}")
print(f"contains 32: {torch.any(output == 32)}")
print(f"shape: {output.shape}")
print(output[0, :10])
print(output[1, :10])
print(output[2, :10])
print(output[3, :10])
print(output[4, :10])
output

contains 128: True
contains 64: True
contains 32: True
shape: torch.Size([5, 256])
tensor([64., 64., 64., 64., 64., 64., 64., 64., 64., 64.], device='cuda:0')
tensor([32., 32., 32., 32., 32., 32., 32., 32., 32., 32.], device='cuda:0')
tensor([128., 128., 128., 128., 128., 128., 128., 128., 128., 128.],
       device='cuda:0')
tensor([16., 16., 16., 16., 16., 16., 16., 16., 16., 16.], device='cuda:0')
tensor([64., 64., 64., 64., 64., 64., 64., 64., 64., 64.], device='cuda:0')


tensor([[ 64.,  64.,  64.,  ...,  64.,  64.,  64.],
        [ 32.,  32.,  32.,  ...,  32.,  32.,  32.],
        [128., 128., 128.,  ..., 128., 128., 128.],
        [ 16.,  16.,  16.,  ...,  16.,  16.,  16.],
        [ 64.,  64.,  64.,  ...,  64.,  64.,  64.]], device='cuda:0')

In [17]:
from torch.nn.functional import linear
bias = torch.ones((256), dtype=torch.float32, device="cuda", requires_grad=True)
torch_output = linear(a, b.T, bias=bias).mean()
torch_output.backward()

print(a.grad[0, :10])
print(b.grad[0, :10])
print(bias.grad[:10])



tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000,
        0.2000], device='cuda:0')
tensor([0.0039, 0.0039, 0.0039, 0.0039, 0.0039, 0.0039, 0.0039, 0.0039, 0.0039,
        0.0039], device='cuda:0')
tensor([0.0039, 0.0039, 0.0039, 0.0039, 0.0039, 0.0039, 0.0039, 0.0039, 0.0039,
        0.0039], device='cuda:0')


In [18]:
# Let's add a bias to compute the output of a linear layer
@triton.autotune(
        configs=get_cuda_autotune_config(),
        key=['M', 'N', 'K']
)
@triton.jit
def nested_linear_expand_kernel(
    x_ptr,
    w_ptr,
    o_ptr,
    b_ptr,
    mask_ptr,
    M,
    N,
    K,
    E,
    stride_xm,
    stride_xk,
    stride_wk,
    stride_wn,
    stride_om,
    stride_on,
    stride_bias,
    stride_mask,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr
):
    pid_m, pid_n = tl.program_id(0), tl.program_id(1)
    num_pid_m, num_pid_n = tl.num_programs(0), tl.num_programs(1)

    # get swizzled program ids
    pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M)

    # get offsets for each axis
    rm = get_1d_offset(BLOCK_SIZE_M, pid_m)
    rn = get_1d_offset(BLOCK_SIZE_N, pid_n)
    rk = get_1d_offset(BLOCK_SIZE_K, 0) # K will always start from 0

    # token mask only operates on the m dimension
    tmask = tl.load(mask_ptr + rm * stride_mask)
    tmask = tl.expand_dims(tmask, 1)

    # initialize accumulator
    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    for e_i in range(E): # e is the number of experts
        # Get in_dim // 2^(e - e_i)
        k_i = K >> (E - e_i - 1)

        # relevant offsets for a and b
        # needs to be updated for each expert
        offs_x = x_ptr + get_2d_offset(rm, rk, stride_xm, stride_xk)
        offs_w = w_ptr + get_2d_offset(rk, rn, stride_wk, stride_wn)
        
        # get masks for the current expert
        t_mask_i = (tmask == (e_i))
        x_mask_i = get_2d_mask(rm, rk, M, k_i)
        w_mask_i = get_2d_mask(rk, rn, k_i, N)

        # perform dot product for the current expert dimension
        for _ in range(0, k_i, BLOCK_SIZE_K):
            x = tl.load(offs_x, mask=t_mask_i & x_mask_i, other=0.0)
            w = tl.load(offs_w, mask=w_mask_i, other=0.0)
            acc += tl.dot(x, w)

            # update offsets for next iteration
            offs_x += BLOCK_SIZE_K * stride_xk
            offs_w += BLOCK_SIZE_K * stride_wk

    # Add bias after accumulation
    if b_ptr is not None:
        offs_b = b_ptr + rn * stride_bias
        b_mask = rn < N
        b = tl.load(offs_b, mask=b_mask, other=0.0)
        acc += b[None, :]

    o = o_ptr + get_2d_offset(rm, rn, stride_om, stride_on)
    mask = get_2d_mask(rm, rn, M, N)
    tl.store(o, acc, mask=mask)

In [19]:
# Let's build the nested linear contraction
@triton.autotune(
        configs=get_cuda_autotune_config(),
        key=['M', 'N', 'K']
)
@triton.jit
def nested_linear_contract_kernel(
    x_ptr,
    w_ptr,
    o_ptr,
    b_ptr,
    mask_ptr,
    M,
    N,
    K,
    E,
    stride_xm,
    stride_xk,
    stride_wk,
    stride_wn,
    stride_om,
    stride_on,
    stride_bias,
    stride_mask,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr
):
    pid_m, pid_n = tl.program_id(0), tl.program_id(1)
    num_pid_m, num_pid_n = tl.num_programs(0), tl.num_programs(1)

    # get swizzled program ids
    pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M)

    # get offsets for each axis
    rm = get_1d_offset(BLOCK_SIZE_M, pid_m)
    rn = get_1d_offset(BLOCK_SIZE_N, pid_n)
    rk = get_1d_offset(BLOCK_SIZE_K, 0) # K will always start from 0

    # token mask only operates on the m dimension
    tmask = tl.load(mask_ptr + rm * stride_mask)
    tmask = tl.expand_dims(tmask, 1)

    # mask for the current expert
    x_mask_i = get_2d_mask(rm, rk, M, K)

    # initialize accumulator
    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    for e_i in range(E): # e is the number of experts
        # Get out_dim // 2^(e - e_i)
        n_i = N >> (E - e_i - 1)

        # relevant offsets for a and b
        # needs to be updated for each expert
        offs_x = x_ptr + get_2d_offset(rm, rk, stride_xm, stride_xk)
        offs_w = w_ptr + get_2d_offset(rk, rn, stride_wk, stride_wn)

        # get masks for the current expert
        t_mask_i = (tmask == (e_i))
        w_mask_i = get_2d_mask(rk, rn, K, n_i)

        for _ in range(0, K, BLOCK_SIZE_K):
            x = tl.load(offs_x, mask=t_mask_i & x_mask_i, other=0.0)
            w = tl.load(offs_w, mask=w_mask_i, other=0.0)
            acc += tl.dot(x, w)

            # update offsets for next iteration
            offs_x += BLOCK_SIZE_K * stride_xk
            offs_w += BLOCK_SIZE_K * stride_wk

    # Add bias after accumulation
    if b_ptr is not None:
        offs_b = b_ptr + rn * stride_bias
        b_mask = rn < N
        b = tl.load(offs_b, mask=b_mask, other=0.0)
        acc += b[None, :]

    o = o_ptr + get_2d_offset(rm, rn, stride_om, stride_on)
    mask = get_2d_mask(rm, rn, M, N)
    tl.store(o, acc, mask=mask)



In [20]:
def nested_linear_expand(x, w, mask, b=None, experts=4):
    assert x.shape[1] == w.shape[0], "Incompatible dimensions"
    assert x.is_contiguous(), "Matrix x must be contiguous"
    M, K = x.shape
    K, N = w.shape
    # Allocate output
    output = torch.empty((M, N), device=x.device, dtype=x.dtype)
    # 1D launch kernel where each block gets its own program
    grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(N, meta['BLOCK_SIZE_N']))
    nested_linear_expand_kernel[grid](
        x, w, output, b, mask, M, N, K, experts,
        x.stride(0), x.stride(1),
        w.stride(0), w.stride(1),
        output.stride(0), output.stride(1),
        b.stride(0) if b is not None else None,
        mask.stride(0),
    )
    return output

In [21]:
# Test nested_linear_expand while keeping the mask fixed
x = torch.ones((5, 128), device="cuda", dtype=torch.float16, requires_grad=True)
w = torch.ones((128, 256), device="cuda", dtype=torch.float16, requires_grad=True)
bias = torch.zeros((256), device="cuda", dtype=torch.float16, requires_grad=True)
mask = torch.tensor([2, 1, 3, 0, 0], dtype=torch.int32, device="cuda")
output = nested_linear_expand(x, w, mask, bias, experts=4)
print(output)
print(output.shape)

tensor([[ 64.,  64.,  64.,  ...,  64.,  64.,  64.],
        [ 32.,  32.,  32.,  ...,  32.,  32.,  32.],
        [128., 128., 128.,  ..., 128., 128., 128.],
        [ 16.,  16.,  16.,  ...,  16.,  16.,  16.],
        [ 16.,  16.,  16.,  ...,  16.,  16.,  16.]], device='cuda:0',
       dtype=torch.float16)
torch.Size([5, 256])


In [36]:
# test forward pass with pytorch linear expand
from layers import NestedLinearExpand
x = torch.ones((1, 5, 128), device="cuda", dtype=torch.float16, requires_grad=True)
mask = torch.tensor([2, 1, 3, 0, 0], dtype=torch.int32, device="cuda").view(1, 5)
nested_linear_expand = NestedLinearExpand(128, 256, 4)
nested_linear_expand.half()
nested_linear_expand.deterministic_init()
nested_linear_expand.cuda()
output = nested_linear_expand(x, mask)
print(output)
print(output.shape)


tensor([[[ 64.,  64.,  64.,  ...,  64.,  64.,  64.],
         [ 32.,  32.,  32.,  ...,  32.,  32.,  32.],
         [128., 128., 128.,  ..., 128., 128., 128.],
         [ 16.,  16.,  16.,  ...,  16.,  16.,  16.],
         [ 16.,  16.,  16.,  ...,  16.,  16.,  16.]]], device='cuda:0',
       dtype=torch.float16, grad_fn=<ViewBackward0>)
torch.Size([1, 5, 256])


In [22]:
def nested_linear_contract(x, w, mask, b=None, experts=4):
    assert x.shape[1] == w.shape[0], "Incompatible dimensions"
    assert x.is_contiguous(), "Matrix x must be contiguous"
    M, K = x.shape
    K, N = w.shape
    # Allocate output
    output = torch.empty((M, N), device=x.device, dtype=x.dtype)
    # 1D launch kernel where each block gets its own program
    grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(N, meta['BLOCK_SIZE_N']))
    nested_linear_contract_kernel[grid](
        x, w, output, b, mask, M, N, K, experts,
        x.stride(0), x.stride(1),
        w.stride(0), w.stride(1),
        output.stride(0), output.stride(1),
        b.stride(0) if b is not None else None,
        mask.stride(0),
    )
    return output



In [23]:
# Test nested_linear_contract while keeping the mask fixed
x = torch.ones((5, 256), device="cuda", dtype=torch.float16, requires_grad=True)
w = torch.ones((256, 128), device="cuda", dtype=torch.float16, requires_grad=True)
bias = torch.zeros((128), device="cuda", dtype=torch.float16, requires_grad=True)
mask = torch.tensor([0, 1, 0, 0, 3], dtype=torch.int32, device="cuda")
output = nested_linear_contract(x, w, mask, bias, experts=4)
print(output)
print(output.shape)

tensor([[256., 256., 256., 256., 256., 256., 256., 256., 256., 256., 256., 256.,
         256., 256., 256., 256.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
           0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
           0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
           0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
           0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
           0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
           0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
           0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
           0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
           0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [256., 256., 256., 256., 256., 256., 256., 256., 256., 256., 256., 256.,
         256., 256., 256., 256., 256., 256., 256., 

In [32]:
# test forward pass with pytorch linear contract
from layers import NestedLinearContract
x = torch.ones((1, 5, 256), device="cuda", dtype=torch.float16, requires_grad=True)
mask = torch.tensor([0, 1, 0, 0, 3], dtype=torch.int32, device="cuda").view(1, 5)
linear_contract = NestedLinearContract(256, 128, 4)
linear_contract.half()
linear_contract.deterministic_init()
linear_contract.cuda()
output = linear_contract(x, mask)
print(output)
print(output.shape)

tensor([[[256., 256., 256., 256., 256., 256., 256., 256., 256., 256., 256.,
          256., 256., 256., 256., 256.,   0.,   0.,   0.,   0.,   0.,   0.,
            0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
            0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
            0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
            0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
            0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
            0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
            0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
            0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
            0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
            0.,   0.,   0.,   0.,   0.,   0.,   0.],
         [256., 256., 256., 256., 256., 256., 256., 256., 256., 256., 256.,
          256., 256., 256., 256., 2

In [24]:

# Let's implement the backward pass for grad_input which is grad_output * w^T
# and grad_weight which is x^T * grad_output

@triton.jit
def nested_linear_expand_dinput_kernel(
    grad_output_ptr,
    w_ptr,
    grad_input_ptr,
    mask_ptr,
    bias_ptr,
    M,
    N,
    K,
    E,
    stride_gom,
    stride_gon,
    stride_gok,
    stride_mask,
    stride_bias,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr
):
    pass

@triton.jit
def nested_linear_contract_dweight_kernel(
    grad_output_ptr,
    x_ptr,
    grad_weight_ptr,
    mask_ptr,
    bias_ptr,
    M,
    N,
    K,
    E,
    stride_gom,
    stride_gon,
    stride_gok,
    stride_mask,
    stride_bias,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr
):
    pass
        





In [25]:
from torch.nn.functional import linear

class NestedLinearExpandFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, w, bias, mask, num_experts):
        x = x.view(-1, x.shape[-1])
        w = w.transpose()
        ctx.save_for_backward(x, w)
        output = nested_linear_expand(x, w, mask, bias, num_experts)
        return output.view(x.shape[:-1] + (w.shape[-1],))
    
    @staticmethod
    def backward(ctx, grad_output):
        x, w = ctx.saved_tensors
        dw = torch.matmul(grad_output, x)
        dx = torch.matmul(grad_output, w)
        db = grad_output.sum(dim=0)
        return dx.view(x.shape), dw.view(w.shape), db
    
class NestedLinearContractFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, w, bias, mask, num_experts):
        x = x.view(-1, x.shape[-1])
        w = w.transpose()
        ctx.save_for_backward(x, w)
        output = nested_linear_contract(x, w, mask, bias, num_experts)
        return output.view(x.shape[:-1] + (w.shape[-1],))
    
    @staticmethod
    def backward(ctx, grad_output):
        grad_output = grad_output.view(-1, grad_output.shape[-1])
        x, w = ctx.saved_tensors
        dx = torch.matmul(grad_output, w)
        dw = torch.matmul(x, grad_output)
        db = grad_output.sum(dim=0)
        return dx.view(x.shape), dw.view(w.shape), db

In [28]:
# Lets test the backward pass on pytorch versions of the functions
from layers import NestedLinearExpand, NestedLinearContract
import time

x = torch.ones((1, 5, 128), device="cuda", dtype=torch.float32, requires_grad=True)
token_mask = torch.tensor([2, 1, 3, 0, 0], dtype=torch.int32, device="cuda").view(1, 5)
linear_expand = NestedLinearExpand(128, 256, 4)
linear_contract = NestedLinearContract(128, 64, 4)
linear_contract.deterministic_init()
linear_expand.deterministic_init()
linear_expand.cuda()
linear_contract.cuda()

def time_expand():
    total = 0
    for _ in range(20):
        start = time.time()
        linear_expand(x, token_mask)
        torch.cuda.synchronize()
        end = time.time()
        total += end - start
    return total / 20

def time_contract():
    total = 0
    for _ in range(20):
        start = time.time()
        linear_contract(x, token_mask)
        torch.cuda.synchronize()
        end = time.time()
        total += end - start
    return total / 20

print(f"Time for forward expand: {time_expand()}")
print(f"Time for forward contract: {time_contract()}")

Time for forward expand: 0.00045578479766845704
Time for forward contract: 0.0005160093307495118


In [29]:
# Lets test the forward pass on the triton versions
x = torch.ones((1, 5, 128), device="cuda", dtype=torch.float32, requires_grad=True)
w1 = torch.ones((128, 256), device="cuda", dtype=torch.float32, requires_grad=True)
w2 = torch.ones((128, 64), device="cuda", dtype=torch.float32, requires_grad=True)
token_mask = torch.tensor([2, 1, 3, 0, 0], dtype=torch.int32, device="cuda").view(1, 5)

def time_expand_triton():
    total = 0
    for _ in range(20):
        start = time.time()
        nested_linear_expand(x.view(-1, 128), w1, token_mask, None, 4)
        torch.cuda.synchronize()
        end = time.time()
        total += end - start
    return total / 20

def time_contract_triton():
    total = 0
    for _ in range(20):
        start = time.time()
        nested_linear_contract(x.view(-1, 128), w2, token_mask, None, 4)
        torch.cuda.synchronize()
        end = time.time()
        total += end - start
    return total / 20

print(f"Time for forward expand triton: {time_expand_triton()}")
print(f"Time for forward contract triton: {time_contract_triton()}")

Time for forward expand triton: 6.181001663208008e-05
Time for forward contract triton: 6.07609748840332e-05
