Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions causallearn/search/ConstraintBased/CDNOD.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
def cdnod(data: ndarray, c_indx: ndarray, alpha: float=0.05, indep_test: str=fisherz, stable: bool=True,
uc_rule: int=0, uc_priority: int=2, mvcdnod: bool=False, correction_name: str='MV_Crtn_Fisher_Z',
background_knowledge: Optional[BackgroundKnowledge]=None, verbose: bool=False,
show_progress: bool = True) -> CausalGraph:
show_progress: bool = True, **kwargs) -> CausalGraph:
"""
Causal discovery from nonstationary/heterogeneous data
phase 1: learning causal skeleton,
Expand All @@ -37,16 +37,16 @@ def cdnod(data: ndarray, c_indx: ndarray, alpha: float=0.05, indep_test: str=fis
if mvcdnod:
return mvcdnod_alg(data=data_aug, alpha=alpha, indep_test=indep_test, correction_name=correction_name,
stable=stable, uc_rule=uc_rule, uc_priority=uc_priority, verbose=verbose,
show_progress=show_progress)
show_progress=show_progress, **kwargs)
else:
return cdnod_alg(data=data_aug, alpha=alpha, indep_test=indep_test, stable=stable, uc_rule=uc_rule,
uc_priority=uc_priority, background_knowledge=background_knowledge, verbose=verbose,
show_progress=show_progress)
show_progress=show_progress, **kwargs)


def cdnod_alg(data: ndarray, alpha: float, indep_test: str, stable: bool, uc_rule: int, uc_priority: int,
background_knowledge: Optional[BackgroundKnowledge] = None, verbose: bool = False,
show_progress: bool = True) -> CausalGraph:
show_progress: bool = True, **kwargs) -> CausalGraph:
"""
Perform Peter-Clark algorithm for causal discovery on the augmented data set that captures the unobserved changing factors

Expand Down Expand Up @@ -85,7 +85,7 @@ def cdnod_alg(data: ndarray, alpha: float, indep_test: str, stable: bool, uc_rul

"""
start = time.time()
indep_test = CIT(data, indep_test)
indep_test = CIT(data, indep_test, **kwargs)
cg_1 = SkeletonDiscovery.skeleton_discovery(data, alpha, indep_test, stable)

# orient the direction from c_indx to X, if there is an edge between c_indx and X
Expand Down Expand Up @@ -127,7 +127,7 @@ def cdnod_alg(data: ndarray, alpha: float, indep_test: str, stable: bool, uc_rul


def mvcdnod_alg(data: ndarray, alpha: float, indep_test: str, correction_name: str, stable: bool, uc_rule: int,
uc_priority: int, verbose: bool, show_progress: bool) -> CausalGraph:
uc_priority: int, verbose: bool, show_progress: bool, **kwargs) -> CausalGraph:
"""
:param data: data set (numpy ndarray)
:param alpha: desired significance level (float) in (0, 1)
Expand Down Expand Up @@ -156,7 +156,7 @@ def mvcdnod_alg(data: ndarray, alpha: float, indep_test: str, correction_name: s
"""

start = time.time()
indep_test = CIT(data, indep_test)
indep_test = CIT(data, indep_test, **kwargs)
## Step 1: detect the direct causes of missingness indicators
prt_m = get_parent_missingness_pairs(data, alpha, indep_test, stable)
# print('Finish detecting the parents of missingness indicators. ')
Expand Down
5 changes: 3 additions & 2 deletions causallearn/search/ConstraintBased/FCI.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,8 @@ def get_color_edges(graph: Graph) -> List[Edge]:


def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float = 0.05, depth: int = -1,
max_path_length: int = -1, verbose: bool = False, background_knowledge: BackgroundKnowledge | None = None) -> Tuple[Graph, List[Edge]]:
max_path_length: int = -1, verbose: bool = False, background_knowledge: BackgroundKnowledge | None = None,
**kwargs) -> Tuple[Graph, List[Edge]]:
"""
Perform Fast Causal Inference (FCI) algorithm for causal discovery

Expand Down Expand Up @@ -766,7 +767,7 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float =
if dataset.shape[0] < dataset.shape[1]:
warnings.warn("The number of features is much larger than the sample size!")

independence_test_method = CIT(dataset, method=independence_test_method)
independence_test_method = CIT(dataset, method=independence_test_method, **kwargs)

## ------- check parameters ------------
if (depth is None) or type(depth) != int:
Expand Down
13 changes: 8 additions & 5 deletions causallearn/search/ConstraintBased/PC.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def pc(
background_knowledge: BackgroundKnowledge | None = None,
verbose: bool = False,
show_progress: bool = True,
node_names: List[str] | None = None,
node_names: List[str] | None = None,
**kwargs
):
if data.shape[0] < data.shape[1]:
warnings.warn("The number of features is much larger than the sample size!")
Expand All @@ -40,11 +41,11 @@ def pc(
return mvpc_alg(data=data, node_names=node_names, alpha=alpha, indep_test=indep_test, correction_name=correction_name, stable=stable,
uc_rule=uc_rule, uc_priority=uc_priority, background_knowledge=background_knowledge,
verbose=verbose,
show_progress=show_progress)
show_progress=show_progress, **kwargs)
else:
return pc_alg(data=data, node_names=node_names, alpha=alpha, indep_test=indep_test, stable=stable, uc_rule=uc_rule,
uc_priority=uc_priority, background_knowledge=background_knowledge, verbose=verbose,
show_progress=show_progress)
show_progress=show_progress, **kwargs)


def pc_alg(
Expand All @@ -58,6 +59,7 @@ def pc_alg(
background_knowledge: BackgroundKnowledge | None = None,
verbose: bool = False,
show_progress: bool = True,
**kwargs
) -> CausalGraph:
"""
Perform Peter-Clark (PC) algorithm for causal discovery
Expand Down Expand Up @@ -98,7 +100,7 @@ def pc_alg(
"""

start = time.time()
indep_test = CIT(data, indep_test)
indep_test = CIT(data, indep_test, **kwargs)
cg_1 = SkeletonDiscovery.skeleton_discovery(data, alpha, indep_test, stable,
background_knowledge=background_knowledge, verbose=verbose,
show_progress=show_progress, node_names=node_names)
Expand Down Expand Up @@ -148,6 +150,7 @@ def mvpc_alg(
background_knowledge: BackgroundKnowledge | None = None,
verbose: bool = False,
show_progress: bool = True,
**kwargs,
) -> CausalGraph:
"""
Perform missing value Peter-Clark (PC) algorithm for causal discovery
Expand Down Expand Up @@ -192,7 +195,7 @@ def mvpc_alg(
"""

start = time.time()
indep_test = CIT(data, indep_test)
indep_test = CIT(data, indep_test, **kwargs)
## Step 1: detect the direct causes of missingness indicators
prt_m = get_parent_missingness_pairs(data, alpha, indep_test, stable)
# print('Finish detecting the parents of missingness indicators. ')
Expand Down
25 changes: 24 additions & 1 deletion tests/TestPC.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os
import os, time
import sys
sys.path.append("")
import unittest
Expand Down Expand Up @@ -330,3 +330,26 @@ def test_pc_load_bnlearn_discrete_datasets(self):
print(f'{bname} ({num_nodes_in_truth} nodes/{num_edges_in_truth} edges): used {cg.PC_elapsed:.5f}s, SHD: {shd.get_shd()}')

print('test_pc_load_bnlearn_discrete_datasets passed!\n')

# Test the usage of local cache checkpoint (check speed).
def test_pc_with_citest_local_checkpoint(self):
print('Now start test_pc_with_citest_local_checkpoint ...')
data_path = "./TestData/data_linear_10.txt"
citest_cache_file = "./TestData/citest_cache_linear_10_first_500_kci.json"

tic = time.time()
data = np.loadtxt(data_path, skiprows=1)[:500]
cg1 = pc(data, 0.05, kci, cache_path=citest_cache_file)
tac = time.time()
print(f'First pc run takes {tac - tic:.3f}s.') # First pc run takes 125.663s.
assert os.path.exists(citest_cache_file), 'Cache file should exist.'

tic = time.time()
data = np.loadtxt(data_path, skiprows=1)[:500]
cg2 = pc(data, 0.05, kci, cache_path=citest_cache_file)
# you might also try other rules of PC, e.g., pc(data, 0.05, kci, True, 0, -1, cache_path=citest_cache_file)
tac = time.time()
print(f'Second pc run takes {tac - tic:.3f}s.') # Second pc run takes 27.316s.
assert np.all(cg1.G.graph == cg2.G.graph), INCONSISTENT_RESULT_GRAPH_ERRMSG

print('test_pc_with_citest_local_checkpoint passed!\n')