Skip to content

Commit

Permalink
Merge pull request #331 from eric-wieser/layout-in-type
Browse files Browse the repository at this point in the history
numba: Change layout to be stored in type itself
  • Loading branch information
eric-wieser committed Jun 10, 2020
2 parents d45b964 + bdd0276 commit 6bbd6cd
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 62 deletions.
9 changes: 8 additions & 1 deletion clifford/_layout_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""

from typing import TypeVar, Generic, Sequence, Tuple, List
from typing import TypeVar, Generic, Sequence, Tuple, List, Optional
import numpy as np
import functools
import operator
Expand Down Expand Up @@ -252,3 +252,10 @@ def __reduce__(self):
return __class__, (self._n,)
else:
return __class__, (self._n, self._first_index)


def layout_short_name(layout) -> Optional[str]:
""" helper to get the short name of a layout """
if hasattr(layout, '__name__') and '__module__' in layout.__dict__:
return "{l.__module__}.{l.__name__}".format(l=layout)
return None
15 changes: 9 additions & 6 deletions clifford/_multivector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import clifford as cf
from . import general_exp
from . import _settings
from ._layout_helpers import layout_short_name


class MultiVector(object):
Expand Down Expand Up @@ -507,11 +508,12 @@ def __repr__(self) -> str:
else:
dtype_str = None

if hasattr(self.layout, '__name__') and '__module__' in self.layout.__dict__:
fmt = "{l.__module__}.{l.__name__}.MultiVector({v!r}{d})"
l_name = layout_short_name(self.layout)
args = dict(v=list(self.value), d=dtype_str)
if l_name is not None:
return "{l}.MultiVector({v!r}{d})".format(l=l_name, **args)
else:
fmt = "{l!r}.MultiVector({v!r}{d})"
return fmt.format(l=self.layout, v=list(self.value), d=dtype_str)
return "{l!r}.MultiVector({v!r}{d})".format(l=self.layout, **args)

def _repr_pretty_(self, p, cycle):
if cycle:
Expand All @@ -521,8 +523,9 @@ def _repr_pretty_(self, p, cycle):
p.text(str(self))
return

if hasattr(self.layout, '__name__') and '__module__' in self.layout.__dict__:
prefix = "{l.__module__}.{l.__name__}.MultiVector(".format(l=self.layout)
l_name = layout_short_name(self.layout)
if l_name is not None:
prefix = "{}.MultiVector(".format(l_name)
include_layout = False
else:
include_layout = True
Expand Down
77 changes: 47 additions & 30 deletions clifford/numba/_layout.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,75 @@
import numba
import numba.extending
from numba.extending import NativeValue
import llvmlite.ir
try:
# module locations as of numba 0.49.0
from numba.core import cgutils, types
from numba.core import types
from numba.core.imputils import lower_constant
except ImportError:
# module locations prior to numba 0.49.0
from numba import cgutils, types
from numba import types
from numba.targets.imputils import lower_constant

from .._layout import Layout
from .._layout import Layout, _cached_property
from .._layout_helpers import layout_short_name
from .._multivector import MultiVector


opaque_layout = types.Opaque('Opaque(Layout)')
# In future we want to store some of the layout in the type (the `order` etc),
# but store the `names` in the layout instances, so that we can reuse jitted
# functions across different basis vector names.


class LayoutType(types.Type):
def __init__(self):
super().__init__("LayoutType")
class LayoutType(types.Dummy):
def __init__(self, layout):
self.obj = layout
# cache of multivector types for this layout
self._cache = {}
layout_name = layout_short_name(layout)
if layout_name is not None:
name = "LayoutType({})".format(layout_name)
else:
name = "LayoutType({!r})".format(layout)
super().__init__(name)


@numba.extending.register_model(LayoutType)
class LayoutModel(numba.extending.models.StructModel):
def __init__(self, dmm, fe_typ):
members = [
('obj', opaque_layout),
]
super().__init__(dmm, fe_typ, members)
class LayoutModel(numba.extending.models.OpaqueModel):
pass

# The docs say we should use register a function to determine the numba type
# with `@numba.extending.typeof_impl.register(LayoutType)`, but this is way
# too slow (https://github.com/numba/numba/issues/5839). Instead, we use the
# undocumented `_numba_type_` attribute, and use our own cache.

@numba.extending.typeof_impl.register(Layout)
def _typeof_Layout(val: Layout, c) -> LayoutType:
return LayoutType()
@_cached_property
def _numba_type_(self):
return LayoutType(self)

Layout._numba_type_ = _numba_type_

# Derived from the `Dispatcher` boxing

@lower_constant(LayoutType)
def lower_constant_dispatcher(context, builder, typ, pyval):
layout = cgutils.create_struct_proxy(typ)(context, builder)
layout.obj = context.add_dynamic_addr(builder, id(pyval), info=type(pyval).__name__)
return layout._getvalue()
def lower_constant_Layout(context, builder, typ: LayoutType, pyval: Layout) -> llvmlite.ir.Value:
return context.get_dummy_value()


@numba.extending.unbox(LayoutType)
def unbox_Layout(typ, obj, context):
layout = cgutils.create_struct_proxy(typ)(context.context, context.builder)
layout.obj = obj
return numba.extending.NativeValue(layout._getvalue())
def unbox_Layout(typ: LayoutType, obj: Layout, c) -> NativeValue:
return NativeValue(c.context.get_dummy_value())

# Derived from the `Dispatcher` boxing

@numba.extending.box(LayoutType)
def box_Layout(typ, val, context):
val = cgutils.create_struct_proxy(typ)(context.context, context.builder, value=val)
obj = val.obj
context.pyapi.incref(obj)
def box_Layout(typ: LayoutType, val: llvmlite.ir.Value, c) -> Layout:
obj = c.context.add_dynamic_addr(c.builder, id(typ.obj), info=typ.name)
c.pyapi.incref(obj)
return obj

# methods

@numba.extending.overload_method(LayoutType, 'MultiVector')
def Layout_MultiVector(self, value):
def impl(self, value):
return MultiVector(self, value)
return impl
51 changes: 26 additions & 25 deletions clifford/numba/_multivector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
For now, this just supports .value wrapping / unwrapping
"""
import numpy as np
import numba
from numba.extending import NativeValue
import llvmlite.ir

try:
# module locations as of numba 0.49.0
Expand All @@ -26,38 +26,43 @@


class MultiVectorType(types.Type):
def __init__(self, dtype: np.dtype):
assert isinstance(dtype, np.dtype)
self.dtype = dtype
super().__init__(name='MultiVector[{!r}]'.format(numba.from_dtype(dtype)))
def __init__(self, layout: LayoutType, dtype: types.DType):
self.layout_type = layout
self._scalar_type = dtype
super().__init__(name='MultiVector({!r}, {!r})'.format(
self.layout_type, self._scalar_type
))

@property
def key(self):
return self.dtype
return self.layout_type, self._scalar_type

@property
def value_type(self):
return numba.from_dtype(self.dtype)[:]

@property
def layout_type(self):
return LayoutType()
return self._scalar_type[:]


# The docs say we should use register a function to determine the numba type
# with `@numba.extending.typeof_impl.register(MultiVector)`, but this is way
# too slow (https://github.com/numba/numba/issues/5839). Instead, we use the
# undocumented `_numba_type_` attribute, and use our own cache. In future
# this may need to be a weak cache, but for now the objects are tiny anyway.
_cache = {}

@property
def _numba_type_(self):
layout_type = self.layout._numba_type_

cache = layout_type._cache
dt = self.value.dtype

# now use the dtype to key that cache.
try:
return _cache[dt]
return cache[dt]
except KeyError:
ret = _cache[dt] = MultiVectorType(dtype=dt)
# Computing and hashing `dtype_type` is slow, so we do not use it as a
# hash key. The raw numpy dtype is much faster to use as a key.
dtype_type = _numpy_support.from_dtype(dt)
ret = cache[dt] = MultiVectorType(layout_type, dtype_type)
return ret

MultiVector._numba_type_ = _numba_type_
Expand All @@ -77,7 +82,7 @@ def __init__(self, dmm, fe_type):
def type_MultiVector(context):
def typer(layout, value):
if isinstance(layout, LayoutType) and isinstance(value, types.Array):
return MultiVectorType(_numpy_support.as_dtype(value.dtype))
return MultiVectorType(layout, value.dtype)
return typer


Expand All @@ -92,15 +97,11 @@ def impl_MultiVector(context, builder, sig, args):


@lower_constant(MultiVectorType)
def lower_constant_MultiVector(context, builder, typ: MultiVectorType, pyval: MultiVector):
value = context.get_constant_generic(builder, typ.value_type, pyval.value)
layout = context.get_constant_generic(builder, typ.layout_type, pyval.layout)
return impl_ret_borrowed(
context,
builder,
typ,
cgutils.pack_struct(builder, (layout, value)),
)
def lower_constant_MultiVector(context, builder, typ: MultiVectorType, pyval: MultiVector) -> llvmlite.ir.Value:
mv = cgutils.create_struct_proxy(typ)(context, builder)
mv.value = context.get_constant_generic(builder, typ.value_type, pyval.value)
mv.layout = context.get_constant_generic(builder, typ.layout_type, pyval.layout)
return mv._getvalue()


@numba.extending.unbox(MultiVectorType)
Expand All @@ -117,7 +118,7 @@ def unbox_MultiVector(typ: MultiVectorType, obj: MultiVector, c) -> NativeValue:


@numba.extending.box(MultiVectorType)
def box_MultiVector(typ: MultiVectorType, val: NativeValue, c) -> MultiVector:
def box_MultiVector(typ: MultiVectorType, val: llvmlite.ir.Value, c) -> MultiVector:
mv = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val)
mv_obj = c.box(typ.value_type, mv.value)
layout_obj = c.box(typ.layout_type, mv.layout)
Expand Down
7 changes: 7 additions & 0 deletions clifford/test/test_numba_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,10 @@ def add_e1(a):
return cf.MultiVector(a.layout, a.value + e1.value)

assert add_e1(e2) == e1 + e2

def test_multivector_shorthand(self):
@numba.njit
def double(a):
return a.layout.MultiVector(a.value*2)

assert double(e2) == 2 * e2

0 comments on commit 6bbd6cd

Please sign in to comment.