Skip to content

Commit

Permalink
Fix condition on DiscreteReal Atomic for non-Range data type [fix #77].
Browse files Browse the repository at this point in the history
  • Loading branch information
Feras A. Saad committed Aug 21, 2020
1 parent e5acaf7 commit 704a722
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 20 deletions.
37 changes: 19 additions & 18 deletions src/spn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .sym_util import are_identical
from .sym_util import get_union
from .sym_util import partition_list_blocks
from .sym_util import partition_finite_real_contiguous
from .sym_util import powerset
from .sym_util import sympify_number

Expand Down Expand Up @@ -722,35 +723,36 @@ def logprob_finite__(self, values):
def logprob_interval__(self, values):
raise NotImplementedError()

def values_to_support(self, values):
def flatten_values_contiguous(self, values):
if isinstance(values, Interval):
return values
return [values]
if isinstance(values, FiniteReal):
assert isinstance(self, DiscreteLeaf)
(low, high) = (min(values), max(values))
# https://github.com/probcomp/sum-product-dsl/issues/77
if sorted(values) != list(range(low, high+1)):
assert False, 'Cannot handle non-contiguous condition'
return Range(low, high)
blocks = partition_finite_real_contiguous(values)
return [Range(min(v), max(v)) for v in blocks]
if isinstance(values, Union):
subvalues = (self.flatten_values_contiguous(v) for v in values)
return list(chain(*subvalues))
assert False

def condition__(self, event):
interval = event.solve()
values = self.support & interval
weight = self.logprob_values__(values)
values_set = self.support & interval
weight = self.logprob_values__(values_set)
# Probability zero event.
if isinf_neg(weight):
raise ValueError('Conditioning event "%s" has probability zero'
% (str(event)))
# Condition on support.
if values == self.support:
if values_set == self.support:
return self
# Condition on Interval.
if isinstance(values, (FiniteReal, Interval)):
support = self.values_to_support(values)
return (type(self))(self.symbol, self.dist, support, True, self.env)
# Condition on union of sets.
if isinstance(values, Union):
# Flatten the set.
values = self.flatten_values_contiguous(values_set)
# Condition on a single contiguous set.
if len(values) == 0:
return (type(self))(self.symbol, self.dist, values[0], True, self.env)
# Condition on a union of contiguous set.
else:
weights_unorm = [self.logprob_values__(v) for v in values]
indexes = [i for i, w in enumerate(weights_unorm) if not isinf_neg(w)]
if not indexes:
Expand All @@ -759,9 +761,8 @@ def condition__(self, event):
# TODO: Normalize the weights with greater precision, e.g.,
# https://stats.stackexchange.com/questions/66616/converting-normalizing-very-small-likelihood-values-to-probability
weights = lognorm([weights_unorm[i] for i in indexes])
supports = [self.values_to_support(v) for v in values]
children = [
(type(self))(self.symbol, self.dist, supports[i], True, self.env)
(type(self))(self.symbol, self.dist, values[i], True, self.env)
for i in indexes
]
return SumSPN(children, weights) if 1 < len(indexes) else children[0]
Expand Down
15 changes: 13 additions & 2 deletions tests/test_real_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,22 @@ def test_poisson():
# Condition on single point.
assert allclose(0, spn.condition(X << {2}).logprob(X<<{2}))

@pytest.mark.xfail(strict=True, reason='https://github.com/probcomp/sum-product-dsl/issues/77')
def test_condition_non_contiguous():
X = Id('X')
spn = X >> poisson(mu=5)
spn.condition(X << {1,2,3,5})
# FiniteSet.
for c in [{0,2,3}, {-1,0,2,3}, {-1,0,2,3,'z'}]:
spn_condition = spn.condition((X << c))
assert isinstance(spn_condition, SumSPN)
assert allclose(0, spn_condition.children[0].logprob(X<<{0}))
assert allclose(0, spn_condition.children[1].logprob(X<<{2,3}))
# FiniteSet or Interval.
spn_condition = spn.condition((X << {-1,'x',0,2,3}) | (X > 7))
assert isinstance(spn_condition, SumSPN)
assert len(spn_condition.children) == 3
assert allclose(0, spn_condition.children[0].logprob(X<<{0}))
assert allclose(0, spn_condition.children[1].logprob(X<<{2,3}))
assert allclose(0, spn_condition.children[2].logprob(X>7))

def test_randint():
X = Id('X')
Expand Down

0 comments on commit 704a722

Please sign in to comment.