Skip to content

Commit

Permalink
Improve API for conditional independency tests (#202)
Browse files Browse the repository at this point in the history
  • Loading branch information
cthoyt committed Jan 26, 2024
1 parent 6572540 commit 1674709
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 24 deletions.
85 changes: 83 additions & 2 deletions src/y0/algorithm/conditional_independencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,104 @@

from functools import partial
from itertools import combinations, groupby
from typing import Callable, Iterable, Optional, Sequence, Set, Tuple
from typing import Callable, Iterable, List, Optional, Sequence, Set, Tuple, Union

import networkx as nx
import pandas as pd
from tqdm.auto import tqdm

from ..dsl import Variable
from ..graph import NxMixedGraph
from ..struct import DSeparationJudgement
from ..struct import (
DEFAULT_SIGNIFICANCE,
CITest,
CITestTuple,
DSeparationJudgement,
_ensure_method,
)
from ..util.combinatorics import powerset

__all__ = [
"are_d_separated",
"minimal",
"get_conditional_independencies",
"test_conditional_independencies",
"add_ci_undirected_edges",
]


def add_ci_undirected_edges(
graph: NxMixedGraph,
data: pd.DataFrame,
*,
method: Optional[CITest] = None,
significance_level: Optional[float] = None,
) -> NxMixedGraph:
"""Add undirected edges between d-separated nodes that fail a data-driven conditional independency test.
:param graph: An acyclic directed mixed graph
:param data: observational data corresponding to the graph
:param method:
The conditional independency test to use. If None, defaults to
:data:`y0.struct.DEFAULT_CONTINUOUS_CI_TEST` for continuous data
or :data:`y0.struct.DEFAULT_DISCRETE_CI_TEST` for discrete data.
:param significance_level: The statistical tests employ this value for
comparison with the p-value of the test to determine the independence of
the tested variables. If none, defaults to 0.05.
:returns: A copy of the input graph potentially with new undirected edges added
"""
rv = graph.copy()
for judgement, result in test_conditional_independencies(
graph=graph, data=data, method=method, boolean=True, significance_level=significance_level
):
if not result:
rv.add_undirected_edge(judgement.left, judgement.right)
return rv


def test_conditional_independencies(
graph: NxMixedGraph,
data: pd.DataFrame,
*,
method: Optional[CITest] = None,
boolean: bool = False,
significance_level: Optional[float] = None,
_method_checked: bool = False,
) -> List[Tuple[DSeparationJudgement, Union[bool, CITestTuple]]]:
"""Gets CIs with :func:`get_conditional_independencies` then tests them against data.
:param graph: An acyclic directed mixed graph
:param data: observational data corresponding to the graph
:param method:
The conditional independency test to use. If None, defaults to
:data:`y0.struct.DEFAULT_CONTINUOUS_CI_TEST` for continuous data
or :data:`y0.struct.DEFAULT_DISCRETE_CI_TEST` for discrete data.
:param boolean:
If set to true, switches the test return type to be a pre-computed
boolean based on the significance level (see parameter below)
:param significance_level: The statistical tests employ this value for
comparison with the p-value of the test to determine the independence of
the tested variables. If none, defaults to 0.05.
:returns: A copy of the input graph potentially with new undirected edges added
"""
if significance_level is None:
significance_level = DEFAULT_SIGNIFICANCE
method = _ensure_method(method, data, skip=_method_checked)
return [
(
judgement,
judgement.test(
data,
boolean=boolean,
method=method,
significance_level=significance_level,
_method_checked=True,
),
)
for judgement in get_conditional_independencies(graph)
]


def get_conditional_independencies(
graph: NxMixedGraph,
*,
Expand Down
17 changes: 4 additions & 13 deletions src/y0/algorithm/falsification.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,24 +98,15 @@ def get_falsifications(
results = []
method = _ensure_method(method, df)
for judgement in tqdm(judgements, disable=not verbose, desc="Checking conditionals"):
result = judgement.test(df, method=method)
# Person's correlation returns a pair with the first element being the Person's correlation
# and the second being the p-value. The other methods return a triple with the first element
# being the Chi^2 statistic, the second being the p-value, and the third being the degrees of
# freedom.
if method == "pearson":
stat, p_value = result
dof = None
else:
stat, p_value, dof = result
result = judgement.test(df, method=method, boolean=False)
results.append(
(
judgement.left.name,
judgement.right.name,
"|".join(c.name for c in judgement.conditions),
stat,
p_value,
dof,
result.statistic,
result.p_value,
result.dof,
)
)
evidence_df = pd.DataFrame(
Expand Down
7 changes: 7 additions & 0 deletions src/y0/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ def __contains__(self, item: Variable) -> bool:
"""Check if the given item is a node in the graph."""
return item in self.directed

def copy(self):
"""Get a copy of the graph."""
return self.__class__(
directed=self.directed.copy(),
undirected=self.undirected.copy(),
)

def is_counterfactual(self) -> bool:
"""Check if this is a counterfactual graph."""
return any(isinstance(n, CounterfactualVariable) for n in self.nodes())
Expand Down
50 changes: 44 additions & 6 deletions src/y0/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from dataclasses import dataclass
from functools import lru_cache
from typing import Callable, Iterable, Literal, NamedTuple, Optional, Tuple, Union
from typing import Callable, Iterable, Literal, NamedTuple, Optional, Tuple, Union, cast

import pandas as pd

Expand All @@ -17,6 +17,8 @@
"DSeparationJudgement",
]

DEFAULT_SIGNIFICANCE = 0.01


class VermaConstraint(NamedTuple):
"""Represent a Verma constraint."""
Expand Down Expand Up @@ -86,6 +88,22 @@ def get_conditional_independence_tests() -> dict[CITest, CITestFunc]:
}


class CITestTuple(NamedTuple):
"""A tuple containing the results from a PGMPy conditional independency test.
Note that continuous tests such as :func:`pgmpy.estimators.CITests.pearsonr`
do not have an associated _degrees of freedom_ (dof), so this field is set
to none in those cases.
"""

statistic: float
p_value: float
dof: Optional[float] = None


CITestResult = Union[CITestTuple, bool]


@dataclass(frozen=True)
class DSeparationJudgement:
"""
Expand Down Expand Up @@ -130,10 +148,12 @@ def is_canonical(self) -> bool:
def test(
self,
df: pd.DataFrame,
*,
boolean: bool = False,
method: Optional[CITest] = None,
significance_level: Optional[float] = None,
) -> Union[Tuple[float, int], Tuple[float, int, float], bool]:
_method_checked: bool = False,
) -> Union[bool, CITestTuple]:
"""Test for conditional independence, given some data.
:param df: A dataframe.
Expand Down Expand Up @@ -168,24 +188,42 @@ def test(
f"conditional {c.name} ({type(c.name)}) not in columns {df.columns}"
)
if significance_level is None:
significance_level = 0.01
significance_level = DEFAULT_SIGNIFICANCE

method = _ensure_method(
method, df[[self.left.name, self.right.name, *(c.name for c in self.conditions)]]
method,
df[[self.left.name, self.right.name, *(c.name for c in self.conditions)]],
skip=_method_checked,
)
tests: dict[CITest, CITestFunc] = get_conditional_independence_tests()
func: CITestFunc = tests[method]
return func(
result = func(
X=self.left.name,
Y=self.right.name,
Z={condition.name for condition in self.conditions},
data=df,
boolean=boolean,
significance_level=significance_level,
)
if boolean:
return cast(bool, result)
# Person's correlation returns a pair with the first element being the Person's correlation
# and the second being the p-value. The other methods return a triple with the first element
# being the Chi^2 statistic, the second being the p-value, and the third being the degrees of
# freedom.
if method == "pearson":
statistic, p_value = result
dof = None
else:
statistic, p_value, dof = result
return CITestTuple(statistic=statistic, p_value=p_value, dof=dof)


def _ensure_method(method: Optional[CITest], df: pd.DataFrame) -> CITest:
def _ensure_method(method: Optional[CITest], df: pd.DataFrame, skip: bool = False) -> CITest:
if skip:
if method is None:
raise RuntimeError
return method
# TODO extend to discrete but more than 2.
# see https://stats.stackexchange.com/questions/12273/how-to-test-if-my-data-is-discrete-or-continuous
# TODO what happens when some variables are binary but others are continous?
Expand Down
56 changes: 53 additions & 3 deletions tests/test_algorithm/test_conditional_independencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,26 @@

"""Test getting conditional independencies (and related)."""

import typing
import unittest
from typing import Iterable, Set

from pgmpy.estimators import CITests

from y0.algorithm.conditional_independencies import (
are_d_separated,
get_conditional_independencies,
)
from y0.dsl import AA, B, C, D, E, F, G, Variable
from y0.examples import Example, d_separation_example, examples
from y0.dsl import AA, B, C, D, E, F, G, Variable, X, Y
from y0.examples import (
Example,
d_separation_example,
examples,
frontdoor_backdoor_example,
frontdoor_example,
)
from y0.graph import NxMixedGraph, iter_moral_links
from y0.struct import DSeparationJudgement
from y0.struct import CITestTuple, DSeparationJudgement


class TestDSeparation(unittest.TestCase):
Expand Down Expand Up @@ -218,3 +227,44 @@ def test_examples(self):
with self.subTest(name=example.name):
self.maxDiff = None
self.assert_example_has_judgements(example)

def test_ci_test_continuous(self):
"""Test conditional independency test on continuous data."""
data = frontdoor_example.generate_data(500) # continuous
judgement = DSeparationJudgement(
left=X,
right=Y,
separated=...,
conditions=(),
)
test_result_bool = judgement.test(data, method="pearson", boolean=True)
self.assertIsInstance(test_result_bool, bool)

test_result_tuple = judgement.test(data, method="pearson", boolean=False)
self.assertIsInstance(test_result_tuple, CITestTuple)
self.assertIsNone(test_result_tuple.dof)

# Test that an error is thrown if using a discrete test on continuous data
with self.assertRaises(ValueError):
judgement.test(data, method="chi-square", boolean=True)

def test_ci_test_discrete(self):
"""Test conditional independency test on discrete data."""
data = frontdoor_backdoor_example.generate_data(500) # discrete
judgement = DSeparationJudgement(
left=X,
right=Y,
separated=...,
conditions=(),
)
for method in typing.get_args(CITests):
test_result_bool = judgement.test(data, method=method, boolean=True)
self.assertIsInstance(test_result_bool, bool)

test_result_tuple = judgement.test(data, method=method, boolean=False)
self.assertIsInstance(test_result_tuple, CITestTuple)
self.assertIsNotNone(test_result_tuple.dof)

# Test that an error is thrown if using a continous test on discrete data
with self.assertRaises(ValueError):
judgement.test(data, method="pearson", boolean=True)

0 comments on commit 1674709

Please sign in to comment.