In [1]:
import logging 

logging.basicConfig(
    level=logging.INFO,
    format="[%(asctime)s][%(name)s:%(lineno)d][%(levelname)s] - %(message)s",
)

import torch

import sys
sys.path.append("..")
from mlstm_kernels.baselines.lightning_attention.lightning_attn2 import lightning_attn2
from mlstm_kernels.baselines.lightning_attention.utils import _build_slope_tensor

from dacite import from_dict 
from omegaconf import OmegaConf

from mlstm_kernels.utils.benchmark.param_handling import BenchmarkConfig
from mlstm_kernels.utils.benchmark.run_benchmark import run_benchmarks
from mlstm_kernels.utils.benchmark.benchmarks.training_kernel_benchmarks import create_training_kernel_benchmark

[2025-01-24 18:49:53,457][numexpr.utils:146][INFO] - Note: detected 224 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-01-24 18:49:53,457][numexpr.utils:149][INFO] - Note: NumExpr detected 224 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2025-01-24 18:49:53,458][numexpr.utils:162][INFO] - NumExpr defaulting to 16 threads.


### quick test if it is runnable

In [2]:
b = 4
h = 8
n = 512
d = 128
dtype = torch.bfloat16
device = torch.device("cuda:0")

In [3]:
torch.manual_seed(0)
q = torch.randn((b, h, n, d), dtype=dtype, device=device).requires_grad_()
k = torch.randn((b, h, n, d), dtype=dtype, device=device).requires_grad_()
v = torch.randn((b, h, n, d), dtype=dtype, device=device).requires_grad_()
s = _build_slope_tensor(h).to(q.device).to(torch.float32)

In [4]:
out = lightning_attn2(q, k, v, s)

In [5]:
out.shape

torch.Size([4, 8, 512, 128])

### benchmark

Note: lightning attention does not support large head dimensions. Get a Out of shared memory error.

In [6]:
S = 8192
DHQK = 128 #64 #256 #128 #256  # *2
DHHV = 128 #64 #256 #128 #512  # *2
NH = 32 #64 #16 #32 #8
B = 2
D = NH * DHHV

In [7]:
cfg_yaml = f"""
vary_type: grid
vary_params: {dict()}
fixed_params: 
  batch_size: {B}
  sequence_length: {S}
  num_heads: {NH}
  head_dim_qk: {DHQK}
  head_dim_v: {DHHV}
  warmup: 5
  rep: 10

kernel_specs:
  # - kernel_name: "chunkwise--triton_limit_chunk"
  #   fwbw: False
  #   dtype: bfloat16
  #   additional_params:
  #     chunk_size: 64
  # - kernel_name: "chunkwise--triton_limit_chunk"
  #   fwbw: True
  #   dtype: bfloat16
  #   additional_params:
  #     chunk_size: 64
  # - kernel_name: "chunkwise--triton_xl_chunk"
  #   fwbw: False
  #   dtype: bfloat16
  #   additional_params:
  #     chunk_size: 128
  # - kernel_name: "chunkwise--triton_xl_chunk"
  #   fwbw: True
  #   dtype: bfloat16
  #   additional_params:
  #     chunk_size: 128
  # - kernel_name: "chunkwise--triton_xl_chunk_siging"
  #   fwbw: False
  #   dtype: bfloat16
  #   use_torch_compile: False
  #   additional_params:
  #     chunk_size: 128
  #     normalize: False
  # - kernel_name: "chunkwise--triton_xl_chunk_siging"
  #   fwbw: True
  #   dtype: bfloat16
  #   use_torch_compile: False
  #   additional_params:
  #     chunk_size: 128
  #     normalize: False
  - kernel_name: "lightning_attn2"
    fwbw: False
    dtype: bfloat16
    use_torch_compile: False
  - kernel_name: "lightning_attn2"
    fwbw: True
    dtype: bfloat16
    use_torch_compile: False

  - kernel_name: "chunkwise--triton_xl_chunk_siging"
    fwbw: False
    dtype: bfloat16
    use_torch_compile: False
    additional_params:
      chunk_size: 256
      normalize: False
  - kernel_name: "chunkwise--triton_xl_chunk_siging"
    fwbw: True
    dtype: bfloat16
    use_torch_compile: False
    additional_params:
      chunk_size: 256
      normalize: False
    


  
benchmark_name: "quick_kernel_benchmark"
"""
cfg_baseline = from_dict(
    data_class=BenchmarkConfig, data=OmegaConf.to_container(OmegaConf.create(cfg_yaml))
)

In [8]:
res_df = run_benchmarks(cfg_baseline, benchmark_creator=create_training_kernel_benchmark, run_garbage_collection=False)

[2025-01-24 18:49:55,567][mlstm_kernels.utils.benchmark.run_benchmark:42][INFO] - Parameter combination (1/1): {'batch_size': 2, 'sequence_length': 8192, 'num_heads': 32, 'head_dim_qk': 128, 'head_dim_v': 128, 'warmup': 5, 'rep': 10}
[2025-01-24 18:49:59,532][mlstm_kernels.utils.benchmark.run_benchmark:56][INFO] - ('Kernel (1/4): lightning_attn2____bfloat16__fw finished.', ' Runtime: 2.3291521072387695 ms. Peak memory: 0.809649152 GB.')
[2025-01-24 18:50:01,440][mlstm_kernels.utils.benchmark.run_benchmark:56][INFO] - ('Kernel (2/4): lightning_attn2____bfloat16__fwbw finished.', ' Runtime: 19.371231079101562 ms. Peak memory: 1.749174272 GB.')
[2025-01-24 18:50:05,695][mlstm_kernels.utils.benchmark.run_benchmark:56][INFO] - ('Kernel (3/4): chunkwise--triton_xl_chunk_siging____bfloat16__fw__cs-256_n-False finished.', ' Runtime: 1.472864031791687 ms. Peak memory: 0.955253248 GB.')
[2025-01-24 18:50:08,003][mlstm_kernels.utils.benchmark.run_benchmark:56][INFO] - ('Kernel (4/4): chunkwise--t

In [9]:
res_df

Unnamed: 0,P--batch_size,P--sequence_length,P--num_heads,P--head_dim_qk,P--head_dim_v,P--warmup,P--rep,R--lightning_attn2____bfloat16__fw,M--lightning_attn2____bfloat16__fw,R--lightning_attn2____bfloat16__fwbw,M--lightning_attn2____bfloat16__fwbw,R--chunkwise--triton_xl_chunk_siging____bfloat16__fw__cs-256_n-False,M--chunkwise--triton_xl_chunk_siging____bfloat16__fw__cs-256_n-False,R--chunkwise--triton_xl_chunk_siging____bfloat16__fwbw__cs-256_n-False,M--chunkwise--triton_xl_chunk_siging____bfloat16__fwbw__cs-256_n-False
0,2,8192,32,128,128,5,10,2.329152,809649152,19.371231,1749174272,1.472864,955253248,5.038624,2043676160


In [10]:
res_df.filter(regex="(R|P)--.*", axis=1).T

Unnamed: 0,0
P--batch_size,2.0
P--sequence_length,8192.0
P--num_heads,32.0
P--head_dim_qk,128.0
P--head_dim_v,128.0
P--warmup,5.0
P--rep,10.0
R--lightning_attn2____bfloat16__fw,2.329152
R--lightning_attn2____bfloat16__fwbw,19.371231
R--chunkwise--triton_xl_chunk_siging____bfloat16__fw__cs-256_n-False,1.472864


In [11]:
res_df.filter(regex="(P|M)--.*", axis=1).T

Unnamed: 0,0
P--batch_size,2
P--sequence_length,8192
P--num_heads,32
P--head_dim_qk,128
P--head_dim_v,128
P--warmup,5
P--rep,10
M--lightning_attn2____bfloat16__fw,809649152
M--lightning_attn2____bfloat16__fwbw,1749174272
M--chunkwise--triton_xl_chunk_siging____bfloat16__fw__cs-256_n-False,955253248
