Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check for missing return statement #592

Merged
merged 3 commits into from Dec 21, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/parser/exceptions/test_invalid_payable.py
Expand Up @@ -27,6 +27,7 @@ def test_variable_decleration_exception(bad_code):
@payable
def foo() -> num:
self.x = 5
return self.x
""",
"""
@public
Expand Down
Expand Up @@ -495,7 +495,7 @@ def bar() -> num: pass
bar_contract: public(Bar)

@public
def foo(contract_address: contract(Bar)) -> num:
def foo(contract_address: contract(Bar)):
self.bar_contract = contract_address

@public
Expand Down
4 changes: 2 additions & 2 deletions tests/parser/syntax/test_invalids.py
Expand Up @@ -154,14 +154,14 @@ def foo():
must_succeed("""
x: num
@public
def foo() -> num:
def foo():
self.x = 5
""")

must_succeed("""
x: num
@private
def foo() -> num:
def foo():
self.x = 5
""")

Expand Down
14 changes: 7 additions & 7 deletions tests/parser/syntax/test_list.py
Expand Up @@ -32,21 +32,21 @@ def foo() -> num[2]:
y: num[3]

@public
def foo(x: num[3]) -> num:
def foo(x: num[3]):
self.y = x[0]
""",
"""
y: num[3]

@public
def foo(x: num[3]) -> num:
def foo(x: num[3]):
self.y[0] = x
""",
"""
y: num[4]

@public
def foo(x: num[3]) -> num:
def foo(x: num[3]):
self.y = x
""",
"""
Expand Down Expand Up @@ -168,28 +168,28 @@ def foo(x: num[3]) -> num:
y: num[3]

@public
def foo(x: num[3]) -> num:
def foo(x: num[3]):
self.y = x
""",
"""
y: decimal[3]

@public
def foo(x: num[3]) -> num:
def foo(x: num[3]):
self.y = x
""",
"""
y: decimal[2][2]

@public
def foo(x: num[2][2]) -> num:
def foo(x: num[2][2]):
self.y = x
""",
"""
y: decimal[2]

@public
def foo(x: num[2][2]) -> num:
def foo(x: num[2][2]):
self.y = x[1]
""",
"""
Expand Down
40 changes: 40 additions & 0 deletions tests/parser/syntax/test_missing_return.py
@@ -0,0 +1,40 @@
import pytest
from pytest import raises

from viper import compiler
from viper.exceptions import StructureException


fail_list = [
"""
@public
def foo() -> num:
pass
""",
]


@pytest.mark.parametrize('bad_code', fail_list)
def test_missing_return(bad_code):
with raises(StructureException):
compiler.compile(bad_code)



valid_list = [
"""
@public
def foo() -> num:
return 123
""",
"""
@public
def foo() -> num:
if false:
return 123
""", # For the time being this is valid code, even though it should not be.
]

@pytest.mark.parametrize('good_code', valid_list)
def test_return_success(good_code):
assert compiler.compile(good_code) is not None
12 changes: 12 additions & 0 deletions viper/parser/parser.py
Expand Up @@ -275,13 +275,18 @@ def __init__(self, vars=None, globals=None, sigs=None, forvars=None, return_type
self.origcode = origcode
# In Loop status. Whether body is currently evaluating within a for-loop or not.
self.in_for_loop = set()
# Count returns in function
self.function_return_count = 0

def set_in_for_loop(self, name_of_list):
self.in_for_loop.add(name_of_list)

def remove_in_for_loop(self, name_of_list):
self.in_for_loop.remove(name_of_list)

def increment_return_counter(self):
self.function_return_count += 1

# Add a new variable
def new_variable(self, name, typ):
if not is_varname_valid(name):
Expand Down Expand Up @@ -470,6 +475,13 @@ def parse_func(code, _globals, sigs, origcode, _vars=None):
['eq', ['mload', 0], method_id_node],
['seq'] + clampers + [parse_body(c, context) for c in code.body] + ['stop']
], typ=None, pos=getpos(code))

# Check for at leasts one return statement if necessary.
if context.return_type and context.function_return_count == 0:
raise StructureException(
"Missing return statement in function '%s' " % sig.name, code
)

o.context = context
o.total_gas = o.gas + calc_mem_gas(o.context.next_mem)
o.func_name = sig.name
Expand Down
1 change: 1 addition & 0 deletions viper/parser/stmt.py
Expand Up @@ -320,6 +320,7 @@ def parse_return(self):
if not self.stmt.value:
raise TypeMismatchException("Expecting to return a value", self.stmt)
sub = Expr(self.stmt.value, self.context).lll_node
self.context.increment_return_counter()
# Returning a value (most common case)
if isinstance(sub.typ, BaseType):
if not isinstance(self.context.return_type, BaseType):
Expand Down