Skip to content

Commit

Permalink
Checkpoint some inconclusive work on patch for Inclusion-Exclusion (#12
Browse files Browse the repository at this point in the history
…).
  • Loading branch information
Feras A. Saad committed Feb 18, 2020
1 parent 906674f commit 3981c98
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 27 deletions.
27 changes: 27 additions & 0 deletions src/dnf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2020 MIT Probabilistic Computing Project.
# See LICENSE.txt

from functools import reduce
from spn.transforms import EventAnd
from spn.transforms import EventBasic
from spn.transforms import EventOr
Expand Down Expand Up @@ -61,3 +62,29 @@ def factor_dnf_symbols(event, lookup):
return events

assert False, 'Invalid DNF event: %s' % (event,)

def dnf_to_disjoint_union(event):
# Given an event in DNF, return a list L of conjunctions that are
# disjoint from one another and whose disjunction is equal to event.
#
# For example, if A, B, C are conjunctions
# event = A or B or C
# The output is
# L = [A, B and ~A, C and ~A and ~B, ...]
if isinstance(event, (EventBasic, EventAnd)):
return [event]
return [
reduce(lambda state, event: state & ~event, event.subexprs[:i],
initial=event.subexprs[i])
for i in range(len(event.subexprs))
]

def make_disjoint_conjunction_factored(dnf_factor, i):
clause = dict(dnf_factor[i])
for j in range(i):
for k in dnf_factor[j]:
if k in clause:
clause[k] &= (~dnf_factor[j][k])
else:
clause[k] = (~dnf_factor[j][k])
return clause
71 changes: 44 additions & 27 deletions src/spn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sympy import Range
from sympy import Union

from .dnf import dnf_to_disjoint_union
from .dnf import factor_dnf_symbols

from .math_util import allclose
Expand Down Expand Up @@ -316,42 +317,58 @@ def logprob_inclusion_exclusion(self, event):
def logprob_disjoint_union(self, event):
# Adopting disjoint union principle.
# Disjoint union algorithm (yields mixture of products).

# Yields A or B or C
expr_dnf = event.to_dnf()

# Yields [A, B & ~A, C and ~A and ~B]
exprs_disjoint = dnf_to_disjoint_union(expr_dnf)

# Convert each item in exprs_disjoint to dnf.
exprs_disjoint_dnf = [e.to_dnf() for e in exprs_disjoint]

# Factor each DNF expression.
exprs_disjoint_dnf_factors = [
factor_dnf_symbols(e, self.lookup) for e in exprs_disjoint_dnf]

# Obtain the clauses in each DNF expression.

dnf_factor = factor_dnf_symbols(expr_dnf, self.lookup)
# Obtain the n disjoint clauses.
clauses = [
self.make_disjoint_conjunction(dnf_factor, i)
for i in dnf_factor
]
# clauses = [
# self.make_disjoint_conjunction(dnf_factor, i)
# for i in dnf_factor
# ]
# Construct the ProductSPN weights.
ws = [self.get_clause_weight(clause) for clause in clauses]
return logsumexp(ws)

def condition(self, event):
pass
# Disjoint union algorithm (yields mixture of products).
expr_dnf = event.to_dnf()
dnf_factor = factor_dnf_symbols(expr_dnf, self.lookup)
# Obtain the n disjoint clauses.
clauses = [
self.make_disjoint_conjunction(dnf_factor, i)
for i in dnf_factor
]
# Construct the ProductSPN weights.
ws = [self.get_clause_weight(clause) for clause in clauses]
indexes = [i for (i, w) in enumerate(ws) if not isinf_neg(w)]
if not indexes:
raise ValueError('Conditioning event "%s" has probability zero' %
(event,))
weights = lognorm([ws[i] for i in indexes])
# Construct the new ProductSPNs.
ds = [self.get_clause_conditioned(clauses[i]) for i in indexes]
products = [ProductSPN(d) for d in ds]
if len(products) == 1:
return products[0]
# Return SumSPN of the products.
return SumSPN(products, weights)

def make_disjoint_conjunction(self, dnf_factor, i):
# expr_dnf = event.to_dnf()
# dnf_factor = factor_dnf_symbols(expr_dnf, self.lookup)
# # Obtain the n disjoint clauses.
# clauses = [
# self.make_disjoint_conjunction(dnf_factor, i)
# for i in dnf_factor
# ]
# # Construct the ProductSPN weights.
# ws = [self.get_clause_weight(clause) for clause in clauses]
# indexes = [i for (i, w) in enumerate(ws) if not isinf_neg(w)]
# if not indexes:
# raise ValueError('Conditioning event "%s" has probability zero' %
# (event,))
# weights = lognorm([ws[i] for i in indexes])
# # Construct the new ProductSPNs.
# ds = [self.get_clause_conditioned(clauses[i]) for i in indexes]
# products = [ProductSPN(d) for d in ds]
# if len(products) == 1:
# return products[0]
# # Return SumSPN of the products.
# return SumSPN(products, weights)

def make_disjoint_conjunction_factored(self, dnf_factor, i):
clause = dict(dnf_factor[i])
for j in range(i):
for k in dnf_factor[j]:
Expand Down

0 comments on commit 3981c98

Please sign in to comment.