In [1]:
import json
import pickle
import time
from datetime import datetime, timedelta
from itertools import product

import numpy as np
import seaborn as sns
import torch
import torch.nn as nn
import joblib
import os

import matplotlib.pyplot as plt
from scipy.stats import wilcoxon
from tqdm import tqdm

# Local modules
import modularised_utils as mut
import opt_utils as oput
import evaluation_utils as evut
import Linear_Additive_Noise_Models as lanm
import operations as ops
import params
import random

np.random.seed(0)



In [2]:
experiment = 'lucas6x3'

In [3]:
T_results_emp = joblib.load(f"data/{experiment}/diroca_train_results_empirical.pkl")

In [4]:
coeff_estimation = True

Dll_obs = joblib.load(f"data/{experiment}/Dll_obs_test.pkl")
Dhl_obs =  joblib.load(f"data/{experiment}/Dhl_obs_test.pkl")

LLmodels = joblib.load(f"data/{experiment}/LLmodels.pkl")
HLmodels = joblib.load(f"data/{experiment}/HLmodels.pkl")

num_llsamples, num_hlsamples  = Dll_obs.shape[0], Dhl_obs.shape[0]

Gll, Ill = mut.load_model(experiment, 'LL')
Ghl, Ihl = mut.load_model(experiment, 'HL')

n_varsll, n_varshl = len(Gll.nodes()), len(Ghl.nodes())

omega    = mut.load_omega_map(experiment)

if coeff_estimation == True:
    ll_coeffs = mut.get_coefficients(Dll_obs, Gll)
    hl_coeffs = mut.get_coefficients(Dhl_obs, Ghl) 
else:
    ll_coeffs = mut.load_coeffs(experiment, 'LL')
    hl_coeffs = mut.load_coeffs(experiment, 'HL')

U_ll_hat, mu_U_ll_hat, Sigma_U_ll_hat = mut.lan_abduction(Dll_obs, Gll, ll_coeffs)
U_hl_hat, mu_U_hl_hat, Sigma_U_hl_hat = mut.lan_abduction(Dhl_obs, Ghl, hl_coeffs)

data = evut.generate_empirical_data(LLmodels, HLmodels, omega, U_ll_hat, U_hl_hat)

In [5]:
test_observ        = True
test_interv        = True 
metric             = 'fro'
num_iter           = 20

if test_observ and test_interv:
    test_data = data

elif test_observ:
    test_data = {None: data[None]}

elif test_interv:
    test_data = {k: v for k, v in data.items() if k is not None}

## 0-shift

In [None]:
if coeff_estimation == True:
    results_single = {method: {'errors': [], 'mean': 0, 'ci': 0} for method in T_results_emp.keys()}

    for name, method_data in T_results_emp.items():
        T = method_data['T_matrix']
        errors = []  # Store errors for each intervention
        scale_factor = 1/np.sqrt(len(Ill))

        for iota in Ill:
            L_i = LLmodels[iota].F
            H_i = HLmodels[omega[iota]].F

            D_l = L_i @ U_ll_hat.T
            D_h = H_i @ U_hl_hat.T
            
            base_norm = D_l#/ np.linalg.norm(D_l, 'fro')
            abst_norm = D_h#/ np.linalg.norm(D_h, 'fro')
            
            tau_base = T @ base_norm
            dist = evut.compute_empirical_distance(tau_base, abst_norm, 'fro')
            errors.append(dist)  # Store individual errors

        # Calculate mean and CI
        mean_error = np.mean(errors)
        std_error = np.std(errors)
        ci = std_error

        # Store all statistics
        results_single[name] = {
            'errors': errors,
            'mean': mean_error,
            'ci': ci
        }

    max_mean = max(v['mean'] for v in results_single.values())
    scale_factor = 1/max_mean

    # Sort by mean error
    results_single = dict(sorted(results_single.items(), key=lambda x: x[1]['mean']))

    # Print results
    print("\n" + "="*100)
    print(f"{'Method':<15} {'Error (mean ± std)':<35}")
    print("="*100)

    for method, stats in results_single.items():
        print(f"{method:<15} {stats['mean']:>8.4f} ± {stats['ci']:<8.4f}")

    # After running the 0-shift test, we can load the coefficients to proceed.
    ll_coeffs = mut.load_coeffs(experiment, 'LL')
    hl_coeffs = mut.load_coeffs(experiment, 'HL')
else:
    print('No coeff estimation')


## ρ-shift

In [21]:
rad_values = np.arange(0.05, 16.05, 1).tolist()  
sample_forms = ['sample'] #['boundary', 'sample']

hat_dict = {'L': U_ll_hat, 'H': U_hl_hat}

worst = 'T_8'
U_worst_L = T_results_emp[worst]['optimization_params']['L']['pert_U']
U_worst_H = T_results_emp[worst]['optimization_params']['H']['pert_U']


target_samplesL = U_ll_hat.shape[0]
target_samplesH = U_hl_hat.shape[0]

indicesL = np.random.choice(U_worst_L.shape[0], size=target_samplesL, replace=False)
indicesH = np.random.choice(U_worst_H.shape[0], size=target_samplesH, replace=False)

U_worst_L = U_worst_L[indicesL]
U_worst_H = U_worst_H[indicesH]

worst_dict = {'L': U_worst_L, 'H': U_worst_H}

center = 'worst'
if center == 'hat':
    center_matrix = hat_dict
elif center == 'worst':
    center_matrix = worst_dict

coverage_type='uniform'

### a. Familly of Pertubations

In [22]:
# Generate perturbation families
pert_family_L = evut.generate_perturbation_family(
    np.zeros_like(hat_dict['L']),
    k=100,  
    r_mu=0.0,
    r_sigma=1.0,
    coverage=coverage_type
)

pert_family_H = evut.generate_perturbation_family(
    np.zeros_like(hat_dict['H']),
    k=100,
    r_mu=0.0,
    r_sigma=1.0,
    coverage=coverage_type
)

In [23]:
results = {
    sample_form: {
        'empirical': {method: [] for method in T_results_emp.keys()}
    } for sample_form in sample_forms
}

for pert_L, pert_H in zip(pert_family_L, pert_family_H):
    for sample_form in sample_forms:
        for name, method_data in T_results_emp.items():
            T = method_data['T_matrix']
            
            distances = []
            
            for iota in Ill:
                L_i = LLmodels[iota].F
                H_i = HLmodels[omega[iota]].F
                
                pert_noise_L = center_matrix['L'].T + pert_L.T
                pert_noise_H = center_matrix['H'].T + pert_H.T
                
                base_norm = L_i @ pert_noise_L
                abst_norm = H_i @ pert_noise_H
                
                tau_base = T @ base_norm
                dist = evut.compute_empirical_distance(tau_base, abst_norm, 'fro')
                distances.append(dist)
            
            results[sample_form]['empirical'][name].extend(distances)

In [None]:
# Print results with ranking
print("\n" + "="*100)
print(f"{'Rank':<5} {'Method':<15} {'Empirical Distance (mean ± std)':<35}")
print("="*100)

for sample_form in sample_forms:
    print(f"\nSample form: {sample_form}")
    print("-"*100)
    
    # Calculate means and stds for all methods
    method_stats = {}
    for method in T_results_emp.keys():
        distances = results[sample_form]['empirical'][method]
        mean = np.mean(distances)
        std = np.std(distances)
        method_stats[method] = (mean, std)
    
    # Sort methods by mean error (worst to best)
    sorted_methods = sorted(method_stats.items(), key=lambda x: x[1][0], reverse=True)
    
    # Print ranked results
    for rank, (method, (mean, std)) in enumerate(sorted_methods, 1):
        print(f"{rank:<5} {method:<15} "
              f"{mean:>8.4f} ± {std/10:<8.4f}")

print("="*100)

# Print summary of best and worst methods
for sample_form in sample_forms:
    print(f"\nSummary for {sample_form} sampling:")
    print("-"*100)
    
    # Get sorted methods
    method_stats = {method: (np.mean(results[sample_form]['empirical'][method]),
                           np.std(results[sample_form]['empirical'][method]))
                   for method in T_results_emp.keys()}
    sorted_methods = sorted(method_stats.items(), key=lambda x: x[1][0], reverse=True)
    
    # Print worst and best
    worst_method, (worst_error, worst_std) = sorted_methods[0]
    best_method, (best_error, best_std) = sorted_methods[-1]

In [25]:
def compute_errors(
    T_results_emp,
    Ill_relevant,
    LLmodels,
    HLmodels,
    omega,
    base_noise_L,
    base_noise_H,
    shift=False,
    r_mu=0.0,
    r_sigma=0.0,
    num_perturbations=1,
    coverage_type='uniform'
):

    if shift:
        pert_family_L = evut.generate_perturbation_family(
            np.zeros_like(base_noise_L),
            k=num_perturbations,
            r_mu=r_mu,
            r_sigma=r_sigma,
            coverage=coverage_type
        )
        pert_family_H = evut.generate_perturbation_family(
            np.zeros_like(base_noise_H),
            k=num_perturbations,
            r_mu=r_mu,
            r_sigma=r_sigma,
            coverage=coverage_type
        )
    else:
        # No perturbations, only one "fake" no-perturbation
        pert_family_L = [np.zeros_like(base_noise_L)]
        pert_family_H = [np.zeros_like(base_noise_H)]
        num_perturbations = 1  # force to 1

    # Initialize result dictionary
    results = {method: [] for method in T_results_emp.keys()}

    # Loop over perturbations (only one if shift=False)
    for pert_L, pert_H in zip(pert_family_L, pert_family_H):
        for method_name, method_data in T_results_emp.items():
            T = method_data['T_matrix']
            distances = []
            for iota in Ill_relevant:
                L_i = LLmodels[iota].F
                H_i = HLmodels[omega[iota]].F

                pert_noise_L = base_noise_L.T + pert_L.T
                pert_noise_H = base_noise_H.T + pert_H.T

                base_norm = L_i @ pert_noise_L
                abst_norm = H_i @ pert_noise_H

                tau_base = T @ base_norm
                dist = evut.compute_empirical_distance(tau_base, abst_norm, 'fro')
                distances.append(dist)

            results[method_name].extend(distances)

    # Compute mean and CI for each method
    stats = {}
    for method, vals in results.items():
        vals = np.array(vals)
        mean = np.mean(vals)
        std = np.std(vals)
        ci = std
        stats[method] = {'mean': mean, 'ci': ci, 'all': vals}

    return results, stats

## F-misspecification

In [88]:
def contaminate_structural_matrix(M, contamination_fraction, contamination_type, num_segments=10, seed=None):
   """
   Contaminates a linear transformation matrix M to break its strict linearity.
  
   Args:
       M (np.ndarray): Original linear transformation matrix (n x m).
       contamination_fraction (float): Magnitude of contamination (e.g., between 0.05 and 1.0).
       contamination_type (str): Type of contamination to apply. Options are:
                                 'multiplicative', 'nonlinear', or 'piecewise'.
       num_segments (int): Number of segments for piecewise linear contamination (default: 3).
       seed (int, optional): Random seed for reproducibility.
      
   Returns:
       np.ndarray: The contaminated matrix.
   """
   rng = np.random.default_rng(seed)
   M_cont = M.copy() 
   n, m = M.shape


   if contamination_type == "multiplicative":
       # Apply element-wise multiplicative noise (preserving zeros below the main diagonal)
       # Only perturb the upper-triangular part.
       noise = rng.uniform(low=1.0 - contamination_fraction, high=1.0 + contamination_fraction, size=M.shape)
       # Create a mask for the upper triangular (including diagonal)
       mask = np.triu(np.ones_like(M))
       M_cont = M * (1 - mask + mask * noise)
  
   elif contamination_type == "nonlinear":
       # Apply a nonlinear function to L: for instance, add a sine-based perturbation.
       M_cont = M + contamination_fraction * np.sin(M)
  
   elif contamination_type == "piecewise":
       # Contaminate each row with a piecewise linear function.
       def piecewise_contaminate_row(row, cont_frac, segments, rng):
           n_elem = len(row)
           # Choose random breakpoints among indices
           if segments < 2:
               return row  # nothing to do
           breakpoints = np.sort(rng.integers(low=1, high=n_elem, size=segments - 1))
           breakpoints = np.concatenate(([0], breakpoints, [n_elem]))
           contaminated_row = np.empty_like(row)
           # For each segment, assign a random multiplicative factor.
           for j in range(len(breakpoints) - 1):
               start = breakpoints[j]
               end = breakpoints[j+1]
               factor = 1.0 + rng.uniform(low=-cont_frac, high=cont_frac)
               contaminated_row[start:end] = row[start:end] * factor
           return contaminated_row
      
       # Apply the piecewise contamination row-by-row.
       for i in range(n):
           M_cont[i, :] = piecewise_contaminate_row(M[i, :], contamination_fraction, num_segments, rng)
  
   else:
       raise ValueError("Unknown contamination type. Choose among 'multiplicative', 'nonlinear', or 'piecewise'.")
  
   return M_cont


In [None]:
# Define contamination levels to test
contamination_levels = np.linspace(0.0, 1.0, 100)

for cont_type in ['piecewise']:
    print(f"\nContamination type: {cont_type}")
    print("="*100)
    
    # Store results for plotting
    plot_results = {method: {'means': [], 'stds': []} for method in T_results_emp.keys()}

    # Run experiment for each contamination level
    for cont_frac in tqdm(contamination_levels):
        abstraction_error = {name: [] for name in T_results_emp.keys()}
        
        for _ in range(1):  # Multiple runs for each contamination level
            for name, res in T_results_emp.items():
                T = res['T_matrix']
                total = 0
                
                for iota in Ill:
                    L_i = LLmodels[iota].F
                    L_i = contaminate_structural_matrix(L_i, contamination_fraction=cont_frac, contamination_type=cont_type)
                    H_i = HLmodels[omega[iota]].F
                    H_i = contaminate_structural_matrix(H_i, contamination_fraction=cont_frac, contamination_type=cont_type)
                    
                    base_norm = L_i @ (hat_dict['L'].T)
                    abst_norm = H_i @ (hat_dict['H'].T)
                    
                    tau_base = T @ base_norm
                    dist = evut.compute_empirical_distance(tau_base, abst_norm, 'fro')
                    d = tau_base.shape[0] * tau_base.shape[1]  # number of entries
                    dist /= np.sqrt(d)
                    # dist *= 100
                    total += dist
                
                # Store average error for this iteration
                iter_avg = total / len(Ill)
                abstraction_error[name].append(iter_avg)
        
        # Store results for this contamination level
        for method in T_results_emp.keys():
            mean_e = np.mean(abstraction_error[method])
            std_e = np.std(abstraction_error[method])
            plot_results[method]['means'].append(mean_e)
            plot_results[method]['stds'].append(std_e)

    # Compute and print the overall averages
    print(f"{'Method':<15} {'Mean ± std':<35}")
    print("-" * 100)
    
    # Compute averages for each method
    method_averages = []
    for method in T_results_emp.keys():
        mean = np.mean(plot_results[method]['means'])
        std = np.std(plot_results[method]['means'])
        method_averages.append((method, mean, std))
    
    # Sort by mean (worst to best)
    method_averages.sort(key=lambda x: x[1], reverse=True)
    
    # Print sorted averages
    for method, mean, std in method_averages:
        ci = std
        print(f"{method:<15} {mean:>8.4f} ± {ci:<8.4f}")
    
    print("="*100)

### ω-misspecification

In [91]:
def contaminate_omega_map(original_omega, num_misalignments):
    """
    Randomly corrupt a subset of entries in the ω map to simulate mapping misspecification.
    
    Args:
        original_omega (dict): Original intervention mapping.
            For example: {None: None, iota1: H_i1, iota2: H_i1, iota3: H_i2, ...}
        num_misalignments (int): Desired number of misaligned mappings.
        
    Returns:
        dict: A new ω mapping with up to num_misalignments entries altered.
    """
    # Exclude keys or values that are None if desired.
    omega_keys = [k for k in original_omega.keys() if k is not None]
    omega_vals = [original_omega[k] for k in omega_keys if original_omega[k] is not None]
    
    # Start with a copy of the original mapping.
    contaminated_omega = original_omega.copy()
    
    # Bound the number of misalignments by the number of eligible keys.
    num_to_corrupt = min(num_misalignments, len(omega_keys))
    
    # Randomly select keys to corrupt.
    to_corrupt = random.sample(omega_keys, k=num_to_corrupt)
    
    # Create a random permutation of available targets (ensuring change)
    # Use the set of targets from eligible keys.
    all_targets = list(set(omega_vals))
    
    for key in to_corrupt:
        original_target = original_omega[key]
        # Only corrupt if there's an alternative available.
        available_targets = [t for t in all_targets if t != original_target]
        if available_targets:
            new_target = random.choice(available_targets)
            contaminated_omega[key] = new_target
            
    return contaminated_omega

In [None]:
# Define misalignment levels to test
misalignment_levels = range(0, len(Ill)) 
omega_plot_results  = {method: {'means': [], 'stds': []} for method in T_results_emp.keys()}

for num_mis in tqdm(misalignment_levels):
    abstraction_error = {name: [] for name in T_results_emp.keys()}
    
    for _ in range(1): 
        # Contaminate the omega map
        omega_cont = contaminate_omega_map(omega, num_mis)
        
        for name, res in T_results_emp.items():
            T = res['T_matrix']
            
            total = 0
            for iota in Ill:
                L_i = LLmodels[iota].F
                H_i = HLmodels[omega_cont[iota]].F
                
                base_norm = L_i @ (hat_dict['L'].T)
                abst_norm = H_i @ (hat_dict['H'].T)
                
                tau_base = T @ base_norm
                dist = evut.compute_empirical_distance(tau_base, abst_norm, 'fro')
                total += dist

            iter_avg = total / len(Ill)
            abstraction_error[name].append(iter_avg)
    
    for method in T_results_emp.keys():
        mean_e = np.mean(abstraction_error[method])
        std_e = np.std(abstraction_error[method])
        omega_plot_results[method]['means'].append(mean_e)
        omega_plot_results[method]['stds'].append(std_e)


print("\n" + "="*100)
print("AVERAGE ERROR ACROSS ALL OMEGA MISALIGNMENTS (EMPIRICAL)")
print("="*100)
print(f"{'Method':<15} {'Mean ± std':<35}")
print("-"*100)

# Compute averages for each method
method_averages = []
for method in T_results_emp.keys():
    # Get all means across misalignment levels
    all_means = omega_plot_results[method]['means']
    # Compute overall mean and std
    overall_mean = np.mean(all_means)
    overall_std = np.std(all_means)
    method_averages.append((method, overall_mean, overall_std))

method_averages.sort(key=lambda x: x[1], reverse=True)
for method, mean, std in method_averages:
    ci = std
    print(f"{method:<15} {mean:>8.4f} ± {ci:<8.4f}")

print("="*100)