In [None]:
import logging 

logging.basicConfig(
    level=logging.INFO,
    format="[%(asctime)s][%(name)s:%(lineno)d][%(levelname)s] - %(message)s",
)

from dacite import from_dict 
from omegaconf import OmegaConf

from mlstm_kernels.utils.benchmark.param_handling import BenchmarkConfig
from mlstm_kernels.utils.benchmark.run_benchmark import run_benchmarks
from mlstm_kernels.utils.benchmark.benchmarks.training_kernel_benchmarks import create_training_kernel_benchmark

In [2]:
# Notes:
# - Torch compile actually makes kernels (xl_chunk_siging) very slow. Something not working properly there.

In [3]:
sequence_length_limits = [9, 17]
sequence_lengths = list(map(lambda i: 1<<i, range(*sequence_length_limits)))
batch_sizes = list(map(lambda i: 1<<i, reversed(range(sequence_length_limits[1] - sequence_length_limits[0]))))


In [None]:
sequence_lengths, batch_sizes

In [5]:
S = 8192
DHQK = 64 #256  # *2
DHHV = 128 #512  # *2
NH = 32 #8
B = 8
D = NH * DHHV

In [6]:
cfg_yaml = f"""
vary_type: grid
vary_params: {dict()}
fixed_params: 
  batch_size: {B}
  sequence_length: {S}
  num_heads: {NH}
  head_dim_qk: {DHQK}
  head_dim_v: {DHHV}
  warmup: 10
  rep: 25

kernel_specs:
  - kernel_name: "chunkwise--triton_limit_chunk"
    fwbw: False
    dtype: bfloat16
    additional_params:
      chunk_size: 64
  - kernel_name: "chunkwise--triton_limit_chunk"
    fwbw: True
    dtype: bfloat16
    additional_params:
      chunk_size: 64
  # - kernel_name: "chunkwise--triton_xl_chunk"
  #   fwbw: False
  #   dtype: bfloat16
  #   additional_params:
  #     chunk_size: 128
  # - kernel_name: "chunkwise--triton_xl_chunk"
  #   fwbw: True
  #   dtype: bfloat16
  #   additional_params:
  #     chunk_size: 128
  - kernel_name: "chunkwise--triton_xl_chunk_siging"
    fwbw: False
    dtype: bfloat16
    use_torch_compile: False
    additional_params:
      chunk_size: 128
      normalize: False
  - kernel_name: "chunkwise--triton_xl_chunk_siging"
    fwbw: True
    dtype: bfloat16
    use_torch_compile: False
    additional_params:
      chunk_size: 128
      normalize: False
  # - kernel_name: "chunkwise--triton_xl_chunk_siging"
  #   fwbw: False
  #   dtype: bfloat16
  #   additional_params:
  #     chunk_size: 128
  #     normalize: True
  # - kernel_name: "chunkwise--triton_xl_chunk_siging"
  #   fwbw: True
  #   dtype: bfloat16
  #   additional_params:
  #     chunk_size: 128
  #     normalize: True

  # - kernel_name: "chunk_gla"
  #   dtype: bfloat16
  #   use_torch_compile: False
  #   fwbw: False
  # - kernel_name: "chunk_gla"
  #   dtype: bfloat16
  #   use_torch_compile: False
  #   fwbw: True
  # - kernel_name: "fused_chunk_gla"
  #   dtype: bfloat16
  #   use_torch_compile: False
  #   fwbw: False
  # - kernel_name: "fused_chunk_gla"
  #   dtype: bfloat16
  #   use_torch_compile: False
  #   fwbw: True
  - kernel_name: "chunk_simple_gla"
    dtype: bfloat16
    use_torch_compile: False
    fwbw: False
  - kernel_name: "chunk_simple_gla"
    dtype: bfloat16
    use_torch_compile: False
    fwbw: True

  # - kernel_name: "mamba"
  #   dtype: bfloat16
  #   fwbw: False
  #   use_torch_compile: False
  #   additional_params:
  #     num_heads: 1
  #     head_dim_v: {2*D}
  #     head_dim_qk: 16
  # - kernel_name: "mamba"
  #   dtype: bfloat16
  #   fwbw: True
  #   use_torch_compile: False
  #   additional_params:
  #     num_heads: 1
  #     head_dim_v: {2*D}
  #     head_dim_qk: 16
      
  # - kernel_name: "mamba2"
  #   dtype: bfloat16
  #   fwbw: False
  #   use_torch_compile: False
  #   additional_params:
  #     num_heads: {2*D//64}
  #     head_dim_v: 64
  #     head_dim_qk: 64
  # - kernel_name: "mamba2"
  #   dtype: bfloat16
  #   fwbw: True
  #   use_torch_compile: False
  #   additional_params:
  #     num_heads: {2*D//64}
  #     head_dim_v: 64
  #     head_dim_qk: 64

  # - kernel_name: "mamba2_noconv"
  #   dtype: bfloat16
  #   fwbw: False
  #   use_torch_compile: False
  #   additional_params:
  #     num_heads: {2*D//64}
  #     head_dim_v: 64
  #     head_dim_qk: 64
  # - kernel_name: "mamba2_noconv"
  #   dtype: bfloat16
  #   fwbw: True
  #   use_torch_compile: False
  #   additional_params:
  #     num_heads: {2*D//64}
  #     head_dim_v: 64
  #     head_dim_qk: 64

  
benchmark_name: "quick_kernel_benchmark"
"""
cfg_baseline = from_dict(
    data_class=BenchmarkConfig, data=OmegaConf.to_container(OmegaConf.create(cfg_yaml))
)

In [None]:
res_df = run_benchmarks(cfg_baseline, benchmark_creator=create_training_kernel_benchmark, run_garbage_collection=False)

In [None]:
res_df.filter(regex="R--.*", axis=1).T

In [None]:
res_df.filter(regex="M--.*", axis=1).T

In [2]:
S = 8192
DHQK = 128  # *2
DHHV = 128  # *2
NH = 32
B = 8
D = NH * DHHV
fwbw = True

cfg_yaml = f"""
vary_type: grid
vary_params: {dict()}
fixed_params: 
  batch_size: {B}
  sequence_length: {S}
  num_heads: {NH}
  head_dim_qk: {DHQK}
  head_dim_v: {DHHV}
  warmup: 10
  rep: 25

kernel_specs:
  - kernel_name: "torch_flash"
    dtype: bfloat16
    fwbw: {fwbw}

  - kernel_name: "torch_cudnn"
    dtype: bfloat16
    fwbw: {fwbw}

  - kernel_name: "flashattn3"
    dtype: bfloat16
    fwbw: {fwbw}
    use_torch_compile: False

  
benchmark_name: "quick_kernel_benchmark"
"""
cfg_baseline = from_dict(
    data_class=BenchmarkConfig, data=OmegaConf.to_container(OmegaConf.create(cfg_yaml))
)

In [None]:
res_df = run_benchmarks(cfg_baseline, benchmark_creator=create_training_kernel_benchmark, run_garbage_collection=False)

In [None]:
res_df.filter(regex="R--.*", axis=1).T

In [None]:
res_df.filter(regex="M--.*", axis=1).T