Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[lang]: protect external calls with keyword #2938

Merged
merged 44 commits into from Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
aa8d7e9
add await to lark grammar
charles-cooper Jun 24, 2022
a5e1645
add await to ast.nodes, add typechecker rules
charles-cooper Jun 24, 2022
a123ff7
fix lint, mypy
charles-cooper Jun 24, 2022
609a5b4
Merge branch 'master' into feat/await
charles-cooper Apr 8, 2023
20a9f90
thread await down thru codegen
charles-cooper Apr 8, 2023
5107e11
fix mypy
charles-cooper Apr 8, 2023
bc064a3
Merge branch 'master' into feat/await
charles-cooper Dec 19, 2023
bd4c963
Merge branch 'master' into feat/await
charles-cooper Feb 20, 2024
ed4a4ec
improve an exception
charles-cooper Feb 20, 2024
f2d9812
refactor extcall handling
charles-cooper Feb 22, 2024
e44337e
fix some small issues
charles-cooper Feb 23, 2024
af8df84
fix dead function
charles-cooper Feb 28, 2024
d114c09
undo changes to extcall ast, add staticcall support
charles-cooper Feb 28, 2024
365e3f0
fix a check
charles-cooper Feb 29, 2024
1f30ea9
Merge branch 'master' into feat/await
charles-cooper Feb 29, 2024
2f432a5
update the grammar
charles-cooper Feb 29, 2024
ccfd771
fixup modification_offsets
charles-cooper Feb 29, 2024
69b0580
rename a variable
charles-cooper Feb 29, 2024
d8a03e2
fix some grammar rules
charles-cooper Mar 1, 2024
37e29ac
fix grammar
charles-cooper Mar 1, 2024
bd4f3b5
update existing tests
charles-cooper Mar 2, 2024
c75248a
ban standalone staticcalls
charles-cooper Mar 2, 2024
271de77
more test updates
charles-cooper Mar 2, 2024
2fb5759
more test updates
charles-cooper Mar 2, 2024
bd53dab
grammar updates
charles-cooper Mar 2, 2024
d60ae63
fix examples
tserg Mar 2, 2024
b80e593
fix some tests
tserg Mar 2, 2024
670313e
refactor grammar
charles-cooper Mar 2, 2024
8d3e0cc
Merge pull request #27 from tserg/tests/await
charles-cooper Mar 3, 2024
1db1321
fix example
tserg Mar 3, 2024
4f183fe
fix tests
tserg Mar 3, 2024
9b4a831
Merge pull request #28 from tserg/tests/await
charles-cooper Mar 3, 2024
62f60d6
fix more tests
charles-cooper Mar 3, 2024
4422425
Merge branch 'master' into feat/await
charles-cooper Mar 3, 2024
fa70c34
fix lint
charles-cooper Mar 3, 2024
5808c10
fix some more tests
charles-cooper Mar 3, 2024
e014b40
minor refactoring
charles-cooper Mar 4, 2024
5aa263f
tests for new extcall/staticcall syntax
charles-cooper Mar 4, 2024
632e66c
fix referenced node for some error messages
charles-cooper Mar 4, 2024
5f728c3
add a getter for node.parent
charles-cooper Mar 4, 2024
0de483f
add more tests for invalid extcall/staticcalls
charles-cooper Mar 4, 2024
44e0294
fix test
tserg Mar 4, 2024
f83ffcc
remove dead functions
charles-cooper Mar 4, 2024
b625c04
Merge pull request #29 from tserg/test/await
charles-cooper Mar 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/factory/Exchange.vy
Expand Up @@ -20,7 +20,7 @@ def __init__(_token: IERC20, _factory: Factory):
@external
def initialize():
# Anyone can safely call this function because of EXTCODEHASH
self.factory.register()
extcall self.factory.register()


# NOTE: This contract restricts trading to only be done by the factory.
Expand All @@ -31,12 +31,12 @@ def initialize():
@external
def receive(_from: address, _amt: uint256):
assert msg.sender == self.factory.address
success: bool = self.token.transferFrom(_from, self, _amt)
success: bool = extcall self.token.transferFrom(_from, self, _amt)
assert success


@external
def transfer(_to: address, _amt: uint256):
assert msg.sender == self.factory.address
success: bool = self.token.transfer(_to, _amt)
success: bool = extcall self.token.transfer(_to, _amt)
assert success
6 changes: 3 additions & 3 deletions examples/factory/Factory.vy
Expand Up @@ -35,12 +35,12 @@ def register():
# NOTE: Should do checks that it hasn't already been set,
# which has to be rectified with any upgrade strategy.
exchange: Exchange = Exchange(msg.sender)
self.exchanges[exchange.token()] = exchange
self.exchanges[staticcall exchange.token()] = exchange


@external
def trade(_token1: IERC20, _token2: IERC20, _amt: uint256):
# Perform a straight exchange of token1 to token 2 (1:1 price)
# NOTE: Any practical implementation would need to solve the price oracle problem
self.exchanges[_token1].receive(msg.sender, _amt)
self.exchanges[_token2].transfer(msg.sender, _amt)
extcall self.exchanges[_token1].receive(msg.sender, _amt)
extcall self.exchanges[_token2].transfer(msg.sender, _amt)
12 changes: 6 additions & 6 deletions examples/market_maker/on_chain_market_maker.vy
Expand Up @@ -8,7 +8,7 @@ totalTokenQty: public(uint256)
# Constant set in `initiate` that's used to calculate
# the amount of ether/tokens that are exchanged
invariant: public(uint256)
token_address: IERC20
token: IERC20
owner: public(address)

# Sets the on chain market maker with its owner, initial token quantity,
Expand All @@ -17,8 +17,8 @@ owner: public(address)
@payable
def initiate(token_addr: address, token_quantity: uint256):
assert self.invariant == 0
self.token_address = IERC20(token_addr)
self.token_address.transferFrom(msg.sender, self, token_quantity)
self.token = IERC20(token_addr)
extcall self.token.transferFrom(msg.sender, self, token_quantity)
self.owner = msg.sender
self.totalEthQty = msg.value
self.totalTokenQty = token_quantity
Expand All @@ -33,14 +33,14 @@ def ethToTokens():
eth_in_purchase: uint256 = msg.value - fee
new_total_eth: uint256 = self.totalEthQty + eth_in_purchase
new_total_tokens: uint256 = self.invariant // new_total_eth
self.token_address.transfer(msg.sender, self.totalTokenQty - new_total_tokens)
extcall self.token.transfer(msg.sender, self.totalTokenQty - new_total_tokens)
self.totalEthQty = new_total_eth
self.totalTokenQty = new_total_tokens

# Sells tokens to the contract in exchange for ether
@external
def tokensToEth(sell_quantity: uint256):
self.token_address.transferFrom(msg.sender, self, sell_quantity)
extcall self.token.transferFrom(msg.sender, self, sell_quantity)
new_total_tokens: uint256 = self.totalTokenQty + sell_quantity
new_total_eth: uint256 = self.invariant // new_total_tokens
eth_to_send: uint256 = self.totalEthQty - new_total_eth
Expand All @@ -52,5 +52,5 @@ def tokensToEth(sell_quantity: uint256):
@external
def ownerWithdraw():
assert self.owner == msg.sender
self.token_address.transfer(self.owner, self.totalTokenQty)
extcall self.token.transfer(self.owner, self.totalTokenQty)
selfdestruct(self.owner)
20 changes: 10 additions & 10 deletions examples/tokens/ERC4626.vy
Expand Up @@ -79,7 +79,7 @@ def transferFrom(sender: address, receiver: address, amount: uint256) -> bool:
@view
@external
def totalAssets() -> uint256:
return self.asset.balanceOf(self)
return staticcall self.asset.balanceOf(self)


@view
Expand All @@ -91,7 +91,7 @@ def _convertToAssets(shareAmount: uint256) -> uint256:

# NOTE: `shareAmount = 0` is extremely rare case, not optimizing for it
# NOTE: `totalAssets = 0` is extremely rare case, not optimizing for it
return shareAmount * self.asset.balanceOf(self) // totalSupply
return shareAmount * staticcall self.asset.balanceOf(self) // totalSupply


@view
Expand All @@ -104,7 +104,7 @@ def convertToAssets(shareAmount: uint256) -> uint256:
@internal
def _convertToShares(assetAmount: uint256) -> uint256:
totalSupply: uint256 = self.totalSupply
totalAssets: uint256 = self.asset.balanceOf(self)
totalAssets: uint256 = staticcall self.asset.balanceOf(self)
if totalAssets == 0 or totalSupply == 0:
return assetAmount # 1:1 price

Expand Down Expand Up @@ -133,7 +133,7 @@ def previewDeposit(assets: uint256) -> uint256:
@external
def deposit(assets: uint256, receiver: address=msg.sender) -> uint256:
shares: uint256 = self._convertToShares(assets)
self.asset.transferFrom(msg.sender, self, assets)
extcall self.asset.transferFrom(msg.sender, self, assets)

self.totalSupply += shares
self.balanceOf[receiver] += shares
Expand All @@ -153,7 +153,7 @@ def previewMint(shares: uint256) -> uint256:
assets: uint256 = self._convertToAssets(shares)

# NOTE: Vyper does lazy eval on `and`, so this avoids SLOADs most of the time
if assets == 0 and self.asset.balanceOf(self) == 0:
if assets == 0 and staticcall self.asset.balanceOf(self) == 0:
return shares # NOTE: Assume 1:1 price if nothing deposited yet

return assets
Expand All @@ -163,10 +163,10 @@ def previewMint(shares: uint256) -> uint256:
def mint(shares: uint256, receiver: address=msg.sender) -> uint256:
assets: uint256 = self._convertToAssets(shares)

if assets == 0 and self.asset.balanceOf(self) == 0:
if assets == 0 and staticcall self.asset.balanceOf(self) == 0:
assets = shares # NOTE: Assume 1:1 price if nothing deposited yet

self.asset.transferFrom(msg.sender, self, assets)
extcall self.asset.transferFrom(msg.sender, self, assets)

self.totalSupply += shares
self.balanceOf[receiver] += shares
Expand Down Expand Up @@ -206,7 +206,7 @@ def withdraw(assets: uint256, receiver: address=msg.sender, owner: address=msg.s
self.totalSupply -= shares
self.balanceOf[owner] -= shares

self.asset.transfer(receiver, assets)
extcall self.asset.transfer(receiver, assets)
log IERC4626.Withdraw(msg.sender, receiver, owner, assets, shares)
return shares

Expand All @@ -232,7 +232,7 @@ def redeem(shares: uint256, receiver: address=msg.sender, owner: address=msg.sen
self.totalSupply -= shares
self.balanceOf[owner] -= shares

self.asset.transfer(receiver, assets)
extcall self.asset.transfer(receiver, assets)
log IERC4626.Withdraw(msg.sender, receiver, owner, assets, shares)
return assets

Expand All @@ -241,4 +241,4 @@ def redeem(shares: uint256, receiver: address=msg.sender, owner: address=msg.sen
def DEBUG_steal_tokens(amount: uint256):
# NOTE: This is the primary method of mocking share price changes
# do not put in production code!!!
self.asset.transfer(msg.sender, amount)
extcall self.asset.transfer(msg.sender, amount)
2 changes: 1 addition & 1 deletion examples/tokens/ERC721.vy
Expand Up @@ -248,7 +248,7 @@ def safeTransferFrom(
"""
self._transferFrom(_from, _to, _tokenId, msg.sender)
if _to.is_contract: # check if `_to` is a contract address
returnValue: bytes4 = ERC721Receiver(_to).onERC721Received(msg.sender, _from, _tokenId, _data)
returnValue: bytes4 = extcall ERC721Receiver(_to).onERC721Received(msg.sender, _from, _tokenId, _data)
# Throws if transfer destination is a contract which does not implement 'onERC721Received'
assert returnValue == method_id("onERC721Received(address,address,uint256,bytes)", output_type=bytes4)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -17,7 +17,7 @@
"py-evm>=0.7.0a1,<0.8",
"web3==6.0.0",
"tox>=3.15,<4.0",
"lark==1.1.2",
"lark==1.1.9",
"hypothesis[lark]>=5.37.1,<6.0",
"eth-stdlib==0.2.6",
],
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/builtins/codegen/test_abi_decode.py
Expand Up @@ -244,7 +244,7 @@ def get_counter() -> Bytes[128]: nonpayable
def foo(addr: address) -> (uint256, String[5]):
a: uint256 = 0
b: String[5] = ""
a, b = _abi_decode(Foo(addr).get_counter(), (uint256, String[5]), unwrap_tuple=False)
a, b = _abi_decode(extcall Foo(addr).get_counter(), (uint256, String[5]), unwrap_tuple=False)
return a, b
"""

Expand Down
2 changes: 1 addition & 1 deletion tests/functional/builtins/codegen/test_abi_encode.py
Expand Up @@ -281,7 +281,7 @@ def get_counter() -> (uint256, String[6]): nonpayable

@external
def foo(addr: address) -> Bytes[164]:
return _abi_encode(Foo(addr).get_counter(), method_id=0xdeadbeef)
return _abi_encode(extcall Foo(addr).get_counter(), method_id=0xdeadbeef)
"""

c2 = get_contract(contract_2)
Expand Down
8 changes: 4 additions & 4 deletions tests/functional/builtins/codegen/test_addmod.py
Expand Up @@ -19,12 +19,12 @@ def test_uint256_addmod_ext_call(
w3, side_effects_contract, assert_side_effects_invoked, get_contract
):
code = """
@external
def foo(f: Foo) -> uint256:
return uint256_addmod(32, 2, f.foo(32))

interface Foo:
def foo(x: uint256) -> uint256: payable

@external
def foo(f: Foo) -> uint256:
return uint256_addmod(32, 2, extcall f.foo(32))
"""

c1 = side_effects_contract("uint256")
Expand Down
8 changes: 4 additions & 4 deletions tests/functional/builtins/codegen/test_as_wei_value.py
Expand Up @@ -97,12 +97,12 @@ def foo(a: {data_type}) -> uint256:

def test_ext_call(w3, side_effects_contract, assert_side_effects_invoked, get_contract):
code = """
@external
def foo(a: Foo) -> uint256:
return as_wei_value(a.foo(7), "ether")

interface Foo:
def foo(x: uint8) -> uint8: nonpayable

@external
def foo(a: Foo) -> uint256:
return as_wei_value(extcall a.foo(7), "ether")
"""

c1 = side_effects_contract("uint8")
Expand Down
8 changes: 4 additions & 4 deletions tests/functional/builtins/codegen/test_ceil.py
Expand Up @@ -108,12 +108,12 @@ def ceil_param(p: decimal) -> int256:

def test_ceil_ext_call(w3, side_effects_contract, assert_side_effects_invoked, get_contract):
code = """
@external
def foo(a: Foo) -> int256:
return ceil(a.foo(2.5))

interface Foo:
def foo(x: decimal) -> decimal: payable

@external
def foo(a: Foo) -> int256:
return ceil(extcall a.foo(2.5))
"""

c1 = side_effects_contract("decimal")
Expand Down
17 changes: 2 additions & 15 deletions tests/functional/builtins/codegen/test_create_functions.py
Expand Up @@ -44,30 +44,23 @@ def test() -> address:

def test_create_minimal_proxy_to_call(get_contract, w3):
code = """

interface SubContract:

def hello() -> Bytes[100]: view


other: public(address)


@external
def test() -> address:
self.other = create_minimal_proxy_to(self)
return self.other


@external
def hello() -> Bytes[100]:
return b"hello world!"


@external
def test2() -> Bytes[100]:
return SubContract(self.other).hello()

return staticcall SubContract(self.other).hello()
"""

c = get_contract(code)
Expand All @@ -79,30 +72,24 @@ def test2() -> Bytes[100]:

def test_minimal_proxy_exception(w3, get_contract, tx_failed):
code = """

interface SubContract:

def hello(a: uint256) -> Bytes[100]: view


other: public(address)


@external
def test() -> address:
self.other = create_minimal_proxy_to(self)
return self.other


@external
def hello(a: uint256) -> Bytes[100]:
assert a > 0, "invaliddddd"
return b"hello world!"


@external
def test2(a: uint256) -> Bytes[100]:
return SubContract(self.other).hello(a)
return staticcall SubContract(self.other).hello(a)
"""

c = get_contract(code)
Expand Down
4 changes: 2 additions & 2 deletions tests/functional/builtins/codegen/test_ec.py
Expand Up @@ -66,7 +66,7 @@ def foo(x: uint256[2]) -> uint256[2]: payable

@external
def foo(a: Foo) -> uint256[2]:
return ecadd([1, 2], a.foo([1, 2]))
return ecadd([1, 2], extcall a.foo([1, 2]))
"""
c1 = side_effects_contract("uint256[2]")
c2 = get_contract(code)
Expand Down Expand Up @@ -148,7 +148,7 @@ def foo(x: uint256) -> uint256: payable

@external
def foo(a: Foo) -> uint256[2]:
return ecmul([1, 2], a.foo(3))
return ecmul([1, 2], extcall a.foo(3))
"""
c1 = side_effects_contract("uint256")
c2 = get_contract(code)
Expand Down
4 changes: 2 additions & 2 deletions tests/functional/builtins/codegen/test_empty.py
Expand Up @@ -445,7 +445,7 @@ def pub2() -> bool:
@external
def pub3(x: address) -> bool:
self.write_junk_to_memory()
return Mirror(x).test_empty(empty(int128[111]), empty(Bytes[1024]), empty(Bytes[31]))
return staticcall Mirror(x).test_empty(empty(int128[111]), empty(Bytes[1024]), empty(Bytes[31]))
"""
c = get_contract_with_gas_estimation(code)
mirror = get_contract_with_gas_estimation(code)
Expand Down Expand Up @@ -658,7 +658,7 @@ def foo(
@view
@external
def bar(a: address) -> (uint256, Bytes[33], Bytes[65], uint256):
return Foo(a).foo(12, {a}, 42, {b})
return staticcall Foo(a).foo(12, {a}, 42, {b})
"""

c1 = get_contract(code_a)
Expand Down
8 changes: 4 additions & 4 deletions tests/functional/builtins/codegen/test_floor.py
Expand Up @@ -112,12 +112,12 @@ def floor_param(p: decimal) -> int256:

def test_floor_ext_call(w3, side_effects_contract, assert_side_effects_invoked, get_contract):
code = """
@external
def foo(a: Foo) -> int256:
return floor(a.foo(2.5))

interface Foo:
def foo(x: decimal) -> decimal: nonpayable

@external
def foo(a: Foo) -> int256:
return floor(extcall a.foo(2.5))
"""

c1 = side_effects_contract("decimal")
Expand Down
8 changes: 4 additions & 4 deletions tests/functional/builtins/codegen/test_mulmod.py
Expand Up @@ -36,12 +36,12 @@ def test_uint256_mulmod_ext_call(
w3, side_effects_contract, assert_side_effects_invoked, get_contract
):
code = """
@external
def foo(f: Foo) -> uint256:
return uint256_mulmod(200, 3, f.foo(601))

interface Foo:
def foo(x: uint256) -> uint256: nonpayable

@external
def foo(f: Foo) -> uint256:
return uint256_mulmod(200, 3, extcall f.foo(601))
"""

c1 = side_effects_contract("uint256")
Expand Down