diff --git a/tests/parser/syntax/test_conditionals.py b/tests/parser/syntax/test_conditionals.py new file mode 100644 index 0000000000..10b5172a15 --- /dev/null +++ b/tests/parser/syntax/test_conditionals.py @@ -0,0 +1,27 @@ +import pytest +from vyper import compiler + +valid_list = [ + """ +@private +def mkint() -> int128: + return 1 + +@public +def test_zerovalent(): + if True: + self.mkint() + +@public +def test_valency_mismatch(): + if True: + self.mkint() + else: + pass + """ +] + + +@pytest.mark.parametrize('good_code', valid_list) +def test_conditional_return_code(good_code): + assert compiler.compile_code(good_code) is not None diff --git a/vyper/parser/lll_node.py b/vyper/parser/lll_node.py index bf96d6308e..cd989f0568 100644 --- a/vyper/parser/lll_node.py +++ b/vyper/parser/lll_node.py @@ -107,12 +107,8 @@ def __init__(self, value, args=None, typ=None, location=None, pos=None, annotati elif self.value == 'if': if len(self.args) == 3: self.gas = self.args[0].gas + max(self.args[1].gas, self.args[2].gas) + 3 - if self.args[1].valency != self.args[2].valency: - raise Exception("Valency mismatch between then and else clause: %r %r" % (self.args[1], self.args[2])) if len(self.args) == 2: self.gas = self.args[0].gas + self.args[1].gas + 17 - if self.args[1].valency: - raise Exception("2-clause if statement must have a zerovalent body: %r" % self.args[1]) if not self.args[0].valency: raise Exception("Can't have a zerovalent argument as a test to an if statement! %r" % self.args[0]) if len(self.args) not in (2, 3): diff --git a/vyper/parser/stmt.py b/vyper/parser/stmt.py index c353a9d472..6fde572638 100644 --- a/vyper/parser/stmt.py +++ b/vyper/parser/stmt.py @@ -252,7 +252,7 @@ def parse_if(self): if self.stmt.orelse: block_scope_id = id(self.stmt.orelse) with self.context.make_blockscope(block_scope_id): - add_on = [parse_body(self.stmt.orelse, self.context)] + add_on = [['seq', parse_body(self.stmt.orelse, self.context)]] else: add_on = [] @@ -262,12 +262,13 @@ def parse_if(self): if not self.is_bool_expr(test_expr): raise TypeMismatchException('Only boolean expressions allowed', self.stmt.test) - + body = ['if', test_expr, + ['seq', parse_body(self.stmt.body, self.context)]] \ + + add_on o = LLLnode.from_list( - ['if', test_expr, parse_body(self.stmt.body, self.context)] + add_on, + body, typ=None, pos=getpos(self.stmt) ) - return o def _clear(self):