From 5dc96f2f54c99864a831e1557a7cf849f3aa61dd Mon Sep 17 00:00:00 2001 From: Haoyue Dai Date: Wed, 27 Apr 2022 02:20:57 -0400 Subject: [PATCH] slight modification for speedup --- causallearn/search/FCMBased/lingam/hsic.py | 63 +++++++++++++++------- 1 file changed, 45 insertions(+), 18 deletions(-) diff --git a/causallearn/search/FCMBased/lingam/hsic.py b/causallearn/search/FCMBased/lingam/hsic.py index a10828d0..cdf49678 100644 --- a/causallearn/search/FCMBased/lingam/hsic.py +++ b/causallearn/search/FCMBased/lingam/hsic.py @@ -1,7 +1,9 @@ """ Python implementation of the LiNGAM algorithms. + * some slight modification for speedup, 04/26/2022 The LiNGAM Project: https://sites.google.com/site/sshimizu06/lingam """ +import time import numpy as np from scipy.stats import gamma @@ -34,16 +36,12 @@ def get_kernel_width(X): X_med = X G = np.sum(X_med * X_med, 1).reshape(n_samples, 1) - Q = np.tile(G, (1, n_samples)) - R = np.tile(G.T, (n_samples, 1)) - - dists = Q + R - 2 * np.dot(X_med, X_med.T) + dists = G + G.T - 2 * np.dot(X_med, X_med.T) dists = dists - np.tril(dists) dists = dists.reshape(n_samples ** 2, 1) return np.sqrt(0.5 * np.median(dists[dists > 0])) - def _rbf_dot(X, Y, width): """Compute the inner product of radial basis functions.""" n_samples_X = X.shape[0] @@ -57,6 +55,11 @@ def _rbf_dot(X, Y, width): return np.exp(-H / 2 / (width ** 2)) +def _rbf_dot_XX(X, width): + """rbf dot, in special case with X dot X""" + G = np.sum(X * X, axis=1) + H = G[None, :] + G[:, None] - 2 * np.dot(X, X.T) + return np.exp(-H / 2 / (width ** 2)) def get_gram_matrix(X, width): """Get the centered gram matrices. @@ -76,11 +79,13 @@ def get_gram_matrix(X, width): the centered gram matrices. """ n = X.shape[0] - H = np.eye(n) - 1 / n * np.ones((n, n)) - - K = _rbf_dot(X, X, width) - Kc = np.dot(np.dot(H, K), H) + K = _rbf_dot_XX(X, width) + K_colsums = K.sum(axis=0) + K_rowsums = K.sum(axis=1) + K_allsum = K_rowsums.sum() + Kc = K - (K_colsums[None, :] + K_rowsums[:, None]) / n + np.ones((n, n)) * (K_allsum / n ** 2) + # equivalent to H @ K @ H, where H = np.eye(n) - 1 / n * np.ones((n, n)). return K, Kc @@ -101,7 +106,7 @@ def hsic_teststat(Kc, Lc, n): the HSIC statistic. """ # test statistic m*HSICb under H1 - return 1 / n * np.sum(np.sum(Kc.T * Lc)) + return 1 / n * np.sum(Kc.T * Lc) def hsic_test_gamma(X, Y, bw_method='mdbs'): @@ -148,25 +153,47 @@ def hsic_test_gamma(X, Y, bw_method='mdbs'): # test statistic m*HSICb under H1 n = X.shape[0] - bone = np.ones((n, 1)) test_stat = hsic_teststat(Kc, Lc, n) var = (1 / 6 * Kc * Lc) ** 2 # second subtracted term is bias correction - var = 1 / n / (n - 1) * (np.sum(np.sum(var)) - np.sum(np.diag(var))) + var = 1 / n / (n - 1) * (np.sum(var) - np.trace(var)) # variance under H0 var = 72 * (n - 4) * (n - 5) / n / (n - 1) / (n - 2) / (n - 3) * var - K = K - np.diag(np.diag(K)) - L = L - np.diag(np.diag(L)) - mu_X = 1 / n / (n - 1) * np.dot(bone.T, np.dot(K, bone)) - mu_Y = 1 / n / (n - 1) * np.dot(bone.T, np.dot(L, bone)) + K[np.diag_indices(n)] = 0 + L[np.diag_indices(n)] = 0 + mu_X = 1 / n / (n - 1) * K.sum() + mu_Y = 1 / n / (n - 1) * L.sum() # mean under H0 mean = 1 / n * (1 + mu_X * mu_Y - mu_X - mu_Y) alpha = mean ** 2 / var # threshold for hsicArr*m - beta = np.dot(var, n) / mean - p = 1 - gamma.cdf(test_stat, alpha, scale=beta)[0][0] + beta = var * n / mean + p = 1 - gamma.cdf(test_stat, alpha, scale=beta) return test_stat, p + + +if __name__ == '__main__': + X = np.random.uniform(0, 1, (15000,)) + Y = X ** 2 + np.random.uniform(0, 1, (15000,)) + tic = time.time() + test_stat, p = hsic_test_gamma(X, Y) + print(f'now used: {time.time() - tic: .5f}s') + + from causallearn.search.FCMBased.lingam.hsic import hsic_test_gamma as hsic_test_gamma_old + tic = time.time() + test_stat_old, p_old = hsic_test_gamma_old(X, Y) + print(f'originally used: {time.time() - tic: .5f}s') + + assert np.isclose(test_stat, test_stat_old) + assert np.isclose(p, p_old) + print('equivalent test passed.') + + ''' + now used: 6.78904s + originally used: 65.28648s + equivalent test passed. + ''' \ No newline at end of file