In [2]:
import pandas as pd
import plotly.express as px
import wandb

from typing import List

In [10]:
ENTITY = "wei2912"
PREFIX = "ethz-ssl-tabular_"
DATASETS = ["jannis", "gas-drift-different-concentrations", "higgs", "covertype"]

api = wandb.Api()
runs_datas = []
for dataset in DATASETS:
    runs = api.runs(f"{ENTITY}/{PREFIX}{dataset}")
    runs_datas.extend(
        [
            {
                "dataset": dataset,
                "summary": run.summary._json_dict,
                "config": run.config,
                "name": run.name,
                "job_type": run.job_type,
            }
            for run in runs
        ]
    )

runs_df = pd.json_normalize(runs_datas)

In [14]:
print(runs_df.columns.values)

['dataset' 'name' 'job_type' 'summary.run.pl_iter2.train.size'
 'summary.run.pl_iter2.train.policy.patience'
 'summary.run.pl_iter1.train.per_epoch.train_accs' 'summary._timestamp'
 'summary.run.pl_iter3.l_pl_acc'
 'summary.run.pl_iter1.train.policy.patience'
 'summary.run.pl_iter2.threshold' 'summary.run.pl_iter4.size_l_pl'
 'summary.run.initial.train.per_epoch.lrs'
 'summary.run.pl_iter0.train.per_epoch.lrs'
 'summary.run.initial.train.per_epoch.train_losses'
 'summary.run.pl_iter4.train.batch_size'
 'summary.run.initial.train.policy.factor'
 'summary.run.pl_iter0.train.policy.factor'
 'summary.run.pl_iter3.train.per_epoch.val_losses'
 'summary.run.pl_iter4.train.max_epochs'
 'summary.run.pl_iter1.train.per_epoch.train_losses'
 'summary.run.pl_iter1.pl_acc' 'summary.run.pl_iter4.val.acc'
 'summary.run.pl_iter1.train.size' 'summary.run.pl_iter0.train.max_epochs'
 'summary.run.pl_iter0.n_pl' 'summary.run.pl_iter3.threshold'
 'summary.run.pl_iter3.train.per_epoch.val_accs'
 'summary.run

In [11]:
runs_df

Unnamed: 0,dataset,name,job_type,summary.run.pl_iter2.train.size,summary.run.pl_iter2.train.policy.patience,summary.run.pl_iter1.train.per_epoch.train_accs,summary._timestamp,summary.run.pl_iter3.l_pl_acc,summary.run.pl_iter1.train.policy.patience,summary.run.pl_iter2.threshold,...,config.direction,config.layer_size,summary.max_depth,summary.min_samples_leaf,summary.trial_number,summary.value_0,summary.value_1,config.max_depth,config.min_samples_leaf,summary.lr
0,jannis,silver-plant-1388,eval,81.0,250.0,1.000000,1.691346e+09,0.922222,250.0,0.6,...,,,,,,,,,,
1,jannis,fearless-energy-1387,eval,81.0,250.0,1.000000,1.691345e+09,0.822222,250.0,0.6,...,,,,,,,,,,
2,jannis,lucky-vortex-1386,eval,111.0,250.0,0.911111,1.691345e+09,0.853846,250.0,0.6,...,,,,,,,,,,
3,jannis,fragrant-cloud-1385,eval,111.0,250.0,0.988889,1.691345e+09,0.784615,250.0,0.6,...,,,,,,,,,,
4,jannis,lyric-galaxy-1384,eval,141.0,250.0,0.927928,1.691345e+09,0.812865,250.0,0.6,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3419,covertype,trial/2/good-sun-5,sweep,,,,1.691012e+09,,,,...,"[MAXIMIZE, MAXIMIZE]",,3.0,3.0,2.0,0.54,0.68,3.0,3.0,0.2
3420,covertype,trial/1/bright-butterfly-4,sweep,,,,1.691012e+09,,,,...,"[MAXIMIZE, MAXIMIZE]",,3.0,1.0,1.0,0.51,0.61,3.0,1.0,
3421,covertype,trial/1/vital-durian-3,sweep,,,,1.691012e+09,,,,...,"[MAXIMIZE, MAXIMIZE]",,3.0,1.0,1.0,0.54,0.68,3.0,1.0,0.3
3422,covertype,trial/0/usual-pine-2,sweep,,,,1.691012e+09,,,,...,"[MAXIMIZE, MAXIMIZE]",,,,,,,4.0,1.0,


In [24]:
COLUMNS = {
    "dataset": "dataset",
    "name": "name",
    "config.args.model": "model",
    "config.args.st_type": "st_type",
    "config.split.l_split": "l_split",
    "config.split.ul_split": "ul_split",
    "summary.test.acc": "test.acc",
}
evals_df = runs_df[runs_df["job_type"] == "eval"][COLUMNS.keys()].rename(
    columns=COLUMNS
)
evals_df

Unnamed: 0,dataset,name,model,st_type,l_split,ul_split,test.acc
0,jannis,silver-plant-1388,mlp,curr,0.025,0.025,0.668
1,jannis,fearless-energy-1387,mlp,curr,0.025,0.025,0.696
2,jannis,lucky-vortex-1386,mlp,curr,0.025,0.050,0.626
3,jannis,fragrant-cloud-1385,mlp,curr,0.025,0.050,0.693
4,jannis,lyric-galaxy-1384,mlp,curr,0.025,0.075,0.664
...,...,...,...,...,...,...,...
3322,covertype,exalted-dawn-204,hgbt,curr,0.500,0.500,0.724
3323,covertype,feasible-butterfly-203,random-forest,curr,0.500,0.250,0.721
3324,covertype,drawn-thunder-202,random-forest,curr,0.500,0.500,0.717
3325,covertype,likely-eon-201,random-forest,curr,0.750,0.250,0.753


In [70]:
mean_evals_df = (
    evals_df.drop(columns=["name"])
    .groupby(
        ["dataset", "model", "st_type", "l_split", "ul_split"],
        as_index=False,
        dropna=False,
    )
    .mean()
)

mean_evals_df["st_type"].fillna("None", inplace=True)
for dataset in DATASETS:
    fig = px.scatter(
        mean_evals_df[mean_evals_df["dataset"] == dataset],
        x="l_split",
        y="ul_split",
        color="test.acc",
        facet_col="model",
        hover_data=["st_type"],
        title=f"{PREFIX}{dataset}",
        color_continuous_scale="bluered",
        width=1200,
        height=400,
    )
    fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
    fig.show()