Skip to content

Commit

Permalink
API: make optional arg nvars in method count
Browse files Browse the repository at this point in the history
This change is on the methods `dd.cudd.BDD.count`
and `dd.cudd_zdd.ZDD.count`. The tests of method
`count` and the CHANGELOG are updated.
  • Loading branch information
johnyf committed Sep 2, 2020
1 parent 890d191 commit 2cb2093
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 4 deletions.
3 changes: 3 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ API:
to `list` of `dd.cudd.Function`
- multiple roots supported in `dd.cudd.BDD.dump` for
file types other than DDDMP
- method `count` of the classes
`dd.cudd.BDD` and `dd.cudd_zdd.ZDD`:
- make optional the argument `nvars`
- `dd.autoref.BDD.load`:
require file extension `.p` for pickle files

Expand Down
6 changes: 5 additions & 1 deletion dd/cudd.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -854,14 +854,18 @@ cdef class BDD(object):
r = cuddUniqueInter(self.manager, index, high.node, low.node)
return wrap(self, r)

def count(self, Function u, int nvars):
def count(self, Function u, nvars=None):
"""Return number of models of node `u`.
@param nvars: regard `u` as an operator that
depends on `nvars` many variables.
If omitted, then assume those in `support(u)`.
"""
assert u.manager == self.manager
n = len(self.support(u))
if nvars is None:
nvars = n
assert nvars >= n, (nvars, n)
r = Cudd_CountMinterm(self.manager, u.node, nvars)
assert r != CUDD_OUT_OF_MEM
Expand Down
6 changes: 5 additions & 1 deletion dd/cudd_zdd.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1030,17 +1030,21 @@ cdef class ZDD(object):
assert w is not None
return (v, w)

def count(self, Function u, int nvars):
def count(self, Function u, nvars=None):
"""Return nuber of models of node `u`.
@param nvars: regard `u` as an operator that
depends on `nvars` many variables.
If omitted, then assume those in `support(u)`.
"""
logger.debug('count')
assert u.manager == self.manager
support = self.support(u)
r = self._count(0, u, support, cache=dict())
n_support = len(support)
if nvars == None:
nvars = n_support
assert nvars >= n_support, (nvars, n_support)
return r * 2**(nvars - n_support)

Expand Down
24 changes: 22 additions & 2 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,8 @@ def test_cofactor(self):

def test_count(self):
b = self.DD()
b.add_var('x')
# x
b.declare('x')
u = b.add_expr('x')
with assert_raises(AssertionError):
b.count(u, 0)
Expand All @@ -264,7 +265,10 @@ def test_count(self):
assert n == 2, n
n = b.count(u, 3)
assert n == 4, n
b.add_var('y')
n = b.count(u)
assert n == 1, n
# x /\ y
b.declare('y')
u = b.add_expr('x /\ y')
with assert_raises(AssertionError):
b.count(u, 0)
Expand All @@ -276,6 +280,22 @@ def test_count(self):
assert n == 2, n
n = b.count(u, 5)
assert n == 8, n
n = b.count(u)
assert n == 1, n
# x \/ ~ y
u = b.add_expr('x \/ ~ y')
with assert_raises(AssertionError):
b.count(u, 0)
with assert_raises(AssertionError):
b.count(u, 1)
n = b.count(u, 2)
assert n == 3, n
n = b.count(u, 3)
assert n == 6, n
n = b.count(u, 4)
assert n == 12, n
n = b.count(u)
assert n == 3, n

def test_pick_iter(self):
b = self.DD()
Expand Down

0 comments on commit 2cb2093

Please sign in to comment.