# Imports

In [1]:
# imports
import numpy as np
from tueplots import bundles, figsizes
import wandb
import matplotlib.pyplot as plt
import matplotlib.lines as mlines

import sys

%load_ext autoreload
%autoreload 2

sys.path.insert(0, '.')

In [2]:
from analysis import sweep2df, plot_typography


In [3]:
USETEX = True

In [4]:
plt.rcParams.update(bundles.neurips2022(usetex=USETEX))
# plt.rcParams.update({
#     'text.latex.preamble': [r'\usepackage{amsfonts}', # mathbb
#                             r'\usepackage{amsmath}'] # boldsymbol
# })

In [5]:
plot_typography(usetex=USETEX, small=12, medium=16, big=20)

In [6]:
# Constants
ENTITY = "causal-representation-learning"
PROJECT = "lti-ica"

# W&B API
api = wandb.Api(timeout=200)
runs = api.runs(ENTITY + "/" + PROJECT)

# Data loading

## Max variability

In [None]:
SWEEP_ID = "6u3mgtpz"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
filename = f"max_var_{SWEEP_ID}"
df_max_var, train_log_likelihood_max_var, train_mcc_max_var, val_log_likelihood_max_var, val_mcc_max_var = sweep2df(sweep.runs, filename, save=True, load=True)

In [16]:
num_comps = df_max_var.num_comp.unique()

In [31]:
df_max_var.groupby("num_comp").mean()[["train_mcc", "max_train_mcc", "val_mcc", "max_val_mcc"]]

Unnamed: 0_level_0,train_mcc,max_train_mcc,val_mcc,max_val_mcc
num_comp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
2,0.967982,0.986363,0.863717,0.872004
3,0.999989,0.999995,0.997175,0.999113
5,0.995257,0.997355,0.824089,0.826335


In [32]:
df_max_var.groupby("num_comp").std()[["train_mcc", "max_train_mcc", "val_mcc", "max_val_mcc"]]

Unnamed: 0_level_0,train_mcc,max_train_mcc,val_mcc,max_val_mcc
num_comp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
2,0.054551,0.017632,0.098386,0.106228
3,1.7e-05,9e-06,0.001646,0.000368
5,0.008955,0.004699,0.233161,0.235198


### Rerun for 8 and 10 dimensions

In [31]:
SWEEP_ID = "f2n0z65l"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
filename = f"max_var_8_10_{SWEEP_ID}"
df_max_var_8_10, train_log_likelihood_max_var_8_10, train_mcc_max_var_8_10, val_log_likelihood_max_var_8_10, val_mcc_max_var_8_10 = sweep2df(sweep.runs, filename, save=True, load=False)

In [32]:
df_max_var_8_10.groupby("num_comp").mean()[["train_mcc", "max_train_mcc", "val_mcc", "max_val_mcc"]]

Unnamed: 0_level_0,train_mcc,max_train_mcc,val_mcc,max_val_mcc
num_comp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
8,0.976543,0.984138,0.435802,0.435802


In [33]:
df_max_var_8_10.groupby("num_comp").std()[["train_mcc", "max_train_mcc", "val_mcc", "max_val_mcc"]]

Unnamed: 0_level_0,train_mcc,max_train_mcc,val_mcc,max_val_mcc
num_comp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
8,0.010628,0.008797,0.102631,0.102631


## Minimal segments

### Original sweep

In [21]:
SWEEP_ID =  "shrjtedq" #"03w02539"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
filename = f"min_segment_{SWEEP_ID}"
df_min_segment, train_log_likelihood_min_segment, train_mcc_min_segment, val_log_likelihood_min_segment, val_mcc_min_segment = sweep2df(sweep.runs, filename, save=True, load=False)

In [22]:
df_min_segment.groupby(["num_comp", "use_B", "use_C"]).mean()[["train_mcc", "max_train_mcc", "val_mcc", "max_val_mcc"]]

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,train_mcc,max_train_mcc,val_mcc,max_val_mcc
num_comp,use_B,use_C,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
2,False,False,0.74961,0.887405,0.601228,0.643706
2,False,True,0.720776,0.825444,0.595818,0.621123
2,True,False,0.732003,0.835528,0.605064,0.630943
2,True,True,0.711214,0.874172,0.664579,0.69112
3,False,False,0.780123,0.86793,0.687267,0.729727
3,False,True,0.764321,0.859162,0.682664,0.692265
3,True,False,0.774278,0.868033,0.705219,0.734749
3,True,True,0.764163,0.856587,0.728543,0.744332
5,False,False,0.774317,0.846832,0.702643,0.719298
5,False,True,0.773904,0.835337,0.693275,0.707001


In [23]:
df_min_segment.groupby(["num_comp", "use_B", "use_C"]).std()[["train_mcc", "max_train_mcc", "val_mcc", "max_val_mcc"]]

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,train_mcc,max_train_mcc,val_mcc,max_val_mcc
num_comp,use_B,use_C,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
2,False,False,0.143044,0.054397,0.105367,0.090867
2,False,True,0.145781,0.078683,0.155281,0.159072
2,True,False,0.165058,0.074605,0.169585,0.169743
2,True,True,0.098364,0.041601,0.113889,0.109193
3,False,False,0.18807,0.109065,0.138182,0.106827
3,False,True,0.226299,0.125249,0.156497,0.153223
3,True,False,0.169024,0.112382,0.127367,0.123624
3,True,True,0.132415,0.081142,0.085279,0.084514
5,False,False,0.15066,0.105896,0.113396,0.105216
5,False,True,0.172979,0.126659,0.129851,0.116479


### Rerun for 8 and 10 dimensions

In [34]:
SWEEP_ID = "dvn24tw0"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
filename = f"min_segment_8_10_{SWEEP_ID}"
df_min_segment_8_10, train_log_likelihood_min_segment_8_10, train_mcc_min_segment_8_10, val_log_likelihood_min_segment_8_10, val_mcc_min_segment_8_10 = sweep2df(sweep.runs, filename, save=True, load=False)

Encountered a faulty run with ID upbeat-sweep-15


In [35]:
df_min_segment_8_10.groupby(["num_comp", "use_B", "use_C"]).mean()[["train_mcc", "max_train_mcc", "val_mcc", "max_val_mcc"]]

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,train_mcc,max_train_mcc,val_mcc,max_val_mcc
num_comp,use_B,use_C,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
8,False,False,0.763699,0.82455,0.691167,0.726957
8,False,True,0.683156,0.76277,0.614221,0.644257
8,True,False,0.821271,0.851866,0.714802,0.727268
10,False,False,0.731828,0.763287,0.660218,0.662591
10,False,True,0.762728,0.781123,0.64251,0.658213
10,True,False,0.796362,0.818015,0.66241,0.673324


In [36]:
df_min_segment_8_10.groupby(["num_comp", "use_B", "use_C"]).std()[["train_mcc", "max_train_mcc", "val_mcc", "max_val_mcc"]]

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,train_mcc,max_train_mcc,val_mcc,max_val_mcc
num_comp,use_B,use_C,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
8,False,False,0.214612,0.154562,0.163221,0.125428
8,False,True,0.259031,0.19009,0.184958,0.160647
8,True,False,0.204077,0.174552,0.17199,0.153041
10,False,False,,,,
10,False,True,0.177957,0.191293,0.177224,0.159062
10,True,False,0.178736,0.177976,0.157806,0.144909
