In [1]:
%load_ext autoreload
%autoreload 2
import torch
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import pickle
from mlstm_kernels.utils.benchmark.plot_results import (
    plot_benchmark_result_table,
    create_runtime_line_plot,
)

### Plot Results for kernel benchmark with 7B model size

In [2]:
from typing import Literal


def get_result_df(
    fwbw: bool,
    data: Literal["runtime", "memory"],
    benchmark_folder: str | Path,
    add_batch_size_col: bool = False,
    remove_col_name_prefixes: bool = True,
) -> pd.DataFrame:
    benchmark_folder = Path(benchmark_folder)
    benchmark_name = "constant_tokens_sequence_"
    fwbw_folder_path = None
    fw_folder_path = None
    for dir_item in benchmark_folder.iterdir():
        if dir_item.is_dir():
            if "fwbw" in dir_item.stem.split(benchmark_name)[-1]:
                assert fwbw_folder_path is None
                fwbw_folder_path = dir_item
            elif "fw" in dir_item.stem.split(benchmark_name)[-1]:
                assert fw_folder_path is None
                fw_folder_path = dir_item

    folder_path = fwbw_folder_path if fwbw else fw_folder_path

    result_df = pd.read_csv(folder_path / "results.csv")

    data_prefix = "R--" if data == "runtime" else "M--"
    bs_col = "P--batch_size|" if add_batch_size_col else ""
    result_df = result_df.filter(regex=f"P--sequence_length|{bs_col}{data_prefix}.*")
    if remove_col_name_prefixes:
        result_df = result_df.rename(columns=lambda x: x[3:])

    return result_df

In [3]:
# FLASHATTN_FOLDER = "/home/beck/wdir/cleaned_repos/mlstm_kernels_internal-speedbench/outputs_kernel_benchmarks/2025-01-20_15-12-37__consttok_flashattn_v2"
# MLSTM_FOLDER = "/home/beck/wdir/cleaned_repos/mlstm_kernels_internal-speedbench/outputs_kernel_benchmarks/2025-01-20_15-05-42__consttok_mlstm_triton_v2"
# FLA_FOLDER = "/home/beck/wdir/cleaned_repos/mlstm_kernels_internal-speedbench/outputs_kernel_benchmarks/2025-01-20_15-31-14__consttok_fla_v2"
# MAMBA_FOLDER = "/home/beck/wdir/cleaned_repos/mlstm_kernels_internal-speedbench/outputs_kernel_benchmarks/2025-01-20_15-52-47__consttok_mamba_v2_1"
# RERUN:
# MAMBA_FOLDER = "/home/beck/wdir/cleaned_repos/mlstm_kernels_internal-speedbench/outputs_kernel_benchmarks/2025-01-22_23-42-35__consttok_mamba_rerun_v0"
# FLA_FOLDER = "/home/beck/wdir/cleaned_repos/mlstm_kernels_internal-speedbench/outputs_kernel_benchmarks/2025-01-23_01-04-04__consttok_fla_rerun_v0-nh16"
MLSTM_FOLDER_NH32 = "/home/beck/wdir/cleaned_repos/mlstm_kernels_internal-speedbench/outputs_kernel_benchmarks/2025-01-24_21-35-50__consttok_mlstm_triton_lightnattn_v1"
MLSTM_FOLDER_NH64 = "/home/beck/wdir/cleaned_repos/mlstm_kernels_internal-speedbench/outputs_kernel_benchmarks/2025-01-24_22-16-44__consttok_mlstm_triton_lightnattn_v1"
LIGHTNATTN_NH64_FOLDER = "/home/beck/wdir/cleaned_repos/mlstm_kernels_internal-speedbench/outputs_kernel_benchmarks/2025-01-24_19-19-58__consttok_lightning_attn2_lightnattn_v0"
LIGHTNATTN_NH32_FOLDER = "/home/beck/wdir/cleaned_repos/mlstm_kernels_internal-speedbench/outputs_kernel_benchmarks/2025-01-24_19-17-47__consttok_lightning_attn2_lightnattn_v0"
NUM_HEADS = 32  # for our mlstm kernels
# MLSTM_FOLDER = MLSTM_FOLDER_NH32 if NUM_HEADS == 32 else MLSTM_FOLDER_NH64
ALL_RESULT_FOLDERS = [
    LIGHTNATTN_NH32_FOLDER,
    LIGHTNATTN_NH64_FOLDER,
    MLSTM_FOLDER_NH32,
    MLSTM_FOLDER_NH64,
]

### Plot raw results

In [4]:
plot_fwbw = True

In [None]:
get_result_df(
    fwbw=plot_fwbw,
    data="runtime",
    benchmark_folder=LIGHTNATTN_NH64_FOLDER,
    add_batch_size_col=True,
)

In [None]:
fig = create_runtime_line_plot(
    data_df=get_result_df(
        fwbw=plot_fwbw,
        data="runtime",
        benchmark_folder=LIGHTNATTN_NH32_FOLDER,
    ),
    group_col_names=["sequence_length"],
)

In [None]:
fig = create_runtime_line_plot(
    data_df=get_result_df(
        fwbw=plot_fwbw,
        data="runtime",
        benchmark_folder=LIGHTNATTN_NH64_FOLDER,
    ),
    group_col_names=["sequence_length"],
)

In [None]:
fig = create_runtime_line_plot(
    data_df=get_result_df(
        fwbw=False,
        data="runtime",
        benchmark_folder=MLSTM_FOLDER_NH64,
    ),
    group_col_names=["sequence_length"],
)

### combine df

In [9]:
def combine_dfs(
    fwbw: bool, data: Literal["runtime", "memory"], benchmark_folders: list[str | Path]
):
    combined_df = pd.concat(
        [
            get_result_df(fwbw=fwbw, data=data, benchmark_folder=folder)
            for folder in benchmark_folders
        ],
        axis=1,
    )
    # remove duplicate columns by name
    combined_df = combined_df.loc[:, ~combined_df.columns.duplicated()]
    return combined_df

In [10]:
all_runtime_results_fwbw_df = combine_dfs(
    fwbw=True, data="runtime", benchmark_folders=ALL_RESULT_FOLDERS
)
all_runtime_results_fw_df = combine_dfs(
    fwbw=False, data="runtime", benchmark_folders=ALL_RESULT_FOLDERS
)
all_memory_results_fwbw_df = combine_dfs(
    fwbw=True, data="memory", benchmark_folders=ALL_RESULT_FOLDERS
)
all_memory_results_fw_df = combine_dfs(
    fwbw=False, data="memory", benchmark_folders=ALL_RESULT_FOLDERS
)

In [11]:
all_results_dict = {
    "runtime": {"fw": all_runtime_results_fw_df, "fwbw": all_runtime_results_fwbw_df},
    "memory": {"fw": all_memory_results_fw_df, "fwbw": all_memory_results_fwbw_df},
}

In [12]:
result_filename = "mlstm_tfla_paper_consttoken_benchmark_results_lightn_attn.p"

In [13]:
with open(Path(".") / result_filename, "wb") as f:
    pickle.dump(all_results_dict, f)

## Make paper plots

In [14]:
result_filename = "mlstm_tfla_paper_consttoken_benchmark_results_lightn_attn.p"

In [15]:
with open(Path(".") / result_filename, "rb") as f:
    all_results_dict = pickle.load(f)

In [16]:
all_runtime_results_fw_df = all_results_dict["runtime"]["fw"]
all_runtime_results_fwbw_df = all_results_dict["runtime"]["fwbw"]
all_memory_results_fwbw_df = all_results_dict["memory"]["fwbw"]

In [None]:
all_runtime_results_fw_df.T.style

In [None]:
all_runtime_results_fwbw_df.T.style

In [None]:
all_memory_results_fwbw_df.T.style

In [None]:
all_runtime_results_fwbw_df.columns

In [21]:
from plot_config import (
    get_col_order_lightnattn,
    map_consttoken_fwbw_lightnattn_data_col_to_plot_col_mapping,
    get_tb_plot_mpl_context,
    legend_order,
    GRIDSPEC_KWARGS,
    get_style_dict_lightnattn,
    savefig,
)

num_heads = [32, 64]
chunk_sizes = list(reversed([128, 256, 512, 1024, 2048, 4096]))

additional_col = "mlstmsig_triton_xl_chunk--nh-64-cs-256"
override_color_mapping = {additional_col: "#9a3c73"}

In [None]:
all_runtime_results_fwbw_df

In [23]:
consttoken_fwbw_raw_df = all_runtime_results_fwbw_df[
    ["sequence_length"]
    + list(
        map_consttoken_fwbw_lightnattn_data_col_to_plot_col_mapping(
            fwbw=True, num_heads=num_heads, half_qkdim=False
        ).keys()
    )
]
consttoken_fwbw_df = consttoken_fwbw_raw_df.rename(
    columns=map_consttoken_fwbw_lightnattn_data_col_to_plot_col_mapping(
        fwbw=True,
        num_heads=num_heads,
        half_qkdim=False,
    )
)

In [24]:
consttoken_fw_raw_df = all_runtime_results_fw_df[
    ["sequence_length"]
    + list(
        map_consttoken_fwbw_lightnattn_data_col_to_plot_col_mapping(
            fwbw=False, num_heads=num_heads, half_qkdim=False
        ).keys()
    )
]
consttoken_fw_df = consttoken_fw_raw_df.rename(
    columns=map_consttoken_fwbw_lightnattn_data_col_to_plot_col_mapping(
        fwbw=False, num_heads=num_heads, half_qkdim=False
    )
)

In [25]:
consttoken_memory_fwbw_raw_df = all_memory_results_fwbw_df[
    ["sequence_length"]
    + list(
        map_consttoken_fwbw_lightnattn_data_col_to_plot_col_mapping(
            fwbw=True, num_heads=num_heads, half_qkdim=False
        ).keys()
    )
]
consttoken_memory_fwbw_df = consttoken_memory_fwbw_raw_df.rename(
    columns=map_consttoken_fwbw_lightnattn_data_col_to_plot_col_mapping(
        fwbw=True, num_heads=num_heads, half_qkdim=False
    )
)

In [None]:
consttoken_memory_fwbw_df

In [None]:
consttoken_fwbw_df

In [None]:
consttoken_fw_df

In [None]:
filter_regex = "sequence_length|mlstmsig.*|mlstmexp_triton_limit_chunk|lightnattn_nh32|lightnattn_nh64"
consttoken_runtime_fwbw_df = consttoken_fwbw_df.filter(regex=filter_regex)
consttoken_runtime_fw_df = consttoken_fw_df.filter(regex=filter_regex)
consttoken_memory_fwbw_df = consttoken_memory_fwbw_df.filter(regex=filter_regex)
consttoken_memory_fwbw_df = pd.concat(
    [
        consttoken_memory_fwbw_df["sequence_length"],
        consttoken_memory_fwbw_df.replace(-1, float("nan")).loc[
            :, consttoken_memory_fwbw_df.columns != "sequence_length"
        ]
        / 1e9,
    ],
    axis=1,
)
consttoken_runtime_fwbw_df

In [None]:
with get_tb_plot_mpl_context(fontsize_delta=1):
    fig, (ax_left, ax_right) = plt.subplots(
        1,
        2,
        figsize=(12.5, 4),
        gridspec_kw=GRIDSPEC_KWARGS,
        sharex=True,
    )

    fig = create_runtime_line_plot(
        ax=ax_left,
        data_df=consttoken_runtime_fw_df,
        style_dict=get_style_dict_lightnattn(
            chunk_sizes=chunk_sizes, colormap=plt.cm.copper, cmap_start_end=(0.2, 1)
        ),
        group_col_names=["sequence_length"],
        plot_column_order=get_col_order_lightnattn(chunk_sizes=chunk_sizes),
        ylim=[0, 15.5],
        legend_args=None,
        legend_order=legend_order,
        yticks=[0, 2.5, 5, 7.5, 10, 12.5, 15.0],
    )
    fig = create_runtime_line_plot(
        ax=ax_right,
        data_df=consttoken_runtime_fwbw_df,
        style_dict=get_style_dict_lightnattn(
            chunk_sizes=chunk_sizes, colormap=plt.cm.copper, cmap_start_end=(0.2, 1)
        ),
        group_col_names=["sequence_length"],
        plot_column_order=get_col_order_lightnattn(chunk_sizes=chunk_sizes),
        ylim=[0, 61],
        legend_args=None,
    )
    handles, labels = ax_left.get_legend_handles_labels()
    legend_kwargs = {
        "loc": "lower center",
        "ncol": 5,
        "bbox_to_anchor": (0.0, 0.87, 1.0, 0.102),
        "frameon": False,
        "facecolor": "white",
    }
    fig.legend(handles, labels, **legend_kwargs)

In [30]:
# savefig(fig=fig, filename="tfla_mlstm_kernel_benchmark_lightnattn--paper")

In [None]:
with get_tb_plot_mpl_context(fontsize_delta=0):
    fig, (ax_left, ax_middle, ax_right) = plt.subplots(
        1,
        3,
        figsize=(18, 4),
        gridspec_kw={"wspace": 0.18, "hspace": 0.1},
        sharex=True,
    )

    fig = create_runtime_line_plot(
        ax=ax_left,
        data_df=consttoken_runtime_fw_df,
        style_dict=get_style_dict_lightnattn(
            chunk_sizes=chunk_sizes, colormap=plt.cm.copper, cmap_start_end=(0.2, 1)
        ),
        group_col_names=["sequence_length"],
        plot_column_order=get_col_order_lightnattn(chunk_sizes=chunk_sizes),
        # legend_order=legend_order,
        # yticks=[0, 2.5, 5, 7.5, 10, 12.5, 15.0],
        legend_args=None,
    )
    fig = create_runtime_line_plot(
        ax=ax_middle,
        data_df=consttoken_runtime_fwbw_df,
        style_dict=get_style_dict_lightnattn(
            chunk_sizes=chunk_sizes, colormap=plt.cm.copper, cmap_start_end=(0.2, 1)
        ),
        group_col_names=["sequence_length"],
        plot_column_order=get_col_order_lightnattn(chunk_sizes=chunk_sizes),
        # legend_order=legend_order,
        # yticks=[0, 2.5, 5, 7.5, 10, 12.5, 15.0],
        legend_args=None,
    )
    fig = create_runtime_line_plot(
        ax=ax_right,
        data_df=consttoken_memory_fwbw_df,
        style_dict=get_style_dict_lightnattn(
            chunk_sizes=chunk_sizes, colormap=plt.cm.copper, cmap_start_end=(0.2, 1)
        ),
        group_col_names=["sequence_length"],
        plot_column_order=get_col_order_lightnattn(chunk_sizes=chunk_sizes),
        # legend_order=legend_order,
        # yticks=[0, 2.5, 5, 7.5, 10, 12.5, 15.0],
        y_label="GPU Memory [GB]",
        legend_args=None,
    )

    handles, labels = ax_left.get_legend_handles_labels()
    legend_kwargs = {
        "loc": "lower center",
        "ncol": 5,
        "bbox_to_anchor": (0.0, 0.87, 1.0, 0.102),
        "frameon": False,
        "facecolor": "white",
    }
    fig.legend(handles, labels, **legend_kwargs)

In [None]:
with get_tb_plot_mpl_context(fontsize_delta=0):
    fig, (ax_left, ax_right) = plt.subplots(
        1,
        2,
        figsize=(12.5, 4),
        gridspec_kw={"wspace": 0.18, "hspace": 0.1},
        sharex=True,
    )

    fig = create_runtime_line_plot(
        ax=ax_left,
        data_df=consttoken_runtime_fwbw_df,
        style_dict=get_style_dict_lightnattn(
            chunk_sizes=chunk_sizes,
            colormap=plt.cm.copper,
            cmap_start_end=(0.2, 1),
            override_color_mapping=override_color_mapping,
        ),
        group_col_names=["sequence_length"],
        plot_column_order=get_col_order_lightnattn(
            chunk_sizes=chunk_sizes, additional_col=additional_col
        ),
        # legend_order=legend_order,
        # yticks=[0, 2.5, 5, 7.5, 10, 12.5, 15.0],
        legend_args=None,
        ylim=(10, 77),
    )
    # fig = create_runtime_line_plot(
    #     ax=ax_middle,
    #     data_df=consttoken_fwbw_df,
    #     style_dict=get_style_dict_lightnattn(chunk_sizes=chunk_sizes, colormap=plt.cm.copper,cmap_start_end=(0.2,1)),
    #     group_col_names=["sequence_length"],
    #     plot_column_order=get_col_order_lightnattn(chunk_sizes=chunk_sizes),
    #     # legend_order=legend_order,
    #     # yticks=[0, 2.5, 5, 7.5, 10, 12.5, 15.0],
    #     legend_args=None,
    # )
    fig = create_runtime_line_plot(
        ax=ax_right,
        data_df=consttoken_memory_fwbw_df,
        style_dict=get_style_dict_lightnattn(
            chunk_sizes=chunk_sizes,
            colormap=plt.cm.copper,
            cmap_start_end=(0.2, 1),
            override_color_mapping=override_color_mapping,
        ),
        group_col_names=["sequence_length"],
        plot_column_order=get_col_order_lightnattn(
            chunk_sizes=chunk_sizes, additional_col=additional_col
        ),
        # legend_order=legend_order,
        # yticks=[0, 2.5, 5, 7.5, 10, 12.5, 15.0],
        y_label="GPU Memory [GB]",
        legend_args=None,
    )

    handles, labels = ax_left.get_legend_handles_labels()
    legend_kwargs = {
        "loc": "lower center",
        "ncol": 5,
        "bbox_to_anchor": (0.0, 0.87, 1.0, 0.102),
        "frameon": False,
        "facecolor": "white",
        "columnspacing": 0.75,  # 1.0
    }
    fig.legend(handles, labels, **legend_kwargs)

In [34]:
savefig(fig=fig, filename="tfla_mlstm_kernel_benchmark_lightnattn--paper")