In [1]:
import pandas as pd
import plotly.express as px
import plotly.graph_objs as go
import wandb

from pathlib import Path

In [2]:
plots_dir_fp = Path(f"../plots/test_accs/")
plots_dir_fp.mkdir(exist_ok=True, parents=True)

In [38]:
ENTITY = "wei2912"
PREFIX = "ethz-ssl-tabular_"
DATASETS = ["jannis", "gas-drift-different-concentrations", "higgs", "covertype"]
MODELS = ["hgbt", "mlp", "random-forest"]

api = wandb.Api()
runs_datas = []
for dataset in DATASETS:
    runs = api.runs(f"{ENTITY}/{PREFIX}{dataset}")
    runs_datas.extend(
        [
            {
                "dataset": dataset,
                "summary": run.summary._json_dict,
                "config": run.config,
                "name": run.name,
                "job_type": run.job_type,
            }
            for run in runs
        ]
    )

runs_df = pd.json_normalize(runs_datas)
runs_df

Unnamed: 0,dataset,name,job_type,summary.run.initial.train.per_epoch.train_losses,summary.run.pl_iter4.train.per_epoch.train_losses,summary.run.pl_iter3.train.per_epoch.lrs,summary.run.pl_iter2.val.acc,summary.run.initial.train.acc,summary.run.pl_iter3.train.max_epochs,summary.run.initial.train.policy.patience,...,config.direction,config.layer_size,summary.max_depth,summary.trial_number,summary.min_samples_leaf,summary.value_0,summary.value_1,config.max_depth,config.min_samples_leaf,summary.lr
0,jannis,silver-plant-1388,eval,0.015744,0.039872,0.000640,0.675,1.0,1000.0,250.0,...,,,,,,,,,,
1,jannis,fearless-energy-1387,eval,0.013957,0.024301,0.000128,0.679,1.0,1000.0,250.0,...,,,,,,,,,,
2,jannis,lucky-vortex-1386,eval,0.012094,0.336839,0.000128,0.642,1.0,1000.0,250.0,...,,,,,,,,,,
3,jannis,fragrant-cloud-1385,eval,0.012667,0.061781,0.000640,0.681,1.0,1000.0,250.0,...,,,,,,,,,,
4,jannis,lyric-galaxy-1384,eval,0.012701,0.091316,0.000128,0.633,1.0,1000.0,250.0,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3419,covertype,trial/2/good-sun-5,sweep,,,,,,,,...,"[MAXIMIZE, MAXIMIZE]",,3.0,2.0,3.0,0.54,0.68,3.0,3.0,0.2
3420,covertype,trial/1/bright-butterfly-4,sweep,,,,,,,,...,"[MAXIMIZE, MAXIMIZE]",,3.0,1.0,1.0,0.51,0.61,3.0,1.0,
3421,covertype,trial/1/vital-durian-3,sweep,,,,,,,,...,"[MAXIMIZE, MAXIMIZE]",,3.0,1.0,1.0,0.54,0.68,3.0,1.0,0.3
3422,covertype,trial/0/usual-pine-2,sweep,,,,,,,,...,"[MAXIMIZE, MAXIMIZE]",,,,,,,4.0,1.0,


In [4]:
runs_df.columns.values

array(['dataset', 'name', 'job_type', 'summary.run.pl_iter3.l_pl_acc',
       'summary.run.pl_iter1.train.policy.factor',
       'summary.run.pl_iter4.train.policy.factor',
       'summary.run.pl_iter0.train.per_epoch.val_accs',
       'summary.run.pl_iter3.n_pl', 'summary.run.pl_iter1.l_pl_acc',
       'summary.run.pl_iter1.train.batch_size',
       'summary.run.pl_iter4.train.per_epoch.lrs',
       'summary.run.pl_iter0.train.policy.patience',
       'summary.run.pl_iter0.pl_acc',
       'summary.run.pl_iter3.train.max_epochs',
       'summary.run.pl_iter1.train.per_epoch.lrs',
       'summary.run.pl_iter3.train.policy.patience',
       'summary.run.pl_iter4.val.acc', 'summary.run.pl_iter1.train.acc',
       'summary.run.pl_iter1.train.size', 'summary.run.pl_iter0.size_ul',
       'summary.run.pl_iter0.threshold',
       'summary.run.pl_iter0.train.per_epoch.train_losses',
       'summary.run.initial.train.size',
       'summary.run.pl_iter0.train.max_epochs',
       'summary.run.pl_

In [10]:
COLUMNS = {
    "dataset": "dataset",
    "name": "name",
    "config.args.model": "model",
    "config.args.st_type": "st_type",
    "config.args.seed": "seed",
    "config.split.l_split": "l_split",
    "config.split.ul_split": "ul_split",
    "summary.test.acc": "test.acc",
}
evals_df = runs_df[runs_df["job_type"] == "eval"][COLUMNS.keys()].rename(
    columns=COLUMNS
)
evals_df = evals_df[evals_df["seed"] <= 4]
evals_df["st_type"].fillna("None", inplace=True)
evals_df

Unnamed: 0,dataset,name,model,st_type,seed,l_split,ul_split,test.acc
292,jannis,misty-elevator-887,mlp,curr,4,0.025,0.025,0.685
293,jannis,sunny-violet-886,mlp,curr,4,0.025,0.050,0.682
294,jannis,hearty-wood-885,mlp,curr,4,0.025,0.075,0.684
295,jannis,bright-universe-884,mlp,curr,4,0.025,0.100,0.679
296,jannis,misunderstood-butterfly-883,mlp,curr,4,0.050,0.025,0.650
...,...,...,...,...,...,...,...,...
3322,covertype,exalted-dawn-204,hgbt,curr,0,0.500,0.500,0.724
3323,covertype,feasible-butterfly-203,random-forest,curr,0,0.500,0.250,0.721
3324,covertype,drawn-thunder-202,random-forest,curr,0,0.500,0.500,0.717
3325,covertype,likely-eon-201,random-forest,curr,0,0.750,0.250,0.753


In [11]:
mean_evals_df = (
    evals_df.drop(columns=["name"])
    .groupby(
        ["dataset", "model", "st_type", "l_split", "ul_split"],
        as_index=False,
        dropna=False,
    )
    .mean()
)

mean_evals_df["st_type"].fillna("None", inplace=True)
mean_evals_df

Unnamed: 0,dataset,model,st_type,l_split,ul_split,seed,test.acc
0,covertype,hgbt,,0.025,0.00,2.0,0.5260
1,covertype,hgbt,,0.050,0.00,2.0,0.5670
2,covertype,hgbt,,0.075,0.00,2.0,0.5870
3,covertype,hgbt,,0.100,0.00,2.0,0.6060
4,covertype,hgbt,,0.250,0.00,2.0,0.6806
...,...,...,...,...,...,...,...
499,jannis,random-forest,curr,0.250,0.50,2.0,0.6916
500,jannis,random-forest,curr,0.250,0.75,2.0,0.6906
501,jannis,random-forest,curr,0.500,0.25,2.0,0.7174
502,jannis,random-forest,curr,0.500,0.50,2.0,0.7150


In [12]:
for dataset in DATASETS:
    fig = px.scatter(
        mean_evals_df[mean_evals_df["dataset"] == dataset],
        x="l_split",
        y="ul_split",
        color="test.acc",
        facet_col="model",
        hover_data=["st_type"],
        category_orders={"model": MODELS},
        title=f"{dataset}",
        color_continuous_scale="bluered",
        width=1200,
        height=400,
    )
    fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
    fig.update_layout(font=dict(size=18))
    fig.write_image(plots_dir_fp / f"{dataset}_ul_split_vs_l_split.svg")
    fig.show()

In [15]:
indiv_evals_df = evals_df[evals_df["l_split"].isin([0.025, 0.1, 0.25])].sort_values(
    by="ul_split"
)
indiv_evals_df

Unnamed: 0,dataset,name,model,st_type,seed,l_split,ul_split,test.acc
2944,covertype,devout-puddle-582,hgbt,,3,0.025,0.00,0.557
551,jannis,morning-pond-628,random-forest,,4,0.250,0.00,0.712
552,jannis,stellar-tree-627,hgbt,,4,0.250,0.00,0.709
1324,gas-drift-different-concentrations,vivid-capybara-689,random-forest,,4,0.250,0.00,0.919
787,jannis,usual-grass-392,hgbt,,2,0.250,0.00,0.734
...,...,...,...,...,...,...,...,...
2017,higgs,pious-mountain-700,mlp,curr,3,0.250,0.75,0.543
1990,higgs,ruby-planet-727,mlp,curr,3,0.100,0.75,0.497
1981,higgs,upbeat-feather-736,mlp,curr,3,0.025,0.75,0.491
2058,higgs,lyric-glitter-659,random-forest,curr,3,0.250,0.75,0.607


In [16]:
for dataset in DATASETS:
    fig_0 = px.box(
        indiv_evals_df[
            (indiv_evals_df["dataset"] == dataset) & (indiv_evals_df["ul_split"] <= 0.1)
        ],
        x="ul_split",
        y="test.acc",
        color="l_split",
        facet_col="model",
        category_orders={
            "l_split": sorted(indiv_evals_df["l_split"].unique()),
            "model": MODELS,
        },
        hover_data=["st_type"],
        points="all",
        title=f"{dataset} (low UL-data)",
        width=1200,
        height=400,
    )
    fig_0.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
    fig_0.update_layout(font=dict(size=18))
    fig_0.write_image(plots_dir_fp / f"{dataset}_test_acc_vs_ul_split_low.svg")
    fig_0.show()

    fig_1 = px.box(
        indiv_evals_df[
            (indiv_evals_df["dataset"] == dataset)
            & (indiv_evals_df["ul_split"] >= 0.25)
        ],
        x="ul_split",
        y="test.acc",
        color="l_split",
        facet_col="model",
        category_orders={
            "l_split": sorted(indiv_evals_df["l_split"].unique()),
            "model": MODELS,
        },
        hover_data=["st_type"],
        points="all",
        title=f"{dataset} (high UL-data)",
        width=1200,
        height=400,
    )
    fig_1.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
    fig_1.update_layout(font=dict(size=18))
    fig_1.write_image(plots_dir_fp / f"{dataset}_test_acc_vs_ul_split_high.svg")
    fig_1.show()

In [58]:
agg_evals_df = (
    indiv_evals_df.drop(columns=["name"])
    .groupby(
        ["dataset", "model", "st_type", "l_split", "ul_split"],
        as_index=False,
        dropna=False,
    )
    .agg(["mean", "std"])
    .drop(columns=["seed"])
    .reset_index()
)
agg_evals_df = agg_evals_df[
    agg_evals_df["l_split"].isin([0.025, 0.25])
    & agg_evals_df["ul_split"].isin([0.0, 0.25])
]

for dataset in DATASETS:
    for model in MODELS:
        df = agg_evals_df[
            (agg_evals_df["dataset"] == dataset) & (agg_evals_df["model"] == model)
        ]
        accs = [
            df[(df["l_split"] == l_split) & (df["ul_split"] == ul_split)]
            .iloc[0]["test.acc"]
            .values.tolist()
            for l_split in [0.025, 0.25]
            for ul_split in [0.0, 0.25]
        ]
        print(f"=== {dataset}/{model} ===")
        print(
            f"${accs[0][0]*100:3.1f} \pm {accs[0][1]*100:3.1f}$ & ${accs[1][0]*100:3.1f} \pm {accs[1][1]*100:3.1f}$ \\\\"
        )
        print(
            f"${accs[2][0]*100:3.1f} \pm {accs[2][1]*100:3.1f}$ & ${accs[3][0]*100:3.1f} \pm {accs[3][1]*100:3.1f}$ \\\\"
        )

=== jannis/hgbt ===
$62.7 \pm 6.1$ & $62.6 \pm 4.1$ \\
$71.9 \pm 1.6$ & $71.0 \pm 2.1$ \\
=== jannis/mlp ===
$64.7 \pm 4.4$ & $63.2 \pm 7.6$ \\
$70.0 \pm 1.0$ & $70.7 \pm 0.9$ \\
=== jannis/random-forest ===
$66.1 \pm 3.9$ & $62.1 \pm 3.8$ \\
$71.6 \pm 1.6$ & $69.8 \pm 1.9$ \\
=== gas-drift-different-concentrations/hgbt ===
$65.3 \pm 3.7$ & $64.5 \pm 2.2$ \\
$94.7 \pm 0.4$ & $95.0 \pm 0.7$ \\
=== gas-drift-different-concentrations/mlp ===
$76.1 \pm 2.9$ & $66.0 \pm 7.0$ \\
$97.0 \pm 0.1$ & $97.2 \pm 0.4$ \\
=== gas-drift-different-concentrations/random-forest ===
$67.5 \pm 3.4$ & $64.4 \pm 5.3$ \\
$93.1 \pm 0.8$ & $92.9 \pm 1.3$ \\
=== higgs/hgbt ===
$54.8 \pm 2.1$ & $53.0 \pm 1.8$ \\
$65.6 \pm 1.7$ & $64.0 \pm 1.2$ \\
=== higgs/mlp ===
$52.1 \pm 2.7$ & $51.5 \pm 2.3$ \\
$54.9 \pm 2.9$ & $53.3 \pm 3.3$ \\
=== higgs/random-forest ===
$53.8 \pm 2.7$ & $51.5 \pm 3.2$ \\
$65.8 \pm 0.9$ & $62.5 \pm 1.4$ \\
=== covertype/hgbt ===
$52.6 \pm 2.4$ & $52.0 \pm 3.8$ \\
$68.1 \pm 0.8$ & $67.4 \pm 