In [1]:
import sys

import torch
from torch.utils.flop_counter import FlopCounterMode

sys.path.append("../..")
from mlstm_kernels.components.conv import CausalConv1d, CausalConv1dConfig
from mlstm_kernels.mlstm import get_available_mlstm_kernels, get_mlstm_kernel
from mlstm_kernels.mlstm.backend_module import mLSTMBackend, mLSTMBackendConfig

In [2]:
get_available_mlstm_kernels()

['recurrent_step--step_torch_autograd',
 'recurrent_step--step_triton',
 'recurrent_step--step_fused_triton',
 'recurrent_sequence--sequence_torch_autograd',
 'chunkwise--torch_autograd',
 'chunkwise--torch_ownbw',
 'chunkwise--max_triton',
 'chunkwise--max_triton_v1',
 'chunkwise--max_triton_v2',
 'chunkwise--max_triton_v3',
 'chunkwise--triton',
 'chunkwise--stable_triton',
 'parallel--torch_autograd',
 'parallel--torch_ownbw',
 'parallel--stable_torch_autograd',
 'parallel--stable_torch_ownbw',
 'parallel--triton']

In [3]:
mlstm_backend_module = mLSTMBackend(mLSTMBackendConfig(kernel_name="chunkwise--torch_ownbw", chunk_size=128))

In [10]:
S = 1  # seq len
B = 1  # batch size
NH = 1  # num heads
DHQK = 1024  # dim per head
DHV = 1024  # dim per head

EPS = 0.0

vecI_offset = 0.0
vecF_offset = 6.0
DTYPE = torch.bfloat16  # torch.bfloat16
DEVICE = torch.device("cuda:0")

In [11]:
torch.manual_seed(0)
matQ = torch.randn((B, NH, S, DHQK), dtype=DTYPE, device=DEVICE)
matK = torch.randn((B, NH, S, DHQK), dtype=DTYPE, device=DEVICE)
matV = torch.randn((B, NH, S, DHV), dtype=DTYPE, device=DEVICE)
# vecI = 0.00001 * torch.randn((B, NH, S), dtype=DTYPE, device=DEVICE)
# vecF = -30. + torch.randn((B, NH, S), dtype=DTYPE, device=DEVICE)
vecI = vecI_offset + torch.randn((B, NH, S), dtype=DTYPE, device=DEVICE)
vecF = vecF_offset + torch.randn((B, NH, S), dtype=DTYPE, device=DEVICE)

In [12]:
with FlopCounterMode():
    out = mlstm_backend_module(matQ, matK, matV, vecI, vecF, EPS)

Module      FLOP    % Total
--------  ------  ---------
Global         0         0%


AssertionError: Sequence length 1 is not divisible by chunk size 128.

In [25]:
S = 10  # seq len
B = 1  # batch size
DHQK = 1  # dim per head
kernel_size = 4
z_c = torch.randn((B, S, DHQK), dtype=DTYPE, device=DEVICE)
conv1d = CausalConv1d(CausalConv1dConfig(feature_dim=DHQK, kernel_size=kernel_size)).to(dtype=DTYPE, device=DEVICE)
conv1d

CausalConv1d(
  (conv): Conv1d(1, 1, kernel_size=(4,), stride=(1,), padding=(3,))
)

In [26]:
with FlopCounterMode():
    out_conv = conv1d(z_c)

Module                  FLOP    % Total
--------------------  ------  ---------
CausalConv1d             104    100.00%
 - aten.convolution      104    100.00%
 CausalConv1d.conv       104    100.00%
  - aten.convolution     104    100.00%


In [29]:
flop_conv = 2 * kernel_size * (S + kernel_size - 1) * DHQK
flop_conv

104