1- import sys
2-
1+ import itertools
2+ import unittest
3+ import hashlib
4+ import numpy as np
5+ from causallearn .graph .Dag import Dag
6+ from causallearn .graph .GraphNode import GraphNode
37from causallearn .search .ScoreBased .ExactSearch import bic_exact_search
8+ from causallearn .utils .DAG2CPDAG import dag2cpdag
49
5- sys .path .append ("" )
6- import unittest
7- from pickle import load
810
9- import numpy as np
11+
12+ ######################################### Test Notes ###########################################
13+ # All the benchmark results of loaded files (e.g. "./TestData/benchmark_returned_results/") #
14+ # are obtained from the code of causal-learn as of commit #
15+ # https://github.com/cmu-phil/causal-learn/commit/8badb41 (07-12-2022). #
16+ # #
17+ # We are not sure if the results are completely "correct" (reflect ground truth graph) or not. #
18+ # So if you find your tests failed, it means that your modified code is logically inconsistent #
19+ # with the code as of 8badb41, but not necessarily means that your code is "wrong". #
20+ # If you are sure that your modification is "correct" (e.g. fixed some bugs in 8badb41), #
21+ # please report it to us. We will then modify these benchmark results accordingly. Thanks :) #
22+ ######################################### Test Notes ###########################################
23+
24+
25+ BENCHMARK_TXTFILE_TO_MD5 = {
26+ "tests/TestData/test_exact_search_simulated_linear_gaussian_data.txt" : "1ec70464e4fc68c312adfb7143bd240b" ,
27+ "tests/TestData/test_exact_search_simulated_linear_gaussian_CPDAG.txt" : "52a6d3c5db269d5e212edcbb8283aca9" ,
28+ }
29+ # verify files integrity first
30+ for file_path , expected_MD5 in BENCHMARK_TXTFILE_TO_MD5 .items ():
31+ with open (file_path , 'rb' ) as fin :
32+ assert hashlib .md5 (fin .read ()).hexdigest () == expected_MD5 ,\
33+ f'{ file_path } is corrupted. Please download it again from https://github.com/cmu-phil/causal-learn/blob/8badb41/tests/TestData'
1034
1135
1236class TestAstar (unittest .TestCase ):
37+ # Load data and run Astar with default parameters.
38+ def test_astar_simulate_linear_gaussian_with_local_score_BIC (self ):
39+ # The data and ground truth loaded in this test case is generated by the function
40+ # simulate_linear_gaussian_data_for_exact_search commented below
41+ print ('Now start test_astar_simulate_linear_gaussian_with_local_score_BIC ...' )
42+ truth_CPDAG_matrix = np .loadtxt ("tests/TestData/test_exact_search_simulated_linear_gaussian_CPDAG.txt" )
43+ data = np .loadtxt ("tests/TestData/test_exact_search_simulated_linear_gaussian_data.txt" )
44+ assert truth_CPDAG_matrix .shape [0 ] == truth_CPDAG_matrix .shape [1 ], "Should be a square numpy matrix"
45+ num_of_nodes = len (truth_CPDAG_matrix )
46+ assert data .shape [1 ] == num_of_nodes , "The second dimension of data should be same as number of nodes"
47+ data = data - data .mean (axis = 0 , keepdims = True ) # Center the data
48+ # Iterate over different configurations of path extension and k-cycle heuristic
49+ # to make sure they are working fine
50+ for use_path_extension , use_k_cycle_heuristic in itertools .product ([False , True ], repeat = 2 ):
51+ DAG_matrix , _ = bic_exact_search (data , search_method = 'astar' , use_path_extension = use_path_extension ,
52+ use_k_cycle_heuristic = use_k_cycle_heuristic , k = 3 )
53+ # Convert DAG adjacency matrix to Dag object
54+ nodes = [GraphNode (str (i )) for i in range (num_of_nodes )]
55+ DAG = Dag (nodes )
56+ for i , j in zip (* np .where (DAG_matrix == 1 )):
57+ DAG .add_directed_edge (nodes [i ], nodes [j ])
58+ CPDAG = dag2cpdag (DAG ) # Convert DAG to CPDAG
59+ self .assertTrue (np .all (CPDAG .graph == truth_CPDAG_matrix ))
60+ print ('test_astar_simulate_linear_gaussian_with_local_score_BIC passed!\n ' )
61+
1362
14- # example1
15- # for data with single-variate dimensions, astar.
16- def test_single_astar (self ):
17- with open ("example_data1.pk" , 'rb' ) as example_data1 :
18- # example_data1 = load(open("example_data1.pk", 'rb'))
19- example_data1 = load (example_data1 )
20- X = example_data1 ['X' ]
21- X = X - np .tile (np .mean (X , axis = 0 ), (X .shape [0 ], 1 ))
22- X = np .dot (X , np .diag (1 / np .std (X , axis = 0 )))
23- X = X [:50 , :]
24- dag_est , search_stats = bic_exact_search (X , search_method = 'astar' )
25- print (dag_est )
26- print (search_stats )
27-
28- # example2
29- # for data with multi-variate dimensions, astar.
30- def test_multi_astar (self ):
31- with open ("example_data2.pk" , 'rb' ) as example_data :
32- # example_data = load(open("example_data2.pk", 'rb'))
33- example_data = load (example_data )
34- Data_save = example_data ['Data_save' ]
35- trial = 0
36- X = Data_save [trial ]
37- X = X - np .tile (np .mean (X , axis = 0 ), (X .shape [0 ], 1 ))
38- X = np .dot (X , np .diag (1 / np .std (X , axis = 0 )))
39- X = X [:50 , :]
40- dag_est , search_stats = bic_exact_search (X , search_method = 'astar' )
41- print (dag_est )
42- print (search_stats )
63+ # def simulate_linear_gaussian_data_for_exact_search():
64+ # import pandas as pd
65+ # import random
66+ # random.seed(1) # Reproducibility
67+ # np.random.seed(1) # Reproducibility
68+ # num_of_samples = 100000
69+ # # Generate linear Gaussian data
70+ # X0 = np.random.normal(scale=1.0, size=num_of_samples)
71+ # X1 = 0.5 * X0 + np.random.normal(scale=2.0, size=num_of_samples)
72+ # X3 = np.random.normal(scale=0.5, size=num_of_samples)
73+ # X2 = 0.4 * X1 + 0.7 * X3 + np.random.normal(scale=1.5, size=num_of_samples)
74+ # data_df = pd.DataFrame(data={'X0': X0, 'X1': X1, 'X2': X2, 'X3': X3})
75+ # # Ground truth DAG: X0 -> X1 -> X2 <- X3
76+ # # Ground truth CPDAG: X0 -- X1 -> X2 <- X3
77+ # truth_CPDAG_matrix = np.array([[ 0, -1, 0, 0],
78+ # [-1, 0, -1, 0],
79+ # [ 0, 1, 0, 1],
80+ # [ 0, 0, -1, 0]])
81+ # truth_CPDAG_df = pd.DataFrame(data=truth_CPDAG_matrix)
82+ # # Save data and ground truth
83+ # truth_CPDAG_df.to_csv('./TestData/test_dp_simulated_linear_gaussian_CPDAG.txt', sep=' ', index=False, header=False)
84+ # data_df.to_csv('./TestData/test_dp_simulated_linear_gaussian_data.txt', sep=' ', index=False, header=False)
0 commit comments