From 0722dbe44ffc9a717b87d3faccc905f9d2d56d47 Mon Sep 17 00:00:00 2001 From: ignavierng Date: Fri, 15 Jul 2022 02:06:27 +0000 Subject: [PATCH 1/2] Refactored unit tests for Astar and DP --- tests/TestDP.py | 81 ---------------------- tests/{TestAstar.py => TestExactSearch.py} | 25 ++++--- 2 files changed, 14 insertions(+), 92 deletions(-) delete mode 100644 tests/TestDP.py rename tests/{TestAstar.py => TestExactSearch.py} (79%) diff --git a/tests/TestDP.py b/tests/TestDP.py deleted file mode 100644 index 4c4bf27a..00000000 --- a/tests/TestDP.py +++ /dev/null @@ -1,81 +0,0 @@ -import unittest -import hashlib -import numpy as np -from causallearn.graph.Dag import Dag -from causallearn.graph.GraphNode import GraphNode -from causallearn.search.ScoreBased.ExactSearch import bic_exact_search -from causallearn.utils.DAG2CPDAG import dag2cpdag - - - -######################################### Test Notes ########################################### -# All the benchmark results of loaded files (e.g. "./TestData/benchmark_returned_results/") # -# are obtained from the code of causal-learn as of commit # -# https://github.com/cmu-phil/causal-learn/commit/8badb41 (07-12-2022). # -# # -# We are not sure if the results are completely "correct" (reflect ground truth graph) or not. # -# So if you find your tests failed, it means that your modified code is logically inconsistent # -# with the code as of 8badb41, but not necessarily means that your code is "wrong". # -# If you are sure that your modification is "correct" (e.g. fixed some bugs in 8badb41), # -# please report it to us. We will then modify these benchmark results accordingly. Thanks :) # -######################################### Test Notes ########################################### - - -BENCHMARK_TXTFILE_TO_MD5 = { - "tests/TestData/test_exact_search_simulated_linear_gaussian_data.txt": "1ec70464e4fc68c312adfb7143bd240b", - "tests/TestData/test_exact_search_simulated_linear_gaussian_CPDAG.txt": "52a6d3c5db269d5e212edcbb8283aca9", -} -# verify files integrity first -for file_path, expected_MD5 in BENCHMARK_TXTFILE_TO_MD5.items(): - with open(file_path, 'rb') as fin: - assert hashlib.md5(fin.read()).hexdigest() == expected_MD5,\ - f'{file_path} is corrupted. Please download it again from https://github.com/cmu-phil/causal-learn/blob/8badb41/tests/TestData' - - -class TestDP(unittest.TestCase): - # Load data and run DP with default parameters. - def test_dp_simulate_linear_gaussian_with_local_score_BIC(self): - # The data and ground truth loaded in this test case is generated by the function - # simulate_linear_gaussian_data_for_exact_search commented below - print('Now start test_dp_simulate_linear_gaussian_with_local_score_BIC ...') - truth_CPDAG_matrix = np.loadtxt("tests/TestData/test_exact_search_simulated_linear_gaussian_CPDAG.txt") - data = np.loadtxt("tests/TestData/test_exact_search_simulated_linear_gaussian_data.txt") - assert truth_CPDAG_matrix.shape[0] == truth_CPDAG_matrix.shape[1], "Should be a square numpy matrix" - num_of_nodes = len(truth_CPDAG_matrix) - assert data.shape[1] == num_of_nodes, "The second dimension of data should be same as number of nodes" - data = data - data.mean(axis=0, keepdims=True) # Center the data - # Iterate over different configurations of path extension to make sure they are working fine - for use_path_extension in [False, True]: - DAG_matrix, _ = bic_exact_search(data, search_method='dp', use_path_extension=use_path_extension) - # Convert DAG adjacency matrix to Dag object - nodes = [GraphNode(str(i)) for i in range(num_of_nodes)] - DAG = Dag(nodes) - for i, j in zip(*np.where(DAG_matrix == 1)): - DAG.add_directed_edge(nodes[i], nodes[j]) - CPDAG = dag2cpdag(DAG) # Convert DAG to CPDAG - self.assertTrue(np.all(CPDAG.graph == truth_CPDAG_matrix)) - print('test_dp_simulate_linear_gaussian_with_local_score_BIC passed!\n') - - -# def simulate_linear_gaussian_data_for_exact_search(): -# import pandas as pd -# import random -# random.seed(1) # Reproducibility -# np.random.seed(1) # Reproducibility -# num_of_samples = 100000 -# # Generate linear Gaussian data -# X0 = np.random.normal(scale=1.0, size=num_of_samples) -# X1 = 0.5 * X0 + np.random.normal(scale=2.0, size=num_of_samples) -# X3 = np.random.normal(scale=0.5, size=num_of_samples) -# X2 = 0.4 * X1 + 0.7 * X3 + np.random.normal(scale=1.5, size=num_of_samples) -# data_df = pd.DataFrame(data={'X0': X0, 'X1': X1, 'X2': X2, 'X3': X3}) -# # Ground truth DAG: X0 -> X1 -> X2 <- X3 -# # Ground truth CPDAG: X0 -- X1 -> X2 <- X3 -# truth_CPDAG_matrix = np.array([[ 0, -1, 0, 0], -# [-1, 0, -1, 0], -# [ 0, 1, 0, 1], -# [ 0, 0, -1, 0]]) -# truth_CPDAG_df = pd.DataFrame(data=truth_CPDAG_matrix) -# # Save data and ground truth -# truth_CPDAG_df.to_csv('./TestData/test_dp_simulated_linear_gaussian_CPDAG.txt', sep=' ', index=False, header=False) -# data_df.to_csv('./TestData/test_dp_simulated_linear_gaussian_data.txt', sep=' ', index=False, header=False) \ No newline at end of file diff --git a/tests/TestAstar.py b/tests/TestExactSearch.py similarity index 79% rename from tests/TestAstar.py rename to tests/TestExactSearch.py index 59481aca..b601d85c 100644 --- a/tests/TestAstar.py +++ b/tests/TestExactSearch.py @@ -1,4 +1,3 @@ -import itertools import unittest import hashlib import numpy as np @@ -33,22 +32,26 @@ f'{file_path} is corrupted. Please download it again from https://github.com/cmu-phil/causal-learn/blob/8badb41/tests/TestData' -class TestAstar(unittest.TestCase): - # Load data and run Astar with default parameters. - def test_astar_simulate_linear_gaussian_with_local_score_BIC(self): +class TestExactSearch(unittest.TestCase): + # Load data and run exact search. + def test_exact_search_simulate_linear_gaussian_with_local_score_BIC(self): # The data and ground truth loaded in this test case is generated by the function # simulate_linear_gaussian_data_for_exact_search commented below - print('Now start test_astar_simulate_linear_gaussian_with_local_score_BIC ...') + print('Now start test_exact_search_simulate_linear_gaussian_with_local_score_BIC ...') truth_CPDAG_matrix = np.loadtxt("tests/TestData/test_exact_search_simulated_linear_gaussian_CPDAG.txt") data = np.loadtxt("tests/TestData/test_exact_search_simulated_linear_gaussian_data.txt") assert truth_CPDAG_matrix.shape[0] == truth_CPDAG_matrix.shape[1], "Should be a square numpy matrix" num_of_nodes = len(truth_CPDAG_matrix) assert data.shape[1] == num_of_nodes, "The second dimension of data should be same as number of nodes" data = data - data.mean(axis=0, keepdims=True) # Center the data - # Iterate over different configurations of path extension and k-cycle heuristic + # Iterate over different configurations of search method, path extension, and k-cycle heuristic # to make sure they are working fine - for use_path_extension, use_k_cycle_heuristic in itertools.product([False, True], repeat=2): - DAG_matrix, _ = bic_exact_search(data, search_method='astar', use_path_extension=use_path_extension, + configs = [('astar', False, False), ('astar', True, False), + ('astar', False, True), ('astar', True, True), + ('dp', False, False), ('dp', True, False)] + for search_method, use_path_extension, use_k_cycle_heuristic in configs: + DAG_matrix, _ = bic_exact_search(data, search_method=search_method, + use_path_extension=use_path_extension, use_k_cycle_heuristic=use_k_cycle_heuristic, k=3) # Convert DAG adjacency matrix to Dag object nodes = [GraphNode(str(i)) for i in range(num_of_nodes)] @@ -57,7 +60,7 @@ def test_astar_simulate_linear_gaussian_with_local_score_BIC(self): DAG.add_directed_edge(nodes[i], nodes[j]) CPDAG = dag2cpdag(DAG) # Convert DAG to CPDAG self.assertTrue(np.all(CPDAG.graph == truth_CPDAG_matrix)) - print('test_astar_simulate_linear_gaussian_with_local_score_BIC passed!\n') + print('test_exact_search_simulate_linear_gaussian_with_local_score_BIC passed!\n') # def simulate_linear_gaussian_data_for_exact_search(): @@ -80,5 +83,5 @@ def test_astar_simulate_linear_gaussian_with_local_score_BIC(self): # [ 0, 0, -1, 0]]) # truth_CPDAG_df = pd.DataFrame(data=truth_CPDAG_matrix) # # Save data and ground truth -# truth_CPDAG_df.to_csv('./TestData/test_dp_simulated_linear_gaussian_CPDAG.txt', sep=' ', index=False, header=False) -# data_df.to_csv('./TestData/test_dp_simulated_linear_gaussian_data.txt', sep=' ', index=False, header=False) \ No newline at end of file +# truth_CPDAG_df.to_csv('./TestData/test_exact_search_simulated_linear_gaussian_CPDAG.txt', sep=' ', index=False, header=False) +# data_df.to_csv('./TestData/test_exact_search_simulated_linear_gaussian_data.txt', sep=' ', index=False, header=False) \ No newline at end of file From 4d15f3ab20afb7f20334bcbff66c04fb86419325 Mon Sep 17 00:00:00 2001 From: ignavierng Date: Fri, 15 Jul 2022 02:12:12 +0000 Subject: [PATCH 2/2] Updated documentation --- tests/TestExactSearch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/TestExactSearch.py b/tests/TestExactSearch.py index b601d85c..b9daea24 100644 --- a/tests/TestExactSearch.py +++ b/tests/TestExactSearch.py @@ -11,12 +11,12 @@ ######################################### Test Notes ########################################### # All the benchmark results of loaded files (e.g. "./TestData/benchmark_returned_results/") # # are obtained from the code of causal-learn as of commit # -# https://github.com/cmu-phil/causal-learn/commit/8badb41 (07-12-2022). # +# https://github.com/cmu-phil/causal-learn/commit/129dcdf (07-14-2022). # # # # We are not sure if the results are completely "correct" (reflect ground truth graph) or not. # # So if you find your tests failed, it means that your modified code is logically inconsistent # -# with the code as of 8badb41, but not necessarily means that your code is "wrong". # -# If you are sure that your modification is "correct" (e.g. fixed some bugs in 8badb41), # +# with the code as of 129dcdf, but not necessarily means that your code is "wrong". # +# If you are sure that your modification is "correct" (e.g. fixed some bugs in 129dcdf), # # please report it to us. We will then modify these benchmark results accordingly. Thanks :) # ######################################### Test Notes ########################################### @@ -29,7 +29,7 @@ for file_path, expected_MD5 in BENCHMARK_TXTFILE_TO_MD5.items(): with open(file_path, 'rb') as fin: assert hashlib.md5(fin.read()).hexdigest() == expected_MD5,\ - f'{file_path} is corrupted. Please download it again from https://github.com/cmu-phil/causal-learn/blob/8badb41/tests/TestData' + f'{file_path} is corrupted. Please download it again from https://github.com/cmu-phil/causal-learn/blob/129dcdf/tests/TestData' class TestExactSearch(unittest.TestCase):