In [None]:
import numpy as np
import torch
import joblib
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns
from scipy.stats import wilcoxon

# Local modules
import modularised_utils as mut
import opt_utils as oput 
import evaluation_utils as evut
import Linear_Additive_Noise_Models as lanm
import params
import random

from math_utils import compute_wasserstein

In [43]:
experiment = 'lucas6x3'

In [44]:
T_results = joblib.load(f"data/{experiment}/diroca_train_results.pkl")

In [45]:
coeff_estimation = True

Dll_obs = joblib.load(f"data/{experiment}/Dll_obs_test.pkl")
Dhl_obs = joblib.load(f"data/{experiment}/Dhl_obs_test.pkl")

LLmodels = joblib.load(f"data/{experiment}/LLmodels.pkl")
HLmodels = joblib.load(f"data/{experiment}/HLmodels.pkl")

num_llsamples, num_hlsamples  = Dll_obs.shape[0], Dhl_obs.shape[0]

Gll, Ill = mut.load_model(experiment, 'LL')
Ghl, Ihl = mut.load_model(experiment, 'HL')

n_varsll, n_varshl = len(Gll.nodes()), len(Ghl.nodes())

omega    = mut.load_omega_map(experiment)

if coeff_estimation == True:
    ll_coeffs = mut.get_coefficients(Dll_obs, Gll)
    hl_coeffs = mut.get_coefficients(Dhl_obs, Ghl) 
else:
    ll_coeffs = mut.load_coeffs(experiment, 'LL')
    hl_coeffs = mut.load_coeffs(experiment, 'HL')

U_ll_hat, mu_U_ll_hat, Sigma_U_ll_hat = mut.lan_abduction(Dll_obs, Gll, ll_coeffs)
U_hl_hat, mu_U_hl_hat, Sigma_U_hl_hat = mut.lan_abduction(Dhl_obs, Ghl, hl_coeffs)

data = evut.generate_data(LLmodels, HLmodels, omega, num_llsamples, num_hlsamples, mu_U_ll_hat, Sigma_U_ll_hat, mu_U_hl_hat, Sigma_U_hl_hat)

In [46]:
test_observ        = True 
test_interv        = True
num_iter           = 100
metric             = 'wass'

if test_observ and test_interv:
    test_data = data

elif test_observ:
    test_data = {None: data[None]}

elif test_interv:
    test_data = {k: v for k, v in data.items() if k is not None}

## 0-shift

In [None]:
# Has to be learnt with coeff = True as well. check opt_ell.ipynb'
if coeff_estimation == True:
    
    results_single = {method: {'errors': [], 'mean': 0, 'ci': 0} for method in T_results.keys()}

    for name, res in T_results.items():
        T = res['T_matrix']
        errors = []  # Store errors for each intervention
        scale_factor = 1/np.sqrt(len(Ill))
        wass_total = 0
        for iota in Ill:
            L_i = LLmodels[iota].F
            V_i = T @ L_i
            H_i = HLmodels[omega[iota]].F

            muV    = V_i @ mu_U_ll_hat
            sigmaV = V_i @ Sigma_U_ll_hat @ V_i.T
            muH    = H_i @ mu_U_hl_hat
            sigmaH = H_i @ Sigma_U_hl_hat @ H_i.T


            # Compute Wasserstein metric
            wass_dist = np.sqrt(mut.compute_wasserstein(muV, sigmaV, muH, sigmaH))
            errors.append(wass_dist)
            wass_total += wass_dist

        # Calculate mean and CI
        mean_error = np.mean(errors)
        std_error = np.std(errors)
        ci = std_error

        # Store all statistics
        results_single[name] = {
            'errors': errors,
            'mean': mean_error,
            'ci': ci
        }

    results_single = dict(sorted(results_single.items(), key=lambda x: x[1]['mean']))
    ll_coeffs = mut.load_coeffs(experiment, 'LL')
    hl_coeffs = mut.load_coeffs(experiment, 'HL')
else:
    print('No coeff estimation')
    
# Print results
print("\n" + "="*100)
print(f"{'Method':<15} {'Error (mean ± CI)':<35}")
print("="*100)

for method, stats in results_single.items():
    print(f"{method:<15} {stats['mean']:>8.4f} ± {stats['ci']:<8.4f}")


## ρ-shift

In [49]:
rad_values = np.arange(0.05, 100.05, 10).tolist()  
sample_forms = ['boundary', 'sample']

center   = 'worst'
coverage_type = 'uniform'

hat_dict = {'L': [mu_U_ll_hat, Sigma_U_ll_hat], 'H': [mu_U_hl_hat, Sigma_U_hl_hat]}

worst = 'T_8'
mu_worst_L    = T_results[worst]['optimization_params']['L']['mu_U']
Sigma_worst_L = T_results[worst]['optimization_params']['L']['Sigma_U']
mu_worst_H    = T_results[worst]['optimization_params']['H']['mu_U']
Sigma_worst_H = T_results[worst]['optimization_params']['H']['Sigma_U']

worst_dict = {'L': [mu_worst_L, Sigma_worst_L], 'H': [mu_worst_H, Sigma_worst_H]}

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Define the r_sigma values to sweep over
sigma_values = np.linspace(0, 1, 100)
methods_to_track = list(T_results.keys())

# Storage for plotting and for mean/CI across all sigmas
error_evolution = {method: [] for method in methods_to_track}
mean_across_sigmas = {method: [] for method in methods_to_track}
ci_across_sigmas = {method: [] for method in methods_to_track}

for r_sigma in sigma_values:
    #print(f"Testing with r_sigma = {r_sigma}")
    # Generate shifted Gaussian families for this sigma
    shift_family_L = mut.generate_shifted_gaussian_family(
        mu_worst_L, Sigma_worst_L, 1, r_mu=0, r_sigma=r_sigma, coverage='rand', seed=None)
    shift_family_H = mut.generate_shifted_gaussian_family(
        mu_worst_H, Sigma_worst_H, 1, r_mu=0, r_sigma=r_sigma, coverage='rand', seed=None)

    # Initialize results for this sigma
    results = {method: [] for method in methods_to_track}

    for shift_L, shift_H in zip(shift_family_L, shift_family_H):
        noise_muL, noise_SigmaL = shift_L
        noise_muH, noise_SigmaH = shift_H
        noise_muL = noise_muL.numpy() if hasattr(noise_muL, 'numpy') else noise_muL
        noise_muH = noise_muH.numpy() if hasattr(noise_muH, 'numpy') else noise_muH
        noise_SigmaL = noise_SigmaL.numpy() if hasattr(noise_SigmaL, 'numpy') else noise_SigmaL
        noise_SigmaH = noise_SigmaH.numpy() if hasattr(noise_SigmaH, 'numpy') else noise_SigmaH

        for name in methods_to_track:
            res = T_results[name]
            T = res['T_matrix']
            wass_total = 0
            for iota in Ill:
                L_i = LLmodels[iota].F
                V_i = T @ L_i
                H_i = HLmodels[omega[iota]].F
                muV = V_i @ noise_muL
                sigmaV = V_i @ noise_SigmaL @ V_i.T
                muH = H_i @ noise_muH
                sigmaH = H_i @ noise_SigmaH @ H_i.T
                wass_dist = np.sqrt(compute_wasserstein(muV, sigmaV, muH, sigmaH))
                wass_total += wass_dist
            results[name].append(wass_total / len(Ill))

    # Store mean and CI for each method for this sigma
    for method in methods_to_track:
        mean = np.mean(results[method])
        std = np.std(results[method])
        ci = std/10
        error_evolution[method].append(mean)
        mean_across_sigmas[method].append(mean)
        ci_across_sigmas[method].append(ci)

# Print mean and CI across all sigmas for each method
print(f"\n{'Method':<15} {'Mean across sigmas ± 1.96*Std':<35}")
print("="*50)
method_stats = []
for method in methods_to_track:
    mean_over_sigmas = np.mean(mean_across_sigmas[method])
    std_over_sigmas = np.std(mean_across_sigmas[method])
    ci_over_sigmas = std_over_sigmas
    method_stats.append((method, mean_over_sigmas, ci_over_sigmas))
# Sort by mean, descending (worst to best)
method_stats.sort(key=lambda x: x[1], reverse=True)
for method, mean, ci in method_stats:
    print(f"{method:<15} {mean:8.4f} ± {ci:<8.4f}")
print("="*50)

## F-contamination

In [52]:
def contaminate_structural_matrix(M, contamination_fraction, contamination_type, num_segments=10, seed=None):
   """
   Contaminates a linear transformation matrix M to break its strict linearity.
  
   Args:
       M (np.ndarray): Original linear transformation matrix (n x m).
       contamination_fraction (float): Magnitude of contamination (e.g., between 0.05 and 1.0).
       contamination_type (str): Type of contamination to apply. Options are:
                                 'multiplicative', 'nonlinear', or 'piecewise'.
       num_segments (int): Number of segments for piecewise linear contamination (default: 3).
       seed (int, optional): Random seed for reproducibility.
      
   Returns:
       np.ndarray: The contaminated matrix.
   """
   rng = np.random.default_rng(seed)
   M_cont = M.copy() 
   n, m = M.shape


   if contamination_type == "multiplicative":
       # Apply element-wise multiplicative noise (preserving zeros below the main diagonal)
       # Only perturb the upper-triangular part.
       noise = rng.uniform(low=1.0 - contamination_fraction, high=1.0 + contamination_fraction, size=M.shape)
       # Create a mask for the upper triangular (including diagonal)
       mask = np.triu(np.ones_like(M))
       M_cont = M * (1 - mask + mask * noise)
  
   elif contamination_type == "nonlinear":
       # Apply a nonlinear function to L: for instance, add a sine-based perturbation.
       M_cont = M + contamination_fraction * np.sin(M)
  
   elif contamination_type == "piecewise":
       # Contaminate each row with a piecewise linear function.
       def piecewise_contaminate_row(row, cont_frac, segments, rng):
           n_elem = len(row)
           # Choose random breakpoints among indices
           if segments < 2:
               return row  # nothing to do
           breakpoints = np.sort(rng.integers(low=1, high=n_elem, size=segments - 1))
           breakpoints = np.concatenate(([0], breakpoints, [n_elem]))
           contaminated_row = np.empty_like(row)
           # For each segment, assign a random multiplicative factor.
           for j in range(len(breakpoints) - 1):
               start = breakpoints[j]
               end = breakpoints[j+1]
               factor = 1.0 + rng.uniform(low=-cont_frac, high=cont_frac)
               contaminated_row[start:end] = row[start:end] * factor
           return contaminated_row
      
       # Apply the piecewise contamination row-by-row.
       for i in range(n):
           M_cont[i, :] = piecewise_contaminate_row(M[i, :], contamination_fraction, num_segments, rng)
  
   else:
       raise ValueError("Unknown contamination type. Choose among 'multiplicative', 'nonlinear', or 'piecewise'.")
  
   return M_cont


In [None]:
# Define contamination levels to test
contamination_levels = np.linspace(0.0, 1.0, 100)  
for cont_type in ['piecewise']:
    print(f"Contamination type: {cont_type}")
    # Store results for plotting
    plot_results = {method: {'means': [], 'stds': []} for method in T_results.keys()}


    # Run experiment for each contamination level
    for cont_frac in tqdm(contamination_levels):
        abstraction_error = {name: [] for name in T_results.keys()}
    
        for _ in range(1):
            noise_muL, noise_SigmaL = mu_U_ll_hat, Sigma_U_ll_hat
            noise_muH, noise_SigmaH = mu_U_hl_hat, Sigma_U_hl_hat
            
            noise_muL    = noise_muL.numpy() if torch.is_tensor(noise_muL) else noise_muL
            noise_muH    = noise_muH.numpy() if torch.is_tensor(noise_muH) else noise_muH
            noise_SigmaL = noise_SigmaL.numpy() if torch.is_tensor(noise_SigmaL) else noise_SigmaL
            noise_SigmaH = noise_SigmaH.numpy() if torch.is_tensor(noise_SigmaH) else noise_SigmaH

            for name, res in T_results.items():
                T = res['T_matrix']
                total = 0
                for iota in Ill:
                    L_i = LLmodels[iota].F
                    L_i = contaminate_structural_matrix(L_i, contamination_fraction=cont_frac, contamination_type=cont_type)
                    V_i = T @ L_i
                    H_i = HLmodels[omega[iota]].F
                    H_i = contaminate_structural_matrix(H_i, contamination_fraction=cont_frac, contamination_type=cont_type)
                    
                    muV    = V_i @ noise_muL
                    sigmaV = V_i @ noise_SigmaL @ V_i.T
                    muH    = H_i @ noise_muH
                    sigmaH = H_i @ noise_SigmaH @ H_i.T

                    dist = np.sqrt(compute_wasserstein(muV, sigmaV, muH, sigmaH))

                    total += dist


                iter_avg = total / len(Ill)
                abstraction_error[name].append(iter_avg)


        # Store results for this contamination level
        for method in T_results.keys():
            mean_e = np.mean(abstraction_error[method])
            std_e = np.std(abstraction_error[method])
            plot_results[method]['means'].append(mean_e)
            plot_results[method]['stds'].append(std_e)

    # Compute averages across all contamination levels for each method
    method_averages = {}

    for method in T_results.keys():
        # Get all means across contamination levels
        all_means = plot_results[method]['means']
        # Compute the mean and std across all contamination levels
        overall_mean = np.mean(all_means)
        overall_std = np.std(all_means)
        method_averages[method] = (overall_mean, overall_std)

    # Sort methods by average (worst to best)
    sorted_methods = sorted(method_averages.items(), key=lambda x: x[1][0], reverse=True)

    # Print results
    print("\n" + "="*100)
    print("AVERAGE WASSERSTEIN DISTANCE ACROSS ALL CONTAMINATION LEVELS (0.0 to 1.0)")
    print("="*100)
    print(f"{'Method':<15} {'Mean ± CI (95%)':<35}")
    print("-"*100)

    for method, (mean, std) in sorted_methods:
        ci = std
        print(f"{method:<15} {mean:>8.4f} ± {ci:<8.4f}")

    print("="*100)

## ω-contamination

In [55]:
def contaminate_omega_map(original_omega, num_misalignments):
    """
    Randomly corrupt a subset of entries in the ω map to simulate mapping misspecification.
    
    Args:
        original_omega (dict): Original intervention mapping.
            For example: {None: None, iota1: H_i1, iota2: H_i1, iota3: H_i2, ...}
        num_misalignments (int): Desired number of misaligned mappings.
        
    Returns:
        dict: A new ω mapping with up to num_misalignments entries altered.
    """
    # Exclude keys or values that are None if desired.
    omega_keys = [k for k in original_omega.keys() if k is not None]
    omega_vals = [original_omega[k] for k in omega_keys if original_omega[k] is not None]
    
    # Start with a copy of the original mapping.
    contaminated_omega = original_omega.copy()
    
    # Bound the number of misalignments by the number of eligible keys.
    num_to_corrupt = min(num_misalignments, len(omega_keys))
    
    # Randomly select keys to corrupt.
    to_corrupt = random.sample(omega_keys, k=num_to_corrupt)
    
    # Create a random permutation of available targets (ensuring change)
    # Use the set of targets from eligible keys.
    all_targets = list(set(omega_vals))
    
    for key in to_corrupt:
        original_target = original_omega[key]
        # Only corrupt if there's an alternative available.
        available_targets = [t for t in all_targets if t != original_target]
        if available_targets:
            new_target = random.choice(available_targets)
            contaminated_omega[key] = new_target
            
    return contaminated_omega

In [None]:
# Define contamination levels to test
misalignment_levels = range(0, len(Ill))

# Store results for plotting
omega_plot_results = {method: {'means': [], 'stds': []} for method in T_results.keys()}


# Run experiment for each contamination level
for num_mis in tqdm(misalignment_levels):
   abstraction_error = {name: [] for name in T_results.keys()}
  
   for _ in range(1):
    noise_muL, noise_SigmaL = mu_U_ll_hat, Sigma_U_ll_hat
    noise_muH, noise_SigmaH = mu_U_hl_hat, Sigma_U_hl_hat
    
    noise_muL    = noise_muL.numpy() if torch.is_tensor(noise_muL) else noise_muL
    noise_muH    = noise_muH.numpy() if torch.is_tensor(noise_muH) else noise_muH
    noise_SigmaL = noise_SigmaL.numpy() if torch.is_tensor(noise_SigmaL) else noise_SigmaL
    noise_SigmaH = noise_SigmaH.numpy() if torch.is_tensor(noise_SigmaH) else noise_SigmaH

    omega_cont = contaminate_omega_map(omega, num_mis)


    for name, res in T_results.items():
        T = res['T_matrix']
        total = 0
        for iota in Ill:
            L_i = LLmodels[iota].F
            V_i = T @ L_i
            H_i = HLmodels[omega_cont[iota]].F
            
            muV    = V_i @ noise_muL
            sigmaV = V_i @ noise_SigmaL @ V_i.T
            muH    = H_i @ noise_muH
            sigmaH = H_i @ noise_SigmaH @ H_i.T


            dist = np.sqrt(compute_wasserstein(muV, sigmaV, muH, sigmaH))
            total += dist


        iter_avg = total / len(Ill)
        abstraction_error[name].append(iter_avg)


   # Store results for this contamination level
   for method in T_results.keys():
       mean_e = np.mean(abstraction_error[method])
       std_e = np.std(abstraction_error[method])
       omega_plot_results[method]['means'].append(mean_e)
       omega_plot_results[method]['stds'].append(std_e)

# Compute averages for each method
method_averages = []
for method in T_results.keys():
    # Get all means across misalignment levels
    all_means = omega_plot_results[method]['means']
    # Compute overall mean and std
    overall_mean = np.mean(all_means)
    overall_std = np.std(all_means)
    method_averages.append((method, overall_mean, overall_std))

# Sort methods by mean (worst to best)
method_averages.sort(key=lambda x: x[1], reverse=True)

# Print sorted averages
for method, mean, std in method_averages:
    ci = std
    print(f"{method:<15} {mean:>8.4f} ± {ci:<8.4f}")

print("="*100)