Skip to content

Commit

Permalink
Merge pull request #77 from adityapb/master
Browse files Browse the repository at this point in the history
Fix return visit in JIT and fix unsigned types handling in binop
  • Loading branch information
prabhuramachandran committed Feb 19, 2021
2 parents 4ca3dca + 0aab06b commit 6b3a512
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 7 deletions.
18 changes: 11 additions & 7 deletions compyle/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,23 +71,32 @@ def getargspec(f):
return getargspec_f(f)[0]


def get_signed_type(a):
return a[1:] if a.startswith('u') else a


def get_binop_return_type(a, b):
int_types = ['short', 'int', 'long']
float_types = ['float', 'double']

if a is None or b is None:
return None
if a.endswith('p') and b in int_types:

if a.endswith('p') and get_signed_type(b) in int_types:
return a
if b.endswith('p') and a in int_types:
if b.endswith('p') and get_signed_type(a) in int_types:
return b

preference_order = int_types + float_types

unsigned_a = unsigned_b = False
if a.startswith('u'):
unsigned_a = True
a = a[1:]
if b.startswith('u'):
unsigned_b = True
b = b[1:]

idx_a = preference_order.index(a)
idx_b = preference_order.index(b)
return_type = preference_order[idx_a] if idx_a > idx_b else \
Expand Down Expand Up @@ -142,8 +151,6 @@ def get_var_type(self, name):
name, self.undecl_var_types.get(name, 'double'))

def get_return_type(self):
if 'return_' not in self.arg_types:
warnings.warn("Couldn't find valid return type for %s" % self.name)
return self.arg_types.get('return_', 'double')

def annotate(self):
Expand Down Expand Up @@ -291,9 +298,6 @@ def visit_Return(self, node):
if result_type:
self.arg_types['return_'] = result_type
return result_type
self.warn("Unknown type for return value. "
"Return value defaulting to 'double'", node)
self.arg_types['return_'] = 'double'


class ElementwiseJIT(parallel.ElementwiseBase):
Expand Down
47 changes: 47 additions & 0 deletions compyle/tests/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,22 @@ def f(a, b):
# Then
assert helper.arg_types['return_'] == 'intp'

# When
types = {'a': 'int', 'b': 'guintp'}
helper = AnnotationHelper(f, types)
helper.annotate()

# Then
assert helper.arg_types['return_'] == 'guintp'

# When
types = {'a': 'uint', 'b': 'guintp'}
helper = AnnotationHelper(f, types)
helper.annotate()

# Then
assert helper.arg_types['return_'] == 'guintp'

def test_cast_return_type(self):
# Given
@annotate
Expand Down Expand Up @@ -531,3 +547,34 @@ def f(a, b):

# Then
assert helper.undecl_var_types['i'] == 'int'

def test_no_return_value(self):
# Given
@annotate
def f_no_return(a, n):
for i in range(n):
a[i] += 1
return

# When
types = {'a': 'guintp', 'n': 'int'}
helper = AnnotationHelper(f_no_return, types)
helper.annotate()

# Then
assert 'return_' not in helper.arg_types

# Given
@annotate
def f_return(a, n):
for i in range(n):
a[i] += 1
return n

# When
helper = AnnotationHelper(f_return, types)
helper.annotate()

# Then
assert 'return_' in helper.arg_types and \
helper.arg_types['return_'] == 'int'

0 comments on commit 6b3a512

Please sign in to comment.