In [None]:
import sys

sys.path.append("..")

import torch

In [None]:
from mlstm_kernels.torch.backend_module import mLSTMBackendConfig, mLSTMBackend

In [None]:
# we use the mLSTMexp TFLA kernel
# we also configure to use the triton step kernel for inference
mlstm_backend_config = mLSTMBackendConfig(
    chunkwise_kernel="chunkwise--triton_xl_chunk",
    sequence_kernel="native_sequence__triton",
    step_kernel="triton",
    chunk_size=256,
    return_last_states=False,
)

mlstm_backend = mLSTMBackend(mlstm_backend_config)

In [None]:
# run the kernel
DEVICE = torch.device("cuda")
DTYPE = torch.bfloat16
B = 2
S = 512
DHQK = 128
DHHV = 256
NH = 4

# create input tensors
torch.manual_seed(1)
matQ = torch.randn((B, NH, S, DHQK), dtype=DTYPE, device=DEVICE)
matK = torch.randn((B, NH, S, DHQK), dtype=DTYPE, device=DEVICE)
matV = torch.randn((B, NH, S, DHHV), dtype=DTYPE, device=DEVICE)
vecI = torch.randn((B, NH, S), dtype=DTYPE, device=DEVICE)
vecF = 3.0 + torch.randn((B, NH, S), dtype=DTYPE, device=DEVICE)

In [None]:
matH1 = mlstm_backend(q=matQ, k=matK, v=matV, i=vecI, f=vecF)

In [None]:
# directly import mLSTMexp TFLA kernel
from mlstm_kernels.torch.chunkwise.triton_xl_chunk import mlstm_chunkwise__xl_chunk

In [None]:
matH2 = mlstm_chunkwise__xl_chunk(
    q=matQ, k=matK, v=matV, i=vecI, f=vecF, return_last_states=False, chunk_size=256
)

In [None]:
torch.allclose(matH1, matH2, atol=1e-5, rtol=1e-3)