Skip to content

Commit

Permalink
Updated types for skel
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <adam2392@gmail.com>
  • Loading branch information
adam2392 committed Aug 24, 2022
1 parent 6457d8d commit befb6d5
Showing 1 changed file with 24 additions and 15 deletions.
39 changes: 24 additions & 15 deletions dodiscover/constraint/skeleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np

from dodiscover.ci import BaseConditionalIndependenceTest
from dodiscover.typing import Column

from ..context import Context

Expand All @@ -24,9 +25,9 @@ def __contains__(cls, item):


class SkeletonMethods(Enum, metaclass=MetaEnum):
complete = "complete"
neighbors = "neighbors"
neighbors_path = "neighbors_path"
COMPLETE = "complete"
NBRS = "neighbors"
NBRS_PATH = "neighbors_path"


def _iter_conditioning_set(
Expand Down Expand Up @@ -82,7 +83,7 @@ def _assign_weight(u, v, edge_attr):
path = nx.shortest_path(G, source=node, target=end, weight=_assign_weight)
if len(path) > 0:
if start in path:
raise RuntimeError("wtf?")
raise RuntimeError("There is an error with the input. This is not possible.")
nbrs.add(node)
return nbrs

Expand Down Expand Up @@ -111,7 +112,7 @@ class LearnSkeleton:
parents still, by default None. If None, then will not be used. If set, then
the conditioning set will be chosen lexographically based on the sorted
test statistic values of 'ith Pa(X) -> X', for each possible parent node of 'X'.
skeleton_method : str
skeleton_method : SkeletonMethods
The method to use for testing conditional independence. Must be one of
('complete', 'neighbors', 'neighbors_path'). See Notes for more details.
keep_sorted : bool
Expand Down Expand Up @@ -183,7 +184,7 @@ class LearnSkeleton:
"""

adj_graph_: nx.Graph
sep_set_: Dict[str, Dict[str, List[Set[Any]]]]
sep_set_: Dict[Column, Dict[Column, List[Set[Column]]]]
remove_edges: Set
min_cond_set_size_: int
max_cond_set_size_: int
Expand All @@ -192,12 +193,12 @@ class LearnSkeleton:
def __init__(
self,
ci_estimator: BaseConditionalIndependenceTest,
sep_set: Optional[Dict[str, Dict[str, List[Set[Any]]]]] = None,
sep_set: Optional[Dict[Column, Dict[Column, List[Set[Column]]]]] = None,
alpha: float = 0.05,
min_cond_set_size: int = 0,
max_cond_set_size: Optional[int] = None,
max_combinations: Optional[int] = None,
skeleton_method: str = "neighbors",
skeleton_method: SkeletonMethods = SkeletonMethods.NBRS,
keep_sorted: bool = False,
**ci_estimator_kwargs,
) -> None:
Expand Down Expand Up @@ -407,7 +408,9 @@ def fit(self, context: Context) -> None:

self.adj_graph_ = adj_graph

def _summarize_xy_comparison(self, x_var, y_var, removed_edge: bool, pvalue: float) -> None:
def _summarize_xy_comparison(
self, x_var: Column, y_var: Column, removed_edge: bool, pvalue: float
) -> None:
# exit loop if we have found an independency and removed the edge
if removed_edge:
remove_edge_str = "Removing edge"
Expand All @@ -421,7 +424,7 @@ def _summarize_xy_comparison(self, x_var, y_var, removed_edge: bool, pvalue: flo
)

def _compute_candidate_conditioning_sets(
self, adj_graph: nx.Graph, x_var, y_var, skeleton_method: str
self, adj_graph: nx.Graph, x_var: Column, y_var: Column, skeleton_method: SkeletonMethods
) -> Set:
"""Compute candidate conditioning sets.
Expand All @@ -433,7 +436,7 @@ def _compute_candidate_conditioning_sets(
The 'X' node.
y_var : node
The 'Y' node.
skeleton_method : str
skeleton_method : SkeletonMethods
The skeleton method, which dictates how we choose the corresponding
conditioning sets.
Expand All @@ -443,12 +446,12 @@ def _compute_candidate_conditioning_sets(
The set of nodes in 'adj_graph' that are candidates for the
conditioning set.
"""
if skeleton_method == "complete":
if skeleton_method == SkeletonMethods.COMPLETE:
possible_variables = set(adj_graph.nodes)
elif skeleton_method == "neighbors":
elif skeleton_method == SkeletonMethods.NBRS:
possible_variables = set(adj_graph.neighbors(x_var))
# possible_adjacencies.copy()
elif skeleton_method == "neighbors_path":
elif skeleton_method == SkeletonMethods.NBRS_PATH:
# constrain adjacency set to ones with a path from x_var to y_var
possible_variables = _find_neighbors_along_path(adj_graph, start=x_var, end=y_var)

Expand All @@ -469,7 +472,13 @@ def _compute_candidate_conditioning_sets(
return possible_variables

def _postprocess_ci_test(
self, adj_graph: nx.Graph, x_var, y_var, cond_set: Set, test_stat: float, pvalue: float
self,
adj_graph: nx.Graph,
x_var: Column,
y_var: Column,
cond_set: Set[Column],
test_stat: float,
pvalue: float,
) -> bool:
# keep track of the smallest test statistic, meaning the highest pvalue
# meaning the "most" independent. keep track of the maximum pvalue as well
Expand Down

0 comments on commit befb6d5

Please sign in to comment.