# Split activation function in Triton

This notebook develops a performant Triton kernel to implement the PyTorch function:

In [6]:
import torch

def torch_splacc(A, B):
    C = torch.matmul(A, B)
    Cl, Cr = torch.tensor_split(C, 2, dim=1)
    D = torch.sigmoid(Cl) * torch.tanh(Cr)
    return D

This addresses the Triton issue: https://github.com/openai/triton/issues/984

## Utilities

We first define some utilities for creating test data.

In [7]:
import torch

def make_tensors(n_feat):
    torch.manual_seed(0)

    K = n_feat
    M = n_feat // 2
    N = n_feat

    print(f"{M=}, {N=}, {K=}")

    A = torch.randn((M, K), device="cuda", dtype=torch.float32)
    B = torch.randn((K, 2 * N), device="cuda", dtype=torch.float32)
    Bl, Br = torch.tensor_split(B, 2, dim=1)
    # We transpose B so that it is stored in column-major order. This makes `tl.load`
    # slightly faster.
    BlT = Bl.T.contiguous()
    BrT = Br.T.contiguous()

    # B interleaved, transposed.
    BiT = torch.cat((BlT, BrT), dim=1).reshape((2 * n_feat, n_feat)).contiguous()

    return A, B, BlT, BrT, BiT

A, B, BlT, BrT, BiT = make_tensors(2048)
A.shape, B.shape, BiT.shape

M=1024, N=2048, K=2048


(torch.Size([1024, 2048]), torch.Size([2048, 4096]), torch.Size([4096, 2048]))

# Kernel v1 - two accumulators

In [8]:
import triton
import triton.language as tl


@triton.jit
def kernel_v1(A_ptr, BlT_ptr, BrT_ptr, out_ptr, M: tl.constexpr,
              N: tl.constexpr, K: tl.constexpr, BLOCK_SIZE_M: tl.constexpr,
              BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr):
    """
    Inputs
    ------
    A : (M, K) matrix.
    B : (K, 2*N) matrix -> Bl, Br : (K, N) matrices -> BlT, BrT : (N, K) matrices.
    out: (M, N) matrix.

    Kernel instance with coordinates (i, j) computes: 
        out[(i:i+1) * BLOCK_SIZE_M, (j:j+1) * BLOCK_SIZE_N] 
    from:
        A[(i:i+1) * BLOCK_SIZE_M, :] and 
        Bl[:, (j:j+1) * BLOCK_SIZE_N] and Br[:, (j:j+1) * BLOCK_SIZE_N].
    The computation is performed incrementally. The k^th iteration multiplies: 
        A[(i:i+1) * BLOCK_SIZE_M, (k:k+1) * BLOCK_SIZE_K] with
        B[(k:k+1) * BLOCK_SIZE_K, (j:j+1) * BLOCK_SIZE_N]
    """
    i = tl.program_id(0)
    j = tl.program_id(1)

    M_idxs = tl.arange(
        0, BLOCK_SIZE_M) + i * BLOCK_SIZE_M  # (i:i+1) * BLOCK_SIZE_M
    N_idxs = tl.arange(
        0, BLOCK_SIZE_N) + j * BLOCK_SIZE_N  # (j:j+1) * BLOCK_SIZE_N

    acc_left = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    acc_right = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):

        # (k:k+1) * BLOCK_SIZE_K
        K_idxs = tl.arange(0, BLOCK_SIZE_K) + k * BLOCK_SIZE_K

        A_blk_idxs = M_idxs[:, None] * K + K_idxs[None, :]
        A_blk = tl.load(A_ptr + A_blk_idxs)

        # Create a transposed index so we load BT as B.
        BT_blk_idxs = N_idxs[None, :] * K + K_idxs[:, None]
        BlT_blk = tl.load(BlT_ptr + BT_blk_idxs)

        acc_left += tl.dot(A_blk, BlT_blk)

        BrT_blk = tl.load(BrT_ptr + BT_blk_idxs)
        acc_right += tl.dot(A_blk, BrT_blk)

    out = tl.sigmoid(acc_left) * tl.libdevice.tanh(acc_right)

    out_idxs = M_idxs[:, None] * N + N_idxs[None, :]
    tl.store(out_ptr + out_idxs, out)

In [9]:
def dispatch_v1(A, BlT, BrT, out, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K):
    M, K = A.shape
    N, _ = BlT.shape
    assert BlT.shape == BrT.shape
    assert out.shape == (M, N)

    grid = M // BLOCK_SIZE_M, N // BLOCK_SIZE_N
    kernel_v1[grid](A, BlT, BrT, out, M, N, K, BLOCK_SIZE_M, BLOCK_SIZE_N,
                    BLOCK_SIZE_K)

## Output Comparison

In [10]:
torch_out = torch_splacc(A, B)
kv1_out = torch.empty_like(torch_out)
dispatch_v1(A, BlT, BrT, kv1_out, 64, 64, 64)

In [11]:
from conch import mad
mad(kv1_out, torch_out)

0.056364

In [12]:
torch_out.abs().mean()

tensor(0.4934, device='cuda:0')

## Performance comparison

In [13]:
from triton.testing import do_bench

do_bench(lambda : torch_splacc(A, B), warmup=100, rep=500)[0] * 1000

988.1600141525269

In [14]:
%%script false --no-raise-error

from functools import partial
from conch import grid_search, results_to_df

v1_perf = grid_search(partial(dispatch_v1, A, BlT, BrT, kv1_out),
                      BLOCK_SIZE_M=(16, 1024),
                      BLOCK_SIZE_N=(16, 1024),
                      BLOCK_SIZE_K=(16, 1024),)

results_to_df(v1_perf)

# Kernel v2 - view-sum split

In [15]:
x = torch.arange(4 * 6).reshape((4, 6))
x

tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16, 17],
        [18, 19, 20, 21, 22, 23]])

In [16]:
even_mask = torch.arange(0, 2) % 2 == 0
even_mask

tensor([ True, False])

In [17]:
even_mask = torch.broadcast_to(even_mask, (4 * 6 // 2, -1))
even_mask

tensor([[ True, False],
        [ True, False],
        [ True, False],
        [ True, False],
        [ True, False],
        [ True, False],
        [ True, False],
        [ True, False],
        [ True, False],
        [ True, False],
        [ True, False],
        [ True, False]])

In [18]:
even_odd = x.view((-1, 2))
even_odd

tensor([[ 0,  1],
        [ 2,  3],
        [ 4,  5],
        [ 6,  7],
        [ 8,  9],
        [10, 11],
        [12, 13],
        [14, 15],
        [16, 17],
        [18, 19],
        [20, 21],
        [22, 23]])

In [19]:
left = torch.where(even_mask, even_odd, torch.zeros_like(even_odd)).sum(dim=1)
left

tensor([ 0,  2,  4,  6,  8, 10, 12, 14, 16, 18, 20, 22])

In [20]:
right = torch.where(~even_mask, even_odd, torch.zeros_like(even_odd)).sum(dim=1)
right

tensor([ 1,  3,  5,  7,  9, 11, 13, 15, 17, 19, 21, 23])

In [21]:
# out = torch.sigmoid(left) * torch.tanh(right)
out = left + right
out

tensor([ 1,  5,  9, 13, 17, 21, 25, 29, 33, 37, 41, 45])

In [22]:
out.view((4, 3))

tensor([[ 1,  5,  9],
        [13, 17, 21],
        [25, 29, 33],
        [37, 41, 45]])

In [23]:
Ci = A @ BiT.T
Clr = Ci.view((-1, 2))
left_mask_row = torch.arange(0, 2, device="cuda") % 2 == 0
left_mask = torch.broadcast_to(left_mask_row, Clr.shape)

left = torch.where(left_mask, Clr, torch.zeros_like(Clr, device="cuda")).sum(dim=1)
right = torch.where(~left_mask, Clr, torch.zeros_like(Clr, device="cuda")).sum(dim=1)

out = torch.sigmoid(left) * torch.tanh(right)
out = out.view((A.shape[0], B.shape[1] // 2))

In [24]:
left.reshape((A.shape[0], -1))

tensor([[-13.9087,  -2.3076,  35.4975,  ...,  30.0138,  62.2520,  92.0306],
        [-49.8847,  -9.4593,  14.9723,  ...,  18.6266,  24.8574, -62.9675],
        [-44.1483, -36.1945,   6.1588,  ...,  51.8622,  29.8774, 100.3384],
        ...,
        [-10.6014, -24.2593, -33.9749,  ...,  52.2683,  61.5418,  30.9633],
        [ 12.0531, -19.7057, -56.3911,  ...,  17.2008,   1.2616,  55.5883],
        [-31.5721, -50.6414, -13.3048,  ...,  30.0606, -23.0794, -40.1447]],
       device='cuda:0')

In [25]:
mad(torch_out, out)

0.0001519844

In [26]:
i = 0; j = 0
M_idxs = torch.arange(0, 64)
N2_idxs = torch.arange(0, 128)
acc = torch.zeros((64, 128), device="cuda")

for k in range(2048 // 64):

    K_idxs = torch.arange(0, 64) + k * 64

    A_blk_idxs = M_idxs[:, None] * 2048 + K_idxs[None, :]
    A_blk = A.view(-1)[A_blk_idxs]

    BiT_blk_idxs = N2_idxs[None, :] * 2048 + K_idxs[:, None]
    print(BiT_blk_idxs)
    Bi_blk = BiT.view(-1)[BiT_blk_idxs]

    acc += A_blk @ Bi_blk

left_right = acc.view((-1, 2))
left_mask_row = torch.arange(0, 2, device="cuda") % 2 == 0
left_mask = torch.broadcast_to(left_mask_row, left_right.shape)

left = torch.where(left_mask, left_right, torch.zeros_like(left_right, device="cuda")).sum(dim=1)
right = torch.where(~left_mask, left_right, torch.zeros_like(left_right, device="cuda")).sum(dim=1)

out = torch.zeros_like(BiT.T).ravel()
out_idxs = M_idxs[:, None] * 2 * 2048 + N2_idxs[None, :]
out[out_idxs] = acc
out = out.view(BiT.T.shape)

tensor([[     0,   2048,   4096,  ..., 256000, 258048, 260096],
        [     1,   2049,   4097,  ..., 256001, 258049, 260097],
        [     2,   2050,   4098,  ..., 256002, 258050, 260098],
        ...,
        [    61,   2109,   4157,  ..., 256061, 258109, 260157],
        [    62,   2110,   4158,  ..., 256062, 258110, 260158],
        [    63,   2111,   4159,  ..., 256063, 258111, 260159]])
tensor([[    64,   2112,   4160,  ..., 256064, 258112, 260160],
        [    65,   2113,   4161,  ..., 256065, 258113, 260161],
        [    66,   2114,   4162,  ..., 256066, 258114, 260162],
        ...,
        [   125,   2173,   4221,  ..., 256125, 258173, 260221],
        [   126,   2174,   4222,  ..., 256126, 258174, 260222],
        [   127,   2175,   4223,  ..., 256127, 258175, 260223]])
tensor([[   128,   2176,   4224,  ..., 256128, 258176, 260224],
        [   129,   2177,   4225,  ..., 256129, 258177, 260225],
        [   130,   2178,   4226,  ..., 256130, 258178, 260226],
        ...,

In [27]:
out

tensor([[-13.9087, -48.8563,  -2.3075,  ...,   0.0000,   0.0000,   0.0000],
        [-49.8847, -28.3231,  -9.4593,  ...,   0.0000,   0.0000,   0.0000],
        [-44.1483,  35.9161, -36.1946,  ...,   0.0000,   0.0000,   0.0000],
        ...,
        [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],
       device='cuda:0')

In [28]:
acc.shape

torch.Size([64, 128])

In [29]:
acc

tensor([[-13.9087, -48.8563,  -2.3075,  ...,  46.4978,   9.1448,   8.6787],
        [-49.8847, -28.3231,  -9.4593,  ..., -21.7991,  28.8791,  16.2089],
        [-44.1483,  35.9161, -36.1946,  ...,  35.2683, -21.0021,  31.0537],
        ...,
        [  0.3285,  34.0651, -19.2530,  ...,  16.3052, -23.8426,  13.2760],
        [-43.8314,  41.4952,  20.6124,  ..., -19.5780, -37.9010,  20.8687],
        [ 26.8188, -32.7362,  14.7291,  ...,  21.8951, -19.5081,  32.6895]],
       device='cuda:0')

In [30]:
BiT_blk_idxs.shape

torch.Size([64, 128])

In [31]:
acc

tensor([[-13.9087, -48.8563,  -2.3075,  ...,  46.4978,   9.1448,   8.6787],
        [-49.8847, -28.3231,  -9.4593,  ..., -21.7991,  28.8791,  16.2089],
        [-44.1483,  35.9161, -36.1946,  ...,  35.2683, -21.0021,  31.0537],
        ...,
        [  0.3285,  34.0651, -19.2530,  ...,  16.3052, -23.8426,  13.2760],
        [-43.8314,  41.4952,  20.6124,  ..., -19.5780, -37.9010,  20.8687],
        [ 26.8188, -32.7362,  14.7291,  ...,  21.8951, -19.5081,  32.6895]],
       device='cuda:0')

In [32]:
left

tensor([-13.9087,  -2.3075,  35.4975,  ...,  -8.8585, -40.6374, -19.5081],
       device='cuda:0')

In [33]:
left.shape, left.sum()

(torch.Size([4096]), tensor(-1046.0355, device='cuda:0'))

In [34]:
@triton.jit
def kernel_v2(A_ptr, BiT_ptr, out_ptr, M: tl.constexpr, N: tl.constexpr,
              K: tl.constexpr, BLOCK_SIZE_M: tl.constexpr,
              BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
              B_LEFT_SIZE: tl.constexpr):
    """
    Inputs
    ------
    A : (M, K) matrix.
    B : (K, 2*N) matrix -> Bl, Br : (K, N) matrices -> BlT, BrT : (N, K) matrices
        -> BiT : (2*N, K) matrix.
    out : (M, N) matrix.

    Kernel instance with coordinates (i, j) computes:
        out[(i:i+1) * BLOCK_SIZE_M, (j:j+1) * BLOCK_SIZE_N]
    from:
        A[(i:i+1) * BLOCK_SIZE_M, :] and
        Bl[:, (j:j+1) * BLOCK_SIZE_N] and Br[:, (j:j+1) * BLOCK_SIZE_N].
    Equivalent to:
        Bi[:, (j:j+1) * BLOCK_SIZE_N * 2]
    The computation is performed incrementally. The k^th iteration multiplies:
        A[(i:i+1) * BLOCK_SIZE_M, (k:k+1) * BLOCK_SIZE_K] with
        Bi[(k:k+1) * BLOCK_SIZE_K, (j:j+1) * BLOCK_SIZE_N * 2]


    """
    i = tl.program_id(0)
    j = tl.program_id(1)

    M_idxs = tl.arange(
        0, BLOCK_SIZE_M) + i * BLOCK_SIZE_M  # (i:i+1) * BLOCK_SIZE_M
    N2_idxs = tl.arange(
        0, 2 * BLOCK_SIZE_N) + j * 2 * BLOCK_SIZE_N  # (j:j+1) * BLOCK_SIZE_N * 2

    acc = tl.zeros((BLOCK_SIZE_M, 2 * BLOCK_SIZE_N), dtype=tl.float32) # (M', 2 * N')

    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):

        # (k:k+1) * BLOCK_SIZE_K
        K_idxs = tl.arange(0, BLOCK_SIZE_K) + k * BLOCK_SIZE_K

        A_blk_idxs = M_idxs[:, None] * K + K_idxs[None, :]
        A_blk = tl.load(A_ptr + A_blk_idxs) # (M', K')

        # Create a transposed index so we load BiT as Bi.
        BiT_blk_idxs = N2_idxs[None, :] * M + K_idxs[:, None]
        Bi_blk = tl.load(BiT_ptr + BiT_blk_idxs) # (K', 2 * N')

        # (M', K') @ (K', 2 * N') -> (M', 2 * N')
        acc += tl.dot(A_blk, Bi_blk)
        # tl.static_print("", K, acc.shape, A_blk.shape, Bi_blk.shape)
        # tl.device_print("N2_idxs:", N2_idxs)
        tl.device_print("M: ", M)
        tl.device_print("N2_idxs: ", N2_idxs)
        tl.device_print("K_idxs: ", K_idxs)
        tl.device_print("BiT_blk_idxs: ", BiT_blk_idxs)

    # At this point, even columns of `acc` are from multiplication with Bl and odd are Br.

    # left_right[0, :] is multiplication with Bl, left_right[1, :] with Br.
    left_right = tl.view(acc, (B_LEFT_SIZE, 2))
    left_mask_row = tl.arange(0, 2) % 2 == 0
    left_mask = tl.broadcast_to(left_mask_row[None, :], (B_LEFT_SIZE, 2))

    left = tl.sum(tl.where(left_mask, left_right, tl.zeros_like(left_right)),
                  axis=1)
    right = tl.sum(tl.where(~left_mask, left_right, tl.zeros_like(left_right)),
                   axis=1)

    out = tl.sigmoid(left) * tl.libdevice.tanh(right)

    N_idxs = tl.arange(0, BLOCK_SIZE_N) + j * BLOCK_SIZE_N
    out_idxs = M_idxs[:, None] * N + N_idxs[None, :]
    tl.store(out_ptr + out_idxs, tl.view(left, (BLOCK_SIZE_M, BLOCK_SIZE_N)))

    # out_idxs = M_idxs[:, None] * 2 * N + N2_idxs[None, :]
    # tl.device_print("", acc)
    # tl.store(out_ptr + out_idxs, acc)

In [35]:
def dispatch_v2(A, BiT, out, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K):
    M, K = A.shape
    N = BiT.shape[0] // 2
    print(f"{M=}, {N=}, {K=}")
    # assert out.shape == (M, N)
    B_LEFT_SIZE = BLOCK_SIZE_M * BLOCK_SIZE_N

    grid = M // BLOCK_SIZE_M, N // BLOCK_SIZE_N
    grid = (1, 1)
    kernel_v2[grid](A, BiT, out, M, N, K, BLOCK_SIZE_M, BLOCK_SIZE_N,
                    BLOCK_SIZE_K, B_LEFT_SIZE)

## Output Comparison

In [36]:
kv2_out = torch.zeros_like(torch_out)
dispatch_v2(A, BiT, kv2_out, 64, 64, 64)

M=1024, N=2048, K=2048


In [5]:
# mad(torch_out, kv2_out)

In [37]:
kv2_out

M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024


tensor([[-13.8840, -48.8040,  -2.2996,  ...,   0.0000,   0.0000,   0.0000],
        [ -9.3257,  58.1085,  21.5608,  ...,   0.0000,   0.0000,   0.0000],
        [ 51.1375, -37.5246, -10.8521,  ...,   0.0000,   0.0000,   0.0000],
        ...,
        [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],
       device='cuda:0')

In [38]:
kv2_out[kv2_out.nonzero(as_tuple=True)].reshape(64, 64)

tensor([[-13.8840, -48.8040,  -2.2996,  ...,   7.0649,  21.2516, -63.2098],
        [ -9.3257,  58.1085,  21.5608,  ...,  -6.0641,  49.0198,  67.8923],
        [ 51.1375, -37.5246, -10.8521,  ..., -33.0007, -48.6052, -50.2526],
        ...,
        [ 77.1428,  49.7725,  35.1742,  ..., -63.4789,  83.2342,  37.8914],
        [ 19.9355,  -9.4671,  46.8979,  ..., -12.8243, -19.3970,  54.0743],
        [ 24.7029, -41.7822,  53.8070,  ..., -15.4960,  13.0463, -12.8770]],
       device='cuda:0')

In [39]:
kv2_out.sum()

tensor(1195.5902, device='cuda:0')

In [40]:
kv2_out[0, :50]

tensor([ -13.8840,  -48.8040,   -2.2996,  -38.0388,  -49.8608,  -28.2969,
          -9.4357, -130.2953,  -44.0970,   35.8804,  -36.1738,   33.2724,
          -3.3789,  -62.8795,   25.7825,   44.2237,  -15.1689,  -68.1102,
          72.6834,    8.2080,   57.8004, -102.8315,  -39.6391,   81.6484,
          87.1128,    2.1834,    9.2295,   24.1082,  -11.6941,   33.3712,
          22.8068,  104.0304,    3.6490,  -39.3678,  -52.6911,   -3.4252,
          -9.3545,   25.1722,  -66.7197,   27.3832,  -16.1054,   -6.8569,
           6.7390,   51.5501,  -31.7114,  -49.9286,   29.7636,   -9.5393,
          13.4559,  -15.2010], device='cuda:0')

In [41]:
1599367490 / 2048

780941.1572265625

In [42]:
format(1599367490, "b")

'1011111010101000110100101000010'

In [43]:
format(64512, "b")

'1111110000000000'

In [45]:
N2_idxs = torch.arange(0, 128)
K_idxs = torch.arange(0, 64)
M = 1024

BiT_blk_idxs = N2_idxs[None, :] * M + K_idxs[:, None]

In [46]:
BiT_blk_idxs

tensor([[     0,   1024,   2048,  ..., 128000, 129024, 130048],
        [     1,   1025,   2049,  ..., 128001, 129025, 130049],
        [     2,   1026,   2050,  ..., 128002, 129026, 130050],
        ...,
        [    61,   1085,   2109,  ..., 128061, 129085, 130109],
        [    62,   1086,   2110,  ..., 128062, 129086, 130110],
        [    63,   1087,   2111,  ..., 128063, 129087, 130111]])

In [47]:
131072 / 2

65536.0

In [None]:
kout = torch.zeros_like(BiT.T)
dispatch_v2(A, BiT, kout, 64, 64, 64)

M=1024, N=2048, K=2048


M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024
M: 1024


In [None]:
kout

tensor([[-13.8840,   0.0000, -49.8608,  ...,   0.0000,   0.0000,   0.0000],
        [-48.8040,   0.0000, -28.2969,  ...,   0.0000,   0.0000,   0.0000],
        [ -2.2996,   0.0000,  -9.4357,  ...,   0.0000,   0.0000,   0.0000],
        ...,
        [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],
       device='cuda:0')

In [None]:
kout[:5, :5]

tensor([[ -13.8840,    0.0000,  -49.8608,    0.0000,  -44.0970],
        [ -48.8040,    0.0000,  -28.2969,    0.0000,   35.8804],
        [  -2.2996,    0.0000,   -9.4357,    0.0000,  -36.1738],
        [ -38.0388,    0.0000, -130.2953,    0.0000,   33.2724],
        [  35.4611,    0.0000,   14.9742,    0.0000,    6.1525]],
       device='cuda:0')

In [None]:
(A @ BiT.T)

tensor([[-13.9087, -48.8564,  -2.3076,  ...,  25.4502,  92.0306, -54.3416],
        [-49.8847, -28.3231,  -9.4593,  ..., -33.7343, -62.9675,  30.6559],
        [-44.1483,  35.9161, -36.1945,  ...,   8.3418, 100.3384,   7.5319],
        ...,
        [-10.6014, -64.9981, -24.2593,  ..., -66.8915,  30.9633,  18.4129],
        [ 12.0531,   5.7368, -19.7057,  ...,  27.1508,  55.5883, -68.4664],
        [-31.5721, -25.5214, -50.6414,  ..., -43.7947, -40.1447,  47.0789]],
       device='cuda:0')

In [None]:
kout.shape

torch.Size([2048, 4096])

In [None]:
(A @ BiT.T).shape

torch.Size([1024, 4096])

In [None]:
(kout - (A @ BiT.T)).abs()

RuntimeError: The size of tensor a (2048) must match the size of tensor b (1024) at non-singleton dimension 0

In [None]:
kv2_out

tensor([[ -24.9350,   40.8251,  -48.7699,  ...,  -53.9297,  -56.3020,
          -76.1971],
        [ -26.5357,  -60.3090,  -54.5616,  ...,   36.5730,   46.6989,
            1.3354],
        [   3.1034,  -18.8726,   62.4309,  ...,  -44.4121,  -14.8037,
           68.7750],
        ...,
        [ -42.1502,  -33.6206,   22.9567,  ...,   64.9995,  -11.6071,
           53.7004],
        [   6.7739,   53.8059,  -51.6598,  ..., -112.1690,   64.8516,
           51.9093],
        [ -41.7245,   -6.9654,  -15.8447,  ...,   -3.6583,    4.6197,
          -15.5340]], device='cuda:0')

In [None]:
Cl = torch.chunk(torch.matmul(A, B), 2, dim=1)[0]
Cl

tensor([[-24.9423,  40.8552, -48.8122,  ...,   1.3447,  15.2233, -42.7197],
        [ 44.9170, -42.2078, -14.1963,  ..., -13.5295,  34.4011,  32.6486],
        [  7.5865,  59.2033,  -3.1652,  ..., -26.3880,  98.2293,   7.2516],
        ...,
        [ 13.1106,  -0.2458, -47.1841,  ..., -30.5524,  36.6943,  58.9802],
        [ 24.9614,  41.7778, -70.7494,  ...,   7.6766,  12.5642, -42.0483],
        [-17.0046, -21.3861,  31.9466,  ...,   8.2247,   2.7370, -11.2509]],
       device='cuda:0')

In [None]:
kv2_out[0, :10]

tensor([-24.9350,  40.8251, -48.7699,  29.0279,  44.8871, -42.1707, -14.1934,
        -49.6960,   7.5818,  59.1495], device='cuda:0')

In [None]:
Cl[0, :10]

tensor([-24.9423,  40.8552, -48.8122,  29.0392,  50.7437,  15.4867, 102.5665,
        -12.7903,  51.2901,  65.8413], device='cuda:0')

In [None]:
kv2_out.sum(), Cl.sum()

(tensor(-124058.6406, device='cuda:0'), tensor(61682.1367, device='cuda:0'))

In [None]:
torch_out

tensor([[ 1.4713e-11,  1.0000e+00, -6.3259e-22,  ...,  7.9326e-01,
          1.0000e+00,  2.7994e-19],
        [ 9.9987e-01,  4.6706e-19, -6.8333e-07,  ...,  1.2704e-06,
         -1.0000e+00,  1.0000e+00],
        [-9.9949e-01,  1.0000e+00, -4.0496e-02,  ..., -3.4391e-12,
          1.0000e+00,  9.9929e-01],
        ...,
        [-1.0000e+00,  4.3887e-01,  3.2225e-21,  ..., -5.3858e-14,
          1.0000e+00, -1.0000e+00],
        [ 1.0000e+00,  1.0000e+00, -1.8791e-31,  ...,  9.9954e-01,
          1.0000e+00,  5.4782e-19],
        [-4.1210e-08,  5.1535e-10, -9.9819e-01,  ..., -9.9973e-01,
          9.3918e-01,  1.2995e-05]], device='cuda:0')

In [None]:
mad(torch_out, kv2_out)

2.0

## Perf Comparison

In [None]:
v2_perf = grid_search(
    partial(dispatch_v2, A, BiT, kv2_out),
    BLOCK_SIZE_M=(32, 1024),
    BLOCK_SIZE_N=(32, 1024),
    BLOCK_SIZE_K=(16, 1024),
)

In [None]:
results_to_df(v2_perf)

Unnamed: 0,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,Time (us)
0,128,64,16,328.704
1,128,32,32,374.783993
2,64,64,32,377.855986
3,128,32,16,378.879994
4,64,64,16,384.000003
5,256,32,16,429.055989
6,64,128,16,430.079997
7,64,32,16,450.560004
8,64,32,32,473.087996
9,32,128,16,534.528017


# Pytorch Reference

In [None]:
import torch.nn.functional as F

def torch_splac(x, Wt, b):
    Z = F.linear(x, Wt, b) # (n_batch, n_feat_out).
    z1, z2 = torch.chunk(Z, chunks=2, dim=-1) # Both (n_batch, n_feat_out/2).
    A = z1 * F.gelu(z2) # (n_batch, n_feat_out/2).
    return A

# Triton reshape-sum split test

In [None]:
import triton
import triton.language as tl

@triton.jit
def reshape_sum_split_kernel(x_ptr, out_ptr, X_ROWS : tl.constexpr, X_COLS: tl.constexpr):

    x_row_idxs = tl.arange(0, X_ROWS)
    x_col_idxs = tl.arange(0, X_COLS)
    x_idxs = x_row_idxs[None, :] * X_COLS + x_col_idxs[:, None]
    x = tl.load(x_ptr + x_idxs)

    tl.store(out_ptr + x_idxs, x)

def reshape_sum_split_dispatch(x: torch.Tensor):

    X_ROWS, X_COLS = x.shape
    out = torch.empty((X_ROWS, X_COLS), device="cuda", dtype=torch.float32)

    reshape_sum_split_kernel[(1,)](x, out, X_ROWS, X_COLS)

    return out

In [None]:
x = torch.arange(4 * 8, device="cuda").reshape(4, 8).float()
x

tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11., 12., 13., 14., 15.],
        [16., 17., 18., 19., 20., 21., 22., 23.],
        [24., 25., 26., 27., 28., 29., 30., 31.]], device='cuda:0')

In [None]:
reshape_sum_split_dispatch(x)

tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11., 12., 13., 14., 15.],
        [16., 17., 18., 19., 20., 21., 22., 23.],
        [24., 25., 26., 27., 28., 29., 30., 31.]], device='cuda:0')

Need to do even * odd

In [None]:
y = x.reshape((-1, 2))
y

tensor([[ 0.,  1.],
        [ 2.,  3.],
        [ 4.,  5.],
        [ 6.,  7.],
        [ 8.,  9.],
        [10., 11.],
        [12., 13.],
        [14., 15.],
        [16., 17.],
        [18., 19.],
        [20., 21.],
        [22., 23.],
        [24., 25.],
        [26., 27.],
        [28., 29.],
        [30., 31.]], device='cuda:0')

In [None]:
even_mask = y % 2 == 0
odd_mask = ~even_mask
even_mask, odd_mask

(tensor([[ True, False],
         [ True, False],
         [ True, False],
         [ True, False],
         [ True, False],
         [ True, False],
         [ True, False],
         [ True, False],
         [ True, False],
         [ True, False],
         [ True, False],
         [ True, False],
         [ True, False],
         [ True, False],
         [ True, False],
         [ True, False]], device='cuda:0'),
 tensor([[False,  True],
         [False,  True],
         [False,  True],
         [False,  True],
         [False,  True],
         [False,  True],
         [False,  True],
         [False,  True],
         [False,  True],
         [False,  True],
         [False,  True],
         [False,  True],
         [False,  True],
         [False,  True],
         [False,  True],
         [False,  True]], device='cuda:0'))

In [None]:
evens = torch.sum(torch.where(even_mask, y, torch.zeros_like(y)), dim=1, keepdim=True)
evens

tensor([[ 0.],
        [ 2.],
        [ 4.],
        [ 6.],
        [ 8.],
        [10.],
        [12.],
        [14.],
        [16.],
        [18.],
        [20.],
        [22.],
        [24.],
        [26.],
        [28.],
        [30.]], device='cuda:0')

In [None]:
odds = torch.sum(torch.where(odd_mask, y, torch.zeros_like(y)), dim=1, keepdim=True)
odds

tensor([[ 1.],
        [ 3.],
        [ 5.],
        [ 7.],
        [ 9.],
        [11.],
        [13.],
        [15.],
        [17.],
        [19.],
        [21.],
        [23.],
        [25.],
        [27.],
        [29.],
        [31.]], device='cuda:0')

In [None]:
evens * odds

tensor([[  0.],
        [  6.],
        [ 20.],
        [ 42.],
        [ 72.],
        [110.],
        [156.],
        [210.],
        [272.],
        [342.],
        [420.],
        [506.],
        [600.],
        [702.],
        [812.],
        [930.]], device='cuda:0')

In [None]:
torch.sum(1, keepdim=True)

TypeError: sum() received an invalid combination of arguments - got (int, keepdim=bool), but expected one of:
 * (Tensor input, *, torch.dtype dtype)
 * (Tensor input, tuple of ints dim, bool keepdim, *, torch.dtype dtype, Tensor out)
 * (Tensor input, tuple of names dim, bool keepdim, *, torch.dtype dtype, Tensor out)


In [None]:
x.reshape(12)

tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11.])

In [None]:
x.T.reshape(12)

tensor([ 0.,  4.,  8.,  1.,  5.,  9.,  2.,  6., 10.,  3.,  7., 11.])

In [None]:
x = torch.zeros(3, 4)
x

tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])

In [None]:
x[:, ::2] = 1
x

tensor([[1., 0., 1., 0.],
        [1., 0., 1., 0.],
        [1., 0., 1., 0.]])

Need to find a way to multiply odd rows by even rows - and therefore output all zeros



In [None]:
y = x.reshape((-1, 2))
y

tensor([[1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.]])

In [None]:
x.reshape((1, -1)).T.ravel()

tensor([1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0.])