In [None]:
import os
from collections import defaultdict
from typing import Literal

import matplotlib.lines as mlines
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from panda.utils.plot_utils import apply_custom_style, make_box_plot

In [None]:
apply_custom_style("../../config/plotting.yaml")

In [None]:
figs_save_dir = os.path.join("../../figures", "eval_metrics")
os.makedirs(figs_save_dir, exist_ok=True)

In [None]:
WORK_DIR = os.getenv("WORK", "")
DATA_DIR = os.path.join(WORK_DIR, "data")

In [None]:
scalinglaw_splits = [2**i for i in range(0, 8)]

In [None]:
scalinglaw_splits

In [None]:
# data_split = "final_skew40/test_zeroshot"
data_split = "test_zeroshot"

run_names_chattn = {
    # "ic1": "pft_chattn_noembed_pretrained_correct-0",
    "ic1": "pft_chattn_mlm_sys20k_ic1-0",
    "ic2": "pft_chattn_mlm_sys10490_ic2-0",
    "ic4": "pft_chattn_mlm_sys5245_ic4-0",
    "ic8": "pft_chattn_mlm_sys2623_ic8-0",
    "ic16": "pft_chattn_mlm_sys1312_ic16-0",
    "ic32": "pft_chattn_mlm_sys656_ic32-0",
    "ic64": "pft_chattn_mlm_sys328_ic64-0",
    "ic128": "pft_chattn_mlm_sys164_ic128-0",
}

run_metrics_dirs_all_groups = {
    "chattn": {
        run_abbrv: os.path.join(
            WORK_DIR,
            "eval_results",
            "patchtst",
            run_name,
            data_split,
        )
        for run_abbrv, run_name in run_names_chattn.items()
    },
}

In [None]:
run_metrics_dirs_all_groups

In [None]:
metrics_all = defaultdict(lambda: defaultdict(dict))
for run_group, run_metrics_dir_dict in run_metrics_dirs_all_groups.items():
    print(f"Run group: {run_group}")
    for run_abbrv, run_metrics_dir in run_metrics_dir_dict.items():
        if not os.path.exists(run_metrics_dir):
            print(f"Run metrics directory does not exist for {run_abbrv}: {run_metrics_dir}")
            continue
        run_abbrv = str(run_abbrv)
        print(f"{run_abbrv}: {run_metrics_dir}")
        for file in sorted(
            os.listdir(run_metrics_dir),
            key=lambda x: int(x.split("_pred")[1].split(".csv")[0]),
        ):
            if file.endswith(".csv"):
                prediction_length = int(file.split("_pred")[1].split(".csv")[0])
                # print(f"Prediction length: {prediction_length} for {run_abbrv}")
                with open(os.path.join(run_metrics_dir, file)) as f:
                    metrics = pd.read_csv(f).to_dict()
                    metrics_all[run_group][run_abbrv][prediction_length] = metrics

In [None]:
metrics_all.keys()

In [None]:
unrolled_metrics_all_groups = defaultdict(lambda: defaultdict(dict))
for run_group, all_metrics_of_run_group in metrics_all.items():
    # print(run_group)
    for run_abbrv, all_metrics_of_run_abbrv in all_metrics_of_run_group.items():
        # print(run_abbrv)
        for run_name, metrics in all_metrics_of_run_abbrv.items():
            # print(run_name)
            systems = metrics.pop("system")
            metrics_unrolled = {k: list(v.values()) for k, v in metrics.items()}
            # print(metrics_unrolled.keys())
            unrolled_metrics_all_groups[run_group][run_abbrv][run_name] = metrics_unrolled

In [None]:
unrolled_metrics_all_combined = {
    **unrolled_metrics_all_groups["chattn"],
}

In [None]:
def get_summary_metrics_dict(unrolled_metrics, metric_name):
    summary_metrics_dict = defaultdict(dict)
    for model_name, metrics_dict in unrolled_metrics.items():
        prediction_lengths = list(metrics_dict.keys())
        summary_metrics_dict[model_name]["prediction_lengths"] = prediction_lengths
        means = []
        medians = []
        stds = []
        for prediction_length in prediction_lengths:
            metric_val = metrics_dict[prediction_length][metric_name]
            means.append(np.nanmean(metric_val))
            medians.append(np.nanmedian(metric_val))
            stds.append(np.nanstd(metric_val))
        summary_metrics_dict[model_name]["means"] = means
        summary_metrics_dict[model_name]["medians"] = medians
        summary_metrics_dict[model_name]["stds"] = stds
    return summary_metrics_dict

In [None]:
def plot_metrics_by_prediction_length(metrics_dict, metric_name, show_std_envelope=False):
    plt.figure(figsize=(5, 4))
    for model_name, metrics in metrics_dict.items():
        plt.plot(
            metrics["prediction_lengths"],
            metrics["medians"],
            marker="o",
            label=model_name,
        )
        std_envelope = np.array(metrics["stds"])
        if show_std_envelope:
            plt.fill_between(
                metrics["prediction_lengths"],
                metrics["means"] - std_envelope,
                metrics["means"] + std_envelope,
                alpha=0.2,
            )
    plt.legend(loc="lower right")
    plt.xlabel("Prediction Length")
    plt.title(metric_name, fontweight="bold")

In [None]:
run_metrics_dirs_all_groups.keys()

In [None]:
metric_names_chosen = [
    "mse",
    "mae",
    "smape",
    "spearman",
]

In [None]:
all_metrics_dict = defaultdict(dict)

for run_group in run_metrics_dirs_all_groups.keys():
    all_metrics_dict[run_group] = {
        metrics_name: get_summary_metrics_dict(unrolled_metrics_all_groups[run_group], metrics_name)
        for metrics_name in metric_names_chosen
    }

In [None]:
default_colors = plt.cm.tab10.colors

In [None]:
all_metrics_dict["chattn"]["mse"].keys()

In [None]:
unrolled_metrics_all_groups["chattn"].keys()

In [None]:
unrolled_metrics_all_combined.keys()

In [None]:
n_runs = len(unrolled_metrics_all_combined)
print(n_runs)

In [None]:
bar_colors = plt.cm.Blues(np.linspace(1.0, 0.1, n_runs)).tolist()
print(len(bar_colors))

In [None]:
selected_pred_length = 512

In [None]:
unrolled_metrics_all_combined.keys()

In [None]:
unrolled_metrics_all_combined["ic2"].keys()

In [None]:
unrolled_metrics_all_combined["ic2"][128].keys()

In [None]:
ic_to_n_systems = {
    "ic1": 20979,
    "ic2": 10490,
    "ic4": 5245,
    "ic8": 2623,
    "ic16": 1312,
    "ic32": 656,
    "ic64": 328,
    "ic128": 164,
}

In [None]:
def make_scaling_plot(
    unrolled_metrics: dict,
    prediction_length: int = 128,
    metric_to_plot: str = "smape",
    stat_to_plot: Literal["median", "mean"] = "median",
    colormap: str = "Blues",
    legend_kwargs: dict = {},
    figsize: tuple = (4, 4),
    save_path: str | None = None,
    use_inv_spearman: bool = True,
    show_legend: bool = True,
    title: str | None = None,
) -> None:
    if metric_to_plot == "smape":
        metric_to_plot_title = "sMAPE"
    elif metric_to_plot == "spearman" and use_inv_spearman:
        metric_to_plot_title = "1 - Spearman"
    else:
        metric_to_plot_title = metric_to_plot.upper()

    metric_at_predlength = defaultdict(list)
    for ic_split, metrics_by_predlength_dict in unrolled_metrics.items():
        n_systems = ic_to_n_systems[ic_split]
        metric_at_predlength[n_systems] = metrics_by_predlength_dict[prediction_length][metric_to_plot]
    # sort metric_at_predlength by n_systems
    metric_at_predlength = dict(sorted(metric_at_predlength.items()))
    # make line plot of medians of metric_at_predlength
    colors = plt.cm.get_cmap(colormap)(np.linspace(0, 1.0, len(metric_at_predlength)))
    plt.figure(figsize=figsize)
    for i, (n_systems, metric_vals) in enumerate(metric_at_predlength.items()):
        metric_vals = np.array(metric_vals)
        # get rid of nan values
        metric_vals = metric_vals[~np.isnan(metric_vals)]
        if metric_to_plot == "spearman" and use_inv_spearman:
            metric_vals = 1 - metric_vals

        if stat_to_plot == "median":
            median_vals = np.median(metric_vals)
            # Create a custom boxplot similar to make_box_plot function
            box_percentile_range = (40, 60)
            whisker_percentile_range = (25, 75)
            box_width = 0.5 * n_systems  # NOTE: this assumes x-axis is log scale
            alpha_val = 0.8

            # Calculate the percentiles
            lower_box, upper_box = np.percentile(metric_vals, box_percentile_range)
            lower_whisker, upper_whisker = np.percentile(metric_vals, whisker_percentile_range)

            # Box width and spacing parameters
            box_half_width = box_width / 2
            whisker_cap_width = box_half_width * 0.5
            # Box
            box = plt.Rectangle(
                (n_systems - box_half_width, lower_box),
                box_width,
                upper_box - lower_box,
                fill=True,
                facecolor=colors[i],
                alpha=alpha_val,
                linewidth=1,
                edgecolor="black",
                zorder=5,
                label=rf"$N_{{sys}}={n_systems}$",
            )
            plt.gca().add_patch(box)

            # Median line
            plt.hlines(
                median_vals,
                n_systems - box_half_width,
                n_systems + box_half_width,
                colors="black",
                linewidth=2.5,
                zorder=10,
            )

            # Whiskers
            plt.vlines(
                n_systems,
                lower_box,
                lower_whisker,
                colors="black",
                linestyle="-",
                linewidth=1,
                zorder=5,
            )
            plt.vlines(
                n_systems,
                upper_box,
                upper_whisker,
                colors="black",
                linestyle="-",
                linewidth=1,
                zorder=5,
            )

            # Caps on whiskers
            plt.hlines(
                lower_whisker,
                n_systems - whisker_cap_width,
                n_systems + whisker_cap_width,
                colors="black",
                linewidth=1,
                zorder=5,
            )
            plt.hlines(
                upper_whisker,
                n_systems - whisker_cap_width,
                n_systems + whisker_cap_width,
                colors="black",
                linewidth=1,
                zorder=5,
            )
        elif stat_to_plot == "mean":
            mean_vals = np.mean(metric_vals)
            std_vals = np.std(metric_vals)
            ste_vals = std_vals / np.sqrt(len(metric_vals))

            plt.scatter(
                n_systems,
                mean_vals,
                s=36,  # equivalent to markersize=6 squared
                edgecolors="black",
                linewidths=0.2,
                label=rf"$N_{{sys}}={n_systems}$",
                color=colors[i],
            )
            plt.errorbar(
                n_systems,
                mean_vals,
                yerr=ste_vals,
                fmt="none",
                color=colors[i],
                capsize=5,  # Add T-shaped caps to the error bars
            )
        else:
            raise ValueError(f"Invalid stat_to_plot: {stat_to_plot}")
    if show_legend:
        plt.legend(**legend_kwargs)
    if title is not None:
        plt.title(title, fontweight="bold")
    plt.xlabel("Number of Systems", fontweight="bold")
    plt.ylabel(metric_to_plot_title, fontweight="bold")
    plt.xscale("log", base=2)
    plt.tight_layout()
    if save_path is not None:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, bbox_inches="tight")
    plt.show()

In [None]:
metric_to_plot = "smape"
prediction_length = 512
stat_to_plot = "median"
make_scaling_plot(
    unrolled_metrics_all_combined,
    metric_to_plot=metric_to_plot,
    stat_to_plot=stat_to_plot,
    prediction_length=prediction_length,
    colormap="cividis_r",
    show_legend=False,
    title=rf"$L_{{pred}}={prediction_length}$",
    legend_kwargs={"loc": "upper right", "frameon": True, "ncol": 1, "fontsize": 8},
    save_path=f"scalinglaw_figs/{metric_to_plot}_{prediction_length}_{stat_to_plot}.pdf",
)

In [None]:
metric_to_plot = "smape"
prediction_length = 512
stat_to_plot = "mean"
make_scaling_plot(
    unrolled_metrics_all_combined,
    metric_to_plot=metric_to_plot,
    stat_to_plot=stat_to_plot,
    prediction_length=prediction_length,
    colormap="cividis_r",
    show_legend=False,
    title=rf"$L_{{pred}}={prediction_length}$",
    legend_kwargs={"loc": "upper right", "frameon": True, "ncol": 1, "fontsize": 8},
    save_path=f"scalinglaw_figs/{metric_to_plot}_{prediction_length}_{stat_to_plot}.pdf",
)

In [None]:
metric_to_plot = "smape"
prediction_length = 256
stat_to_plot = "mean"
make_scaling_plot(
    unrolled_metrics_all_combined,
    metric_to_plot=metric_to_plot,
    stat_to_plot=stat_to_plot,
    prediction_length=prediction_length,
    colormap="cividis_r",
    show_legend=False,
    title=rf"$L_{{pred}}={prediction_length}$",
    legend_kwargs={"loc": "upper right", "frameon": True, "ncol": 1, "fontsize": 8},
    save_path=f"scalinglaw_figs/{metric_to_plot}_{prediction_length}_{stat_to_plot}.pdf",
)

In [None]:
def make_scaling_plot_v2(
    unrolled_metrics: dict,
    prediction_lengths: list[int] = [128, 256, 512],
    metric_to_plot: str = "smape",
    colormap: str = "Blues",
    legend_kwargs: dict = {},
    figsize: tuple = (4, 4),
    save_path: str | None = None,
    use_inv_spearman: bool = True,
    show_legend: bool = True,
    ylim: tuple | None = None,
    alpha_val: float = 0.8,
    markersize: float = 5,
) -> tuple[dict[int, dict[int, float]], dict[int, dict[int, float]], list[mlines.Line2D]]:
    if metric_to_plot == "smape":
        metric_to_plot_title = "sMAPE"
    elif metric_to_plot == "spearman" and use_inv_spearman:
        metric_to_plot_title = "1 - Spearman"
    else:
        metric_to_plot_title = metric_to_plot.upper()

    mean_vals_dict = defaultdict(lambda: defaultdict(list))
    std_vals_dict = defaultdict(lambda: defaultdict(list))
    ste_vals_dict = defaultdict(lambda: defaultdict(list))
    for ic_split, metrics_by_predlength_dict in unrolled_metrics.items():
        n_systems = int(ic_to_n_systems[ic_split])
        for prediction_length in prediction_lengths:
            metric_vals = metrics_by_predlength_dict[prediction_length][metric_to_plot]
            # get rid of nan values
            # Handle case where metric_vals might not be a numpy array
            if isinstance(metric_vals, (list, tuple)):
                metric_vals = np.array(metric_vals)
            # Filter out NaN values
            if len(metric_vals) > 0:
                mask = ~np.isnan(metric_vals)
                metric_vals = metric_vals[mask]
            if metric_to_plot == "spearman" and use_inv_spearman:
                metric_vals = 1 - metric_vals
            mean_vals_dict[prediction_length][n_systems] = np.nanmean(metric_vals)
            std_vals_dict[prediction_length][n_systems] = np.nanstd(metric_vals)
            ste_vals_dict[prediction_length][n_systems] = std_vals_dict[prediction_length][n_systems] / np.sqrt(
                len(metric_vals)
            )
    # sort metric_at_predlength by n_systems
    mean_vals_dict = dict(sorted(mean_vals_dict.items()))
    std_vals_dict = dict(sorted(std_vals_dict.items()))
    ste_vals_dict = dict(sorted(ste_vals_dict.items()))
    # make line plot of medians of metric_at_predlength
    colors = plt.cm.get_cmap(colormap)(np.linspace(0, 0.9, len(mean_vals_dict)))
    plt.figure(figsize=figsize)
    for i, (prediction_length, metrics_dict_by_n_systems) in enumerate(mean_vals_dict.items()):
        n_systems = list(metrics_dict_by_n_systems.keys())
        mean_vals = np.array(list(metrics_dict_by_n_systems.values()))
        ste_vals = np.array(list(ste_vals_dict[prediction_length].values()))
        plt.plot(
            n_systems,
            mean_vals,
            marker="o",
            markersize=markersize,
            linestyle="-",
            label=rf"$L_{{pred}}={prediction_length}$",
            color=colors[i],
            alpha=alpha_val,
        )
        plt.fill_between(
            n_systems,
            mean_vals - ste_vals,
            mean_vals + ste_vals,
            alpha=0.2,
            color=colors[i],
        )

    if show_legend:
        legend_handles = plt.legend(**legend_kwargs)
    else:
        legend_handles = [
            mlines.Line2D(
                [0],
                [0],
                color=colors[i],
                marker="o",
                markersize=markersize,
                linestyle="-",
                alpha=alpha_val,
                label=rf"$L_{{pred}}={list(mean_vals_dict.keys())[i]}$",
            )
            for i in range(len(mean_vals_dict))
        ]

    plt.xlabel("Number of Systems", fontweight="bold")
    plt.ylabel(metric_to_plot_title, fontweight="bold")
    plt.xscale("log", base=2)
    if ylim is not None:
        plt.ylim(ylim)
    plt.tight_layout()
    if save_path is not None:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, bbox_inches="tight")
    return mean_vals_dict, ste_vals_dict, legend_handles

In [None]:
all_pred_lengths = list(unrolled_metrics_all_combined["ic2"].keys())
print(all_pred_lengths)

In [None]:
metric_to_plot = "smape"
prediction_lengths = [128, 256, 512]
stat_to_plot = "mean"

mean_vals_dict, ste_vals_dict, legend_handles = make_scaling_plot_v2(
    unrolled_metrics_all_combined,
    metric_to_plot=metric_to_plot,
    prediction_lengths=all_pred_lengths,
    colormap="cividis",
    show_legend=False,
    figsize=(4, 4),
    alpha_val=1.0,
    markersize=4,
    # legend_kwargs={"loc": "lower center", "frameon": True, "ncol": 4, "fontsize": 5},
    save_path=f"scalinglaw_figs/{metric_to_plot}_combined.pdf",
)

In [None]:
legend_handles

In [None]:
plt.figure(figsize=(6, 1))

# Add the legend with the combined handles
legend = plt.legend(
    handles=legend_handles,
    loc="upper center",
    frameon=True,
    ncol=4,
    framealpha=1.0,
    fontsize=16,
)

plt.xticks([])
plt.yticks([])
plt.tight_layout(pad=0)
plt.savefig("scalinglaw_figs/scalinglaw_legend_horizontal.pdf", bbox_inches="tight")
plt.show()
plt.close()

In [None]:
plt.figure(figsize=(2, 3))

# Add the legend with the combined handles
legend = plt.legend(
    handles=legend_handles,
    loc="upper center",
    frameon=True,
    ncol=1,
    framealpha=1.0,
    fontsize=16,
)

plt.xticks([])
plt.yticks([])
plt.tight_layout(pad=0)
plt.savefig("scalinglaw_figs/scalinglaw_legend_vertical.pdf", bbox_inches="tight")
plt.show()
plt.close()

### Attempt to fit scaling law

In [None]:
mean_vals_dict.keys()

In [None]:
def scaling_law(num_systems, c0, c1, alpha):
    return c0 + c1 * num_systems ** (-alpha)

In [None]:
from scipy.optimize import curve_fit

In [None]:
selected_pred_length = 128

# Curve fitting
initial_guess = [
    1.0,
    1.0,
    1.0,
]  # Initial parameter guess [c1, alpha, c2, beta]
params, pcov = curve_fit(
    lambda X, c0, c1, alpha: scaling_law(X, c0, c1, alpha),
    list(mean_vals_dict[selected_pred_length].keys()),
    list(mean_vals_dict[selected_pred_length].values()),
    p0=initial_guess,
    bounds=([0, 0, 0], [np.inf, np.inf, np.inf]),
    maxfev=10000,
)

# Extract the fitted parameters
c0, c1, alpha = params
param_errors = np.sqrt(np.diag(pcov))

mean_vals_pred = scaling_law(list(mean_vals_dict[selected_pred_length].keys()), c0, c1, alpha)
print(f"mean_vals_pred shape: {mean_vals_pred.shape}")
ss_tot = np.sum(
    (list(mean_vals_dict[selected_pred_length].values()) - np.mean(list(mean_vals_dict[selected_pred_length].values())))
    ** 2
)
ss_res = np.sum((list(mean_vals_dict[selected_pred_length].values()) - mean_vals_pred) ** 2)
r_squared = 1 - (ss_res / ss_tot)
print(f"R² = {r_squared:.4f}")

In [None]:
# Plot the scaling law fit
plt.figure(figsize=(3, 4))

# Plot data and fitted curve
plt.scatter(
    list(mean_vals_dict[selected_pred_length].keys()),
    list(mean_vals_dict[selected_pred_length].values()),
    color=legend_handles[1].get_color(),
    marker="o",
    label=r"$L_{pred}=128$",
)
plt.errorbar(
    list(mean_vals_dict[selected_pred_length].keys()),
    list(mean_vals_dict[selected_pred_length].values()),
    yerr=list(ste_vals_dict[selected_pred_length].values()),
    fmt="none",
    ecolor=legend_handles[1].get_color(),
    capsize=2,
    alpha=1.0,
    elinewidth=1,
)
# Generate smooth curve
x_smooth = np.logspace(
    np.log10(min(mean_vals_dict[selected_pred_length].keys())),
    np.log10(max(mean_vals_dict[selected_pred_length].keys())),
    100,
)
y_smooth = scaling_law(x_smooth, c0, c1, alpha)
plt.plot(
    x_smooth,
    y_smooth,
    color=legend_handles[1].get_color(),
    linestyle="-",
    label="Fitted Curve",
)

# Set labels and title with scaling law formula
plt.xlabel("Number of Systems", fontweight="bold")
# plt.ylabel("sMAPE", fontweight="bold")
plt.title(
    rf"$\mathbb{{E}}[\mathrm{{sMAPE}}] = {c0:.2f} + {c1:.2f} \cdot N_{{sys}}^{{-{alpha:.4f}}}$",
    fontsize=10,
)
plt.xscale("log", base=2)
# plt.grid(True, alpha=0.3)
plt.legend(frameon=True, fontsize=8, loc="upper right")

# Add R² value
plt.text(
    0.05,
    0.05,
    f"R² = {r_squared:.4f}",
    transform=plt.gca().transAxes,
    fontsize=10,
    bbox=dict(facecolor="white", alpha=0.8),
    ha="left",
    va="bottom",
)

plt.tight_layout()
plt.savefig(
    f"scalinglaw_figs/{metric_to_plot}_{selected_pred_length}_fit.pdf",
    bbox_inches="tight",
)
plt.show()

### Fit scaling laws on all prediction lengths sequentially

In [None]:
from tqdm import tqdm

In [None]:
scaling_law_params_by_predlength = {}

for pred_length, mean_vals_dict_by_predlength in tqdm(mean_vals_dict.items(), desc="Fitting scaling laws"):
    print(f"Fitting scaling law for prediction length {pred_length}")

    # Curve fitting
    initial_guess = [
        1.0,
        1.0,
        1.0,
    ]  # Initial parameter guess [c1, alpha, c2, beta]
    params, pcov = curve_fit(
        lambda X, c0, c1, alpha: scaling_law(X, c0, c1, alpha),
        list(mean_vals_dict_by_predlength.keys()),
        list(mean_vals_dict_by_predlength.values()),
        p0=initial_guess,
        bounds=([0, 0, 0], [np.inf, np.inf, np.inf]),
        maxfev=10000,
    )

    # Extract the fitted parameters
    param_errors = np.sqrt(np.diag(pcov))

    print(f"Fitted parameters for prediction length {pred_length}:")
    print(f"c0 = {params[0]:.4e} ± {param_errors[0]:.4e}")
    print(f"c1 = {params[1]:.4e} ± {param_errors[1]:.4e}")
    print(f"alpha = {params[2]:.4f} ± {param_errors[2]:.4f}")

    # compute r2 score
    mean_vals_pred = scaling_law(list(mean_vals_dict_by_predlength.keys()), params[0], params[1], params[2])
    ss_tot = np.sum(
        (list(mean_vals_dict_by_predlength.values()) - np.mean(list(mean_vals_dict_by_predlength.values()))) ** 2
    )
    ss_res = np.sum((list(mean_vals_dict_by_predlength.values()) - mean_vals_pred) ** 2)
    r_squared = 1 - (ss_res / ss_tot)
    print(f"R² = {r_squared:.4f}")

    scaling_law_params_by_predlength[pred_length] = {
        "params": params,
        "param_errors": param_errors,
        "r_squared": r_squared,
    }

print(scaling_law_params_by_predlength)

In [None]:
# Plot the scaling law fit for all prediction lengths
plt.figure(figsize=(3, 4))

# Create colormap
colormap = "cividis"
colors = plt.cm.get_cmap(colormap)(np.linspace(0, 0.9, len(scaling_law_params_by_predlength)))

# Plot data and fitted curves
for i, (pred_length, scaling_params) in enumerate(scaling_law_params_by_predlength.items()):
    c0, c1, alpha = scaling_params["params"]
    r_squared = scaling_params["r_squared"]

    # Plot data points
    plt.scatter(
        list(mean_vals_dict[pred_length].keys()),
        list(mean_vals_dict[pred_length].values()),
        color=colors[i],
        marker="o",
        label=f"Data (pred_len={pred_length})",
        alpha=0.8,
        s=20,
    )

    # plt.errorbar(
    #     list(mean_vals_dict[pred_length].keys()),
    #     list(mean_vals_dict[pred_length].values()),
    #     yerr=list(ste_vals_dict[pred_length].values()),
    #     fmt='none',
    #     ecolor=colors[i],
    #     capsize=2,
    #     alpha=1.0,
    #     elinewidth=1
    # )
    # Plot fitted curve
    x_smooth = np.logspace(
        np.log10(min(mean_vals_dict[pred_length].keys())),
        np.log10(max(mean_vals_dict[pred_length].keys())),
        100,
    )
    y_smooth = scaling_law(x_smooth, c0, c1, alpha)
    plt.plot(
        x_smooth,
        y_smooth,
        color=colors[i],
        linestyle="-",
        linewidth=1,
        label=f"Fit (pred_len={pred_length}): sMAPE = {c0:.2f} + {c1:.2f} · N_sys^(-{alpha:.4f}), R² = {r_squared:.4f}",
    )

# Set plot properties
plt.xlabel("Number of Systems", fontweight="bold")
plt.ylabel("sMAPE", fontweight="bold")
plt.title("Scaling Law Fits", fontweight="bold")
plt.xscale("log", base=2)
# put x ticks at 2**8, 2**9, 2**10, 2**11, 2**12, 2**13, 2**14
plt.xticks([2**8, 2**9, 2**10, 2**11, 2**12, 2**13, 2**14])
# plt.grid(True, alpha=0.3)
# plt.legend(frameon=True, fontsize=12, loc="upper right")

plt.tight_layout()
plt.savefig("scalinglaw_figs/scaling_law_fits_all_pred_lengths.pdf", bbox_inches="tight")
plt.show()

### Fit scaling law on number of systems and prediction length

In [None]:
# Define the scaling law function
def scaling_law_full(num_systems, pred_length, c0, c1, alpha, c2, beta):
    return c0 + c1 * num_systems ** (-alpha) + c2 * pred_length ** (-beta)

In [None]:
mean_vals_dict_full = {}
for pred_length, systems_dict in mean_vals_dict.items():
    for num_systems, value in systems_dict.items():
        mean_vals_dict_full[(num_systems, pred_length)] = value

In [None]:
mean_vals_dict_full.keys()

In [None]:
pred_length_arr = np.array(list(set([x[1] for x in mean_vals_dict_full.keys()])))
num_systems_arr = np.array(list(set([x[0] for x in mean_vals_dict_full.keys()])))

In [None]:
print(pred_length_arr.shape)
print(num_systems_arr.shape)

In [None]:
# Prepare data for curve fitting
X = []
y = []
for (num_sys, pred_len), value in mean_vals_dict_full.items():
    X.append([num_sys, pred_len])
    y.append(value)
X = np.array(X)
y = np.array(y)

In [None]:
X.shape

In [None]:
# Curve fitting
initial_guess = [
    1.0,
    1.0,
    1.0,
    1.0,
    1.0,
]  # Initial parameter guess [c0, c1, alpha, c2, beta]

params, pcov = curve_fit(
    lambda X, c0, c1, alpha, c2, beta: np.array([scaling_law_full(x[0], x[1], c0, c1, alpha, c2, beta) for x in X]),
    X,
    y,
    p0=initial_guess,
    bounds=([0, 0, 0, 0, 0], [np.inf, np.inf, np.inf, np.inf, np.inf]),
    maxfev=10000,
)

# Extract the fitted parameters
c0, c1, alpha, c2, beta = params
param_errors = np.sqrt(np.diag(pcov))

print(f"Fitted parameters: c0={c0:.4f}, c1={c1:.4f}, alpha={alpha:.4f}, c2={c2:.4f}, beta={beta:.4f}")
print(f"Parameter errors: {param_errors}")

In [None]:
print("Fitted parameters:")
print(f"c0 = {c0:.4e} ± {param_errors[0]:.4e}")
print(f"c1 = {c1:.4e} ± {param_errors[0]:.4e}")
print(f"alpha = {alpha:.4f} ± {param_errors[1]:.4f}")
print(f"c2 = {c2:.4e} ± {param_errors[2]:.4e}")
print(f"beta = {beta:.4f} ± {param_errors[3]:.4f}")

In [None]:
# Create plots for scaling law visualization
plt.figure(figsize=(5, 5))

unique_n_systems = sorted(set(n_sys for (n_sys, _) in mean_vals_dict_full.keys()))
unique_pred_lengths = sorted(set(pred_len for (_, pred_len) in mean_vals_dict_full.keys()))
colors = plt.cm.cividis(np.linspace(0, 0.9, len(unique_pred_lengths)))

# Plot for each prediction length
for i, pred_len in enumerate(unique_pred_lengths):
    # Extract data for this prediction length
    n_systems = []
    actual_vals = []
    for (n_sys, p), value in mean_vals_dict_full.items():
        if p == pred_len:
            n_systems.append(n_sys)
            actual_vals.append(value)

    # Sort by number of systems
    sorted_indices = np.argsort(n_systems)
    n_systems = np.array(n_systems)[sorted_indices]
    actual_vals = np.array(actual_vals)[sorted_indices]

    # Plot actual values and fitted curve
    plt.scatter(
        n_systems,
        actual_vals,
        color=colors[i],
        marker="o",
        label=f"L_pred = {pred_len}",
    )

    # Generate smooth curve
    n_systems_smooth = np.logspace(np.log10(min(n_systems)), np.log10(max(n_systems)), 100)
    y_smooth = [scaling_law_full(n, pred_len, c0, c1, alpha, c2, beta) for n in n_systems_smooth]
    plt.plot(n_systems_smooth, y_smooth, color=colors[i], linestyle="-", alpha=0.7)

plt.xlabel("Number of Systems", fontweight="bold")
plt.ylabel("sMAPE", fontweight="bold")
plt.title("Scaling with Number of Systems", fontweight="bold")
plt.xscale("log", base=2)
plt.grid(True, alpha=0.3)
plt.legend(frameon=True, fontsize=8, ncol=2)

plt.tight_layout()
plt.savefig("scalinglaw_figs/scaling_law_fits_all_splits.pdf", bbox_inches="tight")
plt.show()

### Older box plots

In [None]:
legend_handles = make_box_plot(
    unrolled_metrics=unrolled_metrics_all_combined,
    prediction_length=selected_pred_length,
    metric_to_plot="smape",  # Specify which metric to plot
    sort_runs=True,  # Optionally sort runs by their metric values
    colors=bar_colors,
    title=None,
    title_kwargs={"fontsize": 10},
    ylabel_fontsize=12,
    show_xlabel=False,
    box_percentile_range=(40, 60),
    whisker_percentile_range=(25, 75),
    alpha_val=0.8,
    show_legend=True,
    legend_kwargs={"loc": "lower right", "frameon": True, "ncol": 1, "framealpha": 1.0},
    # save_path="scalinglaw_figs/smape_128.pdf",
)

In [None]:
plt.figure(figsize=(4, 0.6))
# Add the legend
plt.legend(
    handles=legend_handles,
    loc="center",
    frameon=True,
    ncol=3,
    framealpha=1.0,
)
plt.xticks([])
plt.yticks([])
plt.tight_layout(pad=0)
# plt.savefig("ablations_figs/ablations_legend.pdf", bbox_inches="tight")
plt.show()
plt.close()