In [None]:
%load_ext autoreload
%autoreload 2
import torch
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import pickle
from mlstm_kernels.utils.benchmark.plot_results import plot_benchmark_result_table, create_runtime_line_plot
import seaborn as sns

## Make paper plots

In [2]:
result_filename = "mlstm_tfla_paper_consttoken_benchmark_results.p"

In [3]:
with open(Path(".") / result_filename, "rb") as f:
    all_results_dict = pickle.load(f)

In [4]:
all_runtime_results_fw_df = all_results_dict["runtime"]["fw"]
all_runtime_results_fwbw_df = all_results_dict["runtime"]["fwbw"]

In [5]:
all_memory_results_fw_df = all_results_dict["memory"]["fw"]
all_memory_results_fwbw_df = all_results_dict["memory"]["fwbw"]

In [None]:
all_runtime_results_fwbw_df.filter(regex=f"sequence_length|.*xl_chunk_siging.*")

In [None]:
fig = create_runtime_line_plot(
    data_df=all_runtime_results_fwbw_df.filter(regex=r"sequence_length|.*xl_chunk_siging.*"),
    group_col_names=["sequence_length"],
)

In [None]:
fig = create_runtime_line_plot(
    data_df=all_memory_results_fwbw_df.filter(regex=r"sequence_length|.*xl_chunk_siging.*"),
    group_col_names=["sequence_length"],
    y_label="Memory [GB]"
)

In [9]:
selected_mlstm_runtime_fwbw_df = all_runtime_results_fwbw_df.filter(regex=r"sequence_length|.*xl_chunk_siging.*")
selected_mlstm_memory_fwbw_df = all_memory_results_fwbw_df.filter(regex=r"sequence_length|.*xl_chunk_siging.*")

In [22]:
seq_len = 8192
selected_mlstm_runtime_ctx_fwbw_df = selected_mlstm_runtime_fwbw_df[selected_mlstm_runtime_fwbw_df["sequence_length"]==8192].drop(columns=["sequence_length"]).T
selected_mlstm_memory_ctx_fwbw_df = selected_mlstm_memory_fwbw_df[selected_mlstm_memory_fwbw_df["sequence_length"]==8192].drop(columns=["sequence_length"]).T

In [16]:
def extract_chunksize(specifier: str):
    return int(specifier.split("__")[-1].split("_")[0].split("-")[1])

In [53]:
runtime_df = selected_mlstm_runtime_ctx_fwbw_df.rename(index=extract_chunksize)
runtime_df.index.name = "chunk_size"
runtime_df.columns = ["runtime"]
runtime_df = runtime_df.reset_index()

In [54]:
memory_df = selected_mlstm_memory_ctx_fwbw_df.rename(index=extract_chunksize)
memory_df.index.name = "chunk_size"
memory_df.columns = ["memory"]
memory_df = memory_df / 1e9
memory_df = memory_df.reset_index()

In [None]:
memory_runtime_df = pd.concat([runtime_df, memory_df], axis=1)
memory_runtime_df = memory_runtime_df.loc[:, ~memory_runtime_df.columns.duplicated()]
memory_runtime_df

In [97]:
import numpy as np


def create_double_bar_plot(
    data: pd.DataFrame,
    y_col_left: str,
    y_col_right: str,
    x_col: str,
    left_color,
    right_color,
    left_label: str,
    right_label: str,
    x_label: str,
    figsize: tuple[float, float],
    bar_width: float = 0.4,
):
    fig, ax_left = plt.subplots(figsize=figsize)

    x_positions = np.arange(len(data))

    bars_left = ax_left.bar(x=x_positions - bar_width / 2, height=data[y_col_left], width=bar_width, color=left_color, label=left_label)

    ax_left.set_ylabel(ylabel=left_label)
    # ax_left.tick_params(axis="y")
    ax_left.set_xticks(x_positions)
    ax_left.set_xticklabels(data[x_col])
    ax_left.set_xlabel(x_label)

    ax_right = ax_left.twinx()

    bars_right = ax_right.bar(x=x_positions + bar_width / 2, height=data[y_col_right], width=bar_width, color=right_color, label=right_label)
    ax_right.set_ylabel(ylabel=right_label)
    ax_left.spines.top.set_visible(False)
    ax_right.spines.top.set_visible(False)
    ax_left.grid(alpha=0.2, which="both")
    # ax_right.grid(alpha=0.2)

    legend_kwargs = {
        "loc": "lower center",
        "ncol": 5,
        "bbox_to_anchor": (0.0, 0.87, 1.0, 0.102),
        "frameon": False,
        "facecolor": "white",
    }
    fig.legend([bars_left, bars_right], [left_label, right_label], **legend_kwargs)
    return fig


In [109]:
from plot_config import (
    get_tb_plot_mpl_context,
    savefig
)

In [None]:
with get_tb_plot_mpl_context(fontsize_delta=-0):
    fig = create_double_bar_plot(
        data=memory_runtime_df,
        y_col_left="runtime",
        y_col_right="memory",
        x_col="chunk_size",
        left_color=plt.colormaps["tab10"](0),
        right_color=plt.colormaps["tab10"](1),
        left_label="Time [ms]",
        right_label="GPU Memory [GB]",
        x_label="Chunk Size",
        figsize=(6,2.5),
        bar_width=0.4
    )

In [122]:
# savefig(fig=fig, filename="tfla_mlstm_kernel_memory_vs_runtime--paper")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# Sample data
x = ['A', 'B', 'C', 'D']
series1 = [10, 20, 30, 40]  # Data for the left y-axis
series2 = [100, 200, 300, 400]  # Data for the right y-axis

# Create a DataFrame for plotting
df = pd.DataFrame({
    'x': x,
    'series1': series1,
    'series2': series2
})

# Initialize the plot
fig, ax1 = plt.subplots(figsize=(8, 6))

# Plot the first series on the left y-axis
sns.barplot(x='x', y='series1', data=df, color='blue', ax=ax1, label='Series 1')
ax1.set_ylabel('Series 1 (Left Y-Axis)', color='blue')
ax1.tick_params(axis='y', labelcolor='blue')

# Create the second y-axis
ax2 = ax1.twinx()

# Plot the second series on the right y-axis
sns.barplot(x='x', y='series2', data=df, color='red', ax=ax2, alpha=0.6, label='Series 2')
ax2.set_ylabel('Series 2 (Right Y-Axis)', color='red')
ax2.tick_params(axis='y', labelcolor='red')

# Add a legend
fig.legend(loc='upper left', bbox_to_anchor=(0.1, 0.9))

# Show the plot
plt.show()

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

# Sample data
x = ['A', 'B', 'C', 'D']
series1 = [10, 20, 30, 40]  # Data for the left y-axis
series2 = [100, 200, 300, 400]  # Data for the right y-axis

# Bar width
bar_width = 0.4

# Create positions for the bars
x_positions = np.arange(len(x))

# Initialize the plot
fig, ax1 = plt.subplots(figsize=(8, 6))

# Plot the first series on the left y-axis
bars1 = ax1.bar(x_positions - bar_width / 2, series1, bar_width, color='blue', label='Series 1')
ax1.set_ylabel('Series 1 (Left Y-Axis)', color='blue')
ax1.tick_params(axis='y', labelcolor='blue')
ax1.set_xticks(x_positions)
ax1.set_xticklabels(x)

# Create the second y-axis
ax2 = ax1.twinx()

# Plot the second series on the right y-axis
bars2 = ax2.bar(x_positions + bar_width / 2, series2, bar_width, color='red', label='Series 2', alpha=0.7)
ax2.set_ylabel('Series 2 (Right Y-Axis)', color='red')
ax2.tick_params(axis='y', labelcolor='red')

# Add a legend
fig.legend([bars1, bars2], ['Series 1', 'Series 2'], loc='upper left', bbox_to_anchor=(0.1, 0.9))

# Show the plot
plt.show()

In [15]:
# savefig(fig=fig, filename="tfla_mlstm_kernel_benchmark--paper")