Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions causallearn/graph/GraphClass.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@


class CausalGraph:
def __init__(self, no_of_var: int):
node_names: List[str] = [("X%d" % (i + 1)) for i in range(no_of_var)]
def __init__(self, no_of_var: int, node_names: List[str] | None = None):
if node_names is None:
node_names = [("X%d" % (i + 1)) for i in range(no_of_var)]
assert len(node_names) == no_of_var, "number of node_names must match number of variables"
assert len(node_names) == len(set(node_names)), "node_names must be unique"
nodes: List[Node] = []
for name in node_names:
node = GraphNode(name)
Expand Down
62 changes: 46 additions & 16 deletions causallearn/search/ConstraintBased/PC.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,56 @@
orient_by_background_knowledge


def pc(data: ndarray, alpha=0.05, indep_test=fisherz, stable: bool = True, uc_rule: int = 0, uc_priority: int = 2,
mvpc: bool = False, correction_name: str = 'MV_Crtn_Fisher_Z',
background_knowledge: BackgroundKnowledge | None = None, verbose: bool = False, show_progress: bool = True):
def pc(
data: ndarray,
alpha=0.05,
indep_test=fisherz,
stable: bool = True,
uc_rule: int = 0,
uc_priority: int = 2,
mvpc: bool = False,
correction_name: str = 'MV_Crtn_Fisher_Z',
background_knowledge: BackgroundKnowledge | None = None,
verbose: bool = False,
show_progress: bool = True,
node_names: List[str] | None = None,
):
if data.shape[0] < data.shape[1]:
warnings.warn("The number of features is much larger than the sample size!")

if mvpc: # missing value PC
if indep_test == fisherz:
indep_test = mv_fisherz
return mvpc_alg(data=data, alpha=alpha, indep_test=indep_test, correction_name=correction_name, stable=stable,
return mvpc_alg(data=data, node_names=node_names, alpha=alpha, indep_test=indep_test, correction_name=correction_name, stable=stable,
uc_rule=uc_rule, uc_priority=uc_priority, background_knowledge=background_knowledge,
verbose=verbose,
show_progress=show_progress)
else:
return pc_alg(data=data, alpha=alpha, indep_test=indep_test, stable=stable, uc_rule=uc_rule,
return pc_alg(data=data, node_names=node_names, alpha=alpha, indep_test=indep_test, stable=stable, uc_rule=uc_rule,
uc_priority=uc_priority, background_knowledge=background_knowledge, verbose=verbose,
show_progress=show_progress)


def pc_alg(data: ndarray, alpha: float, indep_test, stable: bool, uc_rule: int, uc_priority: int,
background_knowledge: BackgroundKnowledge | None = None,
verbose: bool = False,
show_progress: bool = True) -> CausalGraph:
def pc_alg(
data: ndarray,
node_names: List[str] | None,
alpha: float,
indep_test,
stable: bool,
uc_rule: int,
uc_priority: int,
background_knowledge: BackgroundKnowledge | None = None,
verbose: bool = False,
show_progress: bool = True,
) -> CausalGraph:
"""
Perform Peter-Clark (PC) algorithm for causal discovery

Parameters
----------
data : data set (numpy ndarray), shape (n_samples, n_features). The input data, where n_samples is the number of samples and n_features is the number of features.
alpha : float, desired significance level of independence tests (p_value) in (0,1)
node_names: Shape [n_features]. The name for each feature (each feature is represented as a Node in the graph, so it's also the node name)
alpha : float, desired significance level of independence tests (p_value) in (0, 1)
indep_test : the function of the independence test being used
[fisherz, chisq, gsq, kci]
- fisherz: Fisher's Z conditional independence test
Expand Down Expand Up @@ -79,7 +99,7 @@ def pc_alg(data: ndarray, alpha: float, indep_test, stable: bool, uc_rule: int,
start = time.time()
cg_1 = SkeletonDiscovery.skeleton_discovery(data, alpha, indep_test, stable,
background_knowledge=background_knowledge, verbose=verbose,
show_progress=show_progress)
show_progress=show_progress, node_names=node_names)

if background_knowledge is not None:
orient_by_background_knowledge(cg_1, background_knowledge)
Expand Down Expand Up @@ -114,16 +134,26 @@ def pc_alg(data: ndarray, alpha: float, indep_test, stable: bool, uc_rule: int,
return cg


def mvpc_alg(data: ndarray, alpha: float, indep_test, correction_name: str, stable: bool, uc_rule: int,
uc_priority: int, background_knowledge: BackgroundKnowledge | None = None,
verbose: bool = False,
show_progress: bool = True) -> CausalGraph:
def mvpc_alg(
data: ndarray,
node_names: List[str] | None,
alpha: float,
indep_test,
correction_name: str,
stable: bool,
uc_rule: int,
uc_priority: int,
background_knowledge: BackgroundKnowledge | None = None,
verbose: bool = False,
show_progress: bool = True,
) -> CausalGraph:
"""
Perform missing value Peter-Clark (PC) algorithm for causal discovery

Parameters
----------
data : data set (numpy ndarray), shape (n_samples, n_features). The input data, where n_samples is the number of samples and n_features is the number of features.
node_names: Shape [n_features]. The name for each feature (each feature is represented as a Node in the graph, so it's also the node name)
alpha : float, desired significance level of independence tests (p_value) in (0,1)
indep_test : name of the test-wise deletion independence test being used
[mv_fisherz, mv_g_sq]
Expand Down Expand Up @@ -169,7 +199,7 @@ def mvpc_alg(data: ndarray, alpha: float, indep_test, correction_name: str, stab
## a) Run PC algorithm with the 1st step skeleton;
cg_pre = SkeletonDiscovery.skeleton_discovery(data, alpha, indep_test, stable,
background_knowledge=background_knowledge,
verbose=verbose, show_progress=show_progress)
verbose=verbose, show_progress=show_progress, node_names=node_names)
if background_knowledge is not None:
orient_by_background_knowledge(cg_pre, background_knowledge)

Expand Down
17 changes: 13 additions & 4 deletions causallearn/utils/PCUtils/SkeletonDiscovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
from numpy import ndarray
from typing import List
from tqdm.auto import tqdm

from causallearn.graph.GraphClass import CausalGraph
Expand All @@ -12,9 +13,16 @@
from causallearn.utils.PCUtils.Helper import append_value


def skeleton_discovery(data: ndarray, alpha: float, indep_test, stable: bool = True,
background_knowledge: BackgroundKnowledge | None = None, verbose: bool = False,
show_progress: bool = True) -> CausalGraph:
def skeleton_discovery(
data: ndarray,
alpha: float,
indep_test,
stable: bool = True,
background_knowledge: BackgroundKnowledge | None = None,
verbose: bool = False,
show_progress: bool = True,
node_names: List[str] | None = None,
) -> CausalGraph:
"""
Perform skeleton discovery

Expand All @@ -34,6 +42,7 @@ def skeleton_discovery(data: ndarray, alpha: float, indep_test, stable: bool = T
background_knowledge : background knowledge
verbose : True iff verbose output should be printed.
show_progress : True iff the algorithm progress should be show in console.
node_names: Shape [n_features]. The name for each feature (each feature is represented as a Node in the graph, so it's also the node name)

Returns
-------
Expand All @@ -47,7 +56,7 @@ def skeleton_discovery(data: ndarray, alpha: float, indep_test, stable: bool = T
assert 0 < alpha < 1

no_of_var = data.shape[1]
cg = CausalGraph(no_of_var)
cg = CausalGraph(no_of_var, node_names)
cg.set_ind_test(indep_test)
cg.data_hash_key = hash(str(data))
if indep_test == chisq or indep_test == gsq:
Expand Down
5 changes: 2 additions & 3 deletions tests/TestMVPC.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
import numpy as np
import pandas as pd

from causallearn.search.ConstraintBased.PC import (get_adjacancy_matrix,
mvpc_alg, pc, pc_alg)
from causallearn.utils.cit import chisq, fisherz, gsq, kci, mv_fisherz
from causallearn.search.ConstraintBased.PC import get_adjacancy_matrix, pc
from causallearn.utils.cit import fisherz, mv_fisherz


def load(filename):
Expand Down