Skip to content

Commit

Permalink
Merge pull request #592 from jacqueswww/590-missing-return-statement
Browse files Browse the repository at this point in the history
Check for missing return statement
  • Loading branch information
DavidKnott committed Dec 21, 2017
2 parents c5a4a7a + d898f0c commit bee759a
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 10 deletions.
1 change: 1 addition & 0 deletions tests/parser/exceptions/test_invalid_payable.py
Expand Up @@ -27,6 +27,7 @@ def test_variable_decleration_exception(bad_code):
@payable
def foo() -> num:
self.x = 5
return self.x
""",
"""
@public
Expand Down
Expand Up @@ -495,7 +495,7 @@ def bar() -> num: pass
bar_contract: public(Bar)
@public
def foo(contract_address: contract(Bar)) -> num:
def foo(contract_address: contract(Bar)):
self.bar_contract = contract_address
@public
Expand Down
4 changes: 2 additions & 2 deletions tests/parser/syntax/test_invalids.py
Expand Up @@ -154,14 +154,14 @@ def foo():
must_succeed("""
x: num
@public
def foo() -> num:
def foo():
self.x = 5
""")

must_succeed("""
x: num
@private
def foo() -> num:
def foo():
self.x = 5
""")

Expand Down
14 changes: 7 additions & 7 deletions tests/parser/syntax/test_list.py
Expand Up @@ -32,21 +32,21 @@ def foo() -> num[2]:
y: num[3]
@public
def foo(x: num[3]) -> num:
def foo(x: num[3]):
self.y = x[0]
""",
"""
y: num[3]
@public
def foo(x: num[3]) -> num:
def foo(x: num[3]):
self.y[0] = x
""",
"""
y: num[4]
@public
def foo(x: num[3]) -> num:
def foo(x: num[3]):
self.y = x
""",
"""
Expand Down Expand Up @@ -168,28 +168,28 @@ def foo(x: num[3]) -> num:
y: num[3]
@public
def foo(x: num[3]) -> num:
def foo(x: num[3]):
self.y = x
""",
"""
y: decimal[3]
@public
def foo(x: num[3]) -> num:
def foo(x: num[3]):
self.y = x
""",
"""
y: decimal[2][2]
@public
def foo(x: num[2][2]) -> num:
def foo(x: num[2][2]):
self.y = x
""",
"""
y: decimal[2]
@public
def foo(x: num[2][2]) -> num:
def foo(x: num[2][2]):
self.y = x[1]
""",
"""
Expand Down
40 changes: 40 additions & 0 deletions tests/parser/syntax/test_missing_return.py
@@ -0,0 +1,40 @@
import pytest
from pytest import raises

from viper import compiler
from viper.exceptions import StructureException


fail_list = [
"""
@public
def foo() -> num:
pass
""",
]


@pytest.mark.parametrize('bad_code', fail_list)
def test_missing_return(bad_code):
with raises(StructureException):
compiler.compile(bad_code)



valid_list = [
"""
@public
def foo() -> num:
return 123
""",
"""
@public
def foo() -> num:
if false:
return 123
""", # For the time being this is valid code, even though it should not be.
]

@pytest.mark.parametrize('good_code', valid_list)
def test_return_success(good_code):
assert compiler.compile(good_code) is not None
12 changes: 12 additions & 0 deletions viper/parser/parser.py
Expand Up @@ -275,13 +275,18 @@ def __init__(self, vars=None, globals=None, sigs=None, forvars=None, return_type
self.origcode = origcode
# In Loop status. Whether body is currently evaluating within a for-loop or not.
self.in_for_loop = set()
# Count returns in function
self.function_return_count = 0

def set_in_for_loop(self, name_of_list):
self.in_for_loop.add(name_of_list)

def remove_in_for_loop(self, name_of_list):
self.in_for_loop.remove(name_of_list)

def increment_return_counter(self):
self.function_return_count += 1

# Add a new variable
def new_variable(self, name, typ):
if not is_varname_valid(name):
Expand Down Expand Up @@ -470,6 +475,13 @@ def parse_func(code, _globals, sigs, origcode, _vars=None):
['eq', ['mload', 0], method_id_node],
['seq'] + clampers + [parse_body(c, context) for c in code.body] + ['stop']
], typ=None, pos=getpos(code))

# Check for at leasts one return statement if necessary.
if context.return_type and context.function_return_count == 0:
raise StructureException(
"Missing return statement in function '%s' " % sig.name, code
)

o.context = context
o.total_gas = o.gas + calc_mem_gas(o.context.next_mem)
o.func_name = sig.name
Expand Down
1 change: 1 addition & 0 deletions viper/parser/stmt.py
Expand Up @@ -320,6 +320,7 @@ def parse_return(self):
if not self.stmt.value:
raise TypeMismatchException("Expecting to return a value", self.stmt)
sub = Expr(self.stmt.value, self.context).lll_node
self.context.increment_return_counter()
# Returning a value (most common case)
if isinstance(sub.typ, BaseType):
if not isinstance(self.context.return_type, BaseType):
Expand Down

0 comments on commit bee759a

Please sign in to comment.