In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from pathlib import Path
from typing import List, Union

import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from utils import process_viztrace_json

In [None]:
benchmark_dict = {}
benchmark_endtimes = {}
n_runs = 20
for benchmark_name in [
    "tsfresh_sequential",
    "tsfresh_mp",
    "seglearn",
    "tsfel_sequential",
    "tsfel_mp",
    "tsflex_sequential",
    "tsflex_mp",
]:
    df_cpu, df_mem = None, None
    for i in range(n_runs):
        df_new_cpu, df_new_mem = process_viztrace_json(
            f"benchmark_jsons/{benchmark_name}_{i}.json"
        )

        if benchmark_name not in benchmark_endtimes:
            benchmark_endtimes[benchmark_name] = [df_new_mem.index[-1]]
        else:
            # print('append')
            benchmark_endtimes[benchmark_name].append(df_new_mem.index[-1])

        if df_cpu is None:
            df_cpu = df_new_cpu
        else:
            df_new_cpu["args.cpu_percent"].rename(f"args.cpu_percent_{i}")
            dfs = list([df_cpu, df_new_cpu])
            dfs.sort(key=lambda x: x.index[-1], reverse=True)
            df_cpu = pd.merge_asof(
                dfs[0],
                dfs[1],
                left_index=True,
                right_index=True,
                tolerance=pd.Timedelta("100ms"),
            )

        if df_mem is None:
            df_mem = df_new_mem
        else:
            df_new_mem = df_new_mem.filter(like="args.").rename(
                columns={"args.rss": f"args.rss{i}", "args.vms": f"args.vms{i}"}
            )
            dfs = list([df_mem, df_new_mem])
            dfs.sort(key=lambda x: x.index[-1], reverse=True)
            df_mem = pd.merge_asof(
                dfs[0],
                dfs[1],
                left_index=True,
                right_index=True,
                tolerance=pd.Timedelta("100ms"),
            )

    # mem usage
    df_mem["mean_rss"] = df_mem.filter(like="args.rss").mean(axis=1)
    df_mem["std_rss"] = df_mem.filter(like="args.rss").std(axis=1)
    df_mem["max_rss"] = df_mem.filter(like="args.rss").max(axis=1)

    # mean cpu usage
    df_cpu["mean_usage"] = df_cpu.filter(like="args.cpu_percent").mean(axis=1)
    df_cpu["std"] = df_cpu.filter(like="args.cpu_percent").std(axis=1)

    benchmark_dict[benchmark_name] = [df_mem, df_cpu]

In [None]:
# benchmark_dict["tsfel_sequential"][0].filter(like="args.rss").mean(
#     axis=1
# ) + benchmark_dict["tsfel_sequential"][0].filter(like="args.rss").rolling(
#     "1s"
# ).std().mean(
#     axis=1
# )

# np.max((benchmark_dict['tsfel_sequential'][0].filter(like='args.rss').mean(axis=1) + benchmark_dict['tsfel_sequential'][0].filter(like='args.rss').rolling('1s').std().mean(axis=1), benchmark_dict['tsfel_sequential'][0].filter(like='args.rss').max(axis=1)), axis=0)

# benchmark_dict['tsfel_sequential'][0].filter(like='args.rss').mean(axis=1).plot()
# (benchmark_dict['tsfel_sequential'][0].filter(like='args.rss').mean(axis=1) +
# (benchmark_dict['tsfel_sequential'][0].filter(like='args.rss').mean(axis=1)

In [None]:
df_endtime = pd.DataFrame(benchmark_endtimes)
df_endtime= df_endtime.apply(lambda x: x.dt.total_seconds())
df_endtime.info()

# df_endtime.plot(kind='box', logy=True, figsize=(30, 10))

In [None]:
pd.DataFrame(benchmark_endtimes).describe().T.sort_values(by='mean')

In [None]:
y_tickvals = list(sum([[i * (10 ** pw) for i in [1, 3]] for pw in range(-1, 4)], []))
x_tickvals = list(sum([[i * (10 ** pw) for i in [1, 2, 5]] for pw in range(-2, 3)], []))

colors = [
    "#1f78b4",
    "#33a02c",
    #     "#fb9a99",
    #     "#e31a1c",
    #     "#fdbf6f",
    #     "#ff7f00",
    #     "#cab2d6",
    "#6a3d9a",
    #     "#ffff99",
    #     "#b15928",
    #     "#7fc97f",
    #     "#beaed4",
    #     "#fdc086",
    #     "#ffff99",
    #     "#386cb0",
    "#f0027f",
    "#bf5b17",
    "#666666",
]

fig = make_subplots(
    shared_xaxes=True,
    subplot_titles=[
        f"Strided-window feature extraction - averaged over {n_runs} runs"
    ],  # , "CPU usage"],
)


fig.update_layout(height=500)
fig.update_yaxes(type="log", tickvals=y_tickvals)
fig.update_yaxes(title_text="Memory usage (MB)")
fig.update_xaxes(
    type="log",
    tickvals=x_tickvals,
)  # range=[-1, 2.3], )
fig.update_xaxes(title_text="Runtime (s)")


# fig.update_yaxes(title_text="%", row=2, col=1)

library_list = []

color_idx = 0
for i, (benchmark_name, (df_mem, df_cpu)) in enumerate(benchmark_dict.items()):
    library = benchmark_name.split("_")[0]
    multiprocessing = "_mp" in benchmark_name.lower()

    first = False
    if library not in library_list:
        first = True
        library_list.append(library)

    color_idx = library_list.index(library)

    kwargs = {} if not multiprocessing else {"line_dash": "dash"}

    fig.add_trace(
        go.Scatter(
            x=df_mem.index.total_seconds(),
            y=df_mem["mean_rss"] / 1e6,
            name=benchmark_name,
            legendgroup=library,
            line_color=colors[color_idx],
            **kwargs,
        ),
        row=1,
        col=1,
    )

    color_str = colors[color_idx]
    rgb_vals = [str(int(color_str.lstrip("#")[i : i + 2], 16)) for i in (0, 2, 4)]
    fig.add_trace(
        go.Scatter(
            name="upper memory bound",
            x=df_mem.index.total_seconds(),
            y=df_mem["max_rss"] / 1e6,
            marker=dict(color="#444"),
            line=dict(width=0),
            mode="lines",
            fillcolor=f"rgba({', '.join(rgb_vals)}, 0.1)",
            fill="tonexty",
            showlegend=False,  # first
            legendgroup="library",
        ),
        row=1,
        col=1,
    )


updatemenus = [
    dict(
        buttons=list(
            [
                dict(
                    args=[{"yaxis": {"type": "log", "title": "Memory usage (MB)"}}],
                    label="Y-axis: Log",
                    method="relayout",
                ),
                dict(
                    args=[{"yaxis": {"type": "linear", "title": "Memory usage (MB)"}}],
                    label="Y-axis: Linear",
                    method="relayout",
                ),
            ]
        ),
        direction="down",
        showactive=True,
        pad={"r": 10, "t": 10},
        yanchor="top",
        y=0.9,
    ),
    dict(
        buttons=list(
            [
                dict(
                    args=[{"xaxis": {"type": "log", "title": "Runtime (s)"}}],
                    label="X-axis: Log",
                    method="relayout",
                ),
                dict(
                    args=[{"xaxis": {"type": "linear", "title": "Runtime (s)"}}],
                    label="X-axis: Linear",
                    method="relayout",
                ),
            ]
        ),
        direction="down",
        showactive=True,
        pad={"r": 10, "t": 10},
        y=1.05,
        yanchor="top",
    ),
]


fig.update_layout(updatemenus=updatemenus)
fig.show()

In [None]:
def figs_to_html(
    figs: List[go.Figure],
    html_path: Union[Path, str],
    append=False,
    include_plotlyjs=True,
):
    """Save a list of figures in a single HTML file.

    :param figs: A list of plotly figures
    :param html_path: the HTML path where the figure will be saved
    :param append:
    """
    if not isinstance(html_path, Path):
        html_path = Path(html_path)

    if not html_path.parent.exists():
        os.makedirs(html_path.parent)

    with open(html_path, "a" if append else "w") as f:
        for fig in figs:
            f.write(fig.to_html(full_html=False, include_plotlyjs=include_plotlyjs))


figs_to_html([fig], "benchmark.html", include_plotlyjs=True)