Skip to content

Commit 129dcdf

Browse files
authored
Merge pull request #57 from ignavierng/exact_search
Added some unit tests for Astar and DP
2 parents d775822 + 3eb1284 commit 129dcdf

File tree

5 files changed

+100156
-70
lines changed

5 files changed

+100156
-70
lines changed

tests/TestAstar.py

Lines changed: 77 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,84 @@
1-
import sys
2-
1+
import itertools
2+
import unittest
3+
import hashlib
4+
import numpy as np
5+
from causallearn.graph.Dag import Dag
6+
from causallearn.graph.GraphNode import GraphNode
37
from causallearn.search.ScoreBased.ExactSearch import bic_exact_search
8+
from causallearn.utils.DAG2CPDAG import dag2cpdag
49

5-
sys.path.append("")
6-
import unittest
7-
from pickle import load
810

9-
import numpy as np
11+
12+
######################################### Test Notes ###########################################
13+
# All the benchmark results of loaded files (e.g. "./TestData/benchmark_returned_results/") #
14+
# are obtained from the code of causal-learn as of commit #
15+
# https://github.com/cmu-phil/causal-learn/commit/8badb41 (07-12-2022). #
16+
# #
17+
# We are not sure if the results are completely "correct" (reflect ground truth graph) or not. #
18+
# So if you find your tests failed, it means that your modified code is logically inconsistent #
19+
# with the code as of 8badb41, but not necessarily means that your code is "wrong". #
20+
# If you are sure that your modification is "correct" (e.g. fixed some bugs in 8badb41), #
21+
# please report it to us. We will then modify these benchmark results accordingly. Thanks :) #
22+
######################################### Test Notes ###########################################
23+
24+
25+
BENCHMARK_TXTFILE_TO_MD5 = {
26+
"tests/TestData/test_exact_search_simulated_linear_gaussian_data.txt": "1ec70464e4fc68c312adfb7143bd240b",
27+
"tests/TestData/test_exact_search_simulated_linear_gaussian_CPDAG.txt": "52a6d3c5db269d5e212edcbb8283aca9",
28+
}
29+
# verify files integrity first
30+
for file_path, expected_MD5 in BENCHMARK_TXTFILE_TO_MD5.items():
31+
with open(file_path, 'rb') as fin:
32+
assert hashlib.md5(fin.read()).hexdigest() == expected_MD5,\
33+
f'{file_path} is corrupted. Please download it again from https://github.com/cmu-phil/causal-learn/blob/8badb41/tests/TestData'
1034

1135

1236
class TestAstar(unittest.TestCase):
37+
# Load data and run Astar with default parameters.
38+
def test_astar_simulate_linear_gaussian_with_local_score_BIC(self):
39+
# The data and ground truth loaded in this test case is generated by the function
40+
# simulate_linear_gaussian_data_for_exact_search commented below
41+
print('Now start test_astar_simulate_linear_gaussian_with_local_score_BIC ...')
42+
truth_CPDAG_matrix = np.loadtxt("tests/TestData/test_exact_search_simulated_linear_gaussian_CPDAG.txt")
43+
data = np.loadtxt("tests/TestData/test_exact_search_simulated_linear_gaussian_data.txt")
44+
assert truth_CPDAG_matrix.shape[0] == truth_CPDAG_matrix.shape[1], "Should be a square numpy matrix"
45+
num_of_nodes = len(truth_CPDAG_matrix)
46+
assert data.shape[1] == num_of_nodes, "The second dimension of data should be same as number of nodes"
47+
data = data - data.mean(axis=0, keepdims=True) # Center the data
48+
# Iterate over different configurations of path extension and k-cycle heuristic
49+
# to make sure they are working fine
50+
for use_path_extension, use_k_cycle_heuristic in itertools.product([False, True], repeat=2):
51+
DAG_matrix, _ = bic_exact_search(data, search_method='astar', use_path_extension=use_path_extension,
52+
use_k_cycle_heuristic=use_k_cycle_heuristic, k=3)
53+
# Convert DAG adjacency matrix to Dag object
54+
nodes = [GraphNode(str(i)) for i in range(num_of_nodes)]
55+
DAG = Dag(nodes)
56+
for i, j in zip(*np.where(DAG_matrix == 1)):
57+
DAG.add_directed_edge(nodes[i], nodes[j])
58+
CPDAG = dag2cpdag(DAG) # Convert DAG to CPDAG
59+
self.assertTrue(np.all(CPDAG.graph == truth_CPDAG_matrix))
60+
print('test_astar_simulate_linear_gaussian_with_local_score_BIC passed!\n')
61+
1362

14-
# example1
15-
# for data with single-variate dimensions, astar.
16-
def test_single_astar(self):
17-
with open("example_data1.pk", 'rb') as example_data1:
18-
# example_data1 = load(open("example_data1.pk", 'rb'))
19-
example_data1 = load(example_data1)
20-
X = example_data1['X']
21-
X = X - np.tile(np.mean(X, axis=0), (X.shape[0], 1))
22-
X = np.dot(X, np.diag(1 / np.std(X, axis=0)))
23-
X = X[:50, :]
24-
dag_est, search_stats = bic_exact_search(X, search_method='astar')
25-
print(dag_est)
26-
print(search_stats)
27-
28-
# example2
29-
# for data with multi-variate dimensions, astar.
30-
def test_multi_astar(self):
31-
with open("example_data2.pk", 'rb') as example_data:
32-
# example_data = load(open("example_data2.pk", 'rb'))
33-
example_data = load(example_data)
34-
Data_save = example_data['Data_save']
35-
trial = 0
36-
X = Data_save[trial]
37-
X = X - np.tile(np.mean(X, axis=0), (X.shape[0], 1))
38-
X = np.dot(X, np.diag(1 / np.std(X, axis=0)))
39-
X = X[:50, :]
40-
dag_est, search_stats = bic_exact_search(X, search_method='astar')
41-
print(dag_est)
42-
print(search_stats)
63+
# def simulate_linear_gaussian_data_for_exact_search():
64+
# import pandas as pd
65+
# import random
66+
# random.seed(1) # Reproducibility
67+
# np.random.seed(1) # Reproducibility
68+
# num_of_samples = 100000
69+
# # Generate linear Gaussian data
70+
# X0 = np.random.normal(scale=1.0, size=num_of_samples)
71+
# X1 = 0.5 * X0 + np.random.normal(scale=2.0, size=num_of_samples)
72+
# X3 = np.random.normal(scale=0.5, size=num_of_samples)
73+
# X2 = 0.4 * X1 + 0.7 * X3 + np.random.normal(scale=1.5, size=num_of_samples)
74+
# data_df = pd.DataFrame(data={'X0': X0, 'X1': X1, 'X2': X2, 'X3': X3})
75+
# # Ground truth DAG: X0 -> X1 -> X2 <- X3
76+
# # Ground truth CPDAG: X0 -- X1 -> X2 <- X3
77+
# truth_CPDAG_matrix = np.array([[ 0, -1, 0, 0],
78+
# [-1, 0, -1, 0],
79+
# [ 0, 1, 0, 1],
80+
# [ 0, 0, -1, 0]])
81+
# truth_CPDAG_df = pd.DataFrame(data=truth_CPDAG_matrix)
82+
# # Save data and ground truth
83+
# truth_CPDAG_df.to_csv('./TestData/test_dp_simulated_linear_gaussian_CPDAG.txt', sep=' ', index=False, header=False)
84+
# data_df.to_csv('./TestData/test_dp_simulated_linear_gaussian_data.txt', sep=' ', index=False, header=False)

tests/TestDP.py

Lines changed: 75 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,81 @@
1-
import sys
2-
1+
import unittest
2+
import hashlib
3+
import numpy as np
4+
from causallearn.graph.Dag import Dag
5+
from causallearn.graph.GraphNode import GraphNode
36
from causallearn.search.ScoreBased.ExactSearch import bic_exact_search
7+
from causallearn.utils.DAG2CPDAG import dag2cpdag
48

5-
sys.path.append("")
6-
import unittest
7-
from pickle import load
89

9-
import numpy as np
10+
11+
######################################### Test Notes ###########################################
12+
# All the benchmark results of loaded files (e.g. "./TestData/benchmark_returned_results/") #
13+
# are obtained from the code of causal-learn as of commit #
14+
# https://github.com/cmu-phil/causal-learn/commit/8badb41 (07-12-2022). #
15+
# #
16+
# We are not sure if the results are completely "correct" (reflect ground truth graph) or not. #
17+
# So if you find your tests failed, it means that your modified code is logically inconsistent #
18+
# with the code as of 8badb41, but not necessarily means that your code is "wrong". #
19+
# If you are sure that your modification is "correct" (e.g. fixed some bugs in 8badb41), #
20+
# please report it to us. We will then modify these benchmark results accordingly. Thanks :) #
21+
######################################### Test Notes ###########################################
22+
23+
24+
BENCHMARK_TXTFILE_TO_MD5 = {
25+
"tests/TestData/test_exact_search_simulated_linear_gaussian_data.txt": "1ec70464e4fc68c312adfb7143bd240b",
26+
"tests/TestData/test_exact_search_simulated_linear_gaussian_CPDAG.txt": "52a6d3c5db269d5e212edcbb8283aca9",
27+
}
28+
# verify files integrity first
29+
for file_path, expected_MD5 in BENCHMARK_TXTFILE_TO_MD5.items():
30+
with open(file_path, 'rb') as fin:
31+
assert hashlib.md5(fin.read()).hexdigest() == expected_MD5,\
32+
f'{file_path} is corrupted. Please download it again from https://github.com/cmu-phil/causal-learn/blob/8badb41/tests/TestData'
1033

1134

1235
class TestDP(unittest.TestCase):
13-
# example3
14-
# for data with single-variate dimensions, dp.
15-
def test_single_dp(self):
16-
with open("example_data1.pk", 'rb') as example_data1:
17-
# example_data1 = load(open("example_data1.pk", 'rb'))
18-
example_data1 = load(example_data1)
19-
X = example_data1['X']
20-
X = X - np.tile(np.mean(X, axis=0), (X.shape[0], 1))
21-
X = np.dot(X, np.diag(1 / np.std(X, axis=0)))
22-
X = X[:50, :]
23-
dag_est, search_stats = bic_exact_search(X, search_method='dp')
24-
print(dag_est)
25-
print(search_stats)
26-
27-
# example4
28-
# for data with multi-variate dimensions, dp.
29-
def test_multi_dp(self):
30-
with open("example_data2.pk", 'rb') as example_data:
31-
# example_data = load(open("example_data2.pk", 'rb'))
32-
example_data = load(example_data)
33-
Data_save = example_data['Data_save']
34-
trial = 0
35-
X = Data_save[trial]
36-
X = X - np.tile(np.mean(X, axis=0), (X.shape[0], 1))
37-
X = np.dot(X, np.diag(1 / np.std(X, axis=0)))
38-
X = X[:50, :]
39-
dag_est, search_stats = bic_exact_search(X, search_method='dp')
40-
print(dag_est)
41-
print(search_stats)
36+
# Load data and run DP with default parameters.
37+
def test_dp_simulate_linear_gaussian_with_local_score_BIC(self):
38+
# The data and ground truth loaded in this test case is generated by the function
39+
# simulate_linear_gaussian_data_for_exact_search commented below
40+
print('Now start test_dp_simulate_linear_gaussian_with_local_score_BIC ...')
41+
truth_CPDAG_matrix = np.loadtxt("tests/TestData/test_exact_search_simulated_linear_gaussian_CPDAG.txt")
42+
data = np.loadtxt("tests/TestData/test_exact_search_simulated_linear_gaussian_data.txt")
43+
assert truth_CPDAG_matrix.shape[0] == truth_CPDAG_matrix.shape[1], "Should be a square numpy matrix"
44+
num_of_nodes = len(truth_CPDAG_matrix)
45+
assert data.shape[1] == num_of_nodes, "The second dimension of data should be same as number of nodes"
46+
data = data - data.mean(axis=0, keepdims=True) # Center the data
47+
# Iterate over different configurations of path extension to make sure they are working fine
48+
for use_path_extension in [False, True]:
49+
DAG_matrix, _ = bic_exact_search(data, search_method='dp', use_path_extension=use_path_extension)
50+
# Convert DAG adjacency matrix to Dag object
51+
nodes = [GraphNode(str(i)) for i in range(num_of_nodes)]
52+
DAG = Dag(nodes)
53+
for i, j in zip(*np.where(DAG_matrix == 1)):
54+
DAG.add_directed_edge(nodes[i], nodes[j])
55+
CPDAG = dag2cpdag(DAG) # Convert DAG to CPDAG
56+
self.assertTrue(np.all(CPDAG.graph == truth_CPDAG_matrix))
57+
print('test_dp_simulate_linear_gaussian_with_local_score_BIC passed!\n')
58+
59+
60+
# def simulate_linear_gaussian_data_for_exact_search():
61+
# import pandas as pd
62+
# import random
63+
# random.seed(1) # Reproducibility
64+
# np.random.seed(1) # Reproducibility
65+
# num_of_samples = 100000
66+
# # Generate linear Gaussian data
67+
# X0 = np.random.normal(scale=1.0, size=num_of_samples)
68+
# X1 = 0.5 * X0 + np.random.normal(scale=2.0, size=num_of_samples)
69+
# X3 = np.random.normal(scale=0.5, size=num_of_samples)
70+
# X2 = 0.4 * X1 + 0.7 * X3 + np.random.normal(scale=1.5, size=num_of_samples)
71+
# data_df = pd.DataFrame(data={'X0': X0, 'X1': X1, 'X2': X2, 'X3': X3})
72+
# # Ground truth DAG: X0 -> X1 -> X2 <- X3
73+
# # Ground truth CPDAG: X0 -- X1 -> X2 <- X3
74+
# truth_CPDAG_matrix = np.array([[ 0, -1, 0, 0],
75+
# [-1, 0, -1, 0],
76+
# [ 0, 1, 0, 1],
77+
# [ 0, 0, -1, 0]])
78+
# truth_CPDAG_df = pd.DataFrame(data=truth_CPDAG_matrix)
79+
# # Save data and ground truth
80+
# truth_CPDAG_df.to_csv('./TestData/test_dp_simulated_linear_gaussian_CPDAG.txt', sep=' ', index=False, header=False)
81+
# data_df.to_csv('./TestData/test_dp_simulated_linear_gaussian_data.txt', sep=' ', index=False, header=False)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
0 -1 0 0
2+
-1 0 -1 0
3+
0 1 0 1
4+
0 0 -1 0

0 commit comments

Comments
 (0)