In [1]:
%load_ext autoreload
%autoreload 2
import sys

import pandas as pd

sys.path.append("../..")
sys.path.append(".")

from plot_results_for_paper import (
    plot_benchmark_result_table,
    create_runtime_bar_plot,
    rc_context_wrapper,
    rc_context_wrapper_quarter_col_plot,
    select_columns,
    savefig,
)
from pathlib import Path
from plot_config import linestyle_mapping, style_dict
import pickle
import matplotlib.pyplot as plt
from plot_config_for_paper import FIGSIZE

In [5]:
def load_throughput_results_for_ctxes(path_template: str, ctxes: list[int]):
    results = {}
    for ctx in ctxes:
        path = Path(path_template.format(ctx=ctx))
        df = pd.read_csv(path).filter(regex=".*(batch_size|prefill|R--).*")
        results[ctx] = df
    return results

In [6]:
# llama_results = load_throughput_results_for_ctxes(
#     path_template="/home/beck/wdir/dev_repos/mlstm_kernels/outputs_kernel_benchmarks_final/2024-12-06_19-44-48__throughput__forward_llama_v3/hf_7B_throughput__pfl{ctx}_gl0_tcTrue_weightdtypebfloat16/results.csv",
#     ctxes=[2048, 4096, 8192, 16384, 32768],
# )

# mamba_results = load_throughput_results_for_ctxes(
#     path_template="/home/beck/wdir/dev_repos/mlstm_kernels/outputs_kernel_benchmarks_final/2024-12-06_13-02-27__throughput__forward_mamba_v2/hf_7B_throughput__pfl{ctx}_gl0_tcTrue_weightdtypebfloat16/results.csv",
#     ctxes=[2048, 4096, 8192, 16384, 32768],
# )

# mxlstm_results = load_throughput_results_for_ctxes(
#     path_template="/home/beck/wdir/dev_repos/mlstm_kernels/outputs_kernel_benchmarks_final/2024-12-06_12-46-45__throughput__forward_xlstm_v0/hf_7B_throughput__pfl{ctx}_gl0_tcTrue_weightdtypebfloat16/results.csv",
#     ctxes=[2048, 4096, 8192, 16384, 32768],
# )

In [7]:
# # collect all results sorted by context
# result_dicts = [
#     llama_results,
#     mamba_results,
#     mxlstm_results,
# ]
# combined_raw_data = {}
# for ctx in [2048, 4096, 8192, 16384, 32768]:
#     ctx_df = pd.concat(
#         [rd[ctx].set_index("P--batch_size") for rd in result_dicts], axis=1
#     )
#     # select batch size only once
#     ctx_df = pd.concat(
#         [
#             ctx_df.filter(regex="prefill").take([0], axis=1),
#             ctx_df.filter(regex=".*R--.*"),
#         ],
#         axis=1,
#     )
#     combined_raw_data[ctx] = ctx_df

In [8]:
# with open("throughput_data.p", "wb") as f:
#     pickle.dump(combined_raw_data, f)

# for k, v in combined_raw_data.items():
#     v.to_csv(f"raw_data_throughput_{k}.csv")

In [9]:
with open("throughput_data.p", "rb") as f:
    combined_raw_data = pickle.load(f)

In [None]:
2048 * combined_raw_data[2048].index.values

In [None]:
vals = (2048 * combined_raw_data[2048].index.values)[:, None] / combined_raw_data[
    2048
].filter(regex=".*R--.*")
vals

In [None]:
pd.DataFrame(
    data=vals,
    columns=combined_raw_data[2048].filter(regex=".*R--.*").columns,
    index=combined_raw_data[2048].index,
)

In [None]:
combined_raw_data[2048]["P--prefill_length"].values[0]

In [14]:
def compute_throughput_tokens_per_sec(raw_data_dict: dict[str, pd.DataFrame]):
    throughput_dict = {}
    for ctx, df in raw_data_dict.items():
        prefill_col = df.filter(regex="prefill").take([0], axis=1)
        tokens_per_sec_df = (ctx * df.index.values)[:, None] / (
            df.filter(regex=".*R--.*") / 1000
        )
        tokens_per_sec_df["P--prefill_length"] = int(prefill_col.values[0].item())
        # now make index a column again
        batch_col = tokens_per_sec_df.index
        tokens_per_sec_df["P--batch_size"] = batch_col

        tokens_per_sec_df = pd.concat(
            [
                tokens_per_sec_df["P--batch_size"],
                tokens_per_sec_df["P--prefill_length"],
                tokens_per_sec_df.drop(columns=["P--batch_size", "P--prefill_length"]),
            ],
            axis=1,
        )
        throughput_dict[ctx] = tokens_per_sec_df
    return throughput_dict

In [15]:
throughput_data = compute_throughput_tokens_per_sec(combined_raw_data)

In [None]:
throughput_data[2048]

In [None]:
throughput_data[4096]

In [None]:
throughput_data[8192]

In [None]:
throughput_data[16384]

In [None]:
throughput_data[32768]

In [21]:
throughput_32768 = throughput_data[32768].loc[
    throughput_data[32768]["P--batch_size"] == 2.0
]
throughput_16384 = throughput_data[16384].loc[
    throughput_data[16384]["P--batch_size"] == 4.0
]
throughput_8192 = throughput_data[8192].loc[
    throughput_data[8192]["P--batch_size"] == 8.0
]
throughput_4096 = throughput_data[4096].loc[
    throughput_data[4096]["P--batch_size"] == 16.0
]
throughput_2048 = throughput_data[2048].loc[
    throughput_data[2048]["P--batch_size"] == 32.0
]

In [22]:
throughput_df = pd.concat(
    [
        throughput_2048,
        throughput_4096,
        throughput_8192,
        throughput_16384,
        throughput_32768,
    ]
)

In [None]:
throughput_df

## Plotting the raw data

In [24]:
column_name_mapping = {
    "P--batch_size": "BS",
    "P--prefill_length": "CTX",
    "R--llama2__tcm__ampdt-bfloat16__wdt-bfloat16__ucgg-True_ucgm-False": "llama2",
    "R--llama3__tcm__ampdt-bfloat16__wdt-bfloat16__ucgg-True_ucgm-False": "llama3",
    "R--codestral_mamba__ampdt-bfloat16__wdt-bfloat16__ucgg-True_ucgm-False": "codestral_mamba",
    "R--falcon_mamba__ampdt-bfloat16__wdt-bfloat16__ucgg-True_ucgm-False": "falcon_mamba",
    "R--xlstm__tcm__ampdt-bfloat16__wdt-bfloat16__ucgg-True_ucgm-False_isd-bfloat16_ed-4096_nh-8_nb-32_vs-50304_wm-fused_ck-chunkwise--triton_xl_chunk_sk-native_sequence__triton_step_fused_sk-triton_fused_cs-128_akd-bfloat16": "xlstm",
}

In [None]:
throughput_df.columns

In [26]:
new_col_names = [column_name_mapping.get(col, col) for col in throughput_df.columns]
throughput_df.columns = new_col_names

In [None]:
plot_throughput_df = throughput_df.round(0).astype(int)
plot_throughput_df

In [None]:
fig = rc_context_wrapper_quarter_col_plot(
    func=create_runtime_bar_plot,
    data_df=plot_throughput_df,
    group_col_names=["BS", "CTX"],
    bar_label_font_size=30,
    style_dict=style_dict,
    figsize=FIGSIZE,  # (1.6 * 12 * 1 / 2.54, 1.5 * 8 * 1 / 2.54),
    y_label="Tokens per Second",
    legend_args={
        "loc": "lower center",
        "ncol": 2,
        "bbox_to_anchor": (-0.05, 1.02, 1.0, 0.502),
        "frameon": False,
        "facecolor": "white",
    },
)
plt.show()

In [26]:
# savefig(fig, "paper-throughput")