-
Notifications
You must be signed in to change notification settings - Fork 231
Use d-separation as CIT in tests, to ensure PC's correctness #65
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1068ffe
9716b59
d75b215
aa7f3c2
9ca56c9
25bdf4e
2e1c265
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,11 @@ | ||
| import os | ||
| import os, time | ||
| import sys | ||
| sys.path.append("") | ||
| import unittest | ||
| import hashlib | ||
| import numpy as np | ||
| from causallearn.search.ConstraintBased.PC import pc | ||
| from causallearn.utils.cit import chisq, fisherz, gsq, kci, mv_fisherz | ||
| from causallearn.utils.cit import chisq, fisherz, gsq, kci, mv_fisherz, d_separation | ||
| from causallearn.graph.SHD import SHD | ||
| from causallearn.utils.DAG2CPDAG import dag2cpdag | ||
| from causallearn.utils.TXT2GeneralGraph import txt2generalgraph | ||
|
|
@@ -330,3 +330,56 @@ def test_pc_load_bnlearn_discrete_datasets(self): | |
| print(f'{bname} ({num_nodes_in_truth} nodes/{num_edges_in_truth} edges): used {cg.PC_elapsed:.5f}s, SHD: {shd.get_shd()}') | ||
|
|
||
| print('test_pc_load_bnlearn_discrete_datasets passed!\n') | ||
|
|
||
| # Test the usage of local cache checkpoint (check speed). | ||
| def test_pc_with_citest_local_checkpoint(self): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I really like the tests here! Cheers! Now we can guarantee our PC algorithm is indeed written correctly. Great work!! :) |
||
| print('Now start test_pc_with_citest_local_checkpoint ...') | ||
| data_path = "./TestData/data_linear_10.txt" | ||
| citest_cache_file = "./TestData/citest_cache_linear_10_first_500_kci.json" | ||
|
|
||
| tic = time.time() | ||
| data = np.loadtxt(data_path, skiprows=1)[:500] | ||
| cg1 = pc(data, 0.05, kci, cache_path=citest_cache_file) | ||
| tac = time.time() | ||
| print(f'First pc run takes {tac - tic:.3f}s.') # First pc run takes 125.663s. | ||
| assert os.path.exists(citest_cache_file), 'Cache file should exist.' | ||
|
|
||
| tic = time.time() | ||
| data = np.loadtxt(data_path, skiprows=1)[:500] | ||
| cg2 = pc(data, 0.05, kci, cache_path=citest_cache_file) | ||
| # you might also try other rules of PC, e.g., pc(data, 0.05, kci, True, 0, -1, cache_path=citest_cache_file) | ||
| tac = time.time() | ||
| print(f'Second pc run takes {tac - tic:.3f}s.') # Second pc run takes 27.316s. | ||
| assert np.all(cg1.G.graph == cg2.G.graph), INCONSISTENT_RESULT_GRAPH_ERRMSG | ||
|
|
||
| print('test_pc_with_citest_local_checkpoint passed!\n') | ||
|
|
||
| # Test graphs in bnlearn repository with d-separation as cit. Ensure PC's correctness. | ||
| def test_pc_load_bnlearn_graphs_with_d_separation(self): | ||
| import networkx as nx | ||
| print('Now start test_pc_load_bnlearn_graphs_with_d_separation ...') | ||
| benchmark_names = [ | ||
| "asia", "cancer", "earthquake", "sachs", "survey", | ||
| "alarm", "barley", "child", "insurance", "water", | ||
| "hailfinder", "hepar2", "win95pts", | ||
| ] | ||
| bnlearn_truth_dag_graph_dir = './TestData/bnlearn_discrete_10000/truth_dag_graph' | ||
| for bname in benchmark_names: | ||
| truth_dag = txt2generalgraph(os.path.join(bnlearn_truth_dag_graph_dir, f'{bname}.graph.txt')) | ||
| truth_cpdag = dag2cpdag(truth_dag) | ||
| num_edges_in_truth = truth_dag.get_num_edges() | ||
| num_nodes_in_truth = truth_dag.get_num_nodes() | ||
|
|
||
| true_dag_netx = nx.DiGraph() | ||
| true_dag_netx.add_nodes_from(list(range(num_nodes_in_truth))) | ||
| true_dag_netx.add_edges_from(set(map(tuple, np.argwhere(truth_dag.graph.T > 0)))) | ||
|
|
||
| data = np.zeros((100, len(truth_dag.nodes))) # just a placeholder | ||
| cg = pc(data, 0.05, d_separation, True, 0, -1, true_dag=true_dag_netx) | ||
| shd = SHD(truth_cpdag, cg.G) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. one final questions: why we don't have assert here? Should we assert shd = 0 here?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh yes, we should! Thanks for pointing this out! |
||
| self.assertEqual(0, shd, "PC with d-separation as CIT returns an inaccurate CPDAG.") | ||
| print(f'{bname} ({num_nodes_in_truth} nodes/{num_edges_in_truth} edges): used {cg.PC_elapsed:.5f}s, SHD: {shd.get_shd()}') | ||
|
|
||
| print('test_pc_load_bnlearn_graphs_with_d_separation passed!\n') | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is really hacky... Normally we shouldn't write hacky code like this.
D_separation is not a Conditional Independence Test, right? We probably shouldn't inherit CIT_base class.
You can think about this OOP design --- normally we need to follow the logic, if D-Separation and CIT has something in common (like here, data is the same, and you call want to return some value), you can add another layer of abstraction under CIT_base.
We need to strictly follow the logical structure in code whenever possible. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much. I see your point! (Sorry I just saw your message..
My current codes are mainly for convenience - so that we can call d-separation just as if we call a citest (same as fisherz or kci). D-separation indeed has many things in common with citest (e.g., i/o), though yes, logically it is not a citest.
By "another layer of abstraction under CIT_base", are you suggesting something like
CIT_base -> D_separation_base -> D_separation? Then this looks almost the same asCIT_base -> D_separation?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OOP is not just chain, it should be a DAG.
For example, you can design things like:
Data_base -> CIT_base -> ....
Data_base -> D_separation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see!
A bit confused: then maybe we'll have to move all of our functionalities in CIT_base (e.g., input check, cache, etc) to Data_base - while they are not attributes about data, and D_separation requires no data?
To me, here D_separation is more like a duck type? Though in definition, it is NOT a citest (not a statistical one but a graphical one), in our context (to test the algorithm's correctness), we call it, use it and evaluate it all like a citest.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see.
I really didn't think much --- yeah, you are right, D-separation requires no data.
My point is just: in OOP, just think about what needs to be abstracted and shared.
So what's common between D_separation and CIT_base? The cache, and other related utils. Then probably make a base named Cache_base, and maybe you can design things like:
Cache_base -> CIT_base -> ....
Cache_base -> D_separation
Usually don't inherit things that you don't need and don't make OOP design not consistent with the logical structure (hacky like this will usually create troubles in the future.).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx @tofuwen. Cool!
Cache_base -> CIT_base -> ....andCache_base -> D_separationnow looks logically reasonable. Though practically I still have this concern:What is the difference set
CIT_base\Cache_base? In other words, what is something shared byFisherZandChisq, but not used inD_separation? There is only one thing,data.Therefore,
Cache_baseshould contain cache-related utilities and input/output checks, andCIT_base\Cache_baseshould be only about data. However, if we do so, some problems arise:Cache_base- or naming it more accurately, e.g.,Cache_for_constraint_base- is it something deserving a base class treatment, or just some utility functions belonging to the CITs? It's natural to understandKCIas a child class ofCIT_base, but it seems weird to seeCIT_baseas a child class ofCache_base. Even without cache, CIT is still CIT.Cache_base) only for d-separation - while sacrificing all the other main-function parts (mentioned above in the 2nd point)?I will think more about how to put d-separation in our package in a both logically reasonable and functionally clean way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I agree with you for the most part. I think you convinced me: I agree that my suggestions seem to add lots of extra work to make the (not very necessary design) better, which I don't think justify the increasing complexity here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool. Thanks so much for this! The separated class for d-separation that you suggested would still be the perfect one, as long as we had enough time - maybe in a future refactor on citest.