In [1]:
import numpy as np
import torch
import networkx as nx
from torchvision import transforms
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression, Ridge
from src.CBN import CausalBayesianNetwork as CBN
import modularised_utils as mut
import Linear_Additive_Noise_Models as lanm
import operations as ops
import evaluation_utils as evut
import params
import torchvision
import random
import joblib
import opt_utils as oput
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
import time
import ot
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import seaborn as sns
import torch.nn as nn



In [2]:
resolution = 32
correlation = 0.85

experiment_map = {
    2: 'cm2f2',
    4: 'cm4f2',
    8: 'cm8f2',
    16: 'cm16f2',
    32: 'cm32f2'
}

experiment = experiment_map.get(resolution)

In [3]:
# Variable names
D = 'Digit'
C = 'Color'
I = 'Image'

D_ = 'Digit_'
C_ = 'Color_'
I_ = 'Image_'

In [4]:
num_llsamples = params.n_samples[experiment][0] if experiment in params.n_samples else 1500
num_hlsamples = params.n_samples[experiment][1] if experiment in params.n_samples else 1500

In [6]:
class ColorMNISTDataGenerator:
    def __init__(self, image_size=resolution, correlation=correlation):
        self.correlation = correlation
        self.colors = {
            0: (1.0, 0.0, 0.0),  # Red
            1: (1.0, 0.6, 0.0),  # Orange
            2: (0.8, 1.0, 0.0),  # Yellow-Green
            3: (0.2, 1.0, 0.0),  # Green
            4: (0.0, 1.0, 0.4),  # Blue-Green
            5: (0.0, 1.0, 1.0),  # Cyan
            6: (0.0, 0.4, 1.0),  # Light Blue
            7: (0.2, 0.0, 1.0),  # Blue
            8: (0.8, 0.0, 1.0),  # Purple
            9: (1.0, 0.0, 0.6)   # Pink
        }
        # self.transform = transforms.Compose([
        #     transforms.ToPILImage(),
        #     transforms.Resize(image_size),
        #     transforms.ToTensor(),
        #     transforms.Normalize((0.5,), (0.5,))
        # ])
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(image_size),
            transforms.ToTensor()
        ])
        mnist = torchvision.datasets.MNIST('data/', train=True, download=True)
        self.mnist_data = {}
        for i in range(len(mnist)):
            if mnist.targets[i].item() not in self.mnist_data:
                self.mnist_data[mnist.targets[i].item()] = []
            self.mnist_data[mnist.targets[i].item()].append(mnist.data[i])

    def generate_samples(self, n, intervention=None, pad_to_length=None):
        images = []
        digits = []
        colors = []
        original_lengths = []
        for _ in range(n):
            u_conf = np.random.randint(10)
            if intervention is None:
                digit = u_conf if np.random.random() < self.correlation else np.random.randint(10)
                color = u_conf if np.random.random() < self.correlation else np.random.randint(10)
            else:
                intervention_dict = intervention.vv()
                digit = intervention_dict['digit'] if 'digit' in intervention_dict else (u_conf if np.random.random() < self.correlation else np.random.randint(10))
                color = intervention_dict['color'] if 'color' in intervention_dict else (u_conf if np.random.random() < self.correlation else np.random.randint(10))
            idx = np.random.randint(len(self.mnist_data[digit]))
            img = self.mnist_data[digit][idx]
            color_rgb = self.colors[color]
            img_color = torch.tensor(img).float().unsqueeze(0).repeat(3, 1, 1)
            for c in range(3):
                img_color[c] *= color_rgb[c]
            img_final = self.transform(img_color)  # (3, H, W)
            img_np = img_final.numpy()
            img_np = np.transpose(img_np, (1, 2, 0))  # (H, W, 3)
            if pad_to_length is None:
                # Flatten the entire image (including background)
                images.append(img_np.flatten())
                original_lengths.append(len(img_np.flatten()))
            else:
                # Mask for non-background pixels (assuming background is -1 after normalization)
                mask = ~np.all(np.abs(img_np + 1) < 1e-4, axis=2)
                non_bg_pixels = img_np[mask]  # shape: (num_non_bg_pixels, 3)
                images.append(non_bg_pixels.flatten())
                original_lengths.append(len(non_bg_pixels.flatten()))
            digits.append(digit)
            colors.append(color)
        images = np.array(images)
        digits = np.array(digits)
        colors = np.array(colors)
        digits_onehot = np.eye(10)[digits]
        colors_onehot = np.eye(10)[colors]
        if pad_to_length is None:
            samples = np.concatenate([images, digits_onehot, colors_onehot], axis=1)
            return samples  
        else:
            # Pad to pad_to_length as before
            n = len(images)
            max_len = pad_to_length
            images_padded = np.zeros((n, max_len))
            padding_info = []
            for i, img in enumerate(images):
                images_padded[i, :len(img)] = img
                padding_amount = max_len - len(img)
                padding_info.append({
                    'original_length': len(img),
                    'padded_length': max_len,
                    'padding_added': padding_amount,
                    'digit': digits[i],
                    'color': colors[i]
                })
            samples = np.concatenate([images_padded, digits_onehot, colors_onehot], axis=1)
            return samples, padding_info

In [7]:
iota0 = None  # observational
iota1 = ops.Intervention({'digit': 6})
iota2 = ops.Intervention({'digit': 8})
iota3 = ops.Intervention({'digit': 4})
iota4 = ops.Intervention({'color': 7})  
iota5 = ops.Intervention({'color': 0})
iota6 = ops.Intervention({'color': 4})
iota7 = ops.Intervention({'digit': 6, 'color': 7})
iota8 = ops.Intervention({'digit': 8, 'color': 0})
iota9 = ops.Intervention({'digit': 4, 'color': 4})

eta0 = None  # observational
eta1 = ops.Intervention({'D_': 6})
eta2 = ops.Intervention({'D_': 8})
eta3 = ops.Intervention({'D_': 4})
eta4 = ops.Intervention({'C_': 7})
eta5 = ops.Intervention({'C_': 0})
eta6 = ops.Intervention({'C_': 4})
eta7 = ops.Intervention({'D_': 6, 'C_': 7})
eta8 = ops.Intervention({'D_': 8, 'C_': 0})
eta9 = ops.Intervention({'D_': 4, 'C_': 4})

# omega mapping
omega = {
    iota0: eta0,
    iota1: eta1,
    iota2: eta2,
    iota3: eta3,
    iota4: eta4,
    iota5: eta5,
    iota6: eta6,
    iota7: eta7,
    iota8: eta8,
    iota9: eta9
}

Ill_relevant = list(set(omega.keys()))
Ihl_relevant = list(set(omega.values()))

In [8]:
def get_components(samples):
    """
    Extract image, digit, and color components from samples array
    For new structure: images are flat arrays of non-background pixels (padded).
    """
    image_pixels = samples.shape[1] - 20  
    images = samples[:, :image_pixels]  
    digits_onehot = samples[:, image_pixels:image_pixels+10]
    colors_onehot = samples[:, image_pixels+10:]
    digits = np.argmax(digits_onehot, axis=1)
    colors = np.argmax(colors_onehot, axis=1)
    return {
        'images': images,  
        'digits': digits,
        'colors': colors,
        'digits_onehot': digits_onehot,
        'colors_onehot': colors_onehot
    }

def show_samples(components, num_samples=3):
    fig, axes = plt.subplots(1, num_samples, figsize=(2*num_samples, 2))
    
    for i in range(num_samples):
        img = components['images'][i].transpose(1,2,0)
        img = (img + 1) / 2  
        
        if num_samples == 1:
            ax = axes
        else:
            ax = axes[i]
            
        ax.imshow(img)
        ax.axis('off')
        ax.set_title(f'D:{components["digits"][i]}\nC:{components["colors"][i]}')
    
    plt.tight_layout()
    plt.show()

    return img

In [9]:
def low_to_high_level(samples):
    """
    Transform low-level samples to high-level by clustering non-background pixels into one variable I.
    D and C remain unchanged as they are already high-level variables.

    Args:
    - samples: Low level samples with shape (n_samples, padded_length+20) where:
        - First padded_length: non-background image pixels (padded)
        - Next 10: digit one-hot (D)
        - Last 10: color one-hot (C)
    """
    image_pixels = samples.shape[1] - 20
    # Split into components
    images = samples[:, :image_pixels]  # Non-background image features (padded)
    D = samples[:, image_pixels:image_pixels+10]   # Digit one-hot
    C = samples[:, image_pixels+10:]       # Color one-hot

    I = images.mean(axis=1, keepdims=True) 

    high_level_samples = np.concatenate([D, C, I], axis=1)

    return high_level_samples

In [10]:
data_generator = ColorMNISTDataGenerator()

In [None]:
# Generate samples
Dll_samples = {}
for iota in Ill_relevant:
    Dll_samples[iota] = data_generator.generate_samples(num_llsamples, iota)


# Generate high-level samples
Dhl_samples = {}
for eta in Ihl_relevant:
    if eta is not None:
        # Find corresponding low-level intervention
        iota = [i for i, e in omega.items() if e == eta][0]
        Dhl_samples[eta] = low_to_high_level(Dll_samples[iota])
    else:
        # For observational distribution
        Dhl_samples[eta] = low_to_high_level(Dll_samples[None])

In [12]:
from src.CBN import CausalBayesianNetwork

def create_cmnist_graphs():
    # Low-level causal graph (VL)
    ll_edges = [('digit', 'pixels'), ('color', 'pixels')]  
    ll_causal_graph = CausalBayesianNetwork(ll_edges)
    
    # High-level causal graph (VH)
    hl_edges = [('D_', 'I_'), ('C_', 'I_')]  
    hl_causal_graph = CausalBayesianNetwork(hl_edges)
    
    return ll_causal_graph, hl_causal_graph

# Create the graphs as CausalBayesianNetwork objects
ll_causal_graph, hl_causal_graph = create_cmnist_graphs()

In [16]:
test_size = 0.1

data_observational_ll, Dll_obs_test = train_test_split(Dll_samples[None], test_size=test_size, random_state=42)
data_observational_hl, Dhl_obs_test = train_test_split(Dhl_samples[None], test_size=test_size, random_state=42)

In [17]:
def get_ll_coefficients(data, img_dim=None, n_channels=3, use_ridge=False, alpha=1.0):
    """
    For each channel (R, G, B), fit a linear regression to predict pixel values using:
      - One-hot digit (10 features)
      - One-hot color (10 features)
      - x-coordinate (1 feature)
      - y-coordinate (1 feature)
    Returns:
      - coeffs: dict mapping channel index to regression coefficients
      - residuals: dict mapping channel index to residuals (1D array, one per pixel sample)
    """
    img_size = data.shape[1] - 20
    if img_dim is None:
        img_dim = int(np.sqrt(img_size / n_channels))
    
    digit_onehot = data[:, img_size:img_size+10]
    color_onehot = data[:, img_size+10:img_size+20]
    
    n_samples = data.shape[0]
    coeffs = {}
    residuals = {}

    for channel in range(n_channels):
        X_list = []
        y_list = []
        for i in range(n_samples):
            img_flat = data[i, :img_size]
            img = img_flat.reshape((img_dim, img_dim, n_channels))
            for x in range(img_dim):
                for y in range(img_dim):
                    # Predictors: [digit_onehot, color_onehot, x, y]
                    predictors = np.concatenate([
                        digit_onehot[i],         # shape (10,)
                        color_onehot[i],         # shape (10,)
                        [x], [y]                 # shape (2,)
                    ])
                    X_list.append(predictors)
                    y_list.append(img[x, y, channel])
        X = np.vstack(X_list)  
        y = np.array(y_list)   

        model = Ridge(alpha=alpha) if use_ridge else LinearRegression()
        model.fit(X, y)
        coeffs[channel] = model.coef_  
        residuals[channel] = y - model.predict(X) 

    return coeffs, residuals

def get_hl_coefficients(data, use_ridge=False, alpha=1.0):
    """
    For the high-level data (aggregated image), fit a linear regression to predict
    the aggregated image value using:
      - One-hot digit
      - One-hot color 
    Returns:
      - coeffs: 1D numpy array of regression coefficients 
      - residuals: 1D array of residuals (noise) for the high-level image variable
    """
    D = data[:, :10][:, :-1]  
    C = data[:, 10:20][:, :-1] 
    I = data[:, 20]
    X = np.column_stack((D, C)) 
    y = I
    model = Ridge(alpha=alpha) if use_ridge else LinearRegression()
    model.fit(X, y)
    residuals = y - model.predict(X)
    coeffs = model.coef_  
    return coeffs, residuals

In [18]:
ll_coeffs, ll_residuals = get_ll_coefficients(data_observational_ll)
hl_coeffs, hl_residuals = get_hl_coefficients(data_observational_hl)

In [19]:
# For low-level: concatenate all residuals for all images and all channels
U_L = np.column_stack([
    ll_residuals[0],  
    ll_residuals[1],
    ll_residuals[2]
])  

N = data_observational_ll.shape[0]
img_dim = 32
n_pixels = img_dim * img_dim
U_L = U_L.reshape(N, n_pixels * 3)  

# For high-level: all residuals for all images
U_H = hl_residuals.reshape(-1, 1)  

In [None]:
joblib.dump(U_L, f"data/{experiment}/U_L.pkl")
joblib.dump(U_H, f"data/{experiment}/U_H.pkl")

joblib.dump(data_observational_ll, f"data/{experiment}/data_observational_ll.pkl")
joblib.dump(data_observational_hl, f"data/{experiment}/data_observational_hl.pkl")

In [17]:
Ds = {}
for iota in Ill_relevant:
    Ds[iota] = (Dll_samples[iota], Dhl_samples[omega[iota]])

In [18]:
joblib.dump((ll_causal_graph, Ill_relevant), f"data/{experiment}/LL.pkl")
joblib.dump(ll_coeffs, f"data/{experiment}/ll_coeffs.pkl")

joblib.dump((hl_causal_graph, Ihl_relevant), f"data/{experiment}/HL.pkl")
joblib.dump(hl_coeffs, f"data/{experiment}/hl_coeffs.pkl")

joblib.dump(Ds, f"data/{experiment}/Ds.pkl")

joblib.dump(omega, f"data/{experiment}/omega.pkl")

joblib.dump(data_observational_ll, f"data/{experiment}/Dll_obs_train.pkl")
joblib.dump(data_observational_hl, f"data/{experiment}/Dhl_obs_train.pkl")

joblib.dump(Dll_obs_test, f"data/{experiment}/Dll_obs_test.pkl")
joblib.dump(Dhl_obs_test, f"data/{experiment}/Dhl_obs_test.pkl")

joblib.dump(ll_coeffs, f"data/{experiment}/ll_endogenous_coeff_dict.pkl")
joblib.dump(hl_coeffs, f"data/{experiment}/hl_endogenous_coeff_dict.pkl")

joblib.dump(ll_residuals, f"data/{experiment}/exogenous_LL.pkl")
joblib.dump(hl_residuals, f"data/{experiment}/exogenous_HL.pkl")

['data/cm32f2/exogenous_HL.pkl']

In [19]:
def det_ll_func(parent_info_ll, iota):
    # parent_info_ll: dict with keys 'digit_onehot', 'color_onehot', 'x', 'y', 'coeffs'
    digit_onehot = parent_info_ll['digit_onehot'].copy()  # Add copy
    color_onehot = parent_info_ll['color_onehot'].copy()  # Add copy
    x_coords = parent_info_ll['x']
    y_coords = parent_info_ll['y']
    coeffs = parent_info_ll['coeffs']  # dict: channel -> 22-dim vector

    # Add intervention handling
    if iota is not None:
        intervention_dict = iota.vv()
        if 'digit' in intervention_dict:
            digit_onehot = np.zeros_like(digit_onehot)
            digit_onehot[:, intervention_dict['digit']] = 1
        if 'color' in intervention_dict:
            color_onehot = np.zeros_like(color_onehot)
            color_onehot[:, intervention_dict['color']] = 1

    N = digit_onehot.shape[0]
    img_dim = len(x_coords)
    n_channels = len(coeffs)
    det = np.zeros((N, img_dim * img_dim * n_channels))
    for i in range(N):
        det_i = []
        for c in range(n_channels):
            det_img = []
            for x in range(img_dim):
                for y in range(img_dim):
                    predictors = np.concatenate([digit_onehot[i], color_onehot[i], [x], [y]])
                    det_img.append(np.dot(coeffs[c], predictors))
            det_i.append(det_img)
        det[i, :] = np.concatenate(det_i)
    return torch.tensor(det, dtype=torch.float32)

def det_hl_func(parent_info_hl, omega_iota):
    # parent_info_hl: dict with keys 'digit_onehot', 'color_onehot', 'coeffs'
    digit_onehot = parent_info_hl['digit_onehot'].copy()  
    color_onehot = parent_info_hl['color_onehot'].copy()  
    coeffs = parent_info_hl['coeffs']  

    if omega_iota is not None:
        intervention_dict = omega_iota.vv()
        if 'D_' in intervention_dict:
            digit_onehot = np.zeros_like(digit_onehot)
            digit_onehot[:, intervention_dict['D_']] = 1
        if 'C_' in intervention_dict:
            color_onehot = np.zeros_like(color_onehot)
            color_onehot[:, intervention_dict['C_']] = 1

    N = digit_onehot.shape[0]
    det = []
    for i in range(N):
        predictors = np.concatenate([digit_onehot[i][:-1], color_onehot[i][:-1]])  
        det.append(np.dot(coeffs, predictors))
    return torch.tensor(np.array(det).reshape(-1, 1), dtype=torch.float32)

In [20]:
parent_info_ll = {
    'digit_onehot': data_observational_ll[:, -20:-10],
    'color_onehot': data_observational_ll[:, -10:],
    'x': np.arange(img_dim),
    'y': np.arange(img_dim),
    'coeffs': ll_coeffs
}
parent_info_hl = {
    'digit_onehot': data_observational_hl[:, :10],
    'color_onehot': data_observational_hl[:, 10:20],
    'coeffs': hl_coeffs
}

In [21]:
det_ll_dict = {}
for iota in Ill_relevant:
    det_ll_dict[iota] = det_ll_func(parent_info_ll, iota)

det_hl_dict = {}
for eta in Ihl_relevant:
    det_hl_dict[eta] = det_hl_func(parent_info_hl, eta)

In [None]:
joblib.dump(det_ll_dict, f"data/{experiment}/det_ll_dict.pkl")
joblib.dump(det_hl_dict, f"data/{experiment}/det_hl_dict.pkl")

In [22]:
digit_onehotL = data_observational_ll[:, 3072:3082]  
color_onehotL = data_observational_ll[:, 3082:3092]  

digit_onehotH = data_observational_hl[:, :10]  
color_onehotH = data_observational_hl[:, 10:20]  

### DIROCA optimization

In [15]:
def compute_empirical_radius(N, eta, c1=1.0, c2=1.0, alpha=2.0, m=3):
    """
    Compute epsilon_N(eta) for empirical Wasserstein case.

    Parameters:
    - N: int, number of samples
    - eta: float, confidence level (0 < eta < 1)
    - c1: float, constant from theorem (default 1.0, adjust if needed)
    - c2: float, constant from theorem (default 1.0, adjust if needed)
    - alpha: float, light-tail exponent (P[exp(||ξ||^α)] ≤ A)
    - m: int, ambient dimension

    Returns:
    - epsilon: float, the concentration radius
    """
    assert 0 < eta < 1, "eta must be in (0,1)"
    threshold = np.log(c1 / eta) / c2
    if N >= threshold:
        exponent = min(1/m, 0.5)
    else:
        exponent = 1 / alpha

    epsilon = (np.log(c1 / eta) / (c2 * N)) ** exponent
    return epsilon


In [16]:
l = len(ll_causal_graph.nodes())
h = len(hl_causal_graph.nodes())

In [None]:
ll_bound = round(compute_empirical_radius(N=num_llsamples, eta=0.05, c1=1000.0, c2=1.0, alpha=2.0, m=l), 3)
hl_bound = round(compute_empirical_radius(N=num_hlsamples, eta=0.05, c1=1000.0, c2=1.0, alpha=2.0, m=h), 3)

In [65]:
epsilon, delta = ll_bound, hl_bound

eta_max = 0.001
eta_min = 0.001

max_iter = 5000
num_steps_min = 5
num_steps_max = 2

robust_L = True 
robust_H = True

initialization = 'random'

tol  = 1e-4
seed = 23

In [66]:
opt_params_erica = {
                        'U_L': U_L,
                        'U_H': U_H,
                        'Ill': Ill_relevant,
                        'parent_info_ll': parent_info_ll,
                        'parent_info_hl': parent_info_hl,
                        'digit_onehotL': digit_onehotL,
                        'color_onehotL': color_onehotL,
                        'digit_onehotH': digit_onehotH,
                        'color_onehotH': color_onehotH,
                        'omega': omega,
                        'epsilon': epsilon,
                        'delta': delta,
                        'eta_min': eta_min,
                        'eta_max': eta_max,
                        'num_steps_min': num_steps_min,
                        'num_steps_max': num_steps_max,
                        'max_iter': max_iter,
                        'tol': tol,
                        'seed': seed,
                        'robust_L': robust_L,
                        'robust_H': robust_H,
                        'initialization': initialization,
                        'experiment': experiment
                    }

In [69]:
def empirical_objective_cm(
    U_L, U_H, T, Theta, Phi, Ill, omega,
    det_ll_dict, det_hl_dict, 
    parent_info_ll, parent_info_hl,
    digit_onehot, color_onehot, digit_onehotH, color_onehotH  
):
    loss_iota = 0
    N = U_L.shape[0]
    for iota in Ill:
        det_ll = det_ll_dict[iota]
        det_hl = det_hl_dict[omega[iota]]
        endo_ll = det_ll + (U_L + Theta)  
        endo_ll_full = torch.cat([
            endo_ll, 
            torch.tensor(digit_onehot, dtype=endo_ll.dtype, device=endo_ll.device),
            torch.tensor(color_onehot, dtype=endo_ll.dtype, device=endo_ll.device)
        ], dim=1) 
        endo_hl = det_hl + (U_H + Phi)  
        endo_hl_full = torch.cat([
            endo_hl,
            torch.tensor(digit_onehotH, dtype=endo_hl.dtype, device=endo_hl.device),
            torch.tensor(color_onehotH, dtype=endo_hl.dtype, device=endo_hl.device)
        ], dim=1) 
        mapped_ll = (T @ endo_ll_full.T).T  
        diff = mapped_ll - endo_hl_full
        loss_iota += torch.norm(diff, p='fro')**2 / diff.numel()
    loss = loss_iota / len(Ill)
    return loss

In [70]:
def run_empirical_erica_optimization(
    U_L, U_H, Ill, omega, epsilon, delta, eta_min, eta_max,
    num_steps_min, num_steps_max, max_iter, tol, seed, robust_L, robust_H, initialization, experiment,
    det_ll_dict, det_hl_dict, parent_info_ll, parent_info_hl, digit_onehotL, color_onehotL, digit_onehotH, color_onehotH
):
    torch.manual_seed(seed)

    method  = 'erica' if robust_L or robust_H else 'enrico'
    num_steps_min = 1 if method == 'enrico' else num_steps_min

    U_L = torch.as_tensor(U_L, dtype=torch.float32)
    U_H = torch.as_tensor(U_H, dtype=torch.float32)
    
    N, l = U_L.shape
    _, h = U_H.shape

    T = torch.randn(21, 3092, requires_grad=True)
    if initialization == 'random':
        Theta = torch.randn(N, l, requires_grad=True)
        Phi   = torch.randn(N, h, requires_grad=True)
    elif initialization == 'projected':
        Theta = oput.init_in_frobenius_ball((N, l), epsilon)
        Phi   = oput.init_in_frobenius_ball((N, h), delta)

    # Create optimizers
    optimizer_T   = torch.optim.Adam([T], lr=eta_min)
    optimizer_max = torch.optim.Adam([Theta, Phi], lr=eta_max)
    
    prev_T_objective = float('inf')
    
    for iteration in tqdm(range(max_iter)):
        objs_T, objs_max = [], []
        # Step 1: Minimize with respect to T
        for _ in range(num_steps_min):
            optimizer_T.zero_grad()
            T_objective = empirical_objective_cm(
                U_L, U_H, T, Theta, Phi, Ill, omega,
                det_ll_dict, det_hl_dict, parent_info_ll, parent_info_hl, digit_onehotL, color_onehotL, digit_onehotH, color_onehotH
            )
            objs_T.append(T_objective.item())
            T_objective.backward()
            optimizer_T.step()
        # Step 2: Maximize with respect to Theta and Phi
        if method == 'erica':
            for _ in range(num_steps_max):
                optimizer_max.zero_grad()
                max_objective = -empirical_objective_cm(
                    U_L, U_H, T, Theta, Phi, Ill, omega,
                    det_ll_dict, det_hl_dict, parent_info_ll, parent_info_hl, digit_onehotL, color_onehotL, digit_onehotH, color_onehotH
                )
                max_objective.backward()
                optimizer_max.step()
                # Project onto constraint sets
                with torch.no_grad():
                    Theta.data = oput.project_onto_frobenius_ball(Theta, epsilon)
                    Phi.data   = oput.project_onto_frobenius_ball(Phi, delta)
                mobj = empirical_objective_cm(
                    U_L, U_H, T, Theta, Phi, Ill, omega,
                    det_ll_dict, det_hl_dict, parent_info_ll, parent_info_hl, digit_onehotL, color_onehotL, digit_onehotH, color_onehotH
                )
                objs_max.append(mobj.item())

        with torch.no_grad():
            current_T_objective = T_objective.item()
            if abs(prev_T_objective - current_T_objective) < tol:
                print(f"Converged at iteration {iteration + 1}")
                break
            prev_T_objective = current_T_objective
            
    T       = T.detach().numpy()
    paramsL = {'pert_U': Theta.detach().numpy(), 'radius_worst': epsilon,
                    'pert_hat': U_L, 'radius': epsilon}
    paramsH = {'pert_U': Phi.detach().numpy(), 'radius_worst': delta,
                    'pert_hat': U_H, 'radius': delta}
    
    if method == 'erica':
        radius_worst_L          = evut.compute_empirical_worst_case_distance(paramsL)
        paramsL['radius_worst'] = radius_worst_L
        radius_worst_H          = evut.compute_empirical_worst_case_distance(paramsH)
        paramsH['radius_worst'] = radius_worst_H

    opt_params = {'L': paramsL, 'H': paramsH}

    save_dir = f"data/{experiment}/{method}"
    os.makedirs(save_dir, exist_ok=True)
    joblib.dump(opt_params, f"data/{experiment}/{method}/opt_params.pkl")

    return opt_params, T

In [None]:
eps_delta_values = [100, 30, 8, 4, ll_bound, 2, 1]
diroca_train_results_empirical = {}

for eps_delta in eps_delta_values:
    print(f"Training for ε=δ = {eps_delta}")
    
    if eps_delta == ll_bound:
        opt_params_erica['epsilon'] = ll_bound
        opt_params_erica['delta']   = hl_bound
    else:
        opt_params_erica['epsilon'] = eps_delta
        opt_params_erica['delta']   = eps_delta
    
    opt_params_erica['det_ll_dict'] = det_ll_dict
    opt_params_erica['det_hl_dict'] = det_hl_dict
    
    # Run DIROCA optimization
    params_empirical, T_empirical = run_empirical_erica_optimization(**opt_params_erica)
    
    if eps_delta == ll_bound:
        diroca_train_results_empirical['T_'+str(ll_bound)+'-'+str(hl_bound)] = {
            'optimization_params': params_empirical,
            'T_matrix': T_empirical
        }
    else:
        diroca_train_results_empirical['T_'+str(eps_delta)] = {
            'optimization_params': params_empirical,
            'T_matrix': T_empirical
        }

print("\nTraining completed. T matrices stored in trained_results dictionary.")
print("Available ε=δ values:", list(diroca_train_results_empirical.keys()))

In [None]:
def perfect_abstraction(px_samples, py_samples, tau_threshold=1e-2):

    tau_adj = np.linalg.pinv(px_samples) @ py_samples
    tau_adj_mask = np.abs(tau_adj) > tau_threshold
    tau_adj = tau_adj * tau_adj_mask

    return tau_adj

def noisy_abstraction(px_samples, py_samples, tau_threshold=1e-1, refit_coeff=False):

    tau_adj_hat = np.linalg.pinv(px_samples) @ py_samples
    tau_mask_hat = np.argmax(np.abs(tau_adj_hat), axis=1)
    abs_nodes = py_samples.shape[1]
    tau_mask_hat = np.eye(abs_nodes)[tau_mask_hat]
    tau_mask_hat *= np.array(np.abs(tau_adj_hat) > tau_threshold, dtype=np.int32)
    
    if refit_coeff:
        for y in range(tau_mask_hat.shape[1]):
            block = np.where(tau_mask_hat[:, y] == 1)[0]
            if len(block) > 0:
                tau_adj_hat[block, y] = np.linalg.pinv(px_samples[:, block]) @ py_samples[:, y]
    
    tau_adj_hat = tau_mask_hat * tau_adj_hat
    return tau_adj_hat

def abs_lingam_reconstruction_v2_modified(df_base, df_abst, n_paired_samples=None, style="Perfect", tau_threshold=1e-2):
    """
    
    Args:
        df_base: numpy array of concrete observations (pixels only)
        df_abst: numpy array of abstract observations (image value only)
    """
    # Get dimensions
    n_samples_base, d = df_base.shape
    n_samples_abst = df_abst.shape[0]
    
    
    # Create joint dataset D_J by taking a subset of paired samples
    if n_paired_samples is None:
        n_paired_samples = min(n_samples_base, n_samples_abst)
    
    n_paired_samples = min(n_paired_samples, n_samples_base, n_samples_abst)
    
    # Create D_J using the first n_paired_samples
    D_J_base = df_base[:n_paired_samples]  # pixels only
    D_J_abst = df_abst[:n_paired_samples] 
    #D_J_abst = df_abst[:n_paired_samples].reshape(-1, 1)  # image value only, ensure 2D
    
    if style == "Perfect":
        T = perfect_abstraction(D_J_base, D_J_abst, tau_threshold)
    elif style == "Noisy":
        T = noisy_abstraction(D_J_base, D_J_abst, tau_threshold, False)
    else:
        raise ValueError(f"Unknown style {style}")

    return T

def run_abs_lingam_complete_modified(data_observational_ll, data_observational_hl, n_paired_samples=1000):

    df_base = data_observational_ll
    df_abst = data_observational_hl
    
    styles = ["Perfect", "Noisy"]
    results = {}
    
    for style in styles:
        T = abs_lingam_reconstruction_v2_modified(
            df_base, 
            df_abst,
            n_paired_samples=n_paired_samples,
            style=style
        )
        results[style] = {'T': T}

    return results

In [74]:
abslingam_results = run_abs_lingam_complete_modified(data_observational_ll, data_observational_hl)

T_pa_pixels = abslingam_results['Perfect']['T']
T_na_pixels = abslingam_results['Noisy']['T']

diroca_train_results_empirical['T_pa'] = {
    'optimization_params': {
        'L': {'pert_U': U_L},
        'H': {'pert_U': U_H}
    }, 
    'T_matrix': T_pa_pixels.T
}

diroca_train_results_empirical['T_na'] = {
    'optimization_params': {
        'L': {'pert_U': U_L},
        'H': {'pert_U': U_H}
    }, 
    'T_matrix': T_na_pixels.T
}

### GRADCA optimization

In [None]:
params_enrico, T_enrico = run_empirical_erica_optimization(**{**opt_params_erica, 'robust_L': False, 'robust_H': False})

In [100]:
diroca_train_results_empirical['T_0.00'] = {
                                'optimization_params': params_enrico,
                                'T_matrix': T_enrico
                            }

### BARYCA optimization

In [90]:
def run_empirical_bary_optim_new(
    det_ll_dict,
    det_hl_dict,
    Ill_relevant,
    omega,
    U_L, U_H,
    digit_onehotL, color_onehotL,
    digit_onehotH, color_onehotH,
    max_iter, tol, seed
):
    """
    Barycentric optimization using reconstructed endogenous samples.
    Returns T of shape (21, 3092).
    """
    torch.manual_seed(seed)
    np.random.seed(seed)
    
    bary_L_list = []
    bary_H_list = []

    for iota in Ill_relevant:
        # Reconstruct low-level and high-level endogenous samples
        det_ll = det_ll_dict[iota]
        det_hl = det_hl_dict[omega[iota]]

        # Convert to torch tensors if needed
        det_ll = torch.tensor(det_ll, dtype=torch.float32) if not torch.is_tensor(det_ll) else det_ll
        det_hl = torch.tensor(det_hl, dtype=torch.float32) if not torch.is_tensor(det_hl) else det_hl
        U_L_tensor = torch.tensor(U_L, dtype=torch.float32) if not torch.is_tensor(U_L) else U_L
        U_H_tensor = torch.tensor(U_H, dtype=torch.float32) if not torch.is_tensor(U_H) else U_H

        # Reconstruct full endogenous inputs
        endo_ll = det_ll + U_L_tensor  # (N, 3072)
        endo_ll_full = torch.cat([
            endo_ll,
            torch.tensor(digit_onehotL, dtype=endo_ll.dtype, device=endo_ll.device),
            torch.tensor(color_onehotL, dtype=endo_ll.dtype, device=endo_ll.device)
        ], dim=1)  # (N, 3092)

        endo_hl = det_hl + U_H_tensor  # (N, 1)
        endo_hl_full = torch.cat([
            endo_hl,
            torch.tensor(digit_onehotH, dtype=endo_hl.dtype, device=endo_hl.device),
            torch.tensor(color_onehotH, dtype=endo_hl.dtype, device=endo_hl.device)
        ], dim=1)  # (N, 21)

        # Compute barycenters (mean over samples)
        bary_L = endo_ll_full.mean(dim=0, keepdim=True)  # (1, 3092)
        bary_H = endo_hl_full.mean(dim=0, keepdim=True)  # (1, 21)

        bary_L_list.append(bary_L)
        bary_H_list.append(bary_H)

    # Stack all barycenters
    bary_L = torch.cat(bary_L_list, dim=0)  # (num_interventions, 3092)
    bary_H = torch.cat(bary_H_list, dim=0)  # (num_interventions, 21)

    # Optimize T such that bary_H ≈ bary_L @ T.T, so T: (21, 3092)
    h = bary_H.shape[1]
    l = bary_L.shape[1]
    T = torch.randn(h, l, requires_grad=True)
    optimizer_T = torch.optim.Adam([T], lr=0.001)

    previous_objective = float('inf')
    for step in tqdm(range(int(max_iter))):
        optimizer_T.zero_grad()
        diff = bary_L @ T.T - bary_H  # (num_interventions, 21)
        objective_T = torch.norm(diff, p='fro') ** 2 / diff.numel()

        if abs(previous_objective - objective_T.item()) < tol:
            print(f"Converged at step {step + 1}/{max_iter} with objective: {objective_T.item()}")
            break

        previous_objective = objective_T.item()
        objective_T.backward()
        optimizer_T.step()

    return T.detach().cpu().numpy()  # (21, 3092)


In [95]:
opt_params_bary = {
                        'U_L':U_L,
                        'U_H':U_H,
                        'Ill_relevant':Ill_relevant,
                        'omega':omega,
                        'digit_onehotL':digit_onehotL,
                        'color_onehotL':color_onehotL,
                        'digit_onehotH':digit_onehotH,
                        'color_onehotH':color_onehotH,
                        'det_ll_dict':det_ll_dict,
                        'det_hl_dict':det_hl_dict,
                        'max_iter':2000,
                        'tol':tol,
                        'seed':seed
                    }
                          

In [96]:

T_bary = run_empirical_bary_optim_new(**opt_params_bary)
params_bary = {'L':{}, 'H':{}}

 14%|█▍        | 276/2000 [00:00<00:02, 722.73it/s]

Converged at step 277/2000 with objective: 0.007532527204602957





In [97]:
diroca_train_results_empirical['T_b'] = {
                                'optimization_params': params_bary,
                                'T_matrix': T_bary
                            }

### RSACA optimization

In [81]:
def empirical_objective_no_max(U_L, U_H, T, digit_onehotL, color_onehotL, digit_onehotH, color_onehotH, Ill, omega):

    loss_iota = 0
    for iota in Ill:
       
        det_ll = det_ll_dict[iota]
        det_hl = det_hl_dict[omega[iota]]

        endo_ll = det_ll + U_L  
        endo_ll_full = torch.cat([
            endo_ll, 
            torch.tensor(digit_onehotL, dtype=endo_ll.dtype, device=endo_ll.device),
            torch.tensor(color_onehotL, dtype=endo_ll.dtype, device=endo_ll.device)
        ], dim=1) 
        endo_hl = det_hl + U_H  
        endo_hl_full = torch.cat([
            endo_hl,
            torch.tensor(digit_onehotH, dtype=endo_hl.dtype, device=endo_hl.device),
            torch.tensor(color_onehotH, dtype=endo_hl.dtype, device=endo_hl.device)
        ], dim=1) 
        mapped_ll = (T @ endo_ll_full.T).T  

        diff = mapped_ll - endo_hl_full

        loss_iota += torch.norm(diff, p='fro')**2 / (diff.shape[0] * diff.shape[1])

    loss = loss_iota / len(Ill)
    return loss

In [83]:
def run_empirical_smooth_optimization(U_L, U_H, Ill, digit_onehotL, color_onehotL, digit_onehotH, color_onehotH, omega, eta_min,
                                    num_steps_min, max_iter, tol, seed,
                                    noise_sigma, num_noise_samples):
    """
    Run empirical optimization with randomized smoothing.
    """
    torch.manual_seed(seed)

    U_L = torch.as_tensor(U_L, dtype=torch.float32)
    U_H = torch.as_tensor(U_H, dtype=torch.float32)
    
    N, l = U_L.shape
    _, h = U_H.shape
    
    T = torch.randn(21, 3092, requires_grad=True)
    optimizer_T = torch.optim.Adam([T], lr=eta_min)
    
    prev_T_objective = float('inf')
    
    for iteration in tqdm(range(max_iter)):
        objs_T = []
        
        for _ in range(num_steps_min):
            optimizer_T.zero_grad()
            
            smoothed_objective = torch.tensor(0.0)
            
            for _ in range(num_noise_samples):
                # Add noise to T
                noise = torch.randn_like(T) * noise_sigma
                noisy_T = T + noise
                
                T_objective = empirical_objective_no_max(
                    U_L, U_H, noisy_T, digit_onehotL, color_onehotL, digit_onehotH, color_onehotH, Ill, omega
                )
                smoothed_objective += T_objective
            
            smoothed_objective /= num_noise_samples
            objs_T.append(smoothed_objective.item())
            
            
            smoothed_objective.backward()
            
            torch.nn.utils.clip_grad_norm_([T], max_norm=1.0)
            
            optimizer_T.step()
            
            # Check for NaN
            if torch.isnan(T).any():
                print("T contains NaN! Returning zero matrix.")
                print('Failed at step:', iteration + 1)
                return torch.zeros_like(T).detach()
        
        # Check convergence of T's objective
        with torch.no_grad():
            current_T_objective = smoothed_objective.item()
            if abs(prev_T_objective - current_T_objective) < tol:
                print(f"Converged at iteration {iteration + 1}")
                break
            prev_T_objective = current_T_objective
            
    opt_params = {'L': {}, 'H': {}}
    return opt_params, T.detach().numpy()

In [84]:
opt_params_smooth = {
                        'U_L': U_L,
                        'U_H': U_H,
                        'Ill': Ill_relevant,
                        'digit_onehotL': digit_onehotL,
                        'color_onehotL': color_onehotL,
                        'digit_onehotH': digit_onehotH,
                        'color_onehotH': color_onehotH,
                        'omega': omega,
                        'eta_min': eta_min,
                        'num_steps_min': num_steps_min,
                        'max_iter': 300,
                        'tol': tol,
                        'seed': seed,
                        'noise_sigma': 0.1,
                        'num_noise_samples': 10
                        }

In [85]:
params_smooth, T_smooth = run_empirical_smooth_optimization(**opt_params_smooth)

100%|██████████| 300/300 [39:24<00:00,  7.88s/it]


In [86]:
diroca_train_results_empirical['T_s'] = {
                                'optimization_params': params_smooth,
                                'T_matrix': T_smooth
                            }

## Save Results

In [None]:
joblib.dump(diroca_train_results_empirical, f"data/{experiment}/diroca_train_results_empirical.pkl")