Skip to content

Commit

Permalink
Merge pull request #189 from eric-wieser/numba-extension
Browse files Browse the repository at this point in the history
Add primitive numba support for Layout and MultiVector
  • Loading branch information
eric-wieser committed Jun 10, 2020
2 parents bf74bad + 8f4f247 commit d45b964
Show file tree
Hide file tree
Showing 5 changed files with 251 additions and 2 deletions.
7 changes: 5 additions & 2 deletions clifford/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@

# Major library imports.
import numpy as np
import numba
import numba as _numba # to avoid clashing with clifford.numba
import sparse
try:
from numba.np import numpy_support as _numpy_support
Expand Down Expand Up @@ -147,7 +147,7 @@ def get_mult_function(mt: sparse.COO, gradeList,
return _get_mult_function_runtime_sparse(mt)


def _get_mult_function_result_type(a: numba.types.Type, b: numba.types.Type, mt: np.dtype):
def _get_mult_function_result_type(a: _numba.types.Type, b: _numba.types.Type, mt: np.dtype):
a_dt = _numpy_support.as_dtype(getattr(a, 'dtype', a))
b_dt = _numpy_support.as_dtype(getattr(b, 'dtype', b))
return np.result_type(a_dt, mt, b_dt)
Expand Down Expand Up @@ -325,6 +325,9 @@ def val_get_right_gmt_matrix(mt: sparse.COO, x):
from ._layout_helpers import BasisVectorIds, BasisBladeOrder # noqa: F401
from ._mvarray import MVArray, array # noqa: F401
from ._frame import Frame # noqa: F401

# this registers the extension type
from . import numba # noqa: F401
from ._blademap import BladeMap # noqa: F401


Expand Down
2 changes: 2 additions & 0 deletions clifford/numba/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from ._multivector import MultiVectorType
from ._layout import LayoutType
58 changes: 58 additions & 0 deletions clifford/numba/_layout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import numba
import numba.extending
try:
# module locations as of numba 0.49.0
from numba.core import cgutils, types
from numba.core.imputils import lower_constant
except ImportError:
# module locations prior to numba 0.49.0
from numba import cgutils, types
from numba.targets.imputils import lower_constant

from .._layout import Layout


opaque_layout = types.Opaque('Opaque(Layout)')


class LayoutType(types.Type):
def __init__(self):
super().__init__("LayoutType")


@numba.extending.register_model(LayoutType)
class LayoutModel(numba.extending.models.StructModel):
def __init__(self, dmm, fe_typ):
members = [
('obj', opaque_layout),
]
super().__init__(dmm, fe_typ, members)


@numba.extending.typeof_impl.register(Layout)
def _typeof_Layout(val: Layout, c) -> LayoutType:
return LayoutType()


# Derived from the `Dispatcher` boxing

@lower_constant(LayoutType)
def lower_constant_dispatcher(context, builder, typ, pyval):
layout = cgutils.create_struct_proxy(typ)(context, builder)
layout.obj = context.add_dynamic_addr(builder, id(pyval), info=type(pyval).__name__)
return layout._getvalue()


@numba.extending.unbox(LayoutType)
def unbox_Layout(typ, obj, context):
layout = cgutils.create_struct_proxy(typ)(context.context, context.builder)
layout.obj = obj
return numba.extending.NativeValue(layout._getvalue())


@numba.extending.box(LayoutType)
def box_Layout(typ, val, context):
val = cgutils.create_struct_proxy(typ)(context.context, context.builder, value=val)
obj = val.obj
context.pyapi.incref(obj)
return obj
136 changes: 136 additions & 0 deletions clifford/numba/_multivector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""
Numba support for MultiVector objects.
For now, this just supports .value wrapping / unwrapping
"""
import numpy as np
import numba
from numba.extending import NativeValue

try:
# module locations as of numba 0.49.0
import numba.np.numpy_support as _numpy_support
from numba.core.imputils import impl_ret_borrowed, lower_constant
from numba.core import cgutils, types
except ImportError:
# module locations prior to numba 0.49.0
import numba.numpy_support as _numpy_support
from numba.targets.imputils import impl_ret_borrowed, lower_constant
from numba import cgutils, types

from .._multivector import MultiVector

from ._layout import LayoutType

__all__ = ['MultiVectorType']


class MultiVectorType(types.Type):
def __init__(self, dtype: np.dtype):
assert isinstance(dtype, np.dtype)
self.dtype = dtype
super().__init__(name='MultiVector[{!r}]'.format(numba.from_dtype(dtype)))

@property
def key(self):
return self.dtype

@property
def value_type(self):
return numba.from_dtype(self.dtype)[:]

@property
def layout_type(self):
return LayoutType()


# The docs say we should use register a function to determine the numba type
# with `@numba.extending.typeof_impl.register(MultiVector)`, but this is way
# too slow (https://github.com/numba/numba/issues/5839). Instead, we use the
# undocumented `_numba_type_` attribute, and use our own cache. In future
# this may need to be a weak cache, but for now the objects are tiny anyway.
_cache = {}

@property
def _numba_type_(self):
dt = self.value.dtype
try:
return _cache[dt]
except KeyError:
ret = _cache[dt] = MultiVectorType(dtype=dt)
return ret

MultiVector._numba_type_ = _numba_type_


@numba.extending.register_model(MultiVectorType)
class MultiVectorModel(numba.extending.models.StructModel):
def __init__(self, dmm, fe_type):
members = [
('layout', fe_type.layout_type),
('value', fe_type.value_type),
]
super().__init__(dmm, fe_type, members)


@numba.extending.type_callable(MultiVector)
def type_MultiVector(context):
def typer(layout, value):
if isinstance(layout, LayoutType) and isinstance(value, types.Array):
return MultiVectorType(_numpy_support.as_dtype(value.dtype))
return typer


@numba.extending.lower_builtin(MultiVector, LayoutType, types.Any)
def impl_MultiVector(context, builder, sig, args):
typ = sig.return_type
layout, value = args
mv = cgutils.create_struct_proxy(typ)(context, builder)
mv.layout = layout
mv.value = value
return impl_ret_borrowed(context, builder, sig.return_type, mv._getvalue())


@lower_constant(MultiVectorType)
def lower_constant_MultiVector(context, builder, typ: MultiVectorType, pyval: MultiVector):
value = context.get_constant_generic(builder, typ.value_type, pyval.value)
layout = context.get_constant_generic(builder, typ.layout_type, pyval.layout)
return impl_ret_borrowed(
context,
builder,
typ,
cgutils.pack_struct(builder, (layout, value)),
)


@numba.extending.unbox(MultiVectorType)
def unbox_MultiVector(typ: MultiVectorType, obj: MultiVector, c) -> NativeValue:
value = c.pyapi.object_getattr_string(obj, "value")
layout = c.pyapi.object_getattr_string(obj, "layout")
mv = cgutils.create_struct_proxy(typ)(c.context, c.builder)
mv.layout = c.unbox(typ.layout_type, layout).value
mv.value = c.unbox(typ.value_type, value).value
c.pyapi.decref(value)
c.pyapi.decref(layout)
is_error = cgutils.is_not_null(c.builder, c.pyapi.err_occurred())
return NativeValue(mv._getvalue(), is_error=is_error)


@numba.extending.box(MultiVectorType)
def box_MultiVector(typ: MultiVectorType, val: NativeValue, c) -> MultiVector:
mv = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val)
mv_obj = c.box(typ.value_type, mv.value)
layout_obj = c.box(typ.layout_type, mv.layout)

# All the examples use `c.pyapi.unserialize(c.pyapi.serialize_object(MultiVector))` here.
# Doing so is much slower, as it incurs pickle. This is probably safe.
class_obj_ptr = c.context.add_dynamic_addr(c.builder, id(MultiVector), info=MultiVector.__name__)
class_obj = c.builder.bitcast(class_obj_ptr, c.pyapi.pyobj)
res = c.pyapi.call_function_objargs(class_obj, (layout_obj, mv_obj))
c.pyapi.decref(mv_obj)
c.pyapi.decref(layout_obj)
return res


numba.extending.make_attribute_wrapper(MultiVectorType, 'value', 'value')
numba.extending.make_attribute_wrapper(MultiVectorType, 'layout', 'layout')
50 changes: 50 additions & 0 deletions clifford/test/test_numba_extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import numba

from clifford.g3c import layout, e1, e2
import clifford as cf


@numba.njit
def identity(x):
return x


class TestBasic:
""" Test very simple construction and field access """

def test_roundtrip_layout(self):
layout_r = identity(layout)
assert type(layout_r) is type(layout)
assert layout_r is layout

def test_roundtrip_mv(self):
e1_r = identity(e1)
assert type(e1_r) is type(e1_r)

# mvs are values, and not preserved by identity
assert e1_r.layout is e1.layout
assert e1_r == e1

def test_piecewise_construction(self):
@numba.njit
def negate(a):
return cf.MultiVector(a.layout, -a.value)

n_e1 = negate(e1)
assert n_e1.layout is e1.layout
assert n_e1 == -e1

@numba.njit
def add(a, b):
return cf.MultiVector(a.layout, a.value + b.value)

ab = add(e1, e2)
assert ab == e1 + e2
assert ab.layout is e1.layout

def test_constant_multivector(self):
@numba.njit
def add_e1(a):
return cf.MultiVector(a.layout, a.value + e1.value)

assert add_e1(e2) == e1 + e2

0 comments on commit d45b964

Please sign in to comment.