In [1]:
from sparsemm_kernels.autotune import autotune
from sparsemm_kernels.up_dejavu import bench_sparsemm_up_dejavu
from sparsemm_kernels.up_neo import bench_sparsemm_up_neo
from sparsemm_kernels.up_dense import bench_sparsemm_up_dense
from sparsemm_kernels.up_cats import bench_sparsemm_up_cats
from sparsemm_kernels.down_dense import bench_sparsemm_down_dense

BATCH_SIZE = 512
EMBED_DIM = 5120
HIDDEN_DIM = 13824
P = 1
Q = 10000

In [None]:
bench_sparsemm_up_dense(BATCH_SIZE, EMBED_DIM, HIDDEN_DIM, P, Q)

In [None]:
autotune(
    bench_sparsemm_up_dejavu,
    (BATCH_SIZE, EMBED_DIM, HIDDEN_DIM, P, Q),
    {
        "BLOCK_SIZE_M": [16, 32, 64, 128, 256],
        "BLOCK_SIZE_K": [16, 32, 64, 128, 256],
        "BLOCK_SIZE_Q": [16, 32, 64, 128, 256],
        "num_stages": [2, 3, 4, 5],
        "num_warps": [4, 8],
    },
    n_trials=100
)

In [None]:
autotune(
    bench_sparsemm_up_neo,
    (BATCH_SIZE, EMBED_DIM, HIDDEN_DIM, P, Q),
    {
        "BLOCK_SIZE_M": [16, 32, 64, 128, 256],
        "BLOCK_SIZE_K": [16, 32, 64, 128, 256],
        "BLOCK_SIZE_Q": [16, 32, 64, 128, 256],
        "GROUP_SIZE_Q": [1, 2, 4, 8, 16],
        "num_stages": [2, 3, 4, 5],
        "num_warps": [4, 8],
    },
    n_trials=100
)

In [None]:
autotune(
    bench_sparsemm_up_cats,
    (BATCH_SIZE, EMBED_DIM, HIDDEN_DIM, P, Q),
    {
        "BLOCK_SIZE_M": [16, 32, 64, 128, 256],
        "BLOCK_SIZE_K": [16, 32, 64, 128, 256],
        "BLOCK_SIZE_N": [16, 32, 64, 128, 256],
        "GROUP_SIZE_N": [1, 2, 4, 8, 16],
        "num_stages": [2, 3, 4, 5],
        "num_warps": [4, 8],
    },
    n_trials=100
)

In [None]:
bench_sparsemm_down_dense(BATCH_SIZE, EMBED_DIM, HIDDEN_DIM, P, Q)

In [None]:
from sparsemm_kernels.down_dejavu import bench_sparsemm_down_dejavu

autotune(
    bench_sparsemm_down_dejavu,
    (BATCH_SIZE, EMBED_DIM, HIDDEN_DIM, P, Q),
    {
        "BLOCK_SIZE_M": [16, 32, 64, 128, 256],
        "BLOCK_SIZE_Q": [16, 32, 64, 128, 256],
        "BLOCK_SIZE_N": [16, 32, 64, 128, 256],
        "num_stages": [2, 3, 4, 5],
        "num_warps": [4, 8],
    },
    n_trials=100
)

In [None]:
from sparsemm_kernels.down_neo import bench_sparsemm_down_neo

autotune(
    bench_sparsemm_down_neo,
    (BATCH_SIZE, EMBED_DIM, HIDDEN_DIM, P, Q),
    {
        "BLOCK_SIZE_M": [16, 32, 64, 128, 256],
        "BLOCK_SIZE_Q": [16, 32, 64, 128, 256],
        "BLOCK_SIZE_N": [16, 32, 64, 128, 256],
        "GROUP_SIZE_N": [1, 2, 4, 8, 16],
        "num_stages": [2, 3, 4, 5],
        "num_warps": [4, 8],
    },
    n_trials=100
)