Skip to content

Commit

Permalink
Split up table computation to compute each table separately (#297)
Browse files Browse the repository at this point in the history
This means that subclasses can override the `gmt` attribute, and will automatically get correct `imt`, `omt`, and `lcmt` tables.

Performance impact seems negligeable.
  • Loading branch information
eric-wieser committed Mar 31, 2020
1 parent 557d38f commit 3c70a07
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 74 deletions.
172 changes: 98 additions & 74 deletions clifford/_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class _cached_property:
def __init__(self, getter):
self.fget = getter
self.__name__ = getter.__name__
self.__doc__ = getter.__doc__

def __get__(self, obj, cls):
if obj is None:
Expand Down Expand Up @@ -73,6 +74,8 @@ def imt_check(grade_v, grade_i, grade_j):
"""
A check used in imt table generation
"""
# A_r . B_s = <A_r B_s>_|r-s|
# if r, s != 0
return (grade_v == abs(grade_i - grade_j)) and (grade_i != 0) and (grade_j != 0)


Expand All @@ -81,6 +84,7 @@ def omt_check(grade_v, grade_i, grade_j):
"""
A check used in omt table generation
"""
# A_r ^ B_s = <A_r B_s>_|r+s|
return grade_v == (grade_i + grade_j)


Expand All @@ -89,75 +93,87 @@ def lcmt_check(grade_v, grade_i, grade_j):
"""
A check used in lcmt table generation
"""
# A_r _| B_s = <A_r B_s>_(s-r) if s-r >= 0
return grade_v == (grade_j - grade_i)


@_numba_utils.njit(parallel=NUMBA_PARALLEL, nogil=True)
def _numba_construct_tables(
index_to_grade, index_to_bitmap, bitmap_to_index, signature
def _numba_construct_gmt(
index_to_bitmap, bitmap_to_index, signature
):
array_length = int(len(index_to_grade) * len(index_to_grade))
indices = np.zeros((3, array_length), dtype=np.uint64)
k_list = indices[0, :]
l_list = indices[1, :]
m_list = indices[2, :]

imt_prod_mask = np.zeros(array_length, dtype=np.bool_)

omt_prod_mask = np.zeros(array_length, dtype=np.bool_)

lcmt_prod_mask = np.zeros(array_length, dtype=np.bool_)
n = len(index_to_bitmap)
array_length = int(n * n)
coords = np.zeros((3, array_length), dtype=np.uint64)
k_list = coords[0, :]
l_list = coords[1, :]
m_list = coords[2, :]

# use as small a type as possible to minimize type promotion
mult_table_vals = np.zeros(array_length, dtype=np.int8)

for i, grade_i in enumerate(index_to_grade):
for i in range(n):
bitmap_i = index_to_bitmap[i]

for j, grade_j in enumerate(index_to_grade):
for j in range(n):
bitmap_j = index_to_bitmap[j]
bitmap_v, mul = gmt_element(bitmap_i, bitmap_j, signature)
v = bitmap_to_index[bitmap_v]

list_ind = i * len(index_to_grade) + j
list_ind = i * n + j
k_list[list_ind] = i
l_list[list_ind] = v
m_list[list_ind] = j

mult_table_vals[list_ind] = mul
grade_v = index_to_grade[v]

# A_r . B_s = <A_r B_s>_|r-s|
# if r, s != 0
imt_prod_mask[list_ind] = imt_check(grade_v, grade_i, grade_j)

# A_r ^ B_s = <A_r B_s>_|r+s|
omt_prod_mask[list_ind] = omt_check(grade_v, grade_i, grade_j)

# A_r _| B_s = <A_r B_s>_(s-r) if s-r >= 0
lcmt_prod_mask[list_ind] = lcmt_check(grade_v, grade_i, grade_j)

return indices, mult_table_vals, imt_prod_mask, omt_prod_mask, lcmt_prod_mask
return coords, mult_table_vals


def construct_tables(
def construct_gmt(
blade_order: BasisBladeOrder, signature
) -> Tuple[sparse.COO, sparse.COO, sparse.COO, sparse.COO]:
) -> sparse.COO:
# wrap the numba one
indices, *arrs = _numba_construct_tables(
blade_order.grades,
coords, mult_table_vals = _numba_construct_gmt(
blade_order.index_to_bitmap,
blade_order.bitmap_to_index,
signature
)
dims = len(blade_order.grades)
return tuple(
sparse.COO(
coords=indices, data=arr, shape=(dims, dims, dims),
prune=True
)
for arr in arrs
return sparse.COO(coords=coords, data=mult_table_vals, shape=(dims, dims, dims))


@_numba_utils.njit(parallel=NUMBA_PARALLEL, nogil=True)
def _numba_construct_graded_mt(
index_to_grade, coords, gmt_vals, check_func
):
n_elems = coords.shape[1]

mask = np.zeros(n_elems, dtype=np.bool_)

for ind in range(coords.shape[1]):
k, l, m = coords[:, ind]

grade_k = index_to_grade[k]
grade_l = index_to_grade[l]
grade_m = index_to_grade[m]

mask[ind] = check_func(grade_l, grade_k, grade_m)

return coords[:, mask], gmt_vals[mask]


def construct_graded_mt(
blade_order: BasisBladeOrder, gmt: sparse.COO, check_func
) -> sparse.COO:
# wrap the numba one
coords, mult_table_vals = _numba_construct_graded_mt(
blade_order.grades,
gmt.coords,
gmt.data,
check_func
)
dims = len(blade_order.grades)
return sparse.COO(coords=coords, data=mult_table_vals, shape=(dims, dims, dims))


@_utils.set_module('clifford')
Expand Down Expand Up @@ -261,20 +277,6 @@ class Layout(object):
2**dims
names :
pretty-printing symbols for the blades
gmt :
multiplication table for geometric product
imt :
multiplication table for inner product
omt :
multiplication table for outer product
lcmt :
multiplication table for the left-contraction
Notes
-----
The multiplication tables :math:`M` are tensors of rank 3 such that
:math:`a = b \operatorname{op} c` can be computed as
:math:`a_j = \sum_{i,k} b_i \mathit{M}_{ijk} c_k`.
"""
# old signature
def __init__(self, sig, bladeTupList, firstIdx=1, names=None):
Expand Down Expand Up @@ -361,15 +363,42 @@ def __init__(self, *args, **kw):
"names list of length %i needs to be of length %i" %
(len(names), self.gaDims))

self._genTables()
# preload these lazy properties. Not doing this would likely be faster.
self.gmt_func
self.imt_func
self.omt_func
self.lcmt_func
self.adjoint_func
self.left_complement_func
self.right_complement_func
self.dual_func
self.vee_func
self.inv_func

@_cached_property
def gmt(self):
r""" Multiplication table for the geometric product.
This is a tensor of rank 3 such that
:math:`a = b c` can be computed as
:math:`a_j = \sum_{i,k} b_i \mathit{M}_{ijk} c_k`."""
return construct_gmt(self._basis_blade_order, self.sig)

@_cached_property
def omt(self):
""" Multiplication table for the inner product, stored in the same way as :attr:`gmt` """
return construct_graded_mt(self._basis_blade_order, self.gmt, omt_check)

@_cached_property
def imt(self):
""" Multiplication table for the outer product, stored in the same way as :attr:`gmt` """
return construct_graded_mt(self._basis_blade_order, self.gmt, imt_check)

@_cached_property
def lcmt(self):
""" Multiplication table for the left-contraction, stored in the same way as :attr:`gmt` """
return construct_graded_mt(self._basis_blade_order, self.gmt, lcmt_check)

@_cached_property
def bladeTupList(self):
return self._basis_vector_ids.order_as_tuples(self._basis_blade_order)
Expand Down Expand Up @@ -490,27 +519,6 @@ def parse_multivector(self, mv_string: str) -> MultiVector:
from ._parser import parse_multivector
return parse_multivector(self, mv_string)

def _genTables(self):
"Generate the multiplication tables."
self.gmt, imt_prod_mask, omt_prod_mask, lcmt_prod_mask = construct_tables(
self._basis_blade_order,
self.sig
)
self.omt = sparse.where(omt_prod_mask, self.gmt, self.gmt.dtype.type(0))
self.imt = sparse.where(imt_prod_mask, self.gmt, self.gmt.dtype.type(0))
self.lcmt = sparse.where(lcmt_prod_mask, self.gmt, self.gmt.dtype.type(0))

# This generates the functions that will perform the various products
self.gmt_func = get_mult_function(self.gmt, self.gradeList)
self.imt_func = get_mult_function(self.imt, self.gradeList)
self.omt_func = get_mult_function(self.omt, self.gradeList)
self.lcmt_func = get_mult_function(self.lcmt, self.gradeList)

# these are probably not useful, but someone might want them
self.imt_prod_mask = imt_prod_mask
self.omt_prod_mask = omt_prod_mask
self.lcmt_prod_mask = lcmt_prod_mask

def gmt_func_generator(self, grades_a=None, grades_b=None, filter_mask=None):
return get_mult_function(
self.gmt, self.gradeList,
Expand Down Expand Up @@ -566,6 +574,22 @@ def comp_func(Xval):
return Yval
return comp_func

@_cached_property
def gmt_func(self):
return get_mult_function(self.gmt, self.gradeList)

@_cached_property
def imt_func(self):
return get_mult_function(self.imt, self.gradeList)

@_cached_property
def omt_func(self):
return get_mult_function(self.omt, self.gradeList)

@_cached_property
def lcmt_func(self):
return get_mult_function(self.lcmt, self.gradeList)

@_cached_property
def left_complement_func(self):
return self._gen_complement_func(omt=self.omt)
Expand Down
3 changes: 3 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ Compatibility notes
numpy array of ``bytes``. The result now matches the construction order, rather
than being sorted alphabetically. The order of :meth:`Layout.metric` has
been adjusted for consistency.
* The ``imt_prod_mask``, ``omt_prod_mask``, and ``lcmt_prod_mask`` attributes
of :class:`Layout` objects have been removed, as these were an unnecessary
intermediate computation that had no need to be public.


Changes in 1.2.x
Expand Down

0 comments on commit 3c70a07

Please sign in to comment.