From 37a6a9cf4fb0053a7ad91826cb27780c40fbf03b Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Wed, 1 Jan 2020 20:28:00 -0600 Subject: [PATCH 01/15] Fix appendo and add rembero --- kanren/core.py | 9 ++-- kanren/goals.py | 103 +++++++++++++++++++++++++++++++++++--------- tests/test_goals.py | 93 ++++++++++++++++++++++++++++++++++----- 3 files changed, 169 insertions(+), 36 deletions(-) diff --git a/kanren/core.py b/kanren/core.py index 562e178..f13cc79 100644 --- a/kanren/core.py +++ b/kanren/core.py @@ -1,10 +1,11 @@ -import itertools as it +from itertools import tee from functools import partial -from .util import dicthash, interleave, take, multihash, unique, evalt -from toolz import groupby, map +from toolz import groupby, map from unification import reify, unify +from .util import dicthash, interleave, take, multihash, unique, evalt + def fail(s): return iter(()) @@ -182,7 +183,7 @@ def lanyseq(goals): """Construct a logical any with a possibly infinite number of goals.""" def anygoal(s): - anygoal.goals, local_goals = it.tee(anygoal.goals) + anygoal.goals, local_goals = tee(anygoal.goals) def f(goals): for goal in goals: diff --git a/kanren/goals.py b/kanren/goals.py index e62eafb..05946fa 100644 --- a/kanren/goals.py +++ b/kanren/goals.py @@ -1,5 +1,6 @@ import collections +from functools import partial from itertools import permutations from collections.abc import Sequence @@ -13,7 +14,6 @@ conde, condeseq, lany, - lallgreedy, lall, fail, success, @@ -49,7 +49,7 @@ def conso(h, t, l): return eq(cons(h, t), l) -def nullo(*args, default_ConsNull=list): +def nullo(*args, refs=None, default_ConsNull=list): """Create a goal asserting that one or more terms are a/the same `ConsNull` type. `ConsNull` types return proper Python collections when used as a CDR value @@ -62,19 +62,34 @@ def nullo(*args, default_ConsNull=list): walking distinct lists that do not necessarily terminate on the same iteration. - Unground logic variables will be set to the value of the `default_ConsNull` kwarg. + Parameters + ---------- + args: tuple of objects + The terms to consider as an instance of the `ConsNull` type + refs: tuple of objects + The terms to use as reference types. These are not unified with the + `ConsNull` type, instead they are used to constrain the `ConsNull` + types considered valid. + default_ConsNull: type + The sequence type to use when all logic variables are unground. + """ - def eqnullo_goal(s): + def nullo_goal(s): nonlocal args, default_ConsNull + if refs is not None: + refs_rf = reify(refs, s) + else: + refs_rf = () + args_rf = reify(args, s) arg_null_types = set( # Get an empty instance of the type type(a) - for a in args_rf + for a in args_rf + refs_rf # `ConsPair` and `ConsNull` types that are not literally `ConsPair`s if isinstance(a, (ConsPair, ConsNull)) and not issubclass(type(a), ConsPair) ) @@ -92,7 +107,7 @@ def eqnullo_goal(s): yield from goaleval(g)(s) - return eqnullo_goal + return nullo_goal def itero(l, default_ConsNull=list): @@ -232,22 +247,68 @@ def membero(x, coll): raise EarlyGoalError() -def appendo(l, s, ls, base_type=tuple): - """Construct a goal stating ls = l + s. +def appendo(l, s, out, default_ConsNull=list): + """Construct a goal for the relation l + s = ls. See Byrd thesis pg. 247 https://scholarworks.iu.edu/dspace/bitstream/handle/2022/8777/Byrd_indiana_0093A_10344.pdf - - Parameters - ---------- - base_type: type - The empty collection type to use when all terms are logic variables. """ - if all(map(isvar, (l, s, ls))): - raise EarlyGoalError() - a, d, res = [var() for i in range(3)] - return ( - lany, - (lallgreedy, (eq, l, base_type()), (eq, s, ls)), - (lall, (conso, a, d, l), (conso, a, res, ls), (appendo, d, s, res)), - ) + + def appendo_goal(S): + nonlocal l, s, out + + l_rf, s_rf, out_rf = reify((l, s, out), S) + + a, d, res = var(prefix="a"), var(prefix="d"), var(prefix="res") + + _nullo = partial(nullo, default_ConsNull=default_ConsNull) + + g = conde( + [ + # All empty + _nullo(s_rf, l_rf, out_rf), + ], + [ + # `l` is empty + conso(a, d, out_rf), + eq(s_rf, out_rf), + _nullo(l_rf, refs=(s_rf, out_rf)), + ], + [ + conso(a, d, l_rf), + conso(a, res, out_rf), + appendo(d, s_rf, res, default_ConsNull=default_ConsNull), + ], + ) + + yield from goaleval(g)(S) + + return appendo_goal + + +def rembero(x, l, o, default_ConsNull=list): + """Remove the first occurrence of `x` in `l` resulting in `o`.""" + + from .constraints import neq + + def rembero_goal(s): + nonlocal x, l, o + + x_rf, l_rf, o_rf = reify((x, l, o), s) + + l_car, l_cdr, r = var(), var(), var() + + g = conde( + [nullo(l_rf, o_rf, default_ConsNull=default_ConsNull),], + [conso(l_car, l_cdr, l_rf), eq(x_rf, l_car), eq(l_cdr, o_rf),], + [ + conso(l_car, l_cdr, l_rf), + neq(l_car, x), + conso(l_car, r, o_rf), + rembero(x_rf, l_cdr, r, default_ConsNull=default_ConsNull), + ], + ) + + yield from goaleval(g)(s) + + return rembero_goal diff --git a/tests/test_goals.py b/tests/test_goals.py index e4cdc82..280f146 100644 --- a/tests/test_goals.py +++ b/tests/test_goals.py @@ -1,8 +1,8 @@ -from __future__ import absolute_import - import pytest -from unification import var, isvar +from unification import var, isvar, unify + +from cons import cons from kanren.goals import ( tailo, @@ -14,6 +14,7 @@ itero, permuteq, membero, + rembero, ) from kanren.core import run, eq, goaleval, lall, lallgreedy, EarlyGoalError @@ -154,20 +155,73 @@ def test_permuteq(): def test_appendo(): - assert results(appendo((), (1, 2), (1, 2))) == ({},) - assert results(appendo((), (1, 2), (1))) == () - assert results(appendo((1, 2), (3, 4), (1, 2, 3, 4))) - assert run(5, x, appendo((1, 2, 3), x, (1, 2, 3, 4, 5))) == ((4, 5),) - assert run(5, x, appendo(x, (4, 5), (1, 2, 3, 4, 5))) == ((1, 2, 3),) - assert run(5, x, appendo((1, 2, 3), (4, 5), x)) == ((1, 2, 3, 4, 5),) + q_lv = var() + assert run(0, q_lv, appendo((), (1, 2), (1, 2))) == (q_lv,) + assert run(0, q_lv, appendo((), (1, 2), 1)) == () + assert run(0, q_lv, appendo((), (1, 2), (1,))) == () + assert run(0, q_lv, appendo((1, 2), (3, 4), (1, 2, 3, 4))) == (q_lv,) + assert run(5, q_lv, appendo((1, 2, 3), q_lv, (1, 2, 3, 4, 5))) == ((4, 5),) + assert run(5, q_lv, appendo(q_lv, (4, 5), (1, 2, 3, 4, 5))) == ((1, 2, 3),) + assert run(5, q_lv, appendo((1, 2, 3), (4, 5), q_lv)) == ((1, 2, 3, 4, 5),) + + q_lv, r_lv = var(), var() + + assert ([1, 2, 3, 4],) == run(0, q_lv, appendo([1, 2], [3, 4], q_lv)) + assert ([3, 4],) == run(0, q_lv, appendo([1, 2], q_lv, [1, 2, 3, 4])) + assert ([1, 2],) == run(0, q_lv, appendo(q_lv, [3, 4], [1, 2, 3, 4])) + + expected_res = set( + [ + ((), (1, 2, 3, 4)), + ((1,), (2, 3, 4)), + ((1, 2), (3, 4)), + ((1, 2, 3), (4,)), + ((1, 2, 3, 4), ()), + ] + ) + assert expected_res == set(run(0, (q_lv, r_lv), appendo(q_lv, r_lv, (1, 2, 3, 4)))) + + res = run(3, (q_lv, r_lv), appendo(q_lv, [3, 4], r_lv)) + assert len(res) == 3 + assert any(len(a) > 0 and isvar(a[0]) for a, b in res) + assert all(a + [3, 4] == b for a, b in res) + + res = run(0, (q_lv, r_lv), appendo([3, 4], q_lv, r_lv)) + assert len(res) == 2 + assert ([], [3, 4]) == res[0] + assert all( + type(v) == cons for v in unify((var(), cons(3, 4, var())), res[1]).values() + ) +@pytest.mark.skip("Misspecified test") def test_appendo2(): + # XXX: This test generates goal conjunctions that are non-terminating given + # the specified goal ordering. More specifically, it generates + # `lall(appendo(x, y, w), appendo(w, z, ()))`, for which the first + # `appendo` produces an infinite stream of results and the second + # necessarily fails for all values of the first `appendo` yielding + # non-empty `w` unifications. + # + # The only reason it worked before is the `EarlyGoalError` + # and it's implicit goal reordering, which made this case an out-of-place + # test for a goal reordering feature that has nothing to do with `appendo`. + # Furthermore, the `EarlyGoalError` mechanics do *not* fix this general + # problem, and it's trivial to generate an equivalent situation in which + # an `EarlyGoalError` is never thrown. + # + # In other words, it seems like a nice side effect of `EarlyGoalError`, but + # it's actually a very costly approach that masks a bigger issue; one that + # all miniKanren programmers need to think about when developing. + + x, y, z, w = var(), var(), var(), var() for t in [tuple(range(i)) for i in range(5)]: + print(t) for xi, yi in run(0, (x, y), appendo(x, y, t)): assert xi + yi == t - results = run(0, (x, y, z), (appendo, x, y, w), (appendo, w, z, t)) - for xi, yi, zi in results: + + results = run(2, (x, y, z, w), appendo(x, y, w), appendo(w, z, t)) + for xi, yi, zi, wi in results: assert xi + yi + zi == t @@ -202,3 +256,20 @@ def lefto(q, p, lst): (solution,) = run(1, vals, rules_greedy) assert solution == ("green", "white") + + +def test_rembero(): + + q_lv = var() + assert ([],) == run(0, q_lv, rembero(1, [1], q_lv)) + assert ([], [1]) == run(0, q_lv, rembero(1, q_lv, [])) + + expected_res = ( + [5, 1, 2, 3, 4], + [1, 5, 2, 3, 4], + [1, 2, 5, 3, 4], + [1, 2, 3, 5, 4], + [1, 2, 3, 4], + [1, 2, 3, 4, 5], + ) + assert expected_res == run(0, q_lv, rembero(5, q_lv, [1, 2, 3, 4])) From 1eddbd5b857835a5980e36520f63adccb6f4035d Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Tue, 14 Jan 2020 21:43:18 -0600 Subject: [PATCH 02/15] Move goal ordering tests to core --- tests/test_core.py | 33 +++++++++++++++++++++++++++++++++ tests/test_goals.py | 35 +---------------------------------- 2 files changed, 34 insertions(+), 34 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index b0c055c..4b41398 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -175,3 +175,36 @@ def results(g, s=None): def test_dict(): x = var() assert run(0, x, eq({1: x}, {1: 2})) == (2,) + + +def test_goal_ordering(): + # Regression test for https://github.com/logpy/logpy/issues/58 + + def lefto(q, p, lst): + if isvar(lst): + raise EarlyGoalError() + return ege_membero((q, p), zip(lst, lst[1:])) + + vals = var() + + # Verify the solution can be computed when we specify the execution + # ordering. + rules_greedy = ( + lallgreedy, + (eq, (var(), var()), vals), + (lefto, "green", "white", vals), + ) + + (solution,) = run(1, vals, rules_greedy) + assert solution == ("green", "white") + + # Verify that attempting to compute the "safe" order does not itself cause + # the evaluation to fail. + rules_greedy = ( + lall, + (eq, (var(), var()), vals), + (lefto, "green", "white", vals), + ) + + (solution,) = run(1, vals, rules_greedy) + assert solution == ("green", "white") diff --git a/tests/test_goals.py b/tests/test_goals.py index 280f146..3d7bf1c 100644 --- a/tests/test_goals.py +++ b/tests/test_goals.py @@ -16,7 +16,7 @@ membero, rembero, ) -from kanren.core import run, eq, goaleval, lall, lallgreedy, EarlyGoalError +from kanren.core import eq, goaleval, run x, y, z, w = var("x"), var("y"), var("z"), var("w") @@ -225,39 +225,6 @@ def test_appendo2(): assert xi + yi + zi == t -def test_goal_ordering(): - # Regression test for https://github.com/logpy/logpy/issues/58 - - def lefto(q, p, lst): - if isvar(lst): - raise EarlyGoalError() - return membero((q, p), zip(lst, lst[1:])) - - vals = var() - - # Verify the solution can be computed when we specify the execution - # ordering. - rules_greedy = ( - lallgreedy, - (eq, (var(), var()), vals), - (lefto, "green", "white", vals), - ) - - (solution,) = run(1, vals, rules_greedy) - assert solution == ("green", "white") - - # Verify that attempting to compute the "safe" order does not itself cause - # the evaluation to fail. - rules_greedy = ( - lall, - (eq, (var(), var()), vals), - (lefto, "green", "white", vals), - ) - - (solution,) = run(1, vals, rules_greedy) - assert solution == ("green", "white") - - def test_rembero(): q_lv = var() From a726719a404edb010283c2d875c9f692340b50f7 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Thu, 2 Jan 2020 12:00:41 -0600 Subject: [PATCH 03/15] Implement a fully functional membero --- kanren/core.py | 33 +++---------------------- kanren/goals.py | 18 ++++++++++---- tests/test_core.py | 60 ++++++++++++++++++++++++++++----------------- tests/test_goals.py | 39 +++++++++++++++++++++-------- 4 files changed, 84 insertions(+), 66 deletions(-) diff --git a/kanren/core.py b/kanren/core.py index f13cc79..04fdced 100644 --- a/kanren/core.py +++ b/kanren/core.py @@ -50,18 +50,9 @@ def lall(*goals): def lallgreedy(*goals): """Construct a logical all that greedily evaluates each goals in the order provided. - Note that this may raise EarlyGoalError when the ordering of the - goals is incorrect. It is faster than lall, but should be used - with care. + Note that this may raise EarlyGoalError when the ordering of the goals is + incorrect. It is faster than lall, but should be used with care. - >>> from kanren import eq, run, membero, var - >>> x, y = var('x'), var('y') - >>> run(0, x, lallgreedy((eq, y, set([1]))), (membero, x, y)) - (1,) - >>> run(0, x, lallgreedy((membero, x, y), (eq, y, {1}))) # doctest: +SKIP - Traceback (most recent call last): - ... - kanren.core.EarlyGoalError """ if not goals: return success @@ -81,16 +72,7 @@ def allgoal(s): def lallfirst(*goals): - """Construct a logical all that runs goals one at a time. - - >>> from kanren import membero, var - >>> x = var('x') - >>> g = lallfirst(membero(x, (1,2,3)), membero(x, (2,3,4))) - >>> tuple(g({})) - ({~x: 2}, {~x: 3}) - >>> tuple(lallfirst()({})) - ({},) - """ + """Construct a logical all that runs goals one at a time.""" if not goals: return success if len(goals) == 1: @@ -117,14 +99,7 @@ def allgoal(s): def lany(*goals): - """Construct a logical any goal. - - >>> from kanren import lany, membero, var - >>> x = var('x') - >>> g = lany(membero(x, (1,2,3)), membero(x, (2,3,4))) - >>> tuple(g({})) - ({~x: 1}, {~x: 2}, {~x: 3}, {~x: 4}) - """ + """Construct a logical any goal.""" if len(goals) == 1: return goals[0] return lanyseq(goals) diff --git a/kanren/goals.py b/kanren/goals.py index 05946fa..8b10cc1 100644 --- a/kanren/goals.py +++ b/kanren/goals.py @@ -13,7 +13,6 @@ EarlyGoalError, conde, condeseq, - lany, lall, fail, success, @@ -240,11 +239,20 @@ def funco(inputs, out): # pragma: noqa return funco -def membero(x, coll): +def membero(x, ls): """Construct a goal stating that x is an item of coll.""" - if not isvar(coll): - return (lany,) + tuple((eq, x, item) for item in coll) - raise EarlyGoalError() + + def membero_goal(S): + nonlocal x, ls + + x_rf, ls_rf = reify((x, ls), S) + a, d = var(), var() + + g = lall(conso(a, d, ls), conde([eq(a, x)], [membero(x, d)])) + + yield from goaleval(g)(S) + + return membero_goal def appendo(l, s, out, default_ConsNull=list): diff --git a/tests/test_core.py b/tests/test_core.py index 4b41398..ad168a3 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,12 +1,9 @@ -from __future__ import absolute_import - from itertools import count -import pytest -from pytest import raises -from unification import var +from pytest import raises, mark + +from unification import var, isvar -from kanren.goals import membero from kanren.core import ( run, fail, @@ -25,7 +22,11 @@ ) from kanren.util import evalt -w, x, y, z = "wxyz" + +def ege_membero(x, coll): + if not isvar(coll): + return (lany,) + tuple((eq, x, item) for item in coll) + raise EarlyGoalError() def test_eq(): @@ -39,41 +40,57 @@ def test_lany(): assert len(tuple(lany(eq(x, 2), eq(x, 3))({}))) == 2 assert len(tuple(lany((eq, x, 2), (eq, x, 3))({}))) == 2 + g = lany(ege_membero(x, (1, 2, 3)), ege_membero(x, (2, 3, 4))) + assert tuple(g({})) == ({x: 1}, {x: 2}, {x: 3}, {x: 4}) + + +def test_lallfirst(): + x = var("x") + g = lallfirst(ege_membero(x, (1, 2, 3)), ege_membero(x, (2, 3, 4))) + assert tuple(g({})) == ({x: 2}, {x: 3}) + assert tuple(lallfirst()({})) == ({},) + -# Test that all three implementations of lallgreedy behave identically for -# correctly ordered goals. -@pytest.mark.parametrize("lall_impl", [lallgreedy, lall, lallfirst]) +def test_lallgreedy(): + x, y = var("x"), var("y") + assert run(0, x, lallgreedy((eq, y, set([1]))), (ege_membero, x, y)) == (1,) + with raises(EarlyGoalError): + run(0, x, lallgreedy((ege_membero, x, y), (eq, y, {1}))) + + +@mark.parametrize("lall_impl", [lallgreedy, lall, lallfirst]) def test_lall(lall_impl): + """Test that all three implementations of lallgreedy behave identically for correctly ordered goals.""" x, y = var("x"), var("y") assert results(lall_impl((eq, x, 2))) == ({x: 2},) assert results(lall_impl((eq, x, 2), (eq, x, 3))) == () assert results(lall_impl()) == ({},) - assert run(0, x, lall_impl((eq, y, (1, 2)), (membero, x, y))) + assert run(0, x, lall_impl((eq, y, (1, 2)), (ege_membero, x, y))) assert run(0, x, lall_impl()) == (x,) - with pytest.raises(EarlyGoalError): - run(0, x, lall_impl(membero(x, y))) + with raises(EarlyGoalError): + run(0, x, lall_impl(ege_membero(x, y))) -@pytest.mark.parametrize("lall_impl", [lall, lallfirst]) +@mark.parametrize("lall_impl", [lall, lallfirst]) def test_safe_reordering_lall(lall_impl): x, y = var("x"), var("y") - assert run(0, x, lall_impl((membero, x, y), (eq, y, (1, 2)))) == (1, 2) + assert run(0, x, lall_impl((ege_membero, x, y), (eq, y, (1, 2)))) == (1, 2) def test_earlysafe(): x, y = var("x"), var("y") assert earlysafe((eq, 2, 2)) assert earlysafe((eq, 2, 3)) - assert earlysafe((membero, x, (1, 2, 3))) - assert not earlysafe((membero, x, y)) + assert earlysafe((ege_membero, x, (1, 2, 3))) + assert not earlysafe((ege_membero, x, y)) def test_earlyorder(): x, y = var(), var() assert earlyorder((eq, 2, x)) == ((eq, 2, x),) assert earlyorder((eq, 2, x), (eq, 3, x)) == ((eq, 2, x), (eq, 3, x)) - assert earlyorder((membero, x, y), (eq, y, (1, 2, 3)))[0] == (eq, y, (1, 2, 3)) + assert earlyorder((ege_membero, x, y), (eq, y, (1, 2, 3)))[0] == (eq, y, (1, 2, 3)) def test_conde(): @@ -138,7 +155,7 @@ def test_goaleval(): assert goaleval(g) == g assert callable(goaleval((eq, x, 2))) with raises(EarlyGoalError): - goaleval((membero, x, y)) + goaleval((ege_membero, x, y)) assert callable(goaleval((lallgreedy, (eq, x, 2)))) @@ -161,9 +178,8 @@ def _bad_relation(): def test_lany_is_early_safe(): - x = var() - y = var() - assert run(0, x, lany((membero, x, y), (eq, x, 2))) == (2,) + x, y = var(), var() + assert run(0, x, lany((ege_membero, x, y), (eq, x, 2))) == (2,) def results(g, s=None): diff --git a/tests/test_goals.py b/tests/test_goals.py index 3d7bf1c..66672ee 100644 --- a/tests/test_goals.py +++ b/tests/test_goals.py @@ -3,6 +3,7 @@ from unification import var, isvar, unify from cons import cons +from cons.core import ConsPair from kanren.goals import ( tailo, @@ -18,8 +19,6 @@ ) from kanren.core import eq, goaleval, run -x, y, z, w = var("x"), var("y"), var("z"), var("w") - def results(g, s=None): if s is None: @@ -28,6 +27,7 @@ def results(g, s=None): def test_heado(): + x, y, z = var(), var(), var() assert (x, 1) in results(heado(x, (1, 2, 3)))[0].items() assert (x, 1) in results(heado(1, (x, 2, 3)))[0].items() assert results(heado(x, ())) == () @@ -36,6 +36,8 @@ def test_heado(): def test_tailo(): + x, y, z = var(), var(), var() + assert (x, (2, 3)) in results((tailo, x, (1, 2, 3)))[0].items() assert (x, ()) in results((tailo, x, (1,)))[0].items() assert results((tailo, x, ())) == () @@ -44,6 +46,8 @@ def test_tailo(): def test_conso(): + x, y, z = var(), var(), var() + assert not results(conso(x, y, ())) assert results(conso(1, (2, 3), (1, 2, 3))) assert results(conso(x, (2, 3), (1, 2, 3))) == ({x: 1},) @@ -62,6 +66,7 @@ def __add__(self, other): def test_nullo_itero(): + x, y, z = var(), var(), var() q_lv, a_lv = var(), var() assert run(0, q_lv, conso(1, q_lv, [1]), nullo(q_lv)) @@ -98,21 +103,32 @@ def test_nullo_itero(): def test_membero(): - x = var("x") + x, y = var(), var() + assert set(run(5, x, membero(x, (1, 2, 3)), membero(x, (2, 3, 4)))) == {2, 3} assert run(5, x, membero(2, (1, x, 3))) == (2,) - assert run(0, x, (membero, 1, (1, 2, 3))) == (x,) - assert run(0, x, (membero, 1, (2, 3))) == () - + assert run(0, x, membero(1, (1, 2, 3))) == (x,) + assert run(0, x, membero(1, (2, 3))) == () -def test_membero_can_be_reused(): g = membero(x, (0, 1, 2)) - assert list(goaleval(g)({})) == [{x: 0}, {x: 1}, {x: 2}] - assert list(goaleval(g)({})) == [{x: 0}, {x: 1}, {x: 2}] + assert tuple(r[x] for r in goaleval(g)({})) == (0, 1, 2) + + def in_cons(x, y): + if issubclass(type(y), ConsPair): + return x == y.car or in_cons(x, y.cdr) + else: + return False + + res = run(4, x, membero(1, x)) + assert all(in_cons(1, r) for r in res) + + res = run(4, (x, y), membero(x, y)) + assert all(in_cons(i, r) for i, r in res) def test_uneval_membero(): + x, y = var(), var() assert set(run(100, x, (membero, y, ((1, 2, 3), (4, 5, 6))), (membero, x, y))) == { 1, 2, @@ -124,6 +140,8 @@ def test_uneval_membero(): def test_seteq(): + + x, y = var(), var() abc = tuple("abc") bca = tuple("bca") assert results(seteq(abc, bca)) @@ -149,6 +167,7 @@ def test_permuteq(): assert not results(permuteq((1, 2, 1), (2, 1, 2))) assert not results(permuteq([1, 2, 1], (2, 1, 2))) + x = var() assert set(run(0, x, permuteq(x, (1, 2, 2)))) == set( ((1, 2, 2), (2, 1, 2), (2, 2, 1)) ) @@ -195,7 +214,7 @@ def test_appendo(): @pytest.mark.skip("Misspecified test") -def test_appendo2(): +def test_appendo_reorder(): # XXX: This test generates goal conjunctions that are non-terminating given # the specified goal ordering. More specifically, it generates # `lall(appendo(x, y, w), appendo(w, z, ()))`, for which the first From 18c4debaa38460fc7ef129f1ef1e884784075145 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Thu, 2 Jan 2020 23:11:45 -0600 Subject: [PATCH 04/15] Add mapo and permuteo XXX: `eq_assoccomm` is broken and needs to be updated to handle the now relational `permuteq`/`permuteo`, which it relies on via `eq_comm`. --- kanren/__init__.py | 2 +- kanren/assoccomm.py | 34 ++++-- kanren/goals.py | 224 +++++++++++++++++++++++----------------- kanren/graph.py | 24 +++++ tests/test_assoccomm.py | 17 ++- tests/test_goals.py | 116 ++++++++++++++------- tests/test_graph.py | 25 ++++- 7 files changed, 300 insertions(+), 142 deletions(-) diff --git a/kanren/__init__.py b/kanren/__init__.py index cf4e782..372e7a9 100644 --- a/kanren/__init__.py +++ b/kanren/__init__.py @@ -5,7 +5,7 @@ from unification import unify, reify, unifiable, var, isvar, vars, variables, Var from .core import run, eq, conde, lall, lany -from .goals import seteq, permuteq, goalify, membero +from .goals import seteq, permuteo, permuteq, goalify, membero from .facts import Relation, fact, facts from .term import arguments, operator, term, unifiable_with_term diff --git a/kanren/assoccomm.py b/kanren/assoccomm.py index e89676d..89b1e86 100644 --- a/kanren/assoccomm.py +++ b/kanren/assoccomm.py @@ -49,6 +49,7 @@ from .facts import Relation from .util import groupsizes, index from .term import term, arguments, operator +from .graph import mapo associative = Relation("associative") commutative = Relation("commutative") @@ -173,7 +174,7 @@ def eq_assoc(u, v, eq=core.eq, n=None): return (core.eq, u, v) -def eq_comm(u, v, eq=None): +def eq_comm(u, v, inner_eq=None): """Create a goal for commutative equality. >>> from kanren import run, var, fact @@ -187,21 +188,40 @@ def eq_comm(u, v, eq=None): >>> run(0, x, eq(('add', 1, 2, 3), ('add', 2, x, 1))) (3,) """ - eq = eq or eq_comm - vtail = var() + inner_eq = inner_eq or eq_comm + vtail, vhead = var(), var() + if isvar(u) and isvar(v): - return (core.eq, u, v) + return eq(u, v) + uop, uargs = op_args(u) vop, vargs = op_args(v) + if not uop and not vop: - return (core.eq, u, v) + return eq(u, v) + if vop and not uop: uop, uargs = vop, vargs v, u = u, v + return ( conde, - ((core.eq, u, v),), - ((commutative, uop), (buildo, uop, vtail, v), (permuteq, uargs, vtail, eq)), + [eq(u, v)], + [ + (buildo, vhead, vtail, v), + ( + conde, + [ + (inner_eq, uop, vhead), + (mapo, lambda a, b: (eq_comm, a, b, inner_eq), uargs, vtail), + ], + [ + eq(uop, vhead), + (commutative, uop), + (permuteq, uargs, vtail, inner_eq), + ], + ), + ], ) diff --git a/kanren/goals.py b/kanren/goals.py index 8b10cc1..d2ff470 100644 --- a/kanren/goals.py +++ b/kanren/goals.py @@ -1,24 +1,22 @@ -import collections - +from operator import length_hint from functools import partial from itertools import permutations +from collections import Counter from collections.abc import Sequence from cons import cons from cons.core import ConsNull, ConsPair from unification import isvar, reify, var +from unification.core import isground from .core import ( eq, EarlyGoalError, conde, - condeseq, lall, - fail, - success, + lanyseq, goaleval, ) -from .util import unique def heado(head, coll): @@ -109,7 +107,7 @@ def nullo_goal(s): return nullo_goal -def itero(l, default_ConsNull=list): +def itero(l, nullo_refs=None, default_ConsNull=list): """Construct a goal asserting that a term is an iterable type. This is a generic version of the standard `listo` that accounts for @@ -119,11 +117,11 @@ def itero(l, default_ConsNull=list): """ def itero_goal(S): - nonlocal l + nonlocal l, nullo_refs, default_ConsNull l_rf = reify(l, S) c, d = var(), var() g = conde( - [nullo(l_rf, default_ConsNull=default_ConsNull)], + [nullo(l_rf, refs=nullo_refs, default_ConsNull=default_ConsNull)], [conso(c, d, l_rf), itero(d, default_ConsNull=default_ConsNull)], ) yield from goaleval(g)(S) @@ -131,58 +129,6 @@ def itero_goal(S): return itero_goal -def permuteq(a, b, eq2=eq): - """Construct a goal asserting equality under permutation. - - For example, (1, 2, 2) equates to (2, 1, 2) under permutation - >>> from kanren import var, run, permuteq - >>> x = var() - >>> run(0, x, permuteq(x, (1, 2))) - ((1, 2), (2, 1)) - - >>> run(0, x, permuteq((2, 1, x), (2, 1, 2))) - (2,) - """ - if isinstance(a, Sequence) and isinstance(b, Sequence): - if len(a) != len(b): - return fail - elif collections.Counter(a) == collections.Counter(b): - return success - else: - c, d = list(a), list(b) - for x in list(c): - # TODO: This is quadratic in the number items in the sequence. - # Need something like a multiset. Maybe use - # collections.Counter? - try: - d.remove(x) - c.remove(x) - except ValueError: - pass - - if len(c) == 1: - return (eq2, c[0], d[0]) - return condeseq( - ((eq2, x, d[0]), (permuteq, c[0:i] + c[i + 1 :], d[1:], eq2)) - for i, x in enumerate(c) - ) - elif not (isinstance(a, Sequence) or isinstance(b, Sequence)): - raise ValueError( - "Neither a nor b is a Sequence: {}, {}".format(type(a), type(b)) - ) - - if isvar(a) and isvar(b): - raise EarlyGoalError() - - if isvar(a) or isvar(b): - if isinstance(b, Sequence): - c, d = a, b - elif isinstance(a, Sequence): - c, d = b, a - - return (condeseq, ([eq(c, perm)] for perm in unique(permutations(d, len(d))))) - - def seteq(a, b, eq2=eq): """Construct a goal asserting set equality. @@ -205,40 +151,6 @@ def seteq(a, b, eq2=eq): return permuteq(a, ts(b), eq2) -def goalify(func, name=None): # pragma: noqa - """Convert Python function into kanren goal. - - >>> from kanren import run, goalify, var, membero - >>> typo = goalify(type) - >>> x = var('x') - >>> run(0, x, membero(x, (1, 'cat', 'hat', 2)), (typo, x, str)) - ('cat', 'hat') - - Goals go both ways. Here are all of the types in the collection - - >>> typ = var('typ') - >>> results = run(0, typ, membero(x, (1, 'cat', 'hat', 2)), (typo, x, typ)) - >>> print([result.__name__ for result in results]) - ['int', 'str'] - """ - - def funco(inputs, out): # pragma: noqa - if isvar(inputs): - raise EarlyGoalError() - else: - if isinstance(inputs, (tuple, list)): - if any(map(isvar, inputs)): - raise EarlyGoalError() - return (eq, func(*inputs), out) - else: - return (eq, func(inputs), out) - - name = name or (func.__name__ + "o") - funco.__name__ = name - - return funco - - def membero(x, ls): """Construct a goal stating that x is an item of coll.""" @@ -320,3 +232,125 @@ def rembero_goal(s): yield from goaleval(g)(s) return rembero_goal + + +def permuteo(a, b, inner_eq=eq, default_ConsNull=list): + """Construct a goal asserting equality or sequences under permutation. + + For example, (1, 2, 2) equates to (2, 1, 2) under permutation + >>> from kanren import var, run, permuteo + >>> x = var() + >>> run(0, x, permuteo(x, (1, 2))) + ((1, 2), (2, 1)) + + >>> run(0, x, permuteo((2, 1, x), (2, 1, 2))) + (2,) + """ + + def permuteo_goal(S): + nonlocal a, b, default_ConsNull, inner_eq + + a_rf, b_rf = reify((a, b), S) + + # If the lengths differ, then fail + a_len, b_len = length_hint(a_rf, -1), length_hint(b_rf, -1) + if a_len > 0 and b_len > 0 and a_len != b_len: + return + + if isinstance(a_rf, Sequence): + + a_type = type(a_rf) + + if isinstance(b_rf, Sequence): + + # `a` and `b` are sequences, so let's see if we can + # pull out all the equal elements using their hashes. + + b_type = type(b_rf) + + if a_type != b_type: + return + + try: + cntr_a, cntr_b = Counter(a_rf), Counter(b_rf) + rdcd_a, rdcd_b = cntr_a - cntr_b, cntr_b - cntr_a + a_rf, b_rf = tuple(rdcd_a.elements()), b_type(rdcd_b.elements()) + except TypeError: + # TODO: We could probably get more coverage for this case + # by using `HashableForm`. + pass + + # If they're both ground, then simply check that one is a + # permutation of the other and be done + if isground(a_rf, S) and isground(b_rf, S): + if a_rf in permutations(b_rf): + yield S + return + else: + return + + # Unify all permutations of the sequence `a` with `b` + yield from lanyseq(inner_eq(b_rf, a_type(i)) for i in permutations(a_rf))(S) + + elif isinstance(b_rf, Sequence): + + b_type = type(b_rf) + + # Unify all permutations of the sequence `b` with `a` + yield from lanyseq(inner_eq(a_rf, b_type(i)) for i in permutations(b_rf))(S) + + else: + + # None of the arguments are proper sequences, so state that one + # should be and apply `permuteo` to that. + + a_itero_g = itero( + a_rf, nullo_refs=(b_rf,), default_ConsNull=default_ConsNull + ) + + for S_new in a_itero_g(S): + a_new = reify(a_rf, S_new) + a_type = type(a_new) + yield from lanyseq( + inner_eq(b_rf, a_type(i)) for i in permutations(a_new) + )(S_new) + + return permuteo_goal + + +# For backward compatibility +permuteq = permuteo + + +def goalify(func, name=None): # pragma: noqa + """Convert a Python function into kanren goal. + + >>> from kanren import run, goalify, var, membero + >>> typo = goalify(type) + >>> x = var('x') + >>> run(0, x, membero(x, (1, 'cat', 'hat', 2)), (typo, x, str)) + ('cat', 'hat') + + Goals go both ways. Here are all of the types in the collection + + >>> typ = var('typ') + >>> results = run(0, typ, membero(x, (1, 'cat', 'hat', 2)), (typo, x, typ)) + >>> print([result.__name__ for result in results]) + ['int', 'str'] + """ + + def funco(inputs, out): # pragma: noqa + if isvar(inputs): + raise EarlyGoalError() + else: + if isinstance(inputs, (tuple, list)): + if any(map(isvar, inputs)): + raise EarlyGoalError() + return (eq, func(*inputs), out) + else: + return (eq, func(inputs), out) + + name = name or (func.__name__ + "o") + funco.__name__ = name + + return funco diff --git a/kanren/graph.py b/kanren/graph.py index 8e04c42..4799398 100644 --- a/kanren/graph.py +++ b/kanren/graph.py @@ -57,6 +57,30 @@ def applyo_goal(S): return applyo_goal +def mapo(relation, a, b, null_type=list): + """Apply a relation to corresponding elements in two sequences and succeed if the relation succeeds for all pairs.""" + + def mapo_goal(S): + + a_rf, b_rf = reify((a, b), S) + b_car, b_cdr = var(), var() + a_car, a_cdr = var(), var() + + g = conde( + [nullo(a_rf, b_rf, default_ConsNull=null_type)], + [ + conso(a_car, a_cdr, a_rf), + conso(b_car, b_cdr, b_rf), + relation(a_car, b_car), + mapo(relation, a_cdr, b_cdr, null_type=null_type), + ], + ) + + yield from goaleval(g)(S) + + return mapo_goal + + def map_anyo(relation, l_in, l_out, null_type=list): """Apply a relation to corresponding elements in two sequences and succeed if at least one pair succeeds. diff --git a/tests/test_assoccomm.py b/tests/test_assoccomm.py index a9c206e..0851a4d 100644 --- a/tests/test_assoccomm.py +++ b/tests/test_assoccomm.py @@ -20,9 +20,9 @@ ) from kanren.term import operator, arguments, term -a = "assoc_op" -c = "comm_op" x, y = var("x"), var("y") + +a, c = "assoc_op", "comm_op" fact(associative, a) fact(commutative, c) @@ -128,9 +128,9 @@ def test_eq_assoccomm(): assert results(eqac((1,), (1,))) assert results(eqac(x, (1,))) assert results(eqac((1,), x)) - assert results(eqac((ac, (ac, 1, x), y), (ac, 2, (ac, 3, 1)))) assert results((eqac, 1, 1)) assert results(eqac((a, (a, 1, 2), 3), (a, 1, 2, 3))) + assert results(eqac((ac, (ac, 1, x), y), (ac, 2, (ac, 3, 1)))) assert results(eqac((ac, (ac, 1, 2), 3), (ac, 1, 2, 3))) assert results(eqac((ac, 3, (ac, 1, 2)), (ac, 1, 2, 3))) assert not results(eqac((ac, 1, 1), ("other_op", 1, 1))) @@ -155,8 +155,19 @@ def test_expr(): def test_deep_commutativity(): x, y = var("x"), var("y") + e1 = ((c, 3, 1),) + e2 = ((c, 1, x),) + + assert run(0, x, eq_comm(e1, e2)) == (3,) + + e1 = (2, (c, 3, 1)) + e2 = (y, (c, 1, x)) + + assert run(0, (x, y), eq_comm(e1, e2)) == ((3, 2),) + e1 = (c, (c, 1, x), y) e2 = (c, 2, (c, 3, 1)) + assert run(0, (x, y), eq_comm(e1, e2)) == ((3, 2),) diff --git a/tests/test_goals.py b/tests/test_goals.py index 66672ee..a02507a 100644 --- a/tests/test_goals.py +++ b/tests/test_goals.py @@ -13,9 +13,9 @@ conso, nullo, itero, - permuteq, membero, rembero, + permuteo, ) from kanren.core import eq, goaleval, run @@ -139,40 +139,6 @@ def test_uneval_membero(): } -def test_seteq(): - - x, y = var(), var() - abc = tuple("abc") - bca = tuple("bca") - assert results(seteq(abc, bca)) - assert len(results(seteq(abc, x))) == 6 - assert len(results(seteq(x, abc))) == 6 - assert bca in run(0, x, seteq(abc, x)) - assert results(seteq((1, 2, 3), (3, x, 1))) == ({x: 2},) - - assert run(0, (x, y), seteq((1, 2, x), (2, 3, y)))[0] == (3, 1) - assert not run(0, (x, y), seteq((4, 5, x), (2, 3, y))) - - -def test_permuteq(): - assert results(permuteq((1, 2), (2, 1))) - assert results(permuteq([1, 2], [2, 1])) - assert results(permuteq((1, 2, 2), (2, 1, 2))) - - with pytest.raises(ValueError): - permuteq(set((1, 2, 2)), set((2, 1, 2))) - - assert not results(permuteq((1, 2), (2, 1, 2))) - assert not results(permuteq((1, 2, 3), (2, 1, 2))) - assert not results(permuteq((1, 2, 1), (2, 1, 2))) - assert not results(permuteq([1, 2, 1], (2, 1, 2))) - - x = var() - assert set(run(0, x, permuteq(x, (1, 2, 2)))) == set( - ((1, 2, 2), (2, 1, 2), (2, 2, 1)) - ) - - def test_appendo(): q_lv = var() assert run(0, q_lv, appendo((), (1, 2), (1, 2))) == (q_lv,) @@ -259,3 +225,83 @@ def test_rembero(): [1, 2, 3, 4, 5], ) assert expected_res == run(0, q_lv, rembero(5, q_lv, [1, 2, 3, 4])) + + +def test_permuteo(): + + from itertools import permutations + + a_lv = var() + q_lv = var() + + class Blah: + def __hash__(self): + raise TypeError() + + # An unhashable sequence with an unhashable object in it + obj_1 = [Blah()] + + assert results(permuteo((1, 2), (2, 1))) == ({},) + assert results(permuteo((1, obj_1), (obj_1, 1))) == ({},) + assert results(permuteo([1, 2], [2, 1])) == ({},) + assert results(permuteo((1, 2, 2), (2, 1, 2))) == ({},) + + # (1, obj_1, a_lv) == (1, obj_1, a_lv) ==> {a_lv: a_lv} + # (1, obj_1, a_lv) == (1, a_lv, obj_1) ==> {a_lv: obj_1} + # (1, obj_1, a_lv) == (a_lv, obj_1, 1) ==> {a_lv: 1} + assert run(0, a_lv, permuteo((1, obj_1, a_lv), (obj_1, a_lv, 1))) == ( + 1, + a_lv, + obj_1, + ) + + assert not results(permuteo((1, 2), (2, 1, 2))) + assert not results(permuteo((1, 2), (2, 1, 2))) + assert not results(permuteo((1, 2, 3), (2, 1, 2))) + assert not results(permuteo((1, 2, 1), (2, 1, 2))) + assert not results(permuteo([1, 2, 1], (2, 1, 2))) + + x = var() + assert set(run(0, x, permuteo(x, (1, 2, 2)))) == set( + ((1, 2, 2), (2, 1, 2), (2, 2, 1)) + ) + q_lv = var() + + assert run(0, q_lv, permuteo((1, 2, 3), (q_lv, 2, 1))) == (3,) + + assert run(0, q_lv, permuteo([1, 2, 3], [3, 2, 1])) + assert run(0, q_lv, permuteo((1, 2, 3), (3, 2, 1))) + assert run(0, q_lv, permuteo([1, 2, 3], [2, 1])) == () + assert run(0, q_lv, permuteo([1, 2, 3], (3, 2, 1))) == () + + col = [1, 2, 3] + exp_res = set(tuple(i) for i in permutations(col)) + + # The first term is ground + res = run(0, q_lv, permuteo(col, q_lv)) + assert all(type(r) == type(col) for r in res) + + res = set(tuple(r) for r in res) + assert res == exp_res + + # The second term is ground + res = run(0, q_lv, permuteo(q_lv, col)) + assert all(type(r) == type(col) for r in res) + + res = set(tuple(r) for r in res) + assert res == exp_res + + a_lv = var() + # Neither terms are ground + bi_res = run(5, [q_lv, a_lv], permuteo(q_lv, a_lv)) + + assert bi_res[0] == [[], []] + bi_var_1 = bi_res[1][0][0] + assert isvar(bi_var_1) + assert bi_res[1][0] == bi_res[1][1] == [bi_var_1] + bi_var_2 = bi_res[2][0][1] + assert isvar(bi_var_2) and bi_var_1 is not bi_var_2 + assert bi_res[2][0] == bi_res[2][1] == [bi_var_1, bi_var_2] + assert bi_res[3][0] != bi_res[3][1] == [bi_var_2, bi_var_1] + bi_var_3 = bi_res[4][0][2] + assert bi_res[4][0] == bi_res[4][1] == [bi_var_1, bi_var_2, bi_var_3] diff --git a/tests/test_graph.py b/tests/test_graph.py index 74a44c8..0dd8463 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -13,7 +13,7 @@ from kanren import run, eq, conde, lall from kanren.constraints import isinstanceo -from kanren.graph import applyo, reduceo, map_anyo, walko +from kanren.graph import applyo, reduceo, map_anyo, walko, mapo class OrderedFunction(object): @@ -147,6 +147,29 @@ def test_reduceo(): assert res[1] == etuple(log, etuple(exp, etuple(log, etuple(exp, 1)))) +def test_mapo(): + q_lv = var() + + def blah(x, y): + return conde([eq(x, 1), eq(y, "a")], [eq(x, 3), eq(y, "b")]) + + assert run(0, q_lv, mapo(blah, [], q_lv)) == ([],) + assert run(0, q_lv, mapo(blah, [1, 2, 3], q_lv)) == () + assert run(0, q_lv, mapo(blah, [1, 1, 3], q_lv)) == (["a", "a", "b"],) + assert run(0, q_lv, mapo(blah, q_lv, ["a", "a", "b"])) == ([1, 1, 3],) + + exp_res = ( + [[], []], + [[1], ["a"]], + [[3], ["b"]], + [[1, 1], ["a", "a"]], + [[3, 1], ["b", "a"]], + ) + + a_lv = var() + assert run(5, [q_lv, a_lv], mapo(blah, q_lv, a_lv)) == exp_res + + def test_map_anyo_types(): """Make sure that `applyo` preserves the types between its arguments.""" q_lv = var() From 9089aac068701ead17c49601e5572e00adc775f6 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Tue, 14 Jan 2020 21:36:02 -0600 Subject: [PATCH 05/15] Remove seteq and useless future imports --- kanren/__init__.py | 16 +++++++++++++--- kanren/goals.py | 22 ---------------------- tests/test_facts.py | 2 -- tests/test_goals.py | 1 - tests/test_sudoku.py | 2 -- 5 files changed, 13 insertions(+), 30 deletions(-) diff --git a/kanren/__init__.py b/kanren/__init__.py index 372e7a9..37a0504 100644 --- a/kanren/__init__.py +++ b/kanren/__init__.py @@ -1,11 +1,21 @@ # flake8: noqa """kanren is a Python library for logic and relational programming.""" -from __future__ import absolute_import - from unification import unify, reify, unifiable, var, isvar, vars, variables, Var from .core import run, eq, conde, lall, lany -from .goals import seteq, permuteo, permuteq, goalify, membero +from .goals import ( + heado, + tailo, + conso, + nullo, + itero, + appendo, + rembero, + permuteo, + permuteq, + membero, + goalify, +) from .facts import Relation, fact, facts from .term import arguments, operator, term, unifiable_with_term diff --git a/kanren/goals.py b/kanren/goals.py index d2ff470..3dba1af 100644 --- a/kanren/goals.py +++ b/kanren/goals.py @@ -129,28 +129,6 @@ def itero_goal(S): return itero_goal -def seteq(a, b, eq2=eq): - """Construct a goal asserting set equality. - - For example (1, 2, 3) set equates to (2, 1, 3) - - >>> from kanren import var, run, seteq - >>> x = var() - >>> run(0, x, seteq(x, (1, 2))) - ((1, 2), (2, 1)) - - >>> run(0, x, seteq((2, 1, x), (3, 1, 2))) - (3,) - """ - ts = lambda x: tuple(set(x)) - if not isvar(a) and not isvar(b): - return permuteq(ts(a), ts(b), eq2) - elif not isvar(a): - return permuteq(ts(a), b, eq2) - else: # not isvar(b) - return permuteq(a, ts(b), eq2) - - def membero(x, ls): """Construct a goal stating that x is an item of coll.""" diff --git a/tests/test_facts.py b/tests/test_facts.py index 9de2e9c..955b2fa 100644 --- a/tests/test_facts.py +++ b/tests/test_facts.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - from unification import var from kanren.core import run, conde diff --git a/tests/test_goals.py b/tests/test_goals.py index a02507a..ec64532 100644 --- a/tests/test_goals.py +++ b/tests/test_goals.py @@ -9,7 +9,6 @@ tailo, heado, appendo, - seteq, conso, nullo, itero, diff --git a/tests/test_sudoku.py b/tests/test_sudoku.py index 3dc2d0b..714c0a9 100644 --- a/tests/test_sudoku.py +++ b/tests/test_sudoku.py @@ -2,8 +2,6 @@ Based off https://github.com/holtchesley/embedded-logic/blob/master/kanren/sudoku.ipynb """ -from __future__ import absolute_import - from unification import var from kanren import run From 7da4e73dc7b64029e680e1aa1ae14ac7c6f9fb65 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Tue, 14 Jan 2020 21:37:09 -0600 Subject: [PATCH 06/15] Separate unique filter from run function --- kanren/core.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/kanren/core.py b/kanren/core.py index 04fdced..b024d1d 100644 --- a/kanren/core.py +++ b/kanren/core.py @@ -186,7 +186,7 @@ def everyg(predicate, coll): return (lall,) + tuple((predicate, x) for x in coll) -def run(n, x, *goals): +def run_all(n, x, *goals, results_filter=None): """Run a logic program and obtain n solutions that satisfy the given goals. >>> from kanren import run, var, eq @@ -207,7 +207,12 @@ def run(n, x, *goals): (i.e. `lall`). """ results = map(partial(reify, x), goaleval(lall(*goals))({})) - return take(n, unique(results, key=multihash)) + if results_filter is not None: + results = results_filter(results) + return take(n, results) + + +run = partial(run_all, results_filter=partial(unique, key=multihash)) class EarlyGoalError(Exception): From 6b6f5f79dd191d14a39c83289578b0eebcfed292 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Tue, 14 Jan 2020 21:44:16 -0600 Subject: [PATCH 07/15] Make term preserve sequence types --- kanren/assoccomm.py | 6 +++--- kanren/term.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/kanren/assoccomm.py b/kanren/assoccomm.py index 89b1e86..48b56bd 100644 --- a/kanren/assoccomm.py +++ b/kanren/assoccomm.py @@ -107,7 +107,7 @@ def makeops(op, lists): >>> from kanren.assoccomm import makeops >>> makeops('add', [(1, 2), (3, 4, 5)]) - (ExpressionTuple(('add', 1, 2)), ExpressionTuple(('add', 3, 4, 5))) + (('add', 1, 2), ('add', 3, 4, 5)) """ return tuple(l[0] if len(l) == 1 else build(op, l) for l in lists) @@ -151,7 +151,7 @@ def eq_assoc(u, v, eq=core.eq, n=None): >>> x = var() >>> run(0, x, eq(('add', 1, 2, 3), ('add', 1, x))) - (ExpressionTuple(('add', 2, 3)),) + (('add', 2, 3),) """ uop, _ = op_args(u) vop, _ = op_args(v) @@ -282,7 +282,7 @@ def eq_assoccomm(u, v): >>> e1 = ('add', 1, 2, 3) >>> e2 = ('add', 1, x) >>> run(0, x, eq(e1, e2)) - (ExpressionTuple(('add', 2, 3)), ExpressionTuple(('add', 3, 2))) + (('add', 2, 3), ('add', 3, 2)) """ uop, uargs = op_args(u) vop, vargs = op_args(v) diff --git a/kanren/term.py b/kanren/term.py index 6d95d32..a2dd814 100644 --- a/kanren/term.py +++ b/kanren/term.py @@ -1,9 +1,21 @@ +from collections.abc import Sequence + from unification import unify, reify from unification.core import _unify, _reify +from cons.core import cons + from etuples import rator as operator, rands as arguments, apply as term +@term.register(object, Sequence) +def term_Sequence(rator, rands): + # Overwrite the default `apply` dispatch function and make it preserve + # types + res = cons(rator, rands) + return res + + def unifiable_with_term(cls): _reify.add((cls, dict), reify_term) _unify.add((cls, cls, dict), unify_term) From 2439f93b600458d93ec11b956d6b3059059fc850 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Tue, 14 Jan 2020 21:46:37 -0600 Subject: [PATCH 08/15] Factor-out logic that converts objects to hashable forms --- kanren/util.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/kanren/util.py b/kanren/util.py index 4269d2b..dd24468 100644 --- a/kanren/util.py +++ b/kanren/util.py @@ -1,5 +1,9 @@ from itertools import chain, islice -from collections.abc import Hashable, MutableSet, Set +from collections import namedtuple +from collections.abc import Hashable, MutableSet, Set, Mapping, Iterable + + +HashableForm = namedtuple("HashableForm", ["type", "data"]) class FlexibleSet(MutableSet): @@ -20,7 +24,7 @@ def add(self, item): try: self.set.add(item) except TypeError: - # TODO: Could try `multihash`. + # TODO: Could try `make_hashable`. # TODO: Use `bisect` for unhashable but orderable elements if item not in self.list: self.list.append(item) @@ -113,17 +117,21 @@ def dicthash(d): return hash(frozenset(d.items())) +def make_hashable(x): + # TODO: Better as a dispatch function? + if hashable(x): + return x + if isinstance(x, slice): + return HashableForm(type(x), (x.start, x.stop, x.step)) + if isinstance(x, Mapping): + return HashableForm(type(x), frozenset(tuple(multihash(i) for i in x.items()))) + if isinstance(x, Iterable): + return HashableForm(type(x), tuple(multihash(i) for i in x)) + raise TypeError(f"Hashing not covered for {x}") + + def multihash(x): - try: - return hash(x) - except TypeError: - if isinstance(x, (list, tuple, set, frozenset)): - return hash(tuple(multihash(i) for i in x)) - if type(x) is dict: - return hash(frozenset(tuple(multihash(i) for i in x.items()))) - if type(x) is slice: - return hash((x.start, x.stop, x.step)) - raise TypeError("Hashing not covered for " + str(x)) + return hash(make_hashable(x)) def unique(seq, key=lambda x: x): From bbdf3b43d2691d328241be45aa9a8b8061ff8b52 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sun, 26 Jan 2020 19:25:32 -0600 Subject: [PATCH 09/15] Introduce Zzz, eq_length and refactor mapo goals --- kanren/core.py | 9 +++ kanren/graph.py | 185 +++++++++++++++++++++----------------------- tests/test_graph.py | 51 +++++++++++- 3 files changed, 148 insertions(+), 97 deletions(-) diff --git a/kanren/core.py b/kanren/core.py index b024d1d..ff21a86 100644 --- a/kanren/core.py +++ b/kanren/core.py @@ -186,6 +186,15 @@ def everyg(predicate, coll): return (lall,) + tuple((predicate, x) for x in coll) +def Zzz(gctor, *args, **kwargs): + """Create an inverse-η-delay for a goal.""" + + def Zzz_goal(S): + yield from goaleval(gctor(*args, **kwargs))(S) + + return Zzz_goal + + def run_all(n, x, *goals, results_filter=None): """Run a logic program and obtain n solutions that satisfy the given goals. diff --git a/kanren/graph.py b/kanren/graph.py index 4799398..35652c1 100644 --- a/kanren/graph.py +++ b/kanren/graph.py @@ -5,9 +5,9 @@ from cons.core import ConsError -from etuples import apply, rands, rator +from etuples import etuple, apply, rands, rator -from .core import eq, conde, lall, goaleval +from .core import eq, conde, lall, goaleval, success, Zzz, fail from .goals import conso, nullo @@ -57,35 +57,28 @@ def applyo_goal(S): return applyo_goal -def mapo(relation, a, b, null_type=list): +def mapo(relation, a, b, null_type=list, null_res=True, first=True): """Apply a relation to corresponding elements in two sequences and succeed if the relation succeeds for all pairs.""" - def mapo_goal(S): + b_car, b_cdr = var(), var() + a_car, a_cdr = var(), var() - a_rf, b_rf = reify((a, b), S) - b_car, b_cdr = var(), var() - a_car, a_cdr = var(), var() + return conde( + [nullo(a, b, default_ConsNull=null_type) if (not first or null_res) else fail], + [ + conso(a_car, a_cdr, a), + conso(b_car, b_cdr, b), + Zzz(relation, a_car, b_car), + Zzz(mapo, relation, a_cdr, b_cdr, null_type=null_type, first=False), + ], + ) - g = conde( - [nullo(a_rf, b_rf, default_ConsNull=null_type)], - [ - conso(a_car, a_cdr, a_rf), - conso(b_car, b_cdr, b_rf), - relation(a_car, b_car), - mapo(relation, a_cdr, b_cdr, null_type=null_type), - ], - ) - yield from goaleval(g)(S) - - return mapo_goal - - -def map_anyo(relation, l_in, l_out, null_type=list): +def map_anyo( + relation, a, b, null_type=list, null_res=False, first=True, any_succeed=False +): """Apply a relation to corresponding elements in two sequences and succeed if at least one pair succeeds. - Empty `l_in` and/or `l_out` will fail--i.e. `relation` must succeed *at least once*. - Parameters ---------- null_type: optional @@ -94,69 +87,59 @@ def map_anyo(relation, l_in, l_out, null_type=list): inputs, or defaults to an empty list. """ - def _map_anyo(relation, l_in, l_out, i_any): - def map_anyo_goal(s): - - nonlocal relation, l_in, l_out, i_any, null_type - - l_in_rf, l_out_rf = reify((l_in, l_out), s) - - i_car, i_cdr = var(), var() - o_car, o_cdr = var(), var() - - conde_branches = [] - - if i_any or (isvar(l_in_rf) and isvar(l_out_rf)): - # Consider terminating the sequences when we've had at least - # one successful goal or when both sequences are logic variables. - conde_branches.append( - [nullo(l_in_rf, l_out_rf, default_ConsNull=null_type)] - ) - - # Extract the CAR and CDR of each argument sequence; this is how we - # iterate through elements of the two sequences. - cons_parts_branch = [ - goaleval(conso(i_car, i_cdr, l_in_rf)), - goaleval(conso(o_car, o_cdr, l_out_rf)), - ] - - conde_branches.append(cons_parts_branch) - - conde_relation_branches = [] - - relation_branch = [ - # This case tries the relation and continues on. - relation(i_car, o_car), - # In this conde clause, we can tell future calls to - # `map_anyo` that we've had at least one successful - # application of the relation (otherwise, this clause - # would fail due to the above goal). - _map_anyo(relation, i_cdr, o_cdr, True), - ] - - conde_relation_branches.append(relation_branch) - - base_branch = [ - # This is the "base" case; it is used when, for example, - # the given relation isn't satisfied. - eq(i_car, o_car), - _map_anyo(relation, i_cdr, o_cdr, i_any), - ] - - conde_relation_branches.append(base_branch) - - cons_parts_branch.append(conde(*conde_relation_branches)) - - g = conde(*conde_branches) - - yield from goaleval(g)(s) - - return map_anyo_goal - - return _map_anyo(relation, l_in, l_out, False) - - -def reduceo(relation, in_term, out_term): + b_car, b_cdr = var(), var() + a_car, a_cdr = var(), var() + + return conde( + [ + nullo(a, b, default_ConsNull=null_type) + if (any_succeed or (first and null_res)) + else fail + ], + [ + conso(a_car, a_cdr, a), + conso(b_car, b_cdr, b), + conde( + [ + Zzz(relation, a_car, b_car), + Zzz( + map_anyo, + relation, + a_cdr, + b_cdr, + null_type=null_type, + any_succeed=True, + first=False, + ), + ], + [ + eq(a_car, b_car), + Zzz( + map_anyo, + relation, + a_cdr, + b_cdr, + null_type=null_type, + any_succeed=any_succeed, + first=False, + ), + ], + ), + ], + ) + + +def vararg_success(*args): + return success + + +def eq_length(u, v, default_ConsNull=list): + """Construct a goal stating that two sequences are the same length and type.""" + + return mapo(vararg_success, u, v, null_type=default_ConsNull) + + +def reduceo(relation, in_term, out_term, *args, **kwargs): """Relate a term and the fixed-point of that term under a given relation. This includes the "identity" relation. @@ -164,7 +147,7 @@ def reduceo(relation, in_term, out_term): def reduceo_goal(s): - nonlocal in_term, out_term, relation + nonlocal in_term, out_term, relation, args, kwargs in_term_rf, out_term_rf = reify((in_term, out_term), s) @@ -176,7 +159,7 @@ def reduceo_goal(s): is_expanding = isvar(in_term_rf) # One application of the relation assigned to `term_rdcd` - single_apply_g = relation(in_term_rf, term_rdcd) + single_apply_g = relation(in_term_rf, term_rdcd, *args, **kwargs) # Assign/equate (unify, really) the result of a single application to # the "output" term. @@ -184,7 +167,7 @@ def reduceo_goal(s): # Recurse into applications of the relation (well, produce a goal that # will do that) - another_apply_g = reduceo(relation, term_rdcd, out_term_rf) + another_apply_g = reduceo(relation, term_rdcd, out_term_rf, *args, **kwargs) # We want the fixed-point value to show up in the stream output # *first*, but that requires some checks. @@ -212,7 +195,14 @@ def reduceo_goal(s): return reduceo_goal -def walko(goal, graph_in, graph_out, rator_goal=None, null_type=False): +def walko( + goal, + graph_in, + graph_out, + rator_goal=None, + null_type=etuple, + map_rel=partial(map_anyo, null_res=True), +): """Apply a binary relation between all nodes in two graphs. When `rator_goal` is used, the graphs are treated as term graphs, and the @@ -234,30 +224,35 @@ def walko(goal, graph_in, graph_out, rator_goal=None, null_type=False): null_type: type The collection type used when it is not fully determined by the graph arguments. + map_rel: callable + The map relation used to apply `goal` to a sub-graph. """ def walko_goal(s): - nonlocal goal, rator_goal, graph_in, graph_out, null_type + nonlocal goal, rator_goal, graph_in, graph_out, null_type, map_rel graph_in_rf, graph_out_rf = reify((graph_in, graph_out), s) rator_in, rands_in, rator_out, rands_out = var(), var(), var(), var() - _walko = partial(walko, goal, rator_goal=rator_goal, null_type=null_type) + _walko = partial( + walko, goal, rator_goal=rator_goal, null_type=null_type, map_rel=map_rel + ) g = conde( + # TODO: Use `Zzz`, if needed. [goal(graph_in_rf, graph_out_rf),], [ lall( applyo(rator_in, rands_in, graph_in_rf), applyo(rator_out, rands_out, graph_out_rf), rator_goal(rator_in, rator_out), - map_anyo(_walko, rands_in, rands_out, null_type=null_type), + map_rel(_walko, rands_in, rands_out, null_type=null_type), ) if rator_goal is not None else lall( - map_anyo(_walko, graph_in_rf, graph_out_rf, null_type=null_type) + map_rel(_walko, graph_in_rf, graph_out_rf, null_type=null_type) ), ], ) diff --git a/tests/test_graph.py b/tests/test_graph.py index 0dd8463..fc5f154 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -4,7 +4,7 @@ from functools import partial from math import log, exp -from unification import var, unify +from unification import var, unify, isvar, reify from etuples.dispatch import rator, rands, apply from etuples.core import etuple, ExpressionTuple @@ -13,7 +13,7 @@ from kanren import run, eq, conde, lall from kanren.constraints import isinstanceo -from kanren.graph import applyo, reduceo, map_anyo, walko, mapo +from kanren.graph import applyo, reduceo, map_anyo, walko, mapo, eq_length class OrderedFunction(object): @@ -170,6 +170,29 @@ def blah(x, y): assert run(5, [q_lv, a_lv], mapo(blah, q_lv, a_lv)) == exp_res +def test_eq_length(): + q_lv = var() + + res = run(0, q_lv, eq_length([1, 2, 3], q_lv)) + assert len(res) == 1 and len(res[0]) == 3 and all(isvar(q) for q in res[0]) + + res = run(0, q_lv, eq_length(q_lv, [1, 2, 3])) + assert len(res) == 1 and len(res[0]) == 3 and all(isvar(q) for q in res[0]) + + res = run(0, q_lv, eq_length(cons(1, q_lv), [1, 2, 3])) + assert len(res) == 1 and len(res[0]) == 2 and all(isvar(q) for q in res[0]) + + v_lv = var() + res = run(3, (q_lv, v_lv), eq_length(q_lv, v_lv, default_ConsNull=tuple)) + assert len(res) == 3 and all( + isinstance(a, tuple) + and len(a) == len(b) + and (len(a) == 0 or a != b) + and all(isvar(r) for r in a) + for a, b in res + ) + + def test_map_anyo_types(): """Make sure that `applyo` preserves the types between its arguments.""" q_lv = var() @@ -219,6 +242,30 @@ def one_to_threeo(x, y): test_res = run(4, q_lv, map_anyo(math_reduceo, q_lv, var("z"), tuple)) assert all(isinstance(r, tuple) for r in test_res) + x, y, z = var(), var(), var() + + def test_bin(a, b): + return conde([eq(a, 1), eq(b, 2)]) + + res = run(10, (x, y), map_anyo(test_bin, x, y, null_type=tuple)) + exp_res_form = ( + ((1,), (2,)), + ((x, 1), (x, 2)), + ((1, 1), (2, 2)), + ((x, y, 1), (x, y, 2)), + ((1, x), (2, x)), + ((x, 1, 1), (x, 2, 2)), + ((1, 1, 1), (2, 2, 2)), + ((x, y, z, 1), (x, y, z, 2)), + ((1, x, 1), (2, x, 2)), + ((x, 1, y), (x, 2, y)), + ) + + for a, b in zip(res, exp_res_form): + s = unify(a, b) + assert s is not False + assert all(isvar(i) for i in reify((x, y, z), s)) + @pytest.mark.parametrize( "test_input, test_output", From 18893854112aa8c497b95d8d41353c7e61421857 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sun, 26 Jan 2020 19:32:27 -0600 Subject: [PATCH 10/15] Implement non-relation goal ifa --- kanren/core.py | 16 ++++++++++++++++ tests/test_core.py | 26 ++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/kanren/core.py b/kanren/core.py index ff21a86..4e41ae8 100644 --- a/kanren/core.py +++ b/kanren/core.py @@ -186,6 +186,22 @@ def everyg(predicate, coll): return (lall,) + tuple((predicate, x) for x in coll) +def ifa(g1, g2): + """Create a goal operator that returns the first stream unless it fails.""" + + def ifa_goal(S): + g1_stream = goaleval(g1)(S) + S_new = next(g1_stream, None) + + if S_new is None: + yield from goaleval(g2)(S) + else: + yield S_new + yield from g1_stream + + return ifa_goal + + def Zzz(gctor, *args, **kwargs): """Create an inverse-η-delay for a goal.""" diff --git a/tests/test_core.py b/tests/test_core.py index ad168a3..ce1da7e 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -19,6 +19,7 @@ earlysafe, lallfirst, condeseq, + ifa, ) from kanren.util import evalt @@ -224,3 +225,28 @@ def lefto(q, p, lst): (solution,) = run(1, vals, rules_greedy) assert solution == ("green", "white") + + +def test_ifa(): + x, y = var(), var() + + assert run(0, (x, y), ifa(lall(eq(x, True), eq(y, 1)), eq(y, 2))) == ((True, 1),) + assert run( + 0, y, eq(x, False), ifa(lall(eq(x, True), eq(y, 1)), lall(eq(y, 2))) + ) == (2,) + assert ( + run( + 0, + y, + eq(x, False), + ifa(lall(eq(x, True), eq(y, 1)), lall(eq(x, True), eq(y, 2))), + ) + == () + ) + + assert run( + 0, + y, + eq(x, True), + ifa(lall(eq(x, True), eq(y, 1)), lall(eq(x, True), eq(y, 2))), + ) == (1,) From f1534e6c01e76220115682f117025e27a8d6ff18 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sun, 26 Jan 2020 19:35:20 -0600 Subject: [PATCH 11/15] Add option to omit identity results from permuteo --- kanren/goals.py | 64 +++++++++++++++++++++++++++++++++------------ tests/test_goals.py | 20 +++++++++++++- 2 files changed, 66 insertions(+), 18 deletions(-) diff --git a/kanren/goals.py b/kanren/goals.py index 3dba1af..3cf922c 100644 --- a/kanren/goals.py +++ b/kanren/goals.py @@ -212,7 +212,7 @@ def rembero_goal(s): return rembero_goal -def permuteo(a, b, inner_eq=eq, default_ConsNull=list): +def permuteo(a, b, inner_eq=eq, default_ConsNull=list, no_ident=False): """Construct a goal asserting equality or sequences under permutation. For example, (1, 2, 2) equates to (2, 1, 2) under permutation @@ -239,43 +239,70 @@ def permuteo_goal(S): a_type = type(a_rf) - if isinstance(b_rf, Sequence): + a_perms = permutations(a_rf) + + if no_ident: + next(a_perms) - # `a` and `b` are sequences, so let's see if we can - # pull out all the equal elements using their hashes. + if isinstance(b_rf, Sequence): b_type = type(b_rf) - if a_type != b_type: + # Fail on mismatched types or straight equality (when + # `no_ident` is enabled) + if a_type != b_type or (no_ident and a_rf == b_rf): return try: + # `a` and `b` are sequences, so let's see if we can pull out + # all the (hash-)equivalent elements. + # XXX: Use of this requires that the equivalence relation + # implied by `inner_eq` be a *superset* of `eq`. + cntr_a, cntr_b = Counter(a_rf), Counter(b_rf) rdcd_a, rdcd_b = cntr_a - cntr_b, cntr_b - cntr_a - a_rf, b_rf = tuple(rdcd_a.elements()), b_type(rdcd_b.elements()) + + if len(rdcd_a) == len(rdcd_b) == 0: + yield S + return + elif len(rdcd_a) < len(cntr_a): + a_rf, b_rf = tuple(rdcd_a.elements()), b_type(rdcd_b.elements()) + a_perms = permutations(a_rf) + except TypeError: # TODO: We could probably get more coverage for this case # by using `HashableForm`. pass - # If they're both ground, then simply check that one is a - # permutation of the other and be done - if isground(a_rf, S) and isground(b_rf, S): - if a_rf in permutations(b_rf): + # If they're both ground and we're using basic unification, + # then simply check that one is a permutation of the other and + # be done. No need to create and evaluate a bunch of goals in + # order to do something that can be done right here. + # Naturally, this assumes that the `isground` checks aren't + # nearly as costly as all that other stuff. If the gains + # depend on the sizes of `a` and `b`, then we could do + # `length_hint` checks first. + if inner_eq == eq and isground(a_rf, S) and isground(b_rf, S): + if tuple(b_rf) in a_perms: yield S return else: + # This has to be a definitive check, since we can only + # use the `a_perms` generator once; plus, we don't want + # to iterate over it more than once! return - # Unify all permutations of the sequence `a` with `b` - yield from lanyseq(inner_eq(b_rf, a_type(i)) for i in permutations(a_rf))(S) + yield from lanyseq(inner_eq(b_rf, a_type(i)) for i in a_perms)(S) elif isinstance(b_rf, Sequence): b_type = type(b_rf) + b_perms = permutations(b_rf) + + if no_ident: + next(b_perms) - # Unify all permutations of the sequence `b` with `a` - yield from lanyseq(inner_eq(a_rf, b_type(i)) for i in permutations(b_rf))(S) + yield from lanyseq(inner_eq(a_rf, b_type(i)) for i in b_perms)(S) else: @@ -289,9 +316,12 @@ def permuteo_goal(S): for S_new in a_itero_g(S): a_new = reify(a_rf, S_new) a_type = type(a_new) - yield from lanyseq( - inner_eq(b_rf, a_type(i)) for i in permutations(a_new) - )(S_new) + a_perms = permutations(a_new) + + if no_ident: + next(a_perms) + + yield from lanyseq(inner_eq(b_rf, a_type(i)) for i in a_perms)(S_new) return permuteo_goal diff --git a/tests/test_goals.py b/tests/test_goals.py index ec64532..75d1ce3 100644 --- a/tests/test_goals.py +++ b/tests/test_goals.py @@ -16,7 +16,7 @@ rembero, permuteo, ) -from kanren.core import eq, goaleval, run +from kanren.core import eq, goaleval, run, conde def results(g, s=None): @@ -304,3 +304,21 @@ def __hash__(self): assert bi_res[3][0] != bi_res[3][1] == [bi_var_2, bi_var_1] bi_var_3 = bi_res[4][0][2] assert bi_res[4][0] == bi_res[4][1] == [bi_var_1, bi_var_2, bi_var_3] + + assert run(0, x, permuteo((1, 2), (1, 2), no_ident=True)) == () + assert run(0, True, permuteo((1, 2), (2, 1), no_ident=True)) == (True,) + assert run(0, x, permuteo((), x, no_ident=True)) == () + assert run(0, x, permuteo(x, (), no_ident=True)) == () + assert run(0, x, permuteo((1,), x, no_ident=True)) == () + assert run(0, x, permuteo(x, (1,), no_ident=True)) == () + assert (1, 2, 3) not in run(0, x, permuteo((1, 2, 3), x, no_ident=True)) + assert (1, 2, 3) not in run(0, x, permuteo(x, (1, 2, 3), no_ident=True)) + y = var() + assert all(a != b for a, b in run(6, [x, y], permuteo(x, y, no_ident=True))) + + def eq_permute(x, y): + return conde([eq(x, y)], [permuteo(a, b) for a, b in zip(x, y)]) + + assert run( + 0, True, permuteo((1, (2, 3)), ((3, 2), 1), inner_eq=eq_permute, no_ident=True) + ) == (True,) From dd40fbee4743e63c12da92d2db76bc60218c90ec Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sun, 26 Jan 2020 19:41:08 -0600 Subject: [PATCH 12/15] Add a debugger goal --- kanren/core.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/kanren/core.py b/kanren/core.py index 4e41ae8..224d32d 100644 --- a/kanren/core.py +++ b/kanren/core.py @@ -293,3 +293,24 @@ def goaleval(goal): if isinstance(goal, tuple): # goal is not yet evaluated like (eq, x, 1) return find_fixed_point(evalt, goal) raise TypeError("Expected either function or tuple") + + +def dbgo(*args, msg=None): # pragma: no cover + """Construct a goal that sets a debug trace and prints reified arguments.""" + from pprint import pprint + + def dbgo_goal(S): + nonlocal args + args = reify(args, S) + + if msg is not None: + print(msg) + + pprint(args) + + import pdb + + pdb.set_trace() + yield S + + return dbgo_goal From 248f6947fe8a58c0c5e2e9deb7dd64bc4b89a50a Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Wed, 29 Jan 2020 14:02:15 -0600 Subject: [PATCH 13/15] Rename success to succeed --- examples/prime.py | 6 +++--- kanren/core.py | 6 +++--- kanren/graph.py | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/prime.py b/examples/prime.py index b070ba6..76be58d 100644 --- a/examples/prime.py +++ b/examples/prime.py @@ -6,8 +6,8 @@ from unification import isvar from kanren import membero -from kanren.core import (success, fail, var, run, - condeseq, eq) +from kanren.core import succeed, fail, var, run, condeseq, eq + try: import sympy.ntheory.generate as sg except ImportError: @@ -19,7 +19,7 @@ def primo(x): if isvar(x): return condeseq([(eq, x, p)] for p in map(sg.prime, it.count(1))) else: - return success if sg.isprime(x) else fail + return succeed if sg.isprime(x) else fail def test_primo(): diff --git a/kanren/core.py b/kanren/core.py index 224d32d..8439cbf 100644 --- a/kanren/core.py +++ b/kanren/core.py @@ -11,7 +11,7 @@ def fail(s): return iter(()) -def success(s): +def succeed(s): return iter((s,)) @@ -55,7 +55,7 @@ def lallgreedy(*goals): """ if not goals: - return success + return succeed if len(goals) == 1: return goals[0] @@ -74,7 +74,7 @@ def allgoal(s): def lallfirst(*goals): """Construct a logical all that runs goals one at a time.""" if not goals: - return success + return succeed if len(goals) == 1: return goals[0] diff --git a/kanren/graph.py b/kanren/graph.py index 35652c1..8a809df 100644 --- a/kanren/graph.py +++ b/kanren/graph.py @@ -7,7 +7,7 @@ from etuples import etuple, apply, rands, rator -from .core import eq, conde, lall, goaleval, success, Zzz, fail +from .core import eq, conde, lall, goaleval, succeed, Zzz, fail from .goals import conso, nullo @@ -130,7 +130,7 @@ def map_anyo( def vararg_success(*args): - return success + return succeed def eq_length(u, v, default_ConsNull=list): From 55f8a405e3c04bd4c3587c3d89ba2e3667dd62f5 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Wed, 29 Jan 2020 14:29:13 -0600 Subject: [PATCH 14/15] Add a non-relational groundedness ordering goal --- kanren/core.py | 36 +++++++++++++++++++++++++++++++++++- tests/test_core.py | 13 +++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/kanren/core.py b/kanren/core.py index 8439cbf..d4e5c03 100644 --- a/kanren/core.py +++ b/kanren/core.py @@ -1,8 +1,11 @@ from itertools import tee from functools import partial +from collections.abc import Sequence from toolz import groupby, map -from unification import reify, unify +from cons.core import ConsPair +from unification import reify, unify, isvar +from unification.core import isground from .util import dicthash, interleave, take, multihash, unique, evalt @@ -186,6 +189,37 @@ def everyg(predicate, coll): return (lall,) + tuple((predicate, x) for x in coll) +def ground_order_key(S, x): + if isvar(x): + return 2 + elif isground(x, S): + return -1 + elif issubclass(type(x), ConsPair): + return 1 + else: + return 0 + + +def ground_order(in_args, out_args): + """Construct a non-relational goal that orders a list of terms based on groundedness (grounded precede ungrounded).""" + + def ground_order_goal(S): + nonlocal in_args, out_args + + in_args_rf, out_args_rf = reify((in_args, out_args), S) + + S_new = unify( + list(out_args_rf) if isinstance(out_args_rf, Sequence) else out_args_rf, + sorted(in_args_rf, key=partial(ground_order_key, S)), + S, + ) + + if S_new is not False: + yield S_new + + return ground_order_goal + + def ifa(g1, g2): """Create a goal operator that returns the first stream unless it fails.""" diff --git a/tests/test_core.py b/tests/test_core.py index ce1da7e..a5b0212 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -2,6 +2,7 @@ from pytest import raises, mark +from cons import cons from unification import var, isvar from kanren.core import ( @@ -20,6 +21,7 @@ lallfirst, condeseq, ifa, + ground_order, ) from kanren.util import evalt @@ -250,3 +252,14 @@ def test_ifa(): eq(x, True), ifa(lall(eq(x, True), eq(y, 1)), lall(eq(x, True), eq(y, 2))), ) == (1,) + + +def test_ground_order(): + x, y, z = var(), var(), var() + assert run(0, x, ground_order((y, [1, z], 1), x)) == ([1, [1, z], y],) + a, b, c = var(), var(), var() + assert run(0, (a, b, c), ground_order((y, [1, z], 1), (a, b, c))) == ( + (1, [1, z], y), + ) + assert run(0, z, ground_order([cons(x, y), (x, y)], z)) == ([(x, y), cons(x, y)],) + assert run(0, z, ground_order([(x, y), cons(x, y)], z)) == ([(x, y), cons(x, y)],) From 0ea3200816b5a184877d484f65d85ca52d5b49e2 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Fri, 31 Jan 2020 00:29:57 -0600 Subject: [PATCH 15/15] Completely refactor associative and commutative functionality This approach uses associative flattening (e.g. `(op, 1, (op, 2, 3))` is flattened to `(op, 1, 2, 3)`), which now makes nested associative operations possible. --- kanren/assoccomm.py | 393 +++++++++++++-------------- kanren/graph.py | 68 ++++- tests/test_assoccomm.py | 587 ++++++++++++++++++++++++++++++---------- 3 files changed, 703 insertions(+), 345 deletions(-) diff --git a/kanren/assoccomm.py b/kanren/assoccomm.py index 48b56bd..e9f1fe6 100644 --- a/kanren/assoccomm.py +++ b/kanren/assoccomm.py @@ -28,120 +28,168 @@ >>> print(run(0, (x,y), eq(pattern, expr))) ((3, 2),) """ +from functools import partial +from operator import length_hint, eq as equal +from collections.abc import Sequence -from unification.utils import transitive_get as walk -from unification import isvar, var - -from cons.core import ConsError - -from . import core -from .core import ( - unify, - conde, - eq, - fail, - lallgreedy, - EarlyGoalError, - condeseq, - goaleval, -) -from .goals import permuteq -from .facts import Relation -from .util import groupsizes, index -from .term import term, arguments, operator -from .graph import mapo +from toolz import sliding_window -associative = Relation("associative") -commutative = Relation("commutative") +from unification import isvar, var, reify, unify +from cons.core import ConsError, ConsPair, car, cdr -def assocunify(u, v, s, eq=core.eq, n=None): - """Perform associative unification. +from etuples import etuple - See Also - -------- - eq_assoccomm - """ - uop, uargs = op_args(u) - vop, vargs = op_args(v) - - if not uop and not vop: - res = unify(u, v, s) - if res is not False: - return (res,) # TODO: iterate through all possibilities - - if uop and vop: - s = unify(uop, vop, s) - if s is False: - return ().__iter__() - op = walk(uop, s) +from .core import conde, condeseq, eq, goaleval, ground_order, lall, succeed +from .goals import itero, permuteo +from .facts import Relation +from .graph import applyo, term_walko +from .term import term, operator, arguments - sm, lg = (uargs, vargs) if len(uargs) <= len(vargs) else (vargs, uargs) - ops = assocsized(op, lg, len(sm)) - goal = condeseq([(eq, a, b) for a, b, in zip(sm, lg2)] for lg2 in ops) - return goaleval(goal)(s) +associative = Relation("associative") +commutative = Relation("commutative") - if uop: - op, tail = uop, uargs - b = v - if vop: - op, tail = vop, vargs - b = u +# For backward compatibility +buildo = applyo - ns = [n] if n else range(2, len(tail) + 1) - knowns = (build(op, x) for n in ns for x in assocsized(op, tail, n)) - goal = condeseq([(core.eq, b, k)] for k in knowns) - return goaleval(goal)(s) +def op_args(x): + """Break apart x into an operation and tuple of args.""" + if isvar(x): + return None, None + try: + return operator(x), arguments(x) + except (ConsError, NotImplementedError): + return None, None -def assocsized(op, tail, n): - """Produce all associative combinations of x in n groups.""" - gsizess = groupsizes(len(tail), n) - partitions = (groupsizes_to_partition(*gsizes) for gsizes in gsizess) - return (makeops(op, partition(tail, part)) for part in partitions) +def flatten_assoc_args(op_predicate, items): + for i in items: + if isinstance(i, ConsPair) and op_predicate(car(i)): + i_cdr = cdr(i) + if length_hint(i_cdr) > 0: + yield from flatten_assoc_args(op_predicate, i_cdr) + else: + yield i + else: + yield i -def makeops(op, lists): - """Construct operations from an op and parition lists. +def assoc_args(rator, rands, n, ctor=None): + """Produce all associative argument combinations of rator + rands in n-sized rand groupings. - >>> from kanren.assoccomm import makeops - >>> makeops('add', [(1, 2), (3, 4, 5)]) - (('add', 1, 2), ('add', 3, 4, 5)) + >>> from kanren.assoccomm import assoc_args + >>> list(assoc_args('op', [1, 2, 3], 2)) + [[['op', 1, 2], 3], [1, ['op', 2, 3]]] """ - return tuple(l[0] if len(l) == 1 else build(op, l) for l in lists) + assert n > 0 + rands_l = list(rands) -def partition(tup, part): - """Partition a tuple. + if ctor is None: + ctor = type(rands) - >>> from kanren.assoccomm import partition - >>> partition("abcde", [[0,1], [4,3,2]]) - [('a', 'b'), ('e', 'd', 'c')] - """ - return [index(tup, ind) for ind in part] + if n == len(rands_l): + yield ctor(rands) + return + + for i, new_rands in enumerate(sliding_window(n, rands_l)): + prefix = rands_l[:i] + new_term = term(rator, ctor(new_rands)) + suffix = rands_l[n + i :] + res = ctor(prefix + [new_term] + suffix) + yield res -def groupsizes_to_partition(*gsizes): - """Create a list of ranges from their sizes. +def eq_assoc_args( + op, a_args, b_args, n=None, inner_eq=eq, no_ident=False, null_type=etuple +): + """Create a goal that applies associative unification to an operator and two sets of arguments. - >>> from kanren.assoccomm import groupsizes_to_partition - >>> groupsizes_to_partition(2, 3) - [[0, 1], [2, 3, 4]] + This is a non-relational utility goal. It does assumes that the op and at + least one set of arguments are ground under the state in which it is + evaluated. """ - idx = 0 - part = [] - for gs in gsizes: - l_ = [] - for i in range(gs): - l_.append(idx) - idx += 1 - part.append(l_) - return part + u_args, v_args = var(), var() + + def eq_assoc_args_goal(S): + nonlocal op, u_args, v_args, n + + (op_rf, u_args_rf, v_args_rf, n_rf) = reify((op, u_args, v_args, n), S) + + if isinstance(v_args_rf, Sequence): + u_args_rf, v_args_rf = v_args_rf, u_args_rf + + if isinstance(u_args_rf, Sequence) and isinstance(v_args_rf, Sequence): + # TODO: We just ignore `n` when both are sequences? + + if type(u_args_rf) != type(v_args_rf): + return + + if no_ident and unify(u_args_rf, v_args_rf, S) is not False: + return + + op_pred = partial(equal, op_rf) + u_args_flat = type(u_args_rf)(flatten_assoc_args(op_pred, u_args_rf)) + v_args_flat = type(v_args_rf)(flatten_assoc_args(op_pred, v_args_rf)) + + if len(u_args_flat) == len(v_args_flat): + g = inner_eq(u_args_flat, v_args_flat) + else: + if len(u_args_flat) < len(v_args_flat): + sm_args, lg_args = u_args_flat, v_args_flat + else: + sm_args, lg_args = v_args_flat, u_args_flat + + grp_sizes = len(lg_args) - len(sm_args) + 1 + assoc_terms = assoc_args( + op_rf, lg_args, grp_sizes, ctor=type(u_args_rf) + ) + + g = condeseq([inner_eq(sm_args, a_args)] for a_args in assoc_terms) + + yield from goaleval(g)(S) + + elif isinstance(u_args_rf, Sequence): + # TODO: We really need to know the arity (ranges) for the operator + # in order to make good choices here. + # For instance, does `(op, 1, 2) == (op, (op, 1, 2))` make sense? + # If so, the lower-bound on this range should actually be `1`. + if len(u_args_rf) == 1: + if not no_ident and (n_rf == 1 or n_rf is None): + g = inner_eq(u_args_rf, v_args_rf) + else: + return + else: + + u_args_flat = list(flatten_assoc_args(partial(equal, op_rf), u_args_rf)) + + if n_rf is not None: + arg_sizes = [n_rf] + else: + arg_sizes = range(2, len(u_args_flat) + (not no_ident)) + + v_ac_args = ( + v_ac_arg + for n_i in arg_sizes + for v_ac_arg in assoc_args( + op_rf, u_args_flat, n_i, ctor=type(u_args_rf) + ) + if not no_ident or v_ac_arg != u_args_rf + ) + g = condeseq([inner_eq(v_args_rf, v_ac_arg)] for v_ac_arg in v_ac_args) + + yield from goaleval(g)(S) + + return lall( + ground_order((a_args, b_args), (u_args, v_args)), + itero(u_args, nullo_refs=(v_args,), default_ConsNull=null_type), + eq_assoc_args_goal, + ) -def eq_assoc(u, v, eq=core.eq, n=None): - """Create a goal for associative equality. +def eq_assoc(u, v, n=None, op_predicate=associative, null_type=etuple): + """Create a goal for associative unification of two terms. >>> from kanren import run, var, fact >>> from kanren.assoccomm import eq_assoc as eq @@ -153,28 +201,14 @@ def eq_assoc(u, v, eq=core.eq, n=None): >>> run(0, x, eq(('add', 1, 2, 3), ('add', 1, x))) (('add', 2, 3),) """ - uop, _ = op_args(u) - vop, _ = op_args(v) - if uop and vop: - return conde( - [(core.eq, u, v)], - [(eq, uop, vop), (associative, uop), lambda s: assocunify(u, v, s, eq, n)], - ) - - if uop or vop: - if vop: - uop, vop = vop, uop - v, u = u, v - return conde( - [(core.eq, u, v)], - [(associative, uop), lambda s: assocunify(u, v, s, eq, n)], - ) + def assoc_args_unique(a, b, op, **kwargs): + return eq_assoc_args(op, a, b, no_ident=True, null_type=null_type) - return (core.eq, u, v) + return term_walko(op_predicate, assoc_args_unique, u, v, n=n) -def eq_comm(u, v, inner_eq=None): +def eq_comm(u, v, op_predicate=commutative, null_type=etuple): """Create a goal for commutative equality. >>> from kanren import run, var, fact @@ -188,88 +222,37 @@ def eq_comm(u, v, inner_eq=None): >>> run(0, x, eq(('add', 1, 2, 3), ('add', 2, x, 1))) (3,) """ - inner_eq = inner_eq or eq_comm - vtail, vhead = var(), var() - - if isvar(u) and isvar(v): - return eq(u, v) - - uop, uargs = op_args(u) - vop, vargs = op_args(v) - if not uop and not vop: - return eq(u, v) + def permuteo_unique(x, y, op, **kwargs): + return permuteo(x, y, no_ident=True, default_ConsNull=null_type) - if vop and not uop: - uop, uargs = vop, vargs - v, u = u, v + return term_walko(op_predicate, permuteo_unique, u, v) - return ( - conde, - [eq(u, v)], - [ - (buildo, vhead, vtail, v), - ( - conde, - [ - (inner_eq, uop, vhead), - (mapo, lambda a, b: (eq_comm, a, b, inner_eq), uargs, vtail), - ], - [ - eq(uop, vhead), - (commutative, uop), - (permuteq, uargs, vtail, inner_eq), - ], - ), - ], - ) +def assoc_flatten(a, a_flat): + def assoc_flatten_goal(S): + nonlocal a, a_flat -def buildo(op, args, obj): - """Construct a goal that relates an object to its op and args. + a_rf = reify(a, S) - For example, in `add(1,2,3)`, `add` is the op and `(1,2,3)` are the args. + if isinstance(a_rf, Sequence) and (a_rf[0],) in associative.facts: - Checks op_regsitry for functions to define op/arg relationships - """ - if not isvar(obj): - oop, oargs = op_args(obj) - # TODO: Is greedy correct? - return lallgreedy((eq, op, oop), (eq, args, oargs)) - else: - try: - return eq(obj, build(op, args)) - except TypeError: - raise EarlyGoalError() - raise EarlyGoalError() - - -def build(op, args): - try: - return term(op, args) - except NotImplementedError: - raise EarlyGoalError() + def op_pred(sub_op): + nonlocal S + sub_op_rf = reify(sub_op, S) + return sub_op_rf == a_rf[0] + a_flat_rf = type(a_rf)(flatten_assoc_args(op_pred, a_rf)) + else: + a_flat_rf = a_rf -def op_args(x): - """Break apart x into an operation and tuple of args.""" - if isvar(x): - return None, None - try: - return operator(x), arguments(x) - except (ConsError, NotImplementedError): - return None, None + yield from eq(a_flat, a_flat_rf)(S) + return assoc_flatten_goal -def eq_assoccomm(u, v): - """Construct a goal for associative and commutative eq. - Works like logic.core.eq but supports associative/commutative expr trees - - tree-format: (op, *args) - example: (add, 1, 2, 3) - - State that operations are associative or commutative with relations +def eq_assoccomm(u, v, null_type=etuple): + """Construct a goal for associative and commutative unification. >>> from kanren.assoccomm import eq_assoccomm as eq >>> from kanren.assoccomm import commutative, associative @@ -282,33 +265,31 @@ def eq_assoccomm(u, v): >>> e1 = ('add', 1, 2, 3) >>> e2 = ('add', 1, x) >>> run(0, x, eq(e1, e2)) - (('add', 2, 3), ('add', 3, 2)) + (('add', 3, 2), ('add', 2, 3)) """ - uop, uargs = op_args(u) - vop, vargs = op_args(v) - - if uop and not vop and not isvar(v): - return fail - if vop and not uop and not isvar(u): - return fail - if uop and vop and uop != vop: - return fail - if uop and not (uop,) in associative.facts: - return (eq, u, v) - if vop and not (vop,) in associative.facts: - return (eq, u, v) - - if uop and vop: - u, v = (u, v) if len(uargs) >= len(vargs) else (v, u) - n = min(map(len, (uargs, vargs))) # length of shorter tail - else: - n = None - if vop and not uop: - u, v = v, u - w = var() - # TODO: Is greedy correct? - return ( - lallgreedy, - (eq_assoc, u, w, eq_assoccomm, n), - (eq_comm, v, w, eq_assoccomm), + + def eq_assoccomm_step(a, b, op): + z = var() + return lall( + # Permute + conde( + [ + commutative(op), + permuteo(a, z, no_ident=True, default_ConsNull=etuple), + ], + [eq(a, z)], + ), + # Generate associative combinations + conde( + [associative(op), eq_assoc_args(op, z, b, no_ident=True)], [eq(z, b)] + ), + ) + + return term_walko( + lambda x: succeed, + eq_assoccomm_step, + u, + v, + format_step=assoc_flatten, + no_ident=False, ) diff --git a/kanren/graph.py b/kanren/graph.py index 8a809df..49adc2c 100644 --- a/kanren/graph.py +++ b/kanren/graph.py @@ -7,7 +7,7 @@ from etuples import etuple, apply, rands, rator -from .core import eq, conde, lall, goaleval, succeed, Zzz, fail +from .core import eq, conde, lall, goaleval, succeed, Zzz, fail, ground_order from .goals import conso, nullo @@ -260,3 +260,69 @@ def walko_goal(s): yield from goaleval(g)(s) return walko_goal + + +def term_walko( + rator_goal, + rands_goal, + a, + b, + null_type=etuple, + no_ident=False, + format_step=None, + **kwargs +): + """Construct a goal for walking a term graph. + + This implementation is somewhat specific to the needs of `eq_comm` and + `eq_assoc`, but it could be transferred to `kanren.graph`. + + XXX: Make sure `rator_goal` will succeed for unground logic variables; + otherwise, this will diverge. + XXX: `rands_goal` should not be contain `eq`, i.e. `rands_goal(x, x)` + should always fail! + """ + + def single_step(s, t): + u, v = var(), var() + u_rator, u_rands = var(), var() + v_rands = var() + + return lall( + ground_order((s, t), (u, v)), + applyo(u_rator, u_rands, u), + applyo(u_rator, v_rands, v), + rator_goal(u_rator), + # These make sure that there are at least two rands, which + # makes sense for commutativity and associativity, at least. + conso(var(), var(), u_rands), + conso(var(), var(), v_rands), + Zzz(rands_goal, u_rands, v_rands, u_rator, **kwargs), + ) + + def term_walko_step(s, t): + nonlocal rator_goal, rands_goal, null_type + u, v = var(), var() + z, w = var(), var() + + return lall( + ground_order((s, t), (u, v)), + format_step(u, w) if format_step is not None else eq(u, w), + conde( + [ + # Apply, then walk or return + single_step(w, v), + ], + [ + # Walk, then apply or return + map_anyo(term_walko_step, w, z, null_type=null_type), + conde([eq(z, v)], [single_step(z, v)]), + ], + ), + ) + + return lall( + term_walko_step(a, b) + if no_ident + else conde([term_walko_step(a, b)], [eq(a, b)]), + ) diff --git a/tests/test_assoccomm.py b/tests/test_assoccomm.py index 0851a4d..08a33b9 100644 --- a/tests/test_assoccomm.py +++ b/tests/test_assoccomm.py @@ -1,31 +1,30 @@ -from __future__ import absolute_import - import pytest -from unification import reify, var, variables +from collections.abc import Sequence + +from etuples.core import etuple + +from cons import cons -from kanren.core import run, goaleval +from unification import reify, var, variables, isvar, unify + +from kanren.core import goaleval, run_all as run from kanren.facts import fact from kanren.assoccomm import ( associative, commutative, - groupsizes_to_partition, - assocunify, + eq_assoc_args, eq_comm, eq_assoc, eq_assoccomm, - assocsized, + assoc_args, buildo, op_args, + flatten_assoc_args, + assoc_flatten, ) from kanren.term import operator, arguments, term -x, y = var("x"), var("y") - -a, c = "assoc_op", "comm_op" -fact(associative, a) -fact(commutative, c) - class Node(object): def __init__(self, op, args): @@ -65,7 +64,7 @@ def mul(*args): return Node(Mul, args) -@term.register(Operator, (tuple, list)) +@term.register(Operator, Sequence) def term_Operator(op, args): return Node(op, args) @@ -86,182 +85,494 @@ def results(g, s=None): return tuple(goaleval(g)(s)) +def test_op_args(): + assert op_args(var()) == (None, None) + assert op_args(add(1, 2, 3)) == (Add, (1, 2, 3)) + assert op_args("foo") == (None, None) + + +def test_buildo(): + x = var() + assert run(0, x, buildo("add", (1, 2, 3), x)) == (("add", 1, 2, 3),) + assert run(0, x, buildo(x, (1, 2, 3), ("add", 1, 2, 3))) == ("add",) + assert run(0, x, buildo("add", x, ("add", 1, 2, 3))) == ((1, 2, 3),) + + +def test_buildo_object(): + x = var() + assert run(0, x, buildo(Add, (1, 2, 3), x)) == (add(1, 2, 3),) + assert run(0, x, buildo(x, (1, 2, 3), add(1, 2, 3))) == (Add,) + assert run(0, x, buildo(Add, x, add(1, 2, 3))) == ((1, 2, 3),) + + def test_eq_comm(): - assert results(eq_comm(1, 1)) - assert results(eq_comm((c, 1, 2, 3), (c, 1, 2, 3))) - assert results(eq_comm((c, 3, 2, 1), (c, 1, 2, 3))) - assert not results(eq_comm((a, 3, 2, 1), (a, 1, 2, 3))) # not commutative - assert not results(eq_comm((3, c, 2, 1), (c, 1, 2, 3))) - assert not results(eq_comm((c, 1, 2, 1), (c, 1, 2, 3))) - assert not results(eq_comm((a, 1, 2, 3), (c, 1, 2, 3))) - assert len(results(eq_comm((c, 3, 2, 1), x))) >= 6 - assert results(eq_comm(x, y)) == ({x: y},) + x, y, z = var(), var(), var() + commutative.facts.clear() + commutative.index.clear() -def test_eq_assoc(): - assert results(eq_assoc(1, 1)) - assert results(eq_assoc((a, 1, 2, 3), (a, 1, 2, 3))) - assert not results(eq_assoc((a, 3, 2, 1), (a, 1, 2, 3))) - assert results(eq_assoc((a, (a, 1, 2), 3), (a, 1, 2, 3))) - assert results(eq_assoc((a, 1, 2, 3), (a, (a, 1, 2), 3))) - o = "op" - assert not results(eq_assoc((o, 1, 2, 3), (o, (o, 1, 2), 3))) + comm_op = "comm_op" + + fact(commutative, comm_op) - # See TODO in assocunify - gen = results(eq_assoc((a, 1, 2, 3), x, n=2)) - assert set(g[x] for g in gen).issuperset( - set([(a, (a, 1, 2), 3), (a, 1, (a, 2, 3))]) + assert run(0, True, eq_comm(1, 1)) == (True,) + assert run(0, True, eq_comm((comm_op, 1, 2, 3), (comm_op, 1, 2, 3))) == (True,) + + assert run(0, True, eq_comm((comm_op, 3, 2, 1), (comm_op, 1, 2, 3))) == (True,) + assert run(0, y, eq_comm((comm_op, 3, y, 1), (comm_op, 1, 2, 3))) == (2,) + assert run(0, (x, y), eq_comm((comm_op, x, y, 1), (comm_op, 1, 2, 3))) == ( + (2, 3), + (3, 2), ) - gen = results(eq_assoc(x, (a, 1, 2, 3), n=2)) - assert set(g[x] for g in gen).issuperset( - set([(a, (a, 1, 2), 3), (a, 1, (a, 2, 3))]) + assert run(0, (x, y), eq_comm((comm_op, 2, 3, 1), (comm_op, 1, x, y))) == ( + (2, 3), + (3, 2), ) + assert not run( + 0, True, eq_comm(("op", 3, 2, 1), ("op", 1, 2, 3)) + ) # not commutative + assert not run(0, True, eq_comm((3, comm_op, 2, 1), (comm_op, 1, 2, 3))) + assert not run(0, True, eq_comm((comm_op, 1, 2, 1), (comm_op, 1, 2, 3))) + assert not run(0, True, eq_comm(("op", 1, 2, 3), (comm_op, 1, 2, 3))) + + # Test for variable args + res = run(4, (x, y), eq_comm(x, y)) + exp_res_form = ( + (etuple(comm_op, x, y), etuple(comm_op, y, x)), + (x, y), + (etuple(etuple(comm_op, x, y)), etuple(etuple(comm_op, y, x))), + (etuple(comm_op, x, y, z), etuple(comm_op, x, z, y)), + ) -def test_eq_assoccomm(): - x, y = var(), var() - eqac = eq_assoccomm - ac = "commassoc_op" - fact(commutative, ac) - fact(associative, ac) - assert results(eqac(1, 1)) - assert results(eqac((1,), (1,))) - assert results(eqac(x, (1,))) - assert results(eqac((1,), x)) - assert results((eqac, 1, 1)) - assert results(eqac((a, (a, 1, 2), 3), (a, 1, 2, 3))) - assert results(eqac((ac, (ac, 1, x), y), (ac, 2, (ac, 3, 1)))) - assert results(eqac((ac, (ac, 1, 2), 3), (ac, 1, 2, 3))) - assert results(eqac((ac, 3, (ac, 1, 2)), (ac, 1, 2, 3))) - assert not results(eqac((ac, 1, 1), ("other_op", 1, 1))) - assert run(0, x, eqac((ac, 3, (ac, 1, 2)), (ac, 1, x, 3))) == (2,) - - -def test_expr(): - add = "add" - mul = "mul" - fact(commutative, add) - fact(associative, add) - fact(commutative, mul) - fact(associative, mul) + for a, b in zip(res, exp_res_form): + s = unify(a, b) + assert s is not False + assert all(isvar(i) for i in reify((x, y, z), s)) - x, y = var("x"), var("y") + # Make sure it can unify single elements + assert (3,) == run(0, x, eq_comm((comm_op, 1, 2, 3), (comm_op, 2, x, 1))) - pattern = (mul, (add, 1, x), y) # (1 + x) * y - expr = (mul, 2, (add, 3, 1)) # 2 * (3 + 1) - assert run(0, (x, y), eq_assoccomm(pattern, expr)) == ((3, 2),) + # `eq_comm` should propagate through + assert (3,) == run( + 0, x, eq_comm(("div", 1, (comm_op, 1, 2, 3)), ("div", 1, (comm_op, 2, x, 1))) + ) + # Now it should not + assert () == run( + 0, x, eq_comm(("div", 1, ("div", 1, 2, 3)), ("div", 1, ("div", 2, x, 1))) + ) + expected_res = {(1, 2, 3), (2, 1, 3), (3, 1, 2), (1, 3, 2), (2, 3, 1), (3, 2, 1)} + assert expected_res == set( + run(0, (x, y, z), eq_comm((comm_op, 1, 2, 3), (comm_op, x, y, z))) + ) + assert expected_res == set( + run(0, (x, y, z), eq_comm((comm_op, x, y, z), (comm_op, 1, 2, 3))) + ) + assert expected_res == set( + run( + 0, + (x, y, z), + eq_comm(("div", 1, (comm_op, 1, 2, 3)), ("div", 1, (comm_op, x, y, z))), + ) + ) -def test_deep_commutativity(): - x, y = var("x"), var("y") + e1 = (comm_op, (comm_op, 1, x), y) + e2 = (comm_op, 2, (comm_op, 3, 1)) + assert run(0, (x, y), eq_comm(e1, e2)) == ((3, 2),) - e1 = ((c, 3, 1),) - e2 = ((c, 1, x),) + e1 = ((comm_op, 3, 1),) + e2 = ((comm_op, 1, x),) assert run(0, x, eq_comm(e1, e2)) == (3,) - e1 = (2, (c, 3, 1)) - e2 = (y, (c, 1, x)) + e1 = (2, (comm_op, 3, 1)) + e2 = (y, (comm_op, 1, x)) assert run(0, (x, y), eq_comm(e1, e2)) == ((3, 2),) - e1 = (c, (c, 1, x), y) - e2 = (c, 2, (c, 3, 1)) + e1 = (comm_op, (comm_op, 1, x), y) + e2 = (comm_op, 2, (comm_op, 3, 1)) assert run(0, (x, y), eq_comm(e1, e2)) == ((3, 2),) -def test_groupsizes_to_parition(): - assert groupsizes_to_partition(2, 3) == [[0, 1], [2, 3, 4]] +@pytest.mark.xfail(reason="`applyo`/`buildo` needs to be a constraint.", strict=True) +def test_eq_comm_object(): + x = var("x") + + fact(commutative, Add) + fact(associative, Add) + + assert run(0, x, eq_comm(add(1, 2, 3), add(3, 1, x))) == (2,) + assert set(run(0, x, eq_comm(add(1, 2), x))) == set((add(1, 2), add(2, 1))) + assert set(run(0, x, eq_assoccomm(add(1, 2, 3), add(1, x)))) == set( + (add(2, 3), add(3, 2)) + ) -def test_assocunify(): - assert tuple(assocunify(1, 1, {})) - assert tuple(assocunify((a, 1, 1), (a, 1, 1), {})) - assert tuple(assocunify((a, 1, 2, 3), (a, 1, (a, 2, 3)), {})) - assert tuple(assocunify((a, 1, (a, 2, 3)), (a, 1, 2, 3), {})) - assert tuple(assocunify((a, 1, (a, 2, 3), 4), (a, 1, 2, 3, 4), {})) - assert tuple(assocunify((a, 1, x, 4), (a, 1, 2, 3, 4), {})) == ({x: (a, 2, 3)},) - assert tuple(assocunify((a, 1, 1), ("other_op", 1, 1), {})) == () +def test_flatten_assoc_args(): + op = "add" - assert tuple(assocunify((a, 1, 1), (x, 1, 1), {})) == ({x: a},) - assert tuple(assocunify((x, 1, 1), (a, 1, 1), {})) == ({x: a},) + def op_pred(x): + return x == op - gen = assocunify((a, 1, 2, 3), x, {}, n=2) - assert set(g[x] for g in gen) == set([(a, (a, 1, 2), 3), (a, 1, (a, 2, 3))]) - gen = assocunify(x, (a, 1, 2, 3), {}, n=2) - assert set(g[x] for g in gen) == set([(a, (a, 1, 2), 3), (a, 1, (a, 2, 3))]) + assert list(flatten_assoc_args(op_pred, [op, 1, 2, 3, 4])) == [op, 1, 2, 3, 4] + assert list(flatten_assoc_args(op_pred, [op, 1, 2, [op]])) == [op, 1, 2, [op]] + assert list(flatten_assoc_args(op_pred, [[op, 1, 2, [op]]])) == [1, 2, [op]] - gen = assocunify((a, 1, 2, 3), x, {}) - assert set(g[x] for g in gen) == set( - [(a, 1, 2, 3), (a, (a, 1, 2), 3), (a, 1, (a, 2, 3))] + res = list( + flatten_assoc_args( + op_pred, [[1, 2, op], 3, [op, 4, [op, [op]]], [op, 5], 6, op, 7] + ) ) + exp_res = [[1, 2, op], 3, 4, [op], 5, 6, op, 7] + assert res == exp_res -def test_assocsized(): - add = "add" - assert set(assocsized(add, (1, 2, 3), 2)) == set( - (((add, 1, 2), 3), (1, (add, 2, 3))) +def test_assoc_args(): + op = "add" + + def op_pred(x): + return x == op + + assert tuple(assoc_args(op, (1, 2, 3), 2)) == (((op, 1, 2), 3), (1, (op, 2, 3)),) + assert tuple(assoc_args(op, [1, 2, 3], 2)) == ([[op, 1, 2], 3], [1, [op, 2, 3]],) + assert tuple(assoc_args(op, (1, 2, 3), 1)) == ( + ((op, 1), 2, 3), + (1, (op, 2), 3), + (1, 2, (op, 3)), ) - assert set(assocsized(add, (1, 2, 3), 1)) == set((((add, 1, 2, 3),),)) + assert tuple(assoc_args(op, (1, 2, 3), 3)) == ((1, 2, 3),) + f_rands = flatten_assoc_args(op_pred, (1, (op, 2, 3))) + assert tuple(assoc_args(op, f_rands, 2, ctor=tuple)) == ( + ((op, 1, 2), 3), + (1, (op, 2, 3)), + ) -def test_objects(): - fact(commutative, Add) - fact(associative, Add) - assert tuple(goaleval(eq_assoccomm(add(1, 2, 3), add(3, 1, 2)))({})) - assert tuple(goaleval(eq_assoccomm(add(1, 2, 3), add(3, 1, 2)))({})) - x = var("x") +def test_eq_assoc_args(): + + assoc_op = "assoc_op" + + fact(associative, assoc_op) + + assert not run(0, True, eq_assoc_args(assoc_op, (1,), [1], n=None)) + assert run(0, True, eq_assoc_args(assoc_op, (1,), (1,), n=None)) == (True,) + assert run(0, True, eq_assoc_args(assoc_op, (1, 1), (1, 1))) == (True,) + assert run(0, True, eq_assoc_args(assoc_op, (1, 2, 3), (1, (assoc_op, 2, 3)))) == ( + True, + ) + assert run(0, True, eq_assoc_args(assoc_op, (1, (assoc_op, 2, 3)), (1, 2, 3))) == ( + True, + ) + assert run( + 0, True, eq_assoc_args(assoc_op, (1, (assoc_op, 2, 3), 4), (1, 2, 3, 4)) + ) == (True,) + assert not run( + 0, True, eq_assoc_args(assoc_op, (1, 2, 3), (1, (assoc_op, 2, 3), 4)) + ) + + x, y = var(), var() + + assert run(0, True, eq_assoc_args(assoc_op, (x,), (x,), n=None)) == (True,) + assert run(0, x, eq_assoc_args(assoc_op, x, (y,), n=None)) == ((y,),) + assert run(0, x, eq_assoc_args(assoc_op, (y,), x, n=None)) == ((y,),) + + assert run(0, x, eq_assoc_args(assoc_op, (1, x, 4), (1, 2, 3, 4))) == ( + (assoc_op, 2, 3), + ) + assert run(0, x, eq_assoc_args(assoc_op, (1, 2, 3, 4), (1, x, 4))) == ( + (assoc_op, 2, 3), + ) + assert run(0, x, eq_assoc_args(assoc_op, [1, x, 4], [1, 2, 3, 4])) == ( + [assoc_op, 2, 3], + ) + assert run(0, True, eq_assoc_args(assoc_op, (1, 1), ("other_op", 1, 1))) == () + + assert run(0, x, eq_assoc_args(assoc_op, (1, 2, 3), x, n=2)) == ( + ((assoc_op, 1, 2), 3), + (1, (assoc_op, 2, 3)), + ) + assert run(0, x, eq_assoc_args(assoc_op, x, (1, 2, 3), n=2)) == ( + ((assoc_op, 1, 2), 3), + (1, (assoc_op, 2, 3)), + ) + + assert run(0, x, eq_assoc_args(assoc_op, (1, 2, 3), x)) == ( + ((assoc_op, 1, 2), 3), + (1, (assoc_op, 2, 3)), + (1, 2, 3), + ) + + assert () not in run(0, x, eq_assoc_args(assoc_op, (), x, no_ident=True)) + assert (1,) not in run(0, x, eq_assoc_args(assoc_op, (1,), x, no_ident=True)) + assert (1, 2, 3) not in run( + 0, x, eq_assoc_args(assoc_op, (1, 2, 3), x, no_ident=True) + ) assert ( - reify(x, tuple(goaleval(eq_assoccomm(add(1, 2, 3), add(1, 2, x)))({}))[0]) == 3 + run( + 0, + True, + eq_assoc_args( + assoc_op, (1, (assoc_op, 2, 3)), (1, (assoc_op, 2, 3)), no_ident=True, + ), + ) + == () ) - assert reify(x, next(goaleval(eq_assoccomm(add(1, 2, 3), add(x, 2, 1)))({}))) == 3 + assert run( + 0, + True, + eq_assoc_args( + assoc_op, (1, (assoc_op, 2, 3)), ((assoc_op, 1, 2), 3), no_ident=True, + ), + ) == (True,) - v = add(1, 2, 3) - with variables(v): - x = add(5, 6) - assert reify(v, next(goaleval(eq_assoccomm(v, x))({}))) == x +def test_eq_assoc(): -@pytest.mark.xfail(reason="This would work if we flattened first.", strict=True) -def test_deep_associativity(): - expr1 = (a, 1, 2, (a, x, 5, 6)) - expr2 = (a, (a, 1, 2), 3, 4, 5, 6) - result = {x: (a, 3, 4)} - assert tuple(assocunify(expr1, expr2, {})) == result + assoc_op = "assoc_op" + associative.index.clear() + associative.facts.clear() -def test_buildo(): - x = var("x") - assert results(buildo("add", (1, 2, 3), x), {}) == ({x: ("add", 1, 2, 3)},) - assert results(buildo(x, (1, 2, 3), ("add", 1, 2, 3)), {}) == ({x: "add"},) - assert results(buildo("add", x, ("add", 1, 2, 3)), {}) == ({x: (1, 2, 3)},) + fact(associative, assoc_op) + assert run(0, True, eq_assoc(1, 1)) == (True,) + assert run(0, True, eq_assoc((assoc_op, 1, 2, 3), (assoc_op, 1, 2, 3))) == (True,) + assert not run(0, True, eq_assoc((assoc_op, 3, 2, 1), (assoc_op, 1, 2, 3))) + assert run( + 0, True, eq_assoc((assoc_op, (assoc_op, 1, 2), 3), (assoc_op, 1, 2, 3)) + ) == (True,) + assert run( + 0, True, eq_assoc((assoc_op, 1, 2, 3), (assoc_op, (assoc_op, 1, 2), 3)) + ) == (True,) + o = "op" + assert not run(0, True, eq_assoc((o, 1, 2, 3), (o, (o, 1, 2), 3))) + + x = var() + res = run(0, x, eq_assoc((assoc_op, 1, 2, 3), x, n=2)) + assert res == ( + (assoc_op, (assoc_op, 1, 2), 3), + (assoc_op, 1, 2, 3), + (assoc_op, 1, (assoc_op, 2, 3)), + ) -def test_op_args(): - assert op_args(add(1, 2, 3)) == (Add, (1, 2, 3)) - assert op_args("foo") == (None, None) + res = run(0, x, eq_assoc(x, (assoc_op, 1, 2, 3), n=2)) + assert res == ( + (assoc_op, (assoc_op, 1, 2), 3), + (assoc_op, 1, 2, 3), + (assoc_op, 1, (assoc_op, 2, 3)), + ) + y, z = var(), var() + + # Check results when both arguments are variables + res = run(3, (x, y), eq_assoc(x, y)) + exp_res_form = ( + (etuple(assoc_op, x, y, z), etuple(assoc_op, etuple(assoc_op, x, y), z)), + (x, y), + ( + etuple(etuple(assoc_op, x, y, z)), + etuple(etuple(assoc_op, etuple(assoc_op, x, y), z)), + ), + ) -def test_buildo_object(): - x = var("x") - assert results(buildo(Add, (1, 2, 3), x), {}) == ({x: add(1, 2, 3)},) - assert results(buildo(x, (1, 2, 3), add(1, 2, 3)), {}) == ({x: Add},) - assert results(buildo(Add, x, add(1, 2, 3)), {}) == ({x: (1, 2, 3)},) + for a, b in zip(res, exp_res_form): + s = unify(a, b) + assert s is not False, (a, b) + assert all(isvar(i) for i in reify((x, y, z), s)) + + # Make sure it works with `cons` + res = run(0, (x, y), eq_assoc(cons(x, y), (assoc_op, 1, 2, 3))) + assert res == ( + (assoc_op, ((assoc_op, 1, 2), 3)), + (assoc_op, (1, 2, 3)), + (assoc_op, (1, (assoc_op, 2, 3))), + ) + res = run(1, (x, y), eq_assoc(cons(x, y), (x, z, 2, 3))) + assert res == ((assoc_op, ((assoc_op, z, 2), 3)),) + + # Don't use a predicate that can never succeed, e.g. + # associative_2 = Relation("associative_2") + # run(1, (x, y), eq_assoc(cons(x, y), (x, z), op_predicate=associative_2)) + + # Nested expressions should work now + expr1 = (assoc_op, 1, 2, (assoc_op, x, 5, 6)) + expr2 = (assoc_op, (assoc_op, 1, 2), 3, 4, 5, 6) + assert run(0, x, eq_assoc(expr1, expr2, n=2)) == ((assoc_op, 3, 4),) + + +def test_assoc_flatten(): + + add = "add" + mul = "mul" + + fact(commutative, add) + fact(associative, add) + fact(commutative, mul) + fact(associative, mul) + + assert run( + 0, + True, + assoc_flatten((mul, 1, (add, 2, 3), (mul, 4, 5)), (mul, 1, (add, 2, 3), 4, 5)), + ) == (True,) + + x = var() + assert run(0, x, assoc_flatten((mul, 1, (add, 2, 3), (mul, 4, 5)), x),) == ( + (mul, 1, (add, 2, 3), 4, 5), + ) + + assert run( + 0, + True, + assoc_flatten( + ("op", 1, (add, 2, 3), (mul, 4, 5)), ("op", 1, (add, 2, 3), (mul, 4, 5)) + ), + ) == (True,) + + assert run(0, x, assoc_flatten(("op", 1, (add, 2, 3), (mul, 4, 5)), x)) == ( + ("op", 1, (add, 2, 3), (mul, 4, 5)), + ) + + +def test_eq_assoccomm(): + x, y = var(), var() + + ac = "commassoc_op" + + commutative.index.clear() + commutative.facts.clear() + + fact(commutative, ac) + fact(associative, ac) + + assert run(0, True, eq_assoccomm(1, 1)) == (True,) + assert run(0, True, eq_assoccomm((1,), (1,))) == (True,) + assert run(0, True, eq_assoccomm(x, (1,))) == (True,) + assert run(0, True, eq_assoccomm((1,), x)) == (True,) + + # Assoc only + assert run(0, True, eq_assoccomm((ac, 1, (ac, 2, 3)), (ac, (ac, 1, 2), 3))) == ( + True, + ) + # Commute only + assert run(0, True, eq_assoccomm((ac, 1, (ac, 2, 3)), (ac, (ac, 3, 2), 1))) == ( + True, + ) + # Both + assert run(0, True, eq_assoccomm((ac, 1, (ac, 3, 2)), (ac, (ac, 1, 2), 3))) == ( + True, + ) + + exp_res = set( + ( + (ac, 1, 3, 2), + (ac, 1, 2, 3), + (ac, 2, 1, 3), + (ac, 2, 3, 1), + (ac, 3, 1, 2), + (ac, 3, 2, 1), + (ac, 1, (ac, 2, 3)), + (ac, 1, (ac, 3, 2)), + (ac, 2, (ac, 1, 3)), + (ac, 2, (ac, 3, 1)), + (ac, 3, (ac, 1, 2)), + (ac, 3, (ac, 2, 1)), + (ac, (ac, 2, 3), 1), + (ac, (ac, 3, 2), 1), + (ac, (ac, 1, 3), 2), + (ac, (ac, 3, 1), 2), + (ac, (ac, 1, 2), 3), + (ac, (ac, 2, 1), 3), + ) + ) + assert set(run(0, x, eq_assoccomm((ac, 1, (ac, 2, 3)), x))) == exp_res + assert set(run(0, x, eq_assoccomm((ac, 1, 3, 2), x))) == exp_res + assert set(run(0, x, eq_assoccomm((ac, 2, (ac, 3, 1)), x))) == exp_res + # LHS variations + assert set(run(0, x, eq_assoccomm(x, (ac, 1, (ac, 2, 3))))) == exp_res + + assert run(0, (x, y), eq_assoccomm((ac, (ac, 1, x), y), (ac, 2, (ac, 3, 1)))) == ( + (2, 3), + (3, 2), + ) + + assert run(0, True, eq_assoccomm((ac, (ac, 1, 2), 3), (ac, 1, 2, 3))) == (True,) + assert run(0, True, eq_assoccomm((ac, 3, (ac, 1, 2)), (ac, 1, 2, 3))) == (True,) + assert run(0, True, eq_assoccomm((ac, 1, 1), ("other_op", 1, 1))) == () + + assert run(0, x, eq_assoccomm((ac, 3, (ac, 1, 2)), (ac, 1, x, 3))) == (2,) + + # Both arguments unground + op_lv = var() + z = var() + res = run(4, (x, y), eq_assoccomm(x, y)) + exp_res_form = ( + (etuple(op_lv, x, y), etuple(op_lv, y, x)), + (y, y), + (etuple(etuple(op_lv, x, y)), etuple(etuple(op_lv, y, x)),), + (etuple(op_lv, x, y, z), etuple(op_lv, etuple(op_lv, x, y), z),), + ) + + for a, b in zip(res, exp_res_form): + s = unify(a, b) + assert ( + op_lv not in s + or (s[op_lv],) in associative.facts + or (s[op_lv],) in commutative.facts + ) + assert s is not False, (a, b) + assert all(isvar(i) for i in reify((x, y, z), s)) + + +def test_assoccomm_algebra(): + + add = "add" + mul = "mul" + + commutative.index.clear() + commutative.facts.clear() + associative.index.clear() + associative.facts.clear() + + fact(commutative, add) + fact(associative, add) + fact(commutative, mul) + fact(associative, mul) + + x, y = var(), var() + + pattern = (mul, (add, 1, x), y) # (1 + x) * y + expr = (mul, 2, (add, 3, 1)) # 2 * (3 + 1) + + assert run(0, (x, y), eq_assoccomm(pattern, expr)) == ((3, 2),) + + +def test_assoccomm_objects(): + + commutative.index.clear() + commutative.facts.clear() + associative.index.clear() + associative.facts.clear() -def test_eq_comm_object(): - x = var("x") fact(commutative, Add) fact(associative, Add) - assert run(0, x, eq_comm(add(1, 2, 3), add(3, 1, x))) == (2,) + x = var() - assert set(run(0, x, eq_comm(add(1, 2), x))) == set((add(1, 2), add(2, 1))) + assert run(0, True, eq_assoccomm(add(1, 2, 3), add(3, 1, 2))) == (True,) + assert run(0, x, eq_assoccomm(add(1, 2, 3), add(1, 2, x))) == (3,) + assert run(0, x, eq_assoccomm(add(1, 2, 3), add(x, 2, 1))) == (3,) - assert set(run(0, x, eq_assoccomm(add(1, 2, 3), add(1, x)))) == set( - (add(2, 3), add(3, 2)) - ) + v = add(1, 2, 3) + with variables(v): + x = add(5, 6) + # TODO: There are two more cases here, but they're in tuple form. + # See `test_eq_comm_object`. + assert x in run(0, v, eq_assoccomm(v, x))