Skip to content
34 changes: 7 additions & 27 deletions causallearn/graph/GraphClass.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from causallearn.graph.Node import Node
from causallearn.utils.GraphUtils import GraphUtils
from causallearn.utils.PCUtils.Helper import list_union, powerset
from causallearn.utils.cit import CIT


class CausalGraph:
Expand All @@ -34,10 +35,7 @@ def __init__(self, no_of_var: int, node_names: List[str] | None = None):
for i in range(no_of_var):
for j in range(i + 1, no_of_var):
self.G.add_edge(Edge(nodes[i], nodes[j], Endpoint.TAIL, Endpoint.TAIL))

self.data = None # store the data
self.test = None # store the name of the conditional independence test
self.corr_mat = None # store the correlation matrix of the data
Comment on lines -38 to -40
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the attributes you deleted, did you make sure that no code referenced it anymore? :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I verified this by PyCharm -> Find Usages and there is no other usages in the project.

Is this a reliable way to find usages or do you have any other recommendations?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know. Maybe try it using other variables to see whether it did a great job? Normally these kind of things should rely on simple tests (if we have 100% test coverage hhh) && smart IDE.

self.test: CIT | None = None
self.sepset = np.empty((no_of_var, no_of_var), object) # store the collection of sepsets
self.definite_UC = [] # store the list of definite unshielded colliders
self.definite_non_UC = [] # store the list of definite unshielded non-colliders
Expand All @@ -47,35 +45,17 @@ def __init__(self, no_of_var: int, node_names: List[str] | None = None):
self.nx_skel = nx.Graph() # store the undirected graph
self.labels = {}
self.prt_m = {} # store the parents of missingness indicators
self.mvpc = False
self.cardinalities = None # only works when self.data is discrete, i.e. self.test is chisq or gsq
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

self.is_discrete = False
self.citest_cache = dict()
self.data_hash_key = None
self.ci_test_hash_key = None

def set_ind_test(self, indep_test, mvpc=False):


def set_ind_test(self, indep_test):
"""Set the conditional independence test that will be used"""
# assert name_of_test in ["Fisher_Z", "Chi_sq", "G_sq"]
self.mvpc = mvpc
self.test = indep_test
self.ci_test_hash_key = hash(indep_test)

def ci_test(self, i: int, j: int, S) -> float:
"""Define the conditional independence test"""
# assert i != j and not i in S and not j in S
if self.mvpc:
return self.test(self.data, self.nx_skel, self.prt_m, i, j, S)

i, j = (i, j) if (i < j) else (j, i)
ijS_key = (i, j, frozenset(S), self.data_hash_key, self.ci_test_hash_key)
if ijS_key in self.citest_cache:
return self.citest_cache[ijS_key]
# if discrete, assert self.test is chisq or gsq, pass into cardinalities
pValue = self.test(self.data, i, j, S, self.cardinalities) if self.is_discrete \
else self.test(self.data, i, j, S)
self.citest_cache[ijS_key] = pValue
return pValue
if self.test.method == 'mc_fisherz': return self.test(i, j, S, self.nx_skel, self.prt_m)
return self.test(i, j, S)

def neighbors(self, i: int):
"""Find the neighbors of node i in adjmat"""
Expand Down
270 changes: 9 additions & 261 deletions causallearn/search/ConstraintBased/CDNOD.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge
from causallearn.utils.PCUtils.BackgroundKnowledgeOrientUtils import orient_by_background_knowledge
from causallearn.utils.cit import *
from causallearn.search.ConstraintBased.PC import get_parent_missingness_pairs, skeleton_correction


def cdnod(data: ndarray, c_indx: ndarray, alpha: float = 0.05, indep_test=fisherz, stable: bool = True,
uc_rule: int = 0, uc_priority: int = 2, mvcdnod: bool = False, correction_name: str = 'MV_Crtn_Fisher_Z',
background_knowledge: Optional[BackgroundKnowledge] = None, verbose: bool = False,
def cdnod(data: ndarray, c_indx: ndarray, alpha: float=0.05, indep_test: str=fisherz, stable: bool=True,
uc_rule: int=0, uc_priority: int=2, mvcdnod: bool=False, correction_name: str='MV_Crtn_Fisher_Z',
background_knowledge: Optional[BackgroundKnowledge]=None, verbose: bool=False,
show_progress: bool = True) -> CausalGraph:
"""
Causal discovery from nonstationary/heterogeneous data
Expand Down Expand Up @@ -43,7 +44,7 @@ def cdnod(data: ndarray, c_indx: ndarray, alpha: float = 0.05, indep_test=fisher
show_progress=show_progress)


def cdnod_alg(data: ndarray, alpha: float, indep_test, stable: bool, uc_rule: int, uc_priority: int,
def cdnod_alg(data: ndarray, alpha: float, indep_test: str, stable: bool, uc_rule: int, uc_priority: int,
background_knowledge: Optional[BackgroundKnowledge] = None, verbose: bool = False,
show_progress: bool = True) -> CausalGraph:
"""
Expand Down Expand Up @@ -84,6 +85,7 @@ def cdnod_alg(data: ndarray, alpha: float, indep_test, stable: bool, uc_rule: in

"""
start = time.time()
indep_test = CIT(data, indep_test)
cg_1 = SkeletonDiscovery.skeleton_discovery(data, alpha, indep_test, stable)

# orient the direction from c_indx to X, if there is an edge between c_indx and X
Expand Down Expand Up @@ -124,7 +126,7 @@ def cdnod_alg(data: ndarray, alpha: float, indep_test, stable: bool, uc_rule: in
return cg


def mvcdnod_alg(data: ndarray, alpha: float, indep_test, correction_name: str, stable: bool, uc_rule: int,
def mvcdnod_alg(data: ndarray, alpha: float, indep_test: str, correction_name: str, stable: bool, uc_rule: int,
uc_priority: int, verbose: bool, show_progress: bool) -> CausalGraph:
"""
:param data: data set (numpy ndarray)
Expand Down Expand Up @@ -154,9 +156,9 @@ def mvcdnod_alg(data: ndarray, alpha: float, indep_test, correction_name: str, s
"""

start = time.time()

indep_test = CIT(data, indep_test)
## Step 1: detect the direct causes of missingness indicators
prt_m = get_prt_mpairs(data, alpha, indep_test, stable)
prt_m = get_parent_missingness_pairs(data, alpha, indep_test, stable)
# print('Finish detecting the parents of missingness indicators. ')

## Step 2:
Expand Down Expand Up @@ -204,257 +206,3 @@ def mvcdnod_alg(data: ndarray, alpha: float, indep_test, correction_name: str, s
cg.PC_elapsed = end - start

return cg


#######################################################################################################################
## *********** Functions for Step 1 ***********
def get_prt_mpairs(data: ndarray, alpha: float, indep_test, stable: bool = True) -> Dict[str, list]:
"""
Detect the parents of missingness indicators
If a missingness indicator has no parent, it will not be included in the result
:param data: data set (numpy ndarray)
:param alpha: desired significance level in (0, 1) (float)
:param indep_test: name of the test-wise deletion independence test being used
- "MV_Fisher_Z": Fisher's Z conditional independence test
- "MV_G_sq": G-squared conditional independence test (TODO: under development)
:param stable: run stabilized skeleton discovery if True (default = True)
:return:
cg: a CausalGraph object
"""
prt_m = {'prt': [], 'm': []}

## Get the index of missingness indicators
m_indx = get_mindx(data)

## Get the index of parents of missingness indicators
# If the missingness indicator has no parent, then it will not be collected in prt_m
for r in m_indx:
prt_r = detect_parent(r, data, alpha, indep_test, stable)
if isempty(prt_r):
pass
else:
prt_m['prt'].append(prt_r)
prt_m['m'].append(r)
return prt_m


def isempty(prt_r: ndarray) -> bool:
"""Test whether the parent of a missingness indicator is empty"""
return len(prt_r) == 0


def get_mindx(data: ndarray) -> List[int]:
"""Detect the parents of missingness indicators
:param data: data set (numpy ndarray)
:return:
m_indx: list, the index of missingness indicators
"""

m_indx = []
_, ncol = np.shape(data)
for i in range(ncol):
if np.isnan(data[:, i]).any():
m_indx.append(i)
return m_indx


def detect_parent(r: int, data_: ndarray, alpha: float, indep_test, stable: bool = True) -> ndarray:
"""Detect the parents of a missingness indicator
:param r: the missingness indicator
:param data_: data set (numpy ndarray)
:param alpha: desired significance level in (0, 1) (float)
:param indep_test: name of the test-wise deletion independence test being used
- "MV_Fisher_Z": Fisher's Z conditional independence test
- "MV_G_sq": G-squared conditional independence test (TODO: under development)
:param stable: run stabilized skeleton discovery if True (default = True)
: return:
prt: parent of the missingness indicator, r
"""
## TODO: in the test-wise deletion CI test, if test between a binary and a continuous variable,
# there can be the case where the binary variable only take one value after deletion.
# It is because the assumption is violated.

## *********** Adaptation 0 ***********
# For avoid changing the original data
data = data_.copy()
## *********** End ***********

assert type(data) == np.ndarray
assert 0 < alpha < 1

## *********** Adaptation 1 ***********
# data
## Replace the variable r with its missingness indicator
## If r is not a missingness indicator, return [].
data[:, r] = np.isnan(data[:, r]).astype(float) # True is missing; false is not missing
if sum(data[:, r]) == 0 or sum(data[:, r]) == len(data[:, r]):
return np.empty(0)
## *********** End ***********

no_of_var = data.shape[1]
cg = CausalGraph(no_of_var)
cg.data = data
cg.set_ind_test(indep_test)
cg.corr_mat = np.corrcoef(data, rowvar=False) if indep_test == fisherz else []

node_ids = range(no_of_var)
pair_of_variables = list(permutations(node_ids, 2))

depth = -1
while cg.max_degree() - 1 > depth:
depth += 1
edge_removal = []
for (x, y) in pair_of_variables:

## *********** Adaptation 2 ***********
# the skeleton search
## Only test which variable is the neighbor of r
if x != r:
continue
## *********** End ***********

Neigh_x = cg.neighbors(x)
if y not in Neigh_x:
continue
else:
Neigh_x = np.delete(Neigh_x, np.where(Neigh_x == y))

if len(Neigh_x) >= depth:
for S in combinations(Neigh_x, depth):
p = cg.ci_test(x, y, S)
if p > alpha:
if not stable: # Unstable: Remove x---y right away
edge1 = cg.G.get_edge(cg.G.nodes[x], cg.G.nodes[y])
if edge1 is not None:
cg.G.remove_edge(edge1)
edge2 = cg.G.get_edge(cg.G.nodes[y], cg.G.nodes[x])
if edge2 is not None:
cg.G.remove_edge(edge2)
else: # Stable: x---y will be removed only
edge_removal.append((x, y)) # after all conditioning sets at
edge_removal.append((y, x)) # depth l have been considered
Helper.append_value(cg.sepset, x, y, S)
Helper.append_value(cg.sepset, y, x, S)
break

for (x, y) in list(set(edge_removal)):
edge1 = cg.G.get_edge(cg.G.nodes[x], cg.G.nodes[y])
if edge1 is not None:
cg.G.remove_edge(edge1)

## *********** Adaptation 3 ***********
## extract the parent of r from the graph
cg.to_nx_skeleton()
cg_skel_adj: ndarray = nx.to_numpy_array(cg.nx_skel).astype(int)
prt = get_parent(r, cg_skel_adj)
## *********** End ***********

return prt


def get_parent(r: int, cg_skel_adj: ndarray) -> ndarray:
"""Get the neighbors of missingness indicators which are the parents
:param r: the missingness indicator index
:param cg_skel_adj: adjacancy matrix of a causal skeleton
:return:
prt: list, parents of the missingness indicator r
"""
num_var = len(cg_skel_adj[0, :])
indx = np.array([i for i in range(num_var)])
prt = indx[cg_skel_adj[r, :] == 1]
return prt


## *********** END ***********
#######################################################################################################################

def skeleton_correction(data: ndarray, alpha: float, test_with_correction_name: str,
init_cg: CausalGraph, prt_m: dict, stable: bool = True) -> CausalGraph:
"""Perform skeleton discovery
:param data: data set (numpy ndarray)
:param alpha: desired significance level in (0, 1) (float)
:param test_with_correction_name: name of the independence test being used
- "MV_Crtn_Fisher_Z": Fisher's Z conditional independence test
- "MV_Crtn_G_sq": G-squared conditional independence test
:param stable: run stabilized skeleton discovery if True (default = True)
:return:
cg: a CausalGraph object
"""

assert type(data) == np.ndarray
assert 0 < alpha < 1
assert test_with_correction_name in ["MV_Crtn_Fisher_Z", "MV_Crtn_G_sq"]

## *********** Adaption 1 ***********
no_of_var = data.shape[1]

## Initialize the graph with the result of test-wise deletion skeletion search
cg = init_cg

cg.data = data
if test_with_correction_name in ["MV_Crtn_Fisher_Z", "MV_Crtn_G_sq"]:
cg.set_ind_test(mc_fisherz, True)
# No need of the correlation matrix if using test-wise deletion test
cg.corr_mat = np.corrcoef(data, rowvar=False) if test_with_correction_name == "MV_Crtn_Fisher_Z" else []
cg.prt_m = prt_m
## *********** Adaption 1 ***********

node_ids = range(no_of_var)
pair_of_variables = list(permutations(node_ids, 2))

depth = -1
while cg.max_degree() - 1 > depth:
depth += 1
edge_removal = []
for (x, y) in pair_of_variables:
Neigh_x = cg.neighbors(x)
if y not in Neigh_x:
continue
else:
Neigh_x = np.delete(Neigh_x, np.where(Neigh_x == y))

if len(Neigh_x) >= depth:
for S in combinations(Neigh_x, depth):
p = cg.ci_test(x, y, S)
if p > alpha:
if not stable: # Unstable: Remove x---y right away
edge1 = cg.G.get_edge(cg.G.nodes[x], cg.G.nodes[y])
if edge1 is not None:
cg.G.remove_edge(edge1)
edge2 = cg.G.get_edge(cg.G.nodes[y], cg.G.nodes[x])
if edge2 is not None:
cg.G.remove_edge(edge2)
else: # Stable: x---y will be removed only
edge_removal.append((x, y)) # after all conditioning sets at
edge_removal.append((y, x)) # depth l have been considered
Helper.append_value(cg.sepset, x, y, S)
Helper.append_value(cg.sepset, y, x, S)
break

for (x, y) in list(set(edge_removal)):
edge1 = cg.G.get_edge(cg.G.nodes[x], cg.G.nodes[y])
if edge1 is not None:
cg.G.remove_edge(edge1)

return cg


#######################################################################################################################

# *********** Evaluation util ***********

def get_adjacancy_matrix(g: CausalGraph):
return nx.to_numpy_array(g.nx_graph).astype(int)


def matrix_diff(cg1: CausalGraph, cg2: CausalGraph):
adj1 = get_adjacancy_matrix(cg1)
adj2 = get_adjacancy_matrix(cg2)
count = 0
diff_ls = []
for i in range(len(adj1[:, ])):
for j in range(len(adj2[:, ])):
if adj1[i, j] != adj2[i, j]:
diff_ls.append((i, j))
count += 1
return count / 2, diff_ls
Loading