In [1]:
%load_ext autoreload
%autoreload 2
import torch
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import pickle
from mlstm_kernels.utils.benchmark.plot_results import (
    plot_benchmark_result_table,
    create_runtime_line_plot,
)

### Plot Results for HEAD DIM kernel benchmark

In [2]:
from typing import Literal


def get_result_df(
    fwbw: bool,
    data: Literal["runtime", "memory"],
    benchmark_folder: str | Path,
    add_batch_size_col: bool = False,
    remove_col_name_prefixes: bool = True,
    x_axis_param: str = "head_dim_v",
) -> pd.DataFrame:
    benchmark_folder = Path(benchmark_folder)
    benchmark_name = "constant_tokens_sequence_"
    fwbw_folder_path = None
    fw_folder_path = None
    for dir_item in benchmark_folder.iterdir():
        if dir_item.is_dir():
            if "fwbw" in dir_item.stem.split(benchmark_name)[-1]:
                assert fwbw_folder_path is None
                fwbw_folder_path = dir_item
            elif "fw" in dir_item.stem.split(benchmark_name)[-1]:
                assert fw_folder_path is None
                fw_folder_path = dir_item

    folder_path = fwbw_folder_path if fwbw else fw_folder_path

    result_df = pd.read_csv(folder_path / "results.csv")

    data_prefix = "R--" if data == "runtime" else "M--"
    bs_col = "P--batch_size|" if add_batch_size_col else ""
    result_df = result_df.filter(regex=f"P--{x_axis_param}|{bs_col}{data_prefix}.*")
    if remove_col_name_prefixes:
        result_df = result_df.rename(columns=lambda x: x[3:])

    return result_df

In [3]:
MLSTM_FOLDER = "/home/beck/wdir/cleaned_repos/mlstm_kernels_internal-speedbench/outputs_kernel_benchmarks/2025-01-21_07-59-23__headdim_mlstm_triton_v0_1"
FLA_FOLDER = "/home/beck/wdir/cleaned_repos/mlstm_kernels_internal-speedbench/outputs_kernel_benchmarks/2025-01-21_10-10-59__headdim_fla_v0_1"
NUM_HEADS = 16
ALL_RESULT_FOLDERS = [MLSTM_FOLDER, FLA_FOLDER]

### Plot raw results

In [4]:
plot_fwbw = True

In [None]:
get_result_df(
    fwbw=plot_fwbw,
    data="runtime",
    benchmark_folder=MLSTM_FOLDER,
)

In [None]:
fig = create_runtime_line_plot(
    data_df=get_result_df(
        fwbw=plot_fwbw,
        data="runtime",
        benchmark_folder=MLSTM_FOLDER,
    ),
    group_col_names=["head_dim_v"],
    x_label="head_dim_v",
)

In [None]:
fig = create_runtime_line_plot(
    data_df=get_result_df(
        fwbw=plot_fwbw,
        data="runtime",
        benchmark_folder=FLA_FOLDER,
    ),
    group_col_names=["head_dim_v"],
    x_label="head_dim_v",
)

### combine df

In [8]:
fla_df = get_result_df(
    fwbw=plot_fwbw,
    data="runtime",
    benchmark_folder=FLA_FOLDER,
)
mlstm_df = get_result_df(
    fwbw=plot_fwbw,
    data="runtime",
    benchmark_folder=MLSTM_FOLDER,
)

In [None]:
fla_df

In [None]:
mlstm_df

In [None]:
pd.concat(
    [
        mlstm_df.set_index(mlstm_df["head_dim_v"]).drop(["head_dim_v"], axis=1),
        fla_df.set_index(fla_df["head_dim_v"]),
    ],
    axis=1,
)

In [12]:
def combine_dfs_on_common_column(
    fwbw: bool,
    data: Literal["runtime", "memory"],
    benchmark_folders: list[str | Path],
    column: str,
):
    def make_column_to_index(df: pd.DataFrame):
        return df.set_index(df[column]).drop([column], axis=1)

    combined_df = pd.concat(
        [
            make_column_to_index(
                get_result_df(fwbw=fwbw, data=data, benchmark_folder=folder)
            )
            for folder in benchmark_folders
        ],
        axis=1,
    )
    # remove duplicate columns by name
    # combined_df = combined_df.loc[:, ~combined_df.columns.duplicated()]
    return combined_df.reset_index()

In [13]:
all_runtime_results_fwbw_df = combine_dfs_on_common_column(
    fwbw=True, data="runtime", benchmark_folders=ALL_RESULT_FOLDERS, column="head_dim_v"
)
all_runtime_results_fw_df = combine_dfs_on_common_column(
    fwbw=False,
    data="runtime",
    benchmark_folders=ALL_RESULT_FOLDERS,
    column="head_dim_v",
)
all_memory_results_fwbw_df = combine_dfs_on_common_column(
    fwbw=True, data="memory", benchmark_folders=ALL_RESULT_FOLDERS, column="head_dim_v"
)
all_memory_results_fw_df = combine_dfs_on_common_column(
    fwbw=False, data="memory", benchmark_folders=ALL_RESULT_FOLDERS, column="head_dim_v"
)

In [14]:
all_results_dict = {
    "runtime": {"fw": all_runtime_results_fw_df, "fwbw": all_runtime_results_fwbw_df},
    "memory": {"fw": all_memory_results_fw_df, "fwbw": all_memory_results_fwbw_df},
}

In [15]:
result_filename = "mlstm_tfla_paper_head_dim_benchmark_results.p"

In [16]:
with open(Path(".") / result_filename, "wb") as f:
    pickle.dump(all_results_dict, f)

## Make paper plots

In [17]:
result_filename = "mlstm_tfla_paper_head_dim_benchmark_results.p"

In [18]:
with open(Path(".") / result_filename, "rb") as f:
    all_results_dict = pickle.load(f)

In [19]:
all_runtime_results_fw_df = all_results_dict["runtime"]["fw"]
all_runtime_results_fwbw_df = all_results_dict["runtime"]["fwbw"]

In [None]:
all_runtime_results_fw_df.T.style

In [None]:
all_runtime_results_fwbw_df.T.style

In [None]:
all_runtime_results_fwbw_df.columns

In [23]:
from plot_config import (
    col_order_headdim,
    map_headdim_fwbw_data_col_to_plot_col_mapping,
    get_tb_plot_mpl_context,
    legend_order,
    GRIDSPEC_KWARGS,
    style_dict_headdim,
    savefig,
)

MAMBA_VERSION = ""  # "_noconv"

In [24]:
headdim_fwbw_df = all_runtime_results_fwbw_df.rename(
    columns=map_headdim_fwbw_data_col_to_plot_col_mapping(
        fwbw=True, mlstm_exp_chunk_size=128, mlstm_sig_chunk_size=256
    )
)
headdim_fw_df = all_runtime_results_fw_df.rename(
    columns=map_headdim_fwbw_data_col_to_plot_col_mapping(
        fwbw=False, mlstm_exp_chunk_size=128, mlstm_sig_chunk_size=256
    )
)

In [None]:
headdim_fwbw_df

In [None]:
headdim_fw_df

In [None]:
fig_height = 4.5
with get_tb_plot_mpl_context(fontsize_delta=-0.5):
    fig, (ax_left, ax_right) = plt.subplots(
        1,
        2,
        figsize=(12.5, 4),  # (16, 3.5),
        gridspec_kw=GRIDSPEC_KWARGS,
        sharex=True,
    )

    fig = create_runtime_line_plot(
        ax=ax_left,
        data_df=headdim_fw_df,
        style_dict=style_dict_headdim,
        group_col_names=["head_dim_v"],
        plot_column_order=col_order_headdim,
        ylim=[0, 20],
        legend_args=None,
        legend_order=legend_order,
        yticks=[0, 2.5, 5, 7.5, 10, 12.5, 15.0, 17.5],
        x_label="Head Dimension",
    )
    fig = create_runtime_line_plot(
        ax=ax_right,
        data_df=headdim_fwbw_df,
        style_dict=style_dict_headdim,
        group_col_names=["head_dim_v"],
        plot_column_order=col_order_headdim,
        ylim=[0, 91],
        yticks=[0, 10, 20, 30, 40, 50, 60, 70, 80, 90],
        legend_args=None,
        x_label="Head Dimension",
    )
    handles, labels = ax_left.get_legend_handles_labels()
    legend_kwargs = {
        "loc": "lower center",
        "ncol": 3,
        "bbox_to_anchor": (0.0, 0.87, 1.0, 0.102),
        "frameon": False,
        "facecolor": "white",
    }
    fig.legend(handles, labels, **legend_kwargs)

In [28]:
savefig(fig=fig, filename="tfla_mlstm_kernel_benchmark-headdim--paper")