In [1]:
from pathlib import Path
import csv

def build_kan_neat_dict(seeds, functions, base_dir="special_functions_results"):
    """
    Build:
    {
        "KAN-NEAT": {
            <function_name>: {
                <seed>: [best_fitness_0, best_fitness_1, ...]
            }, ...
        }
    }

    Expects files at:
      {base_dir}/kan-{function}-seed{seed}/results_{seed}.csv
    with columns: generation,best_fitness,avg_fitness,species_count,complexity,hidden_nodes
    """
    out = {"KAN-NEAT": {}}
    base = Path(base_dir)

    for fn in functions:
        out["KAN-NEAT"][fn] = {}
        for seed in seeds:
            csv_path = base / f"kan-{fn}-seed{seed}" / f"results_{seed}.csv"
            best_list = []

            if csv_path.exists():
                with csv_path.open("r", newline="", encoding="utf-8") as f:
                    reader = csv.DictReader(f)
                    rows = [r for r in reader if "generation" in r and "best_fitness" in r]

                # Keep order by generation just in case the CSV isn't sorted
                try:
                    rows.sort(key=lambda r: int(float(r["generation"])))
                except Exception:
                    pass  # fall back to file order

                for r in rows:
                    try:
                        best_list.append(float(r["best_fitness"]))
                    except (ValueError, TypeError):
                        # skip malformed rows
                        continue
            else:
                # File missing -> store empty list (or handle as you prefer)
                print(f"Warning: missing {csv_path}")
                pass

            out["KAN-NEAT"][fn][seed] = best_list

    return out

In [2]:
functions = ['ellipj', 'jv', 'lpmv_1']
seeds = [42, 123, 456, 789, 1000]

kan_neat_dict = build_kan_neat_dict(seeds, functions)

In [4]:
kan_neat_dict

{'KAN-NEAT': {'ellipj': {42: [0.8707212041910491,
    0.5425357239603084,
    0.46323512019581126,
    0.3328463594516022,
    0.33199884930846096,
    0.33199884930846096,
    0.33199884930846096,
    0.25356868478633543,
    0.1457082117016645,
    0.09717474254658696,
    0.09717474254658696,
    0.09717474254658696,
    0.09717474254658696,
    0.09717474254658696,
    0.09717474254658696,
    0.09717474254658696,
    0.09717474254658696,
    0.09717474254658696,
    0.09717474254658696,
    0.09717474254658696,
    0.09717474254658696,
    0.09717474254658696,
    0.09717474254658696,
    0.09717474254658696,
    0.09717474254658696,
    0.09717474254658696,
    0.09717474254658696,
    0.09717474254658696,
    0.09717474254658696,
    0.09717474254658696,
    0.09717474254658696,
    0.09717474254658696,
    0.09717474254658696,
    0.09717474254658696,
    0.09717474254658696,
    0.07787984992608192,
    0.07787984992608192,
    0.07787984992608192,
    0.07787984992608192,
   

In [5]:
def build_pykan_dict(seeds, functions, base_dir="special_functions_results/pykan_results"):
    """
    Build:
    {
        "PyKAN": {
            <function_name>: {
                <seed>: [train_loss_step_1, train_loss_step_2, ...]
            }, ...
        }
    }

    Expects files at:
      {base_dir}/pykan-{function}-seed{seed}/losses.txt
    TSV with headers: Step, Train_Loss, Test_Loss
    """
    out = {"PyKAN": {}}
    base = Path(base_dir)

    for fn in functions:
        out["PyKAN"][fn] = {}
        for seed in seeds:
            losses_path = base / f"pykan-{fn}-seed{seed}" / "losses.txt"
            train_losses = []

            if losses_path.exists():
                with losses_path.open("r", encoding="utf-8") as f:
                    reader = csv.DictReader(f, delimiter="\t")
                    rows = []
                    for row in reader:
                        # normalize keys: handle accidental case/space differences
                        row = { (k or "").strip().lower(): (v or "").strip() for k, v in row.items() }
                        rows.append(row)

                    # Ensure sorted by step if the file isn't already
                    try:
                        rows.sort(key=lambda r: int(float(r.get("step", "0"))))
                    except Exception:
                        pass

                    for r in rows:
                        val = r.get("train_loss")
                        if val is None or val == "":
                            continue
                        try:
                            train_losses.append(float(val))
                        except ValueError:
                            continue
            else:
                # Missing file -> keep empty list (same behavior as the KAN-NEAT helper)
                pass

            out["PyKAN"][fn][seed] = train_losses

    return out

In [6]:
pykan_dict = build_pykan_dict(seeds, functions)

In [28]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
from pathlib import Path

def _read_best_runs(csv_path):
    df = pd.read_csv(csv_path)
    method_map = {"kan-neat": "KAN-NEAT", "pykan": "PyKAN"}
    df["method_norm"] = df["method"].str.lower().map(method_map)
    best = {"KAN-NEAT": {}, "PyKAN": {}}
    for _, row in df.iterrows():
        best[row["method_norm"]][row["function"]] = int(row["seed"])
    return best

def _fit_to_len(series, L=50):
    out = np.full(L, np.nan, dtype=float)
    n = min(len(series), L)
    if n > 0:
        out[:n] = np.asarray(series[:n], dtype=float)
    return out

def _normalize_by_first_finite(y):
    idx = np.where(np.isfinite(y))[0]
    if idx.size == 0:
        return y
    base = y[idx[0]]
    if base == 0 or not np.isfinite(base):
        return y
    return y / base

def plot_convergence_per_function_same_x(
    kan_neat_dict,
    pykan_dict,
    best_csv_path,
    output_dir="convergence_plots",
    L=50,
    alpha_bg=0.3,
    lw_bg=1.2,
    lw_best=2.8,
    dpi=200,
    yscale="linear",            # "linear" | "log" | "symlog"
    normalize=False,            # True -> start each curve at ~1 (relative improvement)
    yformat=None,               # e.g. "%.3f" or None for smart formatter
    max_y_ticks=6,              # fewer, cleaner ticks
):
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    best_lookup = _read_best_runs(best_csv_path)
    fnames = set(kan_neat_dict.get("KAN-NEAT", {}).keys()) | set(pykan_dict.get("PyKAN", {}).keys())
    method_colors = {"KAN-NEAT": "tab:blue", "PyKAN": "tab:orange"}
    x = np.arange(1, L + 1)

    for fn in sorted(fnames):
        fig, ax = plt.subplots(figsize=(8, 5), dpi=dpi)

        all_y_values = []

        # KAN-NEAT
        if "KAN-NEAT" in kan_neat_dict and fn in kan_neat_dict["KAN-NEAT"]:
            best_seed_kn = best_lookup.get("KAN-NEAT", {}).get(fn)
            for seed, series in sorted(kan_neat_dict["KAN-NEAT"][fn].items(), key=lambda kv: kv[0]):
                if not series:
                    continue
                y = _fit_to_len(series, L=L)
                if normalize:
                    y = _normalize_by_first_finite(y)
                all_y_values.append(y)

                if seed == best_seed_kn:
                    ax.plot(x, y, linewidth=lw_best, label=f"KAN-NEAT best (seed {seed})",
                            color=method_colors["KAN-NEAT"])
                else:
                    ax.plot(x, y, linewidth=lw_bg, alpha=alpha_bg, color=method_colors["KAN-NEAT"])

        # PyKAN
        if "PyKAN" in pykan_dict and fn in pykan_dict["PyKAN"]:
            best_seed_pk = best_lookup.get("PyKAN", {}).get(fn)
            for seed, series in sorted(pykan_dict["PyKAN"][fn].items(), key=lambda kv: kv[0]):
                if not series:
                    continue
                y = _fit_to_len(series, L=L)
                if normalize:
                    y = _normalize_by_first_finite(y)
                all_y_values.append(y)

                if seed == best_seed_pk:
                    ax.plot(x, y, linewidth=lw_best, label=f"PyKAN best (seed {seed})",
                            color=method_colors["PyKAN"])
                else:
                    ax.plot(x, y, linewidth=lw_bg, alpha=alpha_bg, color=method_colors["PyKAN"])

        # --- Y-axis friendliness ---
        # Scale
        if yscale == "log":
            # if any non-positive values exist, fall back to symlog for safety
            if all_y_values and np.nanmin(np.concatenate(all_y_values)) <= 0:
                ax.set_yscale("symlog", linthresh=1e-6)
            else:
                ax.set_yscale("log")
        elif yscale == "symlog":
            ax.set_yscale("symlog", linthresh=1e-6)

        # Ticks & formatting
        ax.yaxis.set_major_locator(mticker.MaxNLocator(max_y_ticks))
        if yformat:
            ax.yaxis.set_major_formatter(mticker.FormatStrFormatter(yformat))
        else:
            sf = mticker.ScalarFormatter(useMathText=True)
            sf.set_powerlimits((-2, 3))  # show 1e±k only outside this range
            ax.yaxis.set_major_formatter(sf)

        ax.set_xlabel('Training Steps/Generations')
        ax.set_ylabel('Training RMSE/Loss')
        ax.set_title(f'{fn}')
        ax.grid(True, linestyle="--", alpha=0.4)
        ax.set_xlim(1, L)

        handles, _ = ax.get_legend_handles_labels()
        if handles:
            ax.legend(frameon=True)

        fig.tight_layout()
        out_path = output_dir / f"{fn}_convergence.png"
        fig.savefig(out_path)
        plt.close(fig)


In [29]:
plot_convergence_per_function_same_x(kan_neat_dict, pykan_dict,
                                     "best_results_per_function_method.csv",
                                     output_dir="plots", L=50)