diff --git a/causallearn/search/ConstraintBased/CDNOD.py b/causallearn/search/ConstraintBased/CDNOD.py index b6f93444..98b18579 100644 --- a/causallearn/search/ConstraintBased/CDNOD.py +++ b/causallearn/search/ConstraintBased/CDNOD.py @@ -16,7 +16,7 @@ def cdnod(data: ndarray, c_indx: ndarray, alpha: float=0.05, indep_test: str=fisherz, stable: bool=True, uc_rule: int=0, uc_priority: int=2, mvcdnod: bool=False, correction_name: str='MV_Crtn_Fisher_Z', background_knowledge: Optional[BackgroundKnowledge]=None, verbose: bool=False, - show_progress: bool = True) -> CausalGraph: + show_progress: bool = True, **kwargs) -> CausalGraph: """ Causal discovery from nonstationary/heterogeneous data phase 1: learning causal skeleton, @@ -37,16 +37,16 @@ def cdnod(data: ndarray, c_indx: ndarray, alpha: float=0.05, indep_test: str=fis if mvcdnod: return mvcdnod_alg(data=data_aug, alpha=alpha, indep_test=indep_test, correction_name=correction_name, stable=stable, uc_rule=uc_rule, uc_priority=uc_priority, verbose=verbose, - show_progress=show_progress) + show_progress=show_progress, **kwargs) else: return cdnod_alg(data=data_aug, alpha=alpha, indep_test=indep_test, stable=stable, uc_rule=uc_rule, uc_priority=uc_priority, background_knowledge=background_knowledge, verbose=verbose, - show_progress=show_progress) + show_progress=show_progress, **kwargs) def cdnod_alg(data: ndarray, alpha: float, indep_test: str, stable: bool, uc_rule: int, uc_priority: int, background_knowledge: Optional[BackgroundKnowledge] = None, verbose: bool = False, - show_progress: bool = True) -> CausalGraph: + show_progress: bool = True, **kwargs) -> CausalGraph: """ Perform Peter-Clark algorithm for causal discovery on the augmented data set that captures the unobserved changing factors @@ -85,7 +85,7 @@ def cdnod_alg(data: ndarray, alpha: float, indep_test: str, stable: bool, uc_rul """ start = time.time() - indep_test = CIT(data, indep_test) + indep_test = CIT(data, indep_test, **kwargs) cg_1 = SkeletonDiscovery.skeleton_discovery(data, alpha, indep_test, stable) # orient the direction from c_indx to X, if there is an edge between c_indx and X @@ -127,7 +127,7 @@ def cdnod_alg(data: ndarray, alpha: float, indep_test: str, stable: bool, uc_rul def mvcdnod_alg(data: ndarray, alpha: float, indep_test: str, correction_name: str, stable: bool, uc_rule: int, - uc_priority: int, verbose: bool, show_progress: bool) -> CausalGraph: + uc_priority: int, verbose: bool, show_progress: bool, **kwargs) -> CausalGraph: """ :param data: data set (numpy ndarray) :param alpha: desired significance level (float) in (0, 1) @@ -156,7 +156,7 @@ def mvcdnod_alg(data: ndarray, alpha: float, indep_test: str, correction_name: s """ start = time.time() - indep_test = CIT(data, indep_test) + indep_test = CIT(data, indep_test, **kwargs) ## Step 1: detect the direct causes of missingness indicators prt_m = get_parent_missingness_pairs(data, alpha, indep_test, stable) # print('Finish detecting the parents of missingness indicators. ') diff --git a/causallearn/search/ConstraintBased/FCI.py b/causallearn/search/ConstraintBased/FCI.py index 2705a485..ac8d1232 100644 --- a/causallearn/search/ConstraintBased/FCI.py +++ b/causallearn/search/ConstraintBased/FCI.py @@ -729,7 +729,8 @@ def get_color_edges(graph: Graph) -> List[Edge]: def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float = 0.05, depth: int = -1, - max_path_length: int = -1, verbose: bool = False, background_knowledge: BackgroundKnowledge | None = None) -> Tuple[Graph, List[Edge]]: + max_path_length: int = -1, verbose: bool = False, background_knowledge: BackgroundKnowledge | None = None, + **kwargs) -> Tuple[Graph, List[Edge]]: """ Perform Fast Causal Inference (FCI) algorithm for causal discovery @@ -766,7 +767,7 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float = if dataset.shape[0] < dataset.shape[1]: warnings.warn("The number of features is much larger than the sample size!") - independence_test_method = CIT(dataset, method=independence_test_method) + independence_test_method = CIT(dataset, method=independence_test_method, **kwargs) ## ------- check parameters ------------ if (depth is None) or type(depth) != int: diff --git a/causallearn/search/ConstraintBased/PC.py b/causallearn/search/ConstraintBased/PC.py index 17cdeddf..f48a778d 100644 --- a/causallearn/search/ConstraintBased/PC.py +++ b/causallearn/search/ConstraintBased/PC.py @@ -29,7 +29,8 @@ def pc( background_knowledge: BackgroundKnowledge | None = None, verbose: bool = False, show_progress: bool = True, - node_names: List[str] | None = None, + node_names: List[str] | None = None, + **kwargs ): if data.shape[0] < data.shape[1]: warnings.warn("The number of features is much larger than the sample size!") @@ -40,11 +41,11 @@ def pc( return mvpc_alg(data=data, node_names=node_names, alpha=alpha, indep_test=indep_test, correction_name=correction_name, stable=stable, uc_rule=uc_rule, uc_priority=uc_priority, background_knowledge=background_knowledge, verbose=verbose, - show_progress=show_progress) + show_progress=show_progress, **kwargs) else: return pc_alg(data=data, node_names=node_names, alpha=alpha, indep_test=indep_test, stable=stable, uc_rule=uc_rule, uc_priority=uc_priority, background_knowledge=background_knowledge, verbose=verbose, - show_progress=show_progress) + show_progress=show_progress, **kwargs) def pc_alg( @@ -58,6 +59,7 @@ def pc_alg( background_knowledge: BackgroundKnowledge | None = None, verbose: bool = False, show_progress: bool = True, + **kwargs ) -> CausalGraph: """ Perform Peter-Clark (PC) algorithm for causal discovery @@ -98,7 +100,7 @@ def pc_alg( """ start = time.time() - indep_test = CIT(data, indep_test) + indep_test = CIT(data, indep_test, **kwargs) cg_1 = SkeletonDiscovery.skeleton_discovery(data, alpha, indep_test, stable, background_knowledge=background_knowledge, verbose=verbose, show_progress=show_progress, node_names=node_names) @@ -148,6 +150,7 @@ def mvpc_alg( background_knowledge: BackgroundKnowledge | None = None, verbose: bool = False, show_progress: bool = True, + **kwargs, ) -> CausalGraph: """ Perform missing value Peter-Clark (PC) algorithm for causal discovery @@ -192,7 +195,7 @@ def mvpc_alg( """ start = time.time() - indep_test = CIT(data, indep_test) + indep_test = CIT(data, indep_test, **kwargs) ## Step 1: detect the direct causes of missingness indicators prt_m = get_parent_missingness_pairs(data, alpha, indep_test, stable) # print('Finish detecting the parents of missingness indicators. ') diff --git a/causallearn/utils/cit.py b/causallearn/utils/cit.py index 174cdd2a..9c412c4c 100644 --- a/causallearn/utils/cit.py +++ b/causallearn/utils/cit.py @@ -15,6 +15,7 @@ kci = "kci" chisq = "chisq" gsq = "gsq" +d_separation = "d_separation" def CIT(data, method='fisherz', **kwargs): @@ -37,6 +38,8 @@ def CIT(data, method='fisherz', **kwargs): return MV_FisherZ(data, **kwargs) elif method == mc_fisherz: return MC_FisherZ(data, **kwargs) + elif method == d_separation: + return D_Separation(data, **kwargs) else: raise ValueError("Unknown method: {}".format(method)) @@ -450,3 +453,38 @@ def __call__(self, X, Y, condition_set, skel, prt_m): virtual_cit = MV_FisherZ(data_vir) return virtual_cit(0, 1, tuple(cond_set_bgn_0)) + + +class D_Separation(CIT_Base): + def __init__(self, data, true_dag=None, **kwargs): + ''' + Use d-separation as CI test, to ensure the correctness of constraint-based methods. (only used for tests) + Parameters + ---------- + data: numpy.ndarray, just a placeholder, not used in D_Separation + true_dag: nx.DiGraph object, the true DAG + ''' + super().__init__(data, **kwargs) # data is just a placeholder, not used in D_Separation + self.check_cache_method_consistent('d_separation', NO_SPECIFIED_PARAMETERS_MSG) + self.true_dag = true_dag + import networkx as nx; global nx + # import networkx here violates PEP8; but we want to prevent unnecessary import at the top (it's only used here) + + def __call__(self, X, Y, condition_set=None): + Xs, Ys, condition_set, cache_key = self.get_formatted_XYZ_and_cachekey(X, Y, condition_set) + if cache_key in self.pvalue_cache: return self.pvalue_cache[cache_key] + p = float(nx.d_separated(self.true_dag, {Xs[0]}, {Ys[0]}, set(condition_set))) + # pvalue is bool here: 1 if is_d_separated and 0 otherwise. So heuristic comparison-based uc_rules will not work. + + # here we use networkx's d_separation implementation. + # an alternative is to use causal-learn's own d_separation implementation in graph class: + # self.true_dag.is_dseparated_from( + # self.true_dag.nodes[Xs[0]], self.true_dag.nodes[Ys[0]], [self.true_dag.nodes[_] for _ in condition_set]) + # where self.true_dag is an instance of GeneralGrpah class. + # I have checked the two implementations: they are equivalent (when the graph is DAG), + # and generally causal-learn's implementation is faster. + # but just for now, I still use networkx's, for two reasons: + # 1. causal-learn's implementation sometimes stops working during run (haven't check detailed reasons) + # 2. GeneralGraph class will be hugely refactored in the near future. + self.pvalue_cache[cache_key] = p + return p diff --git a/tests/TestPC.py b/tests/TestPC.py index 9450835a..54f6f973 100644 --- a/tests/TestPC.py +++ b/tests/TestPC.py @@ -1,11 +1,11 @@ -import os +import os, time import sys sys.path.append("") import unittest import hashlib import numpy as np from causallearn.search.ConstraintBased.PC import pc -from causallearn.utils.cit import chisq, fisherz, gsq, kci, mv_fisherz +from causallearn.utils.cit import chisq, fisherz, gsq, kci, mv_fisherz, d_separation from causallearn.graph.SHD import SHD from causallearn.utils.DAG2CPDAG import dag2cpdag from causallearn.utils.TXT2GeneralGraph import txt2generalgraph @@ -330,3 +330,56 @@ def test_pc_load_bnlearn_discrete_datasets(self): print(f'{bname} ({num_nodes_in_truth} nodes/{num_edges_in_truth} edges): used {cg.PC_elapsed:.5f}s, SHD: {shd.get_shd()}') print('test_pc_load_bnlearn_discrete_datasets passed!\n') + + # Test the usage of local cache checkpoint (check speed). + def test_pc_with_citest_local_checkpoint(self): + print('Now start test_pc_with_citest_local_checkpoint ...') + data_path = "./TestData/data_linear_10.txt" + citest_cache_file = "./TestData/citest_cache_linear_10_first_500_kci.json" + + tic = time.time() + data = np.loadtxt(data_path, skiprows=1)[:500] + cg1 = pc(data, 0.05, kci, cache_path=citest_cache_file) + tac = time.time() + print(f'First pc run takes {tac - tic:.3f}s.') # First pc run takes 125.663s. + assert os.path.exists(citest_cache_file), 'Cache file should exist.' + + tic = time.time() + data = np.loadtxt(data_path, skiprows=1)[:500] + cg2 = pc(data, 0.05, kci, cache_path=citest_cache_file) + # you might also try other rules of PC, e.g., pc(data, 0.05, kci, True, 0, -1, cache_path=citest_cache_file) + tac = time.time() + print(f'Second pc run takes {tac - tic:.3f}s.') # Second pc run takes 27.316s. + assert np.all(cg1.G.graph == cg2.G.graph), INCONSISTENT_RESULT_GRAPH_ERRMSG + + print('test_pc_with_citest_local_checkpoint passed!\n') + + # Test graphs in bnlearn repository with d-separation as cit. Ensure PC's correctness. + def test_pc_load_bnlearn_graphs_with_d_separation(self): + import networkx as nx + print('Now start test_pc_load_bnlearn_graphs_with_d_separation ...') + benchmark_names = [ + "asia", "cancer", "earthquake", "sachs", "survey", + "alarm", "barley", "child", "insurance", "water", + "hailfinder", "hepar2", "win95pts", + ] + bnlearn_truth_dag_graph_dir = './TestData/bnlearn_discrete_10000/truth_dag_graph' + for bname in benchmark_names: + truth_dag = txt2generalgraph(os.path.join(bnlearn_truth_dag_graph_dir, f'{bname}.graph.txt')) + truth_cpdag = dag2cpdag(truth_dag) + num_edges_in_truth = truth_dag.get_num_edges() + num_nodes_in_truth = truth_dag.get_num_nodes() + + true_dag_netx = nx.DiGraph() + true_dag_netx.add_nodes_from(list(range(num_nodes_in_truth))) + true_dag_netx.add_edges_from(set(map(tuple, np.argwhere(truth_dag.graph.T > 0)))) + + data = np.zeros((100, len(truth_dag.nodes))) # just a placeholder + cg = pc(data, 0.05, d_separation, True, 0, -1, true_dag=true_dag_netx) + shd = SHD(truth_cpdag, cg.G) + self.assertEqual(0, shd, "PC with d-separation as CIT returns an inaccurate CPDAG.") + print(f'{bname} ({num_nodes_in_truth} nodes/{num_edges_in_truth} edges): used {cg.PC_elapsed:.5f}s, SHD: {shd.get_shd()}') + + print('test_pc_load_bnlearn_graphs_with_d_separation passed!\n') + +