Skip to content

Commit

Permalink
numba: Add support for some more operators (#333)
Browse files Browse the repository at this point in the history
This adds:
* MultiVector / scalar
* -MultiVector
* +MultiVector
* ~MultiVector
  • Loading branch information
eric-wieser committed Jun 11, 2020
1 parent 2550e83 commit 1881dac
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 2 deletions.
34 changes: 34 additions & 0 deletions clifford/numba/_multivector.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,37 @@ def impl(a, b):
def impl(a, b):
return a.layout.MultiVector(np.zeros_like(a.value, dtype=ret_type))
return impl


@numba.extending.overload(operator.truediv)
def ga_truediv(a, b):
if isinstance(a, MultiVectorType) and isinstance(b, types.abstract.Number):
def impl(a, b):
return a.layout.MultiVector(a.value / b)
return impl
# TODO: implement inversion for the other pairs


@numba.extending.overload(operator.invert)
def ga_invert(a):
if isinstance(a, MultiVectorType):
adjoint_func = a.layout_type.obj.adjoint_func
def impl(a):
return a.layout.MultiVector(adjoint_func(a.value))
return impl


@numba.extending.overload(operator.pos)
def ga_pos(a):
if isinstance(a, MultiVectorType):
def impl(a):
return a.layout.MultiVector(a.value.copy())
return impl


@numba.extending.overload(operator.neg)
def ga_neg(a):
if isinstance(a, MultiVectorType):
def impl(a):
return a.layout.MultiVector(-a.value)
return impl
38 changes: 36 additions & 2 deletions clifford/test/test_numba_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def double(a):
assert double(e2) == 2 * e2


class TestOverloads:
class TestOperators:
@pytest.mark.parametrize("op", [
pytest.param(getattr(operator, op), id=op)
for op in ['add', 'sub', 'mul', 'xor', 'or_']
Expand All @@ -68,7 +68,7 @@ class TestOverloads:
(0.5, 0.5 * e1), (0.5 * e1, 0.5),
(e1, 0.5), (0.5, e1),
(1, 0.5*e1), (0.5*e1, 1)])
def test_overload(self, op, a, b):
def test_binary(self, op, a, b):
@numba.njit
def overload(a, b):
return op(a, b)
Expand All @@ -81,3 +81,37 @@ def overload(a, b):
# can't directly compare the dtypes. We only care that the float / int
# state is kept anyway.
assert ab.value.dtype.kind == ab_alt.value.dtype.kind

# `op` is not parametrized, for simplicity we only support MultiVector / scalar.
@pytest.mark.parametrize("a,b", [(e1, 2), (2.0*e1, 2)])
def test_truediv(self, a, b):
op = operator.truediv

@numba.njit
def overload(a, b):
return op(a, b)

ab = op(a, b)
ab_alt = overload(a, b)
assert ab == ab_alt
assert ab.layout is ab_alt.layout
# numba disagrees with numpy about what type `int` is on windows, so we
# can't directly compare the dtypes. We only care that the float / int
# state is kept anyway.
assert ab.value.dtype.kind == ab_alt.value.dtype.kind

@pytest.mark.parametrize("op", [
pytest.param(getattr(operator, op), id=op)
for op in ['pos', 'neg', 'invert']
])
@pytest.mark.parametrize("a", [layout.scalar, e1, 0.5*e1, e1^e2, 1 + (e1^e2)])
def test_unary(self, op, a):
@numba.njit
def overload(a):
return op(a)

ret = op(a)
ret_alt = overload(a)
assert ret == ret_alt
assert ret.layout is ret_alt.layout
assert ret.value.dtype == ret_alt.value.dtype

0 comments on commit 1881dac

Please sign in to comment.