Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions tests/TestGranger.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,26 @@

from causallearn.search.Granger.Granger import Granger

######################################### Test Notes ###########################################
# All the benchmark results (p_value_matrix_truth, adj_matrix_truth, coeff_truth) #
# are obtained from the code of causal-learn as of commit #
# https://github.com/cmu-phil/causal-learn/commit/b49980d046607baaaa66ff8dc0ceb98452ab8616 #
# (b49980d). #
# #
# We are not sure if the results are completely "correct" (reflect ground truth graph) or not. #
# So if you find your tests failed, it means that your modified code is logically inconsistent #
# with the code as of b49980d, but not necessarily means that your code is "wrong". #
# If you are sure that your modification is "correct" (e.g. fixed some bugs in b49980d), #
# please report it to us. We will then modify these benchmark results accordingly. Thanks :) #
######################################### Test Notes ###########################################


class TestGranger(unittest.TestCase):
# simulate data from a VAR model
def syn_data_3d(self):
# generate transition matrix, time lag 2
np.random.seed(0)
A = 0.2 * np.random.rand(3,6)
print('True matrix is \n {}'.format(A))
# generate time series
T = 1000
data = np.random.rand(3, T)
Expand All @@ -35,7 +47,6 @@ def syn_data_2d(self):
A = 0.5*np.random.rand(2,4)
A[0,1] = 0
A[0,3] = 0
print('True matrix is \n {}'.format(A))
# generate time series
T = 100
data = np.random.rand(2, T)
Expand All @@ -54,8 +65,10 @@ def test_granger_test(self):
dataset = self.syn_data_2d()
G = Granger()
p_value_matrix, adj_matrix = G.granger_test_2d(data=dataset)
print('P-value matrix is \n {}'.format(p_value_matrix))
print('Adjacency matrix is \n {}'.format(adj_matrix))
p_value_matrix_truth = np.array([[0, 0.5989, 0, 0.5397], [0.0006, 0, 0.0014, 0]])
adj_matrix_truth = np.array([[1, 0, 1, 0], [1, 1, 1, 1]])
Comment on lines +68 to +69
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you remind me how you get this truth_value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was computed from the current version 1ebf232

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool, please add comments on how the p-values are generated --- same as all other PRs (it's a good practice to check other PRs before writing the PR so that you know what are the best practices we follow currently :) ) This can reduce the review overhead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, thanks.

self.assertEqual((np.round(p_value_matrix, 4) - p_value_matrix_truth).all(), 0)
self.assertEqual((adj_matrix - adj_matrix_truth).all(), 0)

# example2
# for data with multi-dimensional variables, granger lasso regression.
Expand All @@ -66,10 +79,8 @@ def test_granger_lasso(self):
dataset = self.syn_data_3d()
G = Granger()
coeff = G.granger_lasso(data=dataset)
print('Estimated matrix is \n {}'.format(coeff))

coeff_truth = np.array([[0.09, 0.1101, 0.1527, 0.1127, 0.0226, 0.1538],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and what about this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, computed from the current version 1ebf232.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, please add some comments :)

A good mental practice: write code in a way that, if you check the data a year later, do you know what the code means and how the values are generated?

[0.1004, 0.15, 0.1757, 0.1037, 0.1612, 0.0987],
[0.1155, 0.1485, 0, 0.039, -0., 0.1085]])
self.assertEqual((np.round(coeff, 4) - coeff_truth).all(), 0)

if __name__ == '__main__':
test = TestGranger()
test.test_granger_test()
test.test_granger_lasso()