In [1]:
# imports
from tueplots import bundles
import wandb
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator, FormatStrFormatter
from tueplots  import figsizes

import sys

%load_ext autoreload
%autoreload 2

sys.path.insert(0, '.')

In [15]:
from analysis import sweep2df, format_violin, RED, BLUE
from care_nl_ica.models.sinkhorn import learn_permutation

In [3]:
plt.rcParams.update(bundles.neurips2022(usetex=True))
plt.rcParams.update({
    'text.latex.preamble': [r'\usepackage{amsfonts}', # mathbb
                            r'\usepackage{amsmath}'] # boldsymbol
})

  self[key] = other[key]


In [273]:
# Constants
ENTITY = "causal-representation-learning"
PROJECT = "nl-causal-representations"

# W&B API
api = wandb.Api(timeout=200)
runs = api.runs(ENTITY + "/" + PROJECT)

# Data loading

In [33]:

SWEEP_ID = "vfv1je0d"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
sweep.runs[59].logged_artifacts()[0].get("dep_mat_table").get_column("dep_mat", "numpy")#.use_artifact("dep_mat_table")
# artifact = run.use_artifact('causal-representation-learning/experiment/run-iczvn8od-dep_mat_table:v0', type='run_table')
# artifact.get("dep_mat_table")

array([[ 1.93882871,  1.40007854,  1.50939941, -1.52444494,  0.01124954,
         0.00467368,  1.57117176,  1.79014182, -0.01026418]])

In [11]:
SWEEP_ID = "gz1yqsuk"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
filename = f"sem_sweep_{SWEEP_ID}"
print(f"Loading sweep with {SWEEP_ID=}")
df1, (true_unmix_jacobians1, est_unmix_jacobians1, permute_indices1) = sweep2df(sweep.runs, filename, save=True, load=True)

SWEEP_ID = "5gzpzb23"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
filename = f"sem_sweep_{SWEEP_ID}"
print(f"Loading sweep with {SWEEP_ID=}")
df2, (true_unmix_jacobians2, est_unmix_jacobians2, permute_indices2) = sweep2df(sweep.runs, filename, save=True, load=True)

df = df1.append(df2)
true_unmix_jacobians = true_unmix_jacobians1 +true_unmix_jacobians2
est_unmix_jacobians = est_unmix_jacobians1+est_unmix_jacobians2
permute_indices = permute_indices1+permute_indices2

Loading sweep with SWEEP_ID='gz1yqsuk'
Loading sweep with SWEEP_ID='5gzpzb23'


## 3D SEM

In [99]:

SWEEP_ID = "vfv1je0d" #"nz4r5d8a"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
filename = f"sem_3d_sweep_{SWEEP_ID}"
df3d, (true_unmix_jacobians3d, est_unmix_jacobians3d, permute_indices3d) = sweep2df(sweep.runs, filename, save=True, load=False)

## 5D SEM

In [268]:
SWEEP_ID = "h6y1gkvo"#"3ldi48id"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
filename = f"sem_5d_sweep_{SWEEP_ID}"
df5d, (true_unmix_jacobians5d, est_unmix_jacobians5d, permute_indices5d) = sweep2df(sweep.runs, filename, save=True, load=False)

## 8D SEM

In [263]:
SWEEP_ID = "7sscc3w1"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
filename = f"sem_8d_sweep_{SWEEP_ID}"
df8d, (true_unmix_jacobians8d, est_unmix_jacobians8d, permute_indices8d) = sweep2df(sweep.runs, filename, save=True, load=False)

Encountered a faulty run with ID apricot-sweep-49
Encountered a faulty run with ID apricot-sweep-49


## 10D SEM

## MLP from Monti et al.

In [274]:
SWEEP_ID = "77huh2ue"#"q99ne3vj" #"fhaza97x"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
filename = f"monti_sweep_{SWEEP_ID}"
df_monti, (true_unmix_jacobians_monti, est_unmix_jacobians_monti, permute_indices_monti) = sweep2df(sweep.runs, filename, save=True, load=False)

In [233]:
df_monti

Unnamed: 0,name,dim,permute,variant,n_mixing_layer,use_sem,nonlin_sem,force_chain,force_uniform,mcc,val_loss
0,charmed-sweep-25,6,False,3,5,False,False,False,False,0.837911,1.854845
1,clean-sweep-22,6,False,3,5,False,False,False,False,0.829384,1.989501
2,sweepy-sweep-20,6,False,3,4,False,False,False,False,0.838729,1.86473
3,trim-sweep-13,6,False,3,3,False,False,False,False,0.830675,1.848964
4,lucky-sweep-14,6,False,3,3,False,False,False,False,0.999311,1.036702
5,zany-sweep-12,6,False,3,3,False,False,False,False,0.987399,1.1389
6,eager-sweep-9,6,False,3,2,False,False,False,False,0.999632,1.02977
7,wild-sweep-10,6,False,3,2,False,False,False,False,0.999327,1.048075
8,vocal-sweep-11,6,False,3,3,False,False,False,False,0.998918,1.047589
9,zany-sweep-4,6,False,3,1,False,False,False,False,0.999831,1.028984


# Pre-processing

In [277]:
def learning_stats(df, true_unmix_jacobians, est_unmix_jacobians, permute_indices, hamming_threshold=1e-2, selector_col="nonlin_sem", weight_threshold=None, dag_permute=True):

    for dim in df.dim.unique():
        for selector in df[selector_col].unique():
            success = []
            hamming = []
            accuracy = []
            for (selector_item, j_gt, j_est, permute) in zip(df[selector_col], true_unmix_jacobians, est_unmix_jacobians, permute_indices):
                if j_gt.shape[0] == dim and selector_item == selector:
                    s, h, a = learn_permutation(j_gt, j_est, permute, triu_weigth=20., tril_weight=10.,diag_weight=6., num_steps=5000, lr=1e-4, verbose=True,drop_smallest=True, threshold=weight_threshold, binary=True, hamming_threshold=hamming_threshold, dag_permute=dag_permute)

                    success.append(s)
                    hamming.append(h)
                    accuracy.append(a)

            mcc = df.mcc[(df.dim ==dim) & (df[selector_col]==selector)]
            print("----------------------------------")
            print("----------------------------------")
            if len(success) > 0:
                print(f"{dim=} ({selector_col}={selector})\tMCC={mcc.mean():.3f}+{mcc.std():.3f}\tAcc(order):{np.array(success).mean():.3f}\t  Acc:{np.array(accuracy).mean():.3f}\tSHD:{np.array(hamming).mean():.6f}\t[{len(success)} items]")
            else:
                print(f"No experiments for {dim=} ({selector_col}={selector})")
            print("----------------------------------")
            print("----------------------------------")

## 3D SEM

In [230]:
learning_stats(df3d, true_unmix_jacobians3d, est_unmix_jacobians3d, permute_indices3d)

tensor([[ 1.2796e-03,  1.1690e+00,  7.3501e-03],
        [-1.7509e+00, -1.1760e+00, -1.0902e+00],
        [ 1.2264e+00,  1.3568e+00, -4.6644e-03]])
Correct order identified
tensor([[-1.7941, -1.3232, -0.0021],
        [-0.0187,  1.9023, -0.0063],
        [-1.7499, -1.3033, -1.1856]])
Correct order identified
tensor([[1.4663, 1.3997, 0.0030],
        [1.1139, 1.6406, 1.2161],
        [1.8614, 0.0077, 0.0053]])
Correct order identified
tensor([[-0.0078,  0.0041,  1.2849],
        [-1.7284, -1.0818, -1.5966],
        [ 0.0085, -1.8809, -1.4370]])
Correct order identified
tensor([[ 1.5253,  1.3723,  1.9197],
        [-0.0031, -0.0029, -1.5467],
        [ 0.0139,  1.8123,  1.6050]])
Correct order identified
tensor([[-0.0050, -1.1637,  0.0052],
        [ 0.0072, -1.3557, -1.2176],
        [-1.0853, -1.1673, -1.7357]])
Correct order identified
tensor([[ 0.0059,  1.8653,  0.0109],
        [-1.2188, -1.1149, -1.6398],
        [-0.0069, -1.4626, -1.3991]])
Correct order identified
tensor([[ 1.73

## 5D SEM

In [269]:
learning_stats(df5d, true_unmix_jacobians5d, est_unmix_jacobians5d, permute_indices5d, 1e-3)

Correct order identified
Correct order identified
Correct order identified
Correct order identified
Correct order identified
Correct order identified
Correct order identified
----------------------------------
true_jac=tensor([[1.6245, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.8148, 2.5022, 0.0000, 0.0000, 0.0000],
        [1.2530, 3.0065, 2.5013, 0.0000, 0.0000],
        [1.2919, 2.0411, 2.0376, 2.5548, 0.0000],
        [1.9508, 2.8156, 2.4423, 2.2518, 3.0212]])
est_jac=tensor([[0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 0., 1., 0.]])
tensor([[5.0000e-01, 5.0000e-01, 1.6908e-07, 9.5173e-09, 1.0258e-08],
        [5.0000e-01, 5.0000e-01, 2.2775e-07, 4.6060e-09, 5.3467e-09],
        [1.0000e+00, 1.0000e+00, 1.0000e+00, 1.7450e-07, 1.7112e-07],
        [1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00],
        [1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00]])
S_DAG=tensor([[5.000

## 8D SEM

In [271]:
learning_stats(df8d, true_unmix_jacobians8d, est_unmix_jacobians8d, permute_indices8d, 1e-3)

----------------------------------
true_jac=tensor([[1.0850, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.9611, 3.5985, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.4147, 4.0179, 2.0769, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.7145, 4.0698, 2.3540, 3.0299, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.8764, 3.3702, 2.4392, 3.5258, 2.2418, 0.0000, 0.0000, 0.0000],
        [1.6313, 2.3522, 1.3941, 2.1047, 1.8677, 2.3872, 0.0000, 0.0000],
        [1.1212, 2.8969, 2.0693, 3.0300, 3.6201, 2.0695, 1.2522, 0.0000],
        [1.7588, 3.3396, 1.7038, 2.6641, 2.6519, 1.6928, 2.2463, 2.7388]])
est_jac=tensor([[0., 0., 0., 1., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 0., 1., 0.],
        [1., 1., 0., 1., 1., 1., 1., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 1., 1., 1., 1., 0.],
        [0., 1., 1., 1., 1., 0., 1., 1.],
        [1., 1., 0., 0., 1., 1., 1., 1.]])
tensor([[3.7194e-0

## 10D SEM

In [146]:
est_unmix_jacobians_monti[0]

array([[  0.26582664,   0.        ,   0.        ,   0.        ,
          0.        ,   0.        ],
       [ -1.42705572,   0.16583258,   0.        ,   0.        ,
          0.        ,   0.        ],
       [  4.06942463,  -0.93084276,   0.18074501,   0.        ,
          0.        ,   0.        ],
       [-11.65866947,   3.33683395,  -1.02096283,   0.08979765,
          0.        ,   0.        ],
       [ 20.40413284,  -7.05351353,   2.93062115,  -0.7304942 ,
          0.2068955 ,   0.        ],
       [-30.47312927,  11.30219364,  -5.26044369,   1.71105742,
         -0.73918879,   0.17236085]])

## MLP from Monti et al.

In [279]:
learning_stats(df_monti, true_unmix_jacobians_monti, est_unmix_jacobians_monti, permute_indices_monti, selector_col="n_mixing_layer", weight_threshold=None, hamming_threshold=1e-3, dag_permute=False)

----------------------------------
true_jac=tensor([[ 0.0025,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.0128,  0.0021,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0314, -0.0108,  0.0021,  0.0000,  0.0000,  0.0000],
        [-0.0772,  0.0332, -0.0109,  0.0016,  0.0000,  0.0000],
        [ 0.1414, -0.0690,  0.0295, -0.0095,  0.0022,  0.0000],
        [-0.2077,  0.1093, -0.0530,  0.0234, -0.0095,  0.0021]])
est_jac=tensor([[1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [0., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [0., 1., 1., 1., 0., 0.]])
tensor([[5.0000e-01, 5.0000e-01, 5.0000e-01, 4.0059e-15, 5.8765e-18, 0.0000e+00],
        [5.4537e-10, 1.0000e+00, 5.5369e-10, 5.5369e-10, 5.8286e-15, 0.0000e+00],
        [5.0000e-01, 5.0000e-01, 5.0000e-01, 4.2770e-07, 1.5501e-09, 0.0000e+00],
        [9.9999e-01, 1.5077e+00, 1.5077e+00, 1.5077e+00, 4.7045e-14, 0.0000e+00],
        [1.2109e+00, 1.2109e+0

In [261]:
success = []
hamming = []
acc = []
for (cond, j_gt, j_est, permute) in zip(df_monti["n_mixing_layer"], true_unmix_jacobians_monti, est_unmix_jacobians_monti, permute_indices_monti):
    if cond == 5:
        s, h, a = learn_permutation(j_gt, j_est, permute, triu_weigth=20., tril_weight=10.,diag_weight=6., num_steps=1500, lr=1e-4, verbose=True, drop_smallest=True, threshold=None,binary=True, dag_permute=False, hamming_threshold=1e-4)
        success.append(s)
        hamming.append(h)
        acc.append(a)

print("----------------------------------")
print("----------------------------------")
print(f"Acc (order):{np.array(success).mean()}\t SHD:{np.array(hamming).mean()}\t Acc:{np.array(acc).mean()}\t[{len(success)} items]")
print("----------------------------------")
print("----------------------------------")

----------------------------------
true_jac=tensor([[ 0.0025,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.0128,  0.0021,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0314, -0.0108,  0.0021,  0.0000,  0.0000,  0.0000],
        [-0.0772,  0.0332, -0.0109,  0.0016,  0.0000,  0.0000],
        [ 0.1414, -0.0690,  0.0295, -0.0095,  0.0022,  0.0000],
        [-0.2077,  0.1093, -0.0530,  0.0234, -0.0095,  0.0021]])
est_jac=tensor([[1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [0., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [0., 1., 1., 1., 0., 0.]])
tensor([[5.0000e-01, 5.0000e-01, 5.0000e-01, 3.9847e-15, 5.8989e-18, 0.0000e+00],
        [5.4260e-10, 1.0000e+00, 5.5087e-10, 5.5087e-10, 5.8508e-15, 0.0000e+00],
        [5.0000e-01, 5.0000e-01, 5.0000e-01, 3.3272e-06, 1.5698e-09, 0.0000e+00],
        [9.9990e-01, 1.5068e+00, 1.5068e+00, 1.5068e+00, 4.7461e-14, 0.0000e+00],
        [1.2115e+00, 1.2115e+0