diff --git a/causallearn/graph/GraphClass.py b/causallearn/graph/GraphClass.py index 175dd8ed..f65211e5 100644 --- a/causallearn/graph/GraphClass.py +++ b/causallearn/graph/GraphClass.py @@ -21,8 +21,11 @@ class CausalGraph: - def __init__(self, no_of_var: int): - node_names: List[str] = [("X%d" % (i + 1)) for i in range(no_of_var)] + def __init__(self, no_of_var: int, node_names: List[str] | None = None): + if node_names is None: + node_names = [("X%d" % (i + 1)) for i in range(no_of_var)] + assert len(node_names) == no_of_var, "number of node_names must match number of variables" + assert len(node_names) == len(set(node_names)), "node_names must be unique" nodes: List[Node] = [] for name in node_names: node = GraphNode(name) diff --git a/causallearn/search/ConstraintBased/PC.py b/causallearn/search/ConstraintBased/PC.py index 18f7d5a7..b6080df9 100644 --- a/causallearn/search/ConstraintBased/PC.py +++ b/causallearn/search/ConstraintBased/PC.py @@ -16,36 +16,56 @@ orient_by_background_knowledge -def pc(data: ndarray, alpha=0.05, indep_test=fisherz, stable: bool = True, uc_rule: int = 0, uc_priority: int = 2, - mvpc: bool = False, correction_name: str = 'MV_Crtn_Fisher_Z', - background_knowledge: BackgroundKnowledge | None = None, verbose: bool = False, show_progress: bool = True): +def pc( + data: ndarray, + alpha=0.05, + indep_test=fisherz, + stable: bool = True, + uc_rule: int = 0, + uc_priority: int = 2, + mvpc: bool = False, + correction_name: str = 'MV_Crtn_Fisher_Z', + background_knowledge: BackgroundKnowledge | None = None, + verbose: bool = False, + show_progress: bool = True, + node_names: List[str] | None = None, +): if data.shape[0] < data.shape[1]: warnings.warn("The number of features is much larger than the sample size!") if mvpc: # missing value PC if indep_test == fisherz: indep_test = mv_fisherz - return mvpc_alg(data=data, alpha=alpha, indep_test=indep_test, correction_name=correction_name, stable=stable, + return mvpc_alg(data=data, node_names=node_names, alpha=alpha, indep_test=indep_test, correction_name=correction_name, stable=stable, uc_rule=uc_rule, uc_priority=uc_priority, background_knowledge=background_knowledge, verbose=verbose, show_progress=show_progress) else: - return pc_alg(data=data, alpha=alpha, indep_test=indep_test, stable=stable, uc_rule=uc_rule, + return pc_alg(data=data, node_names=node_names, alpha=alpha, indep_test=indep_test, stable=stable, uc_rule=uc_rule, uc_priority=uc_priority, background_knowledge=background_knowledge, verbose=verbose, show_progress=show_progress) -def pc_alg(data: ndarray, alpha: float, indep_test, stable: bool, uc_rule: int, uc_priority: int, - background_knowledge: BackgroundKnowledge | None = None, - verbose: bool = False, - show_progress: bool = True) -> CausalGraph: +def pc_alg( + data: ndarray, + node_names: List[str] | None, + alpha: float, + indep_test, + stable: bool, + uc_rule: int, + uc_priority: int, + background_knowledge: BackgroundKnowledge | None = None, + verbose: bool = False, + show_progress: bool = True, +) -> CausalGraph: """ Perform Peter-Clark (PC) algorithm for causal discovery Parameters ---------- data : data set (numpy ndarray), shape (n_samples, n_features). The input data, where n_samples is the number of samples and n_features is the number of features. - alpha : float, desired significance level of independence tests (p_value) in (0,1) + node_names: Shape [n_features]. The name for each feature (each feature is represented as a Node in the graph, so it's also the node name) + alpha : float, desired significance level of independence tests (p_value) in (0, 1) indep_test : the function of the independence test being used [fisherz, chisq, gsq, kci] - fisherz: Fisher's Z conditional independence test @@ -79,7 +99,7 @@ def pc_alg(data: ndarray, alpha: float, indep_test, stable: bool, uc_rule: int, start = time.time() cg_1 = SkeletonDiscovery.skeleton_discovery(data, alpha, indep_test, stable, background_knowledge=background_knowledge, verbose=verbose, - show_progress=show_progress) + show_progress=show_progress, node_names=node_names) if background_knowledge is not None: orient_by_background_knowledge(cg_1, background_knowledge) @@ -114,16 +134,26 @@ def pc_alg(data: ndarray, alpha: float, indep_test, stable: bool, uc_rule: int, return cg -def mvpc_alg(data: ndarray, alpha: float, indep_test, correction_name: str, stable: bool, uc_rule: int, - uc_priority: int, background_knowledge: BackgroundKnowledge | None = None, - verbose: bool = False, - show_progress: bool = True) -> CausalGraph: +def mvpc_alg( + data: ndarray, + node_names: List[str] | None, + alpha: float, + indep_test, + correction_name: str, + stable: bool, + uc_rule: int, + uc_priority: int, + background_knowledge: BackgroundKnowledge | None = None, + verbose: bool = False, + show_progress: bool = True, +) -> CausalGraph: """ Perform missing value Peter-Clark (PC) algorithm for causal discovery Parameters ---------- data : data set (numpy ndarray), shape (n_samples, n_features). The input data, where n_samples is the number of samples and n_features is the number of features. + node_names: Shape [n_features]. The name for each feature (each feature is represented as a Node in the graph, so it's also the node name) alpha : float, desired significance level of independence tests (p_value) in (0,1) indep_test : name of the test-wise deletion independence test being used [mv_fisherz, mv_g_sq] @@ -169,7 +199,7 @@ def mvpc_alg(data: ndarray, alpha: float, indep_test, correction_name: str, stab ## a) Run PC algorithm with the 1st step skeleton; cg_pre = SkeletonDiscovery.skeleton_discovery(data, alpha, indep_test, stable, background_knowledge=background_knowledge, - verbose=verbose, show_progress=show_progress) + verbose=verbose, show_progress=show_progress, node_names=node_names) if background_knowledge is not None: orient_by_background_knowledge(cg_pre, background_knowledge) diff --git a/causallearn/utils/PCUtils/SkeletonDiscovery.py b/causallearn/utils/PCUtils/SkeletonDiscovery.py index bd2f2341..043e3a8c 100644 --- a/causallearn/utils/PCUtils/SkeletonDiscovery.py +++ b/causallearn/utils/PCUtils/SkeletonDiscovery.py @@ -4,6 +4,7 @@ import numpy as np from numpy import ndarray +from typing import List from tqdm.auto import tqdm from causallearn.graph.GraphClass import CausalGraph @@ -12,9 +13,16 @@ from causallearn.utils.PCUtils.Helper import append_value -def skeleton_discovery(data: ndarray, alpha: float, indep_test, stable: bool = True, - background_knowledge: BackgroundKnowledge | None = None, verbose: bool = False, - show_progress: bool = True) -> CausalGraph: +def skeleton_discovery( + data: ndarray, + alpha: float, + indep_test, + stable: bool = True, + background_knowledge: BackgroundKnowledge | None = None, + verbose: bool = False, + show_progress: bool = True, + node_names: List[str] | None = None, +) -> CausalGraph: """ Perform skeleton discovery @@ -34,6 +42,7 @@ def skeleton_discovery(data: ndarray, alpha: float, indep_test, stable: bool = T background_knowledge : background knowledge verbose : True iff verbose output should be printed. show_progress : True iff the algorithm progress should be show in console. + node_names: Shape [n_features]. The name for each feature (each feature is represented as a Node in the graph, so it's also the node name) Returns ------- @@ -47,7 +56,7 @@ def skeleton_discovery(data: ndarray, alpha: float, indep_test, stable: bool = T assert 0 < alpha < 1 no_of_var = data.shape[1] - cg = CausalGraph(no_of_var) + cg = CausalGraph(no_of_var, node_names) cg.set_ind_test(indep_test) cg.data_hash_key = hash(str(data)) if indep_test == chisq or indep_test == gsq: diff --git a/tests/TestMVPC.py b/tests/TestMVPC.py index 56fe90ef..836093f5 100644 --- a/tests/TestMVPC.py +++ b/tests/TestMVPC.py @@ -8,9 +8,8 @@ import numpy as np import pandas as pd -from causallearn.search.ConstraintBased.PC import (get_adjacancy_matrix, - mvpc_alg, pc, pc_alg) -from causallearn.utils.cit import chisq, fisherz, gsq, kci, mv_fisherz +from causallearn.search.ConstraintBased.PC import get_adjacancy_matrix, pc +from causallearn.utils.cit import fisherz, mv_fisherz def load(filename):