-
Notifications
You must be signed in to change notification settings - Fork 71
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #189 from eric-wieser/numba-extension
Add primitive numba support for Layout and MultiVector
- Loading branch information
Showing
5 changed files
with
251 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from ._multivector import MultiVectorType | ||
from ._layout import LayoutType |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import numba | ||
import numba.extending | ||
try: | ||
# module locations as of numba 0.49.0 | ||
from numba.core import cgutils, types | ||
from numba.core.imputils import lower_constant | ||
except ImportError: | ||
# module locations prior to numba 0.49.0 | ||
from numba import cgutils, types | ||
from numba.targets.imputils import lower_constant | ||
|
||
from .._layout import Layout | ||
|
||
|
||
opaque_layout = types.Opaque('Opaque(Layout)') | ||
|
||
|
||
class LayoutType(types.Type): | ||
def __init__(self): | ||
super().__init__("LayoutType") | ||
|
||
|
||
@numba.extending.register_model(LayoutType) | ||
class LayoutModel(numba.extending.models.StructModel): | ||
def __init__(self, dmm, fe_typ): | ||
members = [ | ||
('obj', opaque_layout), | ||
] | ||
super().__init__(dmm, fe_typ, members) | ||
|
||
|
||
@numba.extending.typeof_impl.register(Layout) | ||
def _typeof_Layout(val: Layout, c) -> LayoutType: | ||
return LayoutType() | ||
|
||
|
||
# Derived from the `Dispatcher` boxing | ||
|
||
@lower_constant(LayoutType) | ||
def lower_constant_dispatcher(context, builder, typ, pyval): | ||
layout = cgutils.create_struct_proxy(typ)(context, builder) | ||
layout.obj = context.add_dynamic_addr(builder, id(pyval), info=type(pyval).__name__) | ||
return layout._getvalue() | ||
|
||
|
||
@numba.extending.unbox(LayoutType) | ||
def unbox_Layout(typ, obj, context): | ||
layout = cgutils.create_struct_proxy(typ)(context.context, context.builder) | ||
layout.obj = obj | ||
return numba.extending.NativeValue(layout._getvalue()) | ||
|
||
|
||
@numba.extending.box(LayoutType) | ||
def box_Layout(typ, val, context): | ||
val = cgutils.create_struct_proxy(typ)(context.context, context.builder, value=val) | ||
obj = val.obj | ||
context.pyapi.incref(obj) | ||
return obj |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
""" | ||
Numba support for MultiVector objects. | ||
For now, this just supports .value wrapping / unwrapping | ||
""" | ||
import numpy as np | ||
import numba | ||
from numba.extending import NativeValue | ||
|
||
try: | ||
# module locations as of numba 0.49.0 | ||
import numba.np.numpy_support as _numpy_support | ||
from numba.core.imputils import impl_ret_borrowed, lower_constant | ||
from numba.core import cgutils, types | ||
except ImportError: | ||
# module locations prior to numba 0.49.0 | ||
import numba.numpy_support as _numpy_support | ||
from numba.targets.imputils import impl_ret_borrowed, lower_constant | ||
from numba import cgutils, types | ||
|
||
from .._multivector import MultiVector | ||
|
||
from ._layout import LayoutType | ||
|
||
__all__ = ['MultiVectorType'] | ||
|
||
|
||
class MultiVectorType(types.Type): | ||
def __init__(self, dtype: np.dtype): | ||
assert isinstance(dtype, np.dtype) | ||
self.dtype = dtype | ||
super().__init__(name='MultiVector[{!r}]'.format(numba.from_dtype(dtype))) | ||
|
||
@property | ||
def key(self): | ||
return self.dtype | ||
|
||
@property | ||
def value_type(self): | ||
return numba.from_dtype(self.dtype)[:] | ||
|
||
@property | ||
def layout_type(self): | ||
return LayoutType() | ||
|
||
|
||
# The docs say we should use register a function to determine the numba type | ||
# with `@numba.extending.typeof_impl.register(MultiVector)`, but this is way | ||
# too slow (https://github.com/numba/numba/issues/5839). Instead, we use the | ||
# undocumented `_numba_type_` attribute, and use our own cache. In future | ||
# this may need to be a weak cache, but for now the objects are tiny anyway. | ||
_cache = {} | ||
|
||
@property | ||
def _numba_type_(self): | ||
dt = self.value.dtype | ||
try: | ||
return _cache[dt] | ||
except KeyError: | ||
ret = _cache[dt] = MultiVectorType(dtype=dt) | ||
return ret | ||
|
||
MultiVector._numba_type_ = _numba_type_ | ||
|
||
|
||
@numba.extending.register_model(MultiVectorType) | ||
class MultiVectorModel(numba.extending.models.StructModel): | ||
def __init__(self, dmm, fe_type): | ||
members = [ | ||
('layout', fe_type.layout_type), | ||
('value', fe_type.value_type), | ||
] | ||
super().__init__(dmm, fe_type, members) | ||
|
||
|
||
@numba.extending.type_callable(MultiVector) | ||
def type_MultiVector(context): | ||
def typer(layout, value): | ||
if isinstance(layout, LayoutType) and isinstance(value, types.Array): | ||
return MultiVectorType(_numpy_support.as_dtype(value.dtype)) | ||
return typer | ||
|
||
|
||
@numba.extending.lower_builtin(MultiVector, LayoutType, types.Any) | ||
def impl_MultiVector(context, builder, sig, args): | ||
typ = sig.return_type | ||
layout, value = args | ||
mv = cgutils.create_struct_proxy(typ)(context, builder) | ||
mv.layout = layout | ||
mv.value = value | ||
return impl_ret_borrowed(context, builder, sig.return_type, mv._getvalue()) | ||
|
||
|
||
@lower_constant(MultiVectorType) | ||
def lower_constant_MultiVector(context, builder, typ: MultiVectorType, pyval: MultiVector): | ||
value = context.get_constant_generic(builder, typ.value_type, pyval.value) | ||
layout = context.get_constant_generic(builder, typ.layout_type, pyval.layout) | ||
return impl_ret_borrowed( | ||
context, | ||
builder, | ||
typ, | ||
cgutils.pack_struct(builder, (layout, value)), | ||
) | ||
|
||
|
||
@numba.extending.unbox(MultiVectorType) | ||
def unbox_MultiVector(typ: MultiVectorType, obj: MultiVector, c) -> NativeValue: | ||
value = c.pyapi.object_getattr_string(obj, "value") | ||
layout = c.pyapi.object_getattr_string(obj, "layout") | ||
mv = cgutils.create_struct_proxy(typ)(c.context, c.builder) | ||
mv.layout = c.unbox(typ.layout_type, layout).value | ||
mv.value = c.unbox(typ.value_type, value).value | ||
c.pyapi.decref(value) | ||
c.pyapi.decref(layout) | ||
is_error = cgutils.is_not_null(c.builder, c.pyapi.err_occurred()) | ||
return NativeValue(mv._getvalue(), is_error=is_error) | ||
|
||
|
||
@numba.extending.box(MultiVectorType) | ||
def box_MultiVector(typ: MultiVectorType, val: NativeValue, c) -> MultiVector: | ||
mv = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val) | ||
mv_obj = c.box(typ.value_type, mv.value) | ||
layout_obj = c.box(typ.layout_type, mv.layout) | ||
|
||
# All the examples use `c.pyapi.unserialize(c.pyapi.serialize_object(MultiVector))` here. | ||
# Doing so is much slower, as it incurs pickle. This is probably safe. | ||
class_obj_ptr = c.context.add_dynamic_addr(c.builder, id(MultiVector), info=MultiVector.__name__) | ||
class_obj = c.builder.bitcast(class_obj_ptr, c.pyapi.pyobj) | ||
res = c.pyapi.call_function_objargs(class_obj, (layout_obj, mv_obj)) | ||
c.pyapi.decref(mv_obj) | ||
c.pyapi.decref(layout_obj) | ||
return res | ||
|
||
|
||
numba.extending.make_attribute_wrapper(MultiVectorType, 'value', 'value') | ||
numba.extending.make_attribute_wrapper(MultiVectorType, 'layout', 'layout') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import numba | ||
|
||
from clifford.g3c import layout, e1, e2 | ||
import clifford as cf | ||
|
||
|
||
@numba.njit | ||
def identity(x): | ||
return x | ||
|
||
|
||
class TestBasic: | ||
""" Test very simple construction and field access """ | ||
|
||
def test_roundtrip_layout(self): | ||
layout_r = identity(layout) | ||
assert type(layout_r) is type(layout) | ||
assert layout_r is layout | ||
|
||
def test_roundtrip_mv(self): | ||
e1_r = identity(e1) | ||
assert type(e1_r) is type(e1_r) | ||
|
||
# mvs are values, and not preserved by identity | ||
assert e1_r.layout is e1.layout | ||
assert e1_r == e1 | ||
|
||
def test_piecewise_construction(self): | ||
@numba.njit | ||
def negate(a): | ||
return cf.MultiVector(a.layout, -a.value) | ||
|
||
n_e1 = negate(e1) | ||
assert n_e1.layout is e1.layout | ||
assert n_e1 == -e1 | ||
|
||
@numba.njit | ||
def add(a, b): | ||
return cf.MultiVector(a.layout, a.value + b.value) | ||
|
||
ab = add(e1, e2) | ||
assert ab == e1 + e2 | ||
assert ab.layout is e1.layout | ||
|
||
def test_constant_multivector(self): | ||
@numba.njit | ||
def add_e1(a): | ||
return cf.MultiVector(a.layout, a.value + e1.value) | ||
|
||
assert add_e1(e2) == e1 + e2 |