Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 165 additions & 5 deletions causallearn/search/HiddenCausal/GIN/GIN.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,154 @@
From DMIRLab: https://dmir.gdut.edu.cn/
'''

import random
from collections import deque
from itertools import combinations

import numpy as np
from scipy.stats import chi2

from causallearn.graph.GeneralGraph import GeneralGraph
from causallearn.graph.GraphNode import GraphNode
from causallearn.graph.NodeType import NodeType
from causallearn.graph.Edge import Edge
from causallearn.graph.Endpoint import Endpoint
from causallearn.search.FCMBased.lingam.hsic import hsic_test_gamma
from causallearn.utils.cit import kci


def GIN(data):
def fisher_test(pvals):
pvals = [pval if pval >= 1e-5 else 1e-5 for pval in pvals]
return min(pvals)
# fisher_stat = -2.0 * np.sum(np.log(pvals))
# return 1 - chi2.cdf(fisher_stat, 2 * len(pvals))


def GIN(data, indep_test=kci, alpha=0.05):
'''
Learning causal structure of Latent Variables for Linear Non-Gaussian Latent Variable Model
with Generalized Independent Noise Condition

Parameters
----------
data : numpy ndarray
data set
indep_test : callable, default=kci
the function of the independence test being used
alpha : float, default=0.05
desired significance level of independence tests (p_value) in (0,1)
Returns
-------
G : general graph
causal graph
K : list
causal order
'''
n = data.shape[1]
cov = np.cov(data.T)

var_set = set(range(n))
cluster_size = 2
clusters_list = []
while cluster_size < len(var_set):
tmp_clusters_list = []
for cluster in combinations(var_set, cluster_size):
remain_var_set = var_set - set(cluster)
e = cal_e_with_gin(data, cov, list(cluster), list(remain_var_set))
pvals = []
tmp_data = np.concatenate([data[:, list(remain_var_set)], e.reshape(-1, 1)], axis=1)
for z in range(len(remain_var_set)):
pvals.append(indep_test(tmp_data, z, - 1))
fisher_pval = fisher_test(pvals)
if fisher_pval >= alpha:
tmp_clusters_list.append(cluster)
tmp_clusters_list = merge_overlaping_cluster(tmp_clusters_list)
clusters_list = clusters_list + tmp_clusters_list
for cluster in tmp_clusters_list:
var_set -= set(cluster)
cluster_size += 1

K = []
updated = True
while updated:
updated = False
X = []
Z = []
for cluster_k in K:
cluster_k1, cluster_k2 = array_split(cluster_k, 2)
X += cluster_k1
Z += cluster_k2

for i, cluster_i in enumerate(clusters_list):
is_root = True
random.shuffle(cluster_i)
cluster_i1, cluster_i2 = array_split(cluster_i, 2)
for j, cluster_j in enumerate(clusters_list):
if i == j:
continue
random.shuffle(cluster_j)
cluster_j1, cluster_j2 = array_split(cluster_j, 2)
e = cal_e_with_gin(data, cov, X + cluster_i1 + cluster_j1, Z + cluster_i2)
pvals = []
tmp_data = np.concatenate([data[:, Z + cluster_i2], e.reshape(-1, 1)], axis=1)
for z in range(len(Z + cluster_i2)):
pvals.append(indep_test(tmp_data, z, - 1))
fisher_pval = fisher_test(pvals)
if fisher_pval < alpha:
is_root = False
break
if is_root:
K.append(cluster_i)
clusters_list.remove(cluster_i)
updated = True
break

G = GeneralGraph([])
for var in var_set:
o_node = GraphNode(f"X{var + 1}")
G.add_node(o_node)

latent_id = 1
l_nodes = []

for cluster in K:
l_node = GraphNode(f"L{latent_id}")
l_node.set_node_type(NodeType.LATENT)
G.add_node(l_node)
for l in l_nodes:
G.add_directed_edge(l, l_node)
l_nodes.append(l_node)

for o in cluster:
o_node = GraphNode(f"X{o + 1}")
G.add_node(o_node)
G.add_directed_edge(l_node, o_node)
latent_id += 1

undirected_l_nodes = []

for cluster in clusters_list:
l_node = GraphNode(f"L{latent_id}")
l_node.set_node_type(NodeType.LATENT)
G.add_node(l_node)
for l in l_nodes:
G.add_directed_edge(l, l_node)

for l in undirected_l_nodes:
G.add_edge(Edge(l, l_node, Endpoint.TAIL, Endpoint.TAIL))

undirected_l_nodes.append(l_node)

for o in cluster:
o_node = GraphNode(f"X{o + 1}")
G.add_node(o_node)
G.add_directed_edge(l_node, o_node)
latent_id += 1

return G, K


def GIN_MI(data):
'''
Learning causal structure of Latent Variables for Linear Non-Gaussian Latent Variable Model
with Generalized Independent Noise Condition
Expand Down Expand Up @@ -81,6 +217,13 @@ def GIN(data):
return G, K


def cal_e_with_gin(data, cov, X, Z):
cov_m = cov[np.ix_(Z, X)]
_, _, v = np.linalg.svd(cov_m)
omega = v.T[:, -1]
return np.dot(omega, data[:, X].T)


def cal_dep_for_gin(data, cov, X, Z):
'''
Calculate the statistics of dependence via Generalized Independent Noise Condition
Expand All @@ -96,10 +239,8 @@ def cal_dep_for_gin(data, cov, X, Z):
-------
sta : test statistic
'''
cov_m = cov[np.ix_(Z, X)]
_, _, v = np.linalg.svd(cov_m)
omega = v.T[:, -1]
e_xz = np.dot(omega, data[:, X].T)

e_xz = cal_e_with_gin(data, cov, X, Z)

sta = 0
for i in Z:
Expand Down Expand Up @@ -160,6 +301,8 @@ def _get_all_elements(S):
# merging cluster
def merge_overlaping_cluster(cluster_list):
v_labels = _get_all_elements(cluster_list)
if len(v_labels) == 0:
return []
cluster_dict = {i: -1 for i in v_labels}
cluster_b = {i: [] for i in v_labels}
cluster_len = 0
Expand Down Expand Up @@ -197,3 +340,20 @@ def merge_overlaping_cluster(cluster_list):
cluster[cluster_dict[i]].append(i)

return cluster


def array_split(x, k):
x_len = len(x)
# div_points = []
sub_arys = []
start = 0
section_len = x_len // k
extra = x_len % k
for i in range(extra):
sub_arys.append(x[start:start + section_len + 1])
start = start + section_len + 1

for i in range(k - extra):
sub_arys.append(x[start:start + section_len])
start = start + section_len
return sub_arys
24 changes: 24 additions & 0 deletions tests/TestGIN.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import random
import sys

sys.path.append("")
Expand Down Expand Up @@ -45,3 +46,26 @@ def test_case2(self):
data = (data - np.mean(data, axis=0)) / np.std(data, axis=0)
g, k = GIN(data)
print(g, k)

def test_case3(self):
sample_size = 1000
random.seed(42)
np.random.seed(42)
L1 = np.random.uniform(-1, 1, size=sample_size) ** 5
L2 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(-1, 1, size=sample_size) ** 5
L3 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(0.5, 2.0) * L2 + np.random.uniform(-1, 1, size=sample_size) ** 5
L4 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(0.5, 2.0) * L2 + np.random.uniform(0.5, 2.0) * L3 + np.random.uniform(-1, 1, size=sample_size) ** 5

X1 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(0.5, 2.0) * L2 + np.random.uniform(-1, 1, size=sample_size) ** 5
X2 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(0.5, 2.0) * L2 + np.random.uniform(-1, 1, size=sample_size) ** 5
X3 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(0.5, 2.0) * L2 + np.random.uniform(-1, 1, size=sample_size) ** 5
X4 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(0.5, 2.0) * L2 + np.random.uniform(-1, 1, size=sample_size) ** 5
X5 = np.random.uniform(0.5, 2.0) * L3 + np.random.uniform(-1, 1, size=sample_size) ** 5
X6 = np.random.uniform(0.5, 2.0) * L3 + np.random.uniform(-1, 1, size=sample_size) ** 5
X7 = np.random.uniform(0.5, 2.0) * L4 + np.random.uniform(-1, 1, size=sample_size) ** 5
X8 = np.random.uniform(0.5, 2.0) * L4 + np.random.uniform(-1, 1, size=sample_size) ** 5

data = np.array([X1, X2, X3, X4, X5, X6, X7, X8]).T
data = (data - np.mean(data, axis=0)) / np.std(data, axis=0)
g, k = GIN(data)
print(g, k)