## Optimization inspired by TreeSHAP

In [1]:
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from causal_inference import CausalInference
from sklearn.model_selection import train_test_split, RandomizedSearchCV
from sklearn.preprocessing import LabelEncoder

base_dir = '../../'
result_dir = base_dir + 'result/R/'
data_path = base_dir + 'dataset/' + 'Real_World_IBS.xlsx'
df = pd.read_excel(data_path)
df = df.drop(columns=['HAD_Anxiety', 'Patient', 'Batch_metabolomics', 'BH', 'Sex', 'Age', 'BMI','Race','Education','HAD_Depression','STAI_Tanxiety', 'Diet_Category','Diet_Pattern'])
label_encoder = LabelEncoder()
df['Group'] = label_encoder.fit_transform(df['Group'])
df_encoded = df

X = df_encoded.drop(columns=['Group'])
y = df_encoded['Group']

X = X[["xylose", "xanthosine", "uracil", "ribulose/xylulose", "valylglutamine",
           "tryptophylglycine", "succinate", "valine betaine", "ursodeoxycholate sulfate (1)",
           "tricarballylate", "succinimide", "thymine", "syringic acid", "serotonin", "ribitol"]]

param_dist = {
        'n_estimators': [100, 200, 300],
        'max_depth': [10, 20, 30, None],
        'min_samples_split': [2, 5, 7],
        'min_samples_leaf': [1, 2, 4]
    }
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

rf = RandomForestClassifier(random_state=42)
random_search = RandomizedSearchCV(
estimator=rf, param_distributions=param_dist, n_iter=50, cv=3, n_jobs=-1, verbose=2, random_state=42)
random_search.fit(X_train, y_train)
model = random_search.best_estimator_
best_params = random_search.best_params_

  from .autonotebook import tqdm as notebook_tqdm


Fitting 3 folds for each of 50 candidates, totalling 150 fits


  _data = np.array(data, dtype=dtype, copy=copy,


In [85]:
ci = CausalInference(data=X_train, model=model, target_variable='Prob_Class_1')
ci.load_causal_strengths(result_dir + 'Mean_Causal_Effect_IBS.json')
x_instance = X_test.iloc[33]

phi = ci.compute_modified_shap_proba(x_instance, is_classifier=True)

In [59]:
# Original Causal SHAP
from line_profiler import LineProfiler

ci = CausalInference(data=X_train, model=model, target_variable='Prob_Class_1')
ci.load_causal_strengths(result_dir + 'Mean_Causal_Effect_IBS.json')
x_instance = X_test.iloc[33]

profiler = LineProfiler()
profiler.add_function(ci.compute_modified_shap_proba)
profiler.add_function(ci.compute_v_do)
profiler.add_function(ci.sample_conditional)

profiler.run('phi_normalized = ci.compute_modified_shap_proba(x_instance)')
profiler.print_stats()

Timer unit: 1e-07 s

Total time: 61.2101 s
File: c:\Users\snorl\Desktop\FYP\code\python\causal_inference.py
Function: sample_conditional at line 110

Line #      Hits         Time  Per Hit   % Time  Line Contents
   110                                               def sample_conditional(self, feature, parent_values):
   111                                                   """
   112                                                   Sample a value for a feature conditioned on its parent features using precomputed regression model.
   113                                                   """
   114    461200    7621266.0     16.5      1.2          effective_parents = [p for p in self.get_parents(feature) if p != self.target_variable]
   115    168500     745582.0      4.4      0.1          if not effective_parents:
   116                                                       return self.sample_marginal(feature)
   117    168500    2127240.0     12.6      0.3          model_key = (featu

### Optimization with Causal Tree Path (Reduced Feature Size)

In [None]:
import pandas as pd
import networkx as nx
from causallearn.search.ConstraintBased.PC import pc
from causallearn.utils.GraphUtils import GraphUtils
from causallearn.utils.cit import fisherz
import numpy as np
import json
import random
from math import factorial
from sklearn.linear_model import LinearRegression
from collections import defaultdict

class FastCausalInference:
    def __init__(self, data, model, target_variable):
        self.data = data  
        self.pc_graph = None
        self.model = model  
        self.gamma = None  
        self.target_variable = target_variable 
        self.ida_graph = None
        self.regression_models = {}
        self.feature_depths = {}  # Store feature depths for optimization

    def run_pc_algorithm(self, alpha=0.05):
        data_np = self.data.to_numpy()
        pc_result = pc(data_np, alpha, fisherz)
        self.pc_graph = pc_result.G
        return self.pc_graph

    def draw_graph(self, file_path):
        pyd = GraphUtils.to_pydot(self.pc_graph)
        pyd.write_png(file_path)

    def load_causal_strengths(self, json_file_path):
        """Load causal strengths (beta_i) from JSON file and compute gamma_i."""
        with open(json_file_path, 'r') as f:
            causal_effects_list = json.load(f)
        
        G = nx.DiGraph()
        nodes = list(self.data.columns)
        G.add_nodes_from(nodes)

        for item in causal_effects_list:
            pair = item['Pair']
            mean_causal_effect = item['Mean_Causal_Effect']
            if mean_causal_effect is None:
                continue  
            source, target = pair.split('->')
            source = source.strip()
            target = target.strip()
            G.add_edge(source, target, weight=mean_causal_effect)
        
        self.ida_graph = G.copy()
        
        # Compute feature depths for optimization
        self._compute_feature_depths()
        
        features = self.data.columns.tolist()
        beta_dict = {}

        for feature in features:
            if feature == self.target_variable:
                continue
            try:
                paths = list(nx.all_simple_paths(G, source=feature, target=self.target_variable))
            except nx.NetworkXNoPath:
                continue  
            total_effect = 0
            for path in paths:
                effect = 1
                for i in range(len(path)-1):
                    edge_weight = G[path[i]][path[i+1]]['weight']
                    effect *= edge_weight
                total_effect += effect
            if total_effect != 0:
                beta_dict[feature] = total_effect

        total_causal_effect = sum(abs(beta) for beta in beta_dict.values())
        if total_causal_effect == 0:
            self.gamma = {k: 0.0 for k in features}
        else:
            self.gamma = {k: abs(beta_dict.get(k, 0.0)) / total_causal_effect for k in features}
        return self.gamma
    
    def _compute_feature_depths(self):
        """Compute minimum depth of each feature to target in causal graph."""
        features = [col for col in self.data.columns if col != self.target_variable]
        for feature in features:
            try:
                all_paths = list(nx.all_simple_paths(self.ida_graph, feature, self.target_variable))
                min_depth = float('inf')
                for path in all_paths:
                    depth = len(path) - 1  # Exclude target
                    min_depth = min(min_depth, depth)
                if min_depth != float('inf'):
                    self.feature_depths[feature] = min_depth
            except nx.NetworkXNoPath:
                continue

    def get_topological_order(self, S):
        """Returns the topological order of variables after intervening on subset S."""
        G_intervened = self.ida_graph.copy()
        for feature in S:
            G_intervened.remove_edges_from(list(G_intervened.in_edges(feature)))
        missing_nodes = set(self.data.columns) - set(G_intervened.nodes)
        G_intervened.add_nodes_from(missing_nodes)

        try:
            order = list(nx.topological_sort(G_intervened))
        except nx.NetworkXUnfeasible:
            raise ValueError("The causal graph contains cycles.")
        
        return order
    
    def get_parents(self, feature):
        """Returns the list of parent features for a given feature in the causal graph."""
        return list(self.ida_graph.predecessors(feature))

    def sample_marginal(self, feature):
        """Sample a value from the marginal distribution of the specified feature."""
        return self.data[feature].sample(1).iloc[0]

    def sample_conditional(self, feature, parent_values):
        """Sample a value for a feature conditioned on its parent features."""
        effective_parents = [p for p in self.get_parents(feature) if p != self.target_variable]
        if not effective_parents:
            return self.sample_marginal(feature)
        model_key = (feature, tuple(sorted(effective_parents))) 
        if model_key not in self.regression_models:
            X = self.data[effective_parents].values
            y = self.data[feature].values
            reg = LinearRegression()
            reg.fit(X, y)
            residuals = y - reg.predict(X)
            std = residuals.std()
            self.regression_models[model_key] = (reg, std)
        reg, std = self.regression_models[model_key]
        parent_values_array = np.array([parent_values[parent] for parent in effective_parents]).reshape(1, -1)
        mean = reg.predict(parent_values_array)[0]
        sampled_value = np.random.normal(mean, std)
        return sampled_value

    def compute_v_do(self, S, x_S, num_samples=50, is_classifier=False):
        """Compute interventional expectations."""
        samples_list = []
        variables_order = self.get_topological_order(S)
        
        for _ in range(num_samples):
            sample = {}
            for feature in S:
                sample[feature] = x_S[feature]
            for feature in variables_order:
                if feature in S or feature == self.target_variable:
                    continue
                parents = self.get_parents(feature)
                parent_values = {p: x_S[p] if p in S else sample[p] for p in parents if p != self.target_variable}
                if not parent_values:
                    sample[feature] = self.sample_marginal(feature)
                else:
                    sample[feature] = self.sample_conditional(feature, parent_values)
            samples_list.append(sample)
        
        intervened_data = pd.DataFrame(samples_list)
        intervened_data = intervened_data[self.model.feature_names_in_]
        if is_classifier:
            probas = self.model.predict_proba(intervened_data)[:, 1]
        else:
            probas = self.model.predict(intervened_data)
        return np.mean(probas)

    def compute_modified_shap_proba(self, x, num_samples=50, shap_num_samples=50, is_classifier=False):
        """Compute modified SHAP values using depth-based optimization."""
        features = [col for col in self.data.columns if col != self.target_variable]
        n_features = len(features)
        phi_causal = {feature: 0.0 for feature in features}

        data_without_target = self.data.drop(columns=[self.target_variable], errors='ignore')
        data_without_target = data_without_target[self.model.feature_names_in_]
        if is_classifier:
            E_fX = self.model.predict_proba(data_without_target)[:, 1].mean() 
        else:
            E_fX = self.model.predict(data_without_target).mean()

        x_ordered = x[self.model.feature_names_in_]
        if is_classifier:
            f_x = self.model.predict_proba(x_ordered.to_frame().T)[0][1]  
        else:
            f_x = self.model.predict(x_ordered.to_frame().T)[0]

        features_by_depth = defaultdict(list)
        for feature in features:
            depth = self.feature_depths.get(feature, float('inf'))
            features_by_depth[depth].append(feature)

        for depth in sorted(features_by_depth.keys()):
            depth_features = features_by_depth[depth]
            
            for feature in depth_features:
                valid_features = set()
                for d in range(depth + 1):
                    valid_features.update(features_by_depth[d])
                valid_features.discard(feature)
                
        for _ in range(shap_num_samples):
            S_size = random.randint(0, len(valid_features))
            S = random.sample(list(valid_features), S_size)
            for i in valid_features:
                if i in S:
                    continue 
                S_without_i = S.copy()
                S_with_i = S + [i]
                x_S = x[S_without_i] if S_without_i else pd.Series(dtype=float)
                x_Si = x[S_with_i] if S_with_i else pd.Series(dtype=float)
                v_S = self.compute_v_do(S_without_i, x_S, num_samples=num_samples, is_classifier=is_classifier)
                v_Si = self.compute_v_do(S_with_i, x_Si, num_samples=num_samples, is_classifier=is_classifier)
                weight = (factorial(len(S_without_i)) * factorial(n_features - len(S_without_i) - 1)) / factorial(n_features)
                gamma_i = self.gamma.get(i, 0.0)
                weight *= gamma_i
                delta_v = v_Si - v_S
                phi_causal[i] += weight * delta_v

        sum_phi_causal = sum(phi_causal.values())
        if sum_phi_causal == 0:
            phi_normalized = {k: 0.0 for k in phi_causal.keys()}
        else:
            scaling_factor = (f_x - E_fX) / sum_phi_causal
            phi_normalized = {k: v * scaling_factor for k, v in phi_causal.items()}

        return phi_normalized

In [None]:
ci = FastCausalInference(data=X_train, model=model, target_variable='Prob_Class_1')
ci.load_causal_strengths(result_dir + 'Mean_Causal_Effect_IBS.json')
x_instance = X_test.iloc[33]

phi = ci.compute_modified_shap_proba(x_instance, is_classifier=True)

In [56]:
# Optimized Causal SHAP
from line_profiler import LineProfiler

ci = FastCausalInference(data=X_train, model=model, target_variable='Prob_Class_1')
ci.load_causal_strengths(result_dir + 'Mean_Causal_Effect_IBS.json')
x_instance = X_test.iloc[33]

profiler = LineProfiler()
profiler.add_function(ci.compute_modified_shap_proba)
profiler.add_function(ci.compute_v_do)
profiler.add_function(ci.sample_conditional)

profiler.run('phi = ci.compute_modified_shap_proba(x_instance, is_classifier=True)')
profiler.print_stats()

Timer unit: 1e-07 s

Total time: 63.6335 s
File: C:\Users\snorl\AppData\Local\Temp\ipykernel_1332\343667890.py
Function: sample_conditional at line 123

Line #      Hits         Time  Per Hit   % Time  Line Contents
   123                                               def sample_conditional(self, feature, parent_values):
   124                                                   """Sample a value for a feature conditioned on its parent features."""
   125    446550    7975652.0     17.9      1.3          effective_parents = [p for p in self.get_parents(feature) if p != self.target_variable]
   126    162800     782189.0      4.8      0.1          if not effective_parents:
   127                                                       return self.sample_marginal(feature)
   128    162800    2283014.0     14.0      0.4          model_key = (feature, tuple(sorted(effective_parents))) 
   129    162800    1948488.0     12.0      0.3          if model_key not in self.regression_models:
   130  

### Add Path Cache

In [78]:
import pandas as pd
import networkx as nx
from causallearn.search.ConstraintBased.PC import pc
from causallearn.utils.GraphUtils import GraphUtils
from causallearn.utils.cit import fisherz
import numpy as np
import json
import random
from math import factorial
from sklearn.linear_model import LinearRegression
from collections import defaultdict

class OptimizedCausalInference:
    def __init__(self, data, model, target_variable):
        self.data = data  
        self.pc_graph = None
        self.model = model  
        self.gamma = None  
        self.target_variable = target_variable 
        self.ida_graph = None
        self.regression_models = {}
        self.feature_depths = {}  
        self.path_cache = {}

    def run_pc_algorithm(self, alpha=0.05):
        data_np = self.data.to_numpy()
        pc_result = pc(data_np, alpha, fisherz)
        self.pc_graph = pc_result.G
        return self.pc_graph

    def draw_graph(self, file_path):
        pyd = GraphUtils.to_pydot(self.pc_graph)
        pyd.write_png(file_path)

    def load_causal_strengths(self, json_file_path):
        """Load causal strengths (beta_i) from JSON file and compute gamma_i."""
        with open(json_file_path, 'r') as f:
            causal_effects_list = json.load(f)
        
        G = nx.DiGraph()
        nodes = list(self.data.columns)
        G.add_nodes_from(nodes)

        for item in causal_effects_list:
            pair = item['Pair']
            mean_causal_effect = item['Mean_Causal_Effect']
            if mean_causal_effect is None:
                continue  
            source, target = pair.split('->')
            source = source.strip()
            target = target.strip()
            G.add_edge(source, target, weight=mean_causal_effect)
        
        self.ida_graph = G.copy()
        self._compute_feature_depths()
        
        features = self.data.columns.tolist()
        beta_dict = {}

        for feature in features:
            if feature == self.target_variable:
                continue
            try:
                paths = list(nx.all_simple_paths(G, source=feature, target=self.target_variable))
            except nx.NetworkXNoPath:
                continue  
            total_effect = 0
            for path in paths:
                effect = 1
                for i in range(len(path)-1):
                    edge_weight = G[path[i]][path[i+1]]['weight']
                    effect *= edge_weight
                total_effect += effect
            if total_effect != 0:
                beta_dict[feature] = total_effect

        total_causal_effect = sum(abs(beta) for beta in beta_dict.values())
        if total_causal_effect == 0:
            self.gamma = {k: 0.0 for k in features}
        else:
            self.gamma = {k: abs(beta_dict.get(k, 0.0)) / total_causal_effect for k in features}
        return self.gamma
    
    def _compute_feature_depths(self):
        """Compute minimum depth of each feature to target in causal graph."""
        features = [col for col in self.data.columns if col != self.target_variable]
        for feature in features:
            try:
                all_paths = list(nx.all_simple_paths(self.ida_graph, feature, self.target_variable))
                min_depth = float('inf')
                for path in all_paths:
                    depth = len(path) - 1  # Exclude target
                    min_depth = min(min_depth, depth)
                if min_depth != float('inf'):
                    self.feature_depths[feature] = min_depth
            except nx.NetworkXNoPath:
                continue

    def get_topological_order(self, S):
        """Returns the topological order of variables after intervening on subset S."""
        G_intervened = self.ida_graph.copy()
        for feature in S:
            G_intervened.remove_edges_from(list(G_intervened.in_edges(feature)))
        missing_nodes = set(self.data.columns) - set(G_intervened.nodes)
        G_intervened.add_nodes_from(missing_nodes)

        try:
            order = list(nx.topological_sort(G_intervened))
        except nx.NetworkXUnfeasible:
            raise ValueError("The causal graph contains cycles.")
        
        return order
    
    def get_parents(self, feature):
        """Returns the list of parent features for a given feature in the causal graph."""
        return list(self.ida_graph.predecessors(feature))

    def sample_marginal(self, feature):
        """Sample a value from the marginal distribution of the specified feature."""
        return self.data[feature].sample(1).iloc[0]

    def sample_conditional(self, feature, parent_values):
        """Sample a value for a feature conditioned on its parent features."""
        effective_parents = [p for p in self.get_parents(feature) if p != self.target_variable]
        if not effective_parents:
            return self.sample_marginal(feature)
        model_key = (feature, tuple(sorted(effective_parents))) 
        if model_key not in self.regression_models:
            X = self.data[effective_parents].values
            y = self.data[feature].values
            reg = LinearRegression()
            reg.fit(X, y)
            residuals = y - reg.predict(X)
            std = residuals.std()
            self.regression_models[model_key] = (reg, std)
        reg, std = self.regression_models[model_key]
        parent_values_array = np.array([parent_values[parent] for parent in effective_parents]).reshape(1, -1)
        mean = reg.predict(parent_values_array)[0]
        sampled_value = np.random.normal(mean, std)
        return sampled_value

    def compute_v_do(self, S, x_S, num_samples=50, is_classifier=False):
        """Compute interventional expectations with caching."""
        cache_key = (frozenset(S), tuple(sorted(x_S.items())) if len(x_S) > 0 else tuple())
        
        if cache_key in self.path_cache:
            return self.path_cache[cache_key]
        
        samples_list = []
        variables_order = self.get_topological_order(S)
        
        for _ in range(num_samples):
            sample = {}
            for feature in S:
                sample[feature] = x_S[feature]
            for feature in variables_order:
                if feature in S or feature == self.target_variable:
                    continue
                parents = self.get_parents(feature)
                parent_values = {p: x_S[p] if p in S else sample[p] for p in parents if p != self.target_variable}
                if not parent_values:
                    sample[feature] = self.sample_marginal(feature)
                else:
                    sample[feature] = self.sample_conditional(feature, parent_values)
            samples_list.append(sample)
        
        intervened_data = pd.DataFrame(samples_list)
        intervened_data = intervened_data[self.model.feature_names_in_]
        if is_classifier:
            probas = self.model.predict_proba(intervened_data)[:, 1]
        else:
            probas = self.model.predict(intervened_data)
        
        result = np.mean(probas)
        self.path_cache[cache_key] = result
        return result

    def compute_modified_shap_proba(self, x, num_samples=50, shap_num_samples=50, is_classifier=False):
        """Compute modified SHAP values using depth-based optimization."""
        features = [col for col in self.data.columns if col != self.target_variable]
        n_features = len(features)
        phi_causal = {feature: 0.0 for feature in features}

        data_without_target = self.data.drop(columns=[self.target_variable], errors='ignore')
        data_without_target = data_without_target[self.model.feature_names_in_]
        if is_classifier:
            E_fX = self.model.predict_proba(data_without_target)[:, 1].mean() 
        else:
            E_fX = self.model.predict(data_without_target).mean()

        x_ordered = x[self.model.feature_names_in_]
        if is_classifier:
            f_x = self.model.predict_proba(x_ordered.to_frame().T)[0][1]  
        else:
            f_x = self.model.predict(x_ordered.to_frame().T)[0]

        features_by_depth = defaultdict(list)
        for feature in features:
            depth = self.feature_depths.get(feature, float('inf'))
            features_by_depth[depth].append(feature)

        for depth in sorted(features_by_depth.keys()):
            depth_features = features_by_depth[depth]
            
            for feature in depth_features:
                valid_features = set()
                for d in range(depth + 1):
                    valid_features.update(features_by_depth[d])
                valid_features.discard(feature)
                
        for _ in range(shap_num_samples):
            S_size = random.randint(0, len(valid_features))
            S = random.sample(list(valid_features), S_size)
            for i in valid_features:
                if i in S:
                    continue 
                S_without_i = S.copy()
                S_with_i = S + [i]
                x_S = x[S_without_i] if S_without_i else pd.Series(dtype=float)
                x_Si = x[S_with_i] if S_with_i else pd.Series(dtype=float)
                v_S = self.compute_v_do(S_without_i, x_S, num_samples=num_samples, is_classifier=is_classifier)
                v_Si = self.compute_v_do(S_with_i, x_Si, num_samples=num_samples, is_classifier=is_classifier)
                weight = (factorial(len(S_without_i)) * factorial(n_features - len(S_without_i) - 1)) / factorial(n_features)
                gamma_i = self.gamma.get(i, 0.0)
                weight *= gamma_i
                delta_v = v_Si - v_S
                phi_causal[i] += weight * delta_v

        sum_phi_causal = sum(phi_causal.values())
        if sum_phi_causal == 0:
            phi_normalized = {k: 0.0 for k in phi_causal.keys()}
        else:
            scaling_factor = (f_x - E_fX) / sum_phi_causal
            phi_normalized = {k: v * scaling_factor for k, v in phi_causal.items()}

        return phi_normalized

In [79]:
ci = OptimizedCausalInference(data=X_train, model=model, target_variable='Prob_Class_1')
ci.load_causal_strengths(result_dir + 'Mean_Causal_Effect_IBS.json')
x_instance = X_test.iloc[33]

phi = ci.compute_modified_shap_proba(x_instance, is_classifier=True)

### Optimization with TreeSHAP idea

In [2]:
import pandas as pd
import networkx as nx
from causallearn.search.ConstraintBased.PC import pc
from causallearn.utils.GraphUtils import GraphUtils
from causallearn.utils.cit import fisherz
import numpy as np
import json
import random
from math import factorial
from sklearn.linear_model import LinearRegression
from collections import defaultdict

class TreeShapCausalInference:
    def __init__(self, data, model, target_variable):
        self.data = data  
        self.model = model  
        self.gamma = None  
        self.target_variable = target_variable 
        self.ida_graph = None
        self.regression_models = {}
        self.feature_depths = {}
        self.path_cache = {}
        self.causal_paths = {}  
        
    def _compute_causal_paths(self):
        """Compute and store all causal paths to target for each feature."""
        features = [col for col in self.data.columns if col != self.target_variable]
        for feature in features:
            try:
                paths = list(nx.all_simple_paths(self.ida_graph, feature, self.target_variable))
                path_features = set()
                for path in paths:
                    path_features.update(path[:-1]) 
                self.causal_paths[feature] = path_features
            except nx.NetworkXNoPath:
                self.causal_paths[feature] = set()

    def load_causal_strengths(self, json_file_path):
        with open(json_file_path, 'r') as f:
            causal_effects_list = json.load(f)
        
        G = nx.DiGraph()
        nodes = list(self.data.columns)
        G.add_nodes_from(nodes)

        for item in causal_effects_list:
            pair = item['Pair']
            mean_causal_effect = item['Mean_Causal_Effect']
            if mean_causal_effect is None:
                continue  
            source, target = pair.split('->')
            source = source.strip()
            target = target.strip()
            G.add_edge(source, target, weight=mean_causal_effect)
        self.ida_graph = G.copy()
        self._compute_feature_depths()
        self._compute_causal_paths()
        features = self.data.columns.tolist()
        beta_dict = {}

        for feature in features:
            if feature == self.target_variable:
                continue
            try:
                paths = list(nx.all_simple_paths(G, source=feature, target=self.target_variable))
            except nx.NetworkXNoPath:
                continue  
            total_effect = 0
            for path in paths:
                effect = 1
                for i in range(len(path)-1):
                    edge_weight = G[path[i]][path[i+1]]['weight']
                    effect *= edge_weight
                total_effect += effect
            if total_effect != 0:
                beta_dict[feature] = total_effect

        total_causal_effect = sum(abs(beta) for beta in beta_dict.values())
        if total_causal_effect == 0:
            self.gamma = {k: 0.0 for k in features}
        else:
            self.gamma = {k: abs(beta_dict.get(k, 0.0)) / total_causal_effect for k in features}
        return self.gamma
    
    def _compute_feature_depths(self):
        """Compute minimum depth of each feature to target in causal graph."""
        features = [col for col in self.data.columns if col != self.target_variable]
        for feature in features:
            try:
                all_paths = list(nx.all_simple_paths(self.ida_graph, feature, self.target_variable))
                min_depth = float('inf')
                for path in all_paths:
                    depth = len(path) - 1  
                    min_depth = min(min_depth, depth)
                if min_depth != float('inf'):
                    self.feature_depths[feature] = min_depth
            except nx.NetworkXNoPath:
                continue

    def get_topological_order(self, S):
        """Returns the topological order of variables after intervening on subset S."""
        G_intervened = self.ida_graph.copy()
        for feature in S:
            G_intervened.remove_edges_from(list(G_intervened.in_edges(feature)))
        missing_nodes = set(self.data.columns) - set(G_intervened.nodes)
        G_intervened.add_nodes_from(missing_nodes)

        try:
            order = list(nx.topological_sort(G_intervened))
        except nx.NetworkXUnfeasible:
            raise ValueError("The causal graph contains cycles.")
        
        return order
    
    def get_parents(self, feature):
        """Returns the list of parent features for a given feature in the causal graph."""
        return list(self.ida_graph.predecessors(feature))

    def sample_marginal(self, feature):
        """Sample a value from the marginal distribution of the specified feature."""
        return self.data[feature].sample(1).iloc[0]

    def sample_conditional(self, feature, parent_values):
        """Sample a value for a feature conditioned on its parent features."""
        effective_parents = [p for p in self.get_parents(feature) if p != self.target_variable]
        if not effective_parents:
            return self.sample_marginal(feature)
        model_key = (feature, tuple(sorted(effective_parents))) 
        if model_key not in self.regression_models:
            X = self.data[effective_parents].values
            y = self.data[feature].values
            reg = LinearRegression()
            reg.fit(X, y)
            residuals = y - reg.predict(X)
            std = residuals.std()
            self.regression_models[model_key] = (reg, std)
        reg, std = self.regression_models[model_key]
        parent_values_array = np.array([parent_values[parent] for parent in effective_parents]).reshape(1, -1)
        mean = reg.predict(parent_values_array)[0]
        sampled_value = np.random.normal(mean, std)
        return sampled_value

    def compute_v_do(self, S, x_S, num_samples=50, is_classifier=False):
        """Compute interventional expectations with caching."""
        cache_key = (frozenset(S), tuple(sorted(x_S.items())) if len(x_S) > 0 else tuple())
        
        if cache_key in self.path_cache:
            return self.path_cache[cache_key]
        
        samples_list = []
        variables_order = self.get_topological_order(S)
        
        for _ in range(num_samples):
            sample = {}
            for feature in S:
                sample[feature] = x_S[feature]
            for feature in variables_order:
                if feature in S or feature == self.target_variable:
                    continue
                parents = self.get_parents(feature)
                parent_values = {p: x_S[p] if p in S else sample[p] for p in parents if p != self.target_variable}
                if not parent_values:
                    sample[feature] = self.sample_marginal(feature)
                else:
                    sample[feature] = self.sample_conditional(feature, parent_values)
            samples_list.append(sample)
        
        intervened_data = pd.DataFrame(samples_list)
        intervened_data = intervened_data[self.model.feature_names_in_]
        if is_classifier:
            probas = self.model.predict_proba(intervened_data)[:, 1]
        else:
            probas = self.model.predict(intervened_data)
        
        result = np.mean(probas)
        self.path_cache[cache_key] = result
        return result

    def is_on_causal_path(self, feature, S, target_feature):
        """Check if feature is on any causal path from S to target_feature."""
        if target_feature not in self.causal_paths:
            return False
        path_features = self.causal_paths[target_feature]
        return feature in path_features

    def compute_modified_shap_proba(self, x, num_samples=50, shap_num_samples=50, is_classifier=False):
        """TreeSHAP-inspired computation of SHAP values."""
        features = [col for col in self.data.columns if col != self.target_variable]
        n_features = len(features)
        phi_causal = {feature: 0.0 for feature in features}

        data_without_target = self.data.drop(columns=[self.target_variable], errors='ignore')
        data_without_target = data_without_target[self.model.feature_names_in_]
        if is_classifier:
            E_fX = self.model.predict_proba(data_without_target)[:, 1].mean() 
        else:
            E_fX = self.model.predict(data_without_target).mean()

        x_ordered = x[self.model.feature_names_in_]
        if is_classifier:
            f_x = self.model.predict_proba(x_ordered.to_frame().T)[0][1]  
        else:
            f_x = self.model.predict(x_ordered.to_frame().T)[0]

        features_by_depth = defaultdict(list)
        for feature in features:
            depth = self.feature_depths.get(feature, float('inf'))
            features_by_depth[depth].append(feature)

        for depth in sorted(features_by_depth.keys()):
            depth_features = features_by_depth[depth]
            
            for feature in depth_features:
                valid_features = set()
                for d in range(depth + 1):
                    for f in features_by_depth[d]:
                        if f != feature and self.is_on_causal_path(f, set(), feature):
                            valid_features.add(f)

                for _ in range(shap_num_samples):
                    S_size = random.randint(0, len(valid_features))
                    S = random.sample(list(valid_features), S_size)
                    relevant_features = {f for f in valid_features 
                                      if self.is_on_causal_path(f, S, feature)}
                    
                    if not relevant_features:
                        continue
                    S_without_i = S.copy()
                    S_with_i = S + [feature]
                    x_S = x[S_without_i] if S_without_i else pd.Series(dtype=float)
                    x_Si = x[S_with_i] if S_with_i else pd.Series(dtype=float)
                    
                    v_S = self.compute_v_do(S_without_i, x_S, num_samples=num_samples, 
                                          is_classifier=is_classifier)
                    v_Si = self.compute_v_do(S_with_i, x_Si, num_samples=num_samples, 
                                           is_classifier=is_classifier)
                    
                    weight = (factorial(len(relevant_features)) * 
                            factorial(n_features - len(relevant_features) - 1)) / factorial(n_features)
                    weight *= self.gamma.get(feature, 0.0)
                    
                    delta_v = v_Si - v_S
                    phi_causal[feature] += weight * delta_v

        sum_phi_causal = sum(phi_causal.values())
        if sum_phi_causal == 0:
            phi_normalized = {k: 0.0 for k in phi_causal.keys()}
        else:
            scaling_factor = (f_x - E_fX) / sum_phi_causal
            phi_normalized = {k: v * scaling_factor for k, v in phi_causal.items()}

        return phi_normalized

In [5]:
ci = TreeShapCausalInference(data=X_train, model=model, target_variable='Prob_Class_1')
ci.load_causal_strengths(result_dir + 'Mean_Causal_Effect_IBS.json')
x_instance = X_test.iloc[33]

phi = ci.compute_modified_shap_proba(x_instance, is_classifier=True)

### Final Optimization with TreeSHAP idea?

In [None]:
import pandas as pd
import networkx as nx
import numpy as np
import json
from math import factorial
from sklearn.linear_model import LinearRegression
from collections import defaultdict

class FastCausalInference:
    def __init__(self, data, model, target_variable):
        self.data = data  
        self.model = model  
        self.gamma = None  
        self.target_variable = target_variable 
        self.ida_graph = None
        self.regression_models = {}
        self.feature_depths = {}
        self.path_cache = {}
        self.causal_paths = {}  
        
    def _compute_causal_paths(self):
        """Compute and store all causal paths to target for each feature."""
        features = [col for col in self.data.columns if col != self.target_variable]
        for feature in features:
            try:
                # Store the actual paths instead of just the features
                paths = list(nx.all_simple_paths(self.ida_graph, feature, self.target_variable))
                self.causal_paths[feature] = paths
            except nx.NetworkXNoPath:
                self.causal_paths[feature] = []

    def load_causal_strengths(self, json_file_path):
        with open(json_file_path, 'r') as f:
            causal_effects_list = json.load(f)
        
        G = nx.DiGraph()
        nodes = list(self.data.columns)
        G.add_nodes_from(nodes)

        for item in causal_effects_list:
            pair = item['Pair']
            mean_causal_effect = item['Mean_Causal_Effect']
            if mean_causal_effect is None:
                continue  
            source, target = pair.split('->')
            source = source.strip()
            target = target.strip()
            G.add_edge(source, target, weight=mean_causal_effect)
        self.ida_graph = G.copy()
        self._compute_feature_depths()
        self._compute_causal_paths()
        features = self.data.columns.tolist()
        beta_dict = {}

        for feature in features:
            if feature == self.target_variable:
                continue
            try:
                paths = list(nx.all_simple_paths(G, source=feature, target=self.target_variable))
            except nx.NetworkXNoPath:
                continue  
            total_effect = 0
            for path in paths:
                effect = 1
                for i in range(len(path)-1):
                    edge_weight = G[path[i]][path[i+1]]['weight']
                    effect *= edge_weight
                total_effect += effect
            if total_effect != 0:
                beta_dict[feature] = total_effect

        total_causal_effect = sum(abs(beta) for beta in beta_dict.values())
        if total_causal_effect == 0:
            self.gamma = {k: 0.0 for k in features}
        else:
            self.gamma = {k: abs(beta_dict.get(k, 0.0)) / total_causal_effect for k in features}
        return self.gamma
    
    def _compute_feature_depths(self):
        """Compute minimum depth of each feature to target in causal graph."""
        features = [col for col in self.data.columns if col != self.target_variable]
        for feature in features:
            try:
                all_paths = list(nx.all_simple_paths(self.ida_graph, feature, self.target_variable))
                min_depth = float('inf')
                for path in all_paths:
                    depth = len(path) - 1  
                    min_depth = min(min_depth, depth)
                if min_depth != float('inf'):
                    self.feature_depths[feature] = min_depth
            except nx.NetworkXNoPath:
                continue

    def get_topological_order(self, S):
        """Returns the topological order of variables after intervening on subset S."""
        G_intervened = self.ida_graph.copy()
        for feature in S:
            G_intervened.remove_edges_from(list(G_intervened.in_edges(feature)))
        missing_nodes = set(self.data.columns) - set(G_intervened.nodes)
        G_intervened.add_nodes_from(missing_nodes)

        try:
            order = list(nx.topological_sort(G_intervened))
        except nx.NetworkXUnfeasible:
            raise ValueError("The causal graph contains cycles.")
        
        return order
    
    def get_parents(self, feature):
        """Returns the list of parent features for a given feature in the causal graph."""
        return list(self.ida_graph.predecessors(feature))

    def sample_marginal(self, feature):
        """Sample a value from the marginal distribution of the specified feature."""
        return self.data[feature].sample(1).iloc[0]

    def sample_conditional(self, feature, parent_values):
        """Sample a value for a feature conditioned on its parent features."""
        effective_parents = [p for p in self.get_parents(feature) if p != self.target_variable]
        if not effective_parents:
            return self.sample_marginal(feature)
        model_key = (feature, tuple(sorted(effective_parents))) 
        if model_key not in self.regression_models:
            X = self.data[effective_parents].values
            y = self.data[feature].values
            reg = LinearRegression()
            reg.fit(X, y)
            residuals = y - reg.predict(X)
            std = residuals.std()
            self.regression_models[model_key] = (reg, std)
        reg, std = self.regression_models[model_key]
        parent_values_array = np.array([parent_values[parent] for parent in effective_parents]).reshape(1, -1)
        mean = reg.predict(parent_values_array)[0]
        sampled_value = np.random.normal(mean, std)
        return sampled_value

    def compute_v_do(self, S, x_S, is_classifier=False):
        """Compute interventional expectations with caching."""
        cache_key = (frozenset(S), tuple(sorted(x_S.items())) if len(x_S) > 0 else tuple())
        
        if cache_key in self.path_cache:
            return self.path_cache[cache_key]
        
        variables_order = self.get_topological_order(S)
        
        sample = {}
        for feature in S:
            sample[feature] = x_S[feature]
        for feature in variables_order:
            if feature in S or feature == self.target_variable:
                continue
            parents = self.get_parents(feature)
            parent_values = {p: x_S[p] if p in S else sample[p] for p in parents if p != self.target_variable}
            if not parent_values:
                sample[feature] = self.sample_marginal(feature)
            else:
                sample[feature] = self.sample_conditional(feature, parent_values)
               
        intervened_data = pd.DataFrame([sample])
        intervened_data = intervened_data[self.model.feature_names_in_]
        if is_classifier:
            probas = self.model.predict_proba(intervened_data)[:, 1]
        else:
            probas = self.model.predict(intervened_data)
        
        result = np.mean(probas)
        self.path_cache[cache_key] = result
        return result

    def is_on_causal_path(self, feature, S, target_feature):
        """Check if feature is on any causal path from S to target_feature."""
        if target_feature not in self.causal_paths:
            return False
        path_features = self.causal_paths[target_feature]
        return feature in path_features

    def compute_modified_shap_proba(self, x, num_samples=50, is_classifier=False):
        """TreeSHAP-inspired computation using causal paths and dynamic programming."""
        features = [col for col in self.data.columns if col != self.target_variable]
        n_features = len(features)
        phi_causal = {feature: 0.0 for feature in features}

        # Precompute baseline expectation
        data_without_target = self.data.drop(columns=[self.target_variable])
        if is_classifier:
            E_fX = self.model.predict_proba(data_without_target)[:, 1].mean()
        else:
            E_fX = self.model.predict(data_without_target).mean()

        # Get prediction for instance x
        x_ordered = x[self.model.feature_names_in_]
        if is_classifier:
            f_x = self.model.predict_proba(x_ordered.to_frame().T)[0][1]
        else:
            f_x = self.model.predict(x_ordered.to_frame().T)[0]

        # Process features in topological order (root to target)
        sorted_features = sorted(features, key=lambda f: self.feature_depths.get(f, 0))

        # Precompute Shapley weights for all possible subset sizes
        max_path_length = max(self.feature_depths.values(), default=0)
        shapley_weights = {}
        for m in range(max_path_length + 1):
            for d in range(m + 1, max_path_length + 1):
                shapley_weights[(m, d)] = (factorial(m) * factorial(d - m - 1)) / factorial(d)

        # Track contributions using dynamic programming (EXTEND-like logic)
        for feature in sorted_features:
            if feature not in self.causal_paths:
                continue

            for path in self.causal_paths[feature]:
                path_features = [n for n in path if n != self.target_variable]
                d = len(path_features)
                m_values = defaultdict(float)

                # Initialize with empty subset
                m_values[0] = 1.0

                # Traverse path features (EXTEND)
                for node in path_features:
                    if node == feature:
                        continue  # Skip the feature itself (handled later)

                    new_m_values = defaultdict(float)
                    for m, val in m_values.items():
                        # Include the feature
                        new_m_values[m + 1] += val * self._get_indicator(node, x)
                        # Exclude the feature
                        new_m_values[m] += val * self._get_ratio(node)
                    m_values = new_m_values

                # Compute contributions for all subset sizes
                for m in m_values:
                    weight = shapley_weights.get((m, d), 0) * self.gamma.get(feature, 0)
                    delta_v = self._compute_path_delta_v(feature, path, m, x, is_classifier)
                    phi_causal[feature] += weight * delta_v * m_values[m]

        # Normalize to match f(x) - E[f(X)]
        sum_phi = sum(phi_causal.values())
        if sum_phi != 0:
            scaling_factor = (f_x - E_fX) / sum_phi
            phi_causal = {k: v * scaling_factor for k, v in phi_causal.items()}

        return phi_causal

    def _get_indicator(self, node, x):
        """Check if node's value matches the intervention (simplified for illustration)."""
        return 1.0 if node in x else 0.0  # Replace with actual causal condition check

    def _get_ratio(self, node):
        """Get covering ratio R (from causal graph parameters)."""
        return self.ida_graph.nodes[node].get('ratio', 1.0)  # Precompute ratios during graph setup
        

    def _compute_path_delta_v(self, feature, path, m, x, is_classifier):
        """Compute Δv for a causal path using precomputed expectations."""
        S = [n for n in path[:m] if n != feature]
        x_S = {n: x[n] for n in S if n in x}
        v_S = self.compute_v_do(S, x_S, is_classifier)

        S_with_i = S + [feature]
        x_Si = {**x_S, feature: x[feature]}
        v_Si = self.compute_v_do(S_with_i, x_Si, is_classifier)

        return v_Si - v_S

In [56]:
data_path = base_dir + 'dataset/' + 'Real_World_IBS_Predicted_Probabilities.xlsx'
df_prob = pd.read_excel(data_path)
X_train['Prob_Class_1'] = df_prob['Prob_Class_1']

ci = FastCausalInference(data=X_train, model=model, target_variable='Prob_Class_1')
ci.load_causal_strengths(result_dir + 'Mean_Causal_Effect_IBS.json')
x_instance = X_test.iloc[33]

In [57]:
phi = ci.compute_modified_shap_proba(x_instance, is_classifier=True)


   valylglutamine  valine betaine  ursodeoxycholate sulfate (1)  \
0       -0.513533        0.054241                     -0.017948   

   tricarballylate   thymine  syringic acid  serotonin   ribitol  \
0        -0.039505 -0.935984        1.08951  -0.124574 -0.018983   

   tryptophylglycine  succinimide  xanthosine  succinate    uracil  \
0          -0.685612     0.094881   -0.414384  -0.500207 -0.418864   

   ribulose/xylulose    xylose  
0          -0.897152 -0.383827  
-----------------------------------------line break----------
     xylose  valylglutamine  valine betaine  ursodeoxycholate sulfate (1)  \
0 -0.651071       -0.307074       -0.279434                      0.391271   

   tricarballylate   thymine  syringic acid  serotonin   ribitol  \
0        -0.382151 -0.813522      -0.348835  -0.365496 -0.239429   

   tryptophylglycine  succinimide  xanthosine  succinate    uracil  \
0          -0.664029     0.685179   -0.731309  -0.316879 -1.152711   

   ribulose/xylulose  
0  

## Test Accuracy

In [46]:
#ci = FastCausalInference(data=X_train, model=model, target_variable='Prob_Class_1')
#ci.load_causal_strengths(result_dir + 'Mean_Causal_Effect_IBS.json')

phi = []
for i in range(len(X_test)):
    phi.append(ci.compute_modified_shap_proba(X_test.iloc[i], is_classifier=True))

In [47]:
phi_df = pd.DataFrame(phi)
mean_values = phi_df.abs().mean()
global_importance = mean_values.sort_values(ascending=False)

In [None]:
from evaluation import evaluate_global_shap_scores
from sklearn.preprocessing import StandardScaler


seeds = [42, 123, 456, 789, 1010]

X = X[["xylose", "xanthosine", "uracil", "ribulose/xylulose", "valylglutamine", "tryptophylglycine", "succinate", "valine betaine", "ursodeoxycholate sulfate (1)", "tricarballylate","succinimide", "thymine", "syringic acid", "serotonin", "ribitol" ]]

y = df_encoded['Group']

scaler = StandardScaler()
X = pd.DataFrame(scaler.fit_transform(X),columns=X.columns,index=X.index)

all_scores = {
    'deletion': {
        'auroc': [],
        'cross_entropy': [],
        'brier': []
    },
    'insertion': {
        'auroc': [],
        'cross_entropy': [],
        'brier': []
    }
}

for i in seeds:
    print("Training Random Forest model...")
    param_dist = {
        'n_estimators': [100, 200, 300],
        'max_depth': [10, 20, 30, None],
        'min_samples_split': [2, 5, 7],
        'min_samples_leaf': [1, 2, 4]
    }
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=i)

    rf = RandomForestClassifier(random_state=42)
    random_search = RandomizedSearchCV(
    estimator=rf, param_distributions=param_dist, n_iter=50,
            cv=3, n_jobs=-1, verbose=2, random_state=42)
    random_search.fit(X_train, y_train)
    model = random_search.best_estimator_
    best_params = random_search.best_params_

    shap_values = global_importance

    result = evaluate_global_shap_scores(model, X_test, y_test, shap_values, causal=True)

    for method in ['deletion', 'insertion']:
        for metric in ['auroc', 'cross_entropy', 'brier']:
            all_scores[method][metric].append(result[method]["average_scores"][metric])

    import json

    print(json.dumps(result, indent=4))

Training Random Forest model...
Fitting 3 folds for each of 50 candidates, totalling 150 fits


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


{
    "deletion": {
        "stepwise_metrics": {
            "auroc": [
                0.8020833333333333,
                0.7131410256410257,
                0.6466346153846153,
                0.6145833333333333,
                0.6185897435897436,
                0.5657051282051282,
                0.5817307692307693,
                0.594551282051282,
                0.6386217948717949,
                0.6073717948717949,
                0.5945512820512819,
                0.6378205128205129,
                0.6422275641025641,
                0.4911858974358975,
                0.5
            ],
            "cross_entropy": [
                0.5333619228350892,
                0.6176319966059503,
                0.6528667436699289,
                0.7458745657755351,
                0.7544846458433045,
                0.8056460519544488,
                0.821400464724071,
                0.847067025492969,
                0.8662189324127925,
                0.8891361917078303,


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


{
    "deletion": {
        "stepwise_metrics": {
            "auroc": [
                0.7875,
                0.7458333333333332,
                0.6716666666666666,
                0.56,
                0.5683333333333334,
                0.5341666666666667,
                0.4741666666666666,
                0.4916666666666667,
                0.5083333333333333,
                0.4991666666666666,
                0.46333333333333326,
                0.47583333333333333,
                0.48625,
                0.6029166666666667,
                0.5
            ],
            "cross_entropy": [
                0.5893880935369175,
                0.6939687843745195,
                0.739482983649083,
                0.8458936405573276,
                0.845303842443165,
                0.8545702198313575,
                0.8846286320184191,
                0.8939189794493904,
                0.9521357012719036,
                0.9731932698454661,
                1.0225135070200213

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


{
    "deletion": {
        "stepwise_metrics": {
            "auroc": [
                0.8387096774193548,
                0.7741935483870968,
                0.7599399849962492,
                0.7351837959489873,
                0.741185296324081,
                0.6954238559639909,
                0.695423855963991,
                0.6969242310577645,
                0.6264066016504126,
                0.6271567891972993,
                0.568642160540135,
                0.632033008252063,
                0.6219054763690922,
                0.5626406601650413,
                0.5
            ],
            "cross_entropy": [
                0.5193068224943925,
                0.5884276777641769,
                0.606017463133041,
                0.6769923018390795,
                0.6740822002959469,
                0.7092773128958064,
                0.7261480090260554,
                0.7375841733114205,
                0.7708949145704211,
                0.7821190731505933,
  

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


{
    "deletion": {
        "stepwise_metrics": {
            "auroc": [
                0.8286445012787723,
                0.7468030690537085,
                0.7178175618073316,
                0.6317135549872123,
                0.6351236146632566,
                0.5524296675191817,
                0.4876385336743393,
                0.5831202046035806,
                0.5950554134697357,
                0.5507246376811594,
                0.463768115942029,
                0.46675191815856787,
                0.48508098891730606,
                0.4671781756180733,
                0.5
            ],
            "cross_entropy": [
                0.5247611733780392,
                0.5992980891784528,
                0.6364483622529632,
                0.7030805887135134,
                0.7123429769697089,
                0.7792986083853992,
                0.8037872866271865,
                0.8157944287928084,
                0.8312201302583456,
                0.84997989081101

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


{
    "deletion": {
        "stepwise_metrics": {
            "auroc": [
                0.77344336084021,
                0.7531882970742686,
                0.6946736684171043,
                0.6721680420105026,
                0.6984246061515379,
                0.6631657914478619,
                0.5836459114778695,
                0.5656414103525882,
                0.5468867216804201,
                0.5142535633908477,
                0.4343585896474118,
                0.46061515378844714,
                0.672543135783946,
                0.5112528132033008,
                0.5
            ],
            "cross_entropy": [
                0.5584202674130104,
                0.6152823978620233,
                0.6579896115602043,
                0.7018821964848209,
                0.7015595087141857,
                0.7361052111439195,
                0.809865226540508,
                0.816834065043128,
                0.8645194716902611,
                0.8899206691652569,
 

In [49]:
final_results = {
    'deletion': {
        metric: {
            'mean': np.mean(scores),
            'std': np.std(scores)
        }
        for metric, scores in all_scores['deletion'].items()
    },
    'insertion': {
        metric: {
            'mean': np.mean(scores),
            'std': np.std(scores)
        }
        for metric, scores in all_scores['insertion'].items()
    }
}

print("\nFinal Results:")
print(json.dumps(final_results, indent=4))


Final Results:
{
    "deletion": {
        "auroc": {
            "mean": 0.6059979294462116,
            "std": 0.03843201710634803
        },
        "cross_entropy": {
            "mean": 0.7982817151671735,
            "std": 0.045024520350426006
        },
        "brier": {
            "mean": 0.2980920691073037,
            "std": 0.020046173807333888
        }
    },
    "insertion": {
        "auroc": {
            "mean": 0.8602505705878402,
            "std": 0.0224046478922976
        },
        "cross_entropy": {
            "mean": 0.4628575192469359,
            "std": 0.02927429345970162
        },
        "brier": {
            "mean": 0.14556011287089998,
            "std": 0.012544187841407898
        }
    }
}
