Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 45 additions & 18 deletions causallearn/search/FCMBased/lingam/hsic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""
Python implementation of the LiNGAM algorithms.
* some slight modification for speedup, 04/26/2022
The LiNGAM Project: https://sites.google.com/site/sshimizu06/lingam
"""
import time

import numpy as np
from scipy.stats import gamma
Expand Down Expand Up @@ -34,16 +36,12 @@ def get_kernel_width(X):
X_med = X

G = np.sum(X_med * X_med, 1).reshape(n_samples, 1)
Q = np.tile(G, (1, n_samples))
R = np.tile(G.T, (n_samples, 1))

dists = Q + R - 2 * np.dot(X_med, X_med.T)
dists = G + G.T - 2 * np.dot(X_med, X_med.T)
dists = dists - np.tril(dists)
dists = dists.reshape(n_samples ** 2, 1)

return np.sqrt(0.5 * np.median(dists[dists > 0]))


def _rbf_dot(X, Y, width):
"""Compute the inner product of radial basis functions."""
n_samples_X = X.shape[0]
Expand All @@ -57,6 +55,11 @@ def _rbf_dot(X, Y, width):

return np.exp(-H / 2 / (width ** 2))

def _rbf_dot_XX(X, width):
"""rbf dot, in special case with X dot X"""
G = np.sum(X * X, axis=1)
H = G[None, :] + G[:, None] - 2 * np.dot(X, X.T)
return np.exp(-H / 2 / (width ** 2))

def get_gram_matrix(X, width):
"""Get the centered gram matrices.
Expand All @@ -76,11 +79,13 @@ def get_gram_matrix(X, width):
the centered gram matrices.
"""
n = X.shape[0]
H = np.eye(n) - 1 / n * np.ones((n, n))

K = _rbf_dot(X, X, width)
Kc = np.dot(np.dot(H, K), H)

K = _rbf_dot_XX(X, width)
K_colsums = K.sum(axis=0)
K_rowsums = K.sum(axis=1)
K_allsum = K_rowsums.sum()
Kc = K - (K_colsums[None, :] + K_rowsums[:, None]) / n + np.ones((n, n)) * (K_allsum / n ** 2)
# equivalent to H @ K @ H, where H = np.eye(n) - 1 / n * np.ones((n, n)).
return K, Kc


Expand All @@ -101,7 +106,7 @@ def hsic_teststat(Kc, Lc, n):
the HSIC statistic.
"""
# test statistic m*HSICb under H1
return 1 / n * np.sum(np.sum(Kc.T * Lc))
return 1 / n * np.sum(Kc.T * Lc)


def hsic_test_gamma(X, Y, bw_method='mdbs'):
Expand Down Expand Up @@ -148,25 +153,47 @@ def hsic_test_gamma(X, Y, bw_method='mdbs'):

# test statistic m*HSICb under H1
n = X.shape[0]
bone = np.ones((n, 1))
test_stat = hsic_teststat(Kc, Lc, n)

var = (1 / 6 * Kc * Lc) ** 2
# second subtracted term is bias correction
var = 1 / n / (n - 1) * (np.sum(np.sum(var)) - np.sum(np.diag(var)))
var = 1 / n / (n - 1) * (np.sum(var) - np.trace(var))
# variance under H0
var = 72 * (n - 4) * (n - 5) / n / (n - 1) / (n - 2) / (n - 3) * var

K = K - np.diag(np.diag(K))
L = L - np.diag(np.diag(L))
mu_X = 1 / n / (n - 1) * np.dot(bone.T, np.dot(K, bone))
mu_Y = 1 / n / (n - 1) * np.dot(bone.T, np.dot(L, bone))
K[np.diag_indices(n)] = 0
L[np.diag_indices(n)] = 0
mu_X = 1 / n / (n - 1) * K.sum()
mu_Y = 1 / n / (n - 1) * L.sum()
# mean under H0
mean = 1 / n * (1 + mu_X * mu_Y - mu_X - mu_Y)

alpha = mean ** 2 / var
# threshold for hsicArr*m
beta = np.dot(var, n) / mean
p = 1 - gamma.cdf(test_stat, alpha, scale=beta)[0][0]
beta = var * n / mean
p = 1 - gamma.cdf(test_stat, alpha, scale=beta)

return test_stat, p


if __name__ == '__main__':
X = np.random.uniform(0, 1, (15000,))
Y = X ** 2 + np.random.uniform(0, 1, (15000,))
tic = time.time()
test_stat, p = hsic_test_gamma(X, Y)
print(f'now used: {time.time() - tic: .5f}s')

from causallearn.search.FCMBased.lingam.hsic import hsic_test_gamma as hsic_test_gamma_old
tic = time.time()
test_stat_old, p_old = hsic_test_gamma_old(X, Y)
print(f'originally used: {time.time() - tic: .5f}s')

assert np.isclose(test_stat, test_stat_old)
assert np.isclose(p, p_old)
print('equivalent test passed.')

'''
now used: 6.78904s
originally used: 65.28648s
equivalent test passed.
'''