In [None]:
%load_ext autoreload
%autoreload 2
import torch
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import pickle
from mlstm_kernels.utils.benchmark.plot_results import (
    plot_benchmark_result_table,
    create_runtime_line_plot,
)

### Plot Results for kernel benchmark with 7B model size

In [31]:
from typing import Literal


def get_result_df(
    fwbw: bool,
    data: Literal["runtime", "memory"],
    benchmark_folder: str | Path,
    add_batch_size_col: bool = False,
    remove_col_name_prefixes: bool = True,
) -> pd.DataFrame:
    benchmark_folder = Path(benchmark_folder)
    benchmark_name = "constant_tokens_sequence_"
    fwbw_folder_path = None
    fw_folder_path = None
    for dir_item in benchmark_folder.iterdir():
        if dir_item.is_dir():
            if "fwbw" in dir_item.stem.split(benchmark_name)[-1]:
                assert fwbw_folder_path is None
                fwbw_folder_path = dir_item
            elif "fw" in dir_item.stem.split(benchmark_name)[-1]:
                assert fw_folder_path is None
                fw_folder_path = dir_item

    folder_path = fwbw_folder_path if fwbw else fw_folder_path

    result_df = pd.read_csv(folder_path / "results.csv")

    data_prefix = "R--" if data == "runtime" else "M--"
    bs_col = "P--batch_size|" if add_batch_size_col else ""
    result_df = result_df.filter(regex=f"P--sequence_length|{bs_col}{data_prefix}.*")
    if remove_col_name_prefixes:
        result_df = result_df.rename(columns=lambda x: x[3:])

    return result_df

In [32]:
MLSTM_FOLDER = "/home/beck/wdir/cleaned_repos/mlstm_kernels_internal-speedbench/outputs_kernel_benchmarks/2025-01-20_15-05-42__consttok_mlstm_triton_v2"
FLASHATTN_FOLDER = "/home/beck/wdir/cleaned_repos/mlstm_kernels_internal-speedbench/outputs_kernel_benchmarks/2025-01-20_15-12-37__consttok_flashattn_v2"
FLA_FOLDER = "/home/beck/wdir/cleaned_repos/mlstm_kernels_internal-speedbench/outputs_kernel_benchmarks/2025-01-20_15-31-14__consttok_fla_v2"
MAMBA_FOLDER = "/home/beck/wdir/cleaned_repos/mlstm_kernels_internal-speedbench/outputs_kernel_benchmarks/2025-01-20_15-52-47__consttok_mamba_v2_1"
# "/home/beck/wdir/cleaned_repos/mlstm_kernels_internal-speedbench/outputs_kernel_benchmarks/2025-01-20_15-31-04__consttok_mamba_v2"

ALL_RESULT_FOLDERS = [MLSTM_FOLDER, FLASHATTN_FOLDER, FLA_FOLDER, MAMBA_FOLDER]

### Plot raw results

In [33]:
plot_fwbw = True

In [None]:
get_result_df(
    fwbw=plot_fwbw,
    data="runtime",
    benchmark_folder=FLASHATTN_FOLDER,
    add_batch_size_col=True,
)

In [None]:
fig = create_runtime_line_plot(
    data_df=get_result_df(
        fwbw=plot_fwbw,
        data="runtime",
        benchmark_folder=FLASHATTN_FOLDER,
    ),
    group_col_names=["sequence_length"],
)

In [None]:
fig = create_runtime_line_plot(
    data_df=get_result_df(
        fwbw=plot_fwbw,
        data="runtime",
        benchmark_folder=MLSTM_FOLDER,
    ),
    group_col_names=["sequence_length"],
)

In [None]:
fig = create_runtime_line_plot(
    data_df=get_result_df(
        fwbw=False,
        data="runtime",
        benchmark_folder=MLSTM_FOLDER,
    ),
    group_col_names=["sequence_length"],
)

In [None]:
fig = create_runtime_line_plot(
    data_df=get_result_df(
        fwbw=plot_fwbw,
        data="runtime",
        benchmark_folder=FLA_FOLDER,
    ),
    group_col_names=["sequence_length"],
)

In [None]:
fig = create_runtime_line_plot(
    data_df=get_result_df(
        fwbw=False,
        data="runtime",
        benchmark_folder=FLA_FOLDER,
    ),
    group_col_names=["sequence_length"],
)

In [None]:
fig = create_runtime_line_plot(
    data_df=get_result_df(
        fwbw=plot_fwbw,
        data="runtime",
        benchmark_folder=MAMBA_FOLDER,
    ),
    group_col_names=["sequence_length"],
)

### combine df

In [41]:
def combine_dfs(
    fwbw: bool, data: Literal["runtime", "memory"], benchmark_folders: list[str | Path]
):
    combined_df = pd.concat(
        [
            get_result_df(fwbw=fwbw, data=data, benchmark_folder=folder)
            for folder in benchmark_folders
        ],
        axis=1,
    )
    # remove duplicate columns by name
    combined_df = combined_df.loc[:, ~combined_df.columns.duplicated()]
    return combined_df

In [42]:
all_runtime_results_fwbw_df = combine_dfs(
    fwbw=True, data="runtime", benchmark_folders=ALL_RESULT_FOLDERS
)
all_runtime_results_fw_df = combine_dfs(
    fwbw=False, data="runtime", benchmark_folders=ALL_RESULT_FOLDERS
)
all_memory_results_fwbw_df = combine_dfs(
    fwbw=True, data="memory", benchmark_folders=ALL_RESULT_FOLDERS
)
all_memory_results_fw_df = combine_dfs(
    fwbw=False, data="memory", benchmark_folders=ALL_RESULT_FOLDERS
)

In [43]:
all_results_dict = {
    "runtime": {"fw": all_runtime_results_fw_df, "fwbw": all_runtime_results_fwbw_df},
    "memory": {"fw": all_memory_results_fw_df, "fwbw": all_memory_results_fwbw_df},
}

In [44]:
result_filename = "mlstm_tfla_paper_consttoken_benchmark_results.p"

In [45]:
with open(Path(".") / result_filename, "wb") as f:
    pickle.dump(all_results_dict, f)

## Make paper plots

In [46]:
result_filename = "mlstm_tfla_paper_consttoken_benchmark_results.p"

In [47]:
with open(Path(".") / result_filename, "rb") as f:
    all_results_dict = pickle.load(f)

In [48]:
all_runtime_results_fw_df = all_results_dict["runtime"]["fw"]
all_runtime_results_fwbw_df = all_results_dict["runtime"]["fwbw"]

In [None]:
all_runtime_results_fw_df.T.style

In [None]:
all_runtime_results_fwbw_df.T.style

In [None]:
all_runtime_results_fwbw_df.columns

In [52]:
from plot_config import (
    col_order_consttoken,
    map_consttoken_fwbw_data_col_to_plot_col_mapping,
    get_tb_plot_mpl_context,
    legend_order,
    GRIDSPEC_KWARGS,
    style_dict,
    savefig,
)

In [53]:
consttoken_fwbw_raw_df = all_runtime_results_fwbw_df[
    ["sequence_length"]
    + list(map_consttoken_fwbw_data_col_to_plot_col_mapping(fwbw=True).keys())
]
consttoken_fwbw_df = consttoken_fwbw_raw_df.rename(
    columns=map_consttoken_fwbw_data_col_to_plot_col_mapping(fwbw=True)
)

In [54]:
consttoken_fw_raw_df = all_runtime_results_fw_df[
    ["sequence_length"]
    + list(map_consttoken_fwbw_data_col_to_plot_col_mapping(fwbw=False).keys())
]
consttoken_fw_df = consttoken_fw_raw_df.rename(
    columns=map_consttoken_fwbw_data_col_to_plot_col_mapping(fwbw=False)
)

In [None]:
consttoken_fwbw_df

In [None]:
consttoken_fw_df

In [None]:
fig_height = 4.5
with get_tb_plot_mpl_context(fontsize_delta=1):
    fig, (ax_left, ax_right) = plt.subplots(
        1,
        2,
        figsize=(16, 3.5),
        gridspec_kw=GRIDSPEC_KWARGS,
        sharex=True,
    )

    fig = create_runtime_line_plot(
        ax=ax_left,
        data_df=consttoken_fw_df,
        style_dict=style_dict,
        group_col_names=["sequence_length"],
        plot_column_order=col_order_consttoken,
        ylim=[0, 15.5],
        legend_args=None,
        legend_order=legend_order,
        yticks=[0, 2.5, 5, 7.5, 10, 12.5, 15.0],
    )
    fig = create_runtime_line_plot(
        ax=ax_right,
        data_df=consttoken_fwbw_df,
        style_dict=style_dict,
        group_col_names=["sequence_length"],
        plot_column_order=col_order_consttoken,
        ylim=[0, 68],
        legend_args=None,
    )
    handles, labels = ax_left.get_legend_handles_labels()
    legend_kwargs = {
        "loc": "lower center",
        "ncol": 5,
        "bbox_to_anchor": (0.0, 0.87, 1.0, 0.102),
        "frameon": False,
        "facecolor": "white",
    }
    fig.legend(handles, labels, **legend_kwargs)

In [58]:
savefig(fig=fig, filename="tfla_mlstm_kernel_benchmark--paper")