In [None]:
import logging 

logging.basicConfig(
    level=logging.INFO,
    format="[%(asctime)s][%(name)s:%(lineno)d][%(levelname)s] - %(message)s",
)

import torch

import sys
sys.path.append("..")
from mlstm_kernels.baselines.lightning_attention.lightning_attn2 import lightning_attn2
from mlstm_kernels.baselines.lightning_attention.utils import _build_slope_tensor

from dacite import from_dict 
from omegaconf import OmegaConf

from mlstm_kernels.utils.benchmark.param_handling import BenchmarkConfig
from mlstm_kernels.utils.benchmark.run_benchmark import run_benchmarks
from mlstm_kernels.utils.benchmark.benchmarks.training_kernel_benchmarks import create_training_kernel_benchmark

### quick test if it is runnable

In [2]:
b = 4
h = 8
n = 512
d = 128
dtype = torch.bfloat16
device = torch.device("cuda:0")

In [3]:
torch.manual_seed(0)
q = torch.randn((b, h, n, d), dtype=dtype, device=device).requires_grad_()
k = torch.randn((b, h, n, d), dtype=dtype, device=device).requires_grad_()
v = torch.randn((b, h, n, d), dtype=dtype, device=device).requires_grad_()
s = _build_slope_tensor(h).to(q.device).to(torch.float32)

In [4]:
out = lightning_attn2(q, k, v, s)

In [None]:
out.shape

### benchmark

In [6]:
S = 8192
DHQK = 128 #256  # *2
DHHV = 128 #512  # *2
NH = 32 #8
B = 8
D = NH * DHHV

In [7]:
cfg_yaml = f"""
vary_type: grid
vary_params: {dict()}
fixed_params: 
  batch_size: {B}
  sequence_length: {S}
  num_heads: {NH}
  head_dim_qk: {DHQK}
  head_dim_v: {DHHV}
  warmup: 10
  rep: 25

kernel_specs:
  - kernel_name: "chunkwise--triton_limit_chunk"
    fwbw: False
    dtype: bfloat16
    additional_params:
      chunk_size: 64
  - kernel_name: "chunkwise--triton_limit_chunk"
    fwbw: True
    dtype: bfloat16
    additional_params:
      chunk_size: 64
  # - kernel_name: "chunkwise--triton_xl_chunk"
  #   fwbw: False
  #   dtype: bfloat16
  #   additional_params:
  #     chunk_size: 128
  # - kernel_name: "chunkwise--triton_xl_chunk"
  #   fwbw: True
  #   dtype: bfloat16
  #   additional_params:
  #     chunk_size: 128
  - kernel_name: "chunkwise--triton_xl_chunk_siging"
    fwbw: False
    dtype: bfloat16
    use_torch_compile: False
    additional_params:
      chunk_size: 128
      normalize: False
  - kernel_name: "chunkwise--triton_xl_chunk_siging"
    fwbw: True
    dtype: bfloat16
    use_torch_compile: False
    additional_params:
      chunk_size: 128
      normalize: False

  - kernel_name: "chunkwise--triton_xl_chunk_siging"
    fwbw: False
    dtype: bfloat16
    use_torch_compile: False
    additional_params:
      chunk_size: 256
      normalize: False
  - kernel_name: "chunkwise--triton_xl_chunk_siging"
    fwbw: True
    dtype: bfloat16
    use_torch_compile: False
    additional_params:
      chunk_size: 256
      normalize: False
    
  - kernel_name: "lightning_attn2"
    fwbw: False
    dtype: bfloat16
    use_torch_compile: False
  - kernel_name: "lightning_attn2"
    fwbw: True
    dtype: bfloat16
    use_torch_compile: False


  
benchmark_name: "quick_kernel_benchmark"
"""
cfg_baseline = from_dict(
    data_class=BenchmarkConfig, data=OmegaConf.to_container(OmegaConf.create(cfg_yaml))
)

In [None]:
res_df = run_benchmarks(cfg_baseline, benchmark_creator=create_training_kernel_benchmark, run_garbage_collection=False)

In [None]:
res_df

In [None]:
res_df.filter(regex="(R|P)--.*", axis=1).T

In [None]:
res_df.filter(regex="(P|M)--.*", axis=1).T