Skip to content

Commit

Permalink
Improvements for causal workflow (#205)
Browse files Browse the repository at this point in the history
This PR makes two improvements:

1. Provides an efficient implementation to find nodes on causal paths
between sources and targets
2. Provides a first-party function to apply Evans' simplification rules
to an ADMG
  • Loading branch information
cthoyt committed Jan 27, 2024
1 parent 1674709 commit 5ca9db1
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 6 deletions.
41 changes: 35 additions & 6 deletions src/y0/algorithm/simplify_latent.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
# -*- coding: utf-8 -*-

"""Implement Robin Evans' simplification algorithms.
"""Implement Robin Evans' simplification algorithms from [evans2012]_ and [evans2016]_.
.. seealso:: https://www.fields.utoronto.ca/programs/scientific/11-12/graphicmodels/Evans.pdf slides 34-43
.. [evans2016] `Graphs for margins of Bayesian networks <https://arxiv.org/abs/1408.1809>`_
.. [evans2012] `Constraints on marginalised DAGs
<https://www.fields.utoronto.ca/programs/scientific/11-12/graphicmodels/Evans.pdf>`_
"""

import itertools as itt
import logging
from typing import Iterable, Mapping, NamedTuple, Optional, Set, Tuple
from typing import Iterable, Mapping, NamedTuple, Optional, Set, Tuple, Union

import networkx as nx

from ..dsl import Variable
from ..graph import DEFAULT_TAG
from ..graph import DEFAULT_TAG, NxMixedGraph, _ensure_set

__all__ = [
"evans_simplify",
"simplify_latent_dag",
"SimplifyResults",
"remove_widow_latents",
Expand All @@ -28,6 +31,32 @@
DEFAULT_SUFFIX = "_prime"


def evans_simplify(
graph: NxMixedGraph,
*,
latents: Union[None, Variable, Iterable[Variable]] = None,
tag: Optional[str] = None,
) -> NxMixedGraph:
"""Reduce the ADMG based on Evans' simplification rules in [evans2012]_ and [evans2016]_.
:param graph: an NxMixedGraph
:param latents: Additional variables to mark as latent, in addition to the
ones created by undirected edges
:param tag: The tag for which variables are latent
:return: the new graph after simplification
"""
if tag is None:
tag = DEFAULT_TAG
lv_dag = NxMixedGraph.to_latent_variable_dag(graph, tag=tag)
if latents is not None:
latents = _ensure_set(latents)
for node, data in lv_dag.nodes(data=True):
if Variable(node) in latents:
data[tag] = True
simplify_results = simplify_latent_dag(lv_dag, tag=tag)
return NxMixedGraph.from_latent_variable_dag(simplify_results.graph, tag=tag)


class SimplifyResults(NamedTuple):
"""Results from the simplification of a LV-DAG."""

Expand All @@ -37,8 +66,8 @@ class SimplifyResults(NamedTuple):
unidirectional_latents: Set[Variable]


def simplify_latent_dag(graph: nx.DiGraph, tag: Optional[str] = None):
"""Apply Robin Evans' four rules in succession."""
def simplify_latent_dag(graph: nx.DiGraph, *, tag: Optional[str] = None) -> SimplifyResults:
"""Apply Robin Evans' four rules in succession, in place from [evans2012]_ and [evans2016]_."""
if tag is None:
tag = DEFAULT_TAG

Expand Down
53 changes: 53 additions & 0 deletions src/y0/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,3 +899,56 @@ def iter_moral_links(graph: NxMixedGraph) -> Iterable[Tuple[Variable, Variable]]
yield from chain.from_iterable(
combinations(graph.directed.predecessors(node), 2) for node in graph.nodes()
)


def get_nodes_in_directed_paths(
graph: NxMixedGraph,
sources: Union[Variable, Set[Variable]],
targets: Union[Variable, Set[Variable]],
) -> Set[Variable]:
"""Get all nodes appearing in directed paths from sources to targets.
:param graph: an NxMixedGraph
:param sources: source nodes
:param targets: target nodes
:return: the nodes on all causal paths from sources to targets
"""
sources = _ensure_set(sources)
targets = _ensure_set(targets)
if nx.is_directed_acyclic_graph(graph.directed):
return _get_nodes_in_directed_paths_dag(graph.directed, sources, targets)
else:
# note, this is a simpler implementation can use :func:`nx.all_simple_paths`,
# but it is less efficient since it requires potentially calculating the same
# paths over and over again.
return _get_nodes_in_directed_paths_cyclic(graph.directed, sources, targets)


def _get_nodes_in_directed_paths_dag(
graph: nx.DiGraph, sources: set[Variable], targets: set[Variable]
) -> set[Variable]:
tc: nx.DiGraph = nx.transitive_closure_dag(graph)
rv = {
node
for node in graph.nodes()
if any(
tc.has_edge(source, node) and tc.has_edge(node, target)
for source, target in itt.product(sources, targets)
)
}
for source, target in itt.product(sources, targets):
if tc.has_edge(source, target):
rv.add(source)
rv.add(target)
return rv


def _get_nodes_in_directed_paths_cyclic(
graph: nx.DiGraph, sources: set[Variable], targets: set[Variable]
) -> set[Variable]:
return {
node
for source, target in itt.product(sources, targets)
for causal_path in nx.all_simple_paths(graph, source, target)
for node in causal_path
}
15 changes: 15 additions & 0 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
DEFAULT_TAG,
DEFULT_PREFIX,
NxMixedGraph,
get_nodes_in_directed_paths,
is_a_fixable,
is_markov_blanket_shielded,
is_p_fixable,
Expand Down Expand Up @@ -657,3 +658,17 @@ def test_graph_without_latents(self):
expected = BayesianNetwork(ebunch=[("X", "Y")])
actual = graph.to_pgmpy_bayesian_network()
self.assert_bayesian_equal(expected, actual)


class TestUtilities(unittest.TestCase):
"""Test utility functions."""

def test_nodes_in_paths(self):
"""Test getting nodes in paths."""
graph = NxMixedGraph.from_edges(directed=[(X, Z), (Z, Y)])
self.assertEqual({X, Y, Z}, get_nodes_in_directed_paths(graph, X, Y))
self.assertEqual({X, Z}, get_nodes_in_directed_paths(graph, X, Z))
self.assertEqual({Z, Y}, get_nodes_in_directed_paths(graph, Z, Y))
self.assertEqual(set(), get_nodes_in_directed_paths(graph, Z, X))
self.assertEqual(set(), get_nodes_in_directed_paths(graph, Y, Z))
self.assertEqual(set(), get_nodes_in_directed_paths(graph, Y, X))

0 comments on commit 5ca9db1

Please sign in to comment.