Skip to content

Commit

Permalink
Fix #12, implementation of Product.logprob_disjoint_union.
Browse files Browse the repository at this point in the history
This solution is based on a novel algorithm in dnf.event_to_disjoint_union
Test cases are also updated.
Should test these on NominalVariables as well.
  • Loading branch information
Feras A. Saad committed Feb 20, 2020
1 parent 4ce0caa commit b64f0cb
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 56 deletions.
28 changes: 28 additions & 0 deletions src/dnf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright 2020 MIT Probabilistic Computing Project.
# See LICENSE.txt

from functools import reduce
from itertools import chain
from itertools import combinations

from sympy import Intersection
Expand Down Expand Up @@ -122,3 +124,29 @@ def find_dnf_non_disjoint_clauses(event, indexes=None):
non_disjoint.append((i, j))

return non_disjoint

def event_to_disjoint_union(event):
event_dnf = event.to_dnf()
# Base case.
if isinstance(event_dnf, (EventBasic, EventAnd)):
return event_dnf
# Find indexes of pairs of clauses that overlap.
overlap = find_dnf_non_disjoint_clauses(event_dnf)
if not overlap:
return event_dnf
# Create the cascading negated clauses.
n_clauses = len(event_dnf.subexprs)
overlap_dict = {i : [prev for (prev, j) in overlap if (j == i)]
for i in range(n_clauses)
}
clauses_disjoint = [
reduce(
lambda state, event: state & ~event,
(event_dnf.subexprs[j] for j in overlap_dict[i]),
event_dnf.subexprs[i])
for i in range(n_clauses)
]
# Recursively find the solutions for each clause.
solutions = [event_to_disjoint_union(clause) for clause in clauses_disjoint]
# Return the merged solution.
return reduce(lambda a, b: a|b, solutions)
16 changes: 4 additions & 12 deletions src/spn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from sympy import Range
from sympy import Union

from .dnf import event_to_disjoint_union
from .dnf import factor_dnf_symbols
from .dnf import find_dnf_non_disjoint_clauses

Expand Down Expand Up @@ -313,25 +314,16 @@ def logprob_inclusion_exclusion(self, event):
return logdiffexp(logp_pos, logp_neg)

def condition(self, event):
event_dnf = event.to_dnf()

# Discard all probability zero clauses or fail if all are.
dnf_factor = factor_dnf_symbols(event_dnf, self.lookup)
clauses = event_to_disjoint_union(event)
dnf_factor = factor_dnf_symbols(clauses, self.lookup)
logps = [self.get_clause_weight(clause) for clause in dnf_factor]
assert allclose(logsumexp(logps), self.logprob(event))
indexes = [i for (i, lp) in enumerate(logps) if not isinf_neg(lp)]
if not indexes:
raise ValueError('Conditioning event "%s" has probability zero'
% (str(event),))

# Fail if remaining clauses are not pairwise disjoint (Github #12).
non_disjoint_clauses = find_dnf_non_disjoint_clauses(event, indexes)
if non_disjoint_clauses:
raise ValueError('Cannot condition Product on a disjunction'
'with non-disjoint clauses: %s, %s'
% (str(event), str(non_disjoint_clauses)))

# Return a sum of products.
assert allclose(logsumexp(logps), self.logprob(event))
weights = lognorm([logps[i] for i in indexes])
childrens = [self.get_clause_children(dnf_factor[i]) for i in indexes]
products = [ProductSPN(children) for children in childrens]
Expand Down
15 changes: 15 additions & 0 deletions tests/test_dnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest

from spn.dnf import event_to_disjoint_union
from spn.dnf import factor_dnf
from spn.dnf import factor_dnf_symbols
from spn.dnf import find_dnf_non_disjoint_clauses
Expand Down Expand Up @@ -231,3 +232,17 @@ def test_find_dnf_non_disjoint_clauses():
event = ((X**2 < 9) & (0 < X < 1)) | (1 < X)
overlaps = find_dnf_non_disjoint_clauses(event)
assert overlaps == []

def test_event_to_disjiont_union():
X = Identity('X')
Y = Identity('Y')
Z = Identity('Z')

for event in [
(X > 0) | (X < 3),
(X > 0) | (Y < 3),
((X > 0) & (Y < 1)) | ((X < 1) & (Y < 3)) | (Z < 0),
((X > 0) & (Y < 1)) | ((X < 1) & (Y < 3)) | (Z < 0) | ~(X <<{1, 3}),
]:
event_dnf = event_to_disjoint_union(event)
assert not find_dnf_non_disjoint_clauses(event_dnf)
92 changes: 48 additions & 44 deletions tests/test_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy
import sympy

from spn.dnf import event_to_disjoint_union
from spn.math_util import allclose
from spn.math_util import isinf_neg
from spn.math_util import logdiffexp
Expand Down Expand Up @@ -146,17 +147,15 @@ def test_product_condition_basic():
assert dXY_and.children[1].support == sympy.Interval.Ropen(0, 0.5)

# Condition on (X > 0) | (Y < 0.5)
# Cannot condition on non-disjoint union (Github #12).
event = (X > 0) | (Y < 0.5)
with pytest.raises(ValueError):
spn.condition((X > 0) | (Y < 0.5))
# assert isinstance(dXY_or, SumSPN)
# assert all(isinstance(d, ProductSPN) for d in dXY_or.children)
# assert allclose(dXY_or.logprob(X > 0),dXY_or.weights[0])
# samples = dXY_or.sample(100, rng)
# assert all(event.evaluate(sample) for sample in samples)

# Condition on a kosher disjoint union with one term in second clause.
dXY_or = spn.condition((X > 0) | (Y < 0.5))
assert isinstance(dXY_or, SumSPN)
assert all(isinstance(d, ProductSPN) for d in dXY_or.children)
assert allclose(dXY_or.logprob(X > 0),dXY_or.weights[0])
samples = dXY_or.sample(100, rng)
assert all(event.evaluate(sample) for sample in samples)

# Condition on a disjoint union with one term in second clause.
dXY_disjoint_one = spn.condition((X > 0) & (Y < 0.5) | (X <= 0))
assert isinstance(dXY_disjoint_one, SumSPN)
component_0 = dXY_disjoint_one.children[0]
Expand All @@ -173,7 +172,7 @@ def test_product_condition_basic():
assert component_1.children[1].symbol == Identity('Y')
assert not component_1.children[1].conditioned

# Condition on a kosher disjoint union with two terms in each clause
# Condition on a disjoint union with two terms in each clause
dXY_disjoint_two = spn.condition((X > 0) & (Y < 0.5) | ((X <= 0) & ~(Y < 3)))
assert isinstance(dXY_disjoint_two, SumSPN)
component_0 = dXY_disjoint_two.children[0]
Expand All @@ -191,10 +190,9 @@ def test_product_condition_basic():
assert component_1.children[1].conditioned
assert component_1.children[1].support == sympy.Interval(3, sympy.oo)

spn.condition((X > 0) & (Y < 0.5) | ((X <= 1) & ~(Y < 3)))

with pytest.raises(ValueError):
spn.condition((X > 0) & (Y < 0.5) | ((X <= 1) & (Y < 3)))
# Some various conditioning.
spn.condition((X > 0) & (Y < 0.5) | ((X <= 1) | ~(Y < 3)))
spn.condition((X > 0) & (Y < 0.5) | ((X <= 1) & (Y < 3)))

def test_product_condition_or_probabilithy_zero():
X = Identity('X')
Expand Down Expand Up @@ -224,34 +222,40 @@ def test_product_condition_or_probabilithy_zero():
assert spn_condition.children[1].conditioned
assert spn_condition.children[0].support == sympy.Interval(1, sympy.oo)

# This is a very important case we should be able to handle
# after fixing Github #12.
# Specifically, we have (X < 2) & ~(1 < exp(|3X**2|) is empty.
# Thus Y remains unconditioned
# and X is partitioned into (-oo, 0) U (0, oo) with equal weight.
with pytest.raises(ValueError):
event = (Exp(abs(3*X**2)) > 1) | ((Log(Y) < 0.5) & (X < 2))
spn_condition = spn.condition(event)
assert isinstance(spn_condition, ProductSPN)
assert isinstance(spn_condition.children[0], SumSPN)
assert spn_condition.children[0].weights == (-log(2), -log(2))
assert spn_condition.children[0].children[0].conditioned
assert spn_condition.children[0].children[1].conditioned
assert spn_condition.children[0].children[0].support \
== sympy.Interval.Ropen(-sympy.oo, 0)
assert spn_condition.children[0].children[1].support \
== sympy.Interval.Lopen(0, sympy.oo)
assert spn_condition.children[1].symbol == Y
assert not spn_condition.children[1].conditioned

@pytest.mark.xfail(
reason='https://github.com/probcomp/sum-product-dsl/issues/12',
strict=True)
def test_condition_non_disjoint_union_xfail():
# We have (X < 2) & ~(1 < exp(|3X**2|) is empty.
# Thus Y remains unconditioned,
# and X is partitioned into (-oo, 0) U (0, oo) with equal weight.
event = (Exp(abs(3*X**2)) > 1) | ((Log(Y) < 0.5) & (X < 2))
spn_condition = spn.condition(event)
assert isinstance(spn_condition, ProductSPN)
assert isinstance(spn_condition.children[0], SumSPN)
assert spn_condition.children[0].weights == (-log(2), -log(2))
assert spn_condition.children[0].children[0].conditioned
assert spn_condition.children[0].children[1].conditioned
assert spn_condition.children[0].children[0].support \
== sympy.Interval.Ropen(-sympy.oo, 0)
assert spn_condition.children[0].children[1].support \
== sympy.Interval.Lopen(0, sympy.oo)
assert spn_condition.children[1].symbol == Y
assert not spn_condition.children[1].conditioned

def test_product_disjoint_union_properties():
X = Identity('X')
Y = Identity('Y')
spn = ProductSPN([Norm(X, loc=0, scale=1), Norm(Y, loc=0, scale=2)])

event = ((X > 0) & (Y < 1)) | ((X < 1) & (Y < 3))
assert spn.logprob_disjoint_union(event) < 0
spn.condition(event)
Z = Identity('Z')
spn = ProductSPN([
Norm(X, loc=0, scale=1),
Norm(Y, loc=0, scale=2),
Norm(Z, loc=0, scale=2),
])

for event in [
(X > 0) | (X < 3),
(X > 0) | (Y < 3),
((X > 0) & (Y < 1)) | ((X < 1) & (Y < 3)) | ~(X<<{1, 2}),
((X > 0) & (Y < 1)) | ((X < 1) & (Y < 3)) | (Z < 0),
((X > 0) & (Y < 1)) | ((X < 1) & (Y < 3)) | (Z < 0) | ~(X <<{1, 3}),
]:
clauses = event_to_disjoint_union(event)
logps = [spn.logprob(s) for s in clauses.subexprs]
assert allclose(logsumexp(logps), spn.logprob(event))

0 comments on commit b64f0cb

Please sign in to comment.