diff --git a/causallearn/utils/KCI/KCI.py b/causallearn/utils/KCI/KCI.py index 379894e6..8eda6257 100644 --- a/causallearn/utils/KCI/KCI.py +++ b/causallearn/utils/KCI/KCI.py @@ -145,12 +145,13 @@ def kernel_matrix(self, data_x, data_y): else: raise Exception('Undefined kernel function') - data_x = stats.zscore(data_x, axis=0) - data_x[np.isnan(data_x)] = 0. - - data_y = stats.zscore(data_y, axis=0) + data_x = stats.zscore(data_x, ddof=1, axis=0) + data_x[np.isnan(data_x)] = 0. # in case some dim of data_x is constant + data_y = stats.zscore(data_y, ddof=1, axis=0) data_y[np.isnan(data_y)] = 0. - + # We set 'ddof=1' to conform to the normalization way in the original Matlab implementation in + # http://people.tuebingen.mpg.de/kzhang/KCI-test.zip + Kx = kernelX.kernel(data_x) Ky = kernelY.kernel(data_y) return Kx, Ky @@ -278,8 +279,8 @@ def __init__(self, kernelX='Gaussian', kernelY='Gaussian', kernelZ='Gaussian', n self.kwidthy = kwidthy self.kwidthz = kwidthz self.nullss = nullss - self.epsilon_x = 0.01 - self.epsilon_y = 0.01 + self.epsilon_x = 1e-3 # To conform to the original Matlab implementation. + self.epsilon_y = 1e-3 self.use_gp = use_gp self.thresh = 1e-5 self.approx = approx @@ -326,14 +327,16 @@ def kernel_matrix(self, data_x, data_y, data_z): kzy: centering kernel matrix for data_y (nxn) """ # normalize the data - data_x = stats.zscore(data_x, axis=0) + data_x = stats.zscore(data_x, ddof=1, axis=0) data_x[np.isnan(data_x)] = 0. - data_y = stats.zscore(data_y, axis=0) + data_y = stats.zscore(data_y, ddof=1, axis=0) data_y[np.isnan(data_y)] = 0. - data_z = stats.zscore(data_z, axis=0) + data_z = stats.zscore(data_z, ddof=1, axis=0) data_z[np.isnan(data_z)] = 0. + # We set 'ddof=1' to conform to the normalization way in the original Matlab implementation in + # http://people.tuebingen.mpg.de/kzhang/KCI-test.zip # concatenate x and z data_x = np.concatenate((data_x, 0.5 * data_z), axis=1) @@ -348,7 +351,10 @@ def kernel_matrix(self, data_x, data_y, data_z): if self.est_width == 'median': kernelX.set_width_median(data_x) elif self.est_width == 'empirical': - kernelX.set_width_empirical_kci(data_x) + # kernelX's empirical width is determined by data_z's shape, please refer to the original code + # (http://people.tuebingen.mpg.de/kzhang/KCI-test.zip) in the file + # 'algorithms/CInd_test_new_withGP.m', Line 37 to 52. + kernelX.set_width_empirical_kci(data_z) else: raise Exception('Undefined kernel width estimation method') elif self.kernelX == 'Polynomial': @@ -369,7 +375,10 @@ def kernel_matrix(self, data_x, data_y, data_z): if self.est_width == 'median': kernelY.set_width_median(data_y) elif self.est_width == 'empirical': - kernelY.set_width_empirical_kci(data_y) + # kernelY's empirical width is determined by data_z's shape, please refer to the original code + # (http://people.tuebingen.mpg.de/kzhang/KCI-test.zip) in the file + # 'algorithms/CInd_test_new_withGP.m', Line 37 to 52. + kernelY.set_width_empirical_kci(data_z) else: raise Exception('Undefined kernel width estimation method') elif self.kernelY == 'Polynomial': @@ -400,6 +409,9 @@ def kernel_matrix(self, data_x, data_y, data_z): elif self.est_width == 'empirical': kernelZ.set_width_empirical_kci(data_z) Kzx = kernelZ.kernel(data_z) + Kzx = Kernel.center_kernel_matrix(Kzx) + # centering kernel matrix to conform with the original Matlab implementation, + # specifically, Line 100 in the file 'algorithms/CInd_test_new_withGP.m' Kzy = Kzx else: # learning the kernel width of Kz using Gaussian process @@ -450,10 +462,12 @@ def kernel_matrix(self, data_x, data_y, data_z): elif self.kernelZ == 'Polynomial': kernelZ = PolynomialKernel(self.polyd) Kzx = kernelZ.kernel(data_z) + Kzx = Kernel.center_kernel_matrix(Kzx) Kzy = Kzx elif self.kernelZ == 'Linear': kernelZ = LinearKernel() Kzx = kernelZ.kernel(data_z) + Kzx = Kernel.center_kernel_matrix(Kzx) Kzy = Kzx else: raise Exception('Undefined kernel function') @@ -477,9 +491,11 @@ def KCI_V_statistic(self, Kx, Ky, Kzx, Kzy): [Updated @Haoyue 06/24/2022] 1. Kx, Ky, Kzx, Kzy are all symmetric matrices. - - Kx, Ky are with diagonal elements of 1 (because of exp(-0.5 * sq_dists * self.width)). - - If (self.kernelZ == 'Gaussian' and self.use_gp), then Kzx (Kzy) has all the same diagonal elements (not necessarily 1). - Otherwise Kzx, Kzy are with diagonal elements of 1. + - * Kx's diagonal elements are not the same, because the kernel Kx is centered. + * Before centering, Kx's all diagonal elements are 1 (because of exp(-0.5 * sq_dists * self.width)). + * The same applies to Ky. + - * If (self.kernelZ == 'Gaussian' and self.use_gp), then Kzx has all the same diagonal elements (not necessarily 1). + * The same applies to Kzy. 2. If not (self.kernelZ == 'Gaussian' and self.use_gp): assert (Kzx == Kzy).all() With this we could save one repeated calculation of pinv(Kzy+\epsilonI), which consumes most time. """ diff --git a/tests/TestCIT_KCI.py b/tests/TestCIT_KCI.py new file mode 100644 index 00000000..4e4167bc --- /dev/null +++ b/tests/TestCIT_KCI.py @@ -0,0 +1,297 @@ +import unittest + +import numpy as np + +import causallearn.utils.cit as cit + + +# TODO : Design more comprehensive test cases, including: design dataset of corner cases. +class TestCIT_KCI(unittest.TestCase): + def test_Gaussian_dist(self): + np.random.seed(10) + X = np.random.randn(300, 1) + X_prime = np.random.randn(300, 1) + Y = X + 0.5 * np.random.randn(300, 1) + Z = Y + 0.5 * np.random.randn(300, 1) + data = np.hstack((X, X_prime, Y, Z)) + + pvalue01 = [] + pvalue03 = [] + pvalue032 = [] + for kernelname in ['Gaussian', 'Polynomial', 'Linear']: + for est_width in ['empirical', 'median', 'manual']: + for kwidth in [0.5, 1.0, 2.0]: + for use_gp in [True, False]: + for approx in [True, False]: + for polyd in [1, 2]: + cit_CIT = cit.CIT(data, 'kci', kernelX=kernelname, kernelY=kernelname, + kernelZ=kernelname, est_width=est_width, use_gp=use_gp, approx=approx, + polyd=polyd, kwidthx=kwidth, kwidthy=kwidth, kwidthz=kwidth) + pvalue01.append(round(cit_CIT(0, 1), 4)) + # X and X_prime are independent, pvalue01 should be expected larger than 0.01 + pvalue03.append(round(cit_CIT(0, 3), 4)) + # X and Z are dependent, pvalue03 should be expected smaller than 0.01 + pvalue032.append(round(cit_CIT(0, 3, {2}), 4)) + # X and Z are independent conditional on Y, pvalue032 should be expected larger than + # 0.01 + pvalue01_truth = [0.5404, 0.5404, 0.507, 0.501, 0.5404, 0.5404, 0.516, 0.536, 0.5404, 0.5404, 0.506, 0.517, + 0.5404, 0.5404, 0.526, 0.526, 0.5404, 0.5404, 0.492, 0.507, 0.5404, 0.5404, 0.529, 0.511, + 0.6106, 0.6106, 0.633, 0.594, 0.6106, 0.6106, 0.612, 0.59, 0.6106, 0.6106, 0.595, 0.59, + 0.6106, 0.6106, 0.606, 0.589, 0.6106, 0.6106, 0.616, 0.587, 0.6106, 0.6106, 0.595, 0.596, + 0.5404, 0.5404, 0.522, 0.501, 0.5404, 0.5404, 0.524, 0.53, 0.5864, 0.5864, 0.574, 0.574, + 0.5864, 0.5864, 0.575, 0.603, 0.4901, 0.4901, 0.487, 0.463, 0.4901, 0.4901, 0.47, 0.493, + 0.2745, 0.1613, 0.251, 0.167, 0.2745, 0.1613, 0.274, 0.15, 0.2745, 0.1613, 0.272, 0.143, + 0.2745, 0.1613, 0.276, 0.158, 0.2745, 0.1613, 0.272, 0.149, 0.2745, 0.1613, 0.268, 0.142, + 0.2745, 0.1613, 0.279, 0.141, 0.2745, 0.1613, 0.279, 0.166, 0.2745, 0.1613, 0.26, 0.152, + 0.2745, 0.1613, 0.27, 0.16, 0.2745, 0.1613, 0.262, 0.145, 0.2745, 0.1613, 0.291, 0.154, + 0.2745, 0.1613, 0.254, 0.14, 0.2745, 0.1613, 0.253, 0.16, 0.2745, 0.1613, 0.272, 0.17, + 0.2745, 0.1613, 0.268, 0.168, 0.2745, 0.1613, 0.285, 0.165, 0.2745, 0.1613, 0.276, 0.147, + 0.2745, 0.2745, 0.274, 0.272, 0.2745, 0.2745, 0.279, 0.277, 0.2745, 0.2745, 0.299, 0.28, + 0.2745, 0.2745, 0.263, 0.258, 0.2745, 0.2745, 0.258, 0.269, 0.2745, 0.2745, 0.289, 0.295, + 0.2745, 0.2745, 0.294, 0.283, 0.2745, 0.2745, 0.286, 0.272, 0.2745, 0.2745, 0.267, 0.273, + 0.2745, 0.2745, 0.27, 0.276, 0.2745, 0.2745, 0.257, 0.269, 0.2745, 0.2745, 0.274, 0.264, + 0.2745, 0.2745, 0.249, 0.302, 0.2745, 0.2745, 0.282, 0.259, 0.2745, 0.2745, 0.262, 0.265, + 0.2745, 0.2745, 0.244, 0.264, 0.2745, 0.2745, 0.295, 0.275, 0.2745, 0.2745, 0.261, 0.265] + pvalue03_truth = [0.0] * (3 * 3 * 3 * 2 * 2 * 2) + pvalue032_truth = [0.6087, 0.6087, 0.5956, 0.6, 0.5807, 0.5807, 0.583, 0.5612, 0.6087, 0.6087, 0.5952, 0.5918, + 0.5807, 0.5807, 0.567, 0.5744, 0.6087, 0.6087, 0.5944, 0.6074, 0.5807, 0.5807, 0.5878, + 0.5558, 0.6164, 0.6164, 0.6252, 0.628, 0.6179, 0.6179, 0.6158, 0.6076, 0.6164, 0.6164, + 0.617, 0.6208, 0.6179, 0.6179, 0.6152, 0.6154, 0.6164, 0.6164, 0.6108, 0.6196, 0.6179, + 0.6179, 0.6384, 0.6198, 0.729, 0.729, 0.7334, 0.7246, 0.6899, 0.6899, 0.6918, 0.6874, + 0.6079, 0.6079, 0.6016, 0.6068, 0.5938, 0.5938, 0.598, 0.5752, 0.571, 0.571, 0.5638, + 0.5714, 0.5737, 0.5737, 0.5702, 0.5608, 0.9111, 0.247, 0.9098, 0.2272, 0.9111, 0.247, + 0.9048, 0.2262, 0.9111, 0.247, 0.9106, 0.2488, 0.9111, 0.247, 0.9106, 0.2312, 0.9111, + 0.247, 0.9122, 0.224, 0.9111, 0.247, 0.9224, 0.2218, 0.9111, 0.247, 0.9148, 0.222, 0.9111, + 0.247, 0.9082, 0.216, 0.9111, 0.247, 0.9154, 0.2294, 0.9111, 0.247, 0.9024, 0.2218, 0.9111, + 0.247, 0.9142, 0.2224, 0.9111, 0.247, 0.9178, 0.2292, 0.9111, 0.247, 0.9098, 0.23, 0.9111, + 0.247, 0.9192, 0.224, 0.9111, 0.247, 0.9066, 0.2316, 0.9111, 0.247, 0.917, 0.2302, 0.9111, + 0.247, 0.9134, 0.2392, 0.9111, 0.247, 0.912, 0.2376, 0.9111, 0.9111, 0.8996, 0.9074, 0.9111, + 0.9111, 0.9124, 0.915, 0.9111, 0.9111, 0.9102, 0.9106, 0.9111, 0.9111, 0.912, 0.9082, + 0.9111, 0.9111, 0.9134, 0.9104, 0.9111, 0.9111, 0.9196, 0.9114, 0.9111, 0.9111, 0.908, + 0.912, 0.9111, 0.9111, 0.9114, 0.9116, 0.9111, 0.9111, 0.9074, 0.9066, 0.9111, 0.9111, + 0.9062, 0.9116, 0.9111, 0.9111, 0.9156, 0.907, 0.9111, 0.9111, 0.9116, 0.9078, 0.9111, + 0.9111, 0.9052, 0.916, 0.9111, 0.9111, 0.912, 0.9098, 0.9111, 0.9111, 0.9068, 0.9162, + 0.9111, 0.9111, 0.9098, 0.9098, 0.9111, 0.9111, 0.9132, 0.9136, 0.9111, 0.9111, 0.9136, + 0.9118] + self.assertEqual(pvalue01, pvalue01_truth) + self.assertEqual(pvalue03, pvalue03_truth) + self.assertEqual(pvalue032, pvalue032_truth) + + def test_Exponential_dist(self): + np.random.seed(10) + X = np.random.exponential(size=(300, 1)) + X_prime = np.random.exponential(size=(300, 1)) + Y = X + 0.5 * np.random.exponential(size=(300, 1)) + Z = Y + 0.5 * np.random.exponential(size=(300, 1)) + data = np.hstack((X, X_prime, Y, Z)) + + pvalue01 = [] + pvalue03 = [] + pvalue032 = [] + for kernelname in ['Gaussian', 'Polynomial', 'Linear']: + for est_width in ['empirical', 'median', 'manual']: + for kwidth in [0.5, 1.0, 2.0]: + for use_gp in [True, False]: + for approx in [True, False]: + for polyd in [1, 2]: + cit_CIT = cit.CIT(data, 'kci', kernelX=kernelname, kernelY=kernelname, + kernelZ=kernelname, est_width=est_width, use_gp=use_gp, approx=approx, + polyd=polyd, kwidthx=kwidth, kwidthy=kwidth, kwidthz=kwidth) + pvalue01.append(round(cit_CIT(0, 1), 4)) + # X and X_prime are independent, pvalue01 should be expected larger than 0.01 + pvalue03.append(round(cit_CIT(0, 3), 4)) + # X and Z are dependent, pvalue03 should be expected smaller than 0.01 + pvalue032.append(round(cit_CIT(0, 3, {2}), 4)) + # X and Z are independent conditional on Y, pvalue032 should be expected larger than + # 0.01 + pvalue01_truth = [0.8513, 0.8513, 0.872, 0.873, 0.8513, 0.8513, 0.871, 0.897, 0.8513, 0.8513, 0.879, 0.886, + 0.8513, 0.8513, 0.889, 0.891, 0.8513, 0.8513, 0.872, 0.876, 0.8513, 0.8513, 0.853, 0.866, + 0.5809, 0.5809, 0.573, 0.568, 0.5809, 0.5809, 0.548, 0.571, 0.5809, 0.5809, 0.593, 0.588, + 0.5809, 0.5809, 0.577, 0.577, 0.5809, 0.5809, 0.581, 0.57, 0.5809, 0.5809, 0.596, 0.598, + 0.8513, 0.8513, 0.866, 0.877, 0.8513, 0.8513, 0.876, 0.874, 0.5604, 0.5604, 0.565, 0.562, + 0.5604, 0.5604, 0.522, 0.526, 0.5048, 0.5048, 0.48, 0.49, 0.5048, 0.5048, 0.488, 0.49, + 0.8219, 0.5496, 0.807, 0.553, 0.8219, 0.5496, 0.825, 0.562, 0.8219, 0.5496, 0.801, 0.542, + 0.8219, 0.5496, 0.823, 0.548, 0.8219, 0.5496, 0.824, 0.549, 0.8219, 0.5496, 0.83, 0.548, + 0.8219, 0.5496, 0.795, 0.557, 0.8219, 0.5496, 0.818, 0.547, 0.8219, 0.5496, 0.821, 0.57, + 0.8219, 0.5496, 0.823, 0.539, 0.8219, 0.5496, 0.843, 0.564, 0.8219, 0.5496, 0.823, 0.531, + 0.8219, 0.5496, 0.802, 0.538, 0.8219, 0.5496, 0.811, 0.544, 0.8219, 0.5496, 0.796, 0.572, + 0.8219, 0.5496, 0.822, 0.586, 0.8219, 0.5496, 0.818, 0.565, 0.8219, 0.5496, 0.822, 0.569, + 0.8219, 0.8219, 0.814, 0.811, 0.8219, 0.8219, 0.861, 0.78, 0.8219, 0.8219, 0.85, 0.855, + 0.8219, 0.8219, 0.815, 0.818, 0.8219, 0.8219, 0.829, 0.818, 0.8219, 0.8219, 0.825, 0.818, + 0.8219, 0.8219, 0.839, 0.821, 0.8219, 0.8219, 0.83, 0.812, 0.8219, 0.8219, 0.828, 0.83, + 0.8219, 0.8219, 0.824, 0.806, 0.8219, 0.8219, 0.833, 0.844, 0.8219, 0.8219, 0.824, 0.825, + 0.8219, 0.8219, 0.827, 0.817, 0.8219, 0.8219, 0.827, 0.826, 0.8219, 0.8219, 0.817, 0.835, + 0.8219, 0.8219, 0.829, 0.821, 0.8219, 0.8219, 0.832, 0.814, 0.8219, 0.8219, 0.835, 0.8] + pvalue03_truth = [0.0] * (3 * 3 * 3 * 2 * 2 * 2) + pvalue032_truth = [0.4088, 0.4088, 0.3792, 0.3764, 0.4076, 0.4076, 0.3732, 0.3746, 0.4088, 0.4088, 0.3834, + 0.374, 0.4076, 0.4076, 0.3822, 0.375, 0.4088, 0.4088, 0.3702, 0.3806, 0.4076, 0.4076, + 0.3674, 0.3638, 0.627, 0.627, 0.6232, 0.6236, 0.6756, 0.6756, 0.6788, 0.6806, 0.627, 0.627, + 0.622, 0.6254, 0.6756, 0.6756, 0.6872, 0.6812, 0.627, 0.627, 0.6196, 0.6076, 0.6756, 0.6756, + 0.6858, 0.6656, 0.4087, 0.4087, 0.3898, 0.3886, 0.3398, 0.3398, 0.3092, 0.3042, 0.5165, + 0.5165, 0.4958, 0.4912, 0.5326, 0.5326, 0.5226, 0.5288, 0.8561, 0.8561, 0.8962, 0.8864, + 0.8749, 0.8749, 0.915, 0.9118, 0.7353, 0.515, 0.735, 0.511, 0.7353, 0.515, 0.7274, 0.507, + 0.7353, 0.515, 0.737, 0.509, 0.7353, 0.515, 0.731, 0.5084, 0.7353, 0.515, 0.7338, 0.4996, + 0.7353, 0.515, 0.7312, 0.5156, 0.7353, 0.515, 0.7414, 0.5224, 0.7353, 0.515, 0.7312, 0.519, + 0.7353, 0.515, 0.737, 0.5046, 0.7353, 0.515, 0.7328, 0.5204, 0.7353, 0.515, 0.738, 0.5058, + 0.7353, 0.515, 0.728, 0.5016, 0.7353, 0.515, 0.7416, 0.514, 0.7353, 0.515, 0.724, 0.5174, + 0.7353, 0.515, 0.7342, 0.5118, 0.7353, 0.515, 0.7338, 0.5156, 0.7353, 0.515, 0.7388, 0.5016, + 0.7353, 0.515, 0.737, 0.5102, 0.7353, 0.7353, 0.7348, 0.738, 0.7353, 0.7353, 0.7392, 0.732, + 0.7353, 0.7353, 0.7328, 0.7268, 0.7353, 0.7353, 0.737, 0.7398, 0.7353, 0.7353, 0.736, 0.7416, + 0.7353, 0.7353, 0.7398, 0.7374, 0.7353, 0.7353, 0.7314, 0.737, 0.7353, 0.7353, 0.7338, 0.7354, + 0.7353, 0.7353, 0.7328, 0.7372, 0.7353, 0.7353, 0.7352, 0.7356, 0.7353, 0.7353, 0.739, 0.7336, + 0.7353, 0.7353, 0.7404, 0.7386, 0.7353, 0.7353, 0.7398, 0.7432, 0.7353, 0.7353, 0.7386, + 0.7402, 0.7353, 0.7353, 0.728, 0.7288, 0.7353, 0.7353, 0.7328, 0.7324, 0.7353, 0.7353, + 0.7412, 0.7398, 0.7353, 0.7353, 0.7412, 0.7272] + self.assertEqual(pvalue01, pvalue01_truth) + self.assertEqual(pvalue03, pvalue03_truth) + self.assertEqual(pvalue032, pvalue032_truth) + + def test_Uniform_dist(self): + np.random.seed(10) + X = np.random.uniform(size=(300, 1)) + X_prime = np.random.uniform(size=(300, 1)) + Y = X + 0.5 * np.random.uniform(size=(300, 1)) + Z = Y + 0.5 * np.random.uniform(size=(300, 1)) + data = np.hstack((X, X_prime, Y, Z)) + + pvalue01 = [] + pvalue03 = [] + pvalue032 = [] + for kernelname in ['Gaussian', 'Polynomial', 'Linear']: + for est_width in ['empirical', 'median', 'manual']: + for kwidth in [0.5, 1.0, 2.0]: + for use_gp in [True, False]: + for approx in [True, False]: + for polyd in [1, 2]: + cit_CIT = cit.CIT(data, 'kci', kernelX=kernelname, kernelY=kernelname, + kernelZ=kernelname, est_width=est_width, use_gp=use_gp, approx=approx, + polyd=polyd, kwidthx=kwidth, kwidthy=kwidth, kwidthz=kwidth) + pvalue01.append(round(cit_CIT(0, 1), 4)) + # X and X_prime are independent, pvalue01 should be expected larger than 0.01 + pvalue03.append(round(cit_CIT(0, 3), 4)) + # X and Z are dependent, pvalue03 should be expected smaller than 0.01 + pvalue032.append(round(cit_CIT(0, 3, {2}), 4)) + # X and Z are independent conditional on Y, pvalue032 should be expected larger than + # 0.01 + pvalue01_truth = [0.8099, 0.8099, 0.815, 0.827, 0.8099, 0.8099, 0.821, 0.82, 0.8099, 0.8099, 0.852, 0.828, + 0.8099, 0.8099, 0.83, 0.825, 0.8099, 0.8099, 0.831, 0.83, 0.8099, 0.8099, 0.813, 0.817, + 0.7897, 0.7897, 0.809, 0.814, 0.7897, 0.7897, 0.788, 0.789, 0.7897, 0.7897, 0.798, 0.805, + 0.7897, 0.7897, 0.803, 0.786, 0.7897, 0.7897, 0.804, 0.802, 0.7897, 0.7897, 0.793, 0.794, + 0.8099, 0.8099, 0.815, 0.815, 0.8099, 0.8099, 0.816, 0.826, 0.5796, 0.5796, 0.551, 0.572, + 0.5796, 0.5796, 0.565, 0.593, 0.5546, 0.5546, 0.568, 0.546, 0.5546, 0.5546, 0.546, 0.564, + 0.7155, 0.4159, 0.7, 0.401, 0.7155, 0.4159, 0.719, 0.402, 0.7155, 0.4159, 0.721, 0.419, + 0.7155, 0.4159, 0.708, 0.412, 0.7155, 0.4159, 0.708, 0.416, 0.7155, 0.4159, 0.704, 0.394, + 0.7155, 0.4159, 0.71, 0.426, 0.7155, 0.4159, 0.707, 0.419, 0.7155, 0.4159, 0.728, 0.406, + 0.7155, 0.4159, 0.712, 0.384, 0.7155, 0.4159, 0.718, 0.392, 0.7155, 0.4159, 0.715, 0.365, + 0.7155, 0.4159, 0.723, 0.387, 0.7155, 0.4159, 0.72, 0.398, 0.7155, 0.4159, 0.705, 0.406, + 0.7155, 0.4159, 0.704, 0.379, 0.7155, 0.4159, 0.709, 0.392, 0.7155, 0.4159, 0.717, 0.406, + 0.7155, 0.7155, 0.712, 0.705, 0.7155, 0.7155, 0.719, 0.722, 0.7155, 0.7155, 0.684, 0.715, + 0.7155, 0.7155, 0.702, 0.705, 0.7155, 0.7155, 0.732, 0.729, 0.7155, 0.7155, 0.688, 0.701, + 0.7155, 0.7155, 0.729, 0.737, 0.7155, 0.7155, 0.701, 0.722, 0.7155, 0.7155, 0.717, 0.714, + 0.7155, 0.7155, 0.732, 0.711, 0.7155, 0.7155, 0.709, 0.708, 0.7155, 0.7155, 0.709, 0.708, + 0.7155, 0.7155, 0.702, 0.723, 0.7155, 0.7155, 0.722, 0.708, 0.7155, 0.7155, 0.72, 0.695, + 0.7155, 0.7155, 0.7, 0.71, 0.7155, 0.7155, 0.705, 0.73, 0.7155, 0.7155, 0.736, 0.699] + pvalue03_truth = [0.0] * (3 * 3 * 3 * 2 * 2 * 2) + pvalue032_truth = [0.6393, 0.6393, 0.6354, 0.6396, 0.6124, 0.6124, 0.5972, 0.618, 0.6393, 0.6393, 0.639, + 0.6288, 0.6124, 0.6124, 0.607, 0.6064, 0.6393, 0.6393, 0.6438, 0.6442, 0.6124, 0.6124, + 0.613, 0.595, 0.8899, 0.8899, 0.9352, 0.934, 0.8887, 0.8887, 0.9308, 0.9344, 0.8899, 0.8899, + 0.935, 0.9354, 0.8887, 0.8887, 0.9336, 0.939, 0.8899, 0.8899, 0.9382, 0.9378, 0.8887, 0.8887, + 0.9312, 0.9352, 0.5581, 0.5581, 0.5508, 0.5398, 0.5246, 0.5246, 0.5184, 0.514, 0.833, 0.833, + 0.8634, 0.859, 0.8296, 0.8296, 0.859, 0.8538, 0.8861, 0.8861, 0.9372, 0.9456, 0.8705, 0.8705, + 0.9246, 0.9294, 0.3721, 0.6544, 0.3794, 0.656, 0.3721, 0.6544, 0.3832, 0.6548, 0.3721, 0.6544, + 0.3696, 0.6606, 0.3721, 0.6544, 0.373, 0.6516, 0.3721, 0.6544, 0.3814, 0.653, 0.3721, 0.6544, + 0.3706, 0.6638, 0.3721, 0.6544, 0.373, 0.6624, 0.3721, 0.6544, 0.3728, 0.6566, 0.3721, 0.6544, + 0.373, 0.6752, 0.3721, 0.6544, 0.3686, 0.6644, 0.3721, 0.6544, 0.3744, 0.6684, 0.3721, 0.6544, + 0.3758, 0.663, 0.3721, 0.6544, 0.3752, 0.6462, 0.3721, 0.6544, 0.3694, 0.6738, 0.3721, 0.6544, + 0.3684, 0.6602, 0.3721, 0.6544, 0.3682, 0.6734, 0.3721, 0.6544, 0.3732, 0.662, 0.3721, 0.6544, + 0.3718, 0.6638, 0.3721, 0.3721, 0.3688, 0.372, 0.3721, 0.3721, 0.3596, 0.3722, 0.3721, 0.3721, + 0.3662, 0.361, 0.3721, 0.3721, 0.379, 0.3716, 0.3721, 0.3721, 0.379, 0.3684, 0.3721, 0.3721, + 0.3754, 0.3674, 0.3721, 0.3721, 0.3636, 0.38, 0.3721, 0.3721, 0.3756, 0.3662, 0.3721, 0.3721, + 0.373, 0.3702, 0.3721, 0.3721, 0.3704, 0.3746, 0.3721, 0.3721, 0.3574, 0.3616, 0.3721, 0.3721, + 0.359, 0.3702, 0.3721, 0.3721, 0.366, 0.3704, 0.3721, 0.3721, 0.3682, 0.3732, 0.3721, 0.3721, + 0.3836, 0.3806, 0.3721, 0.3721, 0.371, 0.3596, 0.3721, 0.3721, 0.3728, 0.3744, 0.3721, 0.3721, + 0.3708, 0.388] + self.assertEqual(pvalue01, pvalue01_truth) + self.assertEqual(pvalue03, pvalue03_truth) + self.assertEqual(pvalue032, pvalue032_truth) + + def test_Mixed_dist(self): + np.random.seed(10) + X = np.random.uniform(size=(300, 1)) + X_prime = np.random.randn(300, 1) + Y = X + 0.5 * np.random.exponential(size=(300, 1)) + Z = Y + 0.5 * np.random.randn(300, 1) + data = np.hstack((X, X_prime, Y, Z)) + + pvalue01 = [] + pvalue03 = [] + pvalue032 = [] + for kernelname in ['Gaussian', 'Polynomial', 'Linear']: + for est_width in ['empirical', 'median', 'manual']: + for kwidth in [0.5, 1.0, 2.0]: + for use_gp in [True, False]: + for approx in [True, False]: + for polyd in [1, 2]: + cit_CIT = cit.CIT(data, 'kci', kernelX=kernelname, kernelY=kernelname, + kernelZ=kernelname, est_width=est_width, use_gp=use_gp, approx=approx, + polyd=polyd, kwidthx=kwidth, kwidthy=kwidth, kwidthz=kwidth) + pvalue01.append(round(cit_CIT(0, 1), 4)) + # X and X_prime are independent, pvalue01 should be expected larger than 0.01 + pvalue03.append(round(cit_CIT(0, 3), 4)) + # X and Z are dependent, pvalue03 should be expected smaller than 0.01 + pvalue032.append(round(cit_CIT(0, 3, {2}), 4)) + # X and Z are independent conditional on Y, pvalue032 should be expected larger than + # 0.01 + pvalue01_truth = [0.6565, 0.6565, 0.637, 0.668, 0.6565, 0.6565, 0.64, 0.659, 0.6565, 0.6565, 0.646, 0.632, + 0.6565, 0.6565, 0.67, 0.646, 0.6565, 0.6565, 0.655, 0.668, 0.6565, 0.6565, 0.661, 0.663, + 0.5346, 0.5346, 0.524, 0.507, 0.5346, 0.5346, 0.517, 0.511, 0.5346, 0.5346, 0.535, 0.514, + 0.5346, 0.5346, 0.505, 0.526, 0.5346, 0.5346, 0.534, 0.518, 0.5346, 0.5346, 0.517, 0.507, + 0.6565, 0.6565, 0.633, 0.642, 0.6565, 0.6565, 0.659, 0.64, 0.6557, 0.6557, 0.668, 0.68, + 0.6557, 0.6557, 0.654, 0.66, 0.6663, 0.6663, 0.701, 0.693, 0.6663, 0.6663, 0.704, 0.698, + 0.7537, 0.5882, 0.74, 0.618, 0.7537, 0.5882, 0.768, 0.572, 0.7537, 0.5882, 0.778, 0.581, + 0.7537, 0.5882, 0.755, 0.624, 0.7537, 0.5882, 0.772, 0.569, 0.7537, 0.5882, 0.757, 0.585, + 0.7537, 0.5882, 0.738, 0.612, 0.7537, 0.5882, 0.77, 0.602, 0.7537, 0.5882, 0.74, 0.562, + 0.7537, 0.5882, 0.754, 0.609, 0.7537, 0.5882, 0.749, 0.574, 0.7537, 0.5882, 0.775, 0.573, + 0.7537, 0.5882, 0.76, 0.6, 0.7537, 0.5882, 0.735, 0.613, 0.7537, 0.5882, 0.749, 0.577, + 0.7537, 0.5882, 0.765, 0.612, 0.7537, 0.5882, 0.758, 0.588, 0.7537, 0.5882, 0.763, 0.57, + 0.7537, 0.7537, 0.787, 0.743, 0.7537, 0.7537, 0.762, 0.762, 0.7537, 0.7537, 0.769, 0.76, + 0.7537, 0.7537, 0.746, 0.733, 0.7537, 0.7537, 0.755, 0.723, 0.7537, 0.7537, 0.75, 0.748, + 0.7537, 0.7537, 0.747, 0.754, 0.7537, 0.7537, 0.774, 0.754, 0.7537, 0.7537, 0.734, 0.767, + 0.7537, 0.7537, 0.747, 0.772, 0.7537, 0.7537, 0.76, 0.74, 0.7537, 0.7537, 0.722, 0.738, + 0.7537, 0.7537, 0.732, 0.754, 0.7537, 0.7537, 0.782, 0.745, 0.7537, 0.7537, 0.744, 0.767, + 0.7537, 0.7537, 0.756, 0.738, 0.7537, 0.7537, 0.718, 0.77, 0.7537, 0.7537, 0.771, 0.757] + pvalue03_truth = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + pvalue032_truth = [0.1326, 0.1326, 0.1206, 0.1302, 0.1561, 0.1561, 0.1552, 0.151, 0.1326, 0.1326, 0.119, 0.1292, + 0.1561, 0.1561, 0.144, 0.1454, 0.1326, 0.1326, 0.131, 0.1238, 0.1561, 0.1561, 0.1488, 0.1484, + 0.1196, 0.1196, 0.1072, 0.1066, 0.1217, 0.1217, 0.1142, 0.1102, 0.1196, 0.1196, 0.1084, + 0.1032, 0.1217, 0.1217, 0.1078, 0.112, 0.1196, 0.1196, 0.1092, 0.1102, 0.1217, 0.1217, + 0.112, 0.114, 0.2332, 0.2332, 0.2248, 0.2156, 0.2847, 0.2847, 0.2602, 0.265, 0.1033, 0.1033, + 0.0994, 0.1026, 0.0946, 0.0946, 0.0892, 0.1, 0.2605, 0.2605, 0.2266, 0.2212, 0.2599, 0.2599, + 0.2356, 0.2324, 0.8637, 0.1417, 0.8672, 0.1382, 0.8637, 0.1417, 0.8692, 0.1338, 0.8637, + 0.1417, 0.8676, 0.1414, 0.8637, 0.1417, 0.8702, 0.1394, 0.8637, 0.1417, 0.8626, 0.1336, + 0.8637, 0.1417, 0.8614, 0.1354, 0.8637, 0.1417, 0.8666, 0.127, 0.8637, 0.1417, 0.8568, + 0.1314, 0.8637, 0.1417, 0.8632, 0.1334, 0.8637, 0.1417, 0.863, 0.1386, 0.8637, 0.1417, + 0.8616, 0.1424, 0.8637, 0.1417, 0.8622, 0.1404, 0.8637, 0.1417, 0.8584, 0.13, 0.8637, + 0.1417, 0.8584, 0.1382, 0.8637, 0.1417, 0.8748, 0.1234, 0.8637, 0.1417, 0.856, 0.1414, + 0.8637, 0.1417, 0.8664, 0.1364, 0.8637, 0.1417, 0.8552, 0.1372, 0.8637, 0.8637, 0.858, + 0.8652, 0.8637, 0.8637, 0.8558, 0.8666, 0.8637, 0.8637, 0.8584, 0.8644, 0.8637, 0.8637, + 0.8614, 0.8678, 0.8637, 0.8637, 0.8696, 0.8682, 0.8637, 0.8637, 0.869, 0.8624, 0.8637, + 0.8637, 0.8642, 0.8648, 0.8637, 0.8637, 0.8644, 0.8648, 0.8637, 0.8637, 0.8552, 0.8648, + 0.8637, 0.8637, 0.8642, 0.86, 0.8637, 0.8637, 0.86, 0.8612, 0.8637, 0.8637, 0.8586, 0.8702, + 0.8637, 0.8637, 0.8612, 0.8652, 0.8637, 0.8637, 0.8602, 0.8684, 0.8637, 0.8637, 0.8596, + 0.859, 0.8637, 0.8637, 0.8622, 0.8512, 0.8637, 0.8637, 0.8594, 0.8672, 0.8637, 0.8637, + 0.8626, 0.8716] + self.assertEqual(pvalue01, pvalue01_truth) + self.assertEqual(pvalue03, pvalue03_truth) + self.assertEqual(pvalue032, pvalue032_truth) diff --git a/tests/TestKCI.py b/tests/TestKCI.py deleted file mode 100644 index 09586043..00000000 --- a/tests/TestKCI.py +++ /dev/null @@ -1,134 +0,0 @@ -import os -import sys -# BASE_DIR = os.path.join(os.path.dirname(__file__), '..') -# sys.path.append(BASE_DIR) -import unittest - -import numpy as np - -import causallearn.utils.cit as cit -from causallearn.utils.KCI.KCI import KCI_CInd, KCI_UInd - - -class TestKCI(unittest.TestCase): - def test_Gaussian(self): - X = np.random.randn(300, 1) - X1 = np.random.randn(300, 1) - Y = np.concatenate((X, X), axis=1) + 0.5 * np.random.randn(300, 2) - Z = Y + 0.5 * np.random.randn(300, 2) - - kci_uind = KCI_UInd() - pvalue, _ = kci_uind.compute_pvalue(X, X1) - print('X and X1 are independent, pvalue is {:.2f}'.format(pvalue)) - - pvalue, _ = kci_uind.compute_pvalue(X, Z) - print('X and Z are dependent, pvalue is {:.2f}'.format(pvalue)) - - kci_cind = KCI_CInd() - pvalue, _ = kci_cind.compute_pvalue(X, Z, Y) - print('X and Z are independent conditional on Y, pvalue is {:.2f}'.format(pvalue)) - - def test_Polynomial(self): - X = np.random.randn(300, 1) - X1 = np.random.randn(300, 1) - Y = np.concatenate((X, X), axis=1) + 0.5 * np.random.randn(300, 2) - Z = Y + 0.5 * np.random.randn(300, 2) - - kci_uind = KCI_UInd(kernelX='Polynomial', kernelY='Polynomial') - pvalue, _ = kci_uind.compute_pvalue(X, X1) - print('X and X1 are independent, pvalue is {:.2f}'.format(pvalue)) - - pvalue, _ = kci_uind.compute_pvalue(X, Z) - print('X and Z are dependent, pvalue is {:.2f}'.format(pvalue)) - - kci_cind = KCI_CInd(kernelX='Polynomial', kernelY='Polynomial', kernelZ='Polynomial') - pvalue, _ = kci_cind.compute_pvalue(X, Z, Y) - print('X and Z are independent conditional on Y, pvalue is {:.2f}'.format(pvalue)) - - def test_Linear(self): - X = np.random.randn(300, 1) - X1 = np.random.randn(300, 1) - Y = np.concatenate((X, X), axis=1) + 0.5 * np.random.randn(300, 2) - Z = Y + 0.5 * np.random.randn(300, 2) - - kci_uind = KCI_UInd(kernelX='Linear', kernelY='Linear') - pvalue, _ = kci_uind.compute_pvalue(X, X1) - print('X and X1 are independent, pvalue is {:.2f}'.format(pvalue)) - - pvalue, _ = kci_uind.compute_pvalue(X, Z) - print('X and Z are dependent, pvalue is {:.2f}'.format(pvalue)) - - kci_cind = KCI_CInd(kernelX='Linear', kernelY='Linear', kernelZ='Linear') - pvalue, _ = kci_cind.compute_pvalue(X, Z, Y) - print('X and Z are independent conditional on Y, pvalue is {:.2f}'.format(pvalue)) - - -class TestCIT_KCI(unittest.TestCase): - def test_Gaussian(self): - X = np.random.randn(300, 1) - X1 = np.random.randn(300, 1) - Y = np.concatenate((X, X), axis=1) + 0.5 * np.random.randn(300, 2) - Z = Y + 0.5 * np.random.randn(300, 2) - - pvalue = cit.kci_ui(X, X1) - print('X and X1 are independent, pvalue is {:.2f}'.format(pvalue)) - - pvalue = cit.kci_ui(X, Z) - print('X and Z are dependent, pvalue is {:.2f}'.format(pvalue)) - - pvalue = cit.kci_ci(X, Z, Y) - print('X and Z are independent conditional on Y, pvalue is {:.2f}'.format(pvalue)) - - def test_Polynomial(self): - X = np.random.randn(300, 1) - X1 = np.random.randn(300, 1) - Y = np.concatenate((X, X), axis=1) + 0.5 * np.random.randn(300, 2) - Z = Y + 0.5 * np.random.randn(300, 2) - - pvalue = cit.kci_ui(X, X1, kernelX='Polynomial', kernelY='Polynomial') - print('X and X1 are independent, pvalue is {:.2f}'.format(pvalue)) - - pvalue = cit.kci_ui(X, Z, kernelX='Polynomial', kernelY='Polynomial') - print('X and Z are dependent, pvalue is {:.2f}'.format(pvalue)) - - pvalue = cit.kci_ci(X, Z, Y, kernelX='Polynomial', kernelY='Polynomial', kernelZ='Polynomial') - print('X and Z are independent conditional on Y, pvalue is {:.2f}'.format(pvalue)) - - def test_Linear(self): - X = np.random.randn(300, 1) - X1 = np.random.randn(300, 1) - Y = np.concatenate((X, X), axis=1) + 0.5 * np.random.randn(300, 2) - Z = Y + 0.5 * np.random.randn(300, 2) - - pvalue = cit.kci_ui(X, X1, kernelX='Linear', kernelY='Linear') - print('X and X1 are independent, pvalue is {:.2f}'.format(pvalue)) - - pvalue = cit.kci_ui(X, Z, kernelX='Linear', kernelY='Linear') - print('X and Z are dependent, pvalue is {:.2f}'.format(pvalue)) - - pvalue = cit.kci_ci(X, Z, Y, kernelX='Linear', kernelY='Linear', kernelZ='Linear') - print('X and Z are independent conditional on Y, pvalue is {:.2f}'.format(pvalue)) - - -if __name__ == '__main__': - test = TestKCI() - print('------------------------------') - print('Test KCI with Gaussian kernel') - test.test_Gaussian() - print('------------------------------') - print('Test KCI with Polynomial kernel') - test.test_Polynomial() - print('------------------------------') - print('Test KCI with Linear kernel') - test.test_Linear() - - test = TestCIT_KCI() - print('------------------------------') - print('Test CIT_KCI with Gaussian kernel') - test.test_Gaussian() - print('------------------------------') - print('Test CIT_KCI with Polynomial kernel') - test.test_Polynomial() - print('------------------------------') - print('Test CIT_KCI with Linear kernel') - test.test_Linear()