In [None]:
%load_ext autoreload
%autoreload 2
import sys

import pandas as pd

sys.path.append("../..")

from mlstm_kernels.utils.benchmark.plot_results import (
    plot_benchmark_result_table,
    rc_context_wrapper,
    select_columns,
)
from pathlib import Path
from plot_config import linestyle_mapping, style_dict
import pickle

In [None]:
# Collect all results batch size 1
falconmamba_gen_file = "/home/beck/wdir/dev_repos/mlstm_kernels/outputs_kernel_benchmarks_final/2024-12-05_13-57-50__gen_time__gentime_falconmamba_cgmtrue_v0/hf_7B_generation_time__pfl0_bs1_tcTrue_weightdtypebfloat16/results.csv"

codestralmamba_gen_file = "/home/beck/wdir/dev_repos/mlstm_kernels/outputs_kernel_benchmarks_final/2024-12-05_15-43-18__gen_time__codestral_mamba_gen_cgmtrue_v0/hf_7B_generation_time__pfl0_bs1_tcTrue_weightdtypebfloat16/results.csv"

mxlstmmamba_gen_file = "/home/beck/wdir/dev_repos/mlstm_kernels/outputs_kernel_benchmarks_final/2024-12-05_08-42-43__gen_time__genttime_xlstm_v1/hf_7B_generation_time__pfl0_bs1_tcTrue_weightdtypebfloat16/results.csv"

llama_gen_file = "/home/beck/wdir/dev_repos/mlstm_kernels/outputs_kernel_benchmarks_final/2024-12-05_15-20-04__gen_time__llama_static_v0/hf_7B_generation_time__pfl0_bs1_tcTrue_weightdtypebfloat16/results.csv"

file_dict = {
    "falconmamba": falconmamba_gen_file,
    "codestralmamba": codestralmamba_gen_file,
    "llama": llama_gen_file,
    "mxlstmmamba": mxlstmmamba_gen_file,
}

In [None]:
dataframe_dict = {
    k: (
        pd.read_csv(v).filter(regex=".*generation|.*R--.*"),
        pd.read_csv(v).filter(regex=".*generation|.*M--.*"),
    )
    for k, v in file_dict.items()
}

In [None]:
gen_time_df = pd.concat([v[0] for v in dataframe_dict.values()], axis=1)
gen_time_df = pd.concat(
    [
        gen_time_df.filter(regex=".*generation.*").take([0], axis=1),
        gen_time_df.filter(regex=".*R--.*") / 1e3,
    ],
    axis=1,
)
gen_time_df

In [None]:
gen_mem_df = pd.concat([v[1] for v in dataframe_dict.values()], axis=1)
gen_mem_df = pd.concat(
    [
        gen_mem_df.filter(regex=".*generation.*").take([0], axis=1),
        gen_mem_df.filter(regex=".*M--.*") / 1e9,
    ],
    axis=1,
)
gen_mem_df

In [None]:
raw_data = {
    "gen_time_seconds": gen_time_df,
    "gen_mem_gb": gen_mem_df,
}
with open("gen_time_mem_data.p", "wb") as f:
    pickle.dump(raw_data, f)

for k, v in raw_data.items():
    v.to_csv(f"raw_data_{k}.csv")

## Plotting the raw data

In [None]:
fig = plot_benchmark_result_table(
    gen_time_df,
    x_axis_param="generation_length",
    # linestyle_mapping=linestyle_mapping,
    # style_dict=style_dict,
    style_dict_colname_mapping_exact=False,
    y_label="Time [s]",
    title="Time to generate X tokens, no prefill",
)

In [None]:
fig = plot_benchmark_result_table(
    gen_mem_df,
    x_axis_param="generation_length",
    # linestyle_mapping=linestyle_mapping,
    # style_dict=style_dict,
    style_dict_colname_mapping_exact=False,
    y_label="Memory GB",
    title="Time to generate X tokens, no prefill",
)

## Final Plots - All results

In [None]:
selected_columns_runtime = {
    "llama3": "R--llama3__tcm__ampdt-bfloat16__wdt-bfloat16__ucgg-False_ucgm-False",
    "llama2": "R--llama2__tcm__ampdt-bfloat16__wdt-bfloat16__ucgg-False_ucgm-False",
    "falcon_mamba": "R--falcon_mamba__ampdt-bfloat16__wdt-bfloat16__ucgg-False_ucgm-True",
    "codestral_mamba": "R--codestral_mamba__ampdt-bfloat16__wdt-bfloat16__ucgg-False_ucgm-True",
    "xlstm": "R--xlstm__tcm__ampdt-bfloat16__wdt-bfloat16__ucgg-False_ucgm-True_isd-bfloat16_ed-4096_nh-8_nb-32_vs-50304_wm-fused_ck-chunkwise--triton_xl_chunk_sk-native_sequence__triton_step_fused_sk-triton_fused_cs-128_akd-bfloat16",
}
selected_columns_memory = {
    "llama2": "M--llama2__tcm__ampdt-bfloat16__wdt-bfloat16__ucgg-False_ucgm-False",
    "llama3": "M--llama3__tcm__ampdt-bfloat16__wdt-bfloat16__ucgg-False_ucgm-False",
    "falcon_mamba": "M--falcon_mamba__ampdt-bfloat16__wdt-bfloat16__ucgg-False_ucgm-True",
    "codestral_mamba": "M--codestral_mamba__ampdt-bfloat16__wdt-bfloat16__ucgg-False_ucgm-True",
    "xlstm": "M--xlstm__tcm__ampdt-bfloat16__wdt-bfloat16__ucgg-False_ucgm-True_isd-bfloat16_ed-4096_nh-8_nb-32_vs-50304_wm-fused_ck-chunkwise--triton_xl_chunk_sk-native_sequence__triton_step_fused_sk-triton_fused_cs-128_akd-bfloat16",
}
filename_suffix = ""
add_legend = True

In [None]:
gen_time_plot_df = select_columns(
    gen_time_df, selected_columns_runtime, keep_col_regex=".*generation.*"
)

In [None]:
fig = rc_context_wrapper(
    func=plot_benchmark_result_table,
    result_df=gen_time_plot_df,
    x_axis_param="generation_length",
    # linestyle_mapping=linestyle_mapping,
    style_dict=style_dict,
    style_dict_colname_mapping_exact=False,
    y_label="Generation Time [s]",
    x_label="Generated Tokens",
    title="",  # "Time to generate 1 tokens, for varying prefill lengths",
    figsize=(1.5 * 12 * 1 / 2.54, 1.5 * 8 * 1 / 2.54),
    filename=f"generation_time{filename_suffix}",
    add_legend=add_legend,
)

In [None]:
gen_mem_plot_df = select_columns(
    gen_mem_df, selected_columns_memory, keep_col_regex=".*generation.*"
)

In [None]:
fig = rc_context_wrapper(
    func=plot_benchmark_result_table,
    result_df=gen_mem_plot_df,
    x_axis_param="generation_length",
    # linestyle_mapping=linestyle_mapping,
    style_dict=style_dict,
    style_dict_colname_mapping_exact=False,
    y_label="GPU Memory [GB]",
    x_label="Generation Length",
    title="",  # "Time to generate 100 tokens, for varying prefill lengths",
    figsize=(1.5 * 12 * 1 / 2.54, 1.5 * 8 * 1 / 2.54),
    filename=f"generation_memory{filename_suffix}",
    add_legend=add_legend,
)