Skip to content

Commit

Permalink
Merge pull request #73 from pypr/fix-opencl-bug
Browse files Browse the repository at this point in the history
Fix issue with previous PR.
  • Loading branch information
prabhuramachandran committed Dec 27, 2020
2 parents 2b8cd45 + 5ac7ecc commit d752f60
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
5 changes: 3 additions & 2 deletions compyle/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,8 @@ def sort_by_keys(ary_list, out_list=None, key_bits=None,
for ary in ary_list
]

arg_types = [get_ctype_from_arg(arg) for arg in ary_list]
arg_types = [get_ctype_from_arg(arg, backend=backend)
for arg in ary_list]

sort_knl = get_cl_sort_kernel(arg_types, ary_list)
allocator = get_allocator(get_queue())
Expand Down Expand Up @@ -643,7 +644,7 @@ def template(self, i, order):

def key_align_kernel(ary_list, order, backend=None):
from .jit import get_ctype_from_arg
key = [get_ctype_from_arg(ary) for ary in ary_list]
key = [get_ctype_from_arg(ary, backend=backend) for ary in ary_list]
key.append(backend)
key.append(get_config().use_openmp)
return tuple(key)
Expand Down
15 changes: 8 additions & 7 deletions compyle/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ def wrapper(*args):
return memoize_deco


def get_ctype_from_arg(arg):
def get_ctype_from_arg(arg, backend=None):
if isinstance(arg, array.Array):
return arg.gptr_type
elif isinstance(arg, np.ndarray) or isinstance(arg, np.floating):
return dtype_to_ctype(arg.dtype)
return dtype_to_ctype(arg.dtype, backend=backend)
else:
if isinstance(arg, float):
return 'double'
Expand All @@ -50,14 +50,15 @@ def get_ctype_from_arg(arg):


def kernel_cache_key_args(obj, *args):
key = [get_ctype_from_arg(arg) for arg in args]
key = [get_ctype_from_arg(arg, backend=obj.backend) for arg in args]
key.append(obj.func)
key.append(obj.name)
return tuple(key + list(parallel.get_common_cache_key(obj)))


def kernel_cache_key_kwargs(obj, **kwargs):
key = [get_ctype_from_arg(arg) for arg in kwargs.values()]
key = [get_ctype_from_arg(arg, backend=obj.backend)
for arg in kwargs.values()]
key.append(obj.input_func)
key.append(obj.output_func)
key.append(obj.scan_expr)
Expand Down Expand Up @@ -306,7 +307,7 @@ def get_type_info_from_args(self, *args):
arg_names.remove('i')
type_info['i'] = 'int'
for arg, name in zip(args, arg_names):
arg_type = get_ctype_from_arg(arg)
arg_type = get_ctype_from_arg(arg, backend=self.backend)
if not arg_type:
arg_type = 'double'
type_info[name] = arg_type
Expand Down Expand Up @@ -384,7 +385,7 @@ def get_type_info_from_args(self, *args):
arg_names.remove('i')
type_info['i'] = 'int'
for arg, name in zip(args, arg_names):
arg_type = get_ctype_from_arg(arg)
arg_type = get_ctype_from_arg(arg, backend=self.backend)
if not arg_type:
arg_type = 'double'
type_info[name] = arg_type
Expand Down Expand Up @@ -477,7 +478,7 @@ def get_type_info_from_kwargs(self, func, **kwargs):
if name in self.builtin_types:
arg_type = self.builtin_types[name]
else:
arg_type = get_ctype_from_arg(arg)
arg_type = get_ctype_from_arg(arg, backend=self.backend)
if not arg_type:
arg_type = 'double'
type_info[name] = arg_type
Expand Down

0 comments on commit d752f60

Please sign in to comment.