Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental AST rewriter and JIT decorator #326

Open
wants to merge 27 commits into
base: experimental/abc-mangling
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9e3724c
Added numba overloaded functions to layout
hugohadfield Jun 5, 2020
03826c5
Added a GA specific ast transformer
hugohadfield Jun 5, 2020
ef61257
Added a jit_func decorator to ast transform and numba jit
hugohadfield Jun 5, 2020
c13bc94
Corrected jit_func, added a test
hugohadfield Jun 5, 2020
51f8a42
remove duplication in ast_transformer
hugohadfield Jun 5, 2020
8022092
convert to abstract numeric types in the numba jit overload
hugohadfield Jun 5, 2020
f14521b
Improved handling globals, added a TODO
hugohadfield Jun 5, 2020
5fdbb86
Added ast_pretty warning if not installed
hugohadfield Jun 5, 2020
d6c6e06
removed unnescary print
hugohadfield Jun 5, 2020
8094a61
Added reversion to AST rewriter and JIT
hugohadfield Jun 5, 2020
1767342
Added grade selection via the call syntax
hugohadfield Jun 5, 2020
81601ce
Set up pytest benchmark
hugohadfield Jun 6, 2020
d905393
Make node visitation recursive for Call
hugohadfield Jun 6, 2020
750ec85
Add ImportError type for astpretty
hugohadfield Jun 6, 2020
e0263f8
Improve warning whitespace
hugohadfield Jun 6, 2020
e878dbe
Make the Call rewrite exception an AttributeError
hugohadfield Jun 6, 2020
482b091
Moved the decorator removal to the AST level
hugohadfield Jun 6, 2020
ff9648d
Add scalar and multivector constants to decorator arguments
hugohadfield Jun 7, 2020
5d27874
Fix nested function call transformer
hugohadfield Jun 7, 2020
6c2cea6
Improve speed of linear_operator_to_matrix
hugohadfield Jun 7, 2020
307874f
Add testing for new jit decorator features
hugohadfield Jun 7, 2020
c5be87a
Added a nested jitted function test
hugohadfield Jun 8, 2020
8f02960
Fixed flake8 complaints
hugohadfield Jun 8, 2020
8e96d81
Apply suggestions from Eric code review
hugohadfield Jun 9, 2020
2315f3f
Fix up review comments
hugohadfield Jun 9, 2020
87a41b9
Moved jit_impls into jit_func
hugohadfield Jun 9, 2020
ccf5551
Moved jit_func into an experimental directory
hugohadfield Jun 9, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
44 changes: 44 additions & 0 deletions clifford/_ast_transformer.py
@@ -0,0 +1,44 @@

import ast


class GATransformer(ast.NodeTransformer):
"""
This is an AST transformer that converts operations into
JITable counterparts that work on MultiVector value arrays.
We crawl the AST and convert BinOps and UnaryOps into numba
overloaded functions.
"""
def visit_BinOp(self, node):
ops = {
ast.Mult: 'ga_mul',
ast.BitXor: 'ga_xor',
ast.BitOr: 'ga_or',
ast.Add: 'ga_add',
ast.Sub: 'ga_sub',
}
try:
func_name = ops[type(node.op)]
except KeyError:
return node
else:
return ast.Call(
func=ast.Name(id=func_name, ctx=ast.Load()),
args=[self.visit(node.left), self.visit(node.right)],
keywords=[]
)

def visit_UnaryOp(self, node):
ops = {
ast.Invert: 'ga_rev'
}
try:
func_name = ops[type(node.op)]
except KeyError:
return node
else:
return ast.Call(
func=ast.Name(id=func_name, ctx=ast.Load()),
args=[self.visit(node.operand)],
keywords=[]
)
232 changes: 232 additions & 0 deletions clifford/_layout.py
Expand Up @@ -5,6 +5,10 @@
import numpy as np
import sparse

from numba.extending import overload
from numba import types


# TODO: move some of these functions to this file if they're not useful anywhere
# else
import clifford as cf
Expand Down Expand Up @@ -175,6 +179,201 @@ def construct_graded_mt(
return sparse.COO(coords=coords, data=mult_table_vals, shape=(dims, dims, dims))


def get_as_ga_vector_func(layout):
"""
Returns a function that converts a scalar into a GA value vector
for the given algebra
"""
scalar_index = layout._basis_blade_order.bitmap_to_index[0]
ndims = layout.gaDims
@_numba_utils.njit
def as_ga_value_vector(x):
op = np.zeros(ndims)
op[scalar_index] = x
return op
return as_ga_value_vector


def get_overload_add(layout):
"""
Returns an overloaded JITed function that works on
MultiVector value arrays
"""
scalar_index = layout._basis_blade_order.bitmap_to_index[0]

def ga_add(a, b):
# dummy function to overload
pass

@overload(ga_add, inline='always')
def ol_ga_add(a, b):
if isinstance(a, types.abstract.Number) and isinstance(b, types.Array):
def impl(a, b):
op = b.astype(np.float32)
op[scalar_index] += a
return op
return impl
elif isinstance(a, types.Array) and isinstance(b, types.abstract.Number):
def impl(a, b):
op = a.astype(np.float32)
op[scalar_index] += b
return op
return impl
else:
def impl(a, b):
return a + b
return impl

return ga_add


def get_overload_sub(layout):
"""
Returns an overloaded JITed function that works on
MultiVector value arrays
"""
scalar_index = layout._basis_blade_order.bitmap_to_index[0]

def ga_sub(a, b):
# dummy function to overload
pass

@overload(ga_sub, inline='always')
def ol_ga_sub(a, b):
if isinstance(a, types.abstract.Number) and isinstance(b, types.Array):
def impl(a, b):
op = -b.astype(np.float32)
op[scalar_index] += a
return op
return impl
elif isinstance(a, types.Array) and isinstance(b, types.abstract.Number):
def impl(a, b):
op = a.astype(np.float32)
op[scalar_index] -= b
return op
return impl
else:
def impl(a, b):
return a - b
return impl

return ga_sub


def get_overload_mul(layout):
"""
Returns an overloaded JITed function that works on
MultiVector value arrays
"""
def ga_mul(a, b):
# dummy function to overload
pass

gmt_func = layout.gmt_func
@overload(ga_mul, inline='always')
def ol_ga_mul(a, b):
if isinstance(a, types.Array) and isinstance(b, types.Array):
def impl(a, b):
return gmt_func(a, b)
return impl
else:
def impl(a, b):
return a*b
return impl

return ga_mul


def get_overload_xor(layout):
"""
Returns an overloaded JITed function that works on
MultiVector value arrays
"""
def ga_xor(a, b):
# dummy function to overload
pass

as_ga = layout.as_ga_value_vector_func
omt_func = layout.omt_func
@overload(ga_xor, inline='always')
def ol_ga_xor(a, b):
if isinstance(a, types.Array) and isinstance(b, types.Array):
def impl(a, b):
return omt_func(a, b)
return impl
elif isinstance(a, types.Array) and isinstance(b, types.abstract.Number):
def impl(a, b):
return omt_func(a, as_ga(b))
return impl
elif isinstance(a, types.abstract.Number) and isinstance(b, types.Array):
def impl(a, b):
return omt_func(as_ga(a), b)
return impl
else:
def impl(a, b):
return a^b
return impl
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tempting to make these shorter:

Suggested change
def impl(a, b):
return omt_func(a, b)
return impl
elif isinstance(a, types.Array) and isinstance(b, types.abstract.Number):
def impl(a, b):
return omt_func(a, as_ga(b))
return impl
elif isinstance(a, types.abstract.Number) and isinstance(b, types.Array):
def impl(a, b):
return omt_func(as_ga(a), b)
return impl
else:
def impl(a, b):
return a^b
return impl
return lambda a, b: omt_func(a, b)
elif isinstance(a, types.Array) and isinstance(b, types.abstract.Number):
return lambda a, b: omt_func(a, as_ga(b))
elif isinstance(a, types.abstract.Number) and isinstance(b, types.Array):
return lambda a, b: omt_func(as_ga(a), b)
else:
return lambda a, b: return a^b


return ga_xor


def get_overload_or(layout):
"""
Returns an overloaded JITed function that works on
MultiVector value arrays
"""
def ga_or(a, b):
# dummy function to overload
pass

as_ga = layout.as_ga_value_vector_func
imt_func = layout.imt_func
@overload(ga_or, inline='always')
def ol_ga_or(a, b):
if isinstance(a, types.Array) and isinstance(b, types.Array):
def impl(a, b):
return imt_func(a, b)
return impl
elif isinstance(a, types.Array) and isinstance(b, types.abstract.Number):
def impl(a, b):
return imt_func(a, as_ga(b))
return impl
elif isinstance(a, types.abstract.Number) and isinstance(b, types.Array):
def impl(a, b):
return imt_func(as_ga(a), b)
return impl
else:
def impl(a, b):
return a|b
return impl

return ga_or


def get_overload_reverse(layout):
"""
Returns an overloaded JITed function that works on
MultiVector value arrays
"""
def ga_rev(x):
# dummy function to overload
pass

adjoint_func = layout.adjoint_func
@overload(ga_rev, inline='always')
def ol_ga_rev(x):
if isinstance(x, types.Array):
def impl(x):
return adjoint_func(x)
return impl
else:
def impl(x):
return ~x
return impl

return ga_rev


class Layout(object):
r""" Layout stores information regarding the geometric algebra itself and the
internal representation of multivectors.
Expand Down Expand Up @@ -372,6 +571,11 @@ def __init__(self, *args, **kw):
self.dual_func
self.vee_func
self.inv_func
self.overload_mul_func
self.overload_xor_func
self.overload_or_func
self.overload_add_func
self.overload_sub_func

@_cached_property
def gmt(self):
Expand Down Expand Up @@ -572,6 +776,10 @@ def comp_func(Xval):
return Yval
return comp_func

@_cached_property
def as_ga_value_vector_func(self):
return get_as_ga_vector_func(self)

@_cached_property
def gmt_func(self):
return get_mult_function(self.gmt, self.gradeList)
Expand All @@ -596,6 +804,30 @@ def left_complement_func(self):
def right_complement_func(self):
return self._gen_complement_func(omt=self.omt.T)

@_cached_property
def overload_mul_func(self):
return get_overload_mul(self)

@_cached_property
def overload_xor_func(self):
return get_overload_xor(self)

@_cached_property
def overload_or_func(self):
return get_overload_or(self)

@_cached_property
def overload_add_func(self):
return get_overload_add(self)

@_cached_property
def overload_sub_func(self):
return get_overload_sub(self)

@_cached_property
def overload_reverse_func(self):
return get_overload_reverse(self)

@_cached_property
def adjoint_func(self):
'''
Expand Down
80 changes: 80 additions & 0 deletions clifford/jit_func.py
@@ -0,0 +1,80 @@

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps clifford/experimental/jit_func

import ast
try:
import astpretty
AST_PRETTY_AVAILABLE = True
except:
eric-wieser marked this conversation as resolved.
Show resolved Hide resolved
AST_PRETTY_AVAILABLE = False
import inspect
import warnings
from ._numba_utils import njit
from ._ast_transformer import GATransformer


class jit_func(object):
"""
This is a JIT decorator that re-writes the AST and then numba JITs
the resulting function.
"""
def __init__(self, layout, ast_debug=False):
self.layout = layout
if AST_PRETTY_AVAILABLE:
self.ast_debug = ast_debug
else:
if ast_debug:
warnings.warn('''
The ast_debug flag is set to True, but the astpretty module is not importable.
To see ast_debug output please pip install astpretty
''')
eric-wieser marked this conversation as resolved.
Show resolved Hide resolved
self.ast_debug = False

def __call__(self, func):
# Get the function source
fname = func.__name__
source = inspect.getsource(func)
# Remove the decorator first line.
source = '\n'.join(source.splitlines()[1:])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to remove this at the ast level

# Remove the indentation
indentation = source.splitlines()[0].find('def')
source = '\n'.join([line[indentation:] for line in source.splitlines()])

# Re-write the ast
tree = ast.parse(source)
if self.ast_debug:
print('\n\n\n\n TRANSFORMING FROM \n\n\n\n')
astpretty.pprint(tree)

tree = GATransformer().visit(tree)
ast.fix_missing_locations(tree)

if self.ast_debug:
print('\n\n\n\n TRANSFORMING TO \n\n\n\n')
astpretty.pprint(tree)

# Set things up into locals and globals so that they JIT ok...
locals_dict = {'as_ga': self.layout.as_ga_value_vector_func,
'ga_add': self.layout.overload_add_func,
'ga_sub': self.layout.overload_sub_func,
'ga_mul': self.layout.overload_mul_func,
'ga_xor': self.layout.overload_xor_func,
'ga_or': self.layout.overload_or_func,
'ga_rev': self.layout.overload_reverse_func}
# TODO: Work out a better way to deal with changes to globals
globs = {}
for k, v in globals().items():
globs[k] = v
for k, v in locals_dict.items():
globs[k] = v
hugohadfield marked this conversation as resolved.
Show resolved Hide resolved

# Compile the function
co = compile(tree, '<ast>', "exec")
exec(co, globs, locals_dict)
new_func = locals_dict[fname]

# JIT the function
jitted_func = njit(new_func)

# Wrap the JITed function
def wrapper(*args, **kwargs):
return self.layout.MultiVector(value=jitted_func(*[a.value for a in args], **kwargs))
return wrapper