Skip to content

Commit

Permalink
wip: implement methods for structs
Browse files Browse the repository at this point in the history
test contract:
```
struct Foo:
    user: address
    balance: uint256

    def decrement_balance(self, amount: uint256):
        self.balance -= amount

@external
def foo(f: Foo):
    s: Foo = f
    s.decrement_balance(1)
```
  • Loading branch information
charles-cooper committed May 21, 2024
1 parent 0ba1c62 commit d94f548
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 29 deletions.
17 changes: 9 additions & 8 deletions vyper/codegen/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,21 +677,22 @@ def parse_Call(self):
return arg_ir

if isinstance(func_t, MemberFunctionT):
darray = Expr(self.expr.func.value, self.context).ir_node
assert isinstance(darray.typ, DArrayT)
ptr = Expr(self.expr.func.value, self.context).ir_node

if isinstance(ptr.typ, StructT):
return self_call.ir_for_self_call(self.expr, self.context, ptr=ptr)

assert isinstance(ptr.typ, DArrayT)
args = [Expr(x, self.context).ir_node for x in self.expr.args]
if self.expr.func.attr == "pop":
# TODO consider moving this to builtins
darray = Expr(self.expr.func.value, self.context).ir_node
assert len(self.expr.args) == 0
return_item = not self.is_stmt
return pop_dyn_array(darray, return_popped_item=return_item)
return pop_dyn_array(ptr, return_popped_item=return_item)
elif self.expr.func.attr == "append":
(arg,) = args
check_assign(
dummy_node_for_type(darray.typ.value_type), dummy_node_for_type(arg.typ)
)
return append_dyn_array(darray, arg)
check_assign(dummy_node_for_type(ptr.typ.value_type), dummy_node_for_type(arg.typ))
return append_dyn_array(ptr, arg)

assert isinstance(func_t, ContractFunctionT)
assert func_t.is_internal or func_t.is_constructor
Expand Down
2 changes: 2 additions & 0 deletions vyper/codegen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def _runtime_reachable_functions(module_t, id_generator):
for fn_t in ret:
id_generator.ensure_id(fn_t)

print("ENTER", ret)

return ret


Expand Down
13 changes: 11 additions & 2 deletions vyper/codegen/self_call.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Optional

from vyper.codegen.core import _freshname, eval_once_check, make_setter
from vyper.codegen.ir_node import IRnode
from vyper.evm.address_space import MEMORY
from vyper.exceptions import StateAccessViolation
from vyper.semantics.types.function import MemberFunctionT

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
vyper.semantics.types.function
begins an import cycle.
from vyper.semantics.types.subscriptable import TupleT


Expand All @@ -20,7 +23,7 @@ def _align_kwargs(func_t, args_ir):
return [i.default_value for i in unprovided_kwargs]


def ir_for_self_call(stmt_expr, context):
def ir_for_self_call(stmt_expr, context, ptr: Optional[IRnode] = None):
from vyper.codegen.expr import Expr # TODO rethink this circular import

# ** Internal Call **
Expand All @@ -39,7 +42,10 @@ def ir_for_self_call(stmt_expr, context):
default_vals_ir = [Expr(x, context).ir_node for x in default_vals]

args_ir = pos_args_ir + default_vals_ir
assert len(args_ir) == len(func_t.arguments)
if isinstance(func_t, MemberFunctionT):
assert len(args_ir) == len(func_t.arg_types)
else:
assert len(args_ir) == len(func_t.arguments)

args_tuple_t = TupleT([x.typ for x in args_ir])
args_as_tuple = IRnode.from_list(["multi"] + [x for x in args_ir], typ=args_tuple_t)
Expand Down Expand Up @@ -89,6 +95,9 @@ def ir_for_self_call(stmt_expr, context):
copy_args = make_setter(args_dst, args_as_tuple)

goto_op = ["goto", func_t._ir_info.internal_function_label(context.is_ctor_context)]
if ptr is not None:
goto_op += [ptr]

# pass return buffer to subroutine
if return_buffer is not None:
goto_op += [return_buffer]
Expand Down
7 changes: 4 additions & 3 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,16 @@ def _analyze_call_graph(module_ast: vy_ast.Module):
for call in function_calls:
try:
call_t = get_exact_type_from_node(call.func)
except VyperException:
except VyperException as e:
# there is a problem getting the call type. this might be
# an issue, but it will be handled properly later. right now
# we just want to be able to construct the call graph.
print("ENTER", e, call)
continue

if isinstance(call_t, ContractFunctionT) and (
if isinstance(call_t, MemberFunctionT) or (isinstance(call_t, ContractFunctionT) and (
call_t.is_internal or call_t.is_constructor
):
)):
fn_t.called_functions.add(call_t)

for func in function_defs:
Expand Down
21 changes: 21 additions & 0 deletions vyper/semantics/types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,15 @@ def __init__(
self.return_type = return_type
self.is_modifying = is_modifying

self._ir_info = None

@classmethod
def from_FunctionDef(cls, structname: str, funcdef: vy_ast.FunctionDef):
args = funcdef.args.args[1:]
argtypes = [type_from_annotation(arg.annotation) for arg in args]
return_type = _parse_return_type(funcdef)
return cls(structname, funcdef.name, argtypes, return_type, True)

@property
def modifiability(self):
return Modifiability.MODIFIABLE if self.is_modifying else Modifiability.RUNTIME_CONSTANT
Expand All @@ -856,6 +865,18 @@ def modifiability(self):
def _id(self):
return self.name

@property
def n_positional_args(self):
return len(self.arg_types)

@property
def n_total_args(self):
return self.n_positional_args

@property
def keyword_args(self):
return []

def __repr__(self):
return f"{self.underlying_type._id} member function '{self.name}'"

Expand Down
6 changes: 4 additions & 2 deletions vyper/semantics/types/subscriptable.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,10 @@ def __init__(self, value_type: VyperType, length: int) -> None:

from vyper.semantics.types.function import MemberFunctionT

self.add_member("append", MemberFunctionT(self, "append", [self.value_type], None, True))
self.add_member("pop", MemberFunctionT(self, "pop", [], self.value_type, True))
self.add_member(
"append", MemberFunctionT(self._id, "append", [self.value_type], None, True)
)
self.add_member("pop", MemberFunctionT(self._id, "pop", [], self.value_type, True))

def __repr__(self):
return f"DynArray[{self.value_type}, {self.length}]"
Expand Down
33 changes: 19 additions & 14 deletions vyper/semantics/types/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from vyper.semantics.analysis.utils import check_modifiability, validate_expected_type
from vyper.semantics.data_locations import DataLocation
from vyper.semantics.types.base import VyperType
from vyper.semantics.types.function import MemberFunctionT

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
vyper.semantics.types.function
begins an import cycle.
from vyper.semantics.types.subscriptable import HashMapT
from vyper.semantics.types.utils import type_from_abi, type_from_annotation
from vyper.utils import keccak256
Expand Down Expand Up @@ -324,14 +325,11 @@ def tuple_keys(self):
return [k for (k, _v) in self.tuple_items()]

def tuple_items(self):
return list(self.members.items())
return list(self.member_types.items())

@cached_property
def member_types(self):
"""
Alias to match TupleT API without shadowing `members` on TupleT
"""
return self.members
return {k: v for k, v in self.members.items() if not isinstance(v, MemberFunctionT)}

@classmethod
def from_StructDef(cls, base_node: vy_ast.StructDef) -> "StructT":
Expand All @@ -351,23 +349,30 @@ def from_StructDef(cls, base_node: vy_ast.StructDef) -> "StructT":
struct_name = base_node.name
members: dict[str, VyperType] = {}
for node in base_node.body:
if not isinstance(node, vy_ast.AnnAssign):
if not isinstance(node, (vy_ast.AnnAssign, vy_ast.FunctionDef)):
raise StructureException(
"Struct declarations can only contain variable definitions", node
"Struct declarations can only contain variable or function definitions", node
)
if node.value is not None:
raise StructureException("Cannot assign a value during struct declaration", node)
if not isinstance(node.target, vy_ast.Name):
raise StructureException("Invalid syntax for struct member name", node.target)
member_name = node.target.id
if isinstance(node, vy_ast.AnnAssign):
if node.value is not None:
raise StructureException(
"Cannot assign a value during struct declaration", node
)
if not isinstance(node.target, vy_ast.Name):
raise StructureException("Invalid syntax for struct member name", node.target)
member_name = node.target.id
typ = type_from_annotation(node.annotation)
else:
member_name = node.name
typ = MemberFunctionT.from_FunctionDef(struct_name, node)

if member_name in members:
# TODO: add prev_decl
raise NamespaceCollision(
f"struct member '{member_name}' has already been declared", node.value
f"struct member '{member_name}' has already been declared", node
)

members[member_name] = type_from_annotation(node.annotation)
members[member_name] = typ

return cls(struct_name, members, ast_def=base_node)

Expand Down

0 comments on commit d94f548

Please sign in to comment.