Skip to content

Commit

Permalink
Merge pull request #2 from adityapb/auto-decl
Browse files Browse the repository at this point in the history
Use AnnotationHelper to declare undeclared variables
  • Loading branch information
prabhuramachandran committed Jan 18, 2019
2 parents c54fc2b + cef2961 commit 12ccf58
Show file tree
Hide file tree
Showing 8 changed files with 213 additions and 71 deletions.
36 changes: 22 additions & 14 deletions compyle/cython_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,19 +189,19 @@ def detect_type(self, name, value):
def get_code(self):
return self.code

def parse(self, obj):
def parse(self, obj, declarations=None):
obj_type = type(obj)
if isinstance(obj, types.FunctionType):
self._parse_function(obj)
self._parse_function(obj, declarations=declarations)
elif hasattr(obj, '__class__'):
self._parse_instance(obj)
else:
raise TypeError('Unsupported type to wrap: %s' % obj_type)

def get_func_signature(self, func):
"""Given a function that is wrapped, return the Python wrapper definition
signature and the Python call signature and the C wrapper definition
and C call signature.
"""Given a function that is wrapped, return the Python wrapper
definition signature and the Python call signature and the C
wrapper definition and C call signature.
For example if we had
Expand Down Expand Up @@ -278,7 +278,7 @@ def _analyze_method(self, meth, lines):

# The call_args dict is filled up with the defaults to detect
# the appropriate type of the arguments.
for i in range(1, len(defaults)+1):
for i in range(1, len(defaults) + 1):
call_args[args[-i]] = defaults[-i]

# Set the rest to Undefined
Expand Down Expand Up @@ -333,20 +333,26 @@ def _get_methods(self, cls):
if name in self.ignore_methods:
continue

c_code, py_code = self._get_method_wrapper(meth, indent=' '*8)
c_code, py_code = self._get_method_wrapper(
meth, indent=' ' * 8)
methods.append(c_code)
if self.python_methods:
methods.append(py_code)

return methods

def _get_method_body(self, meth, lines, indent=' '*8):
def _get_method_body(self, meth, lines, indent=' ' * 8, declarations=None):
getfullargspec = getattr(
inspect, 'getfullargspec', inspect.getargspec
)
args = set(getfullargspec(meth).args)
src = [self._process_body_line(line) for line in lines]
declared = []
if declarations:
cy_decls = []
for var, decl in declarations.items():
cy_decls.append((var, indent + 'cdef %s\n' % decl[:-1]))
src = cy_decls + src
declared = [] if not declarations else declarations.keys()
for names, defn in src:
if names:
declared.extend(x.strip() for x in names.split(','))
Expand All @@ -359,12 +365,13 @@ def _get_method_body(self, meth, lines, indent=' '*8):
code = ''.join(declare) + cython_body
return code

def _get_method_wrapper(self, meth, indent=' '*8):
def _get_method_wrapper(self, meth, indent=' ' * 8, declarations=None):
sourcelines = inspect.getsourcelines(meth)[0]
defn, lines = get_func_definition(sourcelines)
m_name, returns, args = self._analyze_method(meth, lines)
c_defn = self._get_c_method_spec(m_name, returns, args)
c_body = self._get_method_body(meth, lines, indent=indent)
c_body = self._get_method_body(meth, lines, indent=indent,
declarations=declarations)
self.code = '{defn}\n{body}'.format(defn=c_defn, body=c_body)
if self.python_methods:
defn, body = self._get_py_method_spec(m_name, returns, args,
Expand All @@ -380,7 +387,7 @@ def _get_public_vars(self, obj):
for name in sorted(data.keys()))
return vars

def _get_py_method_spec(self, name, returns, args, indent=' '*8):
def _get_py_method_spec(self, name, returns, args, indent=' ' * 8):
"""Returns a Python friendly definition for the method along with the
wrapper function.
"""
Expand Down Expand Up @@ -438,8 +445,9 @@ def matrix(size):
defn = 'cdef {type} {name}'.format(type=ctype, name=name)
return defn

def _parse_function(self, obj):
c_code, py_code = self._get_method_wrapper(obj, indent=' '*4)
def _parse_function(self, obj, declarations=None):
c_code, py_code = self._get_method_wrapper(obj, indent=' ' * 4,
declarations=declarations)
code = '{defn}\n{body}'.format(defn=c_code[0], body=c_code[1])
if self.python_methods:
code += '\n'
Expand Down
57 changes: 43 additions & 14 deletions compyle/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def __init__(self, func, arg_types):
self.func = func
self.arg_types = arg_types
self.var_types = arg_types.copy()
self.undecl_var_types = {}
self.external_funcs = {}
self.external_missing_decl = {}
self.warning_msg = ('''
Function called is not marked by the annotate decorator. Argument
type defaulting to 'double'. If the type is not 'double', store
Expand All @@ -88,12 +90,21 @@ def get_type(self, type_str):
ctype = '%sp' % ctype
return ctype

def get_missing_declarations(self, undecl_var_types):
declarations = {}
for var_name, dtype in undecl_var_types.items():
declarations[var_name] = '%s %s;' % (dtype, var_name)
missing_decl = {self.func.__name__: declarations}
missing_decl.update(self.external_missing_decl)
return missing_decl

def annotate(self):
src = dedent('\n'.join(inspect.getsourcelines(self.func)[0]))
self._src = src.splitlines()
code = ast.parse(src)
self.visit(code)
self.func = annotate(self.func, **self.arg_types)
return self.get_missing_declarations(self.undecl_var_types)

def error(self, message, node):
msg = '\nError in code in line %d:\n' % node.lineno
Expand All @@ -115,6 +126,16 @@ def warn(self, message, node):
msg += message
warnings.warn(msg)

def visit_For(self, node):
if node.target.id not in self.var_types and \
node.target.id not in self.undecl_var_types:
self.undecl_var_types[node.target.id] = 'int'
for stmt in node.body:
self.visit(stmt)

def visit_IfExp(self, node):
return self.visit(node.body)

def visit_Call(self, node):
# FIXME: External functions have to be at the module level
# for this to work. Pass list of external functions to
Expand Down Expand Up @@ -142,7 +163,7 @@ def visit_Call(self, node):
f_arg_names = getargspec(f)
f_arg_types = dict(zip(f_arg_names, arg_types))
f_helper = AnnotationHelper(f, f_arg_types)
f_helper.annotate()
self.external_missing_decl.update(f_helper.annotate())
self.external_funcs[node.func.id] = f_helper
return f_helper.arg_types.get('return_', None)

Expand All @@ -153,7 +174,9 @@ def visit_Subscript(self, node):
return base_type[:-1]

def visit_Name(self, node):
node_type = self.var_types.get(node.id, 'double')
node_type = self.var_types.get(
node.id, self.undecl_var_types.get(node.id, 'double')
)
return node_type

def visit_Assign(self, node):
Expand All @@ -172,6 +195,11 @@ def visit_Assign(self, node):
names = [x.id for x in left.elts]
for name in names:
self.var_types[name] = self.get_type(type)
else:
if isinstance(left, ast.Name) and \
left.id not in self.var_types and \
left.id not in self.undecl_var_types:
self.undecl_var_types[left.id] = self.visit(right)

def visit_Compare(self, node):
return 'int'
Expand All @@ -196,10 +224,10 @@ def visit_Num(self, node):
def visit_Return(self, node):
if isinstance(node.value, ast.Name) or \
isinstance(node.value, ast.Subscript) or \
isinstance(node.value, ast.Num):
self.arg_types['return_'] = self.visit(node.value)
elif isinstance(node.value, ast.BinOp) or \
isinstance(node.value, ast.Call):
isinstance(node.value, ast.Num) or \
isinstance(node.value, ast.BinOp) or \
isinstance(node.value, ast.Call) or \
isinstance(node.value, ast.IfExp):
result_type = self.visit(node.value)
if result_type:
self.arg_types['return_'] = self.visit(node.value)
Expand Down Expand Up @@ -241,9 +269,9 @@ def _generate_kernel(self, *args):
if self.func is not None:
arg_types = self.get_type_info_from_args(*args)
helper = AnnotationHelper(self.func, arg_types)
helper.annotate()
declarations = helper.annotate()
self.func = helper.func
return self._generate()
return self._generate(declarations=declarations)

def _massage_arg(self, x):
if isinstance(x, array.Array):
Expand Down Expand Up @@ -314,9 +342,9 @@ def _generate_kernel(self, *args):
if self.func is not None:
arg_types = self.get_type_info_from_args(*args)
helper = AnnotationHelper(self.func, arg_types)
helper.annotate()
declarations = helper.annotate()
self.func = helper.func
return self._generate()
return self._generate(declarations=declarations)

def _massage_arg(self, x):
if isinstance(x, array.Array):
Expand Down Expand Up @@ -398,30 +426,31 @@ def get_type_info_from_kwargs(self, func, **kwargs):

@memoize(key=kernel_cache_key_kwargs, use_kwargs=True)
def _generate_kernel(self, **kwargs):
declarations = {}
if self.input_func is not None:
arg_types = self.get_type_info_from_kwargs(
self.input_func, **kwargs)
arg_types['return_'] = dtype_to_knowntype(self.dtype)
helper = AnnotationHelper(self.input_func, arg_types)
helper.annotate()
declarations.update(helper.annotate())
self.input_func = helper.func

if self.output_func is not None:
arg_types = self.get_type_info_from_kwargs(
self.output_func, **kwargs)
helper = AnnotationHelper(self.output_func, arg_types)
helper.annotate()
declarations.update(helper.annotate())
self.output_func = helper.func

if self.is_segment_func is not None:
arg_types = self.get_type_info_from_kwargs(
self.is_segment_func, **kwargs)
arg_types['return_'] = 'int'
helper = AnnotationHelper(self.is_segment_func, arg_types)
helper.annotate()
declarations.update(helper.annotate())
self.is_segment_func = helper.func

return self._generate()
return self._generate(declarations=declarations)

def _massage_arg(self, x):
if isinstance(x, array.Array):
Expand Down
49 changes: 28 additions & 21 deletions compyle/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,8 @@ def __init__(self, func, backend='cython'):
self.queue = None
self.c_func = self._generate()

def _generate(self):
self.tp.add(self.func)
def _generate(self, declarations=None):
self.tp.add(self.func, declarations=declarations)
if self.backend == 'cython':
py_data, c_data = self.cython_gen.get_func_signature(self.func)
py_defn = ['long SIZE'] + py_data[0][1:]
Expand Down Expand Up @@ -555,10 +555,10 @@ def __init__(self, reduce_expr, map_func=None, dtype_out=np.float64,
self.queue = None
self.c_func = self._generate()

def _generate(self):
def _generate(self, declarations=None):
if self.backend == 'cython':
if self.func is not None:
self.tp.add(self.func)
self.tp.add(self.func, declarations=declarations)
py_data, c_data = self.cython_gen.get_func_signature(self.func)
self._correct_return_type(c_data)
name = self.func.__name__
Expand Down Expand Up @@ -591,7 +591,7 @@ def _generate(self):
return getattr(self.tp.mod, 'py_' + self.name)
elif self.backend == 'opencl':
if self.func is not None:
self.tp.add(self.func)
self.tp.add(self.func, declarations=declarations)
py_data, c_data = self.cython_gen.get_func_signature(self.func)
self._correct_opencl_address_space(c_data)
name = self.func.__name__
Expand Down Expand Up @@ -629,7 +629,7 @@ def _generate(self):
return knl
elif self.backend == 'cuda':
if self.func is not None:
self.tp.add(self.func)
self.tp.add(self.func, declarations=declarations)
py_data, c_data = self.cython_gen.get_func_signature(self.func)
self._correct_opencl_address_space(c_data)
name = self.func.__name__
Expand Down Expand Up @@ -816,13 +816,13 @@ def _num_ignore_args(self, c_data):
break
return result

def _generate(self):
def _generate(self, declarations=None):
if self.backend == 'opencl':
return self._generate_opencl_kernel()
return self._generate_opencl_kernel(declarations=declarations)
elif self.backend == 'cuda':
return self._generate_cuda_kernel()
return self._generate_cuda_kernel(declarations=declarations)
elif self.backend == 'cython':
return self._generate_cython_code()
return self._generate_cython_code(declarations=declarations)

def _default_cython_input_function(self):
py_data = (['int i', '{type}[:] input'.format(type=self.type)],
Expand Down Expand Up @@ -859,7 +859,7 @@ def _append_cython_arg_data(self, all_py_data, all_c_data, py_data,
all_c_data[0].extend(c_data[0][n_ignore:])
all_c_data[1].extend(c_data[1][n_ignore:])

def _generate_cython_code(self):
def _generate_cython_code(self, declarations=None):
name = self.name
all_py_data = [[], []]
all_c_data = [[], []]
Expand Down Expand Up @@ -926,9 +926,9 @@ def _generate_cython_code(self):
self.tp.compile()
return getattr(self.tp.mod, 'py_' + self.name)

def _wrap_ocl_function(self, func, func_type=None):
def _wrap_ocl_function(self, func, func_type=None, declarations=None):
if func is not None:
self.tp.add(func)
self.tp.add(func, declarations=declarations)
py_data, c_data = self.cython_gen.get_func_signature(func)
self._correct_opencl_address_space(c_data, func, func_type)
name = func.__name__
Expand Down Expand Up @@ -961,15 +961,18 @@ def _get_scan_expr_opencl_cuda(self):
else:
return self.scan_expr

def _get_opencl_cuda_code(self):
def _get_opencl_cuda_code(self, declarations=None):
input_expr, input_args, input_c_args = \
self._wrap_ocl_function(self.input_func, func_type='input')
self._wrap_ocl_function(self.input_func, func_type='input',
declarations=declarations)

output_expr, output_args, output_c_args = \
self._wrap_ocl_function(self.output_func, func_type='output')
self._wrap_ocl_function(self.output_func, func_type='output',
declarations=declarations)

segment_expr, segment_args, segment_c_args = \
self._wrap_ocl_function(self.is_segment_func)
self._wrap_ocl_function(self.is_segment_func,
declarations=declarations)

scan_expr = self._get_scan_expr_opencl_cuda()

Expand All @@ -986,9 +989,11 @@ def _get_opencl_cuda_code(self):
return scan_expr, arg_defn, input_expr, output_expr, \
segment_expr, preamble

def _generate_opencl_kernel(self):
def _generate_opencl_kernel(self, declarations=None):
scan_expr, arg_defn, input_expr, output_expr, \
segment_expr, preamble = self._get_opencl_cuda_code()
segment_expr, preamble = self._get_opencl_cuda_code(
declarations=declarations
)

from .opencl import get_context, get_queue
from pyopencl.scan import GenericScanKernel
Expand All @@ -1007,9 +1012,11 @@ def _generate_opencl_kernel(self):
)
return knl

def _generate_cuda_kernel(self):
def _generate_cuda_kernel(self, declarations=None):
scan_expr, arg_defn, input_expr, output_expr, \
segment_expr, preamble = self._get_opencl_cuda_code()
segment_expr, preamble = self._get_opencl_cuda_code(
declarations=declarations
)

from .cuda import set_context, GenericScanKernel
set_context()
Expand Down

0 comments on commit 12ccf58

Please sign in to comment.