In [25]:
import joblib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import scipy.stats as stats
import random
import re
import utilities as ut
import modularised_utils as mut

from matplotlib.animation import FuncAnimation
from IPython.display import HTML

sns.set_theme(style="whitegrid")
seed = 42
np.random.seed(seed)

In [None]:
experiment = 'slc'
setting    = 'empirical'

if setting == 'gaussian':
    path = f"data/{experiment}/results"

elif setting == 'empirical':
    path = f"data/{experiment}/results_empirical"

saved_folds = joblib.load(f"data/{experiment}/cv_folds.pkl")

# Load the original data dictionary
all_data      = ut.load_all_data(experiment)

Dll_samples   = all_data['LLmodel']['data']
Dhl_samples   = all_data['HLmodel']['data']
I_ll_relevant = all_data['LLmodel']['intervention_set']
omega         = all_data['abstraction_data']['omega']
ll_var_names  = list(all_data['LLmodel']['graph'].nodes())
hl_var_names  = list(all_data['HLmodel']['graph'].nodes())

Data loaded for 'slc'.


In [37]:
# Load dictionaries containing the results for each optimization method
if setting == 'gaussian':
    diroca_results = joblib.load(f"{path}/diroca_cv_results.pkl")
    gradca_results = joblib.load(f"{path}/gradca_cv_results.pkl")
    baryca_results = joblib.load(f"{path}/baryca_cv_results.pkl")

elif setting == 'empirical':
    diroca_results = joblib.load(f"{path}/diroca_cv_results_empirical.pkl")
    gradca_results = joblib.load(f"{path}/gradca_cv_results_empirical.pkl")
    baryca_results = joblib.load(f"{path}/baryca_cv_results_empirical.pkl")
    abslingam_results = joblib.load(f"{path}/abslingam_cv_results_empirical.pkl")

results_to_evaluate = {}

if setting == 'empirical':
    if abslingam_results:
        first_fold_key = list(abslingam_results.keys())[0]
        for style in abslingam_results[first_fold_key].keys():
            method_name = f"Abs-LiNGAM ({style})"
            new_abslingam_dict = {}
            for fold_key, fold_results in abslingam_results.items():
                if style in fold_results:
                    new_abslingam_dict[fold_key] = {style: fold_results[style]}
            results_to_evaluate[method_name] = new_abslingam_dict
    
    def create_diroca_label(run_id):
        """Parses a run_id and creates a simplified label if epsilon and delta are equal."""
        # Use regular expression to find numbers for epsilon and delta
        matches = re.findall(r'(\d+\.?\d*)', run_id)
        if len(matches) == 2:
            eps, delta = matches
            # If they are the same, use the simplified format
            if eps == delta:
                # Handle integer conversion for clean labels like '1' instead of '1.0'
                val = int(float(eps)) if float(eps).is_integer() else float(eps)
                return f"DIROCA (eps_delta_{val})"
        # Otherwise, or if parsing fails, use the full original name
        return f"DIROCA ({run_id})"

    # Unpack each DIROCA hyperparameter run with the new, clean label
    if diroca_results:
        first_fold_key = list(diroca_results.keys())[0]
        for run_id in diroca_results[first_fold_key].keys():
            method_name = create_diroca_label(run_id) # Use the new helper to create the name
            new_diroca_dict = {}
            for fold_key, fold_results in diroca_results.items():
                if run_id in fold_results:
                    new_diroca_dict[fold_key] = {run_id: fold_results[run_id]}
            results_to_evaluate[method_name] = new_diroca_dict

    results_to_evaluate['GradCA'] = gradca_results
    results_to_evaluate['BARYCA'] = baryca_results

elif setting == 'gaussian':
    results_to_evaluate['GradCA'] = gradca_results
    results_to_evaluate['BARYCA'] = baryca_results

    if diroca_results:
        first_fold_key = list(diroca_results.keys())[0]
        diroca_run_ids = list(diroca_results[first_fold_key].keys())

        # create a separate entry for each DIROCA run
        for run_id in diroca_run_ids:
            method_name = f"DIROCA ({run_id})"
            
            new_diroca_dict = {}
            for fold_key, fold_results in diroca_results.items():
                # For each fold grab the data for the current run_id
                if run_id in fold_results:
                    new_diroca_dict[fold_key] = {run_id: fold_results[run_id]}
            
            results_to_evaluate[method_name] = new_diroca_dict

label_map_gaussian = {
                        'DIROCA (eps_delta_0.111)': 'DiRoCA_star',
                        'DIROCA (eps_delta_1)': 'DIROCA_1',
                        'DIROCA (eps_delta_2)': 'DIROCA_2',
                        'DIROCA (eps_delta_4)': 'DIROCA_4',
                        'DIROCA (eps_delta_8)': 'DIROCA_8',
                        'GradCA': 'GradCA',
                        'BARYCA': 'BARYCA'
                    }

label_map_empirical = {
                        'DIROCA (eps_0.328_delta_0.107)': 'DiRoCA_star',
                        'DIROCA (eps_delta_1)': 'DIROCA_1',
                        'DIROCA (eps_delta_2)': 'DIROCA_2',
                        'DIROCA (eps_delta_4)': 'DIROCA_4',
                        'DIROCA (eps_delta_8)': 'DIROCA_8',
                        'GradCA': 'GradCA',
                        'BARYCA': 'BARYCA',
                        'Abs-LiNGAM (Perfect)': 'Abslin_p',
                        'Abs-LiNGAM (Noisy)': 'Abslin_n'
                    }

if setting == 'empirical':
    results_to_evaluate = {label_map_empirical.get(key, key): value for key, value in results_to_evaluate.items()}

elif setting == 'gaussian':
    results_to_evaluate = {label_map_gaussian.get(key, key): value for key, value in results_to_evaluate.items()}

print("\nMethods available for evaluation:")
for key in results_to_evaluate.keys():
    print(f"  - {key}")


Methods available for evaluation:
  - Abslin_p
  - Abslin_n
  - DIROCA (eps_0.107_delta_0.035)
  - DIROCA_1
  - DIROCA_2
  - DIROCA_4
  - DIROCA_8
  - GradCA
  - BARYCA


## F-contamination

In [38]:
def contaminate_data(data, strength, contamination_type, num_segments=10, seed=None):
    """
    Applies a specified contamination to data samples to simulate model misspecification.

    Args:
        data (np.ndarray): The original data samples.
        strength (float): The magnitude of the contamination.
        contamination_type (str): 'piecewise', 'multiplicative', or 'nonlinear'.
        num_segments (int): Number of segments for the 'piecewise' type.
        seed (int, optional): Random seed for reproducibility.
        
    Returns:
        np.ndarray: The contaminated data.
    """
    rng = np.random.default_rng(seed)
    data_cont = data.copy()
    
    if contamination_type == "multiplicative":
        # Apply element-wise multiplicative noise
        noise = rng.uniform(low=1.0 - strength, high=1.0 + strength, size=data.shape)
        data_cont *= noise
  
    elif contamination_type == "nonlinear":
        # Apply a sine-based non-linear distortion
        data_cont += strength * np.sin(data_cont)
  
    elif contamination_type == "piecewise":
        # Apply piecewise contamination to each column (variable)
        for col_idx in range(data.shape[1]):
            column = data_cont[:, col_idx]
            breakpoints = np.quantile(column, q=np.linspace(0, 1, num_segments + 1))
            breakpoints[-1] += 1e-6 # Ensure the last element is included
            
            for i in range(num_segments):
                factor = 1.0 + rng.uniform(low=-strength, high=strength)
                mask = (column >= breakpoints[i]) & (column < breakpoints[i+1])
                data_cont[mask, col_idx] *= factor
  
    else:
        raise ValueError(f"Unknown contamination type: {contamination_type}")
  
    return data_cont

In [39]:
contamination_type_to_run = 'multiplicative' # Options: 'piecewise', 'multiplicative', 'nonlinear'

# Range of contamination strengths to test
contamination_strengths = np.linspace(0, 1, 10) 

# Number of random contaminations to average over for each setting
num_trials = 10

f_spec_records = []
print(f"F '{contamination_type_to_run}' misspecification evaluation")

for strength in tqdm(contamination_strengths, desc="Contamination Strength"):
    for trial in range(num_trials):
        for i, fold_info in enumerate(saved_folds):
            for method_name, results_dict in results_to_evaluate.items():
                fold_results = results_dict.get(f'fold_{i}', {})
                for run_key, run_data in fold_results.items():

                    if 'DIROCA' in method_name:
                        method_label = method_name
                    else:
                        method_label = method_name

                    T_learned = run_data['T_matrix']
                    test_indices = run_data['test_indices']
                    
                    errors_per_intervention = []
                    for iota in I_ll_relevant:
                        Dll_test_clean = Dll_samples[iota][test_indices]
                        Dhl_test_clean = Dhl_samples[omega[iota]][test_indices]
                        
                        Dll_test_cont = contaminate_data(Dll_test_clean, strength, contamination_type_to_run, seed=trial)
                        Dhl_test_cont = contaminate_data(Dhl_test_clean, strength, contamination_type_to_run, seed=trial)
                        if setting == 'gaussian':
                            error = ut.calculate_abstraction_error(T_learned, Dll_test_cont, Dhl_test_cont)
                        elif setting == 'empirical':
                            error = ut.calculate_empirical_error(T_learned, Dll_test_cont, Dhl_test_cont)
                        if not np.isnan(error):
                            errors_per_intervention.append(error)
                    
                    avg_error = np.mean(errors_per_intervention) if errors_per_intervention else np.nan
                    
                    record = {
                        'method': method_label, 
                        'contamination': strength,
                        'trial': trial,
                        'fold': i,
                        'error': avg_error
                    }
                    f_spec_records.append(record)

f_spec_df = pd.DataFrame(f_spec_records)
print("--- F-Misspecification Evaluation Complete ---")

F 'multiplicative' misspecification evaluation


Contamination Strength: 100%|██████████| 10/10 [00:05<00:00,  1.67it/s]

--- F-Misspecification Evaluation Complete ---





In [40]:
print("\n" + "="*65)
print(f"Overall Performance (Averaged Across All '{contamination_type_to_run}' Strengths)")
print("="*65)
print(f"{'Method/Run':<35} | {'Mean ± Std'}")
print("="*65)

summary = f_spec_df.groupby('method')['error'].agg(['mean', 'std', 'count'])
summary['sem'] = summary['std'] #/ np.sqrt(summary['count'])
# summary['ci95'] = 1.96 * summary['sem']
summary['ci95'] = summary['sem']

for method_name, row in summary.sort_values('mean').iterrows():
    print(f"{method_name:<35} | {row['mean']:.4f} ± {row['ci95']:.4f}")
print("="*65)


Overall Performance (Averaged Across All 'multiplicative' Strengths)
Method/Run                          | Mean ± Std
GradCA                              | 63.1590 ± 5.4351
DIROCA_1                            | 69.7332 ± 4.9434
DIROCA_2                            | 70.7777 ± 4.0304
DIROCA_4                            | 70.7777 ± 4.0304
DIROCA_8                            | 70.7777 ± 4.0304
BARYCA                              | 95.4702 ± 8.7866
DIROCA (eps_0.107_delta_0.035)      | 97.6368 ± 9.0061
Abslin_n                            | 101.4545 ± 8.6654
Abslin_p                            | 103.8440 ± 10.8534


## ω contamination

In [42]:
def contaminate_omega_map(original_omega, num_misalignments):
    
    """Randomly re-wires a subset of entries in the omega map."""
    
    omega_keys = [k for k in original_omega.keys() if k is not None]
    omega_vals = [original_omega[k] for k in omega_keys if original_omega[k] is not None]

    contaminated_omega = original_omega.copy()
    
    # Bound the number of misalignments by the number of eligible keys.
    num_to_corrupt = min(num_misalignments, len(omega_keys))
    # Randomly select keys to corrupt.
    to_corrupt = random.sample(omega_keys, k=num_to_corrupt)
    
    # Create a random permutation of available targets (ensuring change)
    # Use the set of targets from eligible keys.
    all_targets = list(set(omega_vals))

    for key in to_corrupt:
        original_target = original_omega[key]
        # Only corrupt if there's an alternative available.
        available_targets = [t for t in all_targets if t != original_target]
        if available_targets:
            new_target = random.choice(available_targets)
            contaminated_omega[key] = new_target
            
    return contaminated_omega

In [None]:
max_misalignments = len(I_ll_relevant) 
misalignment_levels = range(0, max_misalignments)

# Number of random contaminations to average over for each setting
num_trials = 20

omega_spec_records = []
print("Omega-misspecification evaluation")

for num_misalignments in tqdm(misalignment_levels, desc="Misalignment Level"):
    for trial in range(num_trials):
        # Create a new scrambled omega map
        omega_cont = contaminate_omega_map(omega, num_misalignments)
        
        for i, fold_info in enumerate(saved_folds):
            for method_name, results_dict in results_to_evaluate.items():
                fold_results = results_dict.get(f'fold_{i}', {})
                for run_key, run_data in fold_results.items():

                    if 'DIROCA' in method_name:
                        method_label = method_name
                    else:
                        method_label = method_name

                    T_learned = run_data['T_matrix']
                    test_indices = run_data['test_indices']
                    
                    errors_per_intervention = []
                    for iota in I_ll_relevant:
                        Dll_test = Dll_samples[iota][test_indices]
                        # Use the contaminated omega map
                        Dhl_test = Dhl_samples[omega_cont[iota]][test_indices]
                        
                        if setting == 'gaussian':
                            error = ut.calculate_abstraction_error(T_learned, Dll_test, Dhl_test)
                        elif setting == 'empirical':
                            error = ut.calculate_empirical_error(T_learned, Dll_test, Dhl_test)
                            
                        if not np.isnan(error): errors_per_intervention.append(error)
                    
                    avg_error = np.mean(errors_per_intervention) if errors_per_intervention else np.nan
                    
                    record = {
                                'method': method_label, 
                                'misalignments': num_misalignments,
                                'trial': trial,
                                'fold': i,
                                'error': avg_error
                            }
                    omega_spec_records.append(record)

omega_spec_df = pd.DataFrame(omega_spec_records)
print("\n\n--- Omega-Misspecification Evaluation Complete ---")


6
Omega-misspecification evaluation


Misalignment Level: 100%|██████████| 6/6 [00:02<00:00,  2.13it/s]



--- Omega-Misspecification Evaluation Complete ---





In [44]:
print("\n" + "="*65)
print("Overall Performance (Averaged Across All Misalignment Levels)")
print("="*65)
print(f"{'Method/Run':<35} | {'Mean ± Std'}")
print("="*65)

summary = omega_spec_df.groupby('method')['error'].agg(['mean', 'std', 'count'])
summary['sem'] = summary['std'] #/ np.sqrt(summary['count'])
# summary['ci95'] = 1.96 * summary['sem']
summary['ci95'] = summary['sem']


for method_name, row in summary.sort_values('mean').iterrows():
    print(f"{method_name:<35} | {row['mean']:.4f} ± {row['ci95']:.4f}")
print("="*65)


Overall Performance (Averaged Across All Misalignment Levels)
Method/Run                          | Mean ± Std
GradCA                              | 58.1176 ± 2.5301
DIROCA_2                            | 60.6816 ± 4.6320
DIROCA_4                            | 60.6816 ± 4.6320
DIROCA_8                            | 60.6816 ± 4.6320
DIROCA_1                            | 61.0272 ± 3.3619
BARYCA                              | 88.2573 ± 4.0329
Abslin_p                            | 91.0618 ± 3.5338
DIROCA (eps_0.107_delta_0.035)      | 91.7016 ± 4.8322
Abslin_n                            | 93.7674 ± 3.3887
