From 041b6f8593f682233880974c180bb834dc2e1d61 Mon Sep 17 00:00:00 2001 From: sam bacha Date: Thu, 5 Dec 2024 04:54:26 -0800 Subject: [PATCH 1/2] refactor(async): new createAsyncCall --- src/AsyncEnabled.sol | 72 ++++++++++++++++++++++++++++++++++---------- 1 file changed, 56 insertions(+), 16 deletions(-) diff --git a/src/AsyncEnabled.sol b/src/AsyncEnabled.sol index 5523519..0fba857 100644 --- a/src/AsyncEnabled.sol +++ b/src/AsyncEnabled.sol @@ -1,48 +1,65 @@ // SPDX-License-Identifier: UNLICENSED pragma solidity ^0.8.13; + import {AsyncUtils} from "./AsyncUtils.sol"; import {console} from "forge-std/console.sol"; import {LocalAsyncProxy} from "./LocalAsyncProxy.sol"; import {AsyncCall, AsyncCallback} from "./AsyncUtils.sol"; import {SuperchainEnabled} from "./SuperchainEnabled.sol"; import {AsyncPromise} from "./AsyncPromise.sol"; -import { IL2ToL2CrossDomainMessenger } from "@contracts-bedrock/L2/interfaces/IL2ToL2CrossDomainMessenger.sol"; -import { Predeploys } from "@contracts-bedrock/libraries/Predeploys.sol"; +import {IL2ToL2CrossDomainMessenger} from "@contracts-bedrock/L2/interfaces/IL2ToL2CrossDomainMessenger.sol"; +import {Predeploys} from "@contracts-bedrock/libraries/Predeploys.sol"; contract AsyncEnabled is SuperchainEnabled { - // mapping of address to chainId to remote caller proxy, should probably be private mapping(address => mapping(uint256 => LocalAsyncProxy)) public remoteAsyncProxies; constructor() { console.log("an asyncEnabled contract was just deployed!"); } - // gets a remote instance of the contract, creating it if it doesn't exist function getAsyncProxy(address _remoteAddress, uint256 _remoteChainId) internal returns (address) { - if (address(remoteAsyncProxies[_remoteAddress][_remoteChainId]) == address(0)) { - remoteAsyncProxies[_remoteAddress][_remoteChainId] = new LocalAsyncProxy{salt: bytes32(0)}(_remoteAddress, _remoteChainId); + if (isProxyNotCreated(_remoteAddress, _remoteChainId)) { + createLocalAsyncProxy(_remoteAddress, _remoteChainId); } return address(remoteAsyncProxies[_remoteAddress][_remoteChainId]); } + function isProxyNotCreated(address _remoteAddress, uint256 _remoteChainId) internal view returns (bool) { + return address(remoteAsyncProxies[_remoteAddress][_remoteChainId]) == address(0); + } + + function createLocalAsyncProxy(address _remoteAddress, uint256 _remoteChainId) internal { + remoteAsyncProxies[_remoteAddress][_remoteChainId] = new LocalAsyncProxy{salt: bytes32(0)}(_remoteAddress, _remoteChainId); + } + function relayAsyncCall(AsyncCall calldata _asyncCall) external { - // Ensure the crossDomainSender is a valid async proxy for the remote address and chain - // TODO: other sanity checks on _asyncCall values - LocalAsyncProxy expectedCrossDomainSender = AsyncUtils.calculateLocalAsyncProxyAddress( - _asyncCall.from.addr, - address(this), - block.chainid - ); - require(_isValidCrossDomainSender(address(expectedCrossDomainSender))); + require(isValidCrossDomainSender(_asyncCall), "Invalid cross-domain sender"); console.log("valid CDM, relaying async call"); - (bool success, bytes memory returndata) = address(this).call(_asyncCall.data); + (bool success, bytes memory returndata) = executeAsyncCall(_asyncCall); console.log("AsyncCallRelayer relayed, success: %s, returndata: ", success); console.logBytes(returndata); require(success, "Relaying async call failed"); + relayCallback(_asyncCall, success, returndata); + } + + function isValidCrossDomainSender(AsyncCall calldata _asyncCall) internal view returns (bool) { + LocalAsyncProxy expectedCrossDomainSender = AsyncUtils.calculateLocalAsyncProxyAddress( + _asyncCall.from.addr, + address(this), + block.chainid + ); + return _isValidCrossDomainSender(address(expectedCrossDomainSender)); + } + + function executeAsyncCall(AsyncCall calldata _asyncCall) internal returns (bool, bytes memory) { + return address(this).call(_asyncCall.data); + } + + function relayCallback(AsyncCall calldata _asyncCall, bool success, bytes memory returndata) internal { bytes32 asyncCallId = AsyncUtils.getAsyncCallId(_asyncCall); AsyncCallback memory callback = AsyncCallback({ asyncCallId: asyncCallId, @@ -65,6 +82,12 @@ contract AsyncEnabled is SuperchainEnabled { function relayAsyncCallback(AsyncCallback calldata _callback) external { console.log("in relayAsyncCallback"); + require(isValidPromiseCallbackSender(_callback), "Invalid promise callback sender"); + + executeCallback(_callback); + } + + function isValidPromiseCallbackSender(AsyncCallback calldata _callback) internal view returns (bool) { address crossDomainCallbackSender = IL2ToL2CrossDomainMessenger(Predeploys.L2_TO_L2_CROSS_DOMAIN_MESSENGER).crossDomainMessageSender(); uint256 crossDomainCallbackSource = IL2ToL2CrossDomainMessenger(Predeploys.L2_TO_L2_CROSS_DOMAIN_MESSENGER).crossDomainMessageSource(); // TODO @@ -77,7 +100,11 @@ contract AsyncEnabled is SuperchainEnabled { AsyncPromise promiseContract = remoteProxy.promisesById(_callback.asyncCallId); - require(promiseContract.remoteTarget() == crossDomainCallbackSender, "Invalid promise callback sender"); + return promiseContract.remoteTarget() == crossDomainCallbackSender; + } + + function executeCallback(AsyncCallback calldata _callback) internal { + AsyncPromise promiseContract = getPromiseContract(_callback); bytes4 callbackSelector = promiseContract.callbackSelector(); (bool success, bytes memory returnData) = address(this).call( @@ -92,6 +119,19 @@ contract AsyncEnabled is SuperchainEnabled { promiseContract.markResolved(); } + function getPromiseContract(AsyncCallback calldata _callback) internal view returns (AsyncPromise) { + address crossDomainCallbackSender = IL2ToL2CrossDomainMessenger(Predeploys.L2_TO_L2_CROSS_DOMAIN_MESSENGER).crossDomainMessageSender(); + uint256 crossDomainCallbackSource = IL2ToL2CrossDomainMessenger(Predeploys.L2_TO_L2_CROSS_DOMAIN_MESSENGER).crossDomainMessageSource(); + + LocalAsyncProxy remoteProxy = AsyncUtils.calculateLocalAsyncProxyAddress( + address(this), + crossDomainCallbackSender, + crossDomainCallbackSource + ); + + return remoteProxy.promisesById(_callback.asyncCallId); + } + modifier async() { // only callable by self via relayAsyncCall require(msg.sender == address(this)); From 569905fad919cf6080501856bc99ef28ba81a746 Mon Sep 17 00:00:00 2001 From: sam bacha Date: Thu, 5 Dec 2024 04:59:38 -0800 Subject: [PATCH 2/2] refactor(transition): separate transition and add checks --- src/AsyncPromise.sol | 51 +++++++++++++++++++++++++++++++------------- 1 file changed, 36 insertions(+), 15 deletions(-) diff --git a/src/AsyncPromise.sol b/src/AsyncPromise.sol index 20cbfeb..a4b19de 100644 --- a/src/AsyncPromise.sol +++ b/src/AsyncPromise.sol @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: UNLICENSED pragma solidity ^0.8.13; import {console} from "forge-std/console.sol"; @@ -16,33 +17,53 @@ contract AsyncPromise { bytes32 public messageId; AsyncPromiseState public state = AsyncPromiseState.WAITING_FOR_SET_CALLBACK_SELECTOR; + error OnlyInvokerAllowed(); + error PromiseAlreadySetup(); + + modifier onlyInvoker() { + if (msg.sender != localInvoker) revert OnlyInvokerAllowed(); + _; + } + constructor(address _invoker, address _remoteTarget, bytes32 _messageId) { localInvoker = _invoker; remoteTarget = _remoteTarget; messageId = _messageId; } - function markResolved() external { - require(msg.sender == localInvoker, "Only the invoker can mark this promise's callback resolved"); + function markResolved() external onlyInvoker { + _setResolved(); + } + + function _setResolved() internal { resolved = true; state = AsyncPromiseState.RESOLVED; } - fallback() external { - require(msg.sender == localInvoker, "Only the caller can set this promise's callback"); + function _isWaitingForCallback() internal view returns (bool) { + return state == AsyncPromiseState.WAITING_FOR_CALLBACK_EXECUTION; + } - if (state == AsyncPromiseState.WAITING_FOR_CALLBACK_EXECUTION) { - revert("Promise already setup"); - } + function _isWaitingForSelector() internal view returns (bool) { + return state == AsyncPromiseState.WAITING_FOR_SET_CALLBACK_SELECTOR; + } + + function _setCallbackSelector(bytes calldata data) internal { + callbackSelector = bytes4(data[24:28]); + state = AsyncPromiseState.WAITING_FOR_CALLBACK_EXECUTION; + } - if (state == AsyncPromiseState.WAITING_FOR_SET_CALLBACK_SELECTOR) { - // TODO: is there a way to confirm in the general case this is ".then"? - console.log("got callback selector"); - console.logBytes(msg.data); - // 4 bytes for the outer selector, 20 bytes for the address, 4 bytes for the callback selector - // TODO: battle test this against more examples / confirm sufficiently generalized - callbackSelector = bytes4(msg.data[24:28]); - state = AsyncPromiseState.WAITING_FOR_CALLBACK_EXECUTION; + function _handleCallbackSetup(bytes calldata data) internal { + if (_isWaitingForCallback()) { + revert PromiseAlreadySetup(); } + + if (_isWaitingForSelector()) { + _setCallbackSelector(data); + } + } + + fallback() external onlyInvoker { + _handleCallbackSetup(msg.data); } } \ No newline at end of file