Skip to content

Commit

Permalink
Merge ebba96d into 1f846de
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard authored Nov 14, 2021
2 parents 1f846de + ebba96d commit 35e20a7
Show file tree
Hide file tree
Showing 11 changed files with 1,250 additions and 525 deletions.
301 changes: 205 additions & 96 deletions kanren/assoccomm.py

Large diffs are not rendered by default.

113 changes: 89 additions & 24 deletions kanren/core.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from collections.abc import Generator, Sequence
from collections.abc import Generator, Mapping, Sequence
from functools import partial, reduce
from itertools import tee
from operator import length_hint
from statistics import mean

from cons.core import ConsPair
from cons.core import ConsPair, car, cdr
from multipledispatch import dispatch
from toolz import interleave, take
from unification import isvar, reify, unify
from unification.core import isground


def fail(s):
Expand Down Expand Up @@ -113,18 +114,25 @@ def conde(*goals):
lany = ldisj


def ground_order_key(S, x):
@dispatch(Mapping, object)
def shallow_ground_order_key(S, x):
if isvar(x):
return 2
elif isground(x, S):
return -1
elif issubclass(type(x), ConsPair):
return 1
else:
return 0


def ground_order(in_args, out_args):
return 10
elif isinstance(x, ConsPair):
val = 0
val += 1 if isvar(car(x)) else 0
cdr_x = cdr(x)
if issubclass(type(x), ConsPair):
val += 2 if isvar(cdr_x) else 0
elif len(cdr_x) == 1:
val += 1 if isvar(cdr_x[0]) else 0
elif len(cdr_x) > 1:
val += mean(1.0 if isvar(i) else 0.0 for i in cdr_x)
return val
return 0


def ground_order(in_args, out_args, key_fn=shallow_ground_order_key):
"""Construct a non-relational goal that orders a list of terms based on groundedness (grounded precede ungrounded).""" # noqa: E501

def ground_order_goal(S):
Expand All @@ -134,7 +142,7 @@ def ground_order_goal(S):

S_new = unify(
list(out_args_rf) if isinstance(out_args_rf, Sequence) else out_args_rf,
sorted(in_args_rf, key=partial(ground_order_key, S)),
sorted(in_args_rf, key=partial(key_fn, S)),
S,
)

Expand All @@ -144,6 +152,48 @@ def ground_order_goal(S):
return ground_order_goal


def ground_order_seqs(in_seqs, out_seqs, key_fn=shallow_ground_order_key):
"""Construct a non-relational goal that orders lists of sequences based on the groundedness of their corresponding terms. # noqa: E501
>>> from unification import var
>>> x, y = var('x'), var('y')
>>> a, b = var('a'), var('b')
>>> run(0, (x, y), ground_order_seqs([(a, b), (b, 2)], [x, y]))
(((~b, ~a), (2, ~b)),)
"""

def ground_order_seqs_goal(S):
nonlocal in_seqs, out_seqs, key_fn

in_seqs_rf, out_seqs_rf = reify((in_seqs, out_seqs), S)

if (
not any(isinstance(s, str) for s in in_seqs_rf)
and reduce(
lambda x, y: x == y and y, (length_hint(s, -1) for s in in_seqs_rf)
)
> 0
):

in_seqs_ord = zip(*sorted(zip(*in_seqs_rf), key=partial(key_fn, S)))
S_new = unify(
list(out_seqs_rf),
[type(j)(i) for i, j in zip(in_seqs_ord, in_seqs_rf)],
S,
)

if S_new is not False:
yield S_new
else:

S_new = unify(out_seqs_rf, in_seqs_rf, S)

if S_new is not False:
yield S_new

return ground_order_seqs_goal


def ifa(g1, g2):
"""Create a goal operator that returns the first stream unless it fails."""

Expand Down Expand Up @@ -204,22 +254,37 @@ def run(n, x, *goals, results_filter=None):
return tuple(take(n, results))


def dbgo(*args, msg=None): # pragma: no cover
"""Construct a goal that sets a debug trace and prints reified arguments."""
def dbgo(*args, msg=None, pdb=False, print_asap=True, trace=True): # pragma: no cover
"""Construct a goal that prints reified arguments and, optionally, sets a debug trace.""" # noqa: E501
from pprint import pprint

from unification import var

trace_var = var("__dbgo_trace")

def dbgo_goal(S):
nonlocal args
args = reify(args, S)
nonlocal args, msg, pdb, print_asap, trace_var, trace

args_rf, trace_rf = reify((args, trace_var), S)

if trace:
S = S.copy()
if isvar(trace_rf):
S[trace_var] = [(msg, tuple(str(a) for a in args_rf))]
else:
trace_rf.append((msg, tuple(str(a) for a in args_rf)))
S[trace_var] = trace_rf

if msg is not None:
print(msg)
if print_asap:
if msg is not None:
print(msg)
pprint(args_rf)

pprint(args)
if pdb:
import pdb

import pdb
pdb.set_trace()

pdb.set_trace()
yield S

return dbgo_goal
Loading

0 comments on commit 35e20a7

Please sign in to comment.