In [1]:
# import sys

# !{sys.executable} -m pip install networkx

In [2]:
import pandas as pd
import datetime, time
import numpy as np
import os, re, json, psutil
import copy
import warnings
import pickle
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
from joblib import Parallel, delayed
import torch
import torch.nn as nn

from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score, f1_score
from sklearn.mixture import GaussianMixture
from sklearn.neural_network import MLPRegressor
from sklearn.utils import resample

from scipy.stats import beta

In [3]:
from causallearn.search.ConstraintBased.PC import pc
from causallearn.search.ConstraintBased.FCI import fci
from causallearn.utils.cit import kci, CIT
from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge
from causallearn.graph.GeneralGraph import GeneralGraph
from causallearn.graph.GraphNode import GraphNode
from causallearn.utils.PDAG2DAG import pdag2dag

from dowhy import gcm
from dowhy.gcm import interventional_samples, AdditiveNoiseModel
from dowhy.gcm.causal_mechanisms import StochasticModel
from dowhy.gcm.ml import SklearnRegressionModel, SklearnClassificationModel
from dowhy.gcm.auto import AssignmentQuality

from notears.linear import notears_linear
from notears.nonlinear import notears_nonlinear, NotearsMLP

from lingam import LiM
import pydot

import networkx as nx
from networkx.drawing.nx_pydot import to_pydot

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
pd.set_option("display.max_columns", 1000)
pd.set_option("display.max_rows", 1000)

os.environ["OMP_NUM_THREADS"] = "1"     # prevent inner OpenMP threads
os.environ["MKL_NUM_THREADS"] = "1"     # same for MKL
os.environ["JOBLIB_MULTIPROCESSING"] = "0"
os.environ["JOBLIB_TEMP_FOLDER"] = "/tmp"  # optional safety for temp files

# Loading model dataframe

In [5]:
file_type = 'loose'
file_name_std = 'robust_std_data_for_model_'+file_type+'.csv'

final_df_std = pd.read_csv(file_name_std)
final_df_std.head(2)

Unnamed: 0,age,is_female,long_covid,me_cfs,fibromyalgia,dysautonomia,infection_episode,period_at_covid_start,acute_num_symp_prop,post_num_symp_prop,pre_num_symp_prop,acute_symp_sev_prop,post_symp_sev_prop,pre_symp_sev_prop,acute_symp_freq_prop,post_symp_freq_prop,pre_symp_freq_prop,pre_emotionally_stressful,pre_hr_variability,pre_mentally_demanding,pre_physically_active,pre_resting_hr,pre_sleep,acute_emotionally_stressful,acute_hr_variability,acute_mentally_demanding,acute_physically_active,acute_resting_hr,acute_sleep,post_emotionally_stressful,post_hr_variability,post_mentally_demanding,post_physically_active,post_resting_hr,post_sleep,pre_funcap_score,post_funcap_score,pre_crash,acute_crash,post_crash,acute_noncovid_infection,post_noncovid_infection,pre_noncovid_infection
0,0.529412,1,1,1,0,0,0.0,0.0,-0.007278,0.037508,0.880572,0.488509,0.952535,0.344357,-0.109665,0.167581,-0.158928,-1.560641,0.660555,-1.53429,-2.018847,0.690702,0.982801,-1.156195,0.650428,-1.545455,-1.69697,0.962482,1.0,-1.303714,0.668402,-1.536009,-2.0125,1.078761,1.257576,0.0,0.0,1.0,1.0,1.0,0,0,0
1,0.411765,0,1,0,0,0,0.0,0.0,0.892934,0.861503,0.880572,1.428591,0.971924,0.816596,1.05765,1.067037,1.224172,0.747051,-0.461927,0.730861,1.453629,0.49089,0.942207,1.030534,-0.448321,0.512397,0.501377,0.097147,2.276923,1.153667,-0.487816,0.777484,1.05,0.507806,1.369464,1.0,1.0,0.0,0.0,0.0,0,0,0


# Distributions

In [6]:
# --- Optional helper distributions ---
class GaussianDistribution(StochasticModel):
    def __init__(self):
        self.mean_ = None
        self.std_ = None

    def fit(self, X: np.ndarray):
        X = np.array(X).flatten()
        self.mean_ = np.mean(X)
        self.std_ = np.std(X)
        if self.std_ == 0:
            self.std_ = 1e-6  # avoid degenerate distribution

    def parameters(self):
        """Return parameters in a uniform dict format."""
        return {"mean": self.mean_, "std": self.std_}

    def draw_samples(self, n):
        return np.random.normal(self.mean_, self.std_, size=(n, 1))

    def evaluate(self, X: np.ndarray) -> np.ndarray:
        # Identity for stochastic-only nodes (no parents)
        return X

    def clone(self):
        new = GaussianDistribution()
        new.mean_, new.std_ = self.mean_, self.std_
        return new
    
class GaussianMixtureDistribution:
    def __init__(self, n_components=2):
        self.model = GaussianMixture(n_components=n_components)

    def fit(self, X):
        X = np.asarray(X)
        if X.ndim == 1:
            X = X.reshape(-1, 1)
        self.model.fit(X)

    def draw_samples(self, n):
        return self.model.sample(n)[0]

    def clone(self):
        new = GaussianMixtureDistribution(n_components=self.model.n_components)
        new.model.means_ = np.copy(self.model.means_)
        new.model.covariances_ = np.copy(self.model.covariances_)
        new.model.weights_ = np.copy(self.model.weights_)
        return new

    def parameters(self):
        return {
            "weights": self.model.weights_.tolist() if hasattr(self.model, "weights_") else None,
            "means": self.model.means_.tolist() if hasattr(self.model, "means_") else None,
            "covariances": self.model.covariances_.tolist() if hasattr(self.model, "covariances_") else None
        }

class BetaDistribution:
    def __init__(self):
        self.beta = beta
        self.a = None
        self.b = None

    def fit(self, X):
        X = np.clip(X.flatten(), 1e-6, 1-1e-6)
        X = np.asarray(X)
        if X.ndim == 1:
            X = X.reshape(-1, 1)
        self.a, self.b, _, _ = self.beta.fit(X, floc=0, fscale=1)

    def draw_samples(self, n):
        return self.beta.rvs(self.a, self.b, size=n).reshape(-1, 1)

    def clone(self):
        new = BetaDistribution()
        new.a, new.b = self.a, self.b
        return new

    def parameters(self):
        return {"a": self.a, "b": self.b}

class BernoulliDistribution:
    def __init__(self):
        self.p = None

    def fit(self, X):
        X = np.asarray(X).flatten()
        self.p = np.mean(X)

    def draw_samples(self, n):
        return np.random.binomial(1, self.p, size=n).reshape(-1, 1)

    def clone(self):
        new = BernoulliDistribution()
        new.p = self.p
        return new

    def parameters(self):
        return {"p": self.p}

# Functions

In [7]:
# === Background Knowledge Builder ===
def build_background_knowledge_cl(var_names, added_required=None, added_forbidden=None):
    graph_nodes = {v: GraphNode(v) for v in var_names}
    bk = BackgroundKnowledge()

    # --- Variable groups ---
    post_vars = [c for c in var_names if c.startswith("post_")]
    acute_vars = [c for c in var_names if c.startswith("acute_")]
    pre_vars = [c for c in var_names if c.startswith("pre_")]
    pre_or_acute_vars = pre_vars + acute_vars
    post_or_acute_vars = post_vars + acute_vars

    # --- Root nodes ---
    all_roots = ['age', 'is_female', 'me_cfs', 'long_covid', 'fibromyalgia', 'dysautonomia', 'infection_episode']
    strict_roots = ['age', 'is_female', 'infection_episode']
    relaxed_roots = [r for r in all_roots if r not in strict_roots]

    # Strict roots: forbid all incoming edges
    for v1 in var_names:
        for v2 in strict_roots:
            if v1 != v2:
                bk.add_forbidden_by_node(graph_nodes[v1], graph_nodes[v2])

    # Relaxed roots: forbid incoming edges from non-root nodes
    for v1 in var_names:
        for v2 in relaxed_roots:
            if v1 != v2 and v1 not in all_roots:
                bk.add_forbidden_by_node(graph_nodes[v1], graph_nodes[v2])

    # Treatment root node: forbid outgoing edges to relaxed root nodes
    for v2 in relaxed_roots:
        bk.add_forbidden_by_node(graph_nodes['infection_episode'], graph_nodes[v2]) # Forbid future features from being influenced by infection_episode

    # --- Required edges  ---
    bk.add_required_by_node(graph_nodes['is_female'], graph_nodes['me_cfs'])
    bk.add_required_by_node(graph_nodes['is_female'], graph_nodes['period_at_covid_start'])
    bk.add_required_by_node(graph_nodes['age'], graph_nodes['period_at_covid_start'])

    # Suggestive relationships from evaluation
    if added_required:
        for edge in added_required:
            bk.add_required_by_node(graph_nodes[edge[0]], graph_nodes[edge[1]])

    # --- Temporal forbidden edges ---
    for v1 in post_vars:
        for v2 in pre_or_acute_vars:
            bk.add_forbidden_by_node(graph_nodes[v1], graph_nodes[v2])
    for v1 in acute_vars:
        for v2 in pre_vars:
            bk.add_forbidden_by_node(graph_nodes[v1], graph_nodes[v2])
    for v1 in post_or_acute_vars:
        bk.add_forbidden_by_node(graph_nodes[v1], graph_nodes['period_at_covid_start']) # Forbid future features from influencing period_at_covid_start
    for v2 in pre_vars:
        bk.add_forbidden_by_node(graph_nodes['period_at_covid_start'], graph_nodes[v2]) # Forbid previous features from being influenced by period_at_covid_start

    # Suggestive relationships from evaluation
    if added_forbidden:
        for edge in added_forbidden:
            bk.add_forbidden_by_node(graph_nodes[edge[0]], graph_nodes[edge[1]])

    print(f"# Forbidden edges: {len(bk.forbidden_rules_specs)}")
    print(f"# Required edges: {len(bk.required_rules_specs)}")
    
    return bk

def build_background_knowledge(columns, added_required=None, added_forbidden=None):
    """
    Build forbidden and required edge sets based on temporal
    and domain constraints.

    Parameters
    ----------
    columns : list
        Variable names
    """
    forbidden = []
    required = []

    # --- Variable groups ---
    post_vars = [c for c in columns if c.startswith("post_")]
    acute_vars = [c for c in columns if c.startswith("acute_")]
    pre_vars = [c for c in columns if c.startswith("pre_")]
    pre_or_acute_vars = pre_vars + acute_vars
    post_or_acute_vars = post_vars + acute_vars

    # --- Root nodes ---
    all_roots = ['age', 'is_female', 'me_cfs', 'long_covid', 'fibromyalgia', 'dysautonomia', 'infection_episode']
    strict_roots = ['age', 'is_female', 'infection_episode']
    relaxed_roots = [r for r in all_roots if r not in strict_roots]

    # Strict roots: forbid all incoming edges
    for src in columns:
        for tgt in strict_roots:
            if src != tgt:
                forbidden.append((src, tgt))

    # Relaxed roots: forbid incoming edges from non-root nodes
    for src in columns:
        for tgt in relaxed_roots:
            if src != tgt and src not in all_roots:
                forbidden.append((src, tgt))

    # Treatment root node: forbid outgoing edges to relaxed root nodes
    for tgt in relaxed_roots:
        forbidden.append(("infection_episode", tgt))

    # --- Required edges ---
    required += [
        ("is_female", "me_cfs"),
        ("is_female", "period_at_covid_start"),
        ("age", "period_at_covid_start"),
    ]

    # Suggestive relationships from evaluation
    if added_required:
        for edge in added_required:
            required += [(edge[0],edge[1])]

    # --- Temporal forbidden edges ---
    for src in post_vars:
        for tgt in pre_or_acute_vars:
            forbidden.append((src, tgt))

    for src in acute_vars:
        for tgt in pre_vars:
            forbidden.append((src, tgt))

    for src in post_or_acute_vars:
        forbidden.append((src, "period_at_covid_start"))

    for tgt in pre_vars:
        forbidden.append(("period_at_covid_start", tgt))

    # Suggestive relationships from evaluation
    if added_forbidden:
        for edge in added_forbidden:
            forbidden.append(edge[0],edge[1])

    return forbidden, required

In [8]:
# === Model Runner ===
def run_model(model, df, outdir="graph_plots", alpha=0.075, added_required=None, added_forbidden=None):
    os.makedirs(outdir, exist_ok=True)

    # Initialize variables
    var_names = df.columns.tolist()
    fitted_model = None
    latent_confounders = set()

    # --- PC ---
    if model=="pc":
        # Build background knowledge
        bk = build_background_knowledge_cl(var_names, added_required=added_required, added_forbidden=added_forbidden)

        # Run model
        data_np = np.asarray(df, dtype=np.float64)
        result = pc(
            data_np,
            indep_test="kci",
            alpha=alpha,
            node_names=var_names,
            background_knowledge=bk
        )

        # Handle tuple vs. direct graph
        graph = result[0] if isinstance(result, tuple) else result
        graph = graph.G if hasattr(graph, "G") else graph  # unwrap to GeneralGraph

        # Convert CPDAG to DAG
        dag_graph = pdag2dag(graph)

        edges = {(e.get_node1().get_name(),
                e.get_node2().get_name(),
                (str(e.get_endpoint1()), str(e.get_endpoint2())))
                for e in dag_graph.get_graph_edges()}

    # --- FCI ---
    elif model=="fci":
        # Build background knowledge
        bk = build_background_knowledge_cl(var_names, added_required=added_required, added_forbidden=added_forbidden)

        # Run model
        data_np = np.asarray(df, dtype=np.float64)
        result = fci(
            data_np,
            depth=2,
            independence_test_method="kci",  # fisherz is fast and gives interpretable partial correlations, but can only be used on continuous vars
            # kci models nonlinear or complex interactions
            alpha=alpha,
            max_path_length=3,
            node_names=var_names,
            background_knowledge=bk
        )
        
        # Handle tuple vs. direct graph
        graph = result[0] if isinstance(result, tuple) else result
        graph = graph.G if hasattr(graph, "G") else graph  # unwrap to GeneralGraph

        # Extract edges
        edges = set()
        for e in graph.get_graph_edges():
            endpoint1 = str(e.get_endpoint1())
            endpoint2 = str(e.get_endpoint2())
            
            # Only include if it's a directed edge (TAIL -> ARROW)
            if endpoint1 == "TAIL" and endpoint2 == "ARROW":
                edges.add((e.get_node1().get_name(), 
                                e.get_node2().get_name(), (endpoint1, endpoint2)))
            elif endpoint1 == "ARROW" and endpoint2 == "TAIL":
                edges.add((e.get_node2().get_name(), 
                                e.get_node1().get_name(), (endpoint2, endpoint1)))
            # Save latent confounders
            elif endpoint1 == "ARROW" and endpoint2 == "ARROW":
                latent_confounders.add((e.get_node1().get_name(), 
                                e.get_node2().get_name(), (endpoint1, endpoint2)))

    # --- LiM ---
    elif model=="lingam":
        # Identify discrete vs continuous variables
        # Continuous (float) = 1, Discrete/binary/int = 0
        dis_con = np.ones((1, len(var_names)))
        for i, col in enumerate(var_names):
            if df[col].nunique() <= 4:
                dis_con[0, i] = 0  # discrete/binary variable
            else:
                dis_con[0, i] = 1  # continuous variable

        # Build background knowledge
        forbidden, required = build_background_knowledge(var_names, added_required=added_required, added_forbidden=added_forbidden)
        forbidden_edges = set(forbidden)  # convert to set for fast lookup

        # Run model
        data_np = np.asarray(df, dtype=np.float64)
        lingam_model = LiM( # best for mixed data types
            lambda1=0.085,     # weaker L1 (default is 0.1) -> more edges allowed
            max_iter=150,      # keep as is
            h_tol=1e-8,        # keep acyclicity tolerance
            rho_max=1e16,      # keep default rho
            w_threshold=1e-6   # preserve as many small weights as possible (not 0 for speed)
        )
        lingam_model.fit(data_np, dis_con, only_global=True) # only_global = False allows local search and usually adds more edges 
        adj_matrix = lingam_model.adjacency_matrix_
        fitted_model = lingam_model

        # Zero out forbidden edges in adjacency (prune)
        name_to_idx = {n: i for i, n in enumerate(var_names)}
        for src, dst in list(forbidden_edges):
            i = name_to_idx[src]; j = name_to_idx[dst]
            adj_matrix[i, j] = 0.0

        # Extract edges
        edges = set()
        for i, src in enumerate(var_names):
            for j, dst in enumerate(var_names):
                if adj_matrix[i, j] != 0:
                    edges.add((src, dst, ("TAIL", "ARROW")))

        # Save fitted model and pruned adjacency so you can use for diagnostics/SEM
        fitted_model = {"lingam": lingam_model, "adj_matrix_pruned": adj_matrix}
    
    # --- NOTEARS ---
    elif model in ["notears_linear", "notears_nonlinear"]:

        # Build background knowledge
        forbidden, required = build_background_knowledge(var_names, added_required=added_required, added_forbidden=added_forbidden)
        forbidden_edges = set(forbidden)  # convert to set for fast lookup

        # Run model
        df_for_notears = np.asarray(df, dtype=np.float64) 
        if model == "notears_linear":
            # Run NOTEARS linear
            W_est = notears_linear(
                df_for_notears,
                lambda1=0.02,       # the higher the value, the more edges are pruned
                loss_type='l2',
                max_iter=100,       # keep as is for accuracy
                h_tol=1e-8,         # don't loosen -- will reduce accuracy
                rho_max=1e16,       # default 1e16 ‚Üí decreasing will stop earlier if diverging
                w_threshold=1e-6    # preserve as many small weights as possible (not 0 for speed)
            )
        else:
            # Run NOTEARS nonlinear
            d = df_for_notears.shape[1] # number of features
            torch.set_default_dtype(torch.double)
            model_1 = NotearsMLP(dims=[d, 15, 1], bias=True) # 8 hidden units, 1 output per variable
            W_est = notears_nonlinear(
                model_1,
                df_for_notears.astype(np.double),
                lambda1=0.002,      # the higher the value, the more edges are pruned
                lambda2=0.07,       # add a bit of L2 regularization to stabilize edges across runs -- very sensitive, will affect run time
                max_iter=100,       # default 100 ‚Üí fewer dual steps for efficiency
                h_tol=1e-8,         # don't loosen -- will reduce accuracy
                rho_max=1e16,       # keep as is
                w_threshold=1e-6    # preserve as many small weights as possible (not 0 for speed)
            )
            # Checking thresholding
            low_edges = [(var_names[i], var_names[j], W_est[i,j])
                for i in range(len(var_names)) for j in range(len(var_names))
                if abs(W_est[i,j]) < 0.05 and W_est[i,j] != 0]
            print(f"Number of low-magnitude edges (<0.05): {len(low_edges)}")
        fitted_model = W_est

        # Adding edges
        edges_before = np.sum(W_est != 0)
        edges_pruned = 0
        edges = set()
        for i, src in enumerate(var_names):
            for j, dst in enumerate(var_names):
                if W_est[i, j] != 0 and (src, dst) not in forbidden_edges:
                    edges.add((src, dst, ("TAIL", "ARROW")))
                elif W_est[i, j] != 0 and (src, dst) in forbidden_edges:
                    edges_pruned += 1
        print(f"Edges before pruning: {edges_before}, pruned forbidden: {edges_pruned}")

    else:
        raise ValueError(f"Unsupported model: {model}")
    
    return {"edges": edges, "latent_confounders": latent_confounders}

def run_bootstrap_discovery(model, data, alpha, bootstrap, stage_level, required=None, forbidden=None):
    """Run causal discovery bootstraps and return edge stability + weights."""

    edge_results = []
    edge_presence, conf_presence = defaultdict(list), defaultdict(list)

    def single_run(seed):
        print("Starting single causal discovery run")
        print("Worker memory (MB):", psutil.Process(os.getpid()).memory_info().rss / 1e6)

        # (1) Bootstrap ‚Äî sampling with replacement (same number of rows)
        df_boot = data.sample(frac=1, replace=True, random_state=int(seed)).copy() # sample with replacement and produce same number of rows
        df_boot = df_boot.reset_index(drop=True)

        # (2) Subsample ‚Äî cap sample size to speed up KCI
        if model in ["pc", "fci"]:
            max_samples = 200
            if len(df_boot) > max_samples:
                df_boot = df_boot.sample(n=max_samples, random_state=int(seed))

        # Run the discovery model
        return run_model(model, df_boot, alpha=alpha, added_required=required, added_forbidden=forbidden)

    # Create unique seeds for each job to avoid RNG collisions
    seeds = [np.random.SeedSequence(42).spawn(bootstrap)[i].generate_state(1)[0] for i in range(bootstrap)]

    # Run in parallel
    run_outputs = Parallel(n_jobs=6, backend="loky", max_tasks_per_child=1)( # define max so that bootstraps don't overrun memory
        delayed(single_run)(int(seed)) for seed in seeds
    )
    print("Bootstraps finished. Now aggregating...")
    print(len(run_outputs))

    # Aggregate results
    for run_out in run_outputs:
        for (u, v, _) in run_out["edges"]:
            edge_presence[(u, v)].append(1.0)
        if isinstance(run_out.get("latent_confounders"), set):
            for (u, v, _) in run_out["latent_confounders"]:
                conf_presence[(u, v)].append(1.0)

    stability = {e: np.sum(pres) / bootstrap for e, pres in edge_presence.items()}
    confounders = {e: np.sum(pres) / bootstrap for e, pres in conf_presence.items()}

    final_edges = [e for e, p in stability.items()]

    edge_results.append({
        "model": model,
        "stage": stage_level,
        "alpha": alpha,
        "n_bootstrap": bootstrap,
        "num_edges": len(final_edges),
        "edges": final_edges,
        "stability": stability,
        "confounders": confounders
    })
    return edge_results, stability

In [9]:
# === Edge Stability ===
def summarize_edge_stability(results_df):
    """
    Summarize edge stability with table per (model, feature_set, alpha, n_bootstrap).
    """
    summaries, conf_summaries = [], []

    for i, row in results_df.iterrows():
        stability = row["stability"]
        confounders = row["confounders"]
        model = row["model"]
        edges = list(stability.keys())
        conf_edges = list(confounders.keys())

        for edge in edges:
            record = {
                "model": row["model"],
                "n_bootstrap": row["n_bootstrap"],
                "edge": edge,
                "freq": stability[edge]
            }
            summaries.append(record)
        
        for c_edge in conf_edges:
            conf_record = {
                "model": row["model"],
                "n_bootstrap": row["n_bootstrap"],
                "edge": c_edge,
                "freq": confounders[c_edge]
            }
            conf_summaries.append(conf_record)

    return pd.DataFrame(summaries), pd.DataFrame(conf_summaries)

# === Consensus Graphs ===
def build_consensus_graph(model, stability, required_edges, freq_threshold=0.5, added_required=None):
    G = nx.DiGraph()
    forced_edges = []
        
    for (u, v), stats in stability.items():
        freq = stats
        if freq >= freq_threshold:
            G.add_edge(u, v, weight=freq)

    # Add required edges if missing, and note them
    required_set = set(required_edges)
    added_required_set = set(added_required) if added_required else set()
    for src, tgt in required_edges:
        if not G.has_edge(src, tgt):
            G.add_edge(src, tgt, weight=1, forced=True)
            forced_edges.append((src, tgt))
        else:
            # Even if discovered, mark it
            G[src][tgt]["forced"] = True  

    if forced_edges:
        print(f"‚ö†Ô∏è Forced edges added: {forced_edges}")

    # Fixing cycles in graph
    if not nx.is_directed_acyclic_graph(G):
        print("‚ö†Ô∏è Cycle detected in consensus graph!")

    # More efficient cycle breaking: break one cycle at a time using prioritization
    # Safety cap: don't remove more than 2 * E edges (shouldn't be needed)
    max_removals = max(1000, 2 * G.number_of_edges())
    removals = 0

    while True:
        try:
            # Get one cycle (NetworkX raises if none)
            cycle_edges = nx.find_cycle(G, orientation="original")
            # nx.find_cycle returns list of (u, v, dir) tuples
            # convert to simple cycle node order for clarity
            cycle_nodes = []
            for u, v, _ in cycle_edges:
                cycle_nodes.append(u)
            # append last node if not present
            if cycle_nodes and cycle_nodes[-1] != cycle_edges[-1][1]:
                cycle_nodes.append(cycle_edges[-1][1])

            # gather edges (u->v) around the cycle in order
            edge_triplets = []
            for u, v in zip(cycle_nodes, cycle_nodes[1:] + [cycle_nodes[0]]):
                if G.has_edge(u, v):
                    attr = G[u][v]
                    weight = attr.get("weight", 0.0)
                    is_required = (u, v) in required_set
                    is_added_required = (u, v) in added_required_set
                    edge_triplets.append({
                        "edge": (u, v),
                        "weight": weight,
                        "is_required": is_required,
                        "is_added_required": is_added_required,
                        "forced": attr.get("forced", False)
                    })
            
            print(f"Total edges BEFORE removal: {G.number_of_edges()}")

            # Skip empty cycles that have been resolved
            if not edge_triplets:
                print(f"‚ö†Ô∏è No actionable edges found for cycle {cycle_nodes}, skipping.")
                continue

            # 1. Prefer to remove non-required edges first
            non_required = [e for e in edge_triplets if not e["is_required"]]
            if non_required:
                edge_to_remove = min(non_required, key=lambda x: x["weight"])
            else:
                # 2. If all are required, prefer to remove those in added_required
                in_added_required = [e for e in edge_triplets if e["is_added_required"]]
                if in_added_required:
                    edge_to_remove = min(in_added_required, key=lambda x: x["weight"]) # TO DO: choose based on smallest feature importance?
                else:
                    # 3. If all are biologically required, just pick the lowest weight
                    edge_to_remove = min(edge_triplets, key=lambda x: x["weight"])

            if edge_to_remove["is_required"] and not edge_to_remove["is_added_required"]:
                print("üö® WARNING: Removing a biologically required edge to break cycle!")
            u, v = edge_to_remove["edge"]
            w = G[u][v].get("weight", None)
            print(f"Removing edge {(u, v, w)} to break cycle {cycle_nodes}")
            G.remove_edge(u, v)
            removals += 1

            print(f"Final DAG edges: {G.number_of_edges()}")
            print(f"   (removed {removals} edges to break cycles)")

            if removals >= max_removals:
                raise RuntimeError(f"Exceeded max removals ({max_removals}) while breaking cycles ‚Äî aborting to avoid infinite loop.")

            # continue until no cycle
            if nx.is_directed_acyclic_graph(G):
                print("No more cycles detected in graph")
                break

        except nx.exception.NetworkXNoCycle:
            # no cycles left
            break

    # final check
    if not nx.is_directed_acyclic_graph(G):
        raise RuntimeError("Failed to produce DAG after cycle-breaking loop.")
            
    # Apply this to all nodes in the NetworkX graph before converting to pydot
    mapping = {node: sanitize_node_name(node) for node in G.nodes()} 
    G = nx.relabel_nodes(G, mapping)

    return G

def sanitize_node_name(node: str) -> str:
    # Remove semicolons, parentheses, spaces, and replace illegal characters with underscore
    node = node.replace(";", "").strip()
    node = re.sub(r'[^A-Za-z0-9_]', '_', node)
    return node

def plot_consensus_graph(G, model, threshold=0.5, outdir="graph_plots", stage_level="initial", alpha=0.01, n_bootstrap=5, dpi=250):
    """
    Plot a consensus graph using pydot (Graphviz) for cleaner layouts.
    """

    # Convert to pydot
    pydot_graph = to_pydot(G)

    # Optional: set node styles/colors by phase
    for node in pydot_graph.get_nodes():
        original_name = node.get_name().strip('"')
        node.set_label(original_name)  # keeps readable name
        if "pre" in original_name:
            node.set_fillcolor("skyblue")
        elif "acute" in original_name:
            node.set_fillcolor("lightgreen")
        elif "post" in original_name:
            node.set_fillcolor("salmon")
        else:
            node.set_fillcolor("gray")
        node.set_style("filled")
        node.set_fontsize(10)

    # Style edges
    for edge in pydot_graph.get_edges():
        src = sanitize_node_name(edge.get_source().strip('"'))
        tgt = sanitize_node_name(edge.get_destination().strip('"'))
        forced = G[src][tgt].get("forced", False)
        weight = G[src][tgt].get("weight", 0.5)

        if forced:
            # Bold red for prior knowledge edges
            edge.set_color("red")
            edge.set_penwidth(2)
            edge.set_style("bold")
        else:
            # Scale penwidth by weight (bootstrap stability)
            edge.set_color("black")
            edge.set_penwidth(1 + 2*weight)  # thicker if higher stability
            edge.set_style("solid")

    # Add title as a graph label
    if model=="fci" or model=="pc":
        title=f"{model.upper()} | alpha={alpha} | bootstraps={n_bootstrap} | thr={threshold}"
    else:  
        title=f"{model.upper()} | bootstraps={n_bootstrap} | thr={threshold}"
    pydot_graph.set_label(title)
    pydot_graph.set_labelloc("t")
    pydot_graph.set_labeljust("c")

    # File naming convention
    if stage_level=="initial":
        if model in ["pc", "fci"]:
            fname_base = f"{model}_a{alpha}_b{n_bootstrap}_thr{threshold}"
        else:
            fname_base = f"{model}_b{n_bootstrap}_thr{threshold}"
    else:
        if model in ["pc", "fci"]:
            fname_base = f"{model}_a{alpha}_b{n_bootstrap}_thr{threshold}_2"
        else:
            fname_base = f"{model}_b{n_bootstrap}_thr{threshold}_2"

    os.makedirs(outdir, exist_ok=True)

    # Export
    pdf_path = os.path.join(outdir, fname_base + ".pdf")
    pydot_graph.write_pdf(pdf_path)

    print(f"Saved Graphviz consensus graph to {pdf_path}")

In [10]:
# === Node Evaluation ===
def evaluate_nodes(scm, data, nodes, expand, causal_mech=True, mech_base=True, overall_kl=False, invert_assumpt=False, causal_struct=False):
    """Safely evaluate for any cut of nodes."""

    if isinstance(nodes, str):
        nodes = [nodes]

    # Always include ancestors (required for non-root mechanisms)
    if expand:
        ancestors = set()
        for node in nodes:
            ancestors |= nx.ancestors(scm.graph, node)
        expanded_nodes = list(set(nodes) | ancestors)
    else:
        expanded_nodes = nodes

    # Make a subgraph with only those nodes
    print("Setting causal mechanisms")
    scm_sub = gcm.InvertibleStructuralCausalModel(scm.graph.subgraph(expanded_nodes).copy())
    for node in expanded_nodes:
        scm_sub.set_causal_mechanism(node, scm.causal_mechanism(node))

    # Run evaluation
    result = gcm.evaluate_causal_model(
        scm_sub,
        data[expanded_nodes],
        evaluate_causal_mechanisms=causal_mech,  # calculates normalized continuous ranked probability score (all nodes), MSE, normalized MSE, and R2 (continuous), and F1 (categorical)
        compare_mechanism_baselines=mech_base, # compares the causal mechanisms with baseline models to see if there are model choices that perform significantly better
        evaluate_invertibility_assumptions=invert_assumpt, # tests statistical independence between the reconstructed noise and the used input samples
        evaluate_overall_kl_divergence=overall_kl, # tests KL divergence between the generated and the observed data
        evaluate_causal_structure=causal_struct # evaluates to find substantial evidence to refute the causal graph based on the provided data
    )

    return result

def describe_mechanism(mech, verbose=False):
    """Return human-readable description of a causal mechanism and its underlying model or distribution."""
    if mech is None:
        return ("None", "None", {})

    mech_type = type(mech).__name__
    model_name, model_params = "None", {}

    # --- Case 1: Regression-based mechanisms ---
    if hasattr(mech, "prediction_model") and mech.prediction_model is not None:
        pm = mech.prediction_model
        inner_model = None
        for attr in ["model", "_model", "_sklearn_mdl", "sklearn_model", "estimator", "clf", "regressor"]:
            if hasattr(pm, attr) and getattr(pm, attr) is not None:
                inner_model = getattr(pm, attr)
                break
        if inner_model is not None:
            model_name = type(inner_model).__name__
            if hasattr(inner_model, "get_params"):
                model_params = inner_model.get_params()
            else:
                model_params = {k: v for k, v in vars(inner_model).items() if not k.startswith("_")}
        else:
            model_name = type(pm).__name__

    # --- Case 2: Classification-based mechanisms ---
    elif hasattr(mech, "classifier_model") and mech.classifier_model is not None:
        cm = mech.classifier_model
        inner_model = None
        for attr in ["model", "_model", "_sklearn_mdl", "sklearn_model", "estimator", "clf"]:
            if hasattr(cm, attr) and getattr(cm, attr) is not None:
                inner_model = getattr(cm, attr)
                break
        if inner_model is not None:
            model_name = type(inner_model).__name__
            if hasattr(inner_model, "get_params"):
                model_params = inner_model.get_params()
            else:
                model_params = {k: v for k, v in vars(inner_model).items() if not k.startswith("_")}
        else:
            model_name = type(cm).__name__

    # --- Case 3: Custom distribution with .parameters() ---
    elif hasattr(mech, "parameters"):
        param_attr = mech.parameters
        if callable(param_attr):
            try:
                model_params = param_attr()
            except Exception:
                model_params = {}
        else:
            model_params = param_attr
        model_name = type(mech).__name__

    # --- Case 4: Wrapped distribution inside ScipyDistribution / custom ---
    if hasattr(mech, "_parameters") and "dist" in mech._parameters:
        dist = mech._parameters["dist"]
        model_name = type(dist).__name__  # e.g., GaussianDistribution
        # Extract parameters if they exist
        if hasattr(dist, "__dict__"):
            model_params = {k: v for k, v in vars(dist).items() if not k.startswith("_")}

    # --- Truncate large parameter sets for readability ---
    if model_params and len(model_params) > 20:
        model_params = {k: model_params[k] for k in list(model_params)[:15]}

    if verbose:
        print(f"[describe_mechanism] {mech_type} -> {model_name}, params keys={list(model_params.keys())}")

    return mech_type, model_name, model_params

def summarize_and_flag_nodes(model, eval_result, scm, kl_thr, crps_thr, f1_thr, nodes=None, causal_mech=True, invert_assumpt=False):
    """
    Summarizes SCM evaluation into a node-level table, including:
      - Mechanism type
      - Number of parents
      - MSE/R¬≤/CRPS where available
      - Invertibility assumption rejected (TRUE/FALSE)
      - Flag for nodes needing attention (invertibility rejected)
    """

    rows = []
    if isinstance(nodes, str):
        nodes = [nodes]
        
    if nodes is None:
        nodes = scm.graph.nodes

    for node in nodes:
        # Mechanism type
        mech = scm.causal_mechanism(node)
        mech_type, pred_model, model_params = describe_mechanism(mech)

        # Parents count
        n_parents = len(list(scm.graph.predecessors(node)))

        passed_metric_threshold = None
        # Performance metrics
        if causal_mech:
            perf = getattr(eval_result, "mechanism_performances", {}).get(node, None)
            mse, nmse, r2, crps, f1 = None, None, None, None, None
            if perf:
                mse  = getattr(perf, "mse", None)
                nmse = getattr(perf, "nmse", None)
                r2   = getattr(perf, "r2", None)
                crps = getattr(perf, "crps", None)
                f1   = getattr(perf, "f1", None)
                kl = getattr(perf, "kl_divergence", None)

            # Create a performance threshold check
            if kl is not None:
                passed_metric_threshold = (kl <= kl_thr)
            elif crps is not None:
                passed_metric_threshold = (crps <= crps_thr)
            elif f1 is not None:
                passed_metric_threshold = (f1 >= f1_thr)
            else:
                passed_metric_threshold = None  # no available metric
        
        # Invertibility check
        p_value = None
        rejected = None
        if invert_assumpt:
            pnl = getattr(eval_result, "pnl_assumptions", {}).get(node, (None, False, None))
            if isinstance(pnl, tuple):
                p_value, rejected, _ = pnl

        rows.append({
            "model": model,
            "node": node,
            "n_parents": n_parents,
            "mechanism": mech_type,
            "distribution/model": pred_model,
            "model_params": model_params,
            "mse": mse if causal_mech else None,
            "nmse": nmse if causal_mech else None, 
            "r2": r2 if causal_mech else None,
            "crps": crps if causal_mech else None,
            "f1": f1 if causal_mech else None,
            "kl": kl if causal_mech else None,
            "p_value": p_value if invert_assumpt else None,
            "invert_assumption_rejected": rejected if invert_assumpt else None, 
            "passed_metric_threshold": passed_metric_threshold
        })
    
    df_summary = pd.DataFrame(rows)

    return df_summary

In [11]:
# === Node Refinement ===
def try_mechanism(mech_name, mech, scm, model, data, nodes, summary_fn, metric, kl_thr, crps_thr, f1_thr):
    "Helper function for mechanism evaluation"

    print("Refining nodes:", nodes)
    try:
        t0 = time.time()
        scm_local = copy.deepcopy(scm)
        for node in nodes:
            scm_local.set_causal_mechanism(node, mech)
        gcm.fit(scm_local, data)
        t1=time.time()
        fit_time=t1-t0
        print(f"‚è±Ô∏è fit time: {(fit_time):.2f}s")
        
        # Evaluate nodes
        node_eval = evaluate_nodes(scm_local, data, nodes, expand=True)
        t2=time.time()
        eval_time=t2-t1
        print(f"‚è±Ô∏è eval time: {(eval_time):.2f}s")

        # Create node summary  
        fit_summary = summary_fn(model, node_eval, scm_local, kl_thr, crps_thr, f1_thr, nodes, causal_mech=True, invert_assumpt=False)
        metric_avg_nodes = fit_summary[metric].mean()
        print(f"{mech_name.upper()}: {metric} average = {metric_avg_nodes}")
        return mech_name, mech, metric, metric_avg_nodes, fit_summary

    except Exception as e:
        print(f"‚ö†Ô∏è nodes failed with {mech_name} ({e})")
        return mech_name, mech, metric, None, None

def refine_node_mechanisms(
    scm, model, data, init_eval, kl_thr, crps_thr, f1_thr, nodes, mechanism_candidates, metric, summary_fn, stage_label, parallel=True):
    """
    Generic refinement loop for a single node.
    Integrates logging, reversion, and early stopping.
    """

    # Initialize
    node_log = []
    best_mech = None
    init_metric_avg = init_eval[init_eval['node'].isin(nodes)][metric].mean()
    best_metric_avg = init_metric_avg

    # Parallel or serial evaluation
    results = (
        Parallel(n_jobs=min(6, max(1, len(mechanism_candidates))), backend="loky", verbose=5)(
            delayed(try_mechanism)(name, mech, scm, model, data, nodes, summary_fn, metric, kl_thr, crps_thr, f1_thr) for name, mech in mechanism_candidates.items()
        )
        if parallel else
        [try_mechanism(name, mech, scm, model, data, nodes, summary_fn, metric, kl_thr, crps_thr, f1_thr) for name, mech in mechanism_candidates.items()]
    )

    # Find best mechanism by metric direction
    for mech_name, mech, metric, new_metric_avg, fit_summary in results:
        
        current_mech_type, current_model_name, current_params = describe_mechanism(mech)
        if metric in ["kl", "crps"]:
            current_model_name = mech_name
            if new_metric_avg < best_metric_avg:
                improved = True
            else: 
                improved = False
        else:
            if new_metric_avg > best_metric_avg:
                improved = True
            else:
                improved = False

        node_log.append({
            "model": model,
            "nodes": nodes,
            "metric": metric,
            "initial_metric_avg": init_metric_avg,
            "new_metric_avg": new_metric_avg,
            "stage": stage_label,
            "current_mechanism": current_mech_type,
            "current_model": current_model_name,
            "current_params": json.dumps(current_params, default=str),
            "improved": improved
        })

        if improved:
            print(f"‚úÖ {metric} nodes improved ({init_metric_avg} ‚Üí {new_metric_avg})")
            best_metric_avg, best_mech, best_summary = new_metric_avg, mech, fit_summary

    # Final mechanism assignment  
    if (metric in ["kl", "crps"] and (best_metric_avg >= init_metric_avg)) or (metric not in ["kl", "crps"] and (best_metric_avg <= init_metric_avg)):
        print(f"‚ö†Ô∏è {metric} nodes: performance worsened or did not improve ‚Äî reverting.")
        
        initial_mech_type, initial_model_name, initial_params = [], [], []
        for node in nodes:
            mech = scm.causal_mechanism(node)
            mech_type, model_name, params = describe_mechanism(mech)
            initial_mech_type.append(mech_type)
            initial_model_name.append(model_name)
            initial_params.append(params)
        
        node_log.append({
            "model": model,
            "nodes": nodes,
            "metric": metric,
            "initial_metric_avg": init_metric_avg,
            "new_metric_avg": init_metric_avg,
            "stage": stage_label,
            "current_mechanism": initial_mech_type,
            "current_model": initial_model_name,
            "current_params": json.dumps(initial_params, default=str),
            "improved": False
        })
    else:
        # Need to assign best mech since the other assignment was on a copy of an SCM
        print(f"üîß {metric} nodes: final mechanism set to {type(best_mech).__name__} "
            f"({type(getattr(best_mech, 'prediction_model', getattr(best_mech, 'dist', None))).__name__})")
        for node in nodes:
            scm.set_causal_mechanism(node, best_mech)


    # Refit after assigning mechanisms
    gcm.fit(scm, data)

    return node_log

def generate_continuous_mechs():
    """Generate continuous model candidates including hyperparameter sweeps."""
    configs = {}

    # Linear baseline
    configs["linear"] = gcm.AdditiveNoiseModel(
        prediction_model=gcm.ml.create_linear_regressor()
    )

    # Random Forests with varying n_estimators
    for n in [50, 100, 200]:
        key = f"rf_{n}"
        configs[key] = gcm.AdditiveNoiseModel(
            prediction_model=SklearnRegressionModel(
                RandomForestRegressor(n_estimators=n, max_depth=8,  # Prevent overly deep trees
                min_samples_leaf=5, n_jobs=-1, random_state=0)
            )
        )

    # Gradient Boosting
    for lr in [0.01, 0.05, 0.1]:
        key = f"gbm_lr{lr}"
        configs[key] = gcm.AdditiveNoiseModel(
            prediction_model=SklearnRegressionModel(
                GradientBoostingRegressor(learning_rate=lr, n_estimators=150, max_depth=3, random_state=0)
            )
        )

    # MLPs with varying hidden sizes
    for size in [(20,), (50,), (100,)]:
        key = f"mlp_{size[0]}"
        configs[key] = gcm.AdditiveNoiseModel(
            prediction_model=SklearnRegressionModel(
                MLPRegressor(hidden_layer_sizes=size, max_iter=150, solver='adam', early_stopping=True, n_iter_no_change=5,
                tol=1e-3, random_state=0)
            )
        )

    return configs

def generate_discrete_mechs():
    """Generate discrete model candidates including hyperparameter sweeps."""
    configs = {}

    # Logistic Regression varying C
    for c in [0.01, 0.1, 1.0, 10]:
        key = f"lg_{c}"
        configs[key] = gcm.ClassifierFCM(
        classifier_model=SklearnClassificationModel(LogisticRegression(max_iter=500, C=c))
        )

    # Random Forest Classifier
    for n in [50, 100, 200]:
        key = f"rf_{n}"
        configs[key] = gcm.ClassifierFCM(
        classifier_model=SklearnClassificationModel(RandomForestClassifier(n_estimators=n, random_state=0)
            )
        )

    return configs

def evaluate_and_refine_scm(G, model, data, discrete_nodes, stage_label, kl_thr=1.0, crps_thr=0.35, f1_thr=0.6):
    # Create SCM
    scm = gcm.InvertibleStructuralCausalModel(G)
    gcm.auto.assign_causal_mechanisms(scm, data, quality = AssignmentQuality.BETTER)
    gcm.fit(scm, data)

    # Run initial evaluation
    eval_result = evaluate_nodes(scm, data, data.columns, expand=False)
    eval_summary = summarize_and_flag_nodes(model, eval_result, scm, kl_thr, crps_thr, f1_thr)
    eval_summary['stage'] = stage_label

    # Determine which nodes need refinement, categorized by metric
    nodes = eval_summary['node'].values
    parents_count = np.array([len(list(G.predecessors(n))) for n in nodes])
    is_discrete_mask = np.isin(nodes, discrete_nodes)

    kl_mask = (parents_count == 0) & eval_summary['kl'].notna() & (eval_summary['kl'] > kl_thr)
    crps_mask = (~is_discrete_mask) & eval_summary['crps'].notna() & (eval_summary['crps'] > crps_thr)
    f1_mask = (is_discrete_mask) & eval_summary['f1'].notna() & (eval_summary['f1'] < f1_thr)

    kl_nodes = nodes[kl_mask].tolist()
    crps_nodes = nodes[crps_mask].tolist()
    f1_nodes = nodes[f1_mask].tolist()

    node_log = pd.DataFrame()
    if not (len(kl_nodes) or len(crps_nodes) or len(f1_nodes)):
        return scm, eval_summary, node_log

    print(f"Nodes flagged for KL refinement ({len(kl_nodes)}): {kl_nodes}")
    print(f"Nodes flagged for R¬≤ refinement ({len(crps_nodes)}): {crps_nodes}")
    print(f"Nodes flagged for F1 refinement ({len(f1_nodes)}): {f1_nodes}")

    # Populate mechanism candidates
    root_mechs = {
        "Bernoulli": gcm.ScipyDistribution(dist=BernoulliDistribution()),
        "Beta": gcm.ScipyDistribution(dist=BetaDistribution()),
        "GMM": gcm.ScipyDistribution(dist=GaussianMixtureDistribution()),
        "Gaussian": gcm.ScipyDistribution(dist=GaussianDistribution())
    }
    continuous_mechs = generate_continuous_mechs()
    discrete_mechs = generate_discrete_mechs()

    # Initialize logs in case no nodes trigger refinement
    node_log_kl, node_log_crps, node_log_f1 = [], [], []

    # Refine nodes in each list
    print(f"Refining nodes via KL")
    if len(kl_nodes) > 0:
        node_log_kl = refine_node_mechanisms(
            scm, model, data, eval_summary, kl_thr, crps_thr, f1_thr, kl_nodes, root_mechs, metric="kl",
            summary_fn=summarize_and_flag_nodes, stage_label=f"{stage_label}_root_refine", parallel=True
        )
    
    print(f"Refining nodes via CRPS")
    if len(crps_nodes) > 0:
        node_log_crps = refine_node_mechanisms(
            scm, model, data, eval_summary, kl_thr, crps_thr, f1_thr, crps_nodes, continuous_mechs, metric="crps", 
            summary_fn=summarize_and_flag_nodes, stage_label=f"{stage_label}_cont_refine", parallel=True
        )

    print(f"Refining nodes via F1")
    if len(f1_nodes) > 0:
        node_log_f1 = refine_node_mechanisms(
            scm, model, data, eval_summary, kl_thr, crps_thr, f1_thr, f1_nodes, discrete_mechs, metric="f1",
            summary_fn=summarize_and_flag_nodes, stage_label=f"{stage_label}_disc_refine", parallel=True
        )
    node_log = pd.concat([pd.DataFrame(node_log_kl), pd.DataFrame(node_log_crps), pd.DataFrame(node_log_f1)])
    
    return scm, eval_summary, node_log

In [12]:
# === Benchmark Evaluation
def run_benchmark_evaluation(scm, model, data, discrete_nodes, stage_level, forbidden_edges, post_fit_summary, n_bootstrap=20):
    """
    Evaluate node-level predictive performance using a benchmark model
    trained with all possible (non-forbidden) parents.
    
    Returns:
        benchmark_results: list of dicts with per-node benchmark metrics
        added_required: list of (p, node) edges to add
        added_forbidden: list of (p, node) edges to forbid
    """
    benchmark_results = []
    added_required = []
    added_forbidden = []

    for node in scm.graph.nodes:
        print(f"benchmark for {node}")
        t0 = time.time()
        parents = list(scm.graph.predecessors(node))

        # Build simulated full-parent graph for this node
        candidate_parents = [
            p for p in scm.graph.nodes
            if p != node and (p, node) not in forbidden_edges
        ]
        if not candidate_parents:
            continue

        X = data[candidate_parents]
        y = data[node]
        importance_accum = {p: 0.0 for p in X.columns}

        # Select model + metric
        if len(parents)==0:
            bench_model = RandomForestRegressor(n_estimators=200, random_state=0)
            metric_fn = lambda yt, yp: r2_score(yt, yp)
            scm_metric = "kl"
            metric = "r2"
        elif node in discrete_nodes:
            bench_model = RandomForestClassifier(n_estimators=200, random_state=0)
            metric_fn = lambda yt, yp: f1_score(yt, yp, average="weighted")
            scm_metric = "f1"
            metric = "f1"
        else:
            bench_model = RandomForestRegressor(n_estimators=200, random_state=0)
            metric_fn = lambda yt, yp: r2_score(yt, yp)
            scm_metric = "crps"
            metric = "r2"

        # Bootstrap importance accumulation
        for _ in range(n_bootstrap):
            X_bs, y_bs = resample(X, y, random_state=None)
            bench_model.fit(X_bs, y_bs)
            for p, imp in zip(X.columns, bench_model.feature_importances_):
                importance_accum[p] += imp
        print(f"‚è±Ô∏è fit time: {(time.time()-t0):.2f}s")
        t1 = time.time()

        # Compute average importance
        feat_importance_dict = {p: imp / n_bootstrap for p, imp in importance_accum.items()}
        y_pred = bench_model.predict(X)
        metric_val = metric_fn(y, y_pred)
        print(f"‚è±Ô∏è evaluation time: {(time.time()-t1):.2f}s")

        # Compare to SCM performance
        scm_score = post_fit_summary.loc[
            post_fit_summary["node"] == node, scm_metric
        ].values[0] if node in post_fit_summary["node"].values else None

        # Get mechanism of SCM
        mech = scm.causal_mechanism(node)
        mech_type, model_name, model_params = describe_mechanism(mech)

        # Option to manipulate graph manipulation based on feature importance with added_required output
        top_parents = [p for p, imp in sorted(feat_importance_dict.items(), key=lambda x: -x[1]) if imp > 0.15]
        top_parents_dict = {p: imp 
            for p, imp in sorted(feat_importance_dict.items(), key=lambda x: x[1], reverse=True) if imp > 0.15
        }        
        for p in top_parents:
            if p not in scm.graph.predecessors(node):
                added_required.append((p, node))

        # Save results
        benchmark_results.append({
            "model": model,
            "node": node,
            "stage": stage_level,
            "n_scm_parents": len(parents),
            "n_all_poss_parents": len(candidate_parents),
            "scm_metric": scm_metric,
            "scm_mechanism": mech_type,
            "scm_model": model_name,
            "scm_model_params": json.dumps(model_params, default=str), 
            "scm_score": scm_score,
            "metric": metric,
            "benchmark_score": metric_val,
            "benchmark_top_parents": top_parents_dict
        })

    return benchmark_results, added_required, added_forbidden

# Pipeline

In [13]:
# Model parameters
models = ["pc", "fci", "lingam", "notears_linear", "notears_nonlinear"]
alpha = 0.1  # default 0.05, the higher the value the more edges
bootstrap = 100 
freq_thresholds = [0.2, 0.35, 0.5]
KL_THRESHOLD = 1.0
CRPS_THRESHOLD = 0.35
F1_THRESHOLD = 0.6

gcm.util.general.set_random_seed(0)

In [None]:
start = time.time()
print("‚è±Ô∏è started at:", datetime.datetime.fromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S"))

# Initialize variables
discrete_nodes = ['is_female', 'long_covid', 'me_cfs', 'fibromyalgia', 'dysautonomia', 'period_at_covid_start', 'pre_crash', 
                 'acute_crash', 'post_crash', 'pre_noncovid_infection', 'acute_noncovid_infection', 'post_noncovid_infection']
consensus_graphs = {}
fit_summary, benchmark_summary, node_log_summary, edge_summary = [], [], [], []
model_datasets = {}

for model in models:
    
    model_start = time.time()
    print("‚è±Ô∏è model loop started at:", datetime.datetime.fromtimestamp(model_start).strftime("%Y-%m-%d %H:%M:%S"))
    print(f"\n=== Running {model.upper()} | Œ±={alpha} | boot={bootstrap} ===")

    # Initialize variables
    fit_summary_model = []
    benchmark_summary_model = []
    node_log_summary_model = []

    # --- Run causal discovery sequentially ---
    features = final_df_std.columns.tolist()
    edge_results, stability = run_bootstrap_discovery(model, final_df_std, alpha, bootstrap, stage_level="initial", required=None, forbidden=None)
    edge_results_df = pd.DataFrame(edge_results)
    causal_disc_time = time.time()
    print(f"‚è±Ô∏è causal discovery time: {(causal_disc_time-model_start):.2f}s")

    # Save results
    initial_edge_summary, initial_conf_summary = summarize_edge_stability(edge_results_df)
    initial_edge_summary['stage'] = 'initial'
    initial_conf_summary['stage'] = 'initial'
    edge_summary.append(initial_edge_summary)

    # --- Iterate through results to analyze and manipulate models/graphs ---
    for freq_threshold in freq_thresholds:
        
        threshold_time = time.time()
        print("‚è±Ô∏è threshold loop started at:", datetime.datetime.fromtimestamp(threshold_time).strftime("%Y-%m-%d %H:%M:%S"))

        # --- Create graphical representations of discovery results ---
        # Calculate stability 
        subset = edge_results_df[(edge_results_df["model"] == model) & (edge_results_df["alpha"] == alpha) & (edge_results_df["n_bootstrap"] == bootstrap)]
        stability_final = subset.iloc[0]["stability"]

        # Populate background knowledge
        forbidden, required = build_background_knowledge(features)
        forbidden_edges = set(forbidden)  # convert to set for fast lookup
        required_edges = set(required)

        # Build graph with stability scores
        G = build_consensus_graph(model, stability_final, required_edges, freq_threshold=freq_threshold) 

        # Add isolated nodes before plotting
        all_nodes = set(final_df_std.columns) 
        graph_nodes = set(G.nodes)
        missing_nodes = all_nodes - graph_nodes
        if missing_nodes:
            print(f"Adding {len(missing_nodes)} isolated nodes to graph: {missing_nodes}")
            G.add_nodes_from(missing_nodes) 

        # Plot and store graph    
        key = (model, alpha, bootstrap, freq_threshold)
        consensus_graphs[key] = G
        graph_dir = Path("graph_objects")
        graph_dir.mkdir(exist_ok=True)
        key_str = f"{model}_b{bootstrap}_{freq_threshold}".replace(".", "_")
        with open(graph_dir / f"{key_str}_initial.pkl", "wb") as f:
            pickle.dump(G, f)
        plot_consensus_graph(G, model=model, threshold=freq_threshold, outdir="graph_plots", stage_level="initial", alpha=alpha, n_bootstrap=bootstrap)

        graph_time = time.time()
        print(f"‚è±Ô∏è graph building time: {(graph_time-threshold_time):.2f}s")

        # --- Graph evaluation and refinement ---
        print(f"\n===== Evaluating Graph ({key[0]}): alpha={key[1]}, bootstraps={key[2]}, threshold={key[3]} =====")
        scm_post_nr, init_fit_summary, node_log = evaluate_and_refine_scm(G, model, final_df_std, discrete_nodes, stage_label='initial', 
        kl_thr=KL_THRESHOLD, crps_thr=CRPS_THRESHOLD, f1_thr=F1_THRESHOLD)

        weak_nodes = init_fit_summary.loc[
            (init_fit_summary['crps']<CRPS_THRESHOLD)|((init_fit_summary["kl"] > KL_THRESHOLD))|( (init_fit_summary["f1"] < F1_THRESHOLD))]["node"].tolist()

        print(f"‚ö†Ô∏è Initial weak nodes {len(weak_nodes)}: {weak_nodes}")
        
        # Post node refinement evaluation
        print("\nüìä Running post node refinement evaluation...")
        post_nr_eval = evaluate_nodes(scm_post_nr, final_df_std, final_df_std.columns, expand=False)
        post_nr_fit_summary = summarize_and_flag_nodes(model, post_nr_eval, scm_post_nr, kl_thr=KL_THRESHOLD, crps_thr=CRPS_THRESHOLD, f1_thr=F1_THRESHOLD)
        post_nr_fit_summary['stage'] = 'post node refinement'

        node_eval_time = time.time()
        print(f"‚è±Ô∏è node evaluation time: {(node_eval_time-graph_time):.2f}s") 

        # Save the refined graph structure and SCM (structure + mechanisms) separately
        G_refined = scm_post_nr.graph
        with open(graph_dir / f"{key_str}_refined.pkl", "wb") as f:
            pickle.dump(G_refined, f)
        print(f"‚úì Saved refined graph structure to {key_str}_refined.pkl")
        with open(graph_dir / f"{key_str}_refined_scm.pkl", "wb") as f:
            pickle.dump(scm_post_nr, f)
        print(f"‚úì Saved refined SCM to {key_str}_refined_scm.pkl")

        weak_nodes = post_nr_fit_summary.loc[
            (post_nr_fit_summary['crps']<CRPS_THRESHOLD)|((post_nr_fit_summary["kl"] > KL_THRESHOLD))|( (post_nr_fit_summary["f1"] < F1_THRESHOLD))]["node"].tolist()
        print(f"‚ö†Ô∏è Remaining weak nodes {len(weak_nodes)}: {weak_nodes}")

        # --- Benchmark evaluation for nodes ---
        benchmark_results, added_required, added_forbidden = run_benchmark_evaluation(scm_post_nr, model, final_df_std, discrete_nodes, stage_level='initial', 
        forbidden_edges=forbidden_edges, post_fit_summary=post_nr_fit_summary, n_bootstrap=20)

        bench_eval_time = time.time()
        print(f"‚è±Ô∏è benchmark evaluation time: {(bench_eval_time-node_eval_time):.2f}s")

        # --- Save results ---
        thr_fit_summary = pd.concat([init_fit_summary, post_nr_fit_summary])
        thr_fit_summary['threshold'] = freq_threshold
        fit_summary_model.append(thr_fit_summary)
        fit_summary.append(thr_fit_summary)

        thr_benchmark_summary = pd.DataFrame(benchmark_results)
        thr_benchmark_summary['threshold'] = freq_threshold
        benchmark_summary_model.append(thr_benchmark_summary)
        benchmark_summary.append(thr_benchmark_summary)

        thr_node_log_summary = pd.DataFrame(node_log)
        thr_node_log_summary['threshold'] = freq_threshold
        node_log_summary_model.append(thr_node_log_summary)
        node_log_summary.append(thr_node_log_summary)

    # Export
    os.makedirs("summaries", exist_ok=True)    
    initial_edge_summary.to_csv(f"summaries/causal_disc_edge_summary_{model}_b{bootstrap}.csv", index=False)
    if model=="fci":
        initial_conf_summary.to_csv(f"summaries/causal_disc_conf_summary_{model}_b{bootstrap}.csv", index=False)
    pd.concat(fit_summary_model).to_csv(f"summaries/scm_fit_summary_{model}_b{bootstrap}.csv", index=False)
    pd.concat(benchmark_summary_model).to_csv(f"summaries/benchmark_comparison_{model}_b{bootstrap}.csv", index=False)
    pd.concat(node_log_summary_model).to_csv(f"summaries/scm_node_assignment_log_{model}_b{bootstrap}.csv", index=False)

pd.concat(edge_summary).to_csv(f"summaries/causal_disc_edge_summary_b{bootstrap}.csv", index=False)   
pd.concat(fit_summary).to_csv(f"summaries/scm_fit_summary_b{bootstrap}.csv", index=False)
pd.concat(benchmark_summary).to_csv(f"summaries/benchmark_comparison_b{bootstrap}.csv", index=False)
pd.concat(node_log_summary).to_csv(f"summaries/scm_node_assignment_log_b{bootstrap}.csv", index=False)
print(f"‚úÖ Results saved")

‚è±Ô∏è started at: 2025-11-28 11:44:49
‚è±Ô∏è model loop started at: 2025-11-28 11:44:49

=== Running PC | Œ±=0.1 | boot=100 ===


Starting single causal discovery run
Worker memory (MB): 392.085504
# Forbidden edges: 717
# Required edges: 3
Starting single causal discovery run
Worker memory (MB): 391.90528
# Forbidden edges: 717
# Required edges: 3
Starting single causal discovery run
Worker memory (MB): 393.13408
# Forbidden edges: 717
# Required edges: 3
Starting single causal discovery run
Worker memory (MB): 393.101312
# Forbidden edges: 717
# Required edges: 3
Starting single causal discovery run
Worker memory (MB): 390.283264
# Forbidden edges: 717
# Required edges: 3
Starting single causal discovery run
Worker memory (MB): 392.101888
# Forbidden edges: 717
# Required edges: 3


Depth=1, working on node 10:  26%|‚ñà‚ñà‚ñå       | 11/43 [00:17<01:31,  2.87s/it] 