Skip to content

Commit

Permalink
Merge pull request #336 from eric-wieser/project-jit
Browse files Browse the repository at this point in the history
numba: Add support for the projection operator
  • Loading branch information
eric-wieser committed Jun 19, 2020
2 parents bb0e2d1 + 581a135 commit a00a872
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 32 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ x-clifford-templates:
pip install . --prefer-binary;
fi
# always install with pip, conda has too old a version
- pip install pytest pytest-cov pytest-benchmark
- pip install --upgrade pytest pytest-cov pytest-benchmark
- pip install codecov
script:
- |
Expand Down
25 changes: 25 additions & 0 deletions clifford/numba/_multivector.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .._multivector import MultiVector

from ._layout import LayoutType
from ._overload_call import overload_call

__all__ = ['MultiVectorType']

Expand Down Expand Up @@ -281,3 +282,27 @@ def ga_neg(a):
def impl(a):
return a.layout.MultiVector(-a.value)
return impl


@overload_call(MultiVectorType)
def ga_call(self, arg):
# grade projection
grades = self.layout_type.obj._basis_blade_order.grades
if isinstance(arg, types.IntegerLiteral):
# Optimized case where the mask can be computed at compile-time.
# using `nonzero` makes the resulting array smaller.
inds, = (grades == arg.literal_value).nonzero()
def impl(self, arg):
mv = self.layout.MultiVector(np.zeros_like(self.value))
mv.value[inds] = self.value[inds]
return mv
return impl
elif isinstance(arg, types.Integer):
# runtime grade selection - should be less common
def impl(self, arg):
# probably faster to not call nonzero here
inds = grades == arg
mv = self.layout.MultiVector(np.zeros_like(self.value))
mv.value[inds] = self.value[inds]
return mv
return impl
95 changes: 95 additions & 0 deletions clifford/numba/_overload_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
Numba support for overloading the `__call__` operator.
This is a workaround until https://github.com/numba/numba/issues/5885 is
resolved.
"""
import numba
import numba.extending

try:
# module locations as of numba 0.49.0
from numba.core.typing.templates import (
AbstractTemplate, _OverloadAttributeTemplate, make_overload_attribute_template)
from numba.core import types
except ImportError:
# module locations prior to numba 0.49.0
from numba.typing.templates import (
AbstractTemplate, _OverloadAttributeTemplate, make_overload_attribute_template)
from numba import types

__all__ = ['overload_call']


class _OverloadCallTemplate(_OverloadAttributeTemplate):
"""
Modified version of _OverloadMethodTemplate for overloading `__call__`.
When typing, numba requires a `__call__` attribute to be provided as a
`BoundFunction` instance.
When lowering, the code in `numba.core.base.BaseContext.get_function`
expects to find the implementation under the key `NumbaType` - but
overload_method uses the key `(NumbaType, '__call__')`.
The only change in this class is to fix up they keys.
"""
is_method = True

@classmethod
def do_class_init(cls):
"""
Register generic method implementation.
"""

# this line is changed for __call__
@numba.extending.lower_builtin(cls.key, cls.key, types.VarArg(types.Any))
def method_impl(context, builder, sig, args):
typ = sig.args[0]
typing_context = context.typing_context
fnty = cls._get_function_type(typing_context, typ)
sig = cls._get_signature(typing_context, fnty, sig.args, {})
call = context.get_function(fnty, sig)
# Link dependent library
context.add_linking_libs(getattr(call, 'libs', ()))
return call(builder, args)

def _resolve(self, typ, attr):
if self._attr != attr:
return None

assert isinstance(typ, self.key)

class MethodTemplate(AbstractTemplate):
key = self.key # this line is changed for __call__
_inline = self._inline
_overload_func = staticmethod(self._overload_func)
_inline_overloads = self._inline_overloads

def generic(_, args, kws):
args = (typ,) + tuple(args)
fnty = self._get_function_type(self.context, typ)
sig = self._get_signature(self.context, fnty, args, kws)
sig = sig.replace(pysig=numba.extending.utils.pysignature(self._overload_func))
for template in fnty.templates:
self._inline_overloads.update(template._inline_overloads)
if sig is not None:
return sig.as_method()

return types.BoundFunction(MethodTemplate, typ)



def overload_call(typ, **kwargs):

def decorate(overload_func):
template = make_overload_attribute_template(
typ, '__call__', overload_func,
inline=kwargs.get('inline', 'never'),
base=_OverloadCallTemplate
)
numba.extending.infer_getattr(template)
numba.extending.overload(overload_func, **kwargs)(overload_func)
return overload_func

return decorate
24 changes: 23 additions & 1 deletion clifford/test/test_numba_extensions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numba
import operator

from clifford.g3c import layout, e1, e2
from clifford.g3c import layout, e1, e2, e3, e4, e5
import clifford as cf
import pytest

Expand Down Expand Up @@ -115,3 +115,25 @@ def overload(a):
assert ret == ret_alt
assert ret.layout is ret_alt.layout
assert ret.value.dtype == ret_alt.value.dtype

# We have a special overload for literal arguments for speed
def literal_grade_func(a, grade):
@numba.njit
def grade_func(a):
return a(grade)
return grade_func(a)

@numba.njit
def grade_func(a, grade):
return a(grade)

@pytest.mark.parametrize('func', [literal_grade_func, grade_func])
def test_grade_projection(self, func):
a = 1 + e1 + (e1^e2) + (e1^e2^e3) + (e1^e2^e3^e4) + (e1^e2^e3^e4^e5)

assert func(a, 0) == 1
assert func(a, 1) == e1
assert func(a, 2) == e1^e2
assert func(a, 3) == e1^e2^e3
assert func(a, 4) == e1^e2^e3^e4
assert func(a, 5) == e1^e2^e3^e4^e5
41 changes: 11 additions & 30 deletions clifford/tools/g3c/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ def get_line_intersection(L3, Ldd):
P = -(Pd * ninf * Pd)
imt = Pd | ninf
P_denominator = 2*(imt * imt).value[0]
return _project(P/P_denominator, 1)
return (P/P_denominator)(1)


@numba.njit
Expand All @@ -692,7 +692,7 @@ def midpoint_between_lines(L1, L2):
L3 = normalised(L1 + L2)
Ldd = normalised(L1 - L2)
S = get_line_intersection(L3, Ldd)
return normalise_n_minus_1(_project(S * ninf * S, 1))
return normalise_n_minus_1((S * ninf * S)(1))


@numba.njit
Expand Down Expand Up @@ -854,26 +854,7 @@ def random_rotation_translation_rotor(maximum_translation=10.0, maximum_angle=np
@numba.njit
@_defunct_wrapper
def project_val(val, grade):
return _project(layout.MultiVector(val), grade).value


@numba.njit
def _project(mv, grade):
""" fast grade projection """
output = np.zeros_like(mv.value)
if grade == 0:
output[0] = mv.value[0]
elif grade == 1:
output[1:6] = mv.value[1:6]
elif grade == 2:
output[6:16] = mv.value[6:16]
elif grade == 3:
output[16:26] = mv.value[16:26]
elif grade == 4:
output[26:31] = mv.value[26:31]
elif grade == 5:
output[31] = mv.value[31]
return mv.layout.MultiVector(output)
return layout.MultiVector(val)(grade).value


def random_conformal_point(l_max=10):
Expand Down Expand Up @@ -1047,7 +1028,7 @@ def dorst_norm_val(sigma_val):
@numba.jit
def dorst_norm(sigma):
""" Square Root of Rotors - Implements the norm of a rotor"""
sigma_4 = _project(sigma, 4)
sigma_4 = sigma(4)
sqrd_ans = sigma.value[0] ** 2 - (sigma_4 * sigma_4).value[0]
return math.sqrt(sqrd_ans)

Expand Down Expand Up @@ -1155,7 +1136,7 @@ def val_annihilate_k(K_val, C_val):
@numba.njit
def annihilate_k(K, C):
""" Removes K from C = KX via (K[0] - K[4])*C """
k_4 = K.value[0] - _project(K, 4)
k_4 = K.value[0] - K(4)
return normalised(k_4 * C)


Expand Down Expand Up @@ -1352,8 +1333,8 @@ def motor_between_rounds(X1, X2):
R = rotor_between_objects_root(F1, F2)
X3 = apply_rotor(X1, R)

C1 = normalise_n_minus_1(_project(X3 * ninf * X3, 1))
C2 = normalise_n_minus_1(_project(X2 * ninf * X2, 1))
C1 = normalise_n_minus_1((X3 * ninf * X3)(1))
C2 = normalise_n_minus_1((X2 * ninf * X2)(1))

t = layout.MultiVector(np.zeros(32))
t.value[1:4] = (C2 - C1).value[1:4]
Expand Down Expand Up @@ -1441,14 +1422,14 @@ def rotor_between_objects_root(X1, X2):
if gamma > 0:
C = 1 + gamma*(X2 * X1)
if abs(C.value[0]) < 1E-6:
R = normalised(_project(I5eo * X21, 2))
R = normalised((I5eo * X21)(2))
return normalised(R * rotor_between_objects_root(X1, -X2))
return normalised(pos_twiddle_root(C)[0])
else:
C = 1 - X21
if abs(C.value[0]) < 1E-6:
R = _project(I5eo * X21, 2)
R = normalised(_project(R * biv3dmask, 2))
R = (I5eo * X21)(2)
R = normalised((R * biv3dmask)(2))
R2 = normalised(rotor_between_objects_root(apply_rotor(X1, R), X2))
return normalised(R2 * R)
else:
Expand Down Expand Up @@ -1570,7 +1551,7 @@ def rotor_between_lines(L1, L2):
L21 = layout.MultiVector(sparse_line_gmt(L2.value, L1.value))
L12 = layout.MultiVector(sparse_line_gmt(L1.value, L2.value))
K = L21 + L12 + 2.0
beta = _project(K, 4)
beta = K(4)
alpha = 2 * K.value[0]

denominator = np.sqrt(alpha / 2)
Expand Down

0 comments on commit a00a872

Please sign in to comment.