Skip to content

Commit

Permalink
Merge pull request #125 from potassco/113-non-determinism-in-controls…
Browse files Browse the repository at this point in the history
…olve-assumptions

clorm.clingo.Control.solve() with assumptions is now deterministic
  • Loading branch information
daveraja committed Sep 20, 2023
2 parents ed70d51 + 59c45cb commit f309c76
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 21 deletions.
16 changes: 4 additions & 12 deletions clorm/_clingo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import clingo as oclingo

from .orm import FactBase, Predicate, Symbol, SymbolPredicateUnifier, control_add_facts
from .util.oset import OrderedSet
from .util.wrapper import init_wrapper, make_class_wrapper

__all__ = ["ClormControl", "ClormModel", "ClormSolveHandle", "_expand_assumptions"]
Expand Down Expand Up @@ -246,16 +247,11 @@ def _expand_assumptions(
Tuple[Union[Iterable[Union[Predicate, Symbol]], Predicate, Symbol], bool]
]
) -> List[Tuple[Symbol, bool]]:
pos_assump = set()
neg_assump = set()
clingo_assump = []

def _add_fact(fact: Union[Predicate, Symbol], bval: bool) -> None:
nonlocal pos_assump, neg_assump
raw = fact.raw if isinstance(fact, Predicate) else fact
if bval:
pos_assump.add(raw)
else:
neg_assump.add(raw)
clingo_assump.append((raw, bool(bval)))

try:
for (arg, bval) in assumptions:
Expand All @@ -274,11 +270,7 @@ def _add_fact(fact: Union[Predicate, Symbol], bval: bool) -> None:
"of raw-symbols/predicates). Got: {}"
).format(assumptions)
)

# Now returned a list of raw assumptions combining pos and neg
pos = [(raw, True) for raw in pos_assump]
neg = [(raw, False) for raw in neg_assump]
return list(itertools.chain(pos, neg))
return clingo_assump


# ------------------------------------------------------------------------------
Expand Down
19 changes: 10 additions & 9 deletions tests/test_clingo.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,9 +782,10 @@ class Meta:
num_models += 1
self.assertEqual(num_models, 3)

# --------------------------------------------------------------------------
# Test the solvehandle
# --------------------------------------------------------------------------
# ----------------------------------------------------------------------------------------
# Test expanding the assumptions - note: the order matters so for an input list of
# predicates or symbols the output list is the corresponding symbols in the same order.
# ----------------------------------------------------------------------------------------
def test_expand_assumptions(self):
class F(Predicate):
num1 = IntegerField()
Expand All @@ -796,20 +797,20 @@ class G(Predicate):
f2 = F(2)
g1 = G(1)

r = set(_expand_assumptions([(f1, True), (g1, False)]))
self.assertEqual(r, set([(f1.raw, True), (g1.raw, False)]))
r = _expand_assumptions([(f1, True), (g1, False)])
self.assertEqual(r, [(f1.raw, True), (g1.raw, False)])

r = set(_expand_assumptions([(FactBase([f1, f2]), True), (set([g1]), False)]))
self.assertEqual(r, set([(f1.raw, True), (f2.raw, True), (g1.raw, False)]))
r = _expand_assumptions([(FactBase([f1, f2]), True), (set([g1]), False)])
self.assertEqual(r, [(f1.raw, True), (f2.raw, True), (g1.raw, False)])

with self.assertRaises(TypeError) as ctx:
_expand_assumptions([g1])
with self.assertRaises(TypeError) as ctx:
_expand_assumptions(g1)

# --------------------------------------------------------------------------
# ----------------------------------------------------------------------------------------
# Test the solvehandle
# --------------------------------------------------------------------------
# ----------------------------------------------------------------------------------------
def test_solve_with_assumptions_simple(self):
spu = SymbolPredicateUnifier()

Expand Down

0 comments on commit f309c76

Please sign in to comment.