In [1]:
from fast_causal_inference import FastCausalInference
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, RandomizedSearchCV
from sklearn.preprocessing import LabelEncoder

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
base_dir = '../../'
result_dir = base_dir + 'result/R/'
data_path = base_dir + 'dataset/' + 'Real_World_IBS.xlsx'
df = pd.read_excel(data_path)
df = df.drop(columns=['HAD_Anxiety', 'Patient', 'Batch_metabolomics', 'BH', 'Sex', 'Age', 'BMI','Race','Education','HAD_Depression','STAI_Tanxiety', 'Diet_Category','Diet_Pattern'])
label_encoder = LabelEncoder()
df['Group'] = label_encoder.fit_transform(df['Group'])
df_encoded = df

X = df_encoded.drop(columns=['Group'])
y = df_encoded['Group']

X = X[["xylose", "xanthosine", "uracil", "ribulose/xylulose", "valylglutamine",
           "tryptophylglycine", "succinate", "valine betaine", "ursodeoxycholate sulfate (1)",
           "tricarballylate", "succinimide", "thymine", "syringic acid", "serotonin", "ribitol"]]

param_dist = {
        'n_estimators': [100, 200, 300],
        'max_depth': [10, 20, 30, None],
        'min_samples_split': [2, 5, 7],
        'min_samples_leaf': [1, 2, 4]
    }
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

rf = RandomForestClassifier(random_state=42)
random_search = RandomizedSearchCV(
estimator=rf, param_distributions=param_dist, n_iter=50, cv=3, n_jobs=-1, verbose=2, random_state=42)
random_search.fit(X_train, y_train)
model = random_search.best_estimator_
best_params = random_search.best_params_

Fitting 3 folds for each of 50 candidates, totalling 150 fits


In [60]:
base_dir = '../../'
data_path = base_dir + 'dataset/' + 'Real_World_IBS_Predicted_Probabilities.xlsx'
df_prob = pd.read_excel(data_path)
X_train['Prob_Class_1'] = df_prob['Prob_Class_1']

In [57]:
import pandas as pd
import networkx as nx
import numpy as np
import json
from math import factorial
from sklearn.linear_model import LinearRegression
from collections import defaultdict

class FastCausalInference:
    def __init__(self, data, model, target_variable):
        self.data = data  
        self.model = model  
        self.gamma = None  
        self.target_variable = target_variable 
        self.ida_graph = None
        self.regression_models = {}
        self.feature_depths = {}
        self.path_cache = {}
        self.causal_paths = {}  
        
    def _compute_causal_paths(self):
        """Compute and store all causal paths to target for each feature."""
        features = [col for col in self.data.columns if col != self.target_variable]
        for feature in features:
            try:
                # Store the actual paths instead of just the features
                paths = list(nx.all_simple_paths(self.ida_graph, feature, self.target_variable))
                self.causal_paths[feature] = paths
            except nx.NetworkXNoPath:
                self.causal_paths[feature] = []

    def load_causal_strengths(self, json_file_path):
        with open(json_file_path, 'r') as f:
            causal_effects_list = json.load(f)
        
        G = nx.DiGraph()
        nodes = list(self.data.columns)
        G.add_nodes_from(nodes)

        for item in causal_effects_list:
            pair = item['Pair']
            mean_causal_effect = item['Mean_Causal_Effect']
            if mean_causal_effect is None:
                continue  
            source, target = pair.split('->')
            source = source.strip()
            target = target.strip()
            G.add_edge(source, target, weight=mean_causal_effect)
        self.ida_graph = G.copy()
        self._compute_feature_depths()
        self._compute_causal_paths()
        features = self.data.columns.tolist()
        beta_dict = {}

        for feature in features:
            if feature == self.target_variable:
                continue
            try:
                paths = list(nx.all_simple_paths(G, source=feature, target=self.target_variable))
            except nx.NetworkXNoPath:
                continue  
            total_effect = 0
            for path in paths:
                effect = 1
                for i in range(len(path)-1):
                    edge_weight = G[path[i]][path[i+1]]['weight']
                    effect *= edge_weight
                total_effect += effect
            if total_effect != 0:
                beta_dict[feature] = total_effect

        total_causal_effect = sum(abs(beta) for beta in beta_dict.values())
        if total_causal_effect == 0:
            self.gamma = {k: 0.0 for k in features}
        else:
            self.gamma = {k: abs(beta_dict.get(k, 0.0)) / total_causal_effect for k in features}
        return self.gamma
    
    def _compute_feature_depths(self):
        """Compute minimum depth of each feature to target in causal graph."""
        features = [col for col in self.data.columns if col != self.target_variable]
        for feature in features:
            try:
                all_paths = list(nx.all_simple_paths(self.ida_graph, feature, self.target_variable))
                min_depth = float('inf')
                for path in all_paths:
                    depth = len(path) - 1  
                    min_depth = min(min_depth, depth)
                if min_depth != float('inf'):
                    self.feature_depths[feature] = min_depth
            except nx.NetworkXNoPath:
                continue

    def get_topological_order(self, S):
        """Returns the topological order of variables after intervening on subset S."""
        G_intervened = self.ida_graph.copy()
        for feature in S:
            G_intervened.remove_edges_from(list(G_intervened.in_edges(feature)))
        missing_nodes = set(self.data.columns) - set(G_intervened.nodes)
        G_intervened.add_nodes_from(missing_nodes)

        try:
            order = list(nx.topological_sort(G_intervened))
        except nx.NetworkXUnfeasible:
            raise ValueError("The causal graph contains cycles.")
        
        return order
    
    def get_parents(self, feature):
        """Returns the list of parent features for a given feature in the causal graph."""
        return list(self.ida_graph.predecessors(feature))

    def sample_marginal(self, feature):
        """Sample a value from the marginal distribution of the specified feature."""
        return self.data[feature].sample(1).iloc[0]

    def sample_conditional(self, feature, parent_values):
        """Sample a value for a feature conditioned on its parent features."""
        effective_parents = [p for p in self.get_parents(feature) if p != self.target_variable]
        if not effective_parents:
            return self.sample_marginal(feature)
        model_key = (feature, tuple(sorted(effective_parents))) 
        if model_key not in self.regression_models:
            X = self.data[effective_parents].values
            y = self.data[feature].values
            reg = LinearRegression()
            reg.fit(X, y)
            residuals = y - reg.predict(X)
            std = residuals.std()
            self.regression_models[model_key] = (reg, std)
        reg, std = self.regression_models[model_key]
        parent_values_array = np.array([parent_values[parent] for parent in effective_parents]).reshape(1, -1)
        mean = reg.predict(parent_values_array)[0]
        sampled_value = np.random.normal(mean, std)
        return sampled_value

    def compute_v_do(self, S, x_S, is_classifier=False):
        """Compute interventional expectations with caching."""
        cache_key = (frozenset(S), tuple(sorted(x_S.items())) if len(x_S) > 0 else tuple())
        
        if cache_key in self.path_cache:
            return self.path_cache[cache_key]
        
        variables_order = self.get_topological_order(S)
        
        sample = {}
        for feature in S:
            sample[feature] = x_S[feature]
        for feature in variables_order:
            if feature in S or feature == self.target_variable:
                continue
            parents = self.get_parents(feature)
            parent_values = {p: x_S[p] if p in S else sample[p] for p in parents if p != self.target_variable}
            if not parent_values:
                sample[feature] = self.sample_marginal(feature)
            else:
                sample[feature] = self.sample_conditional(feature, parent_values)
               
        intervened_data = pd.DataFrame([sample])
        intervened_data = intervened_data[self.model.feature_names_in_]
        if is_classifier:
            probas = self.model.predict_proba(intervened_data)[:, 1]
        else:
            probas = self.model.predict(intervened_data)
        
        result = np.mean(probas)
        self.path_cache[cache_key] = result
        return result

    def is_on_causal_path(self, feature, S, target_feature):
        """Check if feature is on any causal path from S to target_feature."""
        if target_feature not in self.causal_paths:
            return False
        path_features = self.causal_paths[target_feature]
        return feature in path_features

    def compute_modified_shap_proba(self, x, is_classifier=False):
        """TreeSHAP-inspired computation using causal paths and dynamic programming."""
        features = [col for col in self.data.columns if col != self.target_variable]
        phi_causal = {feature: 0.0 for feature in features}

        data_without_target = self.data.drop(columns=[self.target_variable])
        if is_classifier:
            E_fX = self.model.predict_proba(data_without_target)[:, 1].mean()
        else:
            E_fX = self.model.predict(data_without_target).mean()

        x_ordered = x[self.model.feature_names_in_]
        if is_classifier:
            f_x = self.model.predict_proba(x_ordered.to_frame().T)[0][1]
        else:
            f_x = self.model.predict(x_ordered.to_frame().T)[0]

        sorted_features = sorted(features, key=lambda f: self.feature_depths.get(f, 0))
        max_path_length = max(self.feature_depths.values(), default=0)
        shapley_weights = {}
        for m in range(max_path_length + 1):
            for d in range(m + 1, max_path_length + 1):
                shapley_weights[(m, d)] = (factorial(m) * factorial(d - m - 1)) / factorial(d)

        # Track contributions using dynamic programming (EXTEND-like logic in TreeSHAP)
        # m_values will accumulate contributions from subsets (use combinatorial logic)
        # Essentially, values in m_values[k] represent how many ways there are to select k nodes from the path seen so far.
        for feature in sorted_features:
            if feature not in self.causal_paths:
                continue
            for path in self.causal_paths[feature]:
                path_features = [n for n in path if n != self.target_variable]
                d = len(path_features)
                m_values = defaultdict(float)
                m_values[0] = 1.0

                for node in path_features:
                    if node == feature:
                        continue  

                    new_m_values = defaultdict(float)
                    for m, val in m_values.items():
                        new_m_values[m + 1] += val 
                        new_m_values[m] += val 
                    m_values = new_m_values

                for m in m_values:
                    weight = shapley_weights.get((m, d), 0) * self.gamma.get(feature, 0)
                    delta_v = self._compute_path_delta_v(feature, path, m, x, is_classifier)
                    phi_causal[feature] += weight * delta_v * m_values[m]

        sum_phi = sum(phi_causal.values())
        if sum_phi != 0:
            scaling_factor = (f_x - E_fX) / sum_phi
            phi_causal = {k: v * scaling_factor for k, v in phi_causal.items()}

        return phi_causal
        
    def _compute_path_delta_v(self, feature, path, m, x, is_classifier):
        """Compute Δv for a causal path using precomputed expectations."""
        S = [n for n in path[:m] if n != feature]
        x_S = {n: x[n] for n in S if n in x}
        v_S = self.compute_v_do(S, x_S, is_classifier)

        S_with_i = S + [feature]
        x_Si = {**x_S, feature: x[feature]}
        v_Si = self.compute_v_do(S_with_i, x_Si, is_classifier)

        return v_Si - v_S

## Inference

In [58]:
ci = FastCausalInference(data=X_train, model=model, target_variable='Prob_Class_1')
ci.load_causal_strengths(result_dir + 'Mean_Causal_Effect_IBS.json')
x_instance = X_test.iloc[33]

In [61]:
phi = ci.compute_modified_shap_proba(x_instance, is_classifier=True)
print(phi)

{'xylose': np.float64(0.11866853374454497), 'xanthosine': np.float64(-0.018385377073444314), 'uracil': np.float64(0.09265538907924285), 'ribulose/xylulose': np.float64(-0.01608475582038705), 'valylglutamine': np.float64(0.0015482242274281917), 'tryptophylglycine': np.float64(-0.006556190949174268), 'succinate': np.float64(0.09336936802891681), 'valine betaine': np.float64(-0.031943003398665654), 'ursodeoxycholate sulfate (1)': np.float64(0.002086056952346538), 'tricarballylate': np.float64(0.002178932709545898), 'succinimide': np.float64(-0.015894412241578462), 'thymine': np.float64(0.03755382803723222), 'syringic acid': np.float64(0.0033322343721636136), 'serotonin': np.float64(-0.07955677350938707), 'ribitol': np.float64(-0.019355355098207676)}


## Eval

In [62]:
phi = []
for i in range(len(X_test)):
    phi.append(ci.compute_modified_shap_proba(X_test.iloc[i], is_classifier=True))

phi_df = pd.DataFrame(phi)
mean_values = phi_df.abs().mean()
global_importance = mean_values.sort_values(ascending=False)

In [63]:
from evaluation import evaluate_global_shap_scores
from sklearn.preprocessing import StandardScaler
import numpy as np


seeds = [42, 123, 456, 789, 1010]

X = X[["xylose", "xanthosine", "uracil", "ribulose/xylulose", "valylglutamine", "tryptophylglycine", "succinate", "valine betaine", "ursodeoxycholate sulfate (1)", "tricarballylate","succinimide", "thymine", "syringic acid", "serotonin", "ribitol" ]]

y = df_encoded['Group']

scaler = StandardScaler()
X = pd.DataFrame(scaler.fit_transform(X),columns=X.columns,index=X.index)

all_scores = {
    'deletion': {
        'auroc': [],
        'cross_entropy': [],
        'brier': []
    },
    'insertion': {
        'auroc': [],
        'cross_entropy': [],
        'brier': []
    }
}

for i in seeds:
    print("Training Random Forest model...")
    param_dist = {
        'n_estimators': [100, 200, 300],
        'max_depth': [10, 20, 30, None],
        'min_samples_split': [2, 5, 7],
        'min_samples_leaf': [1, 2, 4]
    }
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=i)

    rf = RandomForestClassifier(random_state=42)
    random_search = RandomizedSearchCV(
    estimator=rf, param_distributions=param_dist, n_iter=50,
            cv=3, n_jobs=-1, verbose=2, random_state=42)
    random_search.fit(X_train, y_train)
    model = random_search.best_estimator_
    best_params = random_search.best_params_

    shap_values = global_importance

    result = evaluate_global_shap_scores(model, X_test, y_test, shap_values, causal=True)

    for method in ['deletion', 'insertion']:
        for metric in ['auroc', 'cross_entropy', 'brier']:
            all_scores[method][metric].append(result[method]["average_scores"][metric])


Training Random Forest model...
Fitting 3 folds for each of 50 candidates, totalling 150 fits


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Training Random Forest model...
Fitting 3 folds for each of 50 candidates, totalling 150 fits


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Training Random Forest model...
Fitting 3 folds for each of 50 candidates, totalling 150 fits


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Training Random Forest model...
Fitting 3 folds for each of 50 candidates, totalling 150 fits


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Training Random Forest model...
Fitting 3 folds for each of 50 candidates, totalling 150 fits


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


In [64]:
import json
final_results = {
    'deletion': {
        metric: {
            'mean': np.mean(scores),
            'std': np.std(scores)
        }
        for metric, scores in all_scores['deletion'].items()
    },
    'insertion': {
        metric: {
            'mean': np.mean(scores),
            'std': np.std(scores)
        }
        for metric, scores in all_scores['insertion'].items()
    }
}

print("\nFinal Results:")
print(json.dumps(final_results, indent=4))


Final Results:
{
    "deletion": {
        "auroc": {
            "mean": 0.6072804297593022,
            "std": 0.033756452674940034
        },
        "cross_entropy": {
            "mean": 0.7810383556603926,
            "std": 0.04125080773121551
        },
        "brier": {
            "mean": 0.29054916153982324,
            "std": 0.018640170603287803
        }
    },
    "insertion": {
        "auroc": {
            "mean": 0.8531238679722886,
            "std": 0.01749174187306101
        },
        "cross_entropy": {
            "mean": 0.4902830011310785,
            "std": 0.02738522128428254
        },
        "brier": {
            "mean": 0.15868453686398365,
            "std": 0.011631364668156507
        }
    }
}
