Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 0 additions & 81 deletions tests/TestDP.py

This file was deleted.

33 changes: 18 additions & 15 deletions tests/TestAstar.py → tests/TestExactSearch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
import unittest
import hashlib
import numpy as np
Expand All @@ -12,12 +11,12 @@
######################################### Test Notes ###########################################
# All the benchmark results of loaded files (e.g. "./TestData/benchmark_returned_results/") #
# are obtained from the code of causal-learn as of commit #
# https://github.com/cmu-phil/causal-learn/commit/8badb41 (07-12-2022). #
# https://github.com/cmu-phil/causal-learn/commit/129dcdf (07-14-2022). #
# #
# We are not sure if the results are completely "correct" (reflect ground truth graph) or not. #
# So if you find your tests failed, it means that your modified code is logically inconsistent #
# with the code as of 8badb41, but not necessarily means that your code is "wrong". #
# If you are sure that your modification is "correct" (e.g. fixed some bugs in 8badb41), #
# with the code as of 129dcdf, but not necessarily means that your code is "wrong". #
# If you are sure that your modification is "correct" (e.g. fixed some bugs in 129dcdf), #
# please report it to us. We will then modify these benchmark results accordingly. Thanks :) #
######################################### Test Notes ###########################################

Expand All @@ -30,25 +29,29 @@
for file_path, expected_MD5 in BENCHMARK_TXTFILE_TO_MD5.items():
with open(file_path, 'rb') as fin:
assert hashlib.md5(fin.read()).hexdigest() == expected_MD5,\
f'{file_path} is corrupted. Please download it again from https://github.com/cmu-phil/causal-learn/blob/8badb41/tests/TestData'
f'{file_path} is corrupted. Please download it again from https://github.com/cmu-phil/causal-learn/blob/129dcdf/tests/TestData'


class TestAstar(unittest.TestCase):
# Load data and run Astar with default parameters.
def test_astar_simulate_linear_gaussian_with_local_score_BIC(self):
class TestExactSearch(unittest.TestCase):
# Load data and run exact search.
def test_exact_search_simulate_linear_gaussian_with_local_score_BIC(self):
# The data and ground truth loaded in this test case is generated by the function
# simulate_linear_gaussian_data_for_exact_search commented below
print('Now start test_astar_simulate_linear_gaussian_with_local_score_BIC ...')
print('Now start test_exact_search_simulate_linear_gaussian_with_local_score_BIC ...')
truth_CPDAG_matrix = np.loadtxt("tests/TestData/test_exact_search_simulated_linear_gaussian_CPDAG.txt")
data = np.loadtxt("tests/TestData/test_exact_search_simulated_linear_gaussian_data.txt")
assert truth_CPDAG_matrix.shape[0] == truth_CPDAG_matrix.shape[1], "Should be a square numpy matrix"
num_of_nodes = len(truth_CPDAG_matrix)
assert data.shape[1] == num_of_nodes, "The second dimension of data should be same as number of nodes"
data = data - data.mean(axis=0, keepdims=True) # Center the data
# Iterate over different configurations of path extension and k-cycle heuristic
# Iterate over different configurations of search method, path extension, and k-cycle heuristic
# to make sure they are working fine
for use_path_extension, use_k_cycle_heuristic in itertools.product([False, True], repeat=2):
DAG_matrix, _ = bic_exact_search(data, search_method='astar', use_path_extension=use_path_extension,
configs = [('astar', False, False), ('astar', True, False),
('astar', False, True), ('astar', True, True),
('dp', False, False), ('dp', True, False)]
for search_method, use_path_extension, use_k_cycle_heuristic in configs:
DAG_matrix, _ = bic_exact_search(data, search_method=search_method,
use_path_extension=use_path_extension,
use_k_cycle_heuristic=use_k_cycle_heuristic, k=3)
# Convert DAG adjacency matrix to Dag object
nodes = [GraphNode(str(i)) for i in range(num_of_nodes)]
Expand All @@ -57,7 +60,7 @@ def test_astar_simulate_linear_gaussian_with_local_score_BIC(self):
DAG.add_directed_edge(nodes[i], nodes[j])
CPDAG = dag2cpdag(DAG) # Convert DAG to CPDAG
self.assertTrue(np.all(CPDAG.graph == truth_CPDAG_matrix))
print('test_astar_simulate_linear_gaussian_with_local_score_BIC passed!\n')
print('test_exact_search_simulate_linear_gaussian_with_local_score_BIC passed!\n')


# def simulate_linear_gaussian_data_for_exact_search():
Expand All @@ -80,5 +83,5 @@ def test_astar_simulate_linear_gaussian_with_local_score_BIC(self):
# [ 0, 0, -1, 0]])
# truth_CPDAG_df = pd.DataFrame(data=truth_CPDAG_matrix)
# # Save data and ground truth
# truth_CPDAG_df.to_csv('./TestData/test_dp_simulated_linear_gaussian_CPDAG.txt', sep=' ', index=False, header=False)
# data_df.to_csv('./TestData/test_dp_simulated_linear_gaussian_data.txt', sep=' ', index=False, header=False)
# truth_CPDAG_df.to_csv('./TestData/test_exact_search_simulated_linear_gaussian_CPDAG.txt', sep=' ', index=False, header=False)
# data_df.to_csv('./TestData/test_exact_search_simulated_linear_gaussian_data.txt', sep=' ', index=False, header=False)