Skip to content

Commit

Permalink
Add more Mapping, Set types and use dispatch register
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Dec 26, 2019
1 parent baca94f commit 88326cb
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 36 deletions.
24 changes: 17 additions & 7 deletions tests/test_core.py
@@ -1,4 +1,5 @@
from types import MappingProxyType
from collections import OrderedDict

from unification import var
from unification.core import reify, unify, unground_lvars
Expand All @@ -15,11 +16,19 @@ def test_reify():
assert reify(z, MappingProxyType(s)) == (1, 2)


def test_reify_dict():
def test_reify_Mapping():
x, y = var(), var()
s = {x: 2, y: 4}
e = {1: x, 3: {5: y}}
assert reify(e, s) == {1: 2, 3: {5: 4}}
e = [(1, x), (3, {5: y})]
expected_res = [(1, 2), (3, {5: 4})]
assert reify(dict(e), s) == dict(expected_res)
assert reify(OrderedDict(e), s) == OrderedDict(expected_res)


def test_reify_Set():
x, y = var(), var()
assert reify({1, 2, x, y}, {x: 3}) == {1, 2, 3, y}
assert reify(frozenset({1, 2, x, y}), {x: 3}) == frozenset({1, 2, 3, y})


def test_reify_list():
Expand Down Expand Up @@ -72,10 +81,11 @@ def test_unify_seq():


def test_unify_set():
x = var("x")
assert unify(set((1, 2)), set((1, 2)), {}) == {}
assert unify(set((1, x)), set((1, 2)), {}) == {x: 2}
assert unify(set((x, 2)), set((1, 2)), {}) == {x: 1}
x, y = var(), var()
assert unify({1, 2}, {1, 2}, {}) == {}
assert unify({1, x}, {1, 2}, {}) == {x: 2}
assert unify({x, 2}, {1, 2}, {}) == {x: 1}
assert unify({1, y, x}, {2, 1}, {x: 2}) is False


def test_unify_dict():
Expand Down
62 changes: 33 additions & 29 deletions unification/core.py
@@ -1,39 +1,43 @@
from toolz import assoc
from operator import length_hint
from collections.abc import Iterator, Mapping
from functools import partial
from collections import OrderedDict
from collections.abc import Iterator, Mapping, Set

from .utils import transitive_get as walk
from .variable import isvar
from .dispatch import dispatch


@dispatch(Iterator, Mapping)
def _reify(t, s):
return iter(reify(arg, s) for arg in t)
@dispatch(object, Mapping)
def _reify(o, s):
return o


@dispatch(tuple, Mapping)
def _reify(t, s):
return tuple(reify(iter(t), s))
def _reify_Iterable(type_ctor, t, s):
return type_ctor(reify(a, s) for a in t)


@dispatch(list, Mapping)
def _reify(t, s):
return list(reify(iter(t), s))
for seq, ctor in (
(tuple, tuple),
(list, list),
(Iterator, iter),
(set, set),
(frozenset, frozenset),
):
_reify.add((seq, Mapping), partial(_reify_Iterable, ctor))


@dispatch(dict, Mapping)
def _reify(d, s):
return dict((k, reify(v, s)) for k, v in d.items())
def _reify_Mapping(ctor, d, s):
return ctor((k, reify(v, s)) for k, v in d.items())


@dispatch(object, Mapping)
def _reify(o, s):
return o
for seq in (dict, OrderedDict):
_reify.add((seq, Mapping), partial(_reify_Mapping, seq))


@dispatch(slice, Mapping)
def _reify(o, s):
@_reify.register(slice, Mapping)
def _reify_slice(o, s):
return slice(*reify((o.start, o.stop, o.step), s))


Expand Down Expand Up @@ -61,7 +65,7 @@ def _unify(u, v, s):
return False


def _unify_seq(u, v, s):
def _unify_Sequence(u, v, s):
len_u = length_hint(u, -1)
len_v = length_hint(v, -1)

Expand All @@ -76,19 +80,19 @@ def _unify_seq(u, v, s):


for seq in (tuple, list, Iterator):
_unify.add((seq, seq, Mapping), _unify_seq)
_unify.add((seq, seq, Mapping), _unify_Sequence)


@dispatch((set, frozenset), (set, frozenset), Mapping)
def _unify(u, v, s):
@_unify.register(Set, Set, Mapping)
def _unify_Set(u, v, s):
i = u & v
u = u - i
v = v - i
return _unify(sorted(u), sorted(v), s)
return _unify(iter(u), iter(v), s)


@dispatch(dict, dict, Mapping)
def _unify(u, v, s):
@_unify.register(Mapping, Mapping, Mapping)
def _unify_Mapping(u, v, s):
if len(u) != len(v):
return False
for key, uval in u.items():
Expand All @@ -100,8 +104,8 @@ def _unify(u, v, s):
return s


@dispatch(slice, slice, dict)
def _unify(u, v, s):
@_unify.register(slice, slice, dict)
def _unify_slice(u, v, s):
return unify((u.start, u.stop, u.step), (v.start, v.stop, v.step), s)


Expand All @@ -124,8 +128,8 @@ def unify(u, v, s):
return _unify(u, v, s)


@dispatch(object, object)
def unify(u, v):
@unify.register(object, object)
def unify_NoMap(u, v):
return unify(u, v, {})


Expand Down

0 comments on commit 88326cb

Please sign in to comment.