In [1]:
import torch
from sparsemm_kernels.up_dejavu import sparsemm_up_dejavu
from sparsemm_kernels.up_dense import sparsemm_up_dense
from sparsemm_kernels.up_neo import sparsemm_up_neo
from sparsemm_kernels.up_torchsparse import sparsemm_up_torchsparse
from sparsemm_kernels.up_cats import sparsemm_up_cats
from sparsemm_kernels.utils import idx_to_mask

BATCH_SIZE = 512
EMBED_DIM = 5120
HIDDEN_DIM = 13824
P = 1
Q = 10000

X = torch.empty((BATCH_SIZE, EMBED_DIM), device="cuda", dtype=torch.float16)
Wup = torch.empty((HIDDEN_DIM, EMBED_DIM), device="cuda", dtype=torch.float16)
Wgate = torch.empty((HIDDEN_DIM, EMBED_DIM), device="cuda", dtype=torch.float16)

IDX = torch.randint(0, HIDDEN_DIM, (P, Q), device="cuda", dtype=torch.int32)
IDX = torch.sort(IDX, dim=1)[0]
MASK = idx_to_mask(IDX, Q, HIDDEN_DIM)

torch.nn.init.xavier_uniform_(X)
torch.nn.init.xavier_uniform_(Wup)
torch.nn.init.xavier_uniform_(Wgate)

H_torchsparse = sparsemm_up_torchsparse(X, Wup, Wgate, IDX)
H_dejavu = sparsemm_up_dejavu(
    X, Wup, Wgate, IDX,
    ACT_TYPE="fatrelu", tune=True,
    BLOCK_SIZE_M=16, BLOCK_SIZE_K=16, BLOCK_SIZE_Q=16,
    num_stages=4, num_warps=4,
)
H_neo = sparsemm_up_neo(
    X, Wup, Wgate, IDX,
    ACT_TYPE="fatrelu", tune=True,
    BLOCK_SIZE_M=16, BLOCK_SIZE_K=16, BLOCK_SIZE_Q=16, GROUP_SIZE_Q=1,
    num_stages=4, num_warps=4,
)
H_cats = sparsemm_up_cats(
    X, Wup, Wgate, MASK,
    ACT_TYPE="fatrelu", tune=True,
    BLOCK_SIZE_M=16, BLOCK_SIZE_N=16, BLOCK_SIZE_K=16, GROUP_SIZE_N=1,
    num_stages=4, num_warps=4,
)
H_zero = torch.zeros_like(H_torchsparse)

Wup_masked = Wup * MASK.reshape(-1, 1)
Wgate_masked = Wgate * MASK.reshape(-1, 1)

H_torchcats = sparsemm_up_dense(X, Wup_masked, Wgate_masked)

print(torch.allclose(H_torchsparse, H_dejavu, atol=1e-3, rtol=1e-3))
print(torch.allclose(H_torchsparse, H_neo, atol=1e-3, rtol=1e-3))
print(torch.allclose(H_torchsparse, H_zero, atol=1e-3, rtol=1e-3))

print(torch.allclose(H_torchcats, H_cats, atol=1e-3, rtol=1e-3))
# # print(H_torchsparse)
# # print(H_dejavu)
# print(torch.max(torch.abs(H_torchsparse - H_dejavu)))

True
True
False
True


In [2]:
from sparsemm_kernels.down_dejavu import sparsemm_down_dejavu
from sparsemm_kernels.down_splitk import sparsemm_down_splitk
from sparsemm_kernels.down_torchsparse import sparsemm_down_torchsparse


BATCH_SIZE = 512
EMBED_DIM = 5120
HIDDEN_DIM = 13824
P = 1
Q = 10000

H = torch.empty((BATCH_SIZE, Q), device="cuda", dtype=torch.float16)
Wdown = torch.empty((HIDDEN_DIM, EMBED_DIM), device="cuda", dtype=torch.float16)

IDX = torch.randint(0, EMBED_DIM, (P, Q), device="cuda", dtype=torch.int32)
IDX = torch.sort(IDX, dim=1)[0]
MASK = idx_to_mask(IDX, Q, EMBED_DIM)

torch.nn.init.xavier_uniform_(H)
torch.nn.init.xavier_uniform_(Wdown)

H_torchsparse = sparsemm_down_torchsparse(H, Wdown, IDX)
H_dejavu = sparsemm_down_dejavu(
    H, Wdown, IDX, tune=True,
    BLOCK_SIZE_M=16, BLOCK_SIZE_N=16, BLOCK_SIZE_Q=16,
    num_stages=4, num_warps=4,
)
H_splitk = sparsemm_down_splitk(
    H.T, Wdown, IDX, tune=True,
    BLOCK_SIZE_M=16, BLOCK_SIZE_Q=64, BLOCK_SIZE_N=32, GROUP_SIZE_Q=1,
    num_stages=4, num_warps=4,
).T
H_zero = torch.zeros_like(H_torchsparse)

print(torch.allclose(H_torchsparse, H_dejavu, atol=1e-3, rtol=1e-3))
print(torch.allclose(H_torchsparse, H_splitk, atol=1e-3, rtol=1e-3))
print(torch.allclose(H_torchsparse, H_zero, atol=1e-3, rtol=1e-3))

True
True
False
