In [1]:
from typing import Optional
import pandas as pd
from texttable import Texttable
import latextable
from scipy.stats import ttest_ind_from_stats


def fmt_cell(
    df: pd.DataFrame,
    row: str,
    column_key: str,
    bold: bool = False,
    decimal_points: int = 2,
) -> str:
    mean_key = f"{column_key}/mean"
    std_key = f"{column_key}/std"

    std_str = f"{df.loc[row, std_key]:.{decimal_points}f}".lstrip("0")
    mean_str = f"{df.loc[row, mean_key]:.{decimal_points}f}"
    if bold:
        return f"\\textbf{{{mean_str}}},\\textbf{{{std_str}}}"
    else:
        return f"{mean_str},{std_str}"


def latex_float(f):
    float_str = "{0:.2g}".format(f)
    if "e" in float_str:
        base, exponent = float_str.split("e")
        return r"{0} \times 10^{{{1}}}".format(base, int(exponent))
    else:
        return float_str


def fmt_row(
    df: pd.DataFrame,
    column_keys: list[str],
    metric_objectives: list[bool],
    row: str,
    row_name: str,
    decimal_points: list[int],
    ttest_row: Optional[str] = None,
) -> list[str]:
    # The row's cell is bolded if it is the best in the column
    bold = []
    param_count = df.loc[row]["trainable_parameters/mean"]

    if ttest_row is not None:
        column = column_keys[0]
        mean1 = df.loc[row, f"{column}/mean"]
        std1 = df.loc[row, f"{column}/std"]
        mean2 = df.loc[ttest_row, f"{column}/mean"]
        std2 = df.loc[ttest_row, f"{column}/std"]
        t, p = ttest_ind_from_stats(
            mean1,
            std1,
            param_count,
            mean2,
            std2,
            5,
            equal_var=False,
            alternative="two-sided",
        )
        if p < 0.05 and mean1 > mean2:
            row_name += "*"

    for column_key, metric_objective, digits in zip(
        column_keys, metric_objectives, decimal_points
    ):
        mean_key = f"{column_key}/mean"
        if metric_objective:
            bold.append(
                round(df[mean_key].max(), digits)
                == round(df.loc[row, mean_key], digits)
            )
        else:
            bold.append(
                round(df[mean_key].min(), digits)
                == round(df.loc[row, mean_key], digits)
            )
    return [f"{row_name} ({param_count/1e3:.0f}k)"] + [
        fmt_cell(df, row, column_key, bold=bold[i], decimal_points=decimal_points[i])
        for i, column_key in enumerate(column_keys)
    ]


# Please note that these are from a different set of runs than the ones in the paper
# The results are similar, but not identical
df = pd.read_csv("AAAI25_clkan_9798c8d.csv")
df

Unnamed: 0,scenario,strategy,00_R2/mean,00_R2/std,00_MSE/mean,00_MSE/std,AVG_R2/mean,AVG_R2/std,BWT_R2/mean,BWT_R2/std,...,AVG_MSE/mean,AVG_MSE/std,BWT_MSE/mean,BWT_MSE/std,FWT_MSE/mean,FWT_MSE/std,R2_DIAG/mean,R2_DIAG/std,trainable_parameters/mean,trainable_parameters/count
0,eurowind,ewc-kan,0.727313,0.03723,0.39449,0.053859,0.703218,0.004958,-0.06602476,0.009245142,...,0.25647,0.008079,0.07418046,0.01193685,1.057334,0.06082,0.752536,0.002191,29193,5
1,eurowind,ewc-mlp,0.759202,0.019224,0.348357,0.027812,0.678016,0.006032,-0.03062813,0.004786216,...,0.279199,0.005474,0.03624898,0.006295605,1.045864,0.029282,0.690333,0.006471,44030,5
2,eurowind,joint-kan,0.870426,0.002331,0.131075,0.002358,,,,,...,,,,,,,,,347761,5
3,eurowind,joint-mlp,0.865046,0.003891,0.136517,0.003937,,,,,...,,,,,,,,,276401,5
4,eurowind,kan,0.523521,0.064968,0.689311,0.093987,0.623523,0.006094,-0.1402132,0.007834178,...,0.334271,0.00619,0.1346087,0.008201333,0.943203,0.057065,0.743483,0.003577,26082,5
5,eurowind,mlp,0.628567,0.036686,0.537344,0.053073,0.64963,0.012733,-0.06854786,0.01603812,...,0.307947,0.014316,0.06698709,0.01841287,0.981247,0.029338,0.714551,0.000589,40115,5
6,eurowind,packnet,0.841414,0.005067,0.229422,0.00733,0.780847,0.005756,0.0,0.0,...,0.19033,0.005577,0.0,0.0,1.004771,0.0637,0.766563,0.005835,254675,5
7,eurowind,si-kan,0.882328,0.004186,0.170233,0.006055,0.733884,0.004742,-1.49e-05,8.88e-06,...,0.215828,0.003039,2.12e-05,1.28e-05,1.000512,0.075146,0.735982,0.005216,18147,5
8,eurowind,si-mlp,0.884197,0.0078,0.167529,0.011284,0.700209,0.00837,-0.0002222321,0.0003225993,...,0.243545,0.006172,0.0003076941,0.0003657641,0.99446,0.055979,0.698307,0.005559,67697,5
9,eurowind,wisekan,0.898859,0.02083,0.146319,0.030134,0.798062,0.009082,0.0,0.0,...,0.170548,0.011904,0.0,0.0,0.941554,0.015767,0.789027,0.010413,288600,5


## `scenario/feynman`

In [2]:
def display_table(table, r2_decimals=3):
    table.index = table["strategy"]
    none_entry = "NA"
    row_keys = ["AVG_R2", "BWT_R2", "R2_DIAG"]
    metric_objective = [True, True, True]
    decimal_points = [r2_decimals, 3, 3]

    # Forward transfer does not make any sense because every task is orthogonal
    # so zero-shot learning is not possible.

    dtable = Texttable()  # d for display table
    dtable.header(
        list(
            map(
                lambda s: "\\multicolumn{1}{c}{" + s + "}",
                ["Strategy (\\# param)", "R2 ↑", "$R2_{bwt}$ ↑", "$R2_{on}$ ↑ "],
            )
        )
    )

    joint_df = table[table.index.str.contains("joint")]
    table = table[~table.index.str.contains("joint")]
    dtable.add_row(
        fmt_row(
            joint_df,
            ["00_R2"],
            metric_objective,
            "joint-mlp",
            "joint-mlp",
            decimal_points,
        )
        + [none_entry] * 2
    )
    dtable.add_row(
        fmt_row(
            joint_df,
            ["00_R2"],
            metric_objective,
            "joint-kan",
            "joint-kan",
            decimal_points,
        )
        + [none_entry] * 2
    )

    # Drop joint-mlp and joint-kan
    dtable.add_row(
        fmt_row(
            table,
            row_keys,
            metric_objective,
            "wisemlp",
            "wisemlp",
            decimal_points,
            ttest_row="wisekan",
        )
    )
    dtable.add_row(
        fmt_row(
            table,
            row_keys,
            metric_objective,
            "wisekan",
            "wisekan",
            decimal_points,
            ttest_row="wisemlp",
        )
    )
    dtable.add_row(
        fmt_row(
            table,
            row_keys,
            metric_objective,
            "si-mlp",
            "si-mlp",
            decimal_points,
            ttest_row="si-kan",
        )
    )
    dtable.add_row(
        fmt_row(
            table,
            row_keys,
            metric_objective,
            "si-kan",
            "si-kan",
            decimal_points,
            ttest_row="si-mlp",
        )
    )
    dtable.add_row(
        fmt_row(
            table,
            row_keys,
            metric_objective,
            "ewc-mlp",
            "ewc-mlp",
            decimal_points,
            ttest_row="ewc-kan",
        )
    )
    dtable.add_row(
        fmt_row(
            table,
            row_keys,
            metric_objective,
            "ewc-kan",
            "ewc-kan",
            decimal_points,
            ttest_row="ewc-mlp",
        )
    )
    dtable.add_row(
        fmt_row(
            table,
            row_keys,
            metric_objective,
            "mlp",
            "mlp",
            decimal_points,
            ttest_row="kan",
        )
    )
    dtable.add_row(
        fmt_row(
            table,
            row_keys,
            metric_objective,
            "kan",
            "kan",
            decimal_points,
            ttest_row="mlp",
        )
    )
    dtable.add_row(
        fmt_row(table, row_keys, metric_objective, "packnet", "packnet", decimal_points)
    )

    print(latextable.draw_latex(dtable, use_booktabs=True))

In [3]:
display_table(df[df["scenario"] == "feynman"].copy(), 4)

\begin{table}
	\begin{center}
		\begin{tabular}{llll}
			\toprule
			\multicolumn{1}{c}{Strategy (\# param)} & \multicolumn{1}{c}{R2 ↑} & \multicolumn{1}{c}{$R2_{bwt}$ ↑} & \multicolumn{1}{c}{$R2_{on}$ ↑ } \\
			\midrule
			joint-mlp (155k) & 0.9998,.0001 & NA & NA \\
			joint-kan (274k) & \textbf{0.9999},\textbf{.0000} & NA & NA \\
			wisemlp (308k) & 0.9951,.0022 & \textbf{0.000},\textbf{.000} & 0.995,.002 \\
			wisekan* (247k) & \textbf{0.9994},\textbf{.0001} & \textbf{0.000},\textbf{.000} & \textbf{0.999},\textbf{.000} \\
			si-mlp (53k) & 0.9636,.0592 & -0.029,.063 & 0.992,.001 \\
			si-kan (71k) & 0.9992,.0001 & \textbf{-0.000},\textbf{.000} & \textbf{0.999},\textbf{.000} \\
			ewc-mlp (58k) & 0.9142,.0220 & -0.073,.024 & 0.984,.001 \\
			ewc-kan* (77k) & 0.9958,.0029 & -0.002,.003 & 0.998,.001 \\
			mlp* (60k) & 0.8012,.0076 & -0.107,.009 & 0.901,.001 \\
			kan (94k) & 0.6936,.0360 & -0.223,.037 & 0.909,.001 \\
			packnet (364k) & 0.9258,.0123 & \textbf{0.000},\textbf{.000} & 0.

## `scenario/eurowind`

In [4]:
display_table(df[df["scenario"] == "eurowind"].copy(), 3)

\begin{table}
	\begin{center}
		\begin{tabular}{llll}
			\toprule
			\multicolumn{1}{c}{Strategy (\# param)} & \multicolumn{1}{c}{R2 ↑} & \multicolumn{1}{c}{$R2_{bwt}$ ↑} & \multicolumn{1}{c}{$R2_{on}$ ↑ } \\
			\midrule
			joint-mlp (276k) & 0.865,.004 & NA & NA \\
			joint-kan (348k) & \textbf{0.870},\textbf{.002} & NA & NA \\
			wisemlp (278k) & 0.781,.006 & \textbf{0.000},\textbf{.000} & 0.780,.008 \\
			wisekan* (289k) & \textbf{0.798},\textbf{.009} & \textbf{0.000},\textbf{.000} & \textbf{0.789},\textbf{.010} \\
			si-mlp (68k) & 0.700,.008 & \textbf{-0.000},\textbf{.000} & 0.698,.006 \\
			si-kan* (18k) & 0.734,.005 & \textbf{-0.000},\textbf{.000} & 0.736,.005 \\
			ewc-mlp (44k) & 0.678,.006 & -0.031,.005 & 0.690,.006 \\
			ewc-kan* (29k) & 0.703,.005 & -0.066,.009 & 0.753,.002 \\
			mlp* (40k) & 0.650,.013 & -0.069,.016 & 0.715,.001 \\
			kan (26k) & 0.624,.006 & -0.140,.008 & 0.743,.004 \\
			packnet (255k) & 0.781,.006 & \textbf{0.000},\textbf{.000} & 0.767,.006 \\
			\botto

## `scenario/riverradar`

In [5]:
display_table(df[df["scenario"] == "riverradar"].copy())

\begin{table}
	\begin{center}
		\begin{tabular}{llll}
			\toprule
			\multicolumn{1}{c}{Strategy (\# param)} & \multicolumn{1}{c}{R2 ↑} & \multicolumn{1}{c}{$R2_{bwt}$ ↑} & \multicolumn{1}{c}{$R2_{on}$ ↑ } \\
			\midrule
			joint-mlp (442k) & \textbf{0.620},\textbf{.022} & NA & NA \\
			joint-kan (869k) & 0.609,.030 & NA & NA \\
			wisemlp* (525k) & 0.586,.011 & \textbf{0.000},\textbf{.000} & 0.551,.010 \\
			wisekan (710k) & 0.574,.010 & \textbf{0.000},\textbf{.000} & 0.516,.011 \\
			si-mlp (149k) & 0.452,.010 & -0.001,.003 & 0.427,.009 \\
			si-kan (461k) & 0.465,.005 & -0.001,.001 & 0.426,.005 \\
			ewc-mlp (33k) & 0.232,.044 & -0.228,.062 & 0.371,.006 \\
			ewc-kan* (166k) & 0.343,.007 & -0.144,.015 & 0.413,.006 \\
			mlp (89k) & 0.153,.049 & -0.341,.066 & 0.413,.001 \\
			kan (508k) & 0.159,.013 & -0.283,.018 & 0.378,.003 \\
			packnet (523k) & \textbf{0.591},\textbf{.007} & \textbf{0.000},\textbf{.000} & \textbf{0.553},\textbf{.007} \\
			\bottomrule
		\end{tabular}
	\end{center