diff --git a/causallearn/search/HiddenCausal/GIN/GIN.py b/causallearn/search/HiddenCausal/GIN/GIN.py index e6592c38..17f31b20 100644 --- a/causallearn/search/HiddenCausal/GIN/GIN.py +++ b/causallearn/search/HiddenCausal/GIN/GIN.py @@ -1,11 +1,3 @@ -''' - File name: GIN.py - Discription: Learning Hidden Causal Representation with GIN condition - Author: ZhiyiHuang@DMIRLab, RuichuCai@DMIRLab - From DMIRLab: https://dmir.gdut.edu.cn/ -''' - -import random from collections import deque from itertools import combinations @@ -18,7 +10,7 @@ from causallearn.graph.Edge import Edge from causallearn.graph.Endpoint import Endpoint from causallearn.search.FCMBased.lingam.hsic import hsic_test_gamma -from causallearn.utils.cit import kci +from causallearn.utils.KCI.KCI import KCI_UInd def fisher_test(pvals): @@ -27,29 +19,42 @@ def fisher_test(pvals): return 1 - chi2.cdf(fisher_stat, 2 * len(pvals)) -def GIN(data, indep_test=kci, alpha=0.05): +def GIN(data, indep_test_method='kci', alpha=0.05): ''' Learning causal structure of Latent Variables for Linear Non-Gaussian Latent Variable Model with Generalized Independent Noise Condition - Parameters ---------- data : numpy ndarray data set - indep_test : callable, default=kci - the function of the independence test being used + indep_test_method : str, default='kci' + the name of the independence test being used alpha : float, default=0.05 desired significance level of independence tests (p_value) in (0,1) Returns ------- G : general graph causal graph - K : list + causal_order : list causal order ''' n = data.shape[1] cov = np.cov(data.T) + if indep_test_method == 'kci': + kci = KCI_UInd() + + if indep_test_method not in ['kci', 'hsic']: + raise NotImplementedError((f"Independent test method {indep_test_method} is not implemented.")) + + def indep_test(x, y, method): + if method == 'kci': + return kci.compute_pvalue(x, y)[0] + elif method == 'hsic': + return hsic_test_gamma(x, y)[1] + else: + raise NotImplementedError((f"Independent test method {indep_test_method} is not implemented.")) + var_set = set(range(n)) cluster_size = 2 clusters_list = [] @@ -59,9 +64,8 @@ def GIN(data, indep_test=kci, alpha=0.05): remain_var_set = var_set - set(cluster) e = cal_e_with_gin(data, cov, list(cluster), list(remain_var_set)) pvals = [] - tmp_data = np.concatenate([data[:, list(remain_var_set)], e.reshape(-1, 1)], axis=1) for z in range(len(remain_var_set)): - pvals.append(indep_test(tmp_data, z, - 1)) + pvals.append(indep_test(data[:, [z]], e[:, None], method=indep_test_method)) fisher_pval = fisher_test(pvals) if fisher_pval >= alpha: tmp_clusters_list.append(cluster) @@ -71,37 +75,34 @@ def GIN(data, indep_test=kci, alpha=0.05): var_set -= set(cluster) cluster_size += 1 - K = [] + causal_order = [] # this variable corresponds to K in paper updated = True while updated: updated = False X = [] Z = [] - for cluster_k in K: + for cluster_k in causal_order: cluster_k1, cluster_k2 = array_split(cluster_k, 2) X += cluster_k1 Z += cluster_k2 for i, cluster_i in enumerate(clusters_list): is_root = True - random.shuffle(cluster_i) cluster_i1, cluster_i2 = array_split(cluster_i, 2) for j, cluster_j in enumerate(clusters_list): if i == j: continue - random.shuffle(cluster_j) cluster_j1, cluster_j2 = array_split(cluster_j, 2) e = cal_e_with_gin(data, cov, X + cluster_i1 + cluster_j1, Z + cluster_i2) pvals = [] - tmp_data = np.concatenate([data[:, Z + cluster_i2], e.reshape(-1, 1)], axis=1) for z in range(len(Z + cluster_i2)): - pvals.append(indep_test(tmp_data, z, - 1)) + pvals.append(indep_test(data[:, [z]], e[:, None], method=indep_test_method)) fisher_pval = fisher_test(pvals) if fisher_pval < alpha: is_root = False break if is_root: - K.append(cluster_i) + causal_order.append(cluster_i) clusters_list.remove(cluster_i) updated = True break @@ -114,7 +115,7 @@ def GIN(data, indep_test=kci, alpha=0.05): latent_id = 1 l_nodes = [] - for cluster in K: + for cluster in causal_order: l_node = GraphNode(f"L{latent_id}") l_node.set_node_type(NodeType.LATENT) G.add_node(l_node) @@ -148,7 +149,7 @@ def GIN(data, indep_test=kci, alpha=0.05): G.add_directed_edge(l_node, o_node) latent_id += 1 - return G, K + return G, causal_order def GIN_MI(data): @@ -165,7 +166,7 @@ def GIN_MI(data): ------- G : general graph causal graph - K : list + causal_order : list causal order ''' v_labels = list(range(data.shape[1])) @@ -190,16 +191,16 @@ def GIN_MI(data): cluster_list = merge_overlaping_cluster(cluster_list) # Step 2: Learning the Causal Order of Latent Variables - K = [] + causal_order = [] # this variable corresponds to K in paper while (len(cluster_list) != 0): - root = find_root(data, cov, cluster_list, K) - K.append(root) + root = find_root(data, cov, cluster_list, causal_order) + causal_order.append(root) cluster_list.remove(root) latent_id = 1 l_nodes = [] G = GeneralGraph([]) - for cluster in K: + for cluster in causal_order: l_node = GraphNode(f"L{latent_id}") l_node.set_node_type(NodeType.LATENT) l_nodes.append(l_node) @@ -213,14 +214,14 @@ def GIN_MI(data): G.add_directed_edge(l_node, o_node) latent_id += 1 - return G, K + return G, causal_order def cal_e_with_gin(data, cov, X, Z): cov_m = cov[np.ix_(Z, X)] _, _, v = np.linalg.svd(cov_m) omega = v.T[:, -1] - return np.dot(omega, data[:, X].T) + return np.dot(data[:, X], omega) def cal_dep_for_gin(data, cov, X, Z): @@ -248,17 +249,15 @@ def cal_dep_for_gin(data, cov, X, Z): return sta -def find_root(data, cov, clusters, K): +def find_root(data, cov, clusters, causal_order): ''' Find the causal order by statistics of dependence - Parameters ---------- data : data set (numpy ndarray) cov : covariance matrix clusters : clusters of observed variables - K : causal order - + causal_order : causal order Returns ------- root : latent root cause @@ -276,8 +275,8 @@ def find_root(data, cov, clusters, K): for k in range(1, len(i)): Z.append(i[k]) - if K: - for k in K: + if causal_order: + for k in causal_order: X.append(k[0]) Z.append(k[1]) @@ -355,4 +354,4 @@ def array_split(x, k): for i in range(k - extra): sub_arys.append(x[start:start + section_len]) start = start + section_len - return sub_arys + return sub_arys \ No newline at end of file diff --git a/tests/TestGIN.py b/tests/TestGIN.py index aac71024..16c65f37 100644 --- a/tests/TestGIN.py +++ b/tests/TestGIN.py @@ -1,104 +1,88 @@ import random -import sys -import io - -sys.path.append("") import unittest import numpy as np -import matplotlib.image as mpimg -import matplotlib.pyplot as plt from causallearn.search.HiddenCausal.GIN.GIN import GIN class TestGIN(unittest.TestCase): - def test_case1(self): - sample_size = 1000 - np.random.seed(0) - L1 = np.random.uniform(-1, 1, size=sample_size) ** 5 - L2 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(-1, 1, size=sample_size) ** 5 - X1 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(-1, 1, size=sample_size) ** 5 - X2 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(-1, 1, size=sample_size) ** 5 - X3 = np.random.uniform(0.5, 2.0) * L2 + np.random.uniform(-1, 1, size=sample_size) ** 5 - X4 = np.random.uniform(0.5, 2.0) * L2 + np.random.uniform(-1, 1, size=sample_size) ** 5 + indep_test_methods = ['kci', 'hsic'] + def test_case1(self): + sample_size = 500 + random.seed(42) + np.random.seed(42) + L1 = np.random.uniform(-1, 1, size=sample_size) + L2 = np.random.uniform(1.2, 1.8) * L1 + np.random.uniform(-1, 1, size=sample_size) + X1 = np.random.uniform(1.2, 1.8) * L1 + 0.2 * np.random.uniform(-1, 1, size=sample_size) + X2 = np.random.uniform(1.2, 1.8) * L1 + 0.2 * np.random.uniform(-1, 1, size=sample_size) + X3 = np.random.uniform(1.2, 1.8) * L2 + 0.2 * np.random.uniform(-1, 1, size=sample_size) + X4 = np.random.uniform(1.2, 1.8) * L2 + 0.2 * np.random.uniform(-1, 1, size=sample_size) data = np.array([X1, X2, X3, X4]).T data = (data - np.mean(data, axis=0)) / np.std(data, axis=0) - g, k = GIN(data) - print(g, k) - - # Visualization using pydot - from causallearn.utils.GraphUtils import GraphUtils - pyd = GraphUtils.to_pydot(g) - tmp_png = pyd.create_png(f="png") - fp = io.BytesIO(tmp_png) - img = mpimg.imread(fp, format='png') - plt.axis('off') - plt.imshow(img) - plt.show() - def test_case2(self): - sample_size = 1000 - np.random.seed(0) - L1 = np.random.uniform(-1, 1, size=sample_size) ** 5 - L2 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(-1, 1, size=sample_size) ** 5 - L3 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(0.5, 2.0) * L2 + np.random.uniform(-1, 1, - size=sample_size) ** 5 - X1 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(-1, 1, size=sample_size) ** 5 - X2 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(-1, 1, size=sample_size) ** 5 - X3 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(-1, 1, size=sample_size) ** 5 - X4 = np.random.uniform(0.5, 2.0) * L2 + np.random.uniform(-1, 1, size=sample_size) ** 5 - X5 = np.random.uniform(0.5, 2.0) * L2 + np.random.uniform(-1, 1, size=sample_size) ** 5 - X6 = np.random.uniform(0.5, 2.0) * L2 + np.random.uniform(-1, 1, size=sample_size) ** 5 - X7 = np.random.uniform(0.5, 2.0) * L3 + np.random.uniform(-1, 1, size=sample_size) ** 5 - X8 = np.random.uniform(0.5, 2.0) * L3 + np.random.uniform(-1, 1, size=sample_size) ** 5 - X9 = np.random.uniform(0.5, 2.0) * L3 + np.random.uniform(-1, 1, size=sample_size) ** 5 + ground_truth = [[0, 1], [2, 3]] - data = np.array([X1, X2, X3, X4, X5, X6, X7, X8, X9]).T - data = (data - np.mean(data, axis=0)) / np.std(data, axis=0) - g, k = GIN(data) - print(g, k) - - # Visualization using pydot - from causallearn.utils.GraphUtils import GraphUtils - pyd = GraphUtils.to_pydot(g) - tmp_png = pyd.create_png(f="png") - fp = io.BytesIO(tmp_png) - img = mpimg.imread(fp, format='png') - plt.axis('off') - plt.imshow(img) - plt.show() + TestGIN.run_gin_test(data, ground_truth, 0.05) - def test_case3(self): - sample_size = 1000 + + def test_case2(self): + sample_size = 500 random.seed(42) np.random.seed(42) - L1 = np.random.uniform(-1, 1, size=sample_size) ** 5 - L2 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(-1, 1, size=sample_size) ** 5 - L3 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(0.5, 2.0) * L2 + np.random.uniform(-1, 1, size=sample_size) ** 5 - L4 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(0.5, 2.0) * L2 + np.random.uniform(0.5, 2.0) * L3 + np.random.uniform(-1, 1, size=sample_size) ** 5 - - X1 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(0.5, 2.0) * L2 + np.random.uniform(-1, 1, size=sample_size) ** 5 - X2 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(0.5, 2.0) * L2 + np.random.uniform(-1, 1, size=sample_size) ** 5 - X3 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(0.5, 2.0) * L2 + np.random.uniform(-1, 1, size=sample_size) ** 5 - X4 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(0.5, 2.0) * L2 + np.random.uniform(-1, 1, size=sample_size) ** 5 - X5 = np.random.uniform(0.5, 2.0) * L3 + np.random.uniform(-1, 1, size=sample_size) ** 5 - X6 = np.random.uniform(0.5, 2.0) * L3 + np.random.uniform(-1, 1, size=sample_size) ** 5 - X7 = np.random.uniform(0.5, 2.0) * L4 + np.random.uniform(-1, 1, size=sample_size) ** 5 - X8 = np.random.uniform(0.5, 2.0) * L4 + np.random.uniform(-1, 1, size=sample_size) ** 5 + L1 = np.random.uniform(-1, 1, size=sample_size) + L2 = np.random.uniform(1.2, 1.8) * L1 + np.random.uniform(-1, 1, size=sample_size) + L3 = np.random.uniform(1.2, 1.8) * L1 + np.random.uniform(1.2, 1.8) * L2 + np.random.uniform(-1, 1, size=sample_size) + X1 = np.random.uniform(1.2, 1.8) * L1 + 0.2 * np.random.uniform(-1, 1, size=sample_size) + X2 = np.random.uniform(1.2, 1.8) * L1 + 0.2 * np.random.uniform(-1, 1, size=sample_size) + X3 = np.random.uniform(1.2, 1.8) * L1 + 0.2 * np.random.uniform(-1, 1, size=sample_size) + X4 = np.random.uniform(1.2, 1.8) * L2 + 0.2 * np.random.uniform(-1, 1, size=sample_size) + X5 = np.random.uniform(1.2, 1.8) * L2 + 0.2 * np.random.uniform(-1, 1, size=sample_size) + X6 = np.random.uniform(1.2, 1.8) * L2 + 0.2 * np.random.uniform(-1, 1, size=sample_size) + X7 = np.random.uniform(1.2, 1.8) * L3 + 0.2 * np.random.uniform(-1, 1, size=sample_size) + X8 = np.random.uniform(1.2, 1.8) * L3 + 0.2 * np.random.uniform(-1, 1, size=sample_size) + X9 = np.random.uniform(1.2, 1.8) * L3 + 0.2 * np.random.uniform(-1, 1, size=sample_size) + data = np.array([X1, X2, X3, X4, X5, X6, X7, X8, X9]).T + data = (data - np.mean(data, axis=0)) / np.std(data, axis=0) + + ground_truth = [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + + TestGIN.run_gin_test(data, ground_truth, 0.05) + + def test_case3(self): + sample_size = 500 + random.seed(0) + np.random.seed(0) + L1 = np.random.uniform(-1, 1, size=sample_size) + L2 = np.random.uniform(1.2, 1.8) * L1 + np.random.uniform(-1, 1, size=sample_size) + L3 = np.random.uniform(0.5, 0.8) * L1 + np.random.uniform(0.5, 0.8) * L2 + np.random.uniform(-1, 1, size=sample_size) + L4 = np.random.uniform(0.5, 0.8) * L1 + np.random.uniform(0.5, 0.8) * L2 + np.random.uniform(1.2, 1.8) * L3 + np.random.uniform(-1, 1, size=sample_size) + X1 = np.random.uniform(1.2, 1.8) * L1 + np.random.uniform(1.2, 1.8) * L2 + 0.2 * np.random.uniform(-1, 1, size=sample_size) + X2 = np.random.uniform(1.2, 1.8) * L1 + np.random.uniform(1.2, 1.8) * L2 + 0.2 * np.random.uniform(-1, 1, size=sample_size) + X3 = np.random.uniform(1.2, 1.8) * L1 + np.random.uniform(1.2, 1.8) * L2 + 0.2 * np.random.uniform(-1, 1, size=sample_size) + X4 = np.random.uniform(1.2, 1.8) * L1 + np.random.uniform(1.2, 1.8) * L2 + 0.2 * np.random.uniform(-1, 1, size=sample_size) + X5 = np.random.uniform(1.2, 1.8) * L3 + 0.2 * np.random.uniform(-1, 1, size=sample_size) + X6 = np.random.uniform(1.2, 1.8) * L3 + 0.2 * np.random.uniform(-1, 1, size=sample_size) + X7 = np.random.uniform(1.2, 1.8) * L4 + 0.2 * np.random.uniform(-1, 1, size=sample_size) + X8 = np.random.uniform(1.2, 1.8) * L4 + 0.2 * np.random.uniform(-1, 1, size=sample_size) data = np.array([X1, X2, X3, X4, X5, X6, X7, X8]).T data = (data - np.mean(data, axis=0)) / np.std(data, axis=0) - g, k = GIN(data) - print(g, k) - - # Visualization using pydot - from causallearn.utils.GraphUtils import GraphUtils - pyd = GraphUtils.to_pydot(g) - tmp_png = pyd.create_png(f="png") - fp = io.BytesIO(tmp_png) - img = mpimg.imread(fp, format='png') - plt.axis('off') - plt.imshow(img) - plt.show() \ No newline at end of file + + ground_truth = [[0, 1, 2, 3], [4, 5], [6, 7]] + + TestGIN.run_gin_test(data, ground_truth, 0.05) + + @staticmethod + def run_gin_test(data, ground_truth, alpha): + for indep_test_method in TestGIN.indep_test_methods: + _, causal_order = GIN(data, indep_test_method=indep_test_method, alpha=alpha) + causal_order = [sorted(cluster_i) for cluster_i in causal_order] + TestGIN.validate_result(ground_truth, causal_order) + + @staticmethod + def validate_result(ground_truth, estimated_result): + assert len(ground_truth) == len(estimated_result) + for i in range(len(estimated_result)): + assert estimated_result[i] == ground_truth[i]