From c5f03bddbeefe2547a41766e79f7c60d97e0ca3e Mon Sep 17 00:00:00 2001 From: Ryan Pate Date: Thu, 1 Feb 2024 11:46:54 -0800 Subject: [PATCH] fix(cheatcodes): Properly record call to create2 factory in state diff --- crates/cheatcodes/src/inspector.rs | 51 +++++++- crates/forge/tests/it/repros.rs | 7 ++ testdata/repros/Issue6634.t.sol | 191 +++++++++++++++++++++++++++++ 3 files changed, 245 insertions(+), 4 deletions(-) create mode 100644 testdata/repros/Issue6634.t.sol diff --git a/crates/cheatcodes/src/inspector.rs b/crates/cheatcodes/src/inspector.rs index e4131030d1f2..e20d96c60d5c 100644 --- a/crates/cheatcodes/src/inspector.rs +++ b/crates/cheatcodes/src/inspector.rs @@ -1237,7 +1237,13 @@ impl Inspector for Cheatcodes { // Apply the Create2 deployer if self.broadcast.is_some() || self.config.always_use_create_2_factory { - match apply_create2_deployer(data, call, &self.prank, &self.broadcast) { + match apply_create2_deployer( + data, + call, + &self.prank, + &self.broadcast, + &mut self.recorded_account_diffs_stack, + ) { Ok(_val) => {} Err(err) => return (InstructionResult::Revert, None, gas, Error::encode(err)), }; @@ -1249,6 +1255,15 @@ impl Inspector for Cheatcodes { let address = self.allow_cheatcodes_on_create(data, call); // If `recordAccountAccesses` has been called, record the create if let Some(recorded_account_diffs_stack) = &mut self.recorded_account_diffs_stack { + // If the create scheme is create2, and the caller is the DEFAULT_CREATE2_DEPLOYER then + // we must add 1 to the depth to account for the call to the create2 factory. + let mut depth = data.journaled_state.depth(); + if let CreateScheme::Create2 { salt: _ } = call.scheme { + if call.caller == DEFAULT_CREATE2_DEPLOYER { + depth += 1; + } + } + // Record the create context as an account access and create a new vector to record all // subsequent account accesses recorded_account_diffs_stack.push(vec![AccountAccess { @@ -1269,7 +1284,7 @@ impl Inspector for Cheatcodes { deployedCode: vec![], // updated on create_end storageAccesses: vec![], // updated on create_end }, - depth: data.journaled_state.depth(), + depth, }]); } @@ -1432,17 +1447,45 @@ fn apply_create2_deployer( call: &mut CreateInputs, prank: &Option, broadcast: &Option, + diffs_stack: &mut Option>>, ) -> Result<(), DB::Error> { - if let CreateScheme::Create2 { salt: _ } = call.scheme { + if let CreateScheme::Create2 { salt } = call.scheme { let mut base_depth = 1; if let Some(prank) = &prank { base_depth = prank.depth; } else if let Some(broadcast) = &broadcast { base_depth = broadcast.depth; } + // If the create scheme is Create2 and the depth equals the broadcast/prank/default // depth, then use the default create2 factory as the deployer if data.journaled_state.depth() == base_depth { + if let Some(recorded_account_diffs_stack) = diffs_stack { + // If broadcasting, or the create2 factory option is enabled, then record + // the call to the create2 factory. + let calldata = [&salt.to_be_bytes::<32>()[..], &call.init_code[..]].concat(); + recorded_account_diffs_stack.push(vec![AccountAccess { + access: crate::Vm::AccountAccess { + chainInfo: crate::Vm::ChainInfo { + forkId: data.db.active_fork_id().unwrap_or_default(), + chainId: U256::from(data.env.cfg.chain_id), + }, + accessor: call.caller, + account: DEFAULT_CREATE2_DEPLOYER, + kind: crate::Vm::AccountAccessKind::Call, + initialized: true, + oldBalance: U256::ZERO, // updated on create_end + newBalance: U256::ZERO, // updated on create_end + value: call.value, + data: calldata, + reverted: false, + deployedCode: vec![], // updated on create_end + storageAccesses: vec![], // updated on create_end + }, + depth: data.journaled_state.depth(), + }]) + } + // Sanity checks for our CREATE2 deployer let info = &data.journaled_state.load_account(DEFAULT_CREATE2_DEPLOYER, data.db)?.0.info; @@ -1475,9 +1518,9 @@ fn process_broadcast_create( data: &mut EVMData<'_, DB>, call: &mut CreateInputs, ) -> (Bytes, Option
, u64) { + call.caller = broadcast_sender; match call.scheme { CreateScheme::Create => { - call.caller = broadcast_sender; (bytecode, None, data.journaled_state.account(broadcast_sender).info.nonce) } CreateScheme::Create2 { salt } => { diff --git a/crates/forge/tests/it/repros.rs b/crates/forge/tests/it/repros.rs index 90a253e2242f..37e7be545806 100644 --- a/crates/forge/tests/it/repros.rs +++ b/crates/forge/tests/it/repros.rs @@ -307,3 +307,10 @@ test_repro!(5529; |config| { cheats_config.always_use_create_2_factory = true; config.runner.cheats_config = std::sync::Arc::new(cheats_config); }); + +// https://github.com/foundry-rs/foundry/issues/6634 +test_repro!(6634; |config| { + let mut cheats_config = config.runner.cheats_config.as_ref().clone(); + cheats_config.always_use_create_2_factory = true; + config.runner.cheats_config = std::sync::Arc::new(cheats_config); +}); diff --git a/testdata/repros/Issue6634.t.sol b/testdata/repros/Issue6634.t.sol new file mode 100644 index 000000000000..8276462c1015 --- /dev/null +++ b/testdata/repros/Issue6634.t.sol @@ -0,0 +1,191 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +pragma solidity 0.8.18; + +import "ds-test/test.sol"; +import "../cheats/Vm.sol"; +import "../logs/console.sol"; + +contract Box { + uint256 public number; + + constructor(uint256 _number) { + number = _number; + } +} + +// https://github.com/foundry-rs/foundry/issues/6634 +contract Issue6634Test is DSTest { + Vm constant vm = Vm(HEVM_ADDRESS); + + function test() public { + address CREATE2_DEPLOYER = 0x4e59b44847b379578588920cA78FbF26c0B4956C; + + vm.startStateDiffRecording(); + Box a = new Box{salt: 0}(1); + + Vm.AccountAccess[] memory called = vm.stopAndReturnStateDiff(); + address addr = vm.computeCreate2Address( + 0, + keccak256(abi.encodePacked(type(Box).creationCode, uint(1))), + address(CREATE2_DEPLOYER) + ); + assertEq( + addr, + called[1].account, + "state diff contract address is not correct" + ); + assertEq( + address(a), + called[1].account, + "returned address is not correct" + ); + + assertEq(called.length, 2, "incorrect length"); + assertEq( + uint256(called[0].kind), + uint256(Vm.AccountAccessKind.Call), + "first AccountAccess is incorrect kind" + ); + assertEq( + called[0].account, + CREATE2_DEPLOYER, + "first AccountAccess account is incorrect" + ); + assertEq( + called[0].accessor, + address(this), + "first AccountAccess accessor is incorrect" + ); + assertEq( + uint256(called[1].kind), + uint256(Vm.AccountAccessKind.Create), + "second AccountAccess is incorrect kind" + ); + assertEq( + called[1].accessor, + CREATE2_DEPLOYER, + "second AccountAccess accessor is incorrect" + ); + assertEq( + called[1].account, + address(a), + "second AccountAccess account is incorrect" + ); + } + + function testPrank() public { + address CREATE2_DEPLOYER = 0x4e59b44847b379578588920cA78FbF26c0B4956C; + address accessor = address(0x5555); + + vm.startPrank(accessor); + vm.startStateDiffRecording(); + Box a = new Box{salt: 0}(1); + + Vm.AccountAccess[] memory called = vm.stopAndReturnStateDiff(); + address addr = vm.computeCreate2Address( + 0, + keccak256(abi.encodePacked(type(Box).creationCode, uint(1))), + address(CREATE2_DEPLOYER) + ); + assertEq( + addr, + called[1].account, + "state diff contract address is not correct" + ); + assertEq( + address(a), + called[1].account, + "returned address is not correct" + ); + + assertEq(called.length, 2, "incorrect length"); + assertEq( + uint256(called[0].kind), + uint256(Vm.AccountAccessKind.Call), + "first AccountAccess is incorrect kind" + ); + assertEq( + called[0].account, + CREATE2_DEPLOYER, + "first AccountAccess accout is incorrect" + ); + assertEq( + called[0].accessor, + accessor, + "first AccountAccess accessor is incorrect" + ); + assertEq( + uint256(called[1].kind), + uint256(Vm.AccountAccessKind.Create), + "second AccountAccess is incorrect kind" + ); + assertEq( + called[1].accessor, + CREATE2_DEPLOYER, + "second AccountAccess accessor is incorrect" + ); + assertEq( + called[1].account, + address(a), + "second AccountAccess account is incorrect" + ); + } + + function testBroadcast() public { + address CREATE2_DEPLOYER = 0x4e59b44847b379578588920cA78FbF26c0B4956C; + address accessor = address(0x5555); + + vm.startBroadcast(accessor); + vm.startStateDiffRecording(); + Box a = new Box{salt: 0}(1); + + Vm.AccountAccess[] memory called = vm.stopAndReturnStateDiff(); + address addr = vm.computeCreate2Address( + 0, + keccak256(abi.encodePacked(type(Box).creationCode, uint(1))), + address(CREATE2_DEPLOYER) + ); + assertEq( + addr, + called[1].account, + "state diff contract address is not correct" + ); + assertEq( + address(a), + called[1].account, + "returned address is not correct" + ); + + assertEq(called.length, 2, "incorrect length"); + assertEq( + uint256(called[0].kind), + uint256(Vm.AccountAccessKind.Call), + "first AccountAccess is incorrect kind" + ); + assertEq( + called[0].account, + CREATE2_DEPLOYER, + "first AccountAccess accout is incorrect" + ); + assertEq( + called[0].accessor, + accessor, + "first AccountAccess accessor is incorrect" + ); + assertEq( + uint256(called[1].kind), + uint256(Vm.AccountAccessKind.Create), + "second AccountAccess is incorrect kind" + ); + assertEq( + called[1].accessor, + CREATE2_DEPLOYER, + "second AccountAccess accessor is incorrect" + ); + assertEq( + called[1].account, + address(a), + "second AccountAccess account is incorrect" + ); + } +}