In [None]:
import optuna
from os import environ
from optuna.visualization._pareto_front import _get_pareto_front_info
import numpy as np
import pandas as pd
from functools import cache, partial
import seaborn as sns

optuna_storage = environ.get("OPTUNA_STORAGE")


def get_pareto_front(study):
    pareto = _get_pareto_front_info(study, None, False, None, None, None)
    xy = np.array(list(map(lambda x: x[1], pareto.best_trials_with_values)))
    # Sort
    xy = xy[xy[:, 0].argsort()]
    return xy


def non_best_scatter(study):
    pareto = _get_pareto_front_info(study, None, True, None, None, None)
    return np.array(list(map(lambda x: x[1], pareto.non_best_trials_with_values)))


load_study = cache(partial(optuna.study.load_study, storage=optuna_storage))

In [None]:
version = "v11"
study_names = {
    "eurowind": [
        f"{version}/eurowind/joint-mlp",
        f"{version}/eurowind/joint-kan",
        f"{version}/eurowind/wisemlp",
        f"{version}/eurowind/wisekan",
        f"{version}/eurowind/si-mlp",
        f"{version}/eurowind/si-kan",
        f"{version}/eurowind/mlp",
        f"{version}/eurowind/kan",
        f"{version}/eurowind/ewc-mlp",
        f"{version}/eurowind/ewc-kan",
        f"{version}/eurowind/packnet",
    ],
    "riverradar": [
        f"{version}/riverradar/joint-mlp",
        f"{version}/riverradar/joint-kan",
        f"{version}/riverradar/wisemlp",
        f"{version}/riverradar/wisekan",
        f"{version}/riverradar/si-mlp",
        f"{version}/riverradar/si-kan",
        f"{version}/riverradar/mlp",
        f"{version}/riverradar/kan",
        f"{version}/riverradar/ewc-mlp",
        f"{version}/riverradar/ewc-kan",
        f"{version}/riverradar/packnet",
    ],
    "feynman": [
        f"{version}/feynman/joint-mlp",
        f"{version}/feynman/joint-kan",
        f"{version}/feynman/wisemlp",
        f"{version}/feynman/wisekan",
        f"{version}/feynman/si-mlp",
        f"{version}/feynman/si-kan",
        f"{version}/feynman/mlp",
        f"{version}/feynman/kan",
        f"{version}/feynman/ewc-mlp",
        f"{version}/feynman/ewc-kan",
        f"{version}/feynman/packnet",
    ],
}

In [None]:
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from functools import cache
import matplotlib.patches

cmap = plt.get_cmap("tab20")
n_samples = 400


def visualize_study(
    ax: Axes,
    study_name: str,
    color,
    flipxy: bool = False,
    x_func: callable = lambda x: x,
    y_func: callable = lambda x: x,
):
    study = load_study(study_name=study_name, storage=optuna_storage)
    label = study_name.split("/")[2]
    xy = get_pareto_front(study)
    if flipxy:
        xy = xy[:, ::-1]
    x = x_func(xy[:, 0])
    y = y_func(xy[:, 1])

    marker = "^" if "kan" in label else "v"
    lines = ax.plot(
        x, y, "-", color=color, label=label, linewidth=1, marker=marker, markersize=2
    )

    xy = non_best_scatter(study)
    xy = xy[:n_samples]
    print(study_name, len(study.get_trials(False)))
    if flipxy:
        xy = xy[:, ::-1]
    x = x_func(xy[:, 0])
    y = y_func(xy[:, 1])
    # make colour lighter
    color = np.array(color) + (1 - np.array(color)) * 0.5
    ax.scatter(x, y, s=0.1, color=color, marker=marker)
    return lines


fig, (ax0, ax1, ax2) = plt.subplots(
    1,
    3,
    figsize=(7, 2),
    dpi=300,
)


def plot_dataset(ax, cmap, dataset):
    for i, study_name in enumerate(study_names[dataset]):
        visualize_study(ax, study_name, cmap(i), flipxy=True, y_func=lambda x: -x)


lines = plot_dataset(ax0, cmap, "eurowind")
ax0.set_title("TI Europe Wind Farm")
ax0.set_xscale("log")
ax0.set_ylim(0.5, 0.9)

plot_dataset(ax1, cmap, "riverradar")
ax1.set_title("TI River Radar")
ax1.set_xscale("log")
ax1.set_ylim(-0.05, 0.62)

plot_dataset(ax2, cmap, "feynman")
ax2.set_title("TI Feynman")
ax2.set_xscale("log")
ax2.set_ylim(0.5, 1.05)

# Set labels
ax0.set_ylabel(r"R2 $\longrightarrow$ ")
ax1.set_xlabel(r"$\longleftarrow$ # Parameters")

# Add legend
patches = []
labels = []
for i, name in enumerate(study_names["eurowind"]):
    label = name.split("/")[2]
    patches.append(matplotlib.patches.Patch(color=cmap(i), label=label))
    labels.append(label)
fig.legend(patches, labels, loc="center left", frameon=False, bbox_to_anchor=(0.9, 0.5))
plt.savefig("figures/pareto_front.jpg", bbox_inches="tight", dpi=600)

In [None]:
@cache
def get_strategy_studies(strategy):
    studies = []
    for scenario in ["eurowind", "feynman", "riverradar"]:
        studies.append(
            load_study(study_name=f"v11/{scenario}/{strategy}").trials_dataframe()
        )

    df = pd.concat(
        studies, keys=["TI Europe Wind Farm", "TI Feynman", "TI River Radar"]
    )
    df = df.reset_index()
    df["values_0"] = -df["values_0"]
    df.rename(
        columns={
            "level_0": "Dataset",
            "values_0": "R2",
            "params_model.norm": "Norm",
            "params_model.grid_range": "Grid Range",
            "params_model.n_hidden_layers": "Hidden Layers",
        },
        inplace=True,
    )
    df.sort_values("Dataset", inplace=True)
    return df

In [None]:
plt.figure(figsize=(3.5, 2))
df = get_strategy_studies("wisekan")
sns.boxenplot(data=df, x="Dataset", y="R2", hue="Norm", showfliers=True)
plt.legend(loc="upper left", bbox_to_anchor=(1, 1), frameon=False, title="Norm")
plt.ylim(-0.5, 1.0)
plt.title("WiseKAN Normalization")
plt.savefig("figures/wisekan_norm.jpg", bbox_inches="tight", dpi=400)

In [None]:
plt.figure(figsize=(3.5, 2))
# sort by grid range
df = get_strategy_studies("wisekan")
df.sort_values(["Grid Range", "Dataset"], inplace=True)
# combine grid range and norm columns
df["Grid Range & Norm"] = df["Grid Range"].astype(str) + " " + df["Norm"].astype(str)

sns.boxenplot(data=df, x="Dataset", y="R2", hue="Grid Range & Norm", showfliers=False)
plt.legend(loc="upper left", bbox_to_anchor=(1, 1), frameon=False)
plt.ylim(-0.0, 1.0)
plt.title("WiseKAN Grid Range & Normalization")
plt.savefig("figures/wisekan_grid_range.jpg", bbox_inches="tight", dpi=400)