Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 31 additions & 15 deletions causallearn/utils/KCI/KCI.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,13 @@ def kernel_matrix(self, data_x, data_y):
else:
raise Exception('Undefined kernel function')

data_x = stats.zscore(data_x, axis=0)
data_x[np.isnan(data_x)] = 0.

data_y = stats.zscore(data_y, axis=0)
data_x = stats.zscore(data_x, ddof=1, axis=0)
data_x[np.isnan(data_x)] = 0. # in case some dim of data_x is constant
data_y = stats.zscore(data_y, ddof=1, axis=0)
data_y[np.isnan(data_y)] = 0.

# We set 'ddof=1' to conform to the normalization way in the original Matlab implementation in
# http://people.tuebingen.mpg.de/kzhang/KCI-test.zip

Kx = kernelX.kernel(data_x)
Ky = kernelY.kernel(data_y)
return Kx, Ky
Expand Down Expand Up @@ -278,8 +279,8 @@ def __init__(self, kernelX='Gaussian', kernelY='Gaussian', kernelZ='Gaussian', n
self.kwidthy = kwidthy
self.kwidthz = kwidthz
self.nullss = nullss
self.epsilon_x = 0.01
self.epsilon_y = 0.01
self.epsilon_x = 1e-3 # To conform to the original Matlab implementation.
self.epsilon_y = 1e-3
self.use_gp = use_gp
self.thresh = 1e-5
self.approx = approx
Expand Down Expand Up @@ -326,14 +327,16 @@ def kernel_matrix(self, data_x, data_y, data_z):
kzy: centering kernel matrix for data_y (nxn)
"""
# normalize the data
data_x = stats.zscore(data_x, axis=0)
data_x = stats.zscore(data_x, ddof=1, axis=0)
data_x[np.isnan(data_x)] = 0.

data_y = stats.zscore(data_y, axis=0)
data_y = stats.zscore(data_y, ddof=1, axis=0)
data_y[np.isnan(data_y)] = 0.

data_z = stats.zscore(data_z, axis=0)
data_z = stats.zscore(data_z, ddof=1, axis=0)
data_z[np.isnan(data_z)] = 0.
# We set 'ddof=1' to conform to the normalization way in the original Matlab implementation in
# http://people.tuebingen.mpg.de/kzhang/KCI-test.zip

# concatenate x and z
data_x = np.concatenate((data_x, 0.5 * data_z), axis=1)
Expand All @@ -348,7 +351,10 @@ def kernel_matrix(self, data_x, data_y, data_z):
if self.est_width == 'median':
kernelX.set_width_median(data_x)
elif self.est_width == 'empirical':
kernelX.set_width_empirical_kci(data_x)
# kernelX's empirical width is determined by data_z's shape, please refer to the original code
# (http://people.tuebingen.mpg.de/kzhang/KCI-test.zip) in the file
# 'algorithms/CInd_test_new_withGP.m', Line 37 to 52.
kernelX.set_width_empirical_kci(data_z)
else:
raise Exception('Undefined kernel width estimation method')
elif self.kernelX == 'Polynomial':
Expand All @@ -369,7 +375,10 @@ def kernel_matrix(self, data_x, data_y, data_z):
if self.est_width == 'median':
kernelY.set_width_median(data_y)
elif self.est_width == 'empirical':
kernelY.set_width_empirical_kci(data_y)
# kernelY's empirical width is determined by data_z's shape, please refer to the original code
# (http://people.tuebingen.mpg.de/kzhang/KCI-test.zip) in the file
# 'algorithms/CInd_test_new_withGP.m', Line 37 to 52.
kernelY.set_width_empirical_kci(data_z)
else:
raise Exception('Undefined kernel width estimation method')
elif self.kernelY == 'Polynomial':
Expand Down Expand Up @@ -400,6 +409,9 @@ def kernel_matrix(self, data_x, data_y, data_z):
elif self.est_width == 'empirical':
kernelZ.set_width_empirical_kci(data_z)
Kzx = kernelZ.kernel(data_z)
Kzx = Kernel.center_kernel_matrix(Kzx)
# centering kernel matrix to conform with the original Matlab implementation,
# specifically, Line 100 in the file 'algorithms/CInd_test_new_withGP.m'
Kzy = Kzx
else:
# learning the kernel width of Kz using Gaussian process
Expand Down Expand Up @@ -450,10 +462,12 @@ def kernel_matrix(self, data_x, data_y, data_z):
elif self.kernelZ == 'Polynomial':
kernelZ = PolynomialKernel(self.polyd)
Kzx = kernelZ.kernel(data_z)
Kzx = Kernel.center_kernel_matrix(Kzx)
Kzy = Kzx
elif self.kernelZ == 'Linear':
kernelZ = LinearKernel()
Kzx = kernelZ.kernel(data_z)
Kzx = Kernel.center_kernel_matrix(Kzx)
Kzy = Kzx
else:
raise Exception('Undefined kernel function')
Expand All @@ -477,9 +491,11 @@ def KCI_V_statistic(self, Kx, Ky, Kzx, Kzy):

[Updated @Haoyue 06/24/2022]
1. Kx, Ky, Kzx, Kzy are all symmetric matrices.
- Kx, Ky are with diagonal elements of 1 (because of exp(-0.5 * sq_dists * self.width)).
- If (self.kernelZ == 'Gaussian' and self.use_gp), then Kzx (Kzy) has all the same diagonal elements (not necessarily 1).
Otherwise Kzx, Kzy are with diagonal elements of 1.
- * Kx's diagonal elements are not the same, because the kernel Kx is centered.
* Before centering, Kx's all diagonal elements are 1 (because of exp(-0.5 * sq_dists * self.width)).
* The same applies to Ky.
- * If (self.kernelZ == 'Gaussian' and self.use_gp), then Kzx has all the same diagonal elements (not necessarily 1).
* The same applies to Kzy.
2. If not (self.kernelZ == 'Gaussian' and self.use_gp): assert (Kzx == Kzy).all()
With this we could save one repeated calculation of pinv(Kzy+\epsilonI), which consumes most time.
"""
Expand Down
Loading