Skip to content

Commit

Permalink
Merge pull request #76 from adityapb/master
Browse files Browse the repository at this point in the history
Reorganize the JIT module
  • Loading branch information
prabhuramachandran committed Jan 28, 2021
2 parents 47d2226 + 1e8ed95 commit 4ca3dca
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 88 deletions.
187 changes: 99 additions & 88 deletions compyle/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,15 @@ def getargspec(f):


def get_binop_return_type(a, b):
int_types = ['short', 'int', 'long']
float_types = ['float', 'double']
if a is None or b is None:
return None
preference_order = ['short', 'int', 'long', 'float', 'double']
if a.endswith('p') and b in int_types:
return a
if b.endswith('p') and a in int_types:
return b
preference_order = int_types + float_types
unsigned_a = unsigned_b = False
if a.startswith('u'):
unsigned_a = True
Expand All @@ -94,8 +100,10 @@ def get_binop_return_type(a, b):
class AnnotationHelper(ast.NodeVisitor):
def __init__(self, func, arg_types):
self.func = func
self.arg_types = arg_types
self.var_types = arg_types.copy()
self.name = self.func.__name__
self.arg_types = {name: self.get_declare_type(type_str)
for name, type_str in arg_types.items()}
self.var_types = self.arg_types.copy()
self.undecl_var_types = {}
self.external_funcs = {}
self.external_missing_decl = {}
Expand All @@ -106,7 +114,7 @@ def __init__(self, func, arg_types):
'''
)

def get_type(self, type_str):
def get_declare_type(self, type_str):
kind, address_space, ctype, shape = get_declare_info(type_str)
if 'unsigned' in ctype:
ctype = ctype.replace('unsigned ', 'u')
Expand All @@ -122,6 +130,22 @@ def get_missing_declarations(self, undecl_var_types):
missing_decl.update(self.external_missing_decl)
return missing_decl

def record_var_type(self, name, dtype):
self.var_types[name] = self.get_declare_type(dtype)

def record_undecl_var_type(self, name, dtype):
if name not in self.var_types and name not in self.undecl_var_types:
self.undecl_var_types[name] = self.get_declare_type(dtype)

def get_var_type(self, name):
return self.var_types.get(
name, self.undecl_var_types.get(name, 'double'))

def get_return_type(self):
if 'return_' not in self.arg_types:
warnings.warn("Couldn't find valid return type for %s" % self.name)
return self.arg_types.get('return_', 'double')

def annotate(self):
src = dedent('\n'.join(getsourcelines(self.func)[0]))
self._src = src.splitlines()
Expand All @@ -130,6 +154,19 @@ def annotate(self):
self.func = annotate(self.func, **self.arg_types)
return self.get_missing_declarations(self.undecl_var_types)

def recursive_annotate(self, f, node):
arg_types = {}
f_arg_names = getargspec(f)
for f_arg, arg in zip(f_arg_names, node.args):
arg_type = self.visit(arg)
if not arg_type:
arg_type = 'double'
arg_types[f_arg] = arg_type
f_helper = AnnotationHelper(f, arg_types)
self.external_missing_decl.update(f_helper.annotate())
self.external_funcs[node.func.id] = f_helper
return f_helper

def error(self, message, node):
msg = '\nError in code in line %d:\n' % node.lineno
if self._src: # pragma: no branch
Expand All @@ -150,10 +187,29 @@ def warn(self, message, node):
msg += message
warnings.warn(msg)

def visit_declare(self, node):
if not isinstance(node.args[0], ast.Str):
self.error("Argument to declare should be a string.", node)
type_str = node.args[0].s
return self.get_declare_type(type_str)

def visit_cast(self, node):
if not isinstance(node.args[1], ast.Str):
self.error("Cast type should be a string.", node)
return node.args[1].s

def visit_address(self, node):
base_type = self.visit(node.args[0])
if base_type.endswith('p'):
self.error("Cannot find address of a pointer", node)
if isinstance(node.args[0], ast.Subscript):
array_type = self.visit(node.args[0].value)
if array_type.startswith('g'):
base_type = 'g' + base_type
return base_type + 'p'

def visit_For(self, node):
if node.target.id not in self.var_types and \
node.target.id not in self.undecl_var_types:
self.undecl_var_types[node.target.id] = 'int'
self.record_undecl_var_type(node.target.id, 'int')
for stmt in node.body:
self.visit(stmt)

Expand All @@ -164,32 +220,30 @@ def visit_Call(self, node):
# FIXME: External functions have to be at the module level
# for this to work. Pass list of external functions to
# make this work
if node.func.id == 'annotate':
return
mod = importlib.import_module(self.func.__module__)
f = getattr(mod, node.func.id, None)
if node.func.id not in BUILTINS and not hasattr(f, 'is_jit'):
return None
if node.func.id == 'declare':
return self.visit_declare(node)
if node.func.id == 'cast':
return self.visit_cast(node)
if node.func.id == 'atomic_inc':
return self.visit(node.args[0])
if node.func.id == 'address':
return self.visit_address(node)
if node.func.id in self.external_funcs:
return self.external_funcs[node.func.id].arg_types.get(
'return_', None)
if isinstance(node.func, ast.Name) and \
node.func.id not in BUILTINS:
return self.external_funcs[node.func.id].get_return_type()
if isinstance(node.func, ast.Name) and node.func.id not in BUILTINS:
if f is None or isinstance(f, Extern):
return None
self.warn("%s could not be found or is an external function"
"and cannot be handled by JIT" % node.func.id)
return 'double'
else:
arg_types = []
for arg in node.args:
arg_type = self.visit(arg)
if not arg_type:
self.warn(dedent(self.warning_msg), arg)
arg_type = 'double'
arg_types.append(arg_type)
# make a new helper and call visit
f_arg_names = getargspec(f)
f_arg_types = dict(zip(f_arg_names, arg_types))
f_helper = AnnotationHelper(f, f_arg_types)
self.external_missing_decl.update(f_helper.annotate())
self.external_funcs[node.func.id] = f_helper
return f_helper.arg_types.get('return_', None)
f_helper = self.recursive_annotate(f, node)
return f_helper.get_return_type()
self.warn(dedent(self.warning_msg), node.func)
return 'double'

def visit_Subscript(self, node):
base_type = self.visit(node.value)
Expand All @@ -198,49 +252,22 @@ def visit_Subscript(self, node):
return base_type[:-1]

def visit_Name(self, node):
node_type = self.var_types.get(
node.id, self.undecl_var_types.get(node.id, 'double')
)
return node_type
return self.get_var_type(node.id)

def visit_Assign(self, node):
# Only for declare calls
if len(node.targets) != 1:
self.error("Assignments can have only one target.", node)
left, right = node.targets[0], node.value
if isinstance(right, ast.Call) and isinstance(right.func, ast.Name):
if right.func.id == 'declare':
if not isinstance(right.args[0], ast.Str):
self.error("Argument to declare should be a string.", node)
type = right.args[0].s
if isinstance(left, ast.Name):
self.var_types[left.id] = self.get_type(type)
elif isinstance(left, ast.Tuple):
names = [x.id for x in left.elts]
for name in names:
self.var_types[name] = self.get_type(type)
elif right.func.id == 'cast':
if not isinstance(right.args[1], ast.Str):
self.error("Cast type should be a string.", node)
type = right.args[1].s
if isinstance(left, ast.Name):
self.undecl_var_types[left.id] = self.get_type(type)
elif right.func.id == 'atomic_inc':
if left.id not in self.var_types and \
left.id not in self.undecl_var_types:
self.undecl_var_types[left.id] = self.visit(right.args[0])
elif isinstance(left, ast.Name):
if left.id not in self.var_types and \
left.id not in self.undecl_var_types:
self.undecl_var_types[left.id] = self.visit(right)
else:
self.visit(right)
right_type = self.visit(right)
if isinstance(right, ast.Call) and right.func.id == 'declare':
if isinstance(left, ast.Name):
self.record_var_type(left.id, right_type)
elif isinstance(left, ast.Tuple):
names = [x.id for x in left.elts]
for name in names:
self.record_var_type(name, right_type)
elif isinstance(left, ast.Name):
if left.id not in self.var_types and \
left.id not in self.undecl_var_types:
self.undecl_var_types[left.id] = self.visit(right)
else:
self.visit(right)
self.record_undecl_var_type(left.id, right_type)

def visit_Compare(self, node):
return 'int'
Expand All @@ -253,36 +280,20 @@ def visit_BinOp(self, node):
self.visit(node.right))

def visit_Num(self, node):
if isinstance(node.n, float):
return_type = 'double'
else:
if node.n > 2147483648:
return_type = 'long'
else:
return_type = 'int'
return return_type
return get_ctype_from_arg(node.n)

def visit_UnaryOp(self, node):
return self.visit(node.operand)

def visit_Return(self, node):
if isinstance(node.value, ast.Name) or \
isinstance(node.value, ast.Subscript) or \
isinstance(node.value, ast.Num) or \
isinstance(node.value, ast.BinOp) or \
isinstance(node.value, ast.Call) or \
isinstance(node.value, ast.IfExp) or \
isinstance(node.value, ast.UnaryOp):
if node and node.value:
result_type = self.visit(node.value)
if result_type:
self.arg_types['return_'] = result_type
else:
self.arg_types['return_'] = 'double'
else:
if node.value:
self.warn("Unknown type for return value. "
"Return value defaulting to 'double'", node)
self.arg_types['return_'] = 'double'
return result_type
self.warn("Unknown type for return value. "
"Return value defaulting to 'double'", node)
self.arg_types['return_'] = 'double'


class ElementwiseJIT(parallel.ElementwiseBase):
Expand Down Expand Up @@ -528,7 +539,7 @@ def __call__(self, **kwargs):
c_args_dict = {k: self._massage_arg(x) for k, x in kwargs.items()}
if self._get_backend_key() in self.output_func.arg_keys:
output_arg_keys = self.output_func.arg_keys[
self._get_backend_key()]
self._get_backend_key()]
else:
raise ValueError("No kernel arguments found for backend = %s, "
"use_openmp = %s, use_double = %s" %
Expand Down
93 changes: 93 additions & 0 deletions compyle/tests/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ def g(x):
return x


@annotate(x='int', return_='int')
def g_nonjit(x):
return x


@annotate
def h(a, b):
return g(a) * g(b)
Expand Down Expand Up @@ -68,6 +73,24 @@ def double_f(a):
# Then
assert helper.external_funcs['g'].arg_types['x'] == 'double'

def test_declare_multiple_variables(self):
# Given
@annotate
def f(x):
a, b = declare('int', 2)
a = 0
b = 1
return x + a + b

# When
types = {'x': 'int'}
helper = AnnotationHelper(f, types)
helper.annotate()

# Then
assert helper.get_var_type('a') == 'int'
assert helper.get_var_type('b') == 'int'

def test_variable_as_call_arg(self):
# Given
@annotate
Expand All @@ -84,6 +107,22 @@ def f(a, b):
# Then
assert helper.external_funcs['g'].arg_types['x'] == 'int'

def test_variable_as_call_arg_nonjit(self):
# Given
@annotate
def f(a, b):
x = declare('int')
x = a + b
return g_nonjit(x)

# When
types = {'a': 'int', 'b': 'int'}
helper = AnnotationHelper(f, types)
helper.annotate()

# Then
assert helper.external_funcs['g_nonjit'].arg_types['x'] == 'int'

def test_subscript_as_call_arg(self):
# Given
@annotate
Expand Down Expand Up @@ -375,6 +414,60 @@ def f(a, b):
# Then
assert helper.arg_types['return_'] == 'ulong'

# When
types = {'a': 'intp', 'b': 'int'}
helper = AnnotationHelper(f, types)
helper.annotate()

# Then
assert helper.arg_types['return_'] == 'intp'

# When
types = {'a': 'gdoublep', 'b': 'int'}
helper = AnnotationHelper(f, types)
helper.annotate()

# Then
assert helper.arg_types['return_'] == 'gdoublep'

# When
types = {'a': 'int', 'b': 'intp'}
helper = AnnotationHelper(f, types)
helper.annotate()

# Then
assert helper.arg_types['return_'] == 'intp'

def test_cast_return_type(self):
# Given
@annotate
def f(a):
return cast(a, "int")

# When
types = {'a': 'double'}
helper = AnnotationHelper(f, types)
helper.annotate()

# Then
assert helper.get_return_type() == 'int'

def test_address_type(self):
# Given
@annotate
def f(a):
b = address(a[0])
return b[0]

# When
types = {'a': 'gintp'}
helper = AnnotationHelper(f, types)
helper.annotate()

# Then
assert helper.get_var_type('b') == 'gintp'
assert helper.get_return_type() == 'int'

def test_undeclared_variable_declaration(self):
# Given
@annotate
Expand Down

0 comments on commit 4ca3dca

Please sign in to comment.