In [None]:
%load_ext autoreload
%autoreload 2
import sys

import pandas as pd

sys.path.append("../..")

from mlstm_kernels.utils.benchmark.plot_results import (
    plot_benchmark_result_table,
    create_runtime_bar_plot,
    rc_context_wrapper,
    select_columns,
    savefig
)
from pathlib import Path
from plot_config import linestyle_mapping, style_dict
import pickle

In [None]:
def load_throughput_results_for_ctxes(path_template: str, ctxes: list[int]):
    results = {}
    for ctx in ctxes:
        path = Path(path_template.format(ctx=ctx))
        df = pd.read_csv(path).filter(regex=".*(batch_size|prefill|R--).*")
        results[ctx] = df
    return results

In [None]:
llama_results = load_throughput_results_for_ctxes(
    path_template="/home/beck/wdir/dev_repos/mlstm_kernels/outputs_kernel_benchmarks_final/2024-12-06_12-48-14__throughput__forward_llama_v0/hf_7B_throughput__pfl{ctx}_gl0_tcTrue_weightdtypebfloat16/results.csv",
    ctxes=[2048, 4096, 8192],
)

mamba_results = load_throughput_results_for_ctxes(
    path_template="/home/beck/wdir/dev_repos/mlstm_kernels/outputs_kernel_benchmarks_final/2024-12-06_13-02-27__throughput__forward_mamba_v2/hf_7B_throughput__pfl{ctx}_gl0_tcTrue_weightdtypebfloat16/results.csv",
    ctxes=[2048, 4096, 8192],
)

mxlstm_results = load_throughput_results_for_ctxes(
    path_template="/home/beck/wdir/dev_repos/mlstm_kernels/outputs_kernel_benchmarks_final/2024-12-06_12-46-45__throughput__forward_xlstm_v0/hf_7B_throughput__pfl{ctx}_gl0_tcTrue_weightdtypebfloat16/results.csv",
    ctxes=[2048, 4096, 8192],
)

In [None]:
mamba_results[8192]

In [None]:
# collect all results sorted by context
result_dicts = [
    llama_results,
    mamba_results,
    mxlstm_results,
]
combined_raw_data = {}
for ctx in [2048, 4096, 8192]:
    ctx_df = pd.concat([rd[ctx] for rd in result_dicts], axis=1)
    # select batch size only once
    ctx_df = pd.concat(
        [
            ctx_df.filter(regex="batch_size").take([0], axis=1),
            ctx_df.filter(regex="prefill").take([0], axis=1),
            ctx_df.filter(regex=".*R--.*"),
        ],
        axis=1,
    )
    combined_raw_data[ctx] = ctx_df

In [None]:
with open("throughput_data.p", "wb") as f:
    pickle.dump(combined_raw_data, f)

for k, v in combined_raw_data.items():
    v.to_csv(f"raw_data_throughput_{k}.csv")

In [None]:
with open("throughput_data.p", "rb") as f:
    combined_raw_data = pickle.load(f)

In [None]:
def compute_throughput_tokens_per_sec(raw_data_dict: dict[str, pd.DataFrame]):
    throughput_dict = {}
    for ctx, df in raw_data_dict.items():
        batch_size_col = df.filter(regex="batch_size").take([0], axis=1)
        prefill_col = df.filter(regex="prefill").take([0], axis=1)
        tokens_per_sec_df = (
            ctx * batch_size_col.values / (df.filter(regex=".*R--.*") / 1000)
        )
        throughput_dict[ctx] = pd.concat(
            [batch_size_col, prefill_col, tokens_per_sec_df], axis=1
        )
    return throughput_dict

In [None]:
throughput_data = compute_throughput_tokens_per_sec(combined_raw_data)

In [None]:
throughput_data[2048]

In [None]:
throughput_data[4096]

In [None]:
throughput_data[8192]

In [None]:
throughput_8192 = throughput_data[8192].loc[
    throughput_data[8192]["P--batch_size"] == 8.0
]

In [None]:
throughput_4096 = throughput_data[4096].loc[
    throughput_data[4096]["P--batch_size"] == 16.0
]

In [None]:
throughput_2048 = throughput_data[2048].loc[
    throughput_data[2048]["P--batch_size"] == 32.0
]

In [None]:
throughput_df = pd.concat([throughput_2048, throughput_4096, throughput_8192])

In [None]:
throughput_df

## Plotting the raw data

In [None]:
column_name_mapping = {
    "P--batch_size": "Batch Size",
    "P--prefill_length": "Context Length",
    "R--llama2__tcm__ampdt-bfloat16__wdt-bfloat16__ucgg-False_ucgm-False": "llama2",
    "R--llama3__tcm__ampdt-bfloat16__wdt-bfloat16__ucgg-False_ucgm-False": "llama3",
    "R--codestral_mamba__ampdt-bfloat16__wdt-bfloat16__ucgg-True_ucgm-False": "codestral_mamba",
    "R--falcon_mamba__ampdt-bfloat16__wdt-bfloat16__ucgg-True_ucgm-False": "falcon_mamba",
    "R--xlstm__tcm__ampdt-bfloat16__wdt-bfloat16__ucgg-True_ucgm-False_isd-bfloat16_ed-4096_nh-8_nb-32_vs-50304_wm-fused_ck-chunkwise--triton_xl_chunk_sk-native_sequence__triton_step_fused_sk-triton_fused_cs-128_akd-bfloat16": "xlstm",
}

In [None]:
throughput_df.columns

In [None]:
new_col_names = [column_name_mapping.get(col, col) for col in throughput_df.columns]
throughput_df.columns = new_col_names

In [None]:
plot_throughput_df = throughput_df.round(0).astype(int)
plot_throughput_df

In [None]:
fig = rc_context_wrapper(func=create_runtime_bar_plot,
    data_df=plot_throughput_df,
    group_col_names=["Batch Size", "Context Length"],
    style_dict=style_dict,
    figsize=(1.5 * 12 * 1 / 2.54, 1.5 * 8 * 1 / 2.54),
    y_label="Tokens per Second"
)

In [None]:
savefig(fig, "throughput")