diff --git a/tests/test_core.py b/tests/test_core.py index 680ce6e..5249b0b 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,7 +1,8 @@ from types import MappingProxyType +from collections import OrderedDict from unification import var -from unification.core import reify, unify +from unification.core import reify, unify, unground_lvars def test_reify(): @@ -15,11 +16,19 @@ def test_reify(): assert reify(z, MappingProxyType(s)) == (1, 2) -def test_reify_dict(): +def test_reify_Mapping(): x, y = var(), var() s = {x: 2, y: 4} - e = {1: x, 3: {5: y}} - assert reify(e, s) == {1: 2, 3: {5: 4}} + e = [(1, x), (3, {5: y})] + expected_res = [(1, 2), (3, {5: 4})] + assert reify(dict(e), s) == dict(expected_res) + assert reify(OrderedDict(e), s) == OrderedDict(expected_res) + + +def test_reify_Set(): + x, y = var(), var() + assert reify({1, 2, x, y}, {x: 3}) == {1, 2, 3, y} + assert reify(frozenset({1, 2, x, y}), {x: 3}) == frozenset({1, 2, 3, y}) def test_reify_list(): @@ -37,6 +46,19 @@ def test_reify_complex(): assert reify(e, s) == {1: [2], 3: (4, 5)} +def test_unify_slice(): + x = var("x") + y = var("y") + + assert unify(slice(1), slice(1), {}) == {} + assert unify(slice(1, 2, 3), x, {}) == {x: slice(1, 2, 3)} + assert unify(slice(1, 2, None), slice(x, y), {}) == {x: 1, y: 2} + + +def test_reify_slice(): + assert reify(slice(1, var(2), 3), {var(2): 10}) == slice(1, 10, 3) + + def test_unify(): assert unify(1, 1, {}) == {} assert unify(1, 2, {}) is False @@ -59,10 +81,11 @@ def test_unify_seq(): def test_unify_set(): - x = var("x") - assert unify(set((1, 2)), set((1, 2)), {}) == {} - assert unify(set((1, x)), set((1, 2)), {}) == {x: 2} - assert unify(set((x, 2)), set((1, 2)), {}) == {x: 1} + x, y = var(), var() + assert unify({1, 2}, {1, 2}, {}) == {} + assert unify({1, x}, {1, 2}, {}) == {x: 2} + assert unify({x, 2}, {1, 2}, {}) == {x: 1} + assert unify({1, y, x}, {2, 1}, {x: 2}) is False def test_unify_dict(): @@ -80,3 +103,15 @@ def test_unify_complex(): assert unify({1: (2, 3)}, {1: (2, var(5))}, {}) == {var(5): 3} assert unify({1: [2, 3]}, {1: [2, var(5)]}, {}) == {var(5): 3} + + +def test_unground_lvars(): + assert unground_lvars((1, 2), {}) == set() + assert unground_lvars((1, [var("a"), [var("b"), 2], 3]), {}) == {var("a"), var("b")} + assert unground_lvars((1, [var("a"), [var("b"), 2], 3]), {var("a"): 4}) == { + var("b") + } + assert ( + unground_lvars((1, [var("a"), [var("b"), 2], 3]), {var("a"): 4, var("b"): 5}) + == set() + ) diff --git a/tests/test_more.py b/tests/test_more.py index ce676e3..cd1b808 100644 --- a/tests/test_more.py +++ b/tests/test_more.py @@ -61,20 +61,6 @@ def test_objects_full(): ) == Foo(1, Bar(Foo(2, 3))) -def test_unify_slice(): - x = var("x") - y = var("y") - - assert unify(slice(1), slice(1), {}) == {} - assert unify(slice(1, 2, 3), x, {}) == {x: slice(1, 2, 3)} - assert unify(slice(1, 2, None), slice(x, y), {}) == {x: 1, y: 2} - - -def test_reify_slice(): - x = var("x") - assert reify(slice(1, var(2), 3), {var(2): 10}) == slice(1, 10, 3) - - @unifiable class A(object): def __init__(self, a, b): @@ -105,7 +91,7 @@ def __eq__(self, other): return self.__dict__ == other.__dict__ -def test_unifiable(): +def test_unifiable_slots(): x = var("x") f = Aslot(1, 2) g = Aslot(1, x) diff --git a/unification/core.py b/unification/core.py index ce8a7b8..2bb2916 100644 --- a/unification/core.py +++ b/unification/core.py @@ -1,37 +1,47 @@ from toolz import assoc from operator import length_hint -from collections.abc import Iterator, Mapping +from functools import partial +from collections import OrderedDict +from collections.abc import Iterator, Mapping, Set from .utils import transitive_get as walk from .variable import isvar from .dispatch import dispatch -@dispatch(Iterator, Mapping) -def _reify(t, s): - return iter(reify(arg, s) for arg in t) +@dispatch(object, Mapping) +def _reify(o, s): + return o -@dispatch(tuple, Mapping) -def _reify(t, s): - return tuple(reify(iter(t), s)) +def _reify_Iterable(type_ctor, t, s): + return type_ctor(reify(a, s) for a in t) -@dispatch(list, Mapping) -def _reify(t, s): - return list(reify(iter(t), s)) +for seq, ctor in ( + (tuple, tuple), + (list, list), + (Iterator, iter), + (set, set), + (frozenset, frozenset), +): + _reify.add((seq, Mapping), partial(_reify_Iterable, ctor)) -@dispatch(dict, Mapping) -def _reify(d, s): - return dict((k, reify(v, s)) for k, v in d.items()) +def _reify_Mapping(ctor, d, s): + return ctor((k, reify(v, s)) for k, v in d.items()) -@dispatch(object, Mapping) -def _reify(o, s): - return o # catch all, just return the object +for seq in (dict, OrderedDict): + _reify.add((seq, Mapping), partial(_reify_Mapping, seq)) + +@_reify.register(slice, Mapping) +def _reify_slice(o, s): + return slice(*reify((o.start, o.stop, o.step), s)) + +@dispatch(object, Mapping) def reify(e, s): """Replace variables of an expression with their substitutions. @@ -46,16 +56,16 @@ def reify(e, s): {1: 2, 3: (4, 5)} """ if isvar(e): - return reify(s[e], s) if e in s else e + e = walk(e, s) return _reify(e, s) @dispatch(object, object, Mapping) def _unify(u, v, s): - return False # catch all + return False -def _unify_seq(u, v, s): +def _unify_Sequence(u, v, s): len_u = length_hint(u, -1) len_v = length_hint(v, -1) @@ -70,19 +80,19 @@ def _unify_seq(u, v, s): for seq in (tuple, list, Iterator): - _unify.add((seq, seq, Mapping), _unify_seq) + _unify.add((seq, seq, Mapping), _unify_Sequence) -@dispatch((set, frozenset), (set, frozenset), Mapping) -def _unify(u, v, s): +@_unify.register(Set, Set, Mapping) +def _unify_Set(u, v, s): i = u & v u = u - i v = v - i - return _unify(sorted(u), sorted(v), s) + return _unify(iter(u), iter(v), s) -@dispatch(dict, dict, Mapping) -def _unify(u, v, s): +@_unify.register(Mapping, Mapping, Mapping) +def _unify_Mapping(u, v, s): if len(u) != len(v): return False for key, uval in u.items(): @@ -94,6 +104,11 @@ def _unify(u, v, s): return s +@_unify.register(slice, slice, dict) +def _unify_slice(u, v, s): + return unify((u.start, u.stop, u.step), (v.start, v.stop, v.step), s) + + @dispatch(object, object, Mapping) def unify(u, v, s): """Find substitution so that u == v while satisfying s. @@ -113,6 +128,28 @@ def unify(u, v, s): return _unify(u, v, s) -@dispatch(object, object) -def unify(u, v): +@unify.register(object, object) +def unify_NoMap(u, v): return unify(u, v, {}) + + +def unground_lvars(u, s): + """Return the unground logic variables from a term and state.""" + + lvars = set() + _reify_object = _reify.dispatch(object, Mapping) + + def _reify_var(u, s): + nonlocal lvars + + if isvar(u): + lvars.add(u) + return u + + _reify.add((object, Mapping), _reify_var) + try: + reify(u, s) + finally: + _reify.add((object, Mapping), _reify_object) + + return lvars diff --git a/unification/more.py b/unification/more.py index a2acee2..916e7f7 100644 --- a/unification/more.py +++ b/unification/more.py @@ -1,5 +1,6 @@ -from .core import unify, reify -from .dispatch import dispatch +from collections.abc import Mapping + +from .core import unify, reify, _unify, _reify def unifiable(cls): @@ -24,8 +25,8 @@ def unifiable(cls): >>> unify(a, b, {}) {~x: 2} """ - _unify.add((cls, cls, dict), unify_object) - _reify.add((cls, dict), reify_object) + _unify.add((cls, cls, Mapping), unify_object) + _reify.add((cls, Mapping), reify_object) return cls @@ -74,12 +75,6 @@ def _reify_object_slots(o, s): return newobj -@dispatch(slice, dict) -def _reify(o, s): - """Reify a Python ``slice`` object.""" - return slice(*reify((o.start, o.stop, o.step), s)) - - def unify_object(u, v, s): """Unify two Python objects. @@ -108,9 +103,3 @@ def unify_object(u, v, s): ) else: return unify(u.__dict__, v.__dict__, s) - - -@dispatch(slice, slice, dict) -def _unify(u, v, s): - """Unify a Python ``slice`` object.""" - return unify((u.start, u.stop, u.step), (v.start, v.stop, v.step), s) diff --git a/unification/utils.py b/unification/utils.py index 317a820..6f19930 100644 --- a/unification/utils.py +++ b/unification/utils.py @@ -73,14 +73,6 @@ def reverse_dict(d): return result -def xfail(func): - try: - func() - raise Exception("XFailed test passed") # pragma:nocover - except Exception: - pass - - def freeze(d): """Freeze container to hashable a form.