Skip to content

Commit

Permalink
Merge pull request #14 from brandonwillard/fix-constraint-goals
Browse files Browse the repository at this point in the history
Copy constraint stores and fix arguments overwrite bug
  • Loading branch information
brandonwillard committed Dec 28, 2019
2 parents 92c3159 + 3f08cd6 commit 63803f4
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 35 deletions.
89 changes: 62 additions & 27 deletions kanren/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ class ConstraintStore(ABC):
"""

__slots__ = ("lvar_constraints", "op_str")
__slots__ = ("lvar_constraints",)
op_str = None

def __init__(self, op_str, lvar_constraints=None):
self.op_str = op_str
def __init__(self, lvar_constraints=None):
# self.lvar_constraints = weakref.WeakKeyDictionary(lvar_constraints)
self.lvar_constraints = lvar_constraints or dict()

Expand All @@ -39,7 +39,10 @@ def pre_unify_check(self, lvar_map, lvar=None, value=None):

@abstractmethod
def post_unify_check(self, lvar_map, lvar=None, value=None, old_state=None):
"""Check a key-value pair after they're added to a ConstrainedState."""
"""Check a key-value pair after they're added to a ConstrainedState.
XXX: This method may alter the internal constraints, so make a copy!
"""
raise NotImplementedError()

def add(self, lvar, lvar_constraint, **kwargs):
Expand All @@ -56,6 +59,11 @@ def constraints_str(self, lvar):
else:
return ""

def copy(self):
return type(self)(
lvar_constraints={k: v.copy() for k, v in self.lvar_constraints.items()},
)

def __contains__(self, lvar):
return lvar in self.lvar_constraints

Expand All @@ -80,15 +88,34 @@ def __init__(self, *s, constraints=None):
self.constraints = dict(constraints or [])

def pre_unify_checks(self, lvar, value):
"""Check the constraints before unification."""
return all(
cstore.pre_unify_check(self.data, lvar, value)
for cstore in self.constraints.values()
)

def post_unify_checks(self, lvar_map, lvar, value):
return all(
cstore.post_unify_check(lvar_map, lvar, value, old_state=self)
for cstore in self.constraints.values()
"""Check constraints and return an updated state and constraints.
Returns
-------
A new `ConstrainedState` and `False`.
"""
S = self.copy(data=lvar_map)
if any(
not cstore.post_unify_check(lvar_map, lvar, value, old_state=S)
for cstore in S.constraints.values()
):
return False

return S

def copy(self, data=None):
if data is None:
data = self.data.copy()
return type(self)(
data, constraints={k: v.copy() for k, v in self.constraints.items()}
)

def __eq__(self, other):
Expand All @@ -107,8 +134,10 @@ def __repr__(self):
def unify_ConstrainedState(u, v, S):
if S.pre_unify_checks(u, v):
s = unify(u, v, S.data)
if s is not False and S.post_unify_checks(s, u, v):
return ConstrainedState(s, constraints=S.constraints)
if s is not False:
S = S.post_unify_checks(s, u, v)
if S is not False:
return S

return False

Expand Down Expand Up @@ -165,8 +194,10 @@ def reify_ConstrainedState(u, S):
class DisequalityStore(ConstraintStore):
"""A disequality constraint (i.e. two things do not unify)."""

op_str = "neq"

def __init__(self, lvar_constraints=None):
super().__init__("=/=", lvar_constraints)
super().__init__(lvar_constraints)

def post_unify_check(self, lvar_map, lvar=None, value=None, old_state=None):

Expand Down Expand Up @@ -209,11 +240,11 @@ def neq(u, v):
def neq_goal(S):
nonlocal u, v

u, v = reify((u, v), S)
u_rf, v_rf = reify((u, v), S)

# Get the unground logic variables that would unify the two objects;
# these are all the logic variables that we can't let unify.
s_uv = unify(u, v, {})
s_uv = unify(u_rf, v_rf, {})

if s_uv is False:
# They don't unify and have no unground logic variables, so the
Expand Down Expand Up @@ -315,8 +346,10 @@ def pre_unify_check(self, lvar_map, lvar=None, value=None):
class TypeStore(PredicateStore):
"""A constraint store for asserting object types."""

op_str = "typeo"

def __init__(self, lvar_constraints=None):
super().__init__("typeo", lvar_constraints)
super().__init__(lvar_constraints)

# def cterm_type_check(self, lvt):
# return True
Expand All @@ -334,26 +367,26 @@ def typeo(u, u_type):
def typeo_goal(S):
nonlocal u, u_type

u, u_type = reify((u, u_type), S)
u_rf, u_type_rf = reify((u, u_type), S)

if not isground(u, S) or not isground(u_type, S):
if not isground(u_rf, S) or not isground(u_type, S):

if not isinstance(S, ConstrainedState):
S = ConstrainedState(S)

cs = S.constraints.setdefault(TypeStore, TypeStore())

try:
cs.add(u, u_type)
cs.add(u_rf, u_type_rf)
except TypeError:
# If the instance object can't be hashed, we can simply use a
# logic variable to uniquely identify it.
u_lv = var()
S[u_lv] = u
cs.add(u_lv, u_type)
S[u_lv] = u_rf
cs.add(u_lv, u_type_rf)

yield S
elif isinstance(u_type, type) and type(u) == u_type:
elif isinstance(u_type_rf, type) and type(u_rf) == u_type_rf:
yield S

return typeo_goal
Expand All @@ -362,8 +395,10 @@ def typeo_goal(S):
class IsinstanceStore(PredicateStore):
"""A constraint store for asserting object instance types."""

op_str = "isinstanceo"

def __init__(self, lvar_constraints=None):
super().__init__("isinstanceo", lvar_constraints)
super().__init__(lvar_constraints)

# def cterm_type_check(self, lvt):
# return True
Expand Down Expand Up @@ -392,35 +427,35 @@ def isinstanceo(u, u_type):
def isinstanceo_goal(S):
nonlocal u, u_type

u, u_type = reify((u, u_type), S)
u_rf, u_type_rf = reify((u, u_type), S)

if not isground(u, S) or not isground(u_type, S):
if not isground(u_rf, S) or not isground(u_type_rf, S):

if not isinstance(S, ConstrainedState):
S = ConstrainedState(S)

cs = S.constraints.setdefault(IsinstanceStore, IsinstanceStore())

try:
cs.add(u, u_type)
cs.add(u_rf, u_type_rf)
except TypeError:
# If the instance object can't be hashed, we can simply use a
# logic variable to uniquely identify it.
u_lv = var()
S[u_lv] = u
cs.add(u_lv, u_type)
S[u_lv] = u_rf
cs.add(u_lv, u_type_rf)

yield S

# elif isground(u_type, S):
# yield from lany(eq(u_type, u_t) for u_t in type(u).mro())(S)
elif (
isinstance(u_type, type)
isinstance(u_type_rf, type)
# or (
# isinstance(u_type, Iterable)
# and all(isinstance(t, type) for t in u_type)
# )
) and isinstance(u, u_type):
) and isinstance(u_rf, u_type_rf):
yield S

return isinstanceo_goal
13 changes: 10 additions & 3 deletions kanren/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ class FlexibleSet(MutableSet):

__slots__ = ("set", "list")

def __init__(self, iterable):
def __init__(self, iterable=None):

self.set = set()
self.list = []

for i in iterable:
self.add(i)
if iterable is not None:
for i in iterable:
self.add(i)

def add(self, item):
try:
Expand Down Expand Up @@ -52,6 +53,12 @@ def remove(self, item):
except ValueError:
raise KeyError()

def copy(self):
res = type(self)()
res.set = self.set.copy()
res.list = self.list.copy()
return res

def __le__(self, other):
raise NotImplementedError()

Expand Down
66 changes: 61 additions & 5 deletions tests/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from cons import cons

from kanren import run, eq
from kanren import run, eq, conde
from kanren.core import lall, goaleval
from kanren.constraints import (
ConstrainedState,
Expand All @@ -21,7 +21,7 @@ def lconj(*goals):
return goaleval(lall(*goals))


def test_kanrenstate():
def test_ConstrainedState():

a_lv, b_lv = var(), var()

Expand Down Expand Up @@ -54,6 +54,23 @@ def test_kanrenstate():
assert unify(a_lv, b_lv, ks)
assert unify(a_lv, b_lv, ks)

ks = ConstrainedState(
{a_lv: 1}, constraints={DisequalityStore: DisequalityStore({b_lv: {1}})}
)
ks_2 = ks.copy()
assert ks == ks_2
assert ks is not ks_2
assert ks.constraints is not ks_2.constraints
assert ks.constraints[DisequalityStore] is not ks_2.constraints[DisequalityStore]
assert (
ks.constraints[DisequalityStore].lvar_constraints[b_lv]
== ks_2.constraints[DisequalityStore].lvar_constraints[b_lv]
)
assert (
ks.constraints[DisequalityStore].lvar_constraints[b_lv]
is not ks_2.constraints[DisequalityStore].lvar_constraints[b_lv]
)


def test_reify():
var_a = var("a")
Expand All @@ -64,14 +81,14 @@ def test_reify():
de = DisequalityStore({var_a: {1, 2}})
ks.constraints[DisequalityStore] = de

assert repr(de) == "ConstraintStore(=/=: {~a: {1, 2}})"
assert repr(de) == "ConstraintStore(neq: {~a: {1, 2}})"
assert de.constraints_str(var()) == ""

assert repr(ConstrainedVar(var_a, ks)) == "~a: {=/= {1, 2}}"
assert repr(ConstrainedVar(var_a, ks)) == "~a: {neq {1, 2}}"

# TODO: Make this work with `reify` when `var('a')` isn't in `ks`.
assert isinstance(_reify(var_a, ks), ConstrainedVar)
assert repr(_reify(var_a, ks)) == "~a: {=/= {1, 2}}"
assert repr(_reify(var_a, ks)) == "~a: {neq {1, 2}}"


def test_ConstraintStore():
Expand Down Expand Up @@ -151,8 +168,27 @@ def test_disequality():
([neq(cons(1, a_lv), [1]), eq(a_lv, [])], 0),
([neq(cons(1, a_lv), [1]), eq(a_lv, b_lv), eq(b_lv, [])], 0),
([neq([1], cons(1, a_lv)), eq(a_lv, b_lv), eq(b_lv, [])], 0),
# TODO FIXME: This one won't work due to an ambiguity in `cons`.
# (
# [
# neq([1], cons(1, a_lv)),
# eq(a_lv, b_lv),
# # Both make `cons` produce a list
# conde([eq(b_lv, None)], [eq(b_lv, [])]),
# ],
# 0,
# ),
([neq(cons(1, a_lv), [1]), eq(a_lv, b_lv), eq(b_lv, tuple())], 1),
([neq([1], cons(1, a_lv)), eq(a_lv, b_lv), eq(b_lv, tuple())], 1),
(
[
neq([1], cons(1, a_lv)),
eq(a_lv, b_lv),
# The first should fail, the second should succeed
conde([eq(b_lv, [])], [eq(b_lv, tuple())]),
],
1,
),
([neq(a_lv, 1), eq(a_lv, 1)], 0),
([neq(a_lv, 1), eq(b_lv, 1), eq(a_lv, b_lv)], 0),
([neq(a_lv, 1), eq(b_lv, 1), eq(a_lv, b_lv)], 0),
Expand Down Expand Up @@ -198,6 +234,16 @@ def test_typeo():
([typeo(cons(1, a_lv), list), eq(a_lv, [])], (q_lv,)),
# Logic variable instance and type arguments
([typeo(q_lv, int), eq(b_lv, 1), eq(b_lv, q_lv)], (1,)),
# The same, but with `conde`
(
[
typeo(q_lv, int),
# One succeeds, one fails
conde([eq(b_lv, 1)], [eq(b_lv, "hi")]),
eq(b_lv, q_lv),
],
(1,),
),
# Logic variable instance argument that's eventually grounded to a
# mismatched instance type through another logic variable
([typeo(q_lv, int), eq(b_lv, 1.0), eq(b_lv, q_lv)], ()),
Expand Down Expand Up @@ -255,6 +301,16 @@ def test_instanceo():
# Logic variable instance argument that's eventually grounded through
# another logic variable
([isinstanceo(q_lv, int), eq(b_lv, 1), eq(b_lv, q_lv)], (1,)),
# The same, but with `conde`
(
[
isinstanceo(q_lv, int),
# One succeeds, one fails
conde([eq(b_lv, 1)], [eq(b_lv, "hi")]),
eq(b_lv, q_lv),
],
(1,),
),
# Logic variable instance argument that's eventually grounded to a
# mismatched instance type through another logic variable
([isinstanceo(q_lv, int), eq(b_lv, 1.0), eq(b_lv, q_lv)], ()),
Expand Down

0 comments on commit 63803f4

Please sign in to comment.