In [1]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go

from benchplots import *
from benchutils import *
from benchmodels import *
from plotly.subplots import make_subplots

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm
Seed set to 0


In [2]:
df = pd.read_csv('timings.csv', index_col=0)
# df = df.drop(columns="run").groupby(["method", "n"]).mean().reset_index()

In [3]:
TIME_TO_SHOW = "t_total"
n0 = 200000

def fmt_seconds_compact(s: float) -> str:
    if s is None or (isinstance(s, float) and np.isnan(s)):
        return ""
    s = float(s)
    if s < 1:
        return f"{s:.1f}s"
    if s < 60:
        return f"{int(round(s))}s"
    # cap at minutes (no hours)
    return f"{int(round(s / 60.0))}m"


# -----------------------------
# Left panel: heatmap (log10)
# -----------------------------
mat = df.pivot(index="method", columns="n", values=TIME_TO_SHOW)

# sort methods by mean runtime (DESCENDING)
order = mat.mean(axis=1).sort_values(ascending=False).index
mat = mat.loc[order]

z = np.log10(mat.values)
heat_text = np.vectorize(fmt_seconds_compact)(mat.values)

x_cells = [str(int(x)) for x in mat.columns.astype(int)]
methods = list(mat.index)

# Colorbar ticks: positions in log-space, labels in real time
tick_seconds = [0.1, 1, 10, 60, 600, 3600]   # 0.1s stands in for "0s" on log scale
tick_vals    = np.log10(tick_seconds)
tick_text    = ["0s", "1s", "10s", "1m", "10m", "1h"]

# ---------------------------------------
# Right panel: stacked bars at n = 200k
# ---------------------------------------
d = df[df["n"] == n0].copy()
d = d.set_index("method").reindex(methods)

t_ref   = d["t_ref"].fillna(0.0).to_numpy()
t_annot = d["t_annot"].fillna(0.0).to_numpy()
t_total = d["t_total"].fillna(pd.Series(t_ref + t_annot, index=d.index)).to_numpy()

# 10-minute tick increments on the right panel
step = 600  # 10 minutes
max_total = float(np.nanmax(t_total)) if len(t_total) else 0.0
tick_max = int(np.ceil(max_total / step) * step) if max_total > 0 else step
tick_seconds_bar = list(range(0, tick_max + step, step))
tick_text_bar = ["0" if t == 0 else fmt_seconds_compact(t) for t in tick_seconds_bar]

# -----------------------------
# Combine into one panel
# -----------------------------
fig = make_subplots(
    rows=1,
    cols=2,
    shared_yaxes=True,
    column_widths=[0.50, 0.50],
    horizontal_spacing=0.06,
    subplot_titles=("Runtime scaling", "Wall time at n=200,000"),
    # subplot_titles=("Runtime heatmap (log scale)", f"Decomposition at n={n0:,}"),
)
fig.update_annotations(yshift=10)  # moves subplot titles up a bit

# Shared coloraxis, but place the (horizontal) colorbar directly under LEFT panel
# With 50/50 split + spacing, left panel is roughly centered at x≈0.25 in paper coords.
fig.update_layout(
    coloraxis=dict(
        colorscale="Blues",
        colorbar=dict(
            orientation="h",
            tickmode="array",
            tickvals=tick_vals,
            ticktext=tick_text,
            title="Wall time",
            x=0.25,             # center under left panel
            xanchor="center",
            y=-0.16,            # below plots
            yanchor="top",
            len=0.42,           # width of the bar under left panel
            thickness=14,
        ),
    )
)

# Heatmap (left)
fig.add_trace(
    go.Heatmap(
        z=z,
        x=x_cells,
        y=methods,
        text=heat_text,
        texttemplate="%{text}",
        textfont=dict(size=13),
        coloraxis="coloraxis",
        showscale=True,
    ),
    row=1,
    col=1,
)

# Stacked bars (right) — no inside text
ANNOT_COLOR = "#1f4e79"
REF_COLOR   = "#b0b7c3"

fig.add_trace(
    go.Bar(
        x=t_annot,
        y=methods,
        orientation="h",
        name="Annotate",
        marker=dict(color=ANNOT_COLOR),
    ),
    row=1,
    col=2,
)
fig.add_trace(
    go.Bar(
        x=t_ref,
        y=methods,
        orientation="h",
        name="Set ref",
        marker=dict(color=REF_COLOR),
    ),
    row=1,
    col=2,
)

# Axes formatting
fig.update_xaxes(title_text="Number of cells", type="category", row=1, col=1)

fig.update_xaxes(
    title_text="Wall time",
    tickmode="array",
    tickvals=tick_seconds_bar,
    ticktext=tick_text_bar,
    rangemode="tozero",
    # NOTE: no hard-coded range -> adaptive autorange
    row=1,
    col=2,
)

# Remove "Method" axis title; keep labels on left only
fig.update_yaxes(title_text="", row=1, col=1)
fig.update_yaxes(showticklabels=False, title_text="", row=1, col=2)

# Hide Plotly legend (we’ll draw a boxed legend in the right panel)
fig.update_layout(showlegend=False)

# Boxed legend in top-right of RIGHT panel, stacked vertically
legend_html = (
    f"<span style='color:{ANNOT_COLOR}'>■</span> Annotate<br>"
    f"<span style='color:{REF_COLOR}'>■</span> Ref setup"
)
fig.add_annotation(
    xref="x2 domain",
    yref="paper",
    x=0.99,
    y=0.98,
    xanchor="right",
    yanchor="top",
    showarrow=False,
    align="left",
    text=legend_html,
    font=dict(size=14),
    bgcolor="rgba(255,255,255,0.95)",
    bordercolor="rgba(0,0,0,0.25)",
    borderwidth=1,
    borderpad=6,
)

# Layout tweaks (extra bottom margin for the horizontal colorbar)
fig.update_layout(
    barmode="stack",
    height=800,
    width=1250,
    margin=dict(l=190, r=30, t=90, b=150),
)

# move only the first two annotations (the subplot titles) up
for i in range(2):
    fig.layout.annotations[i].y = fig.layout.annotations[i].y + 0.02



fig.update_yaxes(tickfont=dict(size=16), row=1, col=1)   # methods
fig.update_xaxes(tickfont=dict(size=14), title_font=dict(size=16), row=1, col=1)
fig.update_xaxes(tickfont=dict(size=14), title_font=dict(size=16), row=1, col=2)


add_paper_styling(fig, lines=False)


fig.update_yaxes(
    showticklabels=False,
    tickmode=None,
    tickvals=None,
    ticktext=None,
    tickangle=45,
    showgrid=False,
    zeroline=False,
    ticks="",
    ticklen=0,
    row=1,
    col=2,
)

fig.show()


In [6]:
fig.write_image(f"runtime.png", scale=3)