diff --git a/etuples/dispatch.py b/etuples/dispatch.py index 8d4e238..c17a93c 100644 --- a/etuples/dispatch.py +++ b/etuples/dispatch.py @@ -1,4 +1,4 @@ -from collections.abc import Callable, Sequence +from collections.abc import Callable, Sequence, Mapping from multipledispatch import dispatch @@ -6,6 +6,22 @@ from .core import etuple, ExpressionTuple +try: + from unification.core import _reify, _unify +except ModuleNotFoundError: + pass +else: + + def _unify_ExpressionTuple(u, v, s): + return _unify(u._tuple, v._tuple, s) + + _unify.add((ExpressionTuple, ExpressionTuple, Mapping), _unify_ExpressionTuple) + + def _reify_ExpressionTuple(u, s): + return etuple(*_reify(u._tuple, s)) + + _reify.add((ExpressionTuple, Mapping), _reify_ExpressionTuple) + @dispatch(object) def rator(x): diff --git a/requirements.txt b/requirements.txt index 1342188..c6525e4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ -e ./ +unification ipython coveralls pydocstyle>=3.0.0 diff --git a/setup.cfg b/setup.cfg index 91cd485..e3f7b27 100644 --- a/setup.cfg +++ b/setup.cfg @@ -28,4 +28,5 @@ exclude_lines = def __repr__ raise NotImplementedError if __name__ == .__main__.: - assert False \ No newline at end of file + assert False + ModuleNotFoundError \ No newline at end of file diff --git a/tests/test_dispatch.py b/tests/test_dispatch.py index 4a6cc9c..4b77936 100644 --- a/tests/test_dispatch.py +++ b/tests/test_dispatch.py @@ -1,4 +1,4 @@ -from pytest import raises +from pytest import raises, importorskip from operator import add from collections.abc import Sequence @@ -91,3 +91,21 @@ def test_etuplize(): assert etuplize(node_2) == etuple(op_1, etuple(op_2, 1, 2), 3) assert etuplize(node_2, shallow=True) == etuple(op_1, node_1, 3) + + +def test_unification(): + from cons import cons + + uni = importorskip("unification") + + var, unify, reify = uni.var, uni.unify, uni.reify + + a_lv, b_lv = var(), var() + assert unify(etuple(add, 1, 2), etuple(add, 1, 2), {}) == {} + assert unify(etuple(add, 1, 2), etuple(a_lv, 1, 2), {}) == {a_lv: add} + assert reify(etuple(a_lv, 1, 2), {a_lv: add}) == etuple(add, 1, 2) + + res = unify(etuple(add, 1, 2), cons(a_lv, b_lv), {}) + assert res == {a_lv: add, b_lv: etuple(1, 2)} + + assert reify(cons(a_lv, b_lv), res) == etuple(add, 1, 2)