diff --git a/causallearn/search/HiddenCausal/GIN/GIN.py b/causallearn/search/HiddenCausal/GIN/GIN.py index a609982d..b8bbc1d5 100644 --- a/causallearn/search/HiddenCausal/GIN/GIN.py +++ b/causallearn/search/HiddenCausal/GIN/GIN.py @@ -5,18 +5,154 @@ From DMIRLab: https://dmir.gdut.edu.cn/ ''' +import random from collections import deque from itertools import combinations import numpy as np +from scipy.stats import chi2 from causallearn.graph.GeneralGraph import GeneralGraph from causallearn.graph.GraphNode import GraphNode from causallearn.graph.NodeType import NodeType +from causallearn.graph.Edge import Edge +from causallearn.graph.Endpoint import Endpoint from causallearn.search.FCMBased.lingam.hsic import hsic_test_gamma +from causallearn.utils.cit import kci -def GIN(data): +def fisher_test(pvals): + pvals = [pval if pval >= 1e-5 else 1e-5 for pval in pvals] + return min(pvals) + # fisher_stat = -2.0 * np.sum(np.log(pvals)) + # return 1 - chi2.cdf(fisher_stat, 2 * len(pvals)) + + +def GIN(data, indep_test=kci, alpha=0.05): + ''' + Learning causal structure of Latent Variables for Linear Non-Gaussian Latent Variable Model + with Generalized Independent Noise Condition + + Parameters + ---------- + data : numpy ndarray + data set + indep_test : callable, default=kci + the function of the independence test being used + alpha : float, default=0.05 + desired significance level of independence tests (p_value) in (0,1) + Returns + ------- + G : general graph + causal graph + K : list + causal order + ''' + n = data.shape[1] + cov = np.cov(data.T) + + var_set = set(range(n)) + cluster_size = 2 + clusters_list = [] + while cluster_size < len(var_set): + tmp_clusters_list = [] + for cluster in combinations(var_set, cluster_size): + remain_var_set = var_set - set(cluster) + e = cal_e_with_gin(data, cov, list(cluster), list(remain_var_set)) + pvals = [] + tmp_data = np.concatenate([data[:, list(remain_var_set)], e.reshape(-1, 1)], axis=1) + for z in range(len(remain_var_set)): + pvals.append(indep_test(tmp_data, z, - 1)) + fisher_pval = fisher_test(pvals) + if fisher_pval >= alpha: + tmp_clusters_list.append(cluster) + tmp_clusters_list = merge_overlaping_cluster(tmp_clusters_list) + clusters_list = clusters_list + tmp_clusters_list + for cluster in tmp_clusters_list: + var_set -= set(cluster) + cluster_size += 1 + + K = [] + updated = True + while updated: + updated = False + X = [] + Z = [] + for cluster_k in K: + cluster_k1, cluster_k2 = array_split(cluster_k, 2) + X += cluster_k1 + Z += cluster_k2 + + for i, cluster_i in enumerate(clusters_list): + is_root = True + random.shuffle(cluster_i) + cluster_i1, cluster_i2 = array_split(cluster_i, 2) + for j, cluster_j in enumerate(clusters_list): + if i == j: + continue + random.shuffle(cluster_j) + cluster_j1, cluster_j2 = array_split(cluster_j, 2) + e = cal_e_with_gin(data, cov, X + cluster_i1 + cluster_j1, Z + cluster_i2) + pvals = [] + tmp_data = np.concatenate([data[:, Z + cluster_i2], e.reshape(-1, 1)], axis=1) + for z in range(len(Z + cluster_i2)): + pvals.append(indep_test(tmp_data, z, - 1)) + fisher_pval = fisher_test(pvals) + if fisher_pval < alpha: + is_root = False + break + if is_root: + K.append(cluster_i) + clusters_list.remove(cluster_i) + updated = True + break + + G = GeneralGraph([]) + for var in var_set: + o_node = GraphNode(f"X{var + 1}") + G.add_node(o_node) + + latent_id = 1 + l_nodes = [] + + for cluster in K: + l_node = GraphNode(f"L{latent_id}") + l_node.set_node_type(NodeType.LATENT) + G.add_node(l_node) + for l in l_nodes: + G.add_directed_edge(l, l_node) + l_nodes.append(l_node) + + for o in cluster: + o_node = GraphNode(f"X{o + 1}") + G.add_node(o_node) + G.add_directed_edge(l_node, o_node) + latent_id += 1 + + undirected_l_nodes = [] + + for cluster in clusters_list: + l_node = GraphNode(f"L{latent_id}") + l_node.set_node_type(NodeType.LATENT) + G.add_node(l_node) + for l in l_nodes: + G.add_directed_edge(l, l_node) + + for l in undirected_l_nodes: + G.add_edge(Edge(l, l_node, Endpoint.TAIL, Endpoint.TAIL)) + + undirected_l_nodes.append(l_node) + + for o in cluster: + o_node = GraphNode(f"X{o + 1}") + G.add_node(o_node) + G.add_directed_edge(l_node, o_node) + latent_id += 1 + + return G, K + + +def GIN_MI(data): ''' Learning causal structure of Latent Variables for Linear Non-Gaussian Latent Variable Model with Generalized Independent Noise Condition @@ -81,6 +217,13 @@ def GIN(data): return G, K +def cal_e_with_gin(data, cov, X, Z): + cov_m = cov[np.ix_(Z, X)] + _, _, v = np.linalg.svd(cov_m) + omega = v.T[:, -1] + return np.dot(omega, data[:, X].T) + + def cal_dep_for_gin(data, cov, X, Z): ''' Calculate the statistics of dependence via Generalized Independent Noise Condition @@ -96,10 +239,8 @@ def cal_dep_for_gin(data, cov, X, Z): ------- sta : test statistic ''' - cov_m = cov[np.ix_(Z, X)] - _, _, v = np.linalg.svd(cov_m) - omega = v.T[:, -1] - e_xz = np.dot(omega, data[:, X].T) + + e_xz = cal_e_with_gin(data, cov, X, Z) sta = 0 for i in Z: @@ -160,6 +301,8 @@ def _get_all_elements(S): # merging cluster def merge_overlaping_cluster(cluster_list): v_labels = _get_all_elements(cluster_list) + if len(v_labels) == 0: + return [] cluster_dict = {i: -1 for i in v_labels} cluster_b = {i: [] for i in v_labels} cluster_len = 0 @@ -197,3 +340,20 @@ def merge_overlaping_cluster(cluster_list): cluster[cluster_dict[i]].append(i) return cluster + + +def array_split(x, k): + x_len = len(x) + # div_points = [] + sub_arys = [] + start = 0 + section_len = x_len // k + extra = x_len % k + for i in range(extra): + sub_arys.append(x[start:start + section_len + 1]) + start = start + section_len + 1 + + for i in range(k - extra): + sub_arys.append(x[start:start + section_len]) + start = start + section_len + return sub_arys diff --git a/tests/TestGIN.py b/tests/TestGIN.py index c1b702b5..f8617fea 100644 --- a/tests/TestGIN.py +++ b/tests/TestGIN.py @@ -1,3 +1,4 @@ +import random import sys sys.path.append("") @@ -45,3 +46,26 @@ def test_case2(self): data = (data - np.mean(data, axis=0)) / np.std(data, axis=0) g, k = GIN(data) print(g, k) + + def test_case3(self): + sample_size = 1000 + random.seed(42) + np.random.seed(42) + L1 = np.random.uniform(-1, 1, size=sample_size) ** 5 + L2 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(-1, 1, size=sample_size) ** 5 + L3 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(0.5, 2.0) * L2 + np.random.uniform(-1, 1, size=sample_size) ** 5 + L4 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(0.5, 2.0) * L2 + np.random.uniform(0.5, 2.0) * L3 + np.random.uniform(-1, 1, size=sample_size) ** 5 + + X1 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(0.5, 2.0) * L2 + np.random.uniform(-1, 1, size=sample_size) ** 5 + X2 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(0.5, 2.0) * L2 + np.random.uniform(-1, 1, size=sample_size) ** 5 + X3 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(0.5, 2.0) * L2 + np.random.uniform(-1, 1, size=sample_size) ** 5 + X4 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(0.5, 2.0) * L2 + np.random.uniform(-1, 1, size=sample_size) ** 5 + X5 = np.random.uniform(0.5, 2.0) * L3 + np.random.uniform(-1, 1, size=sample_size) ** 5 + X6 = np.random.uniform(0.5, 2.0) * L3 + np.random.uniform(-1, 1, size=sample_size) ** 5 + X7 = np.random.uniform(0.5, 2.0) * L4 + np.random.uniform(-1, 1, size=sample_size) ** 5 + X8 = np.random.uniform(0.5, 2.0) * L4 + np.random.uniform(-1, 1, size=sample_size) ** 5 + + data = np.array([X1, X2, X3, X4, X5, X6, X7, X8]).T + data = (data - np.mean(data, axis=0)) / np.std(data, axis=0) + g, k = GIN(data) + print(g, k) \ No newline at end of file