In [1]:
# imports
from tueplots import bundles
import wandb
import numpy as np
import matplotlib.pyplot as plt

import sys

%load_ext autoreload
%autoreload 2

sys.path.insert(0, '.')

In [2]:
from analysis import sweep2df
from care_nl_ica.models.sinkhorn import learn_permutation

In [3]:
plt.rcParams.update(bundles.neurips2022(usetex=True))
plt.rcParams.update({
    'text.latex.preamble': [r'\usepackage{amsfonts}', # mathbb
                            r'\usepackage{amsmath}'] # boldsymbol
})

  self[key] = other[key]


In [4]:
# Constants
ENTITY = "causal-representation-learning"
PROJECT = "nl-causal-representations"

# W&B API
api = wandb.Api(timeout=200)
runs = api.runs(ENTITY + "/" + PROJECT)

# Data loading

In [11]:
SWEEP_ID = "gz1yqsuk"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
filename = f"sem_sweep_{SWEEP_ID}"
print(f"Loading sweep with {SWEEP_ID=}")
df1, (true_unmix_jacobians1, est_unmix_jacobians1, permute_indices1) = sweep2df(sweep.runs, filename, save=True, load=True)

SWEEP_ID = "5gzpzb23"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
filename = f"sem_sweep_{SWEEP_ID}"
print(f"Loading sweep with {SWEEP_ID=}")
df2, (true_unmix_jacobians2, est_unmix_jacobians2, permute_indices2) = sweep2df(sweep.runs, filename, save=True, load=True)

df = df1.append(df2)
true_unmix_jacobians = true_unmix_jacobians1 +true_unmix_jacobians2
est_unmix_jacobians = est_unmix_jacobians1+est_unmix_jacobians2
permute_indices = permute_indices1+permute_indices2

Loading sweep with SWEEP_ID='gz1yqsuk'
Loading sweep with SWEEP_ID='5gzpzb23'


## 3D SEM

In [9]:

SWEEP_ID = "vfv1je0d" #"nz4r5d8a"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
filename = f"sem_3d_sweep_{SWEEP_ID}"
df3d, (true_unmix_jacobians3d, est_unmix_jacobians3d, permute_indices3d) = sweep2df(sweep.runs, filename, save=True, load=False)

## 5D SEM

In [10]:
SWEEP_ID = "h6y1gkvo"#"3ldi48id"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
filename = f"sem_5d_sweep_{SWEEP_ID}"
df5d, (true_unmix_jacobians5d, est_unmix_jacobians5d, permute_indices5d) = sweep2df(sweep.runs, filename, save=True, load=False)

Encountered a faulty run with ID happy-sweep-37
Encountered a faulty run with ID wandering-sweep-28


## 8D SEM

In [5]:
SWEEP_ID = "7sscc3w1"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
filename = f"sem_8d_sweep_{SWEEP_ID}"
df8d, (true_unmix_jacobians8d, est_unmix_jacobians8d, permute_indices8d) = sweep2df(sweep.runs, filename, save=True, load=False)

Encountered a faulty run with ID apricot-sweep-49


## 10D SEM

In [5]:
SWEEP_ID = "7lsb5ud3"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
filename = f"sem_10d_sweep_{SWEEP_ID}"
df10d, (true_unmix_jacobians10d, est_unmix_jacobians10d, permute_indices10d) = sweep2df(sweep.runs, filename, save=True, load=False)

## MLP from Monti et al.

In [17]:
SWEEP_ID = "77huh2ue"#"q99ne3vj" #"fhaza97x"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
filename = f"monti_sweep_{SWEEP_ID}"
df_monti, (true_unmix_jacobians_monti, est_unmix_jacobians_monti, permute_indices_monti) = sweep2df(sweep.runs, filename, save=True, load=False)

# Pre-processing

In [7]:
def learning_stats(df, true_unmix_jacobians, est_unmix_jacobians, permute_indices, hamming_threshold=1e-2, selector_col="nonlin_sem", weight_threshold=None, dag_permute=True, num_steps=5000):

    for dim in df.dim.unique():
        for selector in df[selector_col].unique():
            success = []
            hamming = []
            accuracy = []
            for (selector_item, j_gt, j_est, permute) in zip(df[selector_col], true_unmix_jacobians, est_unmix_jacobians, permute_indices):
                if j_gt.shape[0] == dim and selector_item == selector:
                    s, h, a = learn_permutation(j_gt, j_est, permute, triu_weigth=20., tril_weight=10.,diag_weight=6., num_steps=num_steps, lr=1e-4, verbose=True,drop_smallest=True, threshold=weight_threshold, binary=True, hamming_threshold=hamming_threshold, dag_permute=dag_permute)

                    success.append(s)
                    hamming.append(h)
                    accuracy.append(a)

            mcc = df.mcc[(df.dim ==dim) & (df[selector_col]==selector)]
            print("----------------------------------")
            print("----------------------------------")
            if len(success) > 0:
                print(f"{dim=} ({selector_col}={selector})\tMCC={mcc.mean():.3f}+{mcc.std():.3f}\tAcc(order):{np.array(success).mean():.3f}\t  Acc:{np.array(accuracy).mean():.3f}\tSHD:{np.array(hamming).mean():.6f}\t[{len(success)} items]")
            else:
                print(f"No experiments for {dim=} ({selector_col}={selector})")
            print("----------------------------------")
            print("----------------------------------")

## 3D SEM

In [14]:
learning_stats(df3d, true_unmix_jacobians3d, est_unmix_jacobians3d, permute_indices3d)

Correct order identified
Correct order identified
Correct order identified
Correct order identified
Correct order identified
Correct order identified
Correct order identified
Correct order identified
Correct order identified
Correct order identified
Correct order identified
Correct order identified
Correct order identified
Correct order identified
Correct order identified
Correct order identified
Correct order identified
Correct order identified
Correct order identified
Correct order identified
Correct order identified
Correct order identified
Correct order identified
Correct order identified
Correct order identified
Correct order identified
Correct order identified
----------------------------------
----------------------------------
dim=3 (nonlin_sem=False)	MCC=1.000+0.000	Acc(order):1.000	  Acc:1.000	SHD:0.000000	[27 items]
----------------------------------
----------------------------------
Correct order identified
Correct order identified
Correct order identified
Correct order id

## 5D SEM

In [13]:
learning_stats(df5d, true_unmix_jacobians5d, est_unmix_jacobians5d, permute_indices5d, 1e-3)

Correct order identified
Correct order identified
Correct order identified
Correct order identified
Correct order identified
Correct order identified
Correct order identified
Correct order identified
Correct order identified
----------------------------------
true_jac=tensor([[1.6245, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.8148, 2.5022, 0.0000, 0.0000, 0.0000],
        [1.2530, 3.0065, 2.5013, 0.0000, 0.0000],
        [1.2919, 2.0411, 2.0376, 2.5548, 0.0000],
        [1.9508, 2.8156, 2.4423, 2.2518, 3.0212]])
est_jac=tensor([[0., 0., 0., 1., 0.],
        [1., 0., 1., 1., 1.],
        [0., 1., 1., 1., 1.],
        [1., 0., 0., 1., 0.],
        [0., 1., 1., 1., 1.]])
tensor([[5.0000e-01, 5.0000e-01, 6.4698e-08, 2.1337e-09, 2.1152e-09],
        [5.0000e-01, 5.0000e-01, 3.4496e-09, 6.8762e-07, 6.8501e-07],
        [1.0000e+00, 1.0000e+00, 1.0000e+00, 8.2248e-08, 6.5731e-08],
        [1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 2.1817e-07],
        [2.0000e+00, 2.0000e+00, 7.6314e-

## 8D SEM

In [12]:
learning_stats(df8d, true_unmix_jacobians8d, est_unmix_jacobians8d, permute_indices8d, 1e-3)

Correct order identified
----------------------------------
true_jac=tensor([[1.0850, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.9611, 3.5985, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.4147, 4.0179, 2.0769, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.7145, 4.0698, 2.3540, 3.0299, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.8764, 3.3702, 2.4392, 3.5258, 2.2418, 0.0000, 0.0000, 0.0000],
        [1.6313, 2.3522, 1.3941, 2.1047, 1.8677, 2.3872, 0.0000, 0.0000],
        [1.1212, 2.8969, 2.0693, 3.0300, 3.6201, 2.0695, 1.2522, 0.0000],
        [1.7588, 3.3396, 1.7038, 2.6641, 2.6519, 1.6928, 2.2463, 2.7388]])
est_jac=tensor([[0., 0., 0., 1., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 0., 1., 0.],
        [1., 1., 0., 1., 1., 1., 1., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 1., 1., 1., 1., 0.],
        [0., 1., 1., 1., 1., 0., 1., 1.],
        [1., 1., 0., 0., 1., 1., 1.,

## 10D SEM

In [8]:
learning_stats(df10d, true_unmix_jacobians10d, est_unmix_jacobians10d, permute_indices10d, 1e-3)

----------------------------------
true_jac=tensor([[1.1922, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [1.9500, 3.7520, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [1.6217, 2.8972, 2.1334, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [1.0580, 2.2932, 2.2195, 1.6608, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [1.6683, 3.5299, 2.6947, 2.4147, 3.4664, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [1.0154, 2.9480, 2.0143, 1.6252, 3.6249, 2.3881, 0.0000, 0.0000, 0.0000,
         0.0000],
        [1.3168, 3.3697, 2.4557, 2.6948, 2.9257, 1.4711, 3.4429, 0.0000, 0.0000,
         0.0000],
        [1.3574, 4.1002, 2.1081, 2.3233, 2.8957, 1.6435, 3.2952, 2.4223, 0.0000,
         0.0000],
        [1.0870, 2.8801, 1.5293, 1.3736, 2.4271, 1.3763, 2.3635, 2.6303, 2.8053,
         0.0000],
        [1.9490, 2.4977, 2.2379, 2.5739, 3.5761, 2.1033, 3.5990, 

## MLP from Monti et al.

In [18]:
learning_stats(df_monti, true_unmix_jacobians_monti, est_unmix_jacobians_monti, permute_indices_monti, selector_col="n_mixing_layer", weight_threshold=None, hamming_threshold=1e-3, dag_permute=False)

----------------------------------
true_jac=tensor([[ 0.0025,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.0128,  0.0021,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0314, -0.0108,  0.0021,  0.0000,  0.0000,  0.0000],
        [-0.0772,  0.0332, -0.0109,  0.0016,  0.0000,  0.0000],
        [ 0.1414, -0.0690,  0.0295, -0.0095,  0.0022,  0.0000],
        [-0.2077,  0.1093, -0.0530,  0.0234, -0.0095,  0.0021]])
est_jac=tensor([[1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [0., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [0., 1., 1., 1., 0., 0.]])
tensor([[5.0000e-01, 5.0000e-01, 5.0000e-01, 4.0059e-15, 5.8765e-18, 0.0000e+00],
        [5.4537e-10, 1.0000e+00, 5.5369e-10, 5.5369e-10, 5.8286e-15, 0.0000e+00],
        [5.0000e-01, 5.0000e-01, 5.0000e-01, 4.2770e-07, 1.5501e-09, 0.0000e+00],
        [9.9999e-01, 1.5077e+00, 1.5077e+00, 1.5077e+00, 4.7045e-14, 0.0000e+00],
        [1.2109e+00, 1.2109e+0