Skip to content

Commit

Permalink
numba: provide jit overloads for basic multivector binary operators (#…
Browse files Browse the repository at this point in the history
…332)

This provides all of the `-+*^|` operators.
  • Loading branch information
hugohadfield committed Jun 10, 2020
1 parent 6bbd6cd commit 2550e83
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 0 deletions.
112 changes: 112 additions & 0 deletions clifford/numba/_multivector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
For now, this just supports .value wrapping / unwrapping
"""
import numba
import operator
import numpy as np
from numba.extending import NativeValue
import llvmlite.ir

Expand Down Expand Up @@ -135,3 +137,113 @@ def box_MultiVector(typ: MultiVectorType, val: llvmlite.ir.Value, c) -> MultiVec

numba.extending.make_attribute_wrapper(MultiVectorType, 'value', 'value')
numba.extending.make_attribute_wrapper(MultiVectorType, 'layout', 'layout')


@numba.extending.overload(operator.add)
def ga_add(a, b):
if isinstance(a, MultiVectorType) and isinstance(b, MultiVectorType):
if a.layout_type != b.layout_type:
raise numba.TypingError("MultiVector objects belong to different layouts")
def impl(a, b):
return a.layout.MultiVector(a.value + b.value)
return impl
elif isinstance(a, types.abstract.Number) and isinstance(b, MultiVectorType):
scalar_index = b.layout_type.obj._basis_blade_order.bitmap_to_index[0]
ret_type = np.result_type(_numpy_support.as_dtype(a), _numpy_support.as_dtype(b.value_type.dtype))
def impl(a, b):
op = b.value.astype(ret_type)
op[scalar_index] += a
return b.layout.MultiVector(op)
return impl
elif isinstance(a, MultiVectorType) and isinstance(b, types.abstract.Number):
scalar_index = a.layout_type.obj._basis_blade_order.bitmap_to_index[0]
ret_type = np.result_type(_numpy_support.as_dtype(a.value_type.dtype), _numpy_support.as_dtype(b))
def impl(a, b):
op = a.value.astype(ret_type)
op[scalar_index] += b
return a.layout.MultiVector(op)
return impl


@numba.extending.overload(operator.sub)
def ga_sub(a, b):
if isinstance(a, MultiVectorType) and isinstance(b, MultiVectorType):
if a.layout_type != b.layout_type:
raise numba.TypingError("MultiVector objects belong to different layouts")
def impl(a, b):
return a.layout.MultiVector(a.value - b.value)
return impl
elif isinstance(a, types.abstract.Number) and isinstance(b, MultiVectorType):
scalar_index = b.layout_type.obj._basis_blade_order.bitmap_to_index[0]
ret_type = np.result_type(_numpy_support.as_dtype(a), _numpy_support.as_dtype(b.value_type.dtype))
def impl(a, b):
op = -b.value.astype(ret_type)
op[scalar_index] += a
return b.layout.MultiVector(op)
return impl
elif isinstance(a, MultiVectorType) and isinstance(b, types.abstract.Number):
scalar_index = a.layout_type.obj._basis_blade_order.bitmap_to_index[0]
ret_type = np.result_type(_numpy_support.as_dtype(a.value_type.dtype), _numpy_support.as_dtype(b))
def impl(a, b):
op = a.value.astype(ret_type)
op[scalar_index] -= b
return a.layout.MultiVector(op)
return impl


@numba.extending.overload(operator.mul)
def ga_mul(a, b):
if isinstance(a, MultiVectorType) and isinstance(b, MultiVectorType):
if a.layout_type != b.layout_type:
raise numba.TypingError("MultiVector objects belong to different layouts")
gmt_func = a.layout_type.obj.gmt_func
def impl(a, b):
return a.layout.MultiVector(gmt_func(a.value, b.value))
return impl
elif isinstance(a, types.abstract.Number) and isinstance(b, MultiVectorType):
def impl(a, b):
return b.layout.MultiVector(a*b.value)
return impl
elif isinstance(a, MultiVectorType) and isinstance(b, types.abstract.Number):
def impl(a, b):
return a.layout.MultiVector(a.value*b)
return impl


@numba.extending.overload(operator.xor)
def ga_xor(a, b):
if isinstance(a, MultiVectorType) and isinstance(b, MultiVectorType):
if a.layout_type != b.layout_type:
raise numba.TypingError("MultiVector objects belong to different layouts")
omt_func = a.layout_type.obj.omt_func
def impl(a, b):
return a.layout.MultiVector(omt_func(a.value, b.value))
return impl
elif isinstance(a, types.abstract.Number) and isinstance(b, MultiVectorType):
def impl(a, b):
return b.layout.MultiVector(b.value*a)
return impl
elif isinstance(a, MultiVectorType) and isinstance(b, types.abstract.Number):
def impl(a, b):
return a.layout.MultiVector(a.value*b)
return impl

@numba.extending.overload(operator.or_)
def ga_or(a, b):
if isinstance(a, MultiVectorType) and isinstance(b, MultiVectorType):
if a.layout_type != b.layout_type:
raise numba.TypingError("MultiVector objects belong to different layouts")
imt_func = a.layout_type.obj.imt_func
def impl(a, b):
return a.layout.MultiVector(imt_func(a.value, b.value))
return impl
elif isinstance(a, types.abstract.Number) and isinstance(b, MultiVectorType):
ret_type = np.result_type(_numpy_support.as_dtype(a), _numpy_support.as_dtype(b.value_type.dtype))
def impl(a, b):
return b.layout.MultiVector(np.zeros_like(b.value, dtype=ret_type))
return impl
elif isinstance(a, MultiVectorType) and isinstance(b, types.abstract.Number):
ret_type = np.result_type(_numpy_support.as_dtype(a.value_type.dtype), _numpy_support.as_dtype(b))
def impl(a, b):
return a.layout.MultiVector(np.zeros_like(a.value, dtype=ret_type))
return impl
26 changes: 26 additions & 0 deletions clifford/test/test_numba_extensions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numba
import operator

from clifford.g3c import layout, e1, e2
import clifford as cf
import pytest


@numba.njit
Expand Down Expand Up @@ -55,3 +57,27 @@ def double(a):
return a.layout.MultiVector(a.value*2)

assert double(e2) == 2 * e2


class TestOverloads:
@pytest.mark.parametrize("op", [
pytest.param(getattr(operator, op), id=op)
for op in ['add', 'sub', 'mul', 'xor', 'or_']
])
@pytest.mark.parametrize("a,b", [(e1, e2), (1, e1), (e1, 1),
(0.5, 0.5 * e1), (0.5 * e1, 0.5),
(e1, 0.5), (0.5, e1),
(1, 0.5*e1), (0.5*e1, 1)])
def test_overload(self, op, a, b):
@numba.njit
def overload(a, b):
return op(a, b)

ab = op(a, b)
ab_alt = overload(a, b)
assert ab == ab_alt
assert ab.layout is ab_alt.layout
# numba disagrees with numpy about what type `int` is on windows, so we
# can't directly compare the dtypes. We only care that the float / int
# state is kept anyway.
assert ab.value.dtype.kind == ab_alt.value.dtype.kind

0 comments on commit 2550e83

Please sign in to comment.