Skip to content

Commit

Permalink
Progress toward management of disjoint-union implementation (Github #12
Browse files Browse the repository at this point in the history
…).

Implements the following checklist:
    #12 (comment)
  • Loading branch information
Feras A. Saad committed Feb 19, 2020
1 parent 985be46 commit da75fcd
Show file tree
Hide file tree
Showing 4 changed files with 270 additions and 149 deletions.
67 changes: 64 additions & 3 deletions src/dnf.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
# Copyright 2020 MIT Probabilistic Computing Project.
# See LICENSE.txt

from spn.transforms import EventAnd
from spn.transforms import EventBasic
from spn.transforms import EventOr
from itertools import combinations

from sympy import Intersection

from .sym_util import EmptySet

from .transforms import EventAnd
from .transforms import EventBasic
from .transforms import EventOr

def factor_dnf(event):
lookup = {s:s for s in event.symbols()}
Expand Down Expand Up @@ -61,3 +67,58 @@ def factor_dnf_symbols(event, lookup):
return events

assert False, 'Invalid DNF event: %s' % (event,)

def solve_dnf_symbolwise(dnf_factor, indexes=None):
# Given a factored event (in DNF) where distinct symbols have
# distinct keys, returns a dictionary of dictionary R
# R[i][s] is the solution of the events in the i-th DNF clause with
# symbol s.
#
# For example, if e is any predicate
# event = (e(X0) & e(X1) & ~e(X2)) | (~e(X1) & e(X2) & e(X3) & ~e(X3)))
# The output is
# R = [
# { // First clause
# X0: solve(e(X0)),
# X1: solve(e(X1)),
# X2: solve(e(X2))},
# { // Second clause
# X0: solve(~e(X1)),
# X2: solve(e(X2)),
# X3: solve(e(X3) & ~e(X3))},
# ]
solutions = [None]*len(dnf_factor)
for i, event_mapping in enumerate(dnf_factor):
if indexes is not None and i not in indexes:
continue
solutions[i] = {}
for symbol, ev in event_mapping.items():
solutions[i][symbol] = ev.solve()
return solutions

def find_dnf_non_disjoint_clauses(event, indexes=None):
# Given an event in DNF, returns list of pairs of clauses that appear
# in indexes and whose solutions have a non-empty intersections.
dnf_factor = factor_dnf(event)
solutions = solve_dnf_symbolwise(dnf_factor, indexes)
non_disjoint = []

clauses = range(len(dnf_factor)) if indexes is None else indexes
for i, j in combinations(clauses, 2):
# Intersections of events in i with those in j.
intersections_symbols_i = {
symbol: Intersection(solutions[i][symbol], solutions[j][symbol])
if (symbol in solutions[j]) else solutions[i][symbol]
for symbol in solutions[i]
}
# Intersections of events only in j.
intersections_symbols_j = {
symbol: solutions[j][symbol]
for symbol in solutions[j] if symbol not in solutions[i]
}
# Clauses are non disjoint if all intersections are non empty.
intersections = {**intersections_symbols_i, **intersections_symbols_j}
if all(s is not EmptySet for s in intersections.values()):
non_disjoint.append((i, j))

return non_disjoint
76 changes: 26 additions & 50 deletions src/spn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sympy import Union

from .dnf import factor_dnf_symbols
from .dnf import find_dnf_non_disjoint_clauses

from .math_util import allclose
from .math_util import flip
Expand Down Expand Up @@ -281,7 +282,7 @@ def logpdf(self, x):
return logsumexp(logps)

def logprob(self, event):
return self.logprob_disjoint_union(event)
return self.logprob_inclusion_exclusion(event)

def logprob_inclusion_exclusion(self, event):
# Adopting Inclusion--Exclusion principle:
Expand Down Expand Up @@ -313,57 +314,32 @@ def logprob_inclusion_exclusion(self, event):
logp_neg = logsumexp(logps_neg) if logps_neg else -inf
return logdiffexp(logp_pos, logp_neg)

def logprob_disjoint_union(self, event):
# Adopting disjoint union principle.
# Disjoint union algorithm (yields mixture of products).
expr_dnf = event.to_dnf()
dnf_factor = factor_dnf_symbols(expr_dnf, self.lookup)
# Obtain the n disjoint clauses.
n_clauses = len(dnf_factor)
clauses = [
self.make_disjoint_conjunction(dnf_factor, i)
for i in range(n_clauses)
]
# Construct the ProductSPN weights.
ws = [self.get_clause_weight(clause) for clause in clauses]
return logsumexp(ws)

def condition(self, event):
# Disjoint union algorithm (yields mixture of products).
expr_dnf = event.to_dnf()
dnf_factor = factor_dnf_symbols(expr_dnf, self.lookup)
# Obtain the n disjoint clauses.
n_clauses = len(dnf_factor)
clauses = [
self.make_disjoint_conjunction(dnf_factor, i)
for i in range(n_clauses)
]
# Construct the ProductSPN weights.
ws = [self.get_clause_weight(clause) for clause in clauses]
indexes = [i for (i, w) in enumerate(ws) if not isinf_neg(w)]
event_dnf = event.to_dnf()

# Discard all probability zero clauses or fail if all are.
dnf_factor = factor_dnf_symbols(event_dnf, self.lookup)
logps = [self.get_clause_weight(clause) for clause in dnf_factor]
indexes = [i for (i, lp) in enumerate(logps) if not isinf_neg(lp)]
if not indexes:
raise ValueError('Conditioning event "%s" has probability zero' %
(event,))
weights = lognorm([ws[i] for i in indexes])
# Construct the new ProductSPNs.
ds = [self.get_clause_conditioned(clauses[i]) for i in indexes]
products = [ProductSPN(d) for d in ds]
if len(products) == 1:
return products[0]
# Return SumSPN of the products.
return SumSPN(products, weights)

def make_disjoint_conjunction(self, dnf_factor, i):
clause = dict(dnf_factor[i])
for j in range(i):
for k in dnf_factor[j]:
if k in clause:
clause[k] &= (~dnf_factor[j][k])
else:
clause[k] = (~dnf_factor[j][k])
return clause

def get_clause_conditioned(self, clause):
raise ValueError('Conditioning event "%s" has probability zero'
% (str(event),))

# Fail if remaining clauses are not pairwise disjoint (Github #12).
non_disjoint_clauses = find_dnf_non_disjoint_clauses(event, indexes)
if non_disjoint_clauses:
raise ValueError('Cannot condition Product on a disjunction'
'with non-disjoint clauses: %s, %s'
% (str(event), str(non_disjoint_clauses)))

# Return a sum of products.
assert allclose(logsumexp(logps), self.logprob(event))
weights = lognorm([logps[i] for i in indexes])
childrens = [self.get_clause_children(dnf_factor[i]) for i in indexes]
products = [ProductSPN(children) for children in childrens]
return SumSPN(products, weights) if len(products) > 1 else products[0]

def get_clause_children(self, clause):
# Return children conditioned on a clause (one conjunction).
return [
spn.condition(clause[k]) if (k in clause) else spn
Expand Down
39 changes: 39 additions & 0 deletions tests/test_dnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from spn.dnf import factor_dnf
from spn.dnf import factor_dnf_symbols
from spn.dnf import find_dnf_non_disjoint_clauses

from spn.transforms import ExpNat
from spn.transforms import Identity
Expand Down Expand Up @@ -171,3 +172,41 @@ def test_factor_dnf_symbols_3():
assert dnf[1][0] == E
assert dnf[1][1] == G
assert dnf[1][2] == F

def test_find_dnf_non_disjoint_clauses():
X = Identity('X')
Y = Identity('Y')
Z = Identity('Z')

event = (X > 0) | (Y < 0)
overlaps = find_dnf_non_disjoint_clauses(event)
assert overlaps == [(0, 1)]

overlaps = find_dnf_non_disjoint_clauses(event, indexes=[1])
assert overlaps == []

overlaps = find_dnf_non_disjoint_clauses(event, indexes=[2])
assert overlaps == []

event = (X > 0) | ((X < 0) & (Y < 0))
overlaps = find_dnf_non_disjoint_clauses(event)
assert overlaps == []

event = ((X > 0) & (Z < 0)) | ((X < 0) & (Y < 0)) | ((X > 1))
overlaps = find_dnf_non_disjoint_clauses(event)
assert overlaps == [(0, 2)]

overlaps = find_dnf_non_disjoint_clauses(event, indexes=[1, 2])
assert overlaps == []

event = ((X > 0) & (Z < 0)) | ((X < 0) & (Y < 0)) | ((X > 1) & (Z > 1))
overlaps = find_dnf_non_disjoint_clauses(event)
assert overlaps == []

event = ((X**2 < 9)) | (1 < X)
overlaps = find_dnf_non_disjoint_clauses(event)
assert overlaps == [(0, 1)]

event = ((X**2 < 9) & (0 < X < 1)) | (1 < X)
overlaps = find_dnf_non_disjoint_clauses(event)
assert overlaps == []

0 comments on commit da75fcd

Please sign in to comment.