From 9e3724cdb83b5c0e4f9fabe39bf591f250c9c365 Mon Sep 17 00:00:00 2001 From: hugo hadfield Date: Fri, 5 Jun 2020 16:35:47 +0100 Subject: [PATCH 01/27] Added numba overloaded functions to layout --- clifford/_layout.py | 204 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 204 insertions(+) diff --git a/clifford/_layout.py b/clifford/_layout.py index 49003991..eb0ef3c7 100644 --- a/clifford/_layout.py +++ b/clifford/_layout.py @@ -5,6 +5,10 @@ import numpy as np import sparse +from numba.extending import overload +from numba import types + + # TODO: move some of these functions to this file if they're not useful anywhere # else import clifford as cf @@ -175,6 +179,177 @@ def construct_graded_mt( return sparse.COO(coords=coords, data=mult_table_vals, shape=(dims, dims, dims)) +def get_as_ga_vector_func(layout): + """ + Returns a function that converts a scalar into a GA value vector + for the given algebra + """ + scalar_index = layout._basis_blade_order.bitmap_to_index[0] + ndims = layout.gaDims + @_numba_utils.njit + def as_ga_value_vector(x): + op = np.zeros(ndims) + op[scalar_index] = x + return op + return as_ga_value_vector + + +def get_overload_add(layout): + """ + Returns an overloaded JITed function that works on + MultiVector value arrays + """ + scalar_index = layout._basis_blade_order.bitmap_to_index[0] + + def ga_add(x): + # dummy function to overload + pass + + @overload(ga_add, inline='always') + def ol_ga_add(a, b): + if isinstance(a, (types.Integer, types.Float)) and isinstance(b, types.Array): + def impl(a, b): + op = b.astype(np.float32) + op[scalar_index] += a + return op + return impl + elif isinstance(a, types.Array) and isinstance(b, (types.Integer, types.Float)): + def impl(a, b): + op = a.astype(np.float32) + op[scalar_index] += b + return op + return impl + else: + def impl(a, b): + return a + b + return impl + + return ga_add + + +def get_overload_sub(layout): + """ + Returns an overloaded JITed function that works on + MultiVector value arrays + """ + scalar_index = layout._basis_blade_order.bitmap_to_index[0] + + def ga_sub(x): + # dummy function to overload + pass + + @overload(ga_sub, inline='always') + def ol_ga_sub(a, b): + if isinstance(a, (types.Integer, types.Float)) and isinstance(b, types.Array): + def impl(a, b): + op = -b.astype(np.float32) + op[scalar_index] += a + return op + return impl + elif isinstance(a, types.Array) and isinstance(b, (types.Integer, types.Float)): + def impl(a, b): + op = a.astype(np.float32) + op[scalar_index] -= b + return op + return impl + else: + def impl(a, b): + return a - b + return impl + + return ga_sub + + +def get_overload_mul(layout): + """ + Returns an overloaded JITed function that works on + MultiVector value arrays + """ + def ga_mul(x): + # dummy function to overload + pass + + gmt_func = layout.gmt_func + @overload(ga_mul, inline='always') + def ol_ga_mul(a, b): + if isinstance(a, types.Array) and isinstance(b, types.Array): + def impl(a, b): + return gmt_func(a, b) + return impl + else: + def impl(a, b): + return a*b + return impl + + return ga_mul + + +def get_overload_xor(layout): + """ + Returns an overloaded JITed function that works on + MultiVector value arrays + """ + def ga_xor(x): + # dummy function to overload + pass + + as_ga = layout.as_ga_value_vector_func + omt_func = layout.omt_func + @overload(ga_xor, inline='always') + def ol_ga_xor(a, b): + if isinstance(a, types.Array) and isinstance(b, types.Array): + def impl(a, b): + return omt_func(a, b) + return impl + elif isinstance(a, types.Array) and isinstance(b, (types.Integer, types.Float)): + def impl(a, b): + return omt_func(a, as_ga(b)) + return impl + elif isinstance(a, (types.Integer, types.Float)) and isinstance(b, types.Array): + def impl(a, b): + return omt_func(as_ga(a), b) + return impl + else: + def impl(a, b): + return a^b + return impl + + return ga_xor + + +def get_overload_or(layout): + """ + Returns an overloaded JITed function that works on + MultiVector value arrays + """ + def ga_or(x): + # dummy function to overload + pass + + as_ga = layout.as_ga_value_vector_func + imt_func = layout.imt_func + @overload(ga_or, inline='always') + def ol_ga_or(a, b): + if isinstance(a, types.Array) and isinstance(b, types.Array): + def impl(a, b): + return imt_func(a, b) + return impl + elif isinstance(a, types.Array) and isinstance(b, (types.Integer, types.Float)): + def impl(a, b): + return imt_func(a, as_ga(b)) + return impl + elif isinstance(a, (types.Integer, types.Float)) and isinstance(b, types.Array): + def impl(a, b): + return imt_func(as_ga(a), b) + return impl + else: + def impl(a, b): + return a|b + return impl + + return ga_or + + class Layout(object): r""" Layout stores information regarding the geometric algebra itself and the internal representation of multivectors. @@ -372,6 +547,11 @@ def __init__(self, *args, **kw): self.dual_func self.vee_func self.inv_func + self.overload_mul_func + self.overload_xor_func + self.overload_or_func + self.overload_add_func + self.overload_sub_func @_cached_property def gmt(self): @@ -572,6 +752,10 @@ def comp_func(Xval): return Yval return comp_func + @_cached_property + def as_ga_value_vector_func(self): + return get_as_ga_vector_func(self) + @_cached_property def gmt_func(self): return get_mult_function(self.gmt, self.gradeList) @@ -596,6 +780,26 @@ def left_complement_func(self): def right_complement_func(self): return self._gen_complement_func(omt=self.omt.T) + @_cached_property + def overload_mul_func(self): + return get_overload_mul(self) + + @_cached_property + def overload_xor_func(self): + return get_overload_xor(self) + + @_cached_property + def overload_or_func(self): + return get_overload_or(self) + + @_cached_property + def overload_add_func(self): + return get_overload_add(self) + + @_cached_property + def overload_sub_func(self): + return get_overload_sub(self) + @_cached_property def adjoint_func(self): ''' From 03826c57bd2d80b296fa09b71610a7cb8cd1063d Mon Sep 17 00:00:00 2001 From: hugo hadfield Date: Fri, 5 Jun 2020 16:36:25 +0100 Subject: [PATCH 02/27] Added a GA specific ast transformer --- clifford/_ast_transformer.py | 54 ++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 clifford/_ast_transformer.py diff --git a/clifford/_ast_transformer.py b/clifford/_ast_transformer.py new file mode 100644 index 00000000..b0ba035f --- /dev/null +++ b/clifford/_ast_transformer.py @@ -0,0 +1,54 @@ + +import ast + + +class GATransformer(ast.NodeTransformer): + """ + This is an AST transformer that converts operations into + JITable counterparts that work on MultiVector value arrays. + We crawl the AST and convert BinOp's into numba overloaded + functions. + """ + def visit_BinOp(self, node): + if isinstance(node.op, ast.Mult): + new_node = ast.Call( + func=ast.Name(id='ga_mul', ctx=ast.Load()), + args=[node.left, node.right], + keywords=[] + ) + new_node = GATransformer().visit(new_node) + return new_node + elif isinstance(node.op, ast.BitXor): + new_node = ast.Call( + func=ast.Name(id='ga_xor', ctx=ast.Load()), + args=[node.left, node.right], + keywords=[] + ) + new_node = GATransformer().visit(new_node) + return new_node + elif isinstance(node.op, ast.BitOr): + new_node = ast.Call( + func=ast.Name(id='ga_or', ctx=ast.Load()), + args=[node.left, node.right], + keywords=[] + ) + new_node = GATransformer().visit(new_node) + return new_node + elif isinstance(node.op, ast.Add): + new_node = ast.Call( + func=ast.Name(id='ga_add', ctx=ast.Load()), + args=[node.left, node.right], + keywords=[] + ) + new_node = GATransformer().visit(new_node) + return new_node + elif isinstance(node.op, ast.Sub): + new_node = ast.Call( + func=ast.Name(id='ga_sub', ctx=ast.Load()), + args=[node.left, node.right], + keywords=[] + ) + new_node = GATransformer().visit(new_node) + return new_node + return node + From ef6125783bc34052e255c260a30350833288b455 Mon Sep 17 00:00:00 2001 From: hugo hadfield Date: Fri, 5 Jun 2020 16:37:36 +0100 Subject: [PATCH 03/27] Added a jit_func decorator to ast transform and numba jit --- clifford/jit_func.py | 61 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 clifford/jit_func.py diff --git a/clifford/jit_func.py b/clifford/jit_func.py new file mode 100644 index 00000000..33ee79cc --- /dev/null +++ b/clifford/jit_func.py @@ -0,0 +1,61 @@ + +import ast +import astpretty +import inspect +from ._numba_utils import njit +from ._ast_transformer import GATransformer + + +class jit_func(object): + + def __init__(self, layout, ast_debug=False): + self.layout = layout + self.ast_debug = ast_debug + + def __call__(self, func): + # Get the function source + fname = func.__name__ + source = inspect.getsource(func) + source = '\n'.join(source.splitlines()[1:]) # remove the decorator first line. + + # Re-write the ast + tree = ast.parse(source) + if self.ast_debug: + print('\n\n\n\n TRANFORMING FROM \n\n\n\n') + astpretty.pprint(tree) + + tree = GATransformer().visit(tree) + ast.fix_missing_locations(tree) + + if self.ast_debug: + print('\n\n\n\n TRANFORMING TO \n\n\n\n') + astpretty.pprint(tree) + + # Compile the function + co = compile(tree, '', "exec") + locals_dict = {} + exec(co, globals(), locals_dict) + new_func = locals_dict[fname] + + # Set things up into memory so that they JIT ok... + as_ga = self.layout.as_ga_value_vector_func + ga_add = self.layout.overload_add_func + ga_sub = self.layout.overload_sub_func + ga_mul = self.layout.overload_mul_func + ga_xor = self.layout.overload_xor_func + ga_or = self.layout.overload_or_func + + # globals()['as_ga'] = as_ga + # globals()['ga_add'] = ga_add + # globals()['ga_sub'] = ga_sub + # globals()['ga_mul'] = ga_mul + # globals()['ga_xor'] = ga_xor + # globals()['ga_or'] = ga_or + + # JIT the function + jitted_func = njit(new_func) + + # Wrap the JITed function + def wrapper(*args, **kwargs): + return self.layout.MultiVector(value=jitted_func(*[a.value for a in args], **kwargs)) + return wrapper From c13bc94d4fdd27862dbc46f97677e99303fdea2d Mon Sep 17 00:00:00 2001 From: hugo hadfield Date: Fri, 5 Jun 2020 17:19:45 +0100 Subject: [PATCH 04/27] Corrected jit_func, added a test --- clifford/jit_func.py | 45 ++++++++++++++------------ clifford/test/test_jit_func.py | 59 ++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 21 deletions(-) create mode 100644 clifford/test/test_jit_func.py diff --git a/clifford/jit_func.py b/clifford/jit_func.py index 33ee79cc..7f0bd583 100644 --- a/clifford/jit_func.py +++ b/clifford/jit_func.py @@ -7,7 +7,10 @@ class jit_func(object): - + """ + This is a JIT decorator that re-writes the AST and then numba JITs + the resulting function. + """ def __init__(self, layout, ast_debug=False): self.layout = layout self.ast_debug = ast_debug @@ -16,42 +19,42 @@ def __call__(self, func): # Get the function source fname = func.__name__ source = inspect.getsource(func) - source = '\n'.join(source.splitlines()[1:]) # remove the decorator first line. + # Remove the decorator first line. + source = '\n'.join(source.splitlines()[1:]) + # Remove the indentation + indentation = source.splitlines()[0].find('def') + print(indentation) + source = '\n'.join([line[indentation:] for line in source.splitlines()]) # Re-write the ast tree = ast.parse(source) if self.ast_debug: - print('\n\n\n\n TRANFORMING FROM \n\n\n\n') + print('\n\n\n\n TRANSFORMING FROM \n\n\n\n') astpretty.pprint(tree) tree = GATransformer().visit(tree) ast.fix_missing_locations(tree) if self.ast_debug: - print('\n\n\n\n TRANFORMING TO \n\n\n\n') + print('\n\n\n\n TRANSFORMING TO \n\n\n\n') astpretty.pprint(tree) + # Set things up into locals and globals so that they JIT ok... + locals_dict = {'as_ga': self.layout.as_ga_value_vector_func, + 'ga_add': self.layout.overload_add_func, + 'ga_sub': self.layout.overload_sub_func, + 'ga_mul': self.layout.overload_mul_func, + 'ga_xor': self.layout.overload_xor_func, + 'ga_or': self.layout.overload_or_func} + globs = globals() + for k, v in locals_dict.items(): + globs[k] = v + # Compile the function co = compile(tree, '', "exec") - locals_dict = {} - exec(co, globals(), locals_dict) + exec(co, globs, locals_dict) new_func = locals_dict[fname] - # Set things up into memory so that they JIT ok... - as_ga = self.layout.as_ga_value_vector_func - ga_add = self.layout.overload_add_func - ga_sub = self.layout.overload_sub_func - ga_mul = self.layout.overload_mul_func - ga_xor = self.layout.overload_xor_func - ga_or = self.layout.overload_or_func - - # globals()['as_ga'] = as_ga - # globals()['ga_add'] = ga_add - # globals()['ga_sub'] = ga_sub - # globals()['ga_mul'] = ga_mul - # globals()['ga_xor'] = ga_xor - # globals()['ga_or'] = ga_or - # JIT the function jitted_func = njit(new_func) diff --git a/clifford/test/test_jit_func.py b/clifford/test/test_jit_func.py new file mode 100644 index 00000000..55516ffd --- /dev/null +++ b/clifford/test/test_jit_func.py @@ -0,0 +1,59 @@ +import unittest +import numpy as np +import time +from ..jit_func import jit_func + + +class TestJITFunc(unittest.TestCase): + def test_compound_expression(self): + from clifford.g3c import layout, blades + e1 = blades['e1'] + e2 = blades['e2'] + einf = layout.einf + + @jit_func(layout=layout, ast_debug=True) + def test_func(A, B, C): + op = (((A * B) * C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) + return op + + def slow_test_func(A, B, C): + op = (((A * B) * C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) + return op + + np.testing.assert_allclose(test_func(e1, e2, einf).value, slow_test_func(e1, e2, einf).value) + + def test_benchmark(self): + from clifford.g3c import layout, blades + e1 = blades['e1'] + e2 = blades['e2'] + einf = layout.einf + + @jit_func(layout=layout) + def test_func(A, B, C): + op = (((A * B) * C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) + return op + + def slow_test_func(A, B, C): + op = (((A * B) * C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) + return op + + test_func(e1, e2, einf) + slow_test_func(e1, e2, einf) + + nrepeats = 10000 + start_time = time.time() + for i in range(nrepeats): + test_func(e1, e2, einf) + end_time = time.time() + print('With jit_func (us): ', 1E6 * (end_time - start_time) / nrepeats) + + nrepeats = 10000 + start_time = time.time() + for i in range(nrepeats): + slow_test_func(e1, e2, einf) + end_time = time.time() + print('Without jit_func (us): ', 1E6 * (end_time - start_time) / nrepeats) + + +if __name__ == '__main__': + unittest.main() From 51f8a42459d3faa5633706f596e55bb62856c2af Mon Sep 17 00:00:00 2001 From: hugo hadfield Date: Fri, 5 Jun 2020 18:20:44 +0100 Subject: [PATCH 05/27] remove duplication in ast_transformer --- clifford/_ast_transformer.py | 57 ++++++++++-------------------------- 1 file changed, 16 insertions(+), 41 deletions(-) diff --git a/clifford/_ast_transformer.py b/clifford/_ast_transformer.py index b0ba035f..48de5a95 100644 --- a/clifford/_ast_transformer.py +++ b/clifford/_ast_transformer.py @@ -10,45 +10,20 @@ class GATransformer(ast.NodeTransformer): functions. """ def visit_BinOp(self, node): - if isinstance(node.op, ast.Mult): - new_node = ast.Call( - func=ast.Name(id='ga_mul', ctx=ast.Load()), - args=[node.left, node.right], + ops = { + ast.Mult: 'ga_mul', + ast.BitXor: 'ga_xor', + ast.BitOr: 'ga_or', + ast.Add: 'ga_add', + ast.Sub: 'ga_sub', + } + try: + func_name = ops[type(node.op)] + except KeyError: + return node + else: + return ast.Call( + func=ast.Name(id=func_name, ctx=ast.Load()), + args=[self.visit(node.left), self.visit(node.right)], keywords=[] - ) - new_node = GATransformer().visit(new_node) - return new_node - elif isinstance(node.op, ast.BitXor): - new_node = ast.Call( - func=ast.Name(id='ga_xor', ctx=ast.Load()), - args=[node.left, node.right], - keywords=[] - ) - new_node = GATransformer().visit(new_node) - return new_node - elif isinstance(node.op, ast.BitOr): - new_node = ast.Call( - func=ast.Name(id='ga_or', ctx=ast.Load()), - args=[node.left, node.right], - keywords=[] - ) - new_node = GATransformer().visit(new_node) - return new_node - elif isinstance(node.op, ast.Add): - new_node = ast.Call( - func=ast.Name(id='ga_add', ctx=ast.Load()), - args=[node.left, node.right], - keywords=[] - ) - new_node = GATransformer().visit(new_node) - return new_node - elif isinstance(node.op, ast.Sub): - new_node = ast.Call( - func=ast.Name(id='ga_sub', ctx=ast.Load()), - args=[node.left, node.right], - keywords=[] - ) - new_node = GATransformer().visit(new_node) - return new_node - return node - + ) \ No newline at end of file From 8022092c60f02f19c1e65b390ea794b5500e60b6 Mon Sep 17 00:00:00 2001 From: hugo hadfield Date: Fri, 5 Jun 2020 18:21:11 +0100 Subject: [PATCH 06/27] convert to abstract numeric types in the numba jit overload --- clifford/_layout.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/clifford/_layout.py b/clifford/_layout.py index eb0ef3c7..d945064b 100644 --- a/clifford/_layout.py +++ b/clifford/_layout.py @@ -207,13 +207,13 @@ def ga_add(x): @overload(ga_add, inline='always') def ol_ga_add(a, b): - if isinstance(a, (types.Integer, types.Float)) and isinstance(b, types.Array): + if isinstance(a, types.abstract.Number) and isinstance(b, types.Array): def impl(a, b): op = b.astype(np.float32) op[scalar_index] += a return op return impl - elif isinstance(a, types.Array) and isinstance(b, (types.Integer, types.Float)): + elif isinstance(a, types.Array) and isinstance(b, types.abstract.Number): def impl(a, b): op = a.astype(np.float32) op[scalar_index] += b @@ -240,13 +240,13 @@ def ga_sub(x): @overload(ga_sub, inline='always') def ol_ga_sub(a, b): - if isinstance(a, (types.Integer, types.Float)) and isinstance(b, types.Array): + if isinstance(a, types.abstract.Number) and isinstance(b, types.Array): def impl(a, b): op = -b.astype(np.float32) op[scalar_index] += a return op return impl - elif isinstance(a, types.Array) and isinstance(b, (types.Integer, types.Float)): + elif isinstance(a, types.Array) and isinstance(b, types.abstract.Number): def impl(a, b): op = a.astype(np.float32) op[scalar_index] -= b @@ -301,11 +301,11 @@ def ol_ga_xor(a, b): def impl(a, b): return omt_func(a, b) return impl - elif isinstance(a, types.Array) and isinstance(b, (types.Integer, types.Float)): + elif isinstance(a, types.Array) and isinstance(b, types.abstract.Number): def impl(a, b): return omt_func(a, as_ga(b)) return impl - elif isinstance(a, (types.Integer, types.Float)) and isinstance(b, types.Array): + elif isinstance(a, types.abstract.Number) and isinstance(b, types.Array): def impl(a, b): return omt_func(as_ga(a), b) return impl @@ -334,11 +334,11 @@ def ol_ga_or(a, b): def impl(a, b): return imt_func(a, b) return impl - elif isinstance(a, types.Array) and isinstance(b, (types.Integer, types.Float)): + elif isinstance(a, types.Array) and isinstance(b, types.abstract.Number): def impl(a, b): return imt_func(a, as_ga(b)) return impl - elif isinstance(a, (types.Integer, types.Float)) and isinstance(b, types.Array): + elif isinstance(a, types.abstract.Number) and isinstance(b, types.Array): def impl(a, b): return imt_func(as_ga(a), b) return impl From f14521b27bd806f3642337e9bbde3688c73322be Mon Sep 17 00:00:00 2001 From: hugo hadfield Date: Fri, 5 Jun 2020 18:23:32 +0100 Subject: [PATCH 07/27] Improved handling globals, added a TODO --- clifford/jit_func.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/clifford/jit_func.py b/clifford/jit_func.py index 7f0bd583..d850e267 100644 --- a/clifford/jit_func.py +++ b/clifford/jit_func.py @@ -46,7 +46,10 @@ def __call__(self, func): 'ga_mul': self.layout.overload_mul_func, 'ga_xor': self.layout.overload_xor_func, 'ga_or': self.layout.overload_or_func} - globs = globals() + # TODO: Work out a better way to deal with changes to globals + globs = {} + for k, v in globals().items(): + globs[k] = v for k, v in locals_dict.items(): globs[k] = v From 5fdbb865c840887ca98ce01ec00b69d50e70f54c Mon Sep 17 00:00:00 2001 From: hugo hadfield Date: Fri, 5 Jun 2020 18:27:32 +0100 Subject: [PATCH 08/27] Added ast_pretty warning if not installed --- clifford/jit_func.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/clifford/jit_func.py b/clifford/jit_func.py index d850e267..485e6b7d 100644 --- a/clifford/jit_func.py +++ b/clifford/jit_func.py @@ -1,7 +1,12 @@ import ast -import astpretty +try: + import astpretty + AST_PRETTY_AVAILABLE = True +except: + AST_PRETTY_AVAILABLE = False import inspect +import warnings from ._numba_utils import njit from ._ast_transformer import GATransformer @@ -13,7 +18,15 @@ class jit_func(object): """ def __init__(self, layout, ast_debug=False): self.layout = layout - self.ast_debug = ast_debug + if AST_PRETTY_AVAILABLE: + self.ast_debug = ast_debug + else: + if ast_debug: + warnings.warn(''' +The ast_debug flag is set to True, but the astpretty module is not importable. +To see ast_debug output please pip install astpretty +''') + self.ast_debug = False def __call__(self, func): # Get the function source From d6c6e067a09bcc54a95ff2c6c140346a54c03ad4 Mon Sep 17 00:00:00 2001 From: hugo hadfield Date: Fri, 5 Jun 2020 18:29:12 +0100 Subject: [PATCH 09/27] removed unnescary print --- clifford/jit_func.py | 1 - 1 file changed, 1 deletion(-) diff --git a/clifford/jit_func.py b/clifford/jit_func.py index 485e6b7d..1ff426ef 100644 --- a/clifford/jit_func.py +++ b/clifford/jit_func.py @@ -36,7 +36,6 @@ def __call__(self, func): source = '\n'.join(source.splitlines()[1:]) # Remove the indentation indentation = source.splitlines()[0].find('def') - print(indentation) source = '\n'.join([line[indentation:] for line in source.splitlines()]) # Re-write the ast From 8094a61427b1a0ba00946653d7c4a5b5bfe2c673 Mon Sep 17 00:00:00 2001 From: hugo hadfield Date: Fri, 5 Jun 2020 18:51:04 +0100 Subject: [PATCH 10/27] Added reversion to AST rewriter and JIT --- clifford/_ast_transformer.py | 21 ++++++++++++++++--- clifford/_layout.py | 38 +++++++++++++++++++++++++++++----- clifford/jit_func.py | 3 ++- clifford/test/test_jit_func.py | 24 +++++++++++++++++---- 4 files changed, 73 insertions(+), 13 deletions(-) diff --git a/clifford/_ast_transformer.py b/clifford/_ast_transformer.py index 48de5a95..a0e8ce08 100644 --- a/clifford/_ast_transformer.py +++ b/clifford/_ast_transformer.py @@ -6,8 +6,8 @@ class GATransformer(ast.NodeTransformer): """ This is an AST transformer that converts operations into JITable counterparts that work on MultiVector value arrays. - We crawl the AST and convert BinOp's into numba overloaded - functions. + We crawl the AST and convert BinOps and UnaryOps into numba + overloaded functions. """ def visit_BinOp(self, node): ops = { @@ -26,4 +26,19 @@ def visit_BinOp(self, node): func=ast.Name(id=func_name, ctx=ast.Load()), args=[self.visit(node.left), self.visit(node.right)], keywords=[] - ) \ No newline at end of file + ) + + def visit_UnaryOp(self, node): + ops = { + ast.Invert: 'ga_rev' + } + try: + func_name = ops[type(node.op)] + except KeyError: + return node + else: + return ast.Call( + func=ast.Name(id=func_name, ctx=ast.Load()), + args=[self.visit(node.operand)], + keywords=[] + ) diff --git a/clifford/_layout.py b/clifford/_layout.py index d945064b..dcdefadd 100644 --- a/clifford/_layout.py +++ b/clifford/_layout.py @@ -201,7 +201,7 @@ def get_overload_add(layout): """ scalar_index = layout._basis_blade_order.bitmap_to_index[0] - def ga_add(x): + def ga_add(a, b): # dummy function to overload pass @@ -234,7 +234,7 @@ def get_overload_sub(layout): """ scalar_index = layout._basis_blade_order.bitmap_to_index[0] - def ga_sub(x): + def ga_sub(a, b): # dummy function to overload pass @@ -265,7 +265,7 @@ def get_overload_mul(layout): Returns an overloaded JITed function that works on MultiVector value arrays """ - def ga_mul(x): + def ga_mul(a, b): # dummy function to overload pass @@ -289,7 +289,7 @@ def get_overload_xor(layout): Returns an overloaded JITed function that works on MultiVector value arrays """ - def ga_xor(x): + def ga_xor(a, b): # dummy function to overload pass @@ -322,7 +322,7 @@ def get_overload_or(layout): Returns an overloaded JITed function that works on MultiVector value arrays """ - def ga_or(x): + def ga_or(a, b): # dummy function to overload pass @@ -350,6 +350,30 @@ def impl(a, b): return ga_or +def get_overload_reverse(layout): + """ + Returns an overloaded JITed function that works on + MultiVector value arrays + """ + def ga_rev(x): + # dummy function to overload + pass + + adjoint_func = layout.adjoint_func + @overload(ga_rev, inline='always') + def ol_ga_rev(x): + if isinstance(x, types.Array): + def impl(x): + return adjoint_func(x) + return impl + else: + def impl(x): + return ~x + return impl + + return ga_rev + + class Layout(object): r""" Layout stores information regarding the geometric algebra itself and the internal representation of multivectors. @@ -800,6 +824,10 @@ def overload_add_func(self): def overload_sub_func(self): return get_overload_sub(self) + @_cached_property + def overload_reverse_func(self): + return get_overload_reverse(self) + @_cached_property def adjoint_func(self): ''' diff --git a/clifford/jit_func.py b/clifford/jit_func.py index 1ff426ef..d628491c 100644 --- a/clifford/jit_func.py +++ b/clifford/jit_func.py @@ -57,7 +57,8 @@ def __call__(self, func): 'ga_sub': self.layout.overload_sub_func, 'ga_mul': self.layout.overload_mul_func, 'ga_xor': self.layout.overload_xor_func, - 'ga_or': self.layout.overload_or_func} + 'ga_or': self.layout.overload_or_func, + 'ga_rev': self.layout.overload_reverse_func} # TODO: Work out a better way to deal with changes to globals globs = {} for k, v in globals().items(): diff --git a/clifford/test/test_jit_func.py b/clifford/test/test_jit_func.py index 55516ffd..94772ed2 100644 --- a/clifford/test/test_jit_func.py +++ b/clifford/test/test_jit_func.py @@ -5,6 +5,22 @@ class TestJITFunc(unittest.TestCase): + + def test_reverse(self): + from clifford.g3c import layout, blades + e12 = blades['e12'] + + @jit_func(layout=layout, ast_debug=True) + def test_func(A): + op = ~A + return op + + def slow_test_func(A): + op = ~A + return op + + np.testing.assert_allclose(test_func(e12).value, slow_test_func(e12).value) + def test_compound_expression(self): from clifford.g3c import layout, blades e1 = blades['e1'] @@ -13,11 +29,11 @@ def test_compound_expression(self): @jit_func(layout=layout, ast_debug=True) def test_func(A, B, C): - op = (((A * B) * C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) + op = ~(((A * B) * C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) return op def slow_test_func(A, B, C): - op = (((A * B) * C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) + op = ~(((A * B) * C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) return op np.testing.assert_allclose(test_func(e1, e2, einf).value, slow_test_func(e1, e2, einf).value) @@ -30,11 +46,11 @@ def test_benchmark(self): @jit_func(layout=layout) def test_func(A, B, C): - op = (((A * B) * C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) + op = ~(((A * B) * C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) return op def slow_test_func(A, B, C): - op = (((A * B) * C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) + op = ~(((A * B) * C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) return op test_func(e1, e2, einf) From 176734239e99f16586b8233abb53db8c6904cd68 Mon Sep 17 00:00:00 2001 From: hugo hadfield Date: Fri, 5 Jun 2020 20:20:01 +0100 Subject: [PATCH 11/27] Added grade selection via the call syntax --- clifford/_ast_transformer.py | 15 ++++++++++ clifford/_layout.py | 48 ++++++++++++++++++++++++++++++++ clifford/jit_func.py | 3 +- clifford/test/test_jit_func.py | 51 ++++++++++++++++++++++------------ 4 files changed, 98 insertions(+), 19 deletions(-) diff --git a/clifford/_ast_transformer.py b/clifford/_ast_transformer.py index a0e8ce08..5ed21dd1 100644 --- a/clifford/_ast_transformer.py +++ b/clifford/_ast_transformer.py @@ -42,3 +42,18 @@ def visit_UnaryOp(self, node): args=[self.visit(node.operand)], keywords=[] ) + + def visit_Call(self, node): + try: + nfuncid = node.func.id + return node + except: + # Only allow a single grade to be selected for now + if len(node.args) == 1: + return ast.Call( + func=ast.Name(id='ga_call', ctx=ast.Load()), + args=[node.func, node.args[0]], + keywords=[] + ) + else: + return node diff --git a/clifford/_layout.py b/clifford/_layout.py index dcdefadd..246cd04b 100644 --- a/clifford/_layout.py +++ b/clifford/_layout.py @@ -374,6 +374,46 @@ def impl(x): return ga_rev +def get_project_to_grade_func(layout): + """ + Returns a function that projects a multivector to a given grade + """ + gradeList = np.array(layout.gradeList, dtype=int) + ndims = layout.gaDims + @_numba_utils.njit + def project_to_grade(A, g): + op = np.zeros(ndims) + for i in range(ndims): + if gradeList[i] == g: + op[i] = A[i] + return op + return project_to_grade + + +def get_overload_call(layout): + """ + Returns an overloaded JITed function that works on + MultiVector value arrays + """ + def ga_call(a, b): + # dummy function to overload + pass + + project_to_grade = layout.project_to_grade_func + @overload(ga_call, inline='always') + def ol_ga_call(a, b): + if isinstance(a, types.Array) and isinstance(b, types.Integer): + def impl(a, b): + return project_to_grade(a, b) + return impl + else: + def impl(a, b): + return a(b) + return impl + + return ga_call + + class Layout(object): r""" Layout stores information regarding the geometric algebra itself and the internal representation of multivectors. @@ -828,6 +868,14 @@ def overload_sub_func(self): def overload_reverse_func(self): return get_overload_reverse(self) + @_cached_property + def project_to_grade_func(self): + return get_project_to_grade_func(self) + + @_cached_property + def overload_call_func(self): + return get_overload_call(self) + @_cached_property def adjoint_func(self): ''' diff --git a/clifford/jit_func.py b/clifford/jit_func.py index d628491c..e7438f6d 100644 --- a/clifford/jit_func.py +++ b/clifford/jit_func.py @@ -58,7 +58,8 @@ def __call__(self, func): 'ga_mul': self.layout.overload_mul_func, 'ga_xor': self.layout.overload_xor_func, 'ga_or': self.layout.overload_or_func, - 'ga_rev': self.layout.overload_reverse_func} + 'ga_rev': self.layout.overload_reverse_func, + 'ga_call': self.layout.overload_call_func} # TODO: Work out a better way to deal with changes to globals globs = {} for k, v in globals().items(): diff --git a/clifford/test/test_jit_func.py b/clifford/test/test_jit_func.py index 94772ed2..6dd93904 100644 --- a/clifford/test/test_jit_func.py +++ b/clifford/test/test_jit_func.py @@ -1,10 +1,26 @@ -import unittest import numpy as np import time from ..jit_func import jit_func -class TestJITFunc(unittest.TestCase): +class TestJITFunc: + + def test_grade_selection(self): + from clifford.g3c import layout, blades + e1 = blades['e1'] + e12 = blades['e12'] + + @jit_func(layout=layout, ast_debug=True) + def test_func(A, B): + op = (A+B)(1) + return op + + def slow_test_func(A, B): + op = (A+B)(1) + return op + + np.testing.assert_allclose(test_func(e1, e12).value, slow_test_func(e1, e12).value) + def test_reverse(self): from clifford.g3c import layout, blades @@ -25,51 +41,50 @@ def test_compound_expression(self): from clifford.g3c import layout, blades e1 = blades['e1'] e2 = blades['e2'] + e34 = blades['e34'] einf = layout.einf @jit_func(layout=layout, ast_debug=True) - def test_func(A, B, C): - op = ~(((A * B) * C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) + def test_func(A, B, C, D): + op = ~(((A * B) * ~C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) + (A + D)(2) return op - def slow_test_func(A, B, C): - op = ~(((A * B) * C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) + def slow_test_func(A, B, C, D): + op = ~(((A * B) * ~C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) + (A + D)(2) return op - np.testing.assert_allclose(test_func(e1, e2, einf).value, slow_test_func(e1, e2, einf).value) + np.testing.assert_allclose(test_func(e1, e2, einf, e34).value, slow_test_func(e1, e2, einf, e34).value) def test_benchmark(self): from clifford.g3c import layout, blades e1 = blades['e1'] e2 = blades['e2'] + e34 = blades['e34'] einf = layout.einf @jit_func(layout=layout) - def test_func(A, B, C): - op = ~(((A * B) * C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) + def test_func(A, B, C, D): + op = ~(((A * B) * ~C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) + (A + D)(2) return op - def slow_test_func(A, B, C): - op = ~(((A * B) * C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) + def slow_test_func(A, B, C, D): + op = ~(((A * B) * ~C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) + (A + D)(2) return op - test_func(e1, e2, einf) - slow_test_func(e1, e2, einf) + print(test_func(e1, e2, einf, e34)) + print(slow_test_func(e1, e2, einf, e34)) nrepeats = 10000 start_time = time.time() for i in range(nrepeats): - test_func(e1, e2, einf) + test_func(e1, e2, einf, e34) end_time = time.time() print('With jit_func (us): ', 1E6 * (end_time - start_time) / nrepeats) nrepeats = 10000 start_time = time.time() for i in range(nrepeats): - slow_test_func(e1, e2, einf) + slow_test_func(e1, e2, einf, e34) end_time = time.time() print('Without jit_func (us): ', 1E6 * (end_time - start_time) / nrepeats) - -if __name__ == '__main__': - unittest.main() From 81601ce0135f8909c54937bd2b32ca4c15bbedc9 Mon Sep 17 00:00:00 2001 From: hugo hadfield Date: Sat, 6 Jun 2020 09:17:46 +0100 Subject: [PATCH 12/27] Set up pytest benchmark --- clifford/test/test_jit_func.py | 64 ++++++++++++---------------------- 1 file changed, 22 insertions(+), 42 deletions(-) diff --git a/clifford/test/test_jit_func.py b/clifford/test/test_jit_func.py index 6dd93904..88734321 100644 --- a/clifford/test/test_jit_func.py +++ b/clifford/test/test_jit_func.py @@ -1,16 +1,20 @@ import numpy as np -import time from ..jit_func import jit_func class TestJITFunc: - def test_grade_selection(self): + @classmethod + def setup_class(cls): from clifford.g3c import layout, blades - e1 = blades['e1'] - e12 = blades['e12'] + cls.layout = layout + cls.blades = blades + + def test_grade_selection(self): + e1 = self.blades['e1'] + e12 = self.blades['e12'] - @jit_func(layout=layout, ast_debug=True) + @jit_func(layout=self.layout, ast_debug=True) def test_func(A, B): op = (A+B)(1) return op @@ -21,12 +25,10 @@ def slow_test_func(A, B): np.testing.assert_allclose(test_func(e1, e12).value, slow_test_func(e1, e12).value) - def test_reverse(self): - from clifford.g3c import layout, blades - e12 = blades['e12'] + e12 = self.blades['e12'] - @jit_func(layout=layout, ast_debug=True) + @jit_func(layout=self.layout, ast_debug=True) def test_func(A): op = ~A return op @@ -38,13 +40,12 @@ def slow_test_func(A): np.testing.assert_allclose(test_func(e12).value, slow_test_func(e12).value) def test_compound_expression(self): - from clifford.g3c import layout, blades - e1 = blades['e1'] - e2 = blades['e2'] - e34 = blades['e34'] - einf = layout.einf + e1 = self.blades['e1'] + e2 = self.blades['e2'] + e34 = self.blades['e34'] + einf = self.layout.einf - @jit_func(layout=layout, ast_debug=True) + @jit_func(layout=self.layout, ast_debug=True) def test_func(A, B, C, D): op = ~(((A * B) * ~C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) + (A + D)(2) return op @@ -55,36 +56,15 @@ def slow_test_func(A, B, C, D): np.testing.assert_allclose(test_func(e1, e2, einf, e34).value, slow_test_func(e1, e2, einf, e34).value) - def test_benchmark(self): - from clifford.g3c import layout, blades - e1 = blades['e1'] - e2 = blades['e2'] - e34 = blades['e34'] - einf = layout.einf - - @jit_func(layout=layout) + def test_benchmark_jit(self, benchmark): + @jit_func(layout=self.layout) def test_func(A, B, C, D): op = ~(((A * B) * ~C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) + (A + D)(2) return op + benchmark(test_func, self.blades['e1'], self.blades['e2'], self.layout.einf, self.blades['e34']) - def slow_test_func(A, B, C, D): + def test_benchmark_nojit(self, benchmark): + def test_func(A, B, C, D): op = ~(((A * B) * ~C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) + (A + D)(2) return op - - print(test_func(e1, e2, einf, e34)) - print(slow_test_func(e1, e2, einf, e34)) - - nrepeats = 10000 - start_time = time.time() - for i in range(nrepeats): - test_func(e1, e2, einf, e34) - end_time = time.time() - print('With jit_func (us): ', 1E6 * (end_time - start_time) / nrepeats) - - nrepeats = 10000 - start_time = time.time() - for i in range(nrepeats): - slow_test_func(e1, e2, einf, e34) - end_time = time.time() - print('Without jit_func (us): ', 1E6 * (end_time - start_time) / nrepeats) - + benchmark(test_func, self.blades['e1'], self.blades['e2'], self.layout.einf, self.blades['e34']) From d9053931f77857ce98d6b493656b1938be0f621b Mon Sep 17 00:00:00 2001 From: hugo hadfield Date: Sat, 6 Jun 2020 09:18:09 +0100 Subject: [PATCH 13/27] Make node visitation recursive for Call --- clifford/_ast_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clifford/_ast_transformer.py b/clifford/_ast_transformer.py index 5ed21dd1..13a838e6 100644 --- a/clifford/_ast_transformer.py +++ b/clifford/_ast_transformer.py @@ -52,7 +52,7 @@ def visit_Call(self, node): if len(node.args) == 1: return ast.Call( func=ast.Name(id='ga_call', ctx=ast.Load()), - args=[node.func, node.args[0]], + args=[self.visit(node.func), node.args[0]], keywords=[] ) else: From 750ec8555a883c7e0c744fddf677e3ba37f40043 Mon Sep 17 00:00:00 2001 From: hugo hadfield Date: Sat, 6 Jun 2020 09:18:28 +0100 Subject: [PATCH 14/27] Add ImportError type for astpretty --- clifford/jit_func.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clifford/jit_func.py b/clifford/jit_func.py index e7438f6d..62968c76 100644 --- a/clifford/jit_func.py +++ b/clifford/jit_func.py @@ -3,7 +3,7 @@ try: import astpretty AST_PRETTY_AVAILABLE = True -except: +except ImportError: AST_PRETTY_AVAILABLE = False import inspect import warnings From e0263f81535842873f3316490eecc13b88631042 Mon Sep 17 00:00:00 2001 From: hugo hadfield Date: Sat, 6 Jun 2020 09:20:09 +0100 Subject: [PATCH 15/27] Improve warning whitespace --- clifford/jit_func.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/clifford/jit_func.py b/clifford/jit_func.py index 62968c76..fda38a36 100644 --- a/clifford/jit_func.py +++ b/clifford/jit_func.py @@ -22,10 +22,9 @@ def __init__(self, layout, ast_debug=False): self.ast_debug = ast_debug else: if ast_debug: - warnings.warn(''' -The ast_debug flag is set to True, but the astpretty module is not importable. -To see ast_debug output please pip install astpretty -''') + warnings.warn( + 'The ast_debug flag is set to True, but the astpretty module is not importable.\n' + 'To see ast_debug output please pip install astpretty') self.ast_debug = False def __call__(self, func): From e878dbeba427df5aec371ea9f866bd1b591885db Mon Sep 17 00:00:00 2001 From: hugo hadfield Date: Sat, 6 Jun 2020 09:53:23 +0100 Subject: [PATCH 16/27] Make the Call rewrite exception an AttributeError --- clifford/_ast_transformer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/clifford/_ast_transformer.py b/clifford/_ast_transformer.py index 13a838e6..ca8044cf 100644 --- a/clifford/_ast_transformer.py +++ b/clifford/_ast_transformer.py @@ -47,7 +47,7 @@ def visit_Call(self, node): try: nfuncid = node.func.id return node - except: + except AttributeError: # Only allow a single grade to be selected for now if len(node.args) == 1: return ast.Call( @@ -57,3 +57,5 @@ def visit_Call(self, node): ) else: return node + except: + return node From 482b0919d16fe3ac66177454b57f2d6e3f7c7981 Mon Sep 17 00:00:00 2001 From: hugo hadfield Date: Sat, 6 Jun 2020 17:23:46 +0100 Subject: [PATCH 17/27] Moved the decorator removal to the AST level --- clifford/_ast_transformer.py | 7 +++++++ clifford/jit_func.py | 14 ++++++++------ clifford/test/test_jit_func.py | 35 +++++++++++++--------------------- 3 files changed, 28 insertions(+), 28 deletions(-) diff --git a/clifford/_ast_transformer.py b/clifford/_ast_transformer.py index ca8044cf..27753b3c 100644 --- a/clifford/_ast_transformer.py +++ b/clifford/_ast_transformer.py @@ -2,6 +2,13 @@ import ast +class DecoratorRemover(ast.NodeTransformer): + """ Strip decorators from FunctionDefs""" + def visit_FunctionDef(self, node): + node.decorator_list = [] + return node + + class GATransformer(ast.NodeTransformer): """ This is an AST transformer that converts operations into diff --git a/clifford/jit_func.py b/clifford/jit_func.py index fda38a36..6e44fd92 100644 --- a/clifford/jit_func.py +++ b/clifford/jit_func.py @@ -7,8 +7,9 @@ AST_PRETTY_AVAILABLE = False import inspect import warnings +import textwrap from ._numba_utils import njit -from ._ast_transformer import GATransformer +from ._ast_transformer import DecoratorRemover, GATransformer class jit_func(object): @@ -31,19 +32,20 @@ def __call__(self, func): # Get the function source fname = func.__name__ source = inspect.getsource(func) - # Remove the decorator first line. - source = '\n'.join(source.splitlines()[1:]) # Remove the indentation - indentation = source.splitlines()[0].find('def') - source = '\n'.join([line[indentation:] for line in source.splitlines()]) + source = textwrap.dedent(source) - # Re-write the ast + # Parse the source tree = ast.parse(source) if self.ast_debug: print('\n\n\n\n TRANSFORMING FROM \n\n\n\n') astpretty.pprint(tree) + # Remove the decorators from the function + tree = DecoratorRemover().visit(tree) + # Re-write the ast tree = GATransformer().visit(tree) + # Fix it all up ast.fix_missing_locations(tree) if self.ast_debug: diff --git a/clifford/test/test_jit_func.py b/clifford/test/test_jit_func.py index 88734321..02b567da 100644 --- a/clifford/test/test_jit_func.py +++ b/clifford/test/test_jit_func.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from ..jit_func import jit_func @@ -14,29 +15,23 @@ def test_grade_selection(self): e1 = self.blades['e1'] e12 = self.blades['e12'] - @jit_func(layout=self.layout, ast_debug=True) - def test_func(A, B): - op = (A+B)(1) - return op - def slow_test_func(A, B): op = (A+B)(1) return op + test_func = jit_func(self.layout, ast_debug=True)(slow_test_func) + np.testing.assert_allclose(test_func(e1, e12).value, slow_test_func(e1, e12).value) def test_reverse(self): e12 = self.blades['e12'] - @jit_func(layout=self.layout, ast_debug=True) - def test_func(A): - op = ~A - return op - def slow_test_func(A): op = ~A return op + test_func = jit_func(self.layout, ast_debug=True)(slow_test_func) + np.testing.assert_allclose(test_func(e12).value, slow_test_func(e12).value) def test_compound_expression(self): @@ -45,26 +40,22 @@ def test_compound_expression(self): e34 = self.blades['e34'] einf = self.layout.einf - @jit_func(layout=self.layout, ast_debug=True) - def test_func(A, B, C, D): - op = ~(((A * B) * ~C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) + (A + D)(2) - return op - def slow_test_func(A, B, C, D): op = ~(((A * B) * ~C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) + (A + D)(2) return op + test_func = jit_func(self.layout, ast_debug=True)(slow_test_func) + np.testing.assert_allclose(test_func(e1, e2, einf, e34).value, slow_test_func(e1, e2, einf, e34).value) - def test_benchmark_jit(self, benchmark): - @jit_func(layout=self.layout) - def test_func(A, B, C, D): - op = ~(((A * B) * ~C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) + (A + D)(2) - return op - benchmark(test_func, self.blades['e1'], self.blades['e2'], self.layout.einf, self.blades['e34']) + @pytest.mark.parametrize('use_jit', [False, True]) + def test_benchmark(self, use_jit, benchmark): - def test_benchmark_nojit(self, benchmark): def test_func(A, B, C, D): op = ~(((A * B) * ~C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) + (A + D)(2) return op + + if use_jit: + test_func = jit_func(self.layout)(test_func) benchmark(test_func, self.blades['e1'], self.blades['e2'], self.layout.einf, self.blades['e34']) + From ff9648d826d55d1a032cb13e0a62101a51d3822f Mon Sep 17 00:00:00 2001 From: hugo hadfield Date: Sun, 7 Jun 2020 15:10:24 +0100 Subject: [PATCH 18/27] Add scalar and multivector constants to decorator arguments --- clifford/jit_func.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/clifford/jit_func.py b/clifford/jit_func.py index 6e44fd92..30ab3636 100644 --- a/clifford/jit_func.py +++ b/clifford/jit_func.py @@ -17,8 +17,10 @@ class jit_func(object): This is a JIT decorator that re-writes the AST and then numba JITs the resulting function. """ - def __init__(self, layout, ast_debug=False): + def __init__(self, layout, ast_debug=False, mv_constants={}, scalar_constants={}): self.layout = layout + self.mv_constants = mv_constants + self.scalar_constants = scalar_constants if AST_PRETTY_AVAILABLE: self.ast_debug = ast_debug else: @@ -61,6 +63,13 @@ def __call__(self, func): 'ga_or': self.layout.overload_or_func, 'ga_rev': self.layout.overload_reverse_func, 'ga_call': self.layout.overload_call_func} + + # Add the passed multivector and scalar constants + for k, v in self.mv_constants.items(): + locals_dict[k] = v.value + for k, v in self.scalar_constants.items(): + locals_dict[k] = v + # TODO: Work out a better way to deal with changes to globals globs = {} for k, v in globals().items(): @@ -79,4 +88,5 @@ def __call__(self, func): # Wrap the JITed function def wrapper(*args, **kwargs): return self.layout.MultiVector(value=jitted_func(*[a.value for a in args], **kwargs)) + wrapper.value = jitted_func return wrapper From 5d2787456da3368e16911afa44b3cfaf1064fdb6 Mon Sep 17 00:00:00 2001 From: hugo hadfield Date: Sun, 7 Jun 2020 15:10:48 +0100 Subject: [PATCH 19/27] Fix nested function call transformer --- clifford/_ast_transformer.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/clifford/_ast_transformer.py b/clifford/_ast_transformer.py index 27753b3c..1147a0e5 100644 --- a/clifford/_ast_transformer.py +++ b/clifford/_ast_transformer.py @@ -51,18 +51,14 @@ def visit_UnaryOp(self, node): ) def visit_Call(self, node): - try: - nfuncid = node.func.id - return node - except AttributeError: - # Only allow a single grade to be selected for now - if len(node.args) == 1: - return ast.Call( - func=ast.Name(id='ga_call', ctx=ast.Load()), - args=[self.visit(node.func), node.args[0]], - keywords=[] - ) - else: - return node - except: + if len(node.args) == 1: + return ast.Call( + func=ast.Name(id='ga_call', ctx=ast.Load()), + args=[self.visit(node.func), self.visit(node.args[0])], + keywords=[] + ) + else: + node.func = self.visit(node.func) + node.args = [self.visit(a) for a in node.args] return node + From 6c2cea6bbed5cb80074075115323e998a10bd6fa Mon Sep 17 00:00:00 2001 From: hugo hadfield Date: Sun, 7 Jun 2020 15:11:44 +0100 Subject: [PATCH 20/27] Improve speed of linear_operator_to_matrix --- clifford/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/clifford/__init__.py b/clifford/__init__.py index 01d07eba..d814d64c 100644 --- a/clifford/__init__.py +++ b/clifford/__init__.py @@ -117,7 +117,8 @@ def linear_operator_as_matrix(func, input_blades, output_blades): ndimout = len(output_blades) mat = np.zeros((ndimout, ndimin)) for i, b in enumerate(input_blades): - mat[:, i] = np.array([func(b)[j] for j in output_blades]) + b_result = func(b) + mat[:, i] = np.array([b_result[j] for j in output_blades]) return mat From 307874f153c775710b6d7a5e34954ca562c77334 Mon Sep 17 00:00:00 2001 From: hugo hadfield Date: Sun, 7 Jun 2020 15:12:04 +0100 Subject: [PATCH 21/27] Add testing for new jit decorator features --- clifford/test/test_jit_func.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/clifford/test/test_jit_func.py b/clifford/test/test_jit_func.py index 02b567da..e374783b 100644 --- a/clifford/test/test_jit_func.py +++ b/clifford/test/test_jit_func.py @@ -16,7 +16,7 @@ def test_grade_selection(self): e12 = self.blades['e12'] def slow_test_func(A, B): - op = (A+B)(1) + op = (A(1)+B(2))(1) return op test_func = jit_func(self.layout, ast_debug=True)(slow_test_func) @@ -38,24 +38,33 @@ def test_compound_expression(self): e1 = self.blades['e1'] e2 = self.blades['e2'] e34 = self.blades['e34'] + e12345 = self.blades['e12345'] einf = self.layout.einf + pi = np.pi def slow_test_func(A, B, C, D): - op = ~(((A * B) * ~C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) + (A + D)(2) + op = ~(((A * B) * ~C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) + (A + D)(2) + pi*e12345 return op - test_func = jit_func(self.layout, ast_debug=True)(slow_test_func) + test_func = jit_func(self.layout, + mv_constants={'e12345': e12345}, + scalar_constants={'pi': pi} + )(slow_test_func) np.testing.assert_allclose(test_func(e1, e2, einf, e34).value, slow_test_func(e1, e2, einf, e34).value) @pytest.mark.parametrize('use_jit', [False, True]) def test_benchmark(self, use_jit, benchmark): - + e12345 = self.blades['e12345'] + pi = np.pi def test_func(A, B, C, D): - op = ~(((A * B) * ~C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) + (A + D)(2) + op = ~(((A * B) * ~C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) + (A + D)(2) + pi*e12345 return op if use_jit: - test_func = jit_func(self.layout)(test_func) + test_func = jit_func(self.layout, + mv_constants={'e12345': e12345}, + scalar_constants={'pi': pi} + )(test_func) benchmark(test_func, self.blades['e1'], self.blades['e2'], self.layout.einf, self.blades['e34']) From c5be87a8150adba7a89e962abdffc1c15d310909 Mon Sep 17 00:00:00 2001 From: hugo hadfield Date: Mon, 8 Jun 2020 09:28:21 +0100 Subject: [PATCH 22/27] Added a nested jitted function test --- clifford/test/test_jit_func.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/clifford/test/test_jit_func.py b/clifford/test/test_jit_func.py index e374783b..f8ab70aa 100644 --- a/clifford/test/test_jit_func.py +++ b/clifford/test/test_jit_func.py @@ -23,6 +23,31 @@ def slow_test_func(A, B): np.testing.assert_allclose(test_func(e1, e12).value, slow_test_func(e1, e12).value) + def test_nested_functions(self): + e1 = self.blades['e1'] + e12 = self.blades['e12'] + + def test_func_1(A, B): + op = (A(1)+B(2))(1) + return op + + def test_func_2(A): + op = ~A + 5*e12 + return op + + def compound_func(A, B): + return test_func_2(test_func_1(A, B)) + + test_func_1_jit = jit_func(self.layout)(test_func_1) + test_func_2_jit = jit_func(self.layout, + mv_constants={'e12': e12})(test_func_2) + + test_compound_func = jit_func(self.layout, + mv_constants={'test_func_1': test_func_1_jit, + 'test_func_2': test_func_2_jit})(compound_func) + + np.testing.assert_allclose(test_compound_func(e1, e12).value, compound_func(e1, e12).value) + def test_reverse(self): e12 = self.blades['e12'] @@ -66,5 +91,4 @@ def test_func(A, B, C, D): mv_constants={'e12345': e12345}, scalar_constants={'pi': pi} )(test_func) - benchmark(test_func, self.blades['e1'], self.blades['e2'], self.layout.einf, self.blades['e34']) - + benchmark(test_func, self.blades['e1'], self.blades['e2'], self.layout.einf, self.blades['e34']) \ No newline at end of file From 8f0296078a91be9dec79d5f6bc7432b930099707 Mon Sep 17 00:00:00 2001 From: hugo hadfield Date: Mon, 8 Jun 2020 09:45:11 +0100 Subject: [PATCH 23/27] Fixed flake8 complaints --- clifford/_ast_transformer.py | 1 - clifford/test/test_jit_func.py | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/clifford/_ast_transformer.py b/clifford/_ast_transformer.py index 1147a0e5..05d456f6 100644 --- a/clifford/_ast_transformer.py +++ b/clifford/_ast_transformer.py @@ -61,4 +61,3 @@ def visit_Call(self, node): node.func = self.visit(node.func) node.args = [self.visit(a) for a in node.args] return node - diff --git a/clifford/test/test_jit_func.py b/clifford/test/test_jit_func.py index f8ab70aa..651e7f7c 100644 --- a/clifford/test/test_jit_func.py +++ b/clifford/test/test_jit_func.py @@ -67,6 +67,7 @@ def test_compound_expression(self): einf = self.layout.einf pi = np.pi + def slow_test_func(A, B, C, D): op = ~(((A * B) * ~C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) + (A + D)(2) + pi*e12345 return op @@ -82,6 +83,7 @@ def slow_test_func(A, B, C, D): def test_benchmark(self, use_jit, benchmark): e12345 = self.blades['e12345'] pi = np.pi + def test_func(A, B, C, D): op = ~(((A * B) * ~C) | (B ^ A)) - 3.1 - A - 7 * B + 5 + C + 2.5 + (2 ^ (A * B * C) ^ 3) + (A | 5) + (A + D)(2) + pi*e12345 return op @@ -91,4 +93,4 @@ def test_func(A, B, C, D): mv_constants={'e12345': e12345}, scalar_constants={'pi': pi} )(test_func) - benchmark(test_func, self.blades['e1'], self.blades['e2'], self.layout.einf, self.blades['e34']) \ No newline at end of file + benchmark(test_func, self.blades['e1'], self.blades['e2'], self.layout.einf, self.blades['e34']) From 8e96d8142da8d684d557809c80433bba3e388ada Mon Sep 17 00:00:00 2001 From: hugohadfield Date: Tue, 9 Jun 2020 10:41:29 +0100 Subject: [PATCH 24/27] Apply suggestions from Eric code review Co-authored-by: Eric Wieser --- clifford/_ast_transformer.py | 15 ++++++--------- clifford/jit_func.py | 12 ++++-------- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/clifford/_ast_transformer.py b/clifford/_ast_transformer.py index 05d456f6..41645e3d 100644 --- a/clifford/_ast_transformer.py +++ b/clifford/_ast_transformer.py @@ -3,7 +3,7 @@ class DecoratorRemover(ast.NodeTransformer): - """ Strip decorators from FunctionDefs""" + """ Strip decorators from top-level FunctionDefs""" def visit_FunctionDef(self, node): node.decorator_list = [] return node @@ -52,12 +52,9 @@ def visit_UnaryOp(self, node): def visit_Call(self, node): if len(node.args) == 1: - return ast.Call( - func=ast.Name(id='ga_call', ctx=ast.Load()), - args=[self.visit(node.func), self.visit(node.args[0])], - keywords=[] - ) - else: - node.func = self.visit(node.func) - node.args = [self.visit(a) for a in node.args] + node = self.generic_visit(node) + node.args = [node.func] + node.args + node.func = ast.Name(id='ga_call', ctx=ast.Load()) return node + else: + return self.generic_visit(node) diff --git a/clifford/jit_func.py b/clifford/jit_func.py index 30ab3636..b9a38cf4 100644 --- a/clifford/jit_func.py +++ b/clifford/jit_func.py @@ -65,17 +65,13 @@ def __call__(self, func): 'ga_call': self.layout.overload_call_func} # Add the passed multivector and scalar constants - for k, v in self.mv_constants.items(): - locals_dict[k] = v.value - for k, v in self.scalar_constants.items(): - locals_dict[k] = v + locals_dict.update(self.mv_constants) + locals_dict.update(self.scalar_constants) # TODO: Work out a better way to deal with changes to globals globs = {} - for k, v in globals().items(): - globs[k] = v - for k, v in locals_dict.items(): - globs[k] = v + globs.update(globals) + globs.update(locals_dict) # Compile the function co = compile(tree, '', "exec") From 2315f3f04e882ccb81d4aafb8c28148d4870d8c0 Mon Sep 17 00:00:00 2001 From: hugo hadfield Date: Tue, 9 Jun 2020 10:47:43 +0100 Subject: [PATCH 25/27] Fix up review comments --- clifford/jit_func.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/clifford/jit_func.py b/clifford/jit_func.py index b9a38cf4..f35e72f7 100644 --- a/clifford/jit_func.py +++ b/clifford/jit_func.py @@ -65,12 +65,12 @@ def __call__(self, func): 'ga_call': self.layout.overload_call_func} # Add the passed multivector and scalar constants - locals_dict.update(self.mv_constants) + locals_dict.update({k: v.value for k, v in self.mv_constants.items()}) locals_dict.update(self.scalar_constants) # TODO: Work out a better way to deal with changes to globals globs = {} - globs.update(globals) + globs.update(globals()) globs.update(locals_dict) # Compile the function From 87a41b92d49f996a34c672d0c9c7937651a90e04 Mon Sep 17 00:00:00 2001 From: hugo hadfield Date: Tue, 9 Jun 2020 11:03:55 +0100 Subject: [PATCH 26/27] Moved jit_impls into jit_func --- clifford/_layout.py | 279 ------------------------------------------ clifford/jit_func.py | 280 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 272 insertions(+), 287 deletions(-) diff --git a/clifford/_layout.py b/clifford/_layout.py index 246cd04b..421c06a8 100644 --- a/clifford/_layout.py +++ b/clifford/_layout.py @@ -5,9 +5,6 @@ import numpy as np import sparse -from numba.extending import overload -from numba import types - # TODO: move some of these functions to this file if they're not useful anywhere # else @@ -179,241 +176,6 @@ def construct_graded_mt( return sparse.COO(coords=coords, data=mult_table_vals, shape=(dims, dims, dims)) -def get_as_ga_vector_func(layout): - """ - Returns a function that converts a scalar into a GA value vector - for the given algebra - """ - scalar_index = layout._basis_blade_order.bitmap_to_index[0] - ndims = layout.gaDims - @_numba_utils.njit - def as_ga_value_vector(x): - op = np.zeros(ndims) - op[scalar_index] = x - return op - return as_ga_value_vector - - -def get_overload_add(layout): - """ - Returns an overloaded JITed function that works on - MultiVector value arrays - """ - scalar_index = layout._basis_blade_order.bitmap_to_index[0] - - def ga_add(a, b): - # dummy function to overload - pass - - @overload(ga_add, inline='always') - def ol_ga_add(a, b): - if isinstance(a, types.abstract.Number) and isinstance(b, types.Array): - def impl(a, b): - op = b.astype(np.float32) - op[scalar_index] += a - return op - return impl - elif isinstance(a, types.Array) and isinstance(b, types.abstract.Number): - def impl(a, b): - op = a.astype(np.float32) - op[scalar_index] += b - return op - return impl - else: - def impl(a, b): - return a + b - return impl - - return ga_add - - -def get_overload_sub(layout): - """ - Returns an overloaded JITed function that works on - MultiVector value arrays - """ - scalar_index = layout._basis_blade_order.bitmap_to_index[0] - - def ga_sub(a, b): - # dummy function to overload - pass - - @overload(ga_sub, inline='always') - def ol_ga_sub(a, b): - if isinstance(a, types.abstract.Number) and isinstance(b, types.Array): - def impl(a, b): - op = -b.astype(np.float32) - op[scalar_index] += a - return op - return impl - elif isinstance(a, types.Array) and isinstance(b, types.abstract.Number): - def impl(a, b): - op = a.astype(np.float32) - op[scalar_index] -= b - return op - return impl - else: - def impl(a, b): - return a - b - return impl - - return ga_sub - - -def get_overload_mul(layout): - """ - Returns an overloaded JITed function that works on - MultiVector value arrays - """ - def ga_mul(a, b): - # dummy function to overload - pass - - gmt_func = layout.gmt_func - @overload(ga_mul, inline='always') - def ol_ga_mul(a, b): - if isinstance(a, types.Array) and isinstance(b, types.Array): - def impl(a, b): - return gmt_func(a, b) - return impl - else: - def impl(a, b): - return a*b - return impl - - return ga_mul - - -def get_overload_xor(layout): - """ - Returns an overloaded JITed function that works on - MultiVector value arrays - """ - def ga_xor(a, b): - # dummy function to overload - pass - - as_ga = layout.as_ga_value_vector_func - omt_func = layout.omt_func - @overload(ga_xor, inline='always') - def ol_ga_xor(a, b): - if isinstance(a, types.Array) and isinstance(b, types.Array): - def impl(a, b): - return omt_func(a, b) - return impl - elif isinstance(a, types.Array) and isinstance(b, types.abstract.Number): - def impl(a, b): - return omt_func(a, as_ga(b)) - return impl - elif isinstance(a, types.abstract.Number) and isinstance(b, types.Array): - def impl(a, b): - return omt_func(as_ga(a), b) - return impl - else: - def impl(a, b): - return a^b - return impl - - return ga_xor - - -def get_overload_or(layout): - """ - Returns an overloaded JITed function that works on - MultiVector value arrays - """ - def ga_or(a, b): - # dummy function to overload - pass - - as_ga = layout.as_ga_value_vector_func - imt_func = layout.imt_func - @overload(ga_or, inline='always') - def ol_ga_or(a, b): - if isinstance(a, types.Array) and isinstance(b, types.Array): - def impl(a, b): - return imt_func(a, b) - return impl - elif isinstance(a, types.Array) and isinstance(b, types.abstract.Number): - def impl(a, b): - return imt_func(a, as_ga(b)) - return impl - elif isinstance(a, types.abstract.Number) and isinstance(b, types.Array): - def impl(a, b): - return imt_func(as_ga(a), b) - return impl - else: - def impl(a, b): - return a|b - return impl - - return ga_or - - -def get_overload_reverse(layout): - """ - Returns an overloaded JITed function that works on - MultiVector value arrays - """ - def ga_rev(x): - # dummy function to overload - pass - - adjoint_func = layout.adjoint_func - @overload(ga_rev, inline='always') - def ol_ga_rev(x): - if isinstance(x, types.Array): - def impl(x): - return adjoint_func(x) - return impl - else: - def impl(x): - return ~x - return impl - - return ga_rev - - -def get_project_to_grade_func(layout): - """ - Returns a function that projects a multivector to a given grade - """ - gradeList = np.array(layout.gradeList, dtype=int) - ndims = layout.gaDims - @_numba_utils.njit - def project_to_grade(A, g): - op = np.zeros(ndims) - for i in range(ndims): - if gradeList[i] == g: - op[i] = A[i] - return op - return project_to_grade - - -def get_overload_call(layout): - """ - Returns an overloaded JITed function that works on - MultiVector value arrays - """ - def ga_call(a, b): - # dummy function to overload - pass - - project_to_grade = layout.project_to_grade_func - @overload(ga_call, inline='always') - def ol_ga_call(a, b): - if isinstance(a, types.Array) and isinstance(b, types.Integer): - def impl(a, b): - return project_to_grade(a, b) - return impl - else: - def impl(a, b): - return a(b) - return impl - - return ga_call - - class Layout(object): r""" Layout stores information regarding the geometric algebra itself and the internal representation of multivectors. @@ -611,11 +373,6 @@ def __init__(self, *args, **kw): self.dual_func self.vee_func self.inv_func - self.overload_mul_func - self.overload_xor_func - self.overload_or_func - self.overload_add_func - self.overload_sub_func @_cached_property def gmt(self): @@ -816,10 +573,6 @@ def comp_func(Xval): return Yval return comp_func - @_cached_property - def as_ga_value_vector_func(self): - return get_as_ga_vector_func(self) - @_cached_property def gmt_func(self): return get_mult_function(self.gmt, self.gradeList) @@ -844,38 +597,6 @@ def left_complement_func(self): def right_complement_func(self): return self._gen_complement_func(omt=self.omt.T) - @_cached_property - def overload_mul_func(self): - return get_overload_mul(self) - - @_cached_property - def overload_xor_func(self): - return get_overload_xor(self) - - @_cached_property - def overload_or_func(self): - return get_overload_or(self) - - @_cached_property - def overload_add_func(self): - return get_overload_add(self) - - @_cached_property - def overload_sub_func(self): - return get_overload_sub(self) - - @_cached_property - def overload_reverse_func(self): - return get_overload_reverse(self) - - @_cached_property - def project_to_grade_func(self): - return get_project_to_grade_func(self) - - @_cached_property - def overload_call_func(self): - return get_overload_call(self) - @_cached_property def adjoint_func(self): ''' diff --git a/clifford/jit_func.py b/clifford/jit_func.py index f35e72f7..c65bb045 100644 --- a/clifford/jit_func.py +++ b/clifford/jit_func.py @@ -8,10 +8,278 @@ import inspect import warnings import textwrap +import weakref +import functools +import numpy as np +from numba.extending import overload +from numba import types from ._numba_utils import njit from ._ast_transformer import DecoratorRemover, GATransformer +def get_as_ga_vector_func(layout): + """ + Returns a function that converts a scalar into a GA value vector + for the given algebra + """ + scalar_index = layout._basis_blade_order.bitmap_to_index[0] + ndims = layout.gaDims + @njit + def as_ga_value_vector(x): + op = np.zeros(ndims) + op[scalar_index] = x + return op + return as_ga_value_vector + + +def get_overload_add_func(layout): + """ + Returns an overloaded JITed function that works on + MultiVector value arrays + """ + scalar_index = layout._basis_blade_order.bitmap_to_index[0] + + def ga_add(a, b): + # dummy function to overload + pass + + @overload(ga_add, inline='always') + def ol_ga_add(a, b): + if isinstance(a, types.abstract.Number) and isinstance(b, types.Array): + def impl(a, b): + op = b.astype(np.float32) + op[scalar_index] += a + return op + return impl + elif isinstance(a, types.Array) and isinstance(b, types.abstract.Number): + def impl(a, b): + op = a.astype(np.float32) + op[scalar_index] += b + return op + return impl + else: + def impl(a, b): + return a + b + return impl + + return ga_add + + +def get_overload_sub_func(layout): + """ + Returns an overloaded JITed function that works on + MultiVector value arrays + """ + scalar_index = layout._basis_blade_order.bitmap_to_index[0] + + def ga_sub(a, b): + # dummy function to overload + pass + + @overload(ga_sub, inline='always') + def ol_ga_sub(a, b): + if isinstance(a, types.abstract.Number) and isinstance(b, types.Array): + def impl(a, b): + op = -b.astype(np.float32) + op[scalar_index] += a + return op + return impl + elif isinstance(a, types.Array) and isinstance(b, types.abstract.Number): + def impl(a, b): + op = a.astype(np.float32) + op[scalar_index] -= b + return op + return impl + else: + def impl(a, b): + return a - b + return impl + + return ga_sub + + +def get_overload_mul_func(layout): + """ + Returns an overloaded JITed function that works on + MultiVector value arrays + """ + def ga_mul(a, b): + # dummy function to overload + pass + + gmt_func = layout.gmt_func + @overload(ga_mul, inline='always') + def ol_ga_mul(a, b): + if isinstance(a, types.Array) and isinstance(b, types.Array): + def impl(a, b): + return gmt_func(a, b) + return impl + else: + def impl(a, b): + return a*b + return impl + + return ga_mul + + +def get_overload_xor_func(layout): + """ + Returns an overloaded JITed function that works on + MultiVector value arrays + """ + def ga_xor(a, b): + # dummy function to overload + pass + + as_ga = get_as_ga_vector_func(layout) + omt_func = layout.omt_func + @overload(ga_xor, inline='always') + def ol_ga_xor(a, b): + if isinstance(a, types.Array) and isinstance(b, types.Array): + def impl(a, b): + return omt_func(a, b) + return impl + elif isinstance(a, types.Array) and isinstance(b, types.abstract.Number): + def impl(a, b): + return omt_func(a, as_ga(b)) + return impl + elif isinstance(a, types.abstract.Number) and isinstance(b, types.Array): + def impl(a, b): + return omt_func(as_ga(a), b) + return impl + else: + def impl(a, b): + return a^b + return impl + + return ga_xor + + +def get_overload_or_func(layout): + """ + Returns an overloaded JITed function that works on + MultiVector value arrays + """ + def ga_or(a, b): + # dummy function to overload + pass + + as_ga = get_as_ga_vector_func(layout) + imt_func = layout.imt_func + @overload(ga_or, inline='always') + def ol_ga_or(a, b): + if isinstance(a, types.Array) and isinstance(b, types.Array): + def impl(a, b): + return imt_func(a, b) + return impl + elif isinstance(a, types.Array) and isinstance(b, types.abstract.Number): + def impl(a, b): + return imt_func(a, as_ga(b)) + return impl + elif isinstance(a, types.abstract.Number) and isinstance(b, types.Array): + def impl(a, b): + return imt_func(as_ga(a), b) + return impl + else: + def impl(a, b): + return a|b + return impl + + return ga_or + + +def get_overload_reverse_func(layout): + """ + Returns an overloaded JITed function that works on + MultiVector value arrays + """ + def ga_rev(x): + # dummy function to overload + pass + + adjoint_func = layout.adjoint_func + @overload(ga_rev, inline='always') + def ol_ga_rev(x): + if isinstance(x, types.Array): + def impl(x): + return adjoint_func(x) + return impl + else: + def impl(x): + return ~x + return impl + + return ga_rev + + +def get_project_to_grade_func(layout): + """ + Returns a function that projects a multivector to a given grade + """ + gradeList = np.array(layout.gradeList, dtype=int) + ndims = layout.gaDims + @njit + def project_to_grade(A, g): + op = np.zeros(ndims) + for i in range(ndims): + if gradeList[i] == g: + op[i] = A[i] + return op + return project_to_grade + + +def get_overload_call_func(layout): + """ + Returns an overloaded JITed function that works on + MultiVector value arrays + """ + def ga_call(a, b): + # dummy function to overload + pass + + project_to_grade = get_project_to_grade_func(layout) + @overload(ga_call, inline='always') + def ol_ga_call(a, b): + if isinstance(a, types.Array) and isinstance(b, types.Integer): + def impl(a, b): + return project_to_grade(a, b) + return impl + else: + def impl(a, b): + return a(b) + return impl + + return ga_call + + +def weak_cache(f): + _cache = weakref.WeakKeyDictionary() + @functools.wraps(f) + def wrapped(*args, **kwargs): + a, *args = args + try: + return _cache[a] + except KeyError: + ret =_cache[a] = f(a, *args, **kwargs) + return ret + wrapped._cache = _cache + return wrapped + + +@weak_cache +def _get_jit_impls(layout): + return { + 'as_ga': get_as_ga_vector_func(layout), + 'ga_add': get_overload_add_func(layout), + 'ga_sub': get_overload_sub_func(layout), + 'ga_mul': get_overload_mul_func(layout), + 'ga_xor': get_overload_xor_func(layout), + 'ga_or': get_overload_or_func(layout), + 'ga_rev': get_overload_reverse_func(layout), + 'ga_call': get_overload_call_func(layout), + } + + class jit_func(object): """ This is a JIT decorator that re-writes the AST and then numba JITs @@ -21,6 +289,7 @@ def __init__(self, layout, ast_debug=False, mv_constants={}, scalar_constants={} self.layout = layout self.mv_constants = mv_constants self.scalar_constants = scalar_constants + self.jit_impls = _get_jit_impls(layout) if AST_PRETTY_AVAILABLE: self.ast_debug = ast_debug else: @@ -44,6 +313,7 @@ def __call__(self, func): astpretty.pprint(tree) # Remove the decorators from the function + # TODO: Work out how to remove only the jit_func decorator tree = DecoratorRemover().visit(tree) # Re-write the ast tree = GATransformer().visit(tree) @@ -55,14 +325,8 @@ def __call__(self, func): astpretty.pprint(tree) # Set things up into locals and globals so that they JIT ok... - locals_dict = {'as_ga': self.layout.as_ga_value_vector_func, - 'ga_add': self.layout.overload_add_func, - 'ga_sub': self.layout.overload_sub_func, - 'ga_mul': self.layout.overload_mul_func, - 'ga_xor': self.layout.overload_xor_func, - 'ga_or': self.layout.overload_or_func, - 'ga_rev': self.layout.overload_reverse_func, - 'ga_call': self.layout.overload_call_func} + locals_dict = {} + locals_dict.update(self.jit_impls) # Add the passed multivector and scalar constants locals_dict.update({k: v.value for k, v in self.mv_constants.items()}) From ccf5551687a8f64e3e9ac139cd19f088d23bcae7 Mon Sep 17 00:00:00 2001 From: hugo hadfield Date: Tue, 9 Jun 2020 11:47:03 +0100 Subject: [PATCH 27/27] Moved jit_func into an experimental directory --- clifford/experimental/__init__.py | 0 clifford/{ => experimental}/jit_func.py | 4 ++-- clifford/test/test_jit_func.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) create mode 100644 clifford/experimental/__init__.py rename clifford/{ => experimental}/jit_func.py (99%) diff --git a/clifford/experimental/__init__.py b/clifford/experimental/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/clifford/jit_func.py b/clifford/experimental/jit_func.py similarity index 99% rename from clifford/jit_func.py rename to clifford/experimental/jit_func.py index c65bb045..beed5fab 100644 --- a/clifford/jit_func.py +++ b/clifford/experimental/jit_func.py @@ -13,8 +13,8 @@ import numpy as np from numba.extending import overload from numba import types -from ._numba_utils import njit -from ._ast_transformer import DecoratorRemover, GATransformer +from .._numba_utils import njit +from .._ast_transformer import DecoratorRemover, GATransformer def get_as_ga_vector_func(layout): diff --git a/clifford/test/test_jit_func.py b/clifford/test/test_jit_func.py index 651e7f7c..80dfbb65 100644 --- a/clifford/test/test_jit_func.py +++ b/clifford/test/test_jit_func.py @@ -1,6 +1,6 @@ import numpy as np import pytest -from ..jit_func import jit_func +from clifford.experimental.jit_func import jit_func class TestJITFunc: