# Imports

In [64]:
# imports
import numpy as np
from tueplots import bundles, figsizes
import wandb
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import pandas as pd


import sys

%load_ext autoreload
%autoreload 2

sys.path.insert(0, '.')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [91]:
from analysis import sweep2df, plot_typography, stats2string


In [3]:
USETEX = True

In [4]:
plt.rcParams.update(bundles.neurips2022(usetex=USETEX))
# plt.rcParams.update({
#     'text.latex.preamble': [r'\usepackage{amsfonts}', # mathbb
#                             r'\usepackage{amsmath}'] # boldsymbol
# })

In [5]:
plot_typography(usetex=USETEX, small=12, medium=16, big=20)

In [6]:
# Constants
ENTITY = "causal-representation-learning"
PROJECT = "lti-ica"

# W&B API
api = wandb.Api(timeout=200)
runs = api.runs(ENTITY + "/" + PROJECT)

# Data loading

## Max variability

In [7]:
SWEEP_ID = "6u3mgtpz"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
filename = f"max_var_{SWEEP_ID}"
df_max_var, train_log_likelihood_max_var, train_mcc_max_var, val_log_likelihood_max_var, val_mcc_max_var = sweep2df(sweep.runs, filename, save=True, load=True)

### Max variability 10 dimensions

In [11]:
SWEEP_ID = "woiubqya"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
filename = f"max_var_10_{SWEEP_ID}"
df_max_var_10, train_log_likelihood_max_var_10, train_mcc_max_var_10, val_log_likelihood_max_var_10, val_mcc_max_var_10 = sweep2df(sweep.runs, filename, save=True, load=False)

### Rerun for 8 dimensions

In [14]:
SWEEP_ID = "f2n0z65l"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
filename = f"max_var_8_{SWEEP_ID}"
df_max_var_8, train_log_likelihood_max_var_8, train_mcc_max_var_8, val_log_likelihood_max_var_8, val_mcc_max_var_8 = sweep2df(sweep.runs, filename, save=True, load=False)

### Concatenate

In [146]:
df_max_var_concat = pd.concat([df_max_var, df_max_var_8, df_max_var_10])

In [147]:
df_max_var_concat.groupby(["num_comp", "zero_means", "use_B", "use_C"]).mean()[["train_mcc", "max_train_mcc", "val_mcc", "max_val_mcc"]]

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,train_mcc,max_train_mcc,val_mcc,max_val_mcc
num_comp,zero_means,use_B,use_C,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
2,False,True,True,0.967982,0.987145,0.863717,0.872004
3,False,True,True,0.999989,0.999995,0.997175,0.999113
5,False,True,True,0.995257,0.997676,0.824089,0.826335
8,False,True,True,0.976543,0.983045,0.435802,0.435802
10,False,True,True,0.995729,0.996371,0.732558,0.734933


In [148]:
df_max_var_concat_dict = {comp: df_max_var_concat[df_max_var_concat.num_comp == comp].groupby(["num_comp", "zero_means", "use_B", "use_C"]) for
 comp in sorted(df_max_var_concat.num_comp.unique())}

In [149]:
max_var_stats = [stats2string(df) for comp, df in df_max_var_concat_dict.items()]

## Minimal segments

### Original sweep

In [17]:
SWEEP_ID =  "shrjtedq" #"03w02539"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
filename = f"min_segment_{SWEEP_ID}"
df_min_segment, train_log_likelihood_min_segment, train_mcc_min_segment, val_log_likelihood_min_segment, val_mcc_min_segment = sweep2df(sweep.runs, filename, save=True, load=False)

### Rerun for 8 and 10 dimensions

In [75]:
SWEEP_ID = "dvn24tw0"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
filename = f"min_segment_8_10_{SWEEP_ID}"
df_min_segment_8_10, train_log_likelihood_min_segment_8_10, train_mcc_min_segment_8_10, val_log_likelihood_min_segment_8_10, val_mcc_min_segment_8_10 = sweep2df(sweep.runs, filename, save=True, load=False)

Encountered a faulty run with ID upbeat-sweep-15


### Concatenate

In [141]:
df_min_segment_concat = pd.concat([df_min_segment, df_min_segment_8_10])

In [142]:
df_min_segment_concat.groupby(["num_comp", "zero_means", "use_B", "use_C"]).mean()[
    ["train_mcc", "max_train_mcc", "val_mcc", "max_val_mcc"]]

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,train_mcc,max_train_mcc,val_mcc,max_val_mcc
num_comp,zero_means,use_B,use_C,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
2,False,False,False,0.866319,0.911332,0.58251,0.619459
2,False,False,True,0.766681,0.820678,0.511342,0.524911
2,False,True,False,0.729862,0.826553,0.516959,0.553255
2,False,True,True,0.697367,0.888114,0.596271,0.629053
2,True,False,False,0.632901,0.847892,0.619947,0.667954
2,True,False,True,0.674872,0.830931,0.680293,0.717334
2,True,True,False,0.734144,0.842189,0.69317,0.708631
2,True,True,True,0.725062,0.875474,0.732886,0.753187
3,False,False,False,0.900751,0.941721,0.741451,0.747056
3,False,False,True,0.910463,0.941608,0.748249,0.760785


In [143]:
df_min_segment_concat_dict = {
    comp: df_min_segment_concat[df_min_segment_concat.num_comp == comp].groupby(["num_comp", "zero_means", "use_B", "use_C"])
    for
    comp in sorted(df_min_segment_concat.num_comp.unique())}

In [144]:
min_segment_stats = [stats2string(df) for comp, df in df_min_segment_concat_dict.items()]

# Render text for table

In [154]:
all_stats = ["".join(["-----", str(comp),"\n" , min_segment, max_var, "\n"]) for comp, min_segment, max_var in zip(sorted(df_min_segment_concat.num_comp.unique()), min_segment_stats, max_var_stats)]

In [155]:
print("\n".join(all_stats))

-----2
$0.866\scriptscriptstyle\pm 0.033$ & $0.767\scriptscriptstyle\pm 0.158$ & $0.730\scriptscriptstyle\pm 0.191$ & $0.697\scriptscriptstyle\pm 0.090$ & $0.633\scriptscriptstyle\pm 0.104$ & $0.675\scriptscriptstyle\pm 0.132$ & $0.734\scriptscriptstyle\pm 0.158$ & $0.725\scriptscriptstyle\pm 0.115$ & $0.968\scriptscriptstyle\pm 0.055$ & 

-----3
$0.901\scriptscriptstyle\pm 0.054$ & $0.910\scriptscriptstyle\pm 0.061$ & $0.916\scriptscriptstyle\pm 0.044$ & $0.861\scriptscriptstyle\pm 0.062$ & $0.659\scriptscriptstyle\pm 0.201$ & $0.618\scriptscriptstyle\pm 0.241$ & $0.633\scriptscriptstyle\pm 0.110$ & $0.667\scriptscriptstyle\pm 0.109$ & $1.000\scriptscriptstyle\pm 0.000$ & 

-----5
$0.892\scriptscriptstyle\pm 0.057$ & $0.929\scriptscriptstyle\pm 0.025$ & $0.928\scriptscriptstyle\pm 0.026$ & $0.911\scriptscriptstyle\pm 0.045$ & $0.657\scriptscriptstyle\pm 0.116$ & $0.618\scriptscriptstyle\pm 0.079$ & $0.620\scriptscriptstyle\pm 0.025$ & $0.539\scriptscriptstyle\pm 0.078$ & $0.995\script