-
Notifications
You must be signed in to change notification settings - Fork 71
/
_ast_transformer.py
44 lines (40 loc) · 1.19 KB
/
_ast_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import ast
class GATransformer(ast.NodeTransformer):
"""
This is an AST transformer that converts operations into
JITable counterparts that work on MultiVector value arrays.
We crawl the AST and convert BinOps and UnaryOps into numba
overloaded functions.
"""
def visit_BinOp(self, node):
ops = {
ast.Mult: 'ga_mul',
ast.BitXor: 'ga_xor',
ast.BitOr: 'ga_or',
ast.Add: 'ga_add',
ast.Sub: 'ga_sub',
}
try:
func_name = ops[type(node.op)]
except KeyError:
return node
else:
return ast.Call(
func=ast.Name(id=func_name, ctx=ast.Load()),
args=[self.visit(node.left), self.visit(node.right)],
keywords=[]
)
def visit_UnaryOp(self, node):
ops = {
ast.Invert: 'ga_rev'
}
try:
func_name = ops[type(node.op)]
except KeyError:
return node
else:
return ast.Call(
func=ast.Name(id=func_name, ctx=ast.Load()),
args=[self.visit(node.operand)],
keywords=[]
)