Skip to content

Commit

Permalink
Merge pull request #3 from RPate97/pate/include-create2-factory-call
Browse files Browse the repository at this point in the history
fix(cheatcodes): Properly record call to create2 factory in state diff
  • Loading branch information
RPate97 committed Feb 3, 2024
2 parents 2110f14 + c5f03bd commit fcb7c1e
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 4 deletions.
51 changes: 47 additions & 4 deletions crates/cheatcodes/src/inspector.rs
Expand Up @@ -1237,7 +1237,13 @@ impl<DB: DatabaseExt> Inspector<DB> for Cheatcodes {

// Apply the Create2 deployer
if self.broadcast.is_some() || self.config.always_use_create_2_factory {
match apply_create2_deployer(data, call, &self.prank, &self.broadcast) {
match apply_create2_deployer(
data,
call,
&self.prank,
&self.broadcast,
&mut self.recorded_account_diffs_stack,
) {
Ok(_val) => {}
Err(err) => return (InstructionResult::Revert, None, gas, Error::encode(err)),
};
Expand All @@ -1249,6 +1255,15 @@ impl<DB: DatabaseExt> Inspector<DB> for Cheatcodes {
let address = self.allow_cheatcodes_on_create(data, call);
// If `recordAccountAccesses` has been called, record the create
if let Some(recorded_account_diffs_stack) = &mut self.recorded_account_diffs_stack {
// If the create scheme is create2, and the caller is the DEFAULT_CREATE2_DEPLOYER then
// we must add 1 to the depth to account for the call to the create2 factory.
let mut depth = data.journaled_state.depth();
if let CreateScheme::Create2 { salt: _ } = call.scheme {
if call.caller == DEFAULT_CREATE2_DEPLOYER {
depth += 1;
}
}

// Record the create context as an account access and create a new vector to record all
// subsequent account accesses
recorded_account_diffs_stack.push(vec![AccountAccess {
Expand All @@ -1269,7 +1284,7 @@ impl<DB: DatabaseExt> Inspector<DB> for Cheatcodes {
deployedCode: vec![], // updated on create_end
storageAccesses: vec![], // updated on create_end
},
depth: data.journaled_state.depth(),
depth,
}]);
}

Expand Down Expand Up @@ -1432,17 +1447,45 @@ fn apply_create2_deployer<DB: DatabaseExt>(
call: &mut CreateInputs,
prank: &Option<Prank>,
broadcast: &Option<Broadcast>,
diffs_stack: &mut Option<Vec<Vec<AccountAccess>>>,
) -> Result<(), DB::Error> {
if let CreateScheme::Create2 { salt: _ } = call.scheme {
if let CreateScheme::Create2 { salt } = call.scheme {
let mut base_depth = 1;
if let Some(prank) = &prank {
base_depth = prank.depth;
} else if let Some(broadcast) = &broadcast {
base_depth = broadcast.depth;
}

// If the create scheme is Create2 and the depth equals the broadcast/prank/default
// depth, then use the default create2 factory as the deployer
if data.journaled_state.depth() == base_depth {
if let Some(recorded_account_diffs_stack) = diffs_stack {
// If broadcasting, or the create2 factory option is enabled, then record
// the call to the create2 factory.
let calldata = [&salt.to_be_bytes::<32>()[..], &call.init_code[..]].concat();
recorded_account_diffs_stack.push(vec![AccountAccess {
access: crate::Vm::AccountAccess {
chainInfo: crate::Vm::ChainInfo {
forkId: data.db.active_fork_id().unwrap_or_default(),
chainId: U256::from(data.env.cfg.chain_id),
},
accessor: call.caller,
account: DEFAULT_CREATE2_DEPLOYER,
kind: crate::Vm::AccountAccessKind::Call,
initialized: true,
oldBalance: U256::ZERO, // updated on create_end
newBalance: U256::ZERO, // updated on create_end
value: call.value,
data: calldata,
reverted: false,
deployedCode: vec![], // updated on create_end
storageAccesses: vec![], // updated on create_end
},
depth: data.journaled_state.depth(),
}])
}

// Sanity checks for our CREATE2 deployer
let info =
&data.journaled_state.load_account(DEFAULT_CREATE2_DEPLOYER, data.db)?.0.info;
Expand Down Expand Up @@ -1475,9 +1518,9 @@ fn process_broadcast_create<DB: DatabaseExt>(
data: &mut EVMData<'_, DB>,
call: &mut CreateInputs,
) -> (Bytes, Option<Address>, u64) {
call.caller = broadcast_sender;
match call.scheme {
CreateScheme::Create => {
call.caller = broadcast_sender;
(bytecode, None, data.journaled_state.account(broadcast_sender).info.nonce)
}
CreateScheme::Create2 { salt } => {
Expand Down
7 changes: 7 additions & 0 deletions crates/forge/tests/it/repros.rs
Expand Up @@ -307,3 +307,10 @@ test_repro!(5529; |config| {
cheats_config.always_use_create_2_factory = true;
config.runner.cheats_config = std::sync::Arc::new(cheats_config);
});

// https://github.com/foundry-rs/foundry/issues/6634
test_repro!(6634; |config| {
let mut cheats_config = config.runner.cheats_config.as_ref().clone();
cheats_config.always_use_create_2_factory = true;
config.runner.cheats_config = std::sync::Arc::new(cheats_config);
});
191 changes: 191 additions & 0 deletions testdata/repros/Issue6634.t.sol
@@ -0,0 +1,191 @@
// SPDX-License-Identifier: MIT OR Apache-2.0
pragma solidity 0.8.18;

import "ds-test/test.sol";
import "../cheats/Vm.sol";
import "../logs/console.sol";

contract Box {
uint256 public number;

constructor(uint256 _number) {
number = _number;
}
}

// https://github.com/foundry-rs/foundry/issues/6634
contract Issue6634Test is DSTest {
Vm constant vm = Vm(HEVM_ADDRESS);

function test() public {
address CREATE2_DEPLOYER = 0x4e59b44847b379578588920cA78FbF26c0B4956C;

vm.startStateDiffRecording();
Box a = new Box{salt: 0}(1);

Vm.AccountAccess[] memory called = vm.stopAndReturnStateDiff();
address addr = vm.computeCreate2Address(
0,
keccak256(abi.encodePacked(type(Box).creationCode, uint(1))),
address(CREATE2_DEPLOYER)
);
assertEq(
addr,
called[1].account,
"state diff contract address is not correct"
);
assertEq(
address(a),
called[1].account,
"returned address is not correct"
);

assertEq(called.length, 2, "incorrect length");
assertEq(
uint256(called[0].kind),
uint256(Vm.AccountAccessKind.Call),
"first AccountAccess is incorrect kind"
);
assertEq(
called[0].account,
CREATE2_DEPLOYER,
"first AccountAccess account is incorrect"
);
assertEq(
called[0].accessor,
address(this),
"first AccountAccess accessor is incorrect"
);
assertEq(
uint256(called[1].kind),
uint256(Vm.AccountAccessKind.Create),
"second AccountAccess is incorrect kind"
);
assertEq(
called[1].accessor,
CREATE2_DEPLOYER,
"second AccountAccess accessor is incorrect"
);
assertEq(
called[1].account,
address(a),
"second AccountAccess account is incorrect"
);
}

function testPrank() public {
address CREATE2_DEPLOYER = 0x4e59b44847b379578588920cA78FbF26c0B4956C;
address accessor = address(0x5555);

vm.startPrank(accessor);
vm.startStateDiffRecording();
Box a = new Box{salt: 0}(1);

Vm.AccountAccess[] memory called = vm.stopAndReturnStateDiff();
address addr = vm.computeCreate2Address(
0,
keccak256(abi.encodePacked(type(Box).creationCode, uint(1))),
address(CREATE2_DEPLOYER)
);
assertEq(
addr,
called[1].account,
"state diff contract address is not correct"
);
assertEq(
address(a),
called[1].account,
"returned address is not correct"
);

assertEq(called.length, 2, "incorrect length");
assertEq(
uint256(called[0].kind),
uint256(Vm.AccountAccessKind.Call),
"first AccountAccess is incorrect kind"
);
assertEq(
called[0].account,
CREATE2_DEPLOYER,
"first AccountAccess accout is incorrect"
);
assertEq(
called[0].accessor,
accessor,
"first AccountAccess accessor is incorrect"
);
assertEq(
uint256(called[1].kind),
uint256(Vm.AccountAccessKind.Create),
"second AccountAccess is incorrect kind"
);
assertEq(
called[1].accessor,
CREATE2_DEPLOYER,
"second AccountAccess accessor is incorrect"
);
assertEq(
called[1].account,
address(a),
"second AccountAccess account is incorrect"
);
}

function testBroadcast() public {
address CREATE2_DEPLOYER = 0x4e59b44847b379578588920cA78FbF26c0B4956C;
address accessor = address(0x5555);

vm.startBroadcast(accessor);
vm.startStateDiffRecording();
Box a = new Box{salt: 0}(1);

Vm.AccountAccess[] memory called = vm.stopAndReturnStateDiff();
address addr = vm.computeCreate2Address(
0,
keccak256(abi.encodePacked(type(Box).creationCode, uint(1))),
address(CREATE2_DEPLOYER)
);
assertEq(
addr,
called[1].account,
"state diff contract address is not correct"
);
assertEq(
address(a),
called[1].account,
"returned address is not correct"
);

assertEq(called.length, 2, "incorrect length");
assertEq(
uint256(called[0].kind),
uint256(Vm.AccountAccessKind.Call),
"first AccountAccess is incorrect kind"
);
assertEq(
called[0].account,
CREATE2_DEPLOYER,
"first AccountAccess accout is incorrect"
);
assertEq(
called[0].accessor,
accessor,
"first AccountAccess accessor is incorrect"
);
assertEq(
uint256(called[1].kind),
uint256(Vm.AccountAccessKind.Create),
"second AccountAccess is incorrect kind"
);
assertEq(
called[1].accessor,
CREATE2_DEPLOYER,
"second AccountAccess accessor is incorrect"
);
assertEq(
called[1].account,
address(a),
"second AccountAccess account is incorrect"
);
}
}

0 comments on commit fcb7c1e

Please sign in to comment.