In [1]:
%load_ext autoreload
%autoreload 2
import torch
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import pickle
from mlstm_kernels.utils.benchmark.plot_results import plot_benchmark_result_table, create_runtime_line_plot

## Make paper plots

In [2]:
result_filename = "mlstm_tfla_paper_consttoken_benchmark_results.p"

In [3]:
with open(Path(".") / result_filename, "rb") as f:
    all_results_dict = pickle.load(f)

In [4]:
all_runtime_results_fw_df = all_results_dict["runtime"]["fw"]
all_runtime_results_fwbw_df = all_results_dict["runtime"]["fwbw"]
all_memory_results_fwbw_df = all_results_dict["memory"]["fwbw"]

In [None]:
all_runtime_results_fwbw_df.style

In [None]:
all_runtime_results_fwbw_df.columns

In [7]:
from plot_config import (
    map_consttoken_fwbw_appendix_data_col_to_plot_col_mapping,
    get_tb_plot_mpl_context,
    GRIDSPEC_KWARGS,
    savefig,
    get_style_dict_appendix,
    get_col_order_appendix
)

In [8]:
consttoken_fwbw_raw_df = all_runtime_results_fwbw_df[["sequence_length"]+list(map_consttoken_fwbw_appendix_data_col_to_plot_col_mapping(fwbw=True).keys())]
consttoken_fwbw_df = consttoken_fwbw_raw_df.rename(columns=map_consttoken_fwbw_appendix_data_col_to_plot_col_mapping(fwbw=True))

In [9]:
consttoken_fw_raw_df = all_runtime_results_fw_df[["sequence_length"]+list(map_consttoken_fwbw_appendix_data_col_to_plot_col_mapping(fwbw=False).keys())]
consttoken_fw_df = consttoken_fw_raw_df.rename(columns=map_consttoken_fwbw_appendix_data_col_to_plot_col_mapping(fwbw=False))

In [10]:
consttoken_memory_fwbw_raw_df = all_memory_results_fwbw_df[["sequence_length"]+list(map_consttoken_fwbw_appendix_data_col_to_plot_col_mapping(fwbw=True).keys())]
consttoken_memory_fwbw_df = consttoken_memory_fwbw_raw_df.rename(columns=map_consttoken_fwbw_appendix_data_col_to_plot_col_mapping(fwbw=True))

In [None]:
filter_regex = "sequence_length|mlstmsig.*|mlstmexp_triton_limit_chunk|chunk_gla|fused_chunk_gla|chunk_simple_gla"
mlstmsig_consttoken_runtime_fwbw_df = consttoken_fwbw_df.filter(regex=filter_regex)
mlstmsig_consttoken_runtime_fw_df = consttoken_fw_df.filter(regex=filter_regex)
mlstmsig_consttoken_memory_fwbw_df = consttoken_memory_fwbw_df.filter(regex=filter_regex)
mlstmsig_consttoken_memory_fwbw_df = pd.concat(
    [
        mlstmsig_consttoken_memory_fwbw_df["sequence_length"],
        mlstmsig_consttoken_memory_fwbw_df.replace(-1, float("nan")).loc[
            :, mlstmsig_consttoken_memory_fwbw_df.columns != "sequence_length"
        ]
        / 1e9,
    ],
    axis=1,
)
mlstmsig_consttoken_runtime_fwbw_df

In [12]:
chunk_sizes = list(reversed([128, 256, 512, 1024, 2048, 4096]))

In [None]:
fig = create_runtime_line_plot(
    data_df=mlstmsig_consttoken_runtime_fwbw_df,
    style_dict=get_style_dict_appendix(
        chunk_sizes=chunk_sizes, colormap=plt.cm.copper, cmap_start_end=(0.2, 1)
    ),
    group_col_names=["sequence_length"],
    plot_column_order=get_col_order_appendix(chunk_sizes=chunk_sizes),
    # legend_order=legend_order,
    # yticks=[0, 2.5, 5, 7.5, 10, 12.5, 15.0],
)

In [None]:
# fw, fwbw, memory
# fig_height = 4.5
with get_tb_plot_mpl_context(fontsize_delta=0):
    fig, (ax_left, ax_middle, ax_right) = plt.subplots(
        1,
        3,
        figsize=(18, 4),
        gridspec_kw={"wspace": 0.18, "hspace": 0.1}
,
        sharex=True,
    )
    
    fig = create_runtime_line_plot(
        ax=ax_left,
        data_df=mlstmsig_consttoken_runtime_fw_df,
        style_dict=get_style_dict_appendix(
            chunk_sizes=chunk_sizes, colormap=plt.cm.copper, cmap_start_end=(0.2, 1)
        ),
        group_col_names=["sequence_length"],
        plot_column_order=get_col_order_appendix(chunk_sizes=chunk_sizes),
        # legend_order=legend_order,
        # yticks=[0, 2.5, 5, 7.5, 10, 12.5, 15.0],
        legend_args=None,
    )
    fig = create_runtime_line_plot(
        ax=ax_middle,
        data_df=mlstmsig_consttoken_runtime_fwbw_df,
        style_dict=get_style_dict_appendix(
            chunk_sizes=chunk_sizes, colormap=plt.cm.copper, cmap_start_end=(0.2, 1)
        ),
        group_col_names=["sequence_length"],
        plot_column_order=get_col_order_appendix(chunk_sizes=chunk_sizes),
        # legend_order=legend_order,
        # yticks=[0, 2.5, 5, 7.5, 10, 12.5, 15.0],
        legend_args=None,
    )
    fig = create_runtime_line_plot(
        ax=ax_right,
        data_df=mlstmsig_consttoken_memory_fwbw_df,
        style_dict=get_style_dict_appendix(
            chunk_sizes=chunk_sizes, colormap=plt.cm.copper, cmap_start_end=(0.2, 1)
        ),
        group_col_names=["sequence_length"],
        plot_column_order=get_col_order_appendix(chunk_sizes=chunk_sizes),
        # legend_order=legend_order,
        # yticks=[0, 2.5, 5, 7.5, 10, 12.5, 15.0],
        y_label="GPU Memory [GB]",
        legend_args=None,
    )

    handles, labels = ax_left.get_legend_handles_labels()
    legend_kwargs = {
        "loc": "lower center",
        "ncol": 5,
        "bbox_to_anchor": (0.0, 0.87, 1.0, 0.102),
        "frameon": False,
        "facecolor": "white",
    }
    fig.legend(handles, labels, **legend_kwargs)

In [15]:
# savefig(fig=fig, filename="tfla_mlstm_kernel_benchmark--paper")

In [None]:
# fw, fwbw, memory
# fig_height = 4.5
with get_tb_plot_mpl_context(fontsize_delta=-0.5):
    fig, (ax_left, ax_right) = plt.subplots(
        1,
        2,
        figsize=(12.5, 4),
        gridspec_kw={"wspace": 0.15, "hspace": 0.1}
,
        sharex=True,
    )
    
    fig = create_runtime_line_plot(
        ax=ax_left,
        data_df=mlstmsig_consttoken_runtime_fwbw_df,
        style_dict=get_style_dict_appendix(
            chunk_sizes=chunk_sizes, colormap=plt.cm.copper, cmap_start_end=(0.2, 1)
        ),
        group_col_names=["sequence_length"],
        plot_column_order=get_col_order_appendix(chunk_sizes=chunk_sizes),
        # legend_order=legend_order,
        # yticks=[0, 2.5, 5, 7.5, 10, 12.5, 15.0],
        legend_args=None,
    )
    # fig = create_runtime_line_plot(
    #     ax=ax_middle,
    #     data_df=mlstmsig_consttoken_runtime_fwbw_df,
    #     style_dict=get_style_dict_appendix(
    #         chunk_sizes=chunk_sizes, colormap=plt.cm.copper, cmap_start_end=(0.2, 1)
    #     ),
    #     group_col_names=["sequence_length"],
    #     plot_column_order=get_col_order_appendix(chunk_sizes=chunk_sizes),
    #     # legend_order=legend_order,
    #     # yticks=[0, 2.5, 5, 7.5, 10, 12.5, 15.0],
    #     legend_args=None,
    # )
    fig = create_runtime_line_plot(
        ax=ax_right,
        data_df=mlstmsig_consttoken_memory_fwbw_df,
        style_dict=get_style_dict_appendix(
            chunk_sizes=chunk_sizes, colormap=plt.cm.copper, cmap_start_end=(0.2, 1)
        ),
        group_col_names=["sequence_length"],
        plot_column_order=get_col_order_appendix(chunk_sizes=chunk_sizes),
        # legend_order=legend_order,
        # yticks=[0, 2.5, 5, 7.5, 10, 12.5, 15.0],
        y_label="GPU Memory [GB]",
        legend_args=None,
    )

    handles, labels = ax_left.get_legend_handles_labels()
    legend_kwargs = {
        "loc": "lower center",
        "ncol": 5,
        "bbox_to_anchor": (0.0, 0.87, 1.0, 0.102),
        "frameon": False,
        "facecolor": "white",
    }
    fig.legend(handles, labels, **legend_kwargs)

In [17]:
savefig(fig=fig, filename="tfla_mlstm_kernel_benchmark_simple_gla_appendix--paper")