Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions causallearn/search/ConstraintBased/CDNOD.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
def cdnod(data: ndarray, c_indx: ndarray, alpha: float=0.05, indep_test: str=fisherz, stable: bool=True,
uc_rule: int=0, uc_priority: int=2, mvcdnod: bool=False, correction_name: str='MV_Crtn_Fisher_Z',
background_knowledge: Optional[BackgroundKnowledge]=None, verbose: bool=False,
show_progress: bool = True) -> CausalGraph:
show_progress: bool = True, **kwargs) -> CausalGraph:
"""
Causal discovery from nonstationary/heterogeneous data
phase 1: learning causal skeleton,
Expand All @@ -37,16 +37,16 @@ def cdnod(data: ndarray, c_indx: ndarray, alpha: float=0.05, indep_test: str=fis
if mvcdnod:
return mvcdnod_alg(data=data_aug, alpha=alpha, indep_test=indep_test, correction_name=correction_name,
stable=stable, uc_rule=uc_rule, uc_priority=uc_priority, verbose=verbose,
show_progress=show_progress)
show_progress=show_progress, **kwargs)
else:
return cdnod_alg(data=data_aug, alpha=alpha, indep_test=indep_test, stable=stable, uc_rule=uc_rule,
uc_priority=uc_priority, background_knowledge=background_knowledge, verbose=verbose,
show_progress=show_progress)
show_progress=show_progress, **kwargs)


def cdnod_alg(data: ndarray, alpha: float, indep_test: str, stable: bool, uc_rule: int, uc_priority: int,
background_knowledge: Optional[BackgroundKnowledge] = None, verbose: bool = False,
show_progress: bool = True) -> CausalGraph:
show_progress: bool = True, **kwargs) -> CausalGraph:
"""
Perform Peter-Clark algorithm for causal discovery on the augmented data set that captures the unobserved changing factors

Expand Down Expand Up @@ -85,7 +85,7 @@ def cdnod_alg(data: ndarray, alpha: float, indep_test: str, stable: bool, uc_rul

"""
start = time.time()
indep_test = CIT(data, indep_test)
indep_test = CIT(data, indep_test, **kwargs)
cg_1 = SkeletonDiscovery.skeleton_discovery(data, alpha, indep_test, stable)

# orient the direction from c_indx to X, if there is an edge between c_indx and X
Expand Down Expand Up @@ -127,7 +127,7 @@ def cdnod_alg(data: ndarray, alpha: float, indep_test: str, stable: bool, uc_rul


def mvcdnod_alg(data: ndarray, alpha: float, indep_test: str, correction_name: str, stable: bool, uc_rule: int,
uc_priority: int, verbose: bool, show_progress: bool) -> CausalGraph:
uc_priority: int, verbose: bool, show_progress: bool, **kwargs) -> CausalGraph:
"""
:param data: data set (numpy ndarray)
:param alpha: desired significance level (float) in (0, 1)
Expand Down Expand Up @@ -156,7 +156,7 @@ def mvcdnod_alg(data: ndarray, alpha: float, indep_test: str, correction_name: s
"""

start = time.time()
indep_test = CIT(data, indep_test)
indep_test = CIT(data, indep_test, **kwargs)
## Step 1: detect the direct causes of missingness indicators
prt_m = get_parent_missingness_pairs(data, alpha, indep_test, stable)
# print('Finish detecting the parents of missingness indicators. ')
Expand Down
5 changes: 3 additions & 2 deletions causallearn/search/ConstraintBased/FCI.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,8 @@ def get_color_edges(graph: Graph) -> List[Edge]:


def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float = 0.05, depth: int = -1,
max_path_length: int = -1, verbose: bool = False, background_knowledge: BackgroundKnowledge | None = None) -> Tuple[Graph, List[Edge]]:
max_path_length: int = -1, verbose: bool = False, background_knowledge: BackgroundKnowledge | None = None,
**kwargs) -> Tuple[Graph, List[Edge]]:
"""
Perform Fast Causal Inference (FCI) algorithm for causal discovery

Expand Down Expand Up @@ -766,7 +767,7 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float =
if dataset.shape[0] < dataset.shape[1]:
warnings.warn("The number of features is much larger than the sample size!")

independence_test_method = CIT(dataset, method=independence_test_method)
independence_test_method = CIT(dataset, method=independence_test_method, **kwargs)

## ------- check parameters ------------
if (depth is None) or type(depth) != int:
Expand Down
13 changes: 8 additions & 5 deletions causallearn/search/ConstraintBased/PC.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def pc(
background_knowledge: BackgroundKnowledge | None = None,
verbose: bool = False,
show_progress: bool = True,
node_names: List[str] | None = None,
node_names: List[str] | None = None,
**kwargs
):
if data.shape[0] < data.shape[1]:
warnings.warn("The number of features is much larger than the sample size!")
Expand All @@ -40,11 +41,11 @@ def pc(
return mvpc_alg(data=data, node_names=node_names, alpha=alpha, indep_test=indep_test, correction_name=correction_name, stable=stable,
uc_rule=uc_rule, uc_priority=uc_priority, background_knowledge=background_knowledge,
verbose=verbose,
show_progress=show_progress)
show_progress=show_progress, **kwargs)
else:
return pc_alg(data=data, node_names=node_names, alpha=alpha, indep_test=indep_test, stable=stable, uc_rule=uc_rule,
uc_priority=uc_priority, background_knowledge=background_knowledge, verbose=verbose,
show_progress=show_progress)
show_progress=show_progress, **kwargs)


def pc_alg(
Expand All @@ -58,6 +59,7 @@ def pc_alg(
background_knowledge: BackgroundKnowledge | None = None,
verbose: bool = False,
show_progress: bool = True,
**kwargs
) -> CausalGraph:
"""
Perform Peter-Clark (PC) algorithm for causal discovery
Expand Down Expand Up @@ -98,7 +100,7 @@ def pc_alg(
"""

start = time.time()
indep_test = CIT(data, indep_test)
indep_test = CIT(data, indep_test, **kwargs)
cg_1 = SkeletonDiscovery.skeleton_discovery(data, alpha, indep_test, stable,
background_knowledge=background_knowledge, verbose=verbose,
show_progress=show_progress, node_names=node_names)
Expand Down Expand Up @@ -148,6 +150,7 @@ def mvpc_alg(
background_knowledge: BackgroundKnowledge | None = None,
verbose: bool = False,
show_progress: bool = True,
**kwargs,
) -> CausalGraph:
"""
Perform missing value Peter-Clark (PC) algorithm for causal discovery
Expand Down Expand Up @@ -192,7 +195,7 @@ def mvpc_alg(
"""

start = time.time()
indep_test = CIT(data, indep_test)
indep_test = CIT(data, indep_test, **kwargs)
## Step 1: detect the direct causes of missingness indicators
prt_m = get_parent_missingness_pairs(data, alpha, indep_test, stable)
# print('Finish detecting the parents of missingness indicators. ')
Expand Down
38 changes: 38 additions & 0 deletions causallearn/utils/cit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
kci = "kci"
chisq = "chisq"
gsq = "gsq"
d_separation = "d_separation"


def CIT(data, method='fisherz', **kwargs):
Expand All @@ -37,6 +38,8 @@ def CIT(data, method='fisherz', **kwargs):
return MV_FisherZ(data, **kwargs)
elif method == mc_fisherz:
return MC_FisherZ(data, **kwargs)
elif method == d_separation:
return D_Separation(data, **kwargs)
else:
raise ValueError("Unknown method: {}".format(method))

Expand Down Expand Up @@ -450,3 +453,38 @@ def __call__(self, X, Y, condition_set, skel, prt_m):

virtual_cit = MV_FisherZ(data_vir)
return virtual_cit(0, 1, tuple(cond_set_bgn_0))


class D_Separation(CIT_Base):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is really hacky... Normally we shouldn't write hacky code like this.

D_separation is not a Conditional Independence Test, right? We probably shouldn't inherit CIT_base class.

You can think about this OOP design --- normally we need to follow the logic, if D-Separation and CIT has something in common (like here, data is the same, and you call want to return some value), you can add another layer of abstraction under CIT_base.

We need to strictly follow the logical structure in code whenever possible. :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much. I see your point! (Sorry I just saw your message..

My current codes are mainly for convenience - so that we can call d-separation just as if we call a citest (same as fisherz or kci). D-separation indeed has many things in common with citest (e.g., i/o), though yes, logically it is not a citest.

By "another layer of abstraction under CIT_base", are you suggesting something like CIT_base -> D_separation_base -> D_separation? Then this looks almost the same as CIT_base -> D_separation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OOP is not just chain, it should be a DAG.

For example, you can design things like:

Data_base -> CIT_base -> ....
Data_base -> D_separation

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see!

A bit confused: then maybe we'll have to move all of our functionalities in CIT_base (e.g., input check, cache, etc) to Data_base - while they are not attributes about data, and D_separation requires no data?

To me, here D_separation is more like a duck type? Though in definition, it is NOT a citest (not a statistical one but a graphical one), in our context (to test the algorithm's correctness), we call it, use it and evaluate it all like a citest.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.

I really didn't think much --- yeah, you are right, D-separation requires no data.

My point is just: in OOP, just think about what needs to be abstracted and shared.

So what's common between D_separation and CIT_base? The cache, and other related utils. Then probably make a base named Cache_base, and maybe you can design things like:

Cache_base -> CIT_base -> ....
Cache_base -> D_separation

Usually don't inherit things that you don't need and don't make OOP design not consistent with the logical structure (hacky like this will usually create troubles in the future.).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx @tofuwen. Cool! Cache_base -> CIT_base -> .... and Cache_base -> D_separation now looks logically reasonable. Though practically I still have this concern:

What is the difference set CIT_base\Cache_base? In other words, what is something shared by FisherZ and Chisq, but not used in D_separation? There is only one thing, data.

Therefore, Cache_base should contain cache-related utilities and input/output checks, and CIT_base\Cache_base should be only about data. However, if we do so, some problems arise:

  • Cache-related utilities and data are relatively coupled (e.g., data hash check, parameters check). It may not be clean or easy to decouple the two.
  • Cache_base - or naming it more accurately, e.g., Cache_for_constraint_base - is it something deserving a base class treatment, or just some utility functions belonging to the CITs? It's natural to understand KCI as a child class of CIT_base, but it seems weird to see CIT_base as a child class of Cache_base. Even without cache, CIT is still CIT.
  • After all, d-separation is here implemented only to check the algorithms' correctness in tests. It is not some key function for users. So, is it really necessary to do the above refactor (separate out a Cache_base) only for d-separation - while sacrificing all the other main-function parts (mentioned above in the 2nd point)?
  • And, can d-separation be seen as a citest? Maybe it depends. On how general we see them. For example,

    d-separation is a criterion for deciding, from a given a causal graph, whether a set X of variables is independent of another set Y, given a third set Z.
    (http://bayes.cs.ucla.edu/BOOK-2K/d-sep.html)

I will think more about how to put d-separation in our package in a both logically reasonable and functionally clean way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree with you for the most part. I think you convinced me: I agree that my suggestions seem to add lots of extra work to make the (not very necessary design) better, which I don't think justify the increasing complexity here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. Thanks so much for this! The separated class for d-separation that you suggested would still be the perfect one, as long as we had enough time - maybe in a future refactor on citest.

def __init__(self, data, true_dag=None, **kwargs):
'''
Use d-separation as CI test, to ensure the correctness of constraint-based methods. (only used for tests)
Parameters
----------
data: numpy.ndarray, just a placeholder, not used in D_Separation
true_dag: nx.DiGraph object, the true DAG
'''
super().__init__(data, **kwargs) # data is just a placeholder, not used in D_Separation
self.check_cache_method_consistent('d_separation', NO_SPECIFIED_PARAMETERS_MSG)
self.true_dag = true_dag
import networkx as nx; global nx
# import networkx here violates PEP8; but we want to prevent unnecessary import at the top (it's only used here)

def __call__(self, X, Y, condition_set=None):
Xs, Ys, condition_set, cache_key = self.get_formatted_XYZ_and_cachekey(X, Y, condition_set)
if cache_key in self.pvalue_cache: return self.pvalue_cache[cache_key]
p = float(nx.d_separated(self.true_dag, {Xs[0]}, {Ys[0]}, set(condition_set)))
# pvalue is bool here: 1 if is_d_separated and 0 otherwise. So heuristic comparison-based uc_rules will not work.

# here we use networkx's d_separation implementation.
# an alternative is to use causal-learn's own d_separation implementation in graph class:
# self.true_dag.is_dseparated_from(
# self.true_dag.nodes[Xs[0]], self.true_dag.nodes[Ys[0]], [self.true_dag.nodes[_] for _ in condition_set])
# where self.true_dag is an instance of GeneralGrpah class.
# I have checked the two implementations: they are equivalent (when the graph is DAG),
# and generally causal-learn's implementation is faster.
# but just for now, I still use networkx's, for two reasons:
# 1. causal-learn's implementation sometimes stops working during run (haven't check detailed reasons)
# 2. GeneralGraph class will be hugely refactored in the near future.
self.pvalue_cache[cache_key] = p
return p
57 changes: 55 additions & 2 deletions tests/TestPC.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os
import os, time
import sys
sys.path.append("")
import unittest
import hashlib
import numpy as np
from causallearn.search.ConstraintBased.PC import pc
from causallearn.utils.cit import chisq, fisherz, gsq, kci, mv_fisherz
from causallearn.utils.cit import chisq, fisherz, gsq, kci, mv_fisherz, d_separation
from causallearn.graph.SHD import SHD
from causallearn.utils.DAG2CPDAG import dag2cpdag
from causallearn.utils.TXT2GeneralGraph import txt2generalgraph
Expand Down Expand Up @@ -330,3 +330,56 @@ def test_pc_load_bnlearn_discrete_datasets(self):
print(f'{bname} ({num_nodes_in_truth} nodes/{num_edges_in_truth} edges): used {cg.PC_elapsed:.5f}s, SHD: {shd.get_shd()}')

print('test_pc_load_bnlearn_discrete_datasets passed!\n')

# Test the usage of local cache checkpoint (check speed).
def test_pc_with_citest_local_checkpoint(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like the tests here!

Cheers! Now we can guarantee our PC algorithm is indeed written correctly. Great work!! :)

print('Now start test_pc_with_citest_local_checkpoint ...')
data_path = "./TestData/data_linear_10.txt"
citest_cache_file = "./TestData/citest_cache_linear_10_first_500_kci.json"

tic = time.time()
data = np.loadtxt(data_path, skiprows=1)[:500]
cg1 = pc(data, 0.05, kci, cache_path=citest_cache_file)
tac = time.time()
print(f'First pc run takes {tac - tic:.3f}s.') # First pc run takes 125.663s.
assert os.path.exists(citest_cache_file), 'Cache file should exist.'

tic = time.time()
data = np.loadtxt(data_path, skiprows=1)[:500]
cg2 = pc(data, 0.05, kci, cache_path=citest_cache_file)
# you might also try other rules of PC, e.g., pc(data, 0.05, kci, True, 0, -1, cache_path=citest_cache_file)
tac = time.time()
print(f'Second pc run takes {tac - tic:.3f}s.') # Second pc run takes 27.316s.
assert np.all(cg1.G.graph == cg2.G.graph), INCONSISTENT_RESULT_GRAPH_ERRMSG

print('test_pc_with_citest_local_checkpoint passed!\n')

# Test graphs in bnlearn repository with d-separation as cit. Ensure PC's correctness.
def test_pc_load_bnlearn_graphs_with_d_separation(self):
import networkx as nx
print('Now start test_pc_load_bnlearn_graphs_with_d_separation ...')
benchmark_names = [
"asia", "cancer", "earthquake", "sachs", "survey",
"alarm", "barley", "child", "insurance", "water",
"hailfinder", "hepar2", "win95pts",
]
bnlearn_truth_dag_graph_dir = './TestData/bnlearn_discrete_10000/truth_dag_graph'
for bname in benchmark_names:
truth_dag = txt2generalgraph(os.path.join(bnlearn_truth_dag_graph_dir, f'{bname}.graph.txt'))
truth_cpdag = dag2cpdag(truth_dag)
num_edges_in_truth = truth_dag.get_num_edges()
num_nodes_in_truth = truth_dag.get_num_nodes()

true_dag_netx = nx.DiGraph()
true_dag_netx.add_nodes_from(list(range(num_nodes_in_truth)))
true_dag_netx.add_edges_from(set(map(tuple, np.argwhere(truth_dag.graph.T > 0))))

data = np.zeros((100, len(truth_dag.nodes))) # just a placeholder
cg = pc(data, 0.05, d_separation, True, 0, -1, true_dag=true_dag_netx)
shd = SHD(truth_cpdag, cg.G)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one final questions: why we don't have assert here?

Should we assert shd = 0 here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, we should! Thanks for pointing this out!

self.assertEqual(0, shd, "PC with d-separation as CIT returns an inaccurate CPDAG.")
print(f'{bname} ({num_nodes_in_truth} nodes/{num_edges_in_truth} edges): used {cg.PC_elapsed:.5f}s, SHD: {shd.get_shd()}')

print('test_pc_load_bnlearn_graphs_with_d_separation passed!\n')