Skip to content

Commit

Permalink
feat: allow dynamic array iterators (#2606)
Browse files Browse the repository at this point in the history
for x in <dynarray>:
    ... # do something
  • Loading branch information
charles-cooper committed Jan 21, 2022
1 parent 25075ae commit 0004e0c
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 62 deletions.
57 changes: 57 additions & 0 deletions tests/parser/features/iteration/test_for_in_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ def data() -> int128:
return -1""",
3,
),
# basic for-in-dynamic array
(
"""
@external
def data() -> int128:
s: DynArray[int128, 10] = [1, 2, 3, 4, 5]
for i in s:
if i >= 3:
return i
return -1""",
3,
),
# basic for-in-list literal
(
"""
Expand Down Expand Up @@ -88,6 +100,31 @@ def data() -> int128:
assert c.data() == 7


def test_basic_for_dyn_array_storage(get_contract_with_gas_estimation):
code = """
x: DynArray[int128, 4]
@external
def set(xs: DynArray[int128, 4]):
self.x = xs
@external
def data() -> int128:
t: int128 = 0
for i in self.x:
t += i
return t
"""

c = get_contract_with_gas_estimation(code)

assert c.data() == 0
# test all sorts of lists
for xs in [[3, 5, 7, 9], [4, 6, 8], [1, 2], [5], []]:
c.set(xs, transact={})
assert c.data() == sum(xs)


def test_basic_for_list_storage_address(get_contract_with_gas_estimation):
code = """
addresses: address[3]
Expand Down Expand Up @@ -171,6 +208,26 @@ def func(amounts: uint256[3]) -> uint256:
assert c.func([100, 200, 300]) == 600


def test_for_in_dyn_array(get_contract_with_gas_estimation):
code = """
@external
@view
def func(amounts: DynArray[uint256, 3]) -> uint256:
total: uint256 = 0
# calculate total
for amount in amounts:
total += amount
return total
"""

c = get_contract_with_gas_estimation(code)

assert c.func([100, 200, 300]) == 600
assert c.func([100, 200]) == 300


GOOD_CODE = [
# multiple for loops
"""
Expand Down
105 changes: 45 additions & 60 deletions vyper/codegen/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from vyper.codegen.context import Constancy, Context
from vyper.codegen.core import (
LLLnode,
get_dyn_array_count,
get_element_ptr,
getpos,
make_byte_array_copier,
make_setter,
Expand All @@ -14,7 +16,7 @@
)
from vyper.codegen.expr import Expr
from vyper.codegen.return_ import make_return_stmt
from vyper.codegen.types import BaseType, ByteArrayType, SArrayType, parse_type
from vyper.codegen.types import BaseType, ByteArrayType, DArrayType, SArrayType, parse_type
from vyper.exceptions import CompilerPanic, StructureException, TypeCheckFailure


Expand Down Expand Up @@ -294,80 +296,63 @@ def _parse_For_range(self):

def _parse_For_list(self):
with self.context.range_scope():
iter_list_node = Expr(self.stmt.iter, self.context).lll_node
if not isinstance(iter_list_node.typ.subtype, BaseType): # Sanity check on list subtype.
iter_list = Expr(self.stmt.iter, self.context).lll_node

# TODO relax this restriction
if not isinstance(iter_list.typ.subtype, BaseType):
return

iter_var_type = (
self.context.vars.get(self.stmt.iter.id).typ
if isinstance(self.stmt.iter, vy_ast.Name)
else None
)
# override with type inferred at typechecking time
subtype = BaseType(self.stmt.target._metadata["type"]._id)
iter_list_node.typ.subtype = subtype
iter_list.typ.subtype = subtype

# user-supplied name for loop variable
varname = self.stmt.target.id
value_pos = self.context.new_variable(varname, subtype)
i_pos = self.context.new_internal_variable(subtype)
loop_var = LLLnode.from_list(
self.context.new_variable(varname, subtype),
typ=subtype,
location="memory",
)

iptr = LLLnode.from_list(
self.context.new_internal_variable(BaseType("uint256")),
typ="uint256",
location="memory",
)

self.context.forvars[varname] = True

# Is a list that is already allocated to memory.
if iter_var_type:
iter_var = self.context.vars.get(self.stmt.iter.id)
if iter_var.location == "calldata":
fetcher = "calldataload"
elif iter_var.location == "memory":
fetcher = "mload"
else:
return
body = [
"seq",
[
"mstore",
value_pos,
[fetcher, ["add", iter_var.pos, ["mul", ["mload", i_pos], 32]]],
],
parse_body(self.stmt.body, self.context),
]
lll_node = LLLnode.from_list(
["repeat", i_pos, 0, iter_var.size, body], typ=None, pos=getpos(self.stmt)
)
ret = ["seq"]

# List gets defined in the for statement.
elif isinstance(self.stmt.iter, vy_ast.List):
# Allocate list to memory.
count = iter_list_node.typ.count
# list literal, force it to memory first
if isinstance(self.stmt.iter, vy_ast.List):
count = iter_list.typ.count
tmp_list = LLLnode.from_list(
obj=self.context.new_internal_variable(SArrayType(subtype, count)),
typ=SArrayType(subtype, count),
location="memory",
)
setter = make_setter(tmp_list, iter_list_node, self.context, pos=getpos(self.stmt))
body = [
"seq",
["mstore", value_pos, ["mload", ["add", tmp_list, ["mul", ["mload", i_pos], 32]]]],
parse_body(self.stmt.body, self.context),
]
lll_node = LLLnode.from_list(
["seq", setter, ["repeat", i_pos, 0, count, body]], typ=None, pos=getpos(self.stmt)
)
ret.append(make_setter(tmp_list, iter_list, self.context, pos=getpos(self.stmt)))
iter_list = tmp_list

# List contained in storage.
elif isinstance(self.stmt.iter, vy_ast.Attribute):
count = iter_list_node.typ.count
body = [
"seq",
["mstore", value_pos, ["sload", ["add", iter_list_node, ["mload", i_pos]]]],
parse_body(self.stmt.body, self.context),
]
lll_node = LLLnode.from_list(
["seq", ["repeat", i_pos, 0, count, body]], typ=None, pos=getpos(self.stmt)
)
# set up the loop variable
loop_var_ast = getpos(self.stmt.target)
e = get_element_ptr(iter_list, iptr, array_bounds_check=False, pos=loop_var_ast)
body = [
"seq",
make_setter(loop_var, e, self.context, pos=loop_var_ast),
parse_body(self.stmt.body, self.context),
]

repeat_bound = iter_list.typ.count
if isinstance(iter_list.typ, DArrayType):
array_len = get_dyn_array_count(iter_list)
ret.append(["repeat", iptr, 0, array_len, repeat_bound, body])
else:
ret.append(["repeat", iptr, 0, repeat_bound, body])

# this kind of open access to the vars dict should be disallowed.
# we should use member functions to provide an API for these kinds
# of operations.
del self.context.forvars[varname]
return lll_node
return LLLnode.from_list(ret, pos=getpos(self.stmt))

def parse_AugAssign(self):
target = self._get_target(self.stmt.target)
Expand Down
8 changes: 6 additions & 2 deletions vyper/semantics/validation/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
from vyper.semantics.types.abstract import IntegerAbstractType
from vyper.semantics.types.bases import DataLocation
from vyper.semantics.types.function import ContractFunction, FunctionVisibility, StateMutability
from vyper.semantics.types.indexable.sequence import ArrayDefinition, TupleDefinition
from vyper.semantics.types.indexable.sequence import (
ArrayDefinition,
DynamicArrayDefinition,
TupleDefinition,
)
from vyper.semantics.types.user.event import Event
from vyper.semantics.types.user.struct import StructDefinition
from vyper.semantics.types.utils import get_type_from_annotation
Expand Down Expand Up @@ -359,7 +363,7 @@ def visit_For(self, node):
type_list = [
i.value_type
for i in get_possible_types_from_node(node.iter)
if isinstance(i, ArrayDefinition)
if isinstance(i, (DynamicArrayDefinition, ArrayDefinition))
]

if not type_list:
Expand Down

0 comments on commit 0004e0c

Please sign in to comment.