Skip to content

Commit

Permalink
Use variables for all LV-DAGs (#219)
Browse files Browse the repository at this point in the history
Closes #218
  • Loading branch information
cthoyt committed Apr 25, 2024
1 parent d2e4498 commit 2cb5fc2
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 14 deletions.
14 changes: 13 additions & 1 deletion src/y0/algorithm/simplify_latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def evans_simplify(
if latents is not None:
latents = _ensure_set(latents)
for node, data in lv_dag.nodes(data=True):
if Variable(node) in latents:
if node in latents:
data[tag] = True
simplify_results = simplify_latent_dag(lv_dag, tag=tag)
return NxMixedGraph.from_latent_variable_dag(simplify_results.graph, tag=tag)
Expand All @@ -70,6 +70,8 @@ def simplify_latent_dag(graph: nx.DiGraph, *, tag: Optional[str] = None) -> Simp
if tag is None:
tag = DEFAULT_TAG

_assert_variable_nodes(graph)

_ = transform_latents_with_parents(graph, tag=tag)
_, widows = remove_widow_latents(graph, tag=tag)
_, unidirectional_latents = remove_unidirectional_latents(graph, tag=tag)
Expand Down Expand Up @@ -213,6 +215,7 @@ def remove_redundant_latents(
:param tag: The tag for which variables are latent
:returns: The graph, modified in place
"""
_assert_variable_nodes(graph)
remove = set(_iter_redundant_latents(graph, tag=tag))
graph.remove_nodes_from(remove)
return graph, remove
Expand All @@ -229,3 +232,12 @@ def _iter_redundant_latents(graph: nx.DiGraph, *, tag: Optional[str] = None) ->
elif left_children < right_children:
# if left's children are a proper subset of right's children, we don't need left
yield left


def _assert_variable_nodes(graph: nx.DiGraph) -> None:
"""Assert that all nodes in the graph are variables."""
for node in graph.nodes:
if not isinstance(node, Variable):
raise TypeError(
f"latent variable dags must contain Variable objects as nodes. Got {type(node)}"
)
2 changes: 2 additions & 0 deletions src/y0/dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ class Variable(Element):
star: Optional[bool] = None

def __post_init__(self):
if not isinstance(self.name, str):
raise TypeError(f"Names must be strings: {self.name}")
if self.name in {"P", "Q", "PP"}:
raise ValueError(f"trust me, {self.name} is a bad variable name.")

Expand Down
11 changes: 5 additions & 6 deletions src/y0/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,15 +742,14 @@ def _latent_dag(
if prefix is None:
prefix = DEFULT_PREFIX

str_di_edges = [(u.name, v.name) for u, v in di_edges]
str_bi_edges = [(u.name, v.name) for u, v in bi_edges]
bi_edges_list = list(bi_edges)

rv = nx.DiGraph()
rv.add_nodes_from(itt.chain.from_iterable(str_bi_edges))
rv.add_edges_from(str_di_edges)
rv.add_nodes_from(itt.chain.from_iterable(bi_edges_list))
rv.add_edges_from(di_edges)
nx.set_node_attributes(rv, False, tag)
for i, (u, v) in enumerate(sorted(str_bi_edges), start=start):
latent_node = f"{prefix}{i}"
for i, (u, v) in enumerate(sorted(bi_edges_list), start=start):
latent_node = Variable(f"{prefix}{i}")
rv.add_node(latent_node, **{tag: True})
rv.add_edge(latent_node, u)
rv.add_edge(latent_node, v)
Expand Down
15 changes: 8 additions & 7 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import networkx as nx
from pgmpy.models import BayesianNetwork

from y0.dsl import A, B, C, D, M, Variable, X, Y, Z
from y0.dsl import V1, V2, V3, V4, A, B, C, D, M, Variable, X, Y, Z
from y0.examples import SARS_SMALL_GRAPH, Example, examples, napkin, verma_1
from y0.graph import (
DEFAULT_TAG,
Expand Down Expand Up @@ -71,7 +71,8 @@ def assert_labeled_convertable(
labeled_dag = graph.to_latent_variable_dag(prefix=prefix, tag=tag)
for node in labeled_dag:
self.assertIn(tag, labeled_dag.nodes[node], msg=f"Node: {node}")
self.assertEqual(node.startswith(prefix), labeled_dag.nodes[node][tag])
self.assertIsInstance(node, Variable)
self.assertEqual(node.name.startswith(prefix), labeled_dag.nodes[node][tag])

self.assertEqual(labeled_edges, set(labeled_dag.edges()))

Expand All @@ -84,11 +85,11 @@ def test_convertable(self):
(
verma_1,
{
("V1", "V2"),
("V2", "V3"),
("V3", "V4"),
(f"{DEFULT_PREFIX}0", "V2"),
(f"{DEFULT_PREFIX}0", "V4"),
(V1, V2),
(V2, V3),
(V3, V4),
(Variable(f"{DEFULT_PREFIX}0"), V2),
(Variable(f"{DEFULT_PREFIX}0"), V4),
},
),
]:
Expand Down

0 comments on commit 2cb5fc2

Please sign in to comment.