Skip to content

Commit

Permalink
Implement canonicalize() algorithm (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
cthoyt committed Jan 22, 2021
1 parent 580c5ad commit 1140eca
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/y0/dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,3 +613,10 @@ def _iter_variables(self) -> Iterable[Variable]:


A, B, C, D, Q, S, T, W, X, Y, Z = map(Variable, 'ABCDQSTWXYZ') # type: ignore


def _upgrade_ordering(variables: Sequence[Union[str, Variable]]) -> Sequence[Variable]:
return tuple(
Variable(variable) if isinstance(variable, str) else variable
for variable in variables
)
3 changes: 3 additions & 0 deletions src/y0/mutate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# -*- coding: utf-8 -*-

"""Functions that mutate probability expressions."""
116 changes: 116 additions & 0 deletions src/y0/mutate/canonicalize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# -*- coding: utf-8 -*-

"""Implementation of the canonicalization algorithm."""

from typing import Sequence, Union

from ..dsl import Distribution, Expression, Fraction, Probability, Product, Sum, Variable, _upgrade_ordering
from ..predicates import has_markov_postcondition

__all__ = [
'canonicalize',
]


def canonicalize(expression: Expression, ordering: Sequence[Union[str, Variable]]) -> Expression:
"""Canonicalize an expression that meets the markov condition with respect to the given ordering.
:param expression: An expression to canonicalize
:param ordering: A toplogical ordering
:return: A canonical expression
:raises ValueError: if the expression does not pass the markov postcondition
:raises ValueError: if the ordering has duplicates
"""
if not has_markov_postcondition(expression):
raise ValueError(f'can not canonicalize expression that does not have the markov postcondition: {expression}')

ordering = _upgrade_ordering(ordering)
if len(set(ordering)) != len(ordering):
raise ValueError(f'ordering has duplicates: {ordering}')

canonicalizer = Canonicalizer(ordering)
return canonicalizer.canonicalize(expression)


def _sort_probability_key(probability: Probability) -> str:
return probability.distribution.children[0].name


class Canonicalizer:
"""A data structure to support application of the canonicalize algorithm."""

def __init__(self, ordering: Sequence[Variable]) -> None:
"""Initialize the canonicalizer.
:param ordering: A topological ordering over the variables appearing in the expression.
"""
self.ordering = ordering
self.ordering_level = {
variable: level
for level, variable in enumerate(self.ordering)
}

def _canonicalize_probability(self, expression: Probability) -> Probability:
return Probability(Distribution(
children=expression.distribution.children,
parents=tuple(sorted(expression.distribution.parents, key=self.ordering_level.__getitem__)),
))

def canonicalize(self, expression: Expression) -> Expression:
"""Canonicalize an expression.
:param expression: An uncanonicalized expression
:return: A canonicalized expression
:raises TypeError: if an object with an invalid type is passed
"""
if isinstance(expression, Probability): # atomic
return self._canonicalize_probability(expression)
elif isinstance(expression, Sum):
if isinstance(expression.expression, Probability): # also atomic
return expression

return Sum(
expression=self.canonicalize(expression.expression),
ranges=expression.ranges,
)
elif isinstance(expression, Product):
probabilities = []
other = []
for subexpr in expression.expressions:
subexpr = self.canonicalize(subexpr)
if isinstance(subexpr, Probability):
probabilities.append(subexpr)
else:
other.append(subexpr)
probabilities = sorted(probabilities, key=_sort_probability_key)

# If other is empty, this is also atomic
other = sorted(other, key=self._nonatomic_key)
return Product((*probabilities, *other))
elif isinstance(expression, Fraction):
return Fraction(
numerator=self.canonicalize(expression.numerator),
denominator=self.canonicalize(expression.denominator),
)
else:
raise TypeError

def _nonatomic_key(self, expression: Expression):
"""Generate a sort key for a *canonical* expression.
:param expression: A canonical expression
:returns: A tuple in which the first element is the integer priority for the expression
and the rest depends on the expression type.
:raises TypeError: if an invalid expression type is given
"""
if isinstance(expression, Probability):
return 0, expression.distribution.children[0].name
elif isinstance(expression, Sum):
return 1, *self._nonatomic_key(expression.expression)
elif isinstance(expression, Product):
inner_keys = (self._nonatomic_key(sexpr) for sexpr in expression.expressions)
return 2, *inner_keys
elif isinstance(expression, Fraction):
return 3, self._nonatomic_key(expression.numerator), self._nonatomic_key(expression.denominator)
else:
raise TypeError
3 changes: 3 additions & 0 deletions tests/test_mutate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# -*- coding: utf-8 -*-

"""Tests for functions that mutate probability expressions."""
105 changes: 105 additions & 0 deletions tests/test_mutate/test_canonicalize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# -*- coding: utf-8 -*-

"""Tests for the canonicalization algorithm."""

import itertools as itt
import unittest
from typing import Sequence

from y0.dsl import A, B, C, D, Expression, P, Sum, Variable, X, Y, Z
from y0.mutate.canonicalize import canonicalize


class TestCanonicalize(unittest.TestCase):
"""Tests for the canonicalization of a simplified algorithm."""

def test_canonicalize_raises(self):
"""Test a value error is raised for non markov-conditioning expressions."""
with self.assertRaises(ValueError):
canonicalize(P(A, B, C), [A, B, C])

def assert_canonicalize(self, expected: Expression, expression: Expression, ordering: Sequence[Variable]) -> None:
"""Check that the expression is canonicalized properly given an ordering."""
with self.subTest(expr=str(expression), ordering=', '.join(variable.name for variable in ordering)):
actual = canonicalize(expression, ordering)
self.assertEqual(
expected, actual,
msg=f'\nExpected: {str(expression)}\nActual: {str(actual)}',
)

def test_atomic(self):
"""Test canonicalization of atomic expressions."""
for expected, expression, ordering in [
(P(A), P(A), [A]),
(P(A | B), P(A | B), [A, B]),
(P(A | (B, C)), P(A | (B, C)), [A, B, C]),
(P(A | (B, C)), P(A | (C, B)), [A, B, C]),
]:
self.assert_canonicalize(expected, expression, ordering)

expected = P(A | (B, C, D))
for b, c, d in itt.permutations((B, C, D)):
expression = P(A | (b, c, d))
self.assert_canonicalize(expected, expression, [A, B, C, D])

def test_derived_atomic(self):
"""Test canonicalizing."""
# Sum
expected = expression = Sum(P(A))
self.assert_canonicalize(expected, expression, [A])

# Simple product (only atomic)
expected = P(A) * P(B) * P(C)
for a, b, c in itt.permutations((P(A), P(B), P(C))):
expression = a * b * c
self.assert_canonicalize(expected, expression, [A, B, C])

# Sum with simple product (only atomic)
expected = Sum(P(A) * P(B) * P(C))
for a, b, c in itt.permutations((P(A), P(B), P(C))):
expression = Sum(a * b * c)
self.assert_canonicalize(expected, expression, [A, B, C])

# Fraction
expected = expression = P(A) / P(B)
self.assert_canonicalize(expected, expression, [A, B])

# Fraction with simple products (only atomic)
expected = (P(A) * P(B) * P(C)) / (P(X) * P(Y) * P(Z))
for (a, b, c), (x, y, z) in itt.product(
itt.permutations((P(A), P(B), P(C))),
itt.permutations((P(X), P(Y), P(Z))),
):
expression = (a * b * c) / (x * y * z)
self.assert_canonicalize(expected, expression, [A, B, C, X, Y, Z])

def test_mixed(self):
"""Test mixed expressions."""
expected = expression = P(A) * Sum(P(B))
self.assert_canonicalize(expected, expression, [A, B])

expected = P(A) * Sum(P(B)) * Sum(P(C))
for a, b, c in itt.permutations((P(A), Sum(P(B)), Sum(P(C)))):
expression = a * b * c
self.assert_canonicalize(expected, expression, [A, B, C])

expected = P(D) * Sum(P(A) * P(B) * P(C))
for a, b, c in itt.permutations((P(A), P(B), P(C))):
sum_expr = Sum(a * b * c)
for left, right in itt.permutations((P(D), sum_expr)):
self.assert_canonicalize(expected, left * right, [A, B, C, D])

expected = P(X) * Sum(P(A) * P(B)) * Sum(P(C) * P(D))
for (a, b), (c, d) in itt.product(
itt.permutations((P(A), P(B))),
itt.permutations((P(C), P(D))),
):
sexpr = Sum(a * b) * Sum(c * d)
self.assert_canonicalize(expected, sexpr * P(X), [A, B, C, D])
self.assert_canonicalize(expected, P(X) * sexpr, [A, B, C, D])

expected = expression = Sum(P(A) / P(B))
self.assert_canonicalize(expected, expression, [A, B])

expected = expression = Sum(P(A) / Sum(P(B))) * Sum(P(A) / Sum(P(B) / P(C)))
self.assert_canonicalize(expected, expression, [A, B, C])

0 comments on commit 1140eca

Please sign in to comment.