From 9569692d474f7e6c7d4a843e8fe897848a9ba4fa Mon Sep 17 00:00:00 2001 From: Roman Krasiuk Date: Fri, 8 Mar 2024 14:23:27 +0100 Subject: [PATCH] feat(trie): parallel storage roots (#6903) --- Cargo.lock | 23 ++ Cargo.toml | 2 + bin/reth/Cargo.toml | 2 +- crates/blockchain-tree/Cargo.toml | 2 +- crates/interfaces/src/provider.rs | 2 +- crates/stages/Cargo.toml | 2 +- crates/storage/provider/Cargo.toml | 2 +- .../provider/src/providers/consistent_view.rs | 11 +- .../provider/src/providers/database/mod.rs | 14 +- crates/storage/provider/src/providers/mod.rs | 2 +- crates/trie-parallel/Cargo.toml | 65 ++++ crates/trie-parallel/benches/root.rs | 135 +++++++ crates/trie-parallel/src/async_root.rs | 333 ++++++++++++++++++ crates/trie-parallel/src/lib.rs | 26 ++ crates/trie-parallel/src/metrics.rs | 44 +++ crates/trie-parallel/src/parallel_root.rs | 300 ++++++++++++++++ crates/trie-parallel/src/stats.rs | 68 ++++ .../trie-parallel/src/storage_root_targets.rs | 47 +++ crates/trie/Cargo.toml | 1 - crates/trie/src/node_iter.rs | 16 +- crates/trie/src/trie.rs | 69 ++-- crates/trie/src/updates.rs | 47 ++- crates/trie/src/walker.rs | 124 +++---- 23 files changed, 1215 insertions(+), 122 deletions(-) create mode 100644 crates/trie-parallel/Cargo.toml create mode 100644 crates/trie-parallel/benches/root.rs create mode 100644 crates/trie-parallel/src/async_root.rs create mode 100644 crates/trie-parallel/src/lib.rs create mode 100644 crates/trie-parallel/src/metrics.rs create mode 100644 crates/trie-parallel/src/parallel_root.rs create mode 100644 crates/trie-parallel/src/stats.rs create mode 100644 crates/trie-parallel/src/storage_root_targets.rs diff --git a/Cargo.lock b/Cargo.lock index 939bd11c34a..6aebb90f72a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6818,6 +6818,29 @@ dependencies = [ "triehash", ] +[[package]] +name = "reth-trie-parallel" +version = "0.2.0-beta.1" +dependencies = [ + "alloy-rlp", + "criterion", + "derive_more", + "itertools 0.12.1", + "metrics", + "proptest", + "rand 0.8.5", + "rayon", + "reth-db", + "reth-metrics", + "reth-primitives", + "reth-provider", + "reth-tasks", + "reth-trie", + "thiserror", + "tokio", + "tracing", +] + [[package]] name = "revm" version = "6.1.0" diff --git a/Cargo.toml b/Cargo.toml index 9e2d9583898..7de53773e0e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,6 +57,7 @@ members = [ "crates/tracing/", "crates/transaction-pool/", "crates/trie/", + "crates/trie-parallel/", "examples/", "examples/additional-rpc-namespace-in-cli/", "examples/beacon-api-sse/", @@ -182,6 +183,7 @@ reth-tokio-util = { path = "crates/tokio-util" } reth-tracing = { path = "crates/tracing" } reth-transaction-pool = { path = "crates/transaction-pool" } reth-trie = { path = "crates/trie" } +reth-trie-parallel = { path = "crates/trie-parallel" } # revm revm = { version = "6.1.0", features = ["std", "secp256k1"], default-features = false } diff --git a/bin/reth/Cargo.toml b/bin/reth/Cargo.toml index 3c974bcc49e..7cf23adbfda 100644 --- a/bin/reth/Cargo.toml +++ b/bin/reth/Cargo.toml @@ -45,7 +45,7 @@ reth-basic-payload-builder.workspace = true reth-discv4.workspace = true reth-prune.workspace = true reth-static-file = { workspace = true, features = ["clap"] } -reth-trie.workspace = true +reth-trie = { workspace = true, features = ["metrics"] } reth-nippy-jar.workspace = true reth-node-api.workspace = true reth-node-ethereum.workspace = true diff --git a/crates/blockchain-tree/Cargo.toml b/crates/blockchain-tree/Cargo.toml index 2060160d828..4910bb53963 100644 --- a/crates/blockchain-tree/Cargo.toml +++ b/crates/blockchain-tree/Cargo.toml @@ -17,7 +17,7 @@ reth-interfaces.workspace = true reth-db.workspace = true reth-provider.workspace = true reth-stages.workspace = true -reth-trie.workspace = true +reth-trie = { workspace = true, features = ["metrics"] } # common parking_lot.workspace = true diff --git a/crates/interfaces/src/provider.rs b/crates/interfaces/src/provider.rs index f5211072dfc..b72c92304fb 100644 --- a/crates/interfaces/src/provider.rs +++ b/crates/interfaces/src/provider.rs @@ -162,7 +162,7 @@ pub enum ConsistentViewError { Syncing(BlockNumber), /// Error thrown on inconsistent database view. #[error("inconsistent database state: {tip:?}")] - InconsistentView { + Inconsistent { /// The tip diff. tip: GotExpected>, }, diff --git a/crates/stages/Cargo.toml b/crates/stages/Cargo.toml index 0cd3e6dd302..11aa2f54501 100644 --- a/crates/stages/Cargo.toml +++ b/crates/stages/Cargo.toml @@ -18,7 +18,7 @@ reth-interfaces.workspace = true reth-db.workspace = true reth-codecs.workspace = true reth-provider.workspace = true -reth-trie.workspace = true +reth-trie = { workspace = true, features = ["metrics"] } reth-tokio-util.workspace = true reth-etl.workspace = true reth-static-file.workspace = true diff --git a/crates/storage/provider/Cargo.toml b/crates/storage/provider/Cargo.toml index 0efe1c98033..26e6a84b5df 100644 --- a/crates/storage/provider/Cargo.toml +++ b/crates/storage/provider/Cargo.toml @@ -16,7 +16,7 @@ workspace = true reth-primitives.workspace = true reth-interfaces.workspace = true reth-db.workspace = true -reth-trie.workspace = true +reth-trie = { workspace = true, features = ["metrics"] } reth-nippy-jar.workspace = true reth-codecs.workspace = true reth-node-api.workspace = true diff --git a/crates/storage/provider/src/providers/consistent_view.rs b/crates/storage/provider/src/providers/consistent_view.rs index 1175ca53df4..452fa05a1d4 100644 --- a/crates/storage/provider/src/providers/consistent_view.rs +++ b/crates/storage/provider/src/providers/consistent_view.rs @@ -1,9 +1,11 @@ use crate::{BlockNumReader, DatabaseProviderFactory, DatabaseProviderRO, ProviderError}; use reth_db::{cursor::DbCursorRO, database::Database, tables, transaction::DbTx}; -use reth_interfaces::provider::{ConsistentViewError, ProviderResult}; +use reth_interfaces::provider::ProviderResult; use reth_primitives::{GotExpected, B256}; use std::marker::PhantomData; +pub use reth_interfaces::provider::ConsistentViewError; + /// A consistent view over state in the database. /// /// View gets initialized with the latest or provided tip. @@ -14,6 +16,11 @@ use std::marker::PhantomData; /// /// The view should only be used outside of staged-sync. /// Otherwise, any attempt to create a provider will result in [ConsistentViewError::Syncing]. +/// +/// When using the view, the consumer should either +/// 1) have a failover for when the state changes and handle [ConsistentViewError::Inconsistent] +/// appropriately. +/// 2) be sure that the state does not change. #[derive(Clone, Debug)] pub struct ConsistentDbView { database: PhantomData, @@ -56,7 +63,7 @@ where let tip = last_entry.map(|(_, hash)| hash); if self.tip != tip { - return Err(ConsistentViewError::InconsistentView { + return Err(ConsistentViewError::Inconsistent { tip: GotExpected { got: tip, expected: self.tip }, }) } diff --git a/crates/storage/provider/src/providers/database/mod.rs b/crates/storage/provider/src/providers/database/mod.rs index 7393043c4e3..6b413134fca 100644 --- a/crates/storage/provider/src/providers/database/mod.rs +++ b/crates/storage/provider/src/providers/database/mod.rs @@ -5,10 +5,10 @@ use crate::{ }, to_range, traits::{BlockSource, ReceiptProvider}, - BlockHashReader, BlockNumReader, BlockReader, ChainSpecProvider, EvmEnvProvider, - HeaderProvider, HeaderSyncGap, HeaderSyncGapProvider, HeaderSyncMode, ProviderError, - PruneCheckpointReader, StageCheckpointReader, StateProviderBox, TransactionVariant, - TransactionsProvider, WithdrawalsProvider, + BlockHashReader, BlockNumReader, BlockReader, ChainSpecProvider, DatabaseProviderFactory, + EvmEnvProvider, HeaderProvider, HeaderSyncGap, HeaderSyncGapProvider, HeaderSyncMode, + ProviderError, PruneCheckpointReader, StageCheckpointReader, StateProviderBox, + TransactionVariant, TransactionsProvider, WithdrawalsProvider, }; use reth_db::{database::Database, init_db, models::StoredBlockBodyIndices, DatabaseEnv}; use reth_interfaces::{provider::ProviderResult, RethError, RethResult}; @@ -208,6 +208,12 @@ impl ProviderFactory { } } +impl DatabaseProviderFactory for ProviderFactory { + fn database_provider_ro(&self) -> ProviderResult> { + self.provider() + } +} + impl HeaderSyncGapProvider for ProviderFactory { fn sync_gap( &self, diff --git a/crates/storage/provider/src/providers/mod.rs b/crates/storage/provider/src/providers/mod.rs index d9218f48dda..4ca7502887b 100644 --- a/crates/storage/provider/src/providers/mod.rs +++ b/crates/storage/provider/src/providers/mod.rs @@ -59,7 +59,7 @@ mod chain_info; use chain_info::ChainInfoTracker; mod consistent_view; -pub use consistent_view::ConsistentDbView; +pub use consistent_view::{ConsistentDbView, ConsistentViewError}; /// The main type for interacting with the blockchain. /// diff --git a/crates/trie-parallel/Cargo.toml b/crates/trie-parallel/Cargo.toml new file mode 100644 index 00000000000..b5aea9db6f6 --- /dev/null +++ b/crates/trie-parallel/Cargo.toml @@ -0,0 +1,65 @@ +[package] +name = "reth-trie-parallel" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +license.workspace = true +homepage.workspace = true +repository.workspace = true +description = "Parallel implementation of merkle root algorithm" + +[lints] +workspace = true + +[dependencies] +# reth +reth-primitives.workspace = true +reth-db.workspace = true +reth-trie.workspace = true +reth-provider.workspace = true + +# alloy +alloy-rlp.workspace = true + +# tracing +tracing.workspace = true + +# misc +thiserror.workspace = true +derive_more.workspace = true + +# `async` feature +reth-tasks = { workspace = true, optional = true } +tokio = { workspace = true, optional = true, default-features = false } +itertools = { workspace = true, optional = true } + +# `parallel` feature +rayon = { workspace = true, optional = true } + +# `metrics` feature +reth-metrics = { workspace = true, optional = true } +metrics = { workspace = true, optional = true } + +[dev-dependencies] +# reth +reth-primitives = { workspace = true, features = ["test-utils", "arbitrary"] } +reth-provider = { workspace = true, features = ["test-utils"] } +reth-trie = { workspace = true, features = ["test-utils"] } + +# misc +rand.workspace = true +tokio = { workspace = true, default-features = false, features = ["sync", "rt", "macros"] } +rayon.workspace = true +criterion = { workspace = true, features = ["async_tokio"] } +proptest.workspace = true + +[features] +default = ["metrics"] +metrics = ["reth-metrics", "dep:metrics", "reth-trie/metrics"] +async = ["reth-tasks/rayon", "tokio/sync", "itertools"] +parallel = ["rayon"] + +[[bench]] +name = "root" +required-features = ["async", "parallel"] +harness = false diff --git a/crates/trie-parallel/benches/root.rs b/crates/trie-parallel/benches/root.rs new file mode 100644 index 00000000000..ca03ed75824 --- /dev/null +++ b/crates/trie-parallel/benches/root.rs @@ -0,0 +1,135 @@ +#![allow(missing_docs, unreachable_pub)] +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use proptest::{prelude::*, strategy::ValueTree, test_runner::TestRunner}; +use rayon::ThreadPoolBuilder; +use reth_primitives::{Account, B256, U256}; +use reth_provider::{ + bundle_state::HashedStateChanges, providers::ConsistentDbView, + test_utils::create_test_provider_factory, +}; +use reth_tasks::pool::BlockingTaskPool; +use reth_trie::{ + hashed_cursor::HashedPostStateCursorFactory, HashedPostState, HashedStorage, StateRoot, +}; +use reth_trie_parallel::{async_root::AsyncStateRoot, parallel_root::ParallelStateRoot}; +use std::collections::HashMap; + +pub fn calculate_state_root(c: &mut Criterion) { + let mut group = c.benchmark_group("Calculate State Root"); + group.sample_size(20); + + let runtime = tokio::runtime::Runtime::new().unwrap(); + let blocking_pool = BlockingTaskPool::new(ThreadPoolBuilder::default().build().unwrap()); + + for size in [1_000, 3_000, 5_000, 10_000] { + let (db_state, updated_state) = generate_test_data(size); + let provider_factory = create_test_provider_factory(); + { + let provider_rw = provider_factory.provider_rw().unwrap(); + HashedStateChanges(db_state).write_to_db(provider_rw.tx_ref()).unwrap(); + let (_, updates) = + StateRoot::from_tx(provider_rw.tx_ref()).root_with_updates().unwrap(); + updates.flush(provider_rw.tx_ref()).unwrap(); + provider_rw.commit().unwrap(); + } + + let view = ConsistentDbView::new(provider_factory.clone()); + + // state root + group.bench_function(BenchmarkId::new("sync root", size), |b| { + b.to_async(&runtime).iter_with_setup( + || { + let sorted_state = updated_state.clone().into_sorted(); + let prefix_sets = updated_state.construct_prefix_sets(); + let provider = provider_factory.provider().unwrap(); + (provider, sorted_state, prefix_sets) + }, + |(provider, sorted_state, prefix_sets)| async move { + StateRoot::from_tx(provider.tx_ref()) + .with_hashed_cursor_factory(HashedPostStateCursorFactory::new( + provider.tx_ref(), + &sorted_state, + )) + .with_prefix_sets(prefix_sets) + .root() + }, + ) + }); + + // parallel root + group.bench_function(BenchmarkId::new("parallel root", size), |b| { + b.to_async(&runtime).iter_with_setup( + || ParallelStateRoot::new(view.clone(), updated_state.clone()), + |calculator| async { calculator.incremental_root() }, + ); + }); + + // async root + group.bench_function(BenchmarkId::new("async root", size), |b| { + b.to_async(&runtime).iter_with_setup( + || AsyncStateRoot::new(view.clone(), blocking_pool.clone(), updated_state.clone()), + |calculator| calculator.incremental_root(), + ); + }); + } +} + +fn generate_test_data(size: usize) -> (HashedPostState, HashedPostState) { + let storage_size = 1_000; + let mut runner = TestRunner::new(ProptestConfig::default()); + + use proptest::{collection::hash_map, sample::subsequence}; + let db_state = hash_map( + any::(), + ( + any::().prop_filter("non empty account", |a| !a.is_empty()), + hash_map( + any::(), + any::().prop_filter("non zero value", |v| !v.is_zero()), + storage_size, + ), + ), + size, + ) + .new_tree(&mut runner) + .unwrap() + .current(); + + let keys = db_state.keys().cloned().collect::>(); + let keys_to_update = subsequence(keys, size / 2).new_tree(&mut runner).unwrap().current(); + + let updated_storages = keys_to_update + .into_iter() + .map(|address| { + let (_, storage) = db_state.get(&address).unwrap(); + let slots = storage.keys().cloned().collect::>(); + let slots_to_update = + subsequence(slots, storage_size / 2).new_tree(&mut runner).unwrap().current(); + ( + address, + slots_to_update + .into_iter() + .map(|slot| (slot, any::().new_tree(&mut runner).unwrap().current())) + .collect::>(), + ) + }) + .collect::>(); + + ( + HashedPostState::default() + .with_accounts( + db_state.iter().map(|(address, (account, _))| (*address, Some(*account))), + ) + .with_storages(db_state.into_iter().map(|(address, (_, storage))| { + (address, HashedStorage::from_iter(false, storage)) + })), + HashedPostState::default().with_storages( + updated_storages + .into_iter() + .map(|(address, storage)| (address, HashedStorage::from_iter(false, storage))), + ), + ) +} + +criterion_group!(state_root, calculate_state_root); +criterion_main!(state_root); diff --git a/crates/trie-parallel/src/async_root.rs b/crates/trie-parallel/src/async_root.rs new file mode 100644 index 00000000000..85aaad29743 --- /dev/null +++ b/crates/trie-parallel/src/async_root.rs @@ -0,0 +1,333 @@ +use crate::{stats::ParallelTrieTracker, storage_root_targets::StorageRootTargets}; +use alloy_rlp::{BufMut, Encodable}; +use itertools::Itertools; +use reth_db::database::Database; +use reth_primitives::{ + trie::{HashBuilder, Nibbles, TrieAccount}, + B256, +}; +use reth_provider::{ + providers::{ConsistentDbView, ConsistentViewError}, + DatabaseProviderFactory, ProviderError, +}; +use reth_tasks::pool::BlockingTaskPool; +use reth_trie::{ + hashed_cursor::HashedPostStateCursorFactory, + node_iter::{AccountNode, AccountNodeIter}, + trie_cursor::TrieCursorFactory, + updates::TrieUpdates, + walker::TrieWalker, + HashedPostState, StorageRoot, StorageRootError, +}; +use std::{collections::HashMap, sync::Arc}; +use thiserror::Error; +use tracing::*; + +#[cfg(feature = "metrics")] +use crate::metrics::ParallelStateRootMetrics; + +/// Async state root calculator. +/// +/// The calculator starts off by launching tasks to compute storage roots. +/// Then, it immediately starts walking the state trie updating the necessary trie +/// nodes in the process. Upon encountering a leaf node, it will poll the storage root +/// task for the corresponding hashed address. +/// +/// Internally, the calculator uses [ConsistentDbView] since +/// it needs to rely on database state saying the same until +/// the last transaction is open. +/// See docs of using [ConsistentDbView] for caveats. +/// +/// For sync usage, take a look at `ParallelStateRoot`. +#[derive(Debug)] +pub struct AsyncStateRoot { + /// Consistent view of the database. + view: ConsistentDbView, + /// Blocking task pool. + blocking_pool: BlockingTaskPool, + /// Changed hashed state. + hashed_state: HashedPostState, + /// Parallel state root metrics. + #[cfg(feature = "metrics")] + metrics: ParallelStateRootMetrics, +} + +impl AsyncStateRoot { + /// Create new async state root calculator. + pub fn new( + view: ConsistentDbView, + blocking_pool: BlockingTaskPool, + hashed_state: HashedPostState, + ) -> Self { + Self { + view, + blocking_pool, + hashed_state, + #[cfg(feature = "metrics")] + metrics: ParallelStateRootMetrics::default(), + } + } +} + +impl AsyncStateRoot +where + DB: Database + Clone + 'static, + Provider: DatabaseProviderFactory + Clone + Send + Sync + 'static, +{ + /// Calculate incremental state root asynchronously. + pub async fn incremental_root(self) -> Result { + self.calculate(false).await.map(|(root, _)| root) + } + + /// Calculate incremental state root with updates asynchronously. + pub async fn incremental_root_with_updates( + self, + ) -> Result<(B256, TrieUpdates), AsyncStateRootError> { + self.calculate(true).await + } + + async fn calculate( + self, + retain_updates: bool, + ) -> Result<(B256, TrieUpdates), AsyncStateRootError> { + let mut tracker = ParallelTrieTracker::default(); + let prefix_sets = self.hashed_state.construct_prefix_sets(); + let storage_root_targets = StorageRootTargets::new( + self.hashed_state.accounts.keys().copied(), + prefix_sets.storage_prefix_sets, + ); + let hashed_state_sorted = Arc::new(self.hashed_state.into_sorted()); + + // Pre-calculate storage roots async for accounts which were changed. + tracker.set_precomputed_storage_roots(storage_root_targets.len() as u64); + debug!(target: "trie::async_state_root", len = storage_root_targets.len(), "pre-calculating storage roots"); + let mut storage_roots = HashMap::with_capacity(storage_root_targets.len()); + for (hashed_address, prefix_set) in + storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address) + { + let view = self.view.clone(); + let hashed_state_sorted = hashed_state_sorted.clone(); + #[cfg(feature = "metrics")] + let metrics = self.metrics.storage_trie.clone(); + let handle = + self.blocking_pool.spawn_fifo(move || -> Result<_, AsyncStateRootError> { + let provider = view.provider_ro()?; + Ok(StorageRoot::new_hashed( + provider.tx_ref(), + HashedPostStateCursorFactory::new(provider.tx_ref(), &hashed_state_sorted), + hashed_address, + #[cfg(feature = "metrics")] + metrics, + ) + .with_prefix_set(prefix_set) + .calculate(retain_updates)?) + }); + storage_roots.insert(hashed_address, handle); + } + + trace!(target: "trie::async_state_root", "calculating state root"); + let mut trie_updates = TrieUpdates::default(); + + let provider_ro = self.view.provider_ro()?; + let tx = provider_ro.tx_ref(); + let hashed_cursor_factory = HashedPostStateCursorFactory::new(tx, &hashed_state_sorted); + let trie_cursor_factory = tx; + + let trie_cursor = + trie_cursor_factory.account_trie_cursor().map_err(ProviderError::Database)?; + + let mut hash_builder = HashBuilder::default().with_updates(retain_updates); + let walker = TrieWalker::new(trie_cursor, prefix_sets.account_prefix_set) + .with_updates(retain_updates); + let mut account_node_iter = + AccountNodeIter::from_factory(walker, hashed_cursor_factory.clone()) + .map_err(ProviderError::Database)?; + + let mut account_rlp = Vec::with_capacity(128); + while let Some(node) = account_node_iter.try_next().map_err(ProviderError::Database)? { + match node { + AccountNode::Branch(node) => { + hash_builder.add_branch(node.key, node.value, node.children_are_in_trie); + } + AccountNode::Leaf(hashed_address, account) => { + let (storage_root, _, updates) = match storage_roots.remove(&hashed_address) { + Some(rx) => rx.await.map_err(|_| { + AsyncStateRootError::StorageRootChannelClosed { hashed_address } + })??, + // Since we do not store all intermediate nodes in the database, there might + // be a possibility of re-adding a non-modified leaf to the hash builder. + None => { + tracker.inc_missed_leaves(); + StorageRoot::new_hashed( + trie_cursor_factory, + hashed_cursor_factory.clone(), + hashed_address, + #[cfg(feature = "metrics")] + self.metrics.storage_trie.clone(), + ) + .calculate(retain_updates)? + } + }; + + if retain_updates { + trie_updates.extend(updates.into_iter()); + } + + account_rlp.clear(); + let account = TrieAccount::from((account, storage_root)); + account.encode(&mut account_rlp as &mut dyn BufMut); + hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp); + } + } + } + + let root = hash_builder.root(); + + trie_updates.finalize_state_updates( + account_node_iter.walker, + hash_builder, + prefix_sets.destroyed_accounts, + ); + + let stats = tracker.finish(); + + #[cfg(feature = "metrics")] + self.metrics.record_state_trie(stats); + + trace!( + target: "trie::async_state_root", + %root, + duration = ?stats.duration(), + branches_added = stats.branches_added(), + leaves_added = stats.leaves_added(), + missed_leaves = stats.missed_leaves(), + precomputed_storage_roots = stats.precomputed_storage_roots(), + "calculated state root" + ); + + Ok((root, trie_updates)) + } +} + +/// Error during async state root calculation. +#[derive(Error, Debug)] +pub enum AsyncStateRootError { + /// Storage root channel for a given address was closed. + #[error("storage root channel for {hashed_address} got closed")] + StorageRootChannelClosed { + /// The hashed address for which channel was closed. + hashed_address: B256, + }, + /// Consistency error on attempt to create new database provider. + #[error(transparent)] + ConsistentView(#[from] ConsistentViewError), + /// Error while calculating storage root. + #[error(transparent)] + StorageRoot(#[from] StorageRootError), + /// Provider error. + #[error(transparent)] + Provider(#[from] ProviderError), +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::Rng; + use rayon::ThreadPoolBuilder; + use reth_primitives::{keccak256, Account, Address, StorageEntry, U256}; + use reth_provider::{test_utils::create_test_provider_factory, HashingWriter}; + use reth_trie::{test_utils, HashedStorage}; + + #[tokio::test] + async fn random_async_root() { + let blocking_pool = BlockingTaskPool::new(ThreadPoolBuilder::default().build().unwrap()); + + let factory = create_test_provider_factory(); + let consistent_view = ConsistentDbView::new(factory.clone()); + + let mut rng = rand::thread_rng(); + let mut state = (0..100) + .map(|_| { + let address = Address::random(); + let account = + Account { balance: U256::from(rng.gen::()), ..Default::default() }; + let mut storage = HashMap::::default(); + let has_storage = rng.gen_bool(0.7); + if has_storage { + for _ in 0..100 { + storage.insert( + B256::from(U256::from(rng.gen::())), + U256::from(rng.gen::()), + ); + } + } + (address, (account, storage)) + }) + .collect::>(); + + { + let provider_rw = factory.provider_rw().unwrap(); + provider_rw + .insert_account_for_hashing( + state.iter().map(|(address, (account, _))| (*address, Some(*account))), + ) + .unwrap(); + provider_rw + .insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| { + ( + *address, + storage + .iter() + .map(|(slot, value)| StorageEntry { key: *slot, value: *value }), + ) + })) + .unwrap(); + provider_rw.commit().unwrap(); + } + + assert_eq!( + AsyncStateRoot::new( + consistent_view.clone(), + blocking_pool.clone(), + HashedPostState::default() + ) + .incremental_root() + .await + .unwrap(), + test_utils::state_root(state.clone()) + ); + + let mut hashed_state = HashedPostState::default(); + for (address, (account, storage)) in state.iter_mut() { + let hashed_address = keccak256(address); + + let should_update_account = rng.gen_bool(0.5); + if should_update_account { + *account = Account { balance: U256::from(rng.gen::()), ..*account }; + hashed_state.accounts.insert(hashed_address, Some(*account)); + } + + let should_update_storage = rng.gen_bool(0.3); + if should_update_storage { + for (slot, value) in storage.iter_mut() { + let hashed_slot = keccak256(slot); + *value = U256::from(rng.gen::()); + hashed_state + .storages + .entry(hashed_address) + .or_insert_with(|| HashedStorage::new(false)) + .storage + .insert(hashed_slot, *value); + } + } + } + + assert_eq!( + AsyncStateRoot::new(consistent_view.clone(), blocking_pool.clone(), hashed_state) + .incremental_root() + .await + .unwrap(), + test_utils::state_root(state) + ); + } +} diff --git a/crates/trie-parallel/src/lib.rs b/crates/trie-parallel/src/lib.rs new file mode 100644 index 00000000000..ff130b2187e --- /dev/null +++ b/crates/trie-parallel/src/lib.rs @@ -0,0 +1,26 @@ +//! Implementation of exotic state root computation approaches. + +#![doc( + html_logo_url = "https://raw.githubusercontent.com/paradigmxyz/reth/main/assets/reth-docs.png", + html_favicon_url = "https://avatars0.githubusercontent.com/u/97369466?s=256", + issue_tracker_base_url = "https://github.com/paradigmxyz/reth/issues/" +)] +#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] + +mod storage_root_targets; +pub use storage_root_targets::StorageRootTargets; + +/// Parallel trie calculation stats. +pub mod stats; + +/// Implementation of async state root computation. +#[cfg(feature = "async")] +pub mod async_root; + +/// Implementation of parallel state root computation. +#[cfg(feature = "parallel")] +pub mod parallel_root; + +/// Parallel state root metrics. +#[cfg(feature = "metrics")] +pub mod metrics; diff --git a/crates/trie-parallel/src/metrics.rs b/crates/trie-parallel/src/metrics.rs new file mode 100644 index 00000000000..5a4a3ba7406 --- /dev/null +++ b/crates/trie-parallel/src/metrics.rs @@ -0,0 +1,44 @@ +use crate::stats::ParallelTrieStats; +use metrics::Histogram; +use reth_metrics::Metrics; +use reth_trie::metrics::{TrieRootMetrics, TrieType}; + +/// Parallel state root metrics. +#[derive(Debug)] +pub struct ParallelStateRootMetrics { + /// State trie metrics. + pub state_trie: TrieRootMetrics, + /// Parallel trie metrics. + pub parallel: ParallelTrieMetrics, + /// Storage trie metrics. + pub storage_trie: TrieRootMetrics, +} + +impl Default for ParallelStateRootMetrics { + fn default() -> Self { + Self { + state_trie: TrieRootMetrics::new(TrieType::State), + parallel: ParallelTrieMetrics::default(), + storage_trie: TrieRootMetrics::new(TrieType::Storage), + } + } +} + +impl ParallelStateRootMetrics { + /// Record state trie metrics + pub fn record_state_trie(&self, stats: ParallelTrieStats) { + self.state_trie.record(stats.trie_stats()); + self.parallel.precomputed_storage_roots.record(stats.precomputed_storage_roots() as f64); + self.parallel.missed_leaves.record(stats.missed_leaves() as f64); + } +} + +/// Parallel state root metrics. +#[derive(Metrics)] +#[metrics(scope = "trie_parallel")] +pub struct ParallelTrieMetrics { + /// The number of storage roots computed in parallel. + pub precomputed_storage_roots: Histogram, + /// The number of leaves for which we did not pre-compute the storage roots. + pub missed_leaves: Histogram, +} diff --git a/crates/trie-parallel/src/parallel_root.rs b/crates/trie-parallel/src/parallel_root.rs new file mode 100644 index 00000000000..d551815079e --- /dev/null +++ b/crates/trie-parallel/src/parallel_root.rs @@ -0,0 +1,300 @@ +use crate::{stats::ParallelTrieTracker, storage_root_targets::StorageRootTargets}; +use alloy_rlp::{BufMut, Encodable}; +use rayon::prelude::*; +use reth_db::database::Database; +use reth_primitives::{ + trie::{HashBuilder, Nibbles, TrieAccount}, + B256, +}; +use reth_provider::{ + providers::{ConsistentDbView, ConsistentViewError}, + DatabaseProviderFactory, ProviderError, +}; +use reth_trie::{ + hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory}, + node_iter::{AccountNode, AccountNodeIter}, + trie_cursor::TrieCursorFactory, + updates::TrieUpdates, + walker::TrieWalker, + HashedPostState, StorageRoot, StorageRootError, +}; +use std::collections::HashMap; +use thiserror::Error; +use tracing::*; + +#[cfg(feature = "metrics")] +use crate::metrics::ParallelStateRootMetrics; + +/// Parallel incremental state root calculator. +/// +/// The calculator starts off by pre-computing storage roots of changed +/// accounts in parallel. Once that's done, it proceeds to walking the state +/// trie retrieving the pre-computed storage roots when needed. +/// +/// Internally, the calculator uses [ConsistentDbView] since +/// it needs to rely on database state saying the same until +/// the last transaction is open. +/// See docs of using [ConsistentDbView] for caveats. +/// +/// If possible, use more optimized `AsyncStateRoot` instead. +#[derive(Debug)] +pub struct ParallelStateRoot { + /// Consistent view of the database. + view: ConsistentDbView, + /// Changed hashed state. + hashed_state: HashedPostState, + /// Parallel state root metrics. + #[cfg(feature = "metrics")] + metrics: ParallelStateRootMetrics, +} + +impl ParallelStateRoot { + /// Create new parallel state root calculator. + pub fn new(view: ConsistentDbView, hashed_state: HashedPostState) -> Self { + Self { + view, + hashed_state, + #[cfg(feature = "metrics")] + metrics: ParallelStateRootMetrics::default(), + } + } +} + +impl ParallelStateRoot +where + DB: Database, + Provider: DatabaseProviderFactory + Send + Sync, +{ + /// Calculate incremental state root in parallel. + pub fn incremental_root(self) -> Result { + self.calculate(false).map(|(root, _)| root) + } + + /// Calculate incremental state root with updates in parallel. + pub fn incremental_root_with_updates( + self, + ) -> Result<(B256, TrieUpdates), ParallelStateRootError> { + self.calculate(true) + } + + fn calculate( + self, + retain_updates: bool, + ) -> Result<(B256, TrieUpdates), ParallelStateRootError> { + let mut tracker = ParallelTrieTracker::default(); + let prefix_sets = self.hashed_state.construct_prefix_sets(); + let storage_root_targets = StorageRootTargets::new( + self.hashed_state.accounts.keys().copied(), + prefix_sets.storage_prefix_sets, + ); + let hashed_state_sorted = self.hashed_state.into_sorted(); + + // Pre-calculate storage roots in parallel for accounts which were changed. + tracker.set_precomputed_storage_roots(storage_root_targets.len() as u64); + debug!(target: "trie::parallel_state_root", len = storage_root_targets.len(), "pre-calculating storage roots"); + let mut storage_roots = storage_root_targets + .into_par_iter() + .map(|(hashed_address, prefix_set)| { + let provider_ro = self.view.provider_ro()?; + let storage_root_result = StorageRoot::new_hashed( + provider_ro.tx_ref(), + HashedPostStateCursorFactory::new(provider_ro.tx_ref(), &hashed_state_sorted), + hashed_address, + #[cfg(feature = "metrics")] + self.metrics.storage_trie.clone(), + ) + .with_prefix_set(prefix_set) + .calculate(retain_updates); + Ok((hashed_address, storage_root_result?)) + }) + .collect::, ParallelStateRootError>>()?; + + trace!(target: "trie::parallel_state_root", "calculating state root"); + let mut trie_updates = TrieUpdates::default(); + + let provider_ro = self.view.provider_ro()?; + let hashed_cursor_factory = + HashedPostStateCursorFactory::new(provider_ro.tx_ref(), &hashed_state_sorted); + let trie_cursor_factory = provider_ro.tx_ref(); + + let hashed_account_cursor = + hashed_cursor_factory.hashed_account_cursor().map_err(ProviderError::Database)?; + let trie_cursor = + trie_cursor_factory.account_trie_cursor().map_err(ProviderError::Database)?; + + let walker = TrieWalker::new(trie_cursor, prefix_sets.account_prefix_set) + .with_updates(retain_updates); + let mut account_node_iter = AccountNodeIter::new(walker, hashed_account_cursor); + let mut hash_builder = HashBuilder::default().with_updates(retain_updates); + + let mut account_rlp = Vec::with_capacity(128); + while let Some(node) = account_node_iter.try_next().map_err(ProviderError::Database)? { + match node { + AccountNode::Branch(node) => { + hash_builder.add_branch(node.key, node.value, node.children_are_in_trie); + } + AccountNode::Leaf(hashed_address, account) => { + let (storage_root, _, updates) = match storage_roots.remove(&hashed_address) { + Some(result) => result, + // Since we do not store all intermediate nodes in the database, there might + // be a possibility of re-adding a non-modified leaf to the hash builder. + None => { + tracker.inc_missed_leaves(); + StorageRoot::new_hashed( + trie_cursor_factory, + hashed_cursor_factory.clone(), + hashed_address, + #[cfg(feature = "metrics")] + self.metrics.storage_trie.clone(), + ) + .calculate(retain_updates)? + } + }; + + if retain_updates { + trie_updates.extend(updates.into_iter()); + } + + account_rlp.clear(); + let account = TrieAccount::from((account, storage_root)); + account.encode(&mut account_rlp as &mut dyn BufMut); + hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp); + } + } + } + + let root = hash_builder.root(); + + trie_updates.finalize_state_updates( + account_node_iter.walker, + hash_builder, + prefix_sets.destroyed_accounts, + ); + + let stats = tracker.finish(); + + #[cfg(feature = "metrics")] + self.metrics.record_state_trie(stats); + + trace!( + target: "trie::parallel_state_root", + %root, + duration = ?stats.duration(), + branches_added = stats.branches_added(), + leaves_added = stats.leaves_added(), + missed_leaves = stats.missed_leaves(), + precomputed_storage_roots = stats.precomputed_storage_roots(), + "calculated state root" + ); + + Ok((root, trie_updates)) + } +} + +/// Error during parallel state root calculation. +#[derive(Error, Debug)] +pub enum ParallelStateRootError { + /// Consistency error on attempt to create new database provider. + #[error(transparent)] + ConsistentView(#[from] ConsistentViewError), + /// Error while calculating storage root. + #[error(transparent)] + StorageRoot(#[from] StorageRootError), + /// Provider error. + #[error(transparent)] + Provider(#[from] ProviderError), +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::Rng; + use reth_primitives::{keccak256, Account, Address, StorageEntry, U256}; + use reth_provider::{test_utils::create_test_provider_factory, HashingWriter}; + use reth_trie::{test_utils, HashedStorage}; + + #[tokio::test] + async fn random_parallel_root() { + let factory = create_test_provider_factory(); + let consistent_view = ConsistentDbView::new(factory.clone()); + + let mut rng = rand::thread_rng(); + let mut state = (0..100) + .map(|_| { + let address = Address::random(); + let account = + Account { balance: U256::from(rng.gen::()), ..Default::default() }; + let mut storage = HashMap::::default(); + let has_storage = rng.gen_bool(0.7); + if has_storage { + for _ in 0..100 { + storage.insert( + B256::from(U256::from(rng.gen::())), + U256::from(rng.gen::()), + ); + } + } + (address, (account, storage)) + }) + .collect::>(); + + { + let provider_rw = factory.provider_rw().unwrap(); + provider_rw + .insert_account_for_hashing( + state.iter().map(|(address, (account, _))| (*address, Some(*account))), + ) + .unwrap(); + provider_rw + .insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| { + ( + *address, + storage + .iter() + .map(|(slot, value)| StorageEntry { key: *slot, value: *value }), + ) + })) + .unwrap(); + provider_rw.commit().unwrap(); + } + + assert_eq!( + ParallelStateRoot::new(consistent_view.clone(), HashedPostState::default()) + .incremental_root() + .unwrap(), + test_utils::state_root(state.clone()) + ); + + let mut hashed_state = HashedPostState::default(); + for (address, (account, storage)) in state.iter_mut() { + let hashed_address = keccak256(address); + + let should_update_account = rng.gen_bool(0.5); + if should_update_account { + *account = Account { balance: U256::from(rng.gen::()), ..*account }; + hashed_state.accounts.insert(hashed_address, Some(*account)); + } + + let should_update_storage = rng.gen_bool(0.3); + if should_update_storage { + for (slot, value) in storage.iter_mut() { + let hashed_slot = keccak256(slot); + *value = U256::from(rng.gen::()); + hashed_state + .storages + .entry(hashed_address) + .or_insert_with(|| HashedStorage::new(false)) + .storage + .insert(hashed_slot, *value); + } + } + } + + assert_eq!( + ParallelStateRoot::new(consistent_view.clone(), hashed_state) + .incremental_root() + .unwrap(), + test_utils::state_root(state) + ); + } +} diff --git a/crates/trie-parallel/src/stats.rs b/crates/trie-parallel/src/stats.rs new file mode 100644 index 00000000000..e3d4d77a0a8 --- /dev/null +++ b/crates/trie-parallel/src/stats.rs @@ -0,0 +1,68 @@ +use derive_more::Deref; +use reth_trie::stats::{TrieStats, TrieTracker}; + +/// Trie stats. +#[derive(Deref, Clone, Copy, Debug)] +pub struct ParallelTrieStats { + #[deref] + trie: TrieStats, + precomputed_storage_roots: u64, + missed_leaves: u64, +} + +impl ParallelTrieStats { + /// Return general trie stats. + pub fn trie_stats(&self) -> TrieStats { + self.trie + } + + /// The number of pre-computed storage roots. + pub fn precomputed_storage_roots(&self) -> u64 { + self.precomputed_storage_roots + } + + /// The number of added leaf nodes for which we did not precompute the storage root. + pub fn missed_leaves(&self) -> u64 { + self.missed_leaves + } +} + +/// Trie metrics tracker. +#[derive(Deref, Default, Debug)] +pub struct ParallelTrieTracker { + #[deref] + trie: TrieTracker, + precomputed_storage_roots: u64, + missed_leaves: u64, +} + +impl ParallelTrieTracker { + /// Set the number of precomputed storage roots. + pub fn set_precomputed_storage_roots(&mut self, count: u64) { + self.precomputed_storage_roots = count; + } + + /// Increment the number of branches added to the hash builder during the calculation. + pub fn inc_branch(&mut self) { + self.trie.inc_branch(); + } + + /// Increment the number of leaves added to the hash builder during the calculation. + pub fn inc_leaf(&mut self) { + self.trie.inc_leaf(); + } + + /// Increment the number of added leaf nodes for which we did not precompute the storage root. + pub fn inc_missed_leaves(&mut self) { + self.missed_leaves += 1; + } + + /// Called when root calculation is finished to return trie statistics. + pub fn finish(self) -> ParallelTrieStats { + ParallelTrieStats { + trie: self.trie.finish(), + precomputed_storage_roots: self.precomputed_storage_roots, + missed_leaves: self.missed_leaves, + } + } +} diff --git a/crates/trie-parallel/src/storage_root_targets.rs b/crates/trie-parallel/src/storage_root_targets.rs new file mode 100644 index 00000000000..a0ada63491d --- /dev/null +++ b/crates/trie-parallel/src/storage_root_targets.rs @@ -0,0 +1,47 @@ +use derive_more::{Deref, DerefMut}; +use reth_primitives::B256; +use reth_trie::prefix_set::PrefixSet; +use std::collections::HashMap; + +/// Target accounts with corresponding prefix sets for storage root calculation. +#[derive(Deref, DerefMut, Debug)] +pub struct StorageRootTargets(HashMap); + +impl StorageRootTargets { + /// Create new storage root targets from updated post state accounts + /// and storage prefix sets. + /// + /// NOTE: Since updated accounts and prefix sets always overlap, + /// it's important that iterator over storage prefix sets takes precedence. + pub fn new( + changed_accounts: impl IntoIterator, + storage_prefix_sets: impl IntoIterator, + ) -> Self { + Self( + changed_accounts + .into_iter() + .map(|address| (address, PrefixSet::default())) + .chain(storage_prefix_sets) + .collect(), + ) + } +} + +impl IntoIterator for StorageRootTargets { + type Item = (B256, PrefixSet); + type IntoIter = std::collections::hash_map::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +#[cfg(feature = "parallel")] +impl rayon::iter::IntoParallelIterator for StorageRootTargets { + type Item = (B256, PrefixSet); + type Iter = rayon::collections::hash_map::IntoIter; + + fn into_par_iter(self) -> Self::Iter { + self.0.into_par_iter() + } +} diff --git a/crates/trie/Cargo.toml b/crates/trie/Cargo.toml index 45939c14bfe..55f2672e93e 100644 --- a/crates/trie/Cargo.toml +++ b/crates/trie/Cargo.toml @@ -56,7 +56,6 @@ similar-asserts.workspace = true criterion.workspace = true [features] -default = ["metrics"] metrics = ["reth-metrics", "dep:metrics"] test-utils = ["triehash"] diff --git a/crates/trie/src/node_iter.rs b/crates/trie/src/node_iter.rs index 4d75c006efa..742896140bc 100644 --- a/crates/trie/src/node_iter.rs +++ b/crates/trie/src/node_iter.rs @@ -1,9 +1,9 @@ use crate::{ - hashed_cursor::{HashedAccountCursor, HashedStorageCursor}, + hashed_cursor::{HashedAccountCursor, HashedCursorFactory, HashedStorageCursor}, trie_cursor::TrieCursor, walker::TrieWalker, - StateRootError, StorageRootError, }; +use reth_db::DatabaseError; use reth_primitives::{trie::Nibbles, Account, StorageEntry, B256, U256}; /// Represents a branch node in the trie. @@ -71,6 +71,14 @@ impl AccountNodeIter { } } + /// Create new `AccountNodeIter` by creating hashed account cursor from factory. + pub fn from_factory>( + walker: TrieWalker, + factory: F, + ) -> Result { + Ok(Self::new(walker, factory.hashed_account_cursor()?)) + } + /// Sets the last iterated account key and returns the modified `AccountNodeIter`. /// This is used to resume iteration from the last checkpoint. pub fn with_last_account_key(mut self, previous_account_key: B256) -> Self { @@ -95,7 +103,7 @@ where /// 5. Repeat. /// /// NOTE: The iteration will start from the key of the previous hashed entry if it was supplied. - pub fn try_next(&mut self) -> Result, StateRootError> { + pub fn try_next(&mut self) -> Result, DatabaseError> { loop { // If the walker has a key... if let Some(key) = self.walker.key() { @@ -194,7 +202,7 @@ where /// 3. Reposition the hashed storage cursor on the next unprocessed key. /// 4. Return every hashed storage entry up to the key of the current intermediate branch node. /// 5. Repeat. - pub fn try_next(&mut self) -> Result, StorageRootError> { + pub fn try_next(&mut self) -> Result, DatabaseError> { loop { // Check if there's a key in the walker. if let Some(key) = self.walker.key() { diff --git a/crates/trie/src/trie.rs b/crates/trie/src/trie.rs index 08742122d74..82fbeceec21 100644 --- a/crates/trie/src/trie.rs +++ b/crates/trie/src/trie.rs @@ -1,7 +1,7 @@ use crate::{ hashed_cursor::{HashedCursorFactory, HashedStorageCursor}, node_iter::{AccountNode, AccountNodeIter, StorageNode, StorageNodeIter}, - prefix_set::{PrefixSet, PrefixSetLoader, PrefixSetMut, TriePrefixSets}, + prefix_set::{PrefixSet, PrefixSetLoader, TriePrefixSets}, progress::{IntermediateStateRootState, StateRootProgress}, stats::TrieTracker, trie_cursor::TrieCursorFactory, @@ -214,31 +214,32 @@ where let mut tracker = TrieTracker::default(); let mut trie_updates = TrieUpdates::default(); - let hashed_account_cursor = self.hashed_cursor_factory.hashed_account_cursor()?; let trie_cursor = self.trie_cursor_factory.account_trie_cursor()?; let (mut hash_builder, mut account_node_iter) = match self.previous_state { Some(state) => { + let hash_builder = state.hash_builder.with_updates(retain_updates); let walker = TrieWalker::from_stack( trie_cursor, state.walker_stack, self.prefix_sets.account_prefix_set, - ); - ( - state.hash_builder, - AccountNodeIter::new(walker, hashed_account_cursor) - .with_last_account_key(state.last_account_key), ) + .with_updates(retain_updates); + let node_iter = + AccountNodeIter::from_factory(walker, self.hashed_cursor_factory.clone())? + .with_last_account_key(state.last_account_key); + (hash_builder, node_iter) } None => { - let walker = TrieWalker::new(trie_cursor, self.prefix_sets.account_prefix_set); - (HashBuilder::default(), AccountNodeIter::new(walker, hashed_account_cursor)) + let hash_builder = HashBuilder::default().with_updates(retain_updates); + let walker = TrieWalker::new(trie_cursor, self.prefix_sets.account_prefix_set) + .with_updates(retain_updates); + let node_iter = + AccountNodeIter::from_factory(walker, self.hashed_cursor_factory.clone())?; + (hash_builder, node_iter) } }; - account_node_iter.walker.set_updates(retain_updates); - hash_builder.set_updates(retain_updates); - let mut account_rlp = Vec::with_capacity(128); let mut hashed_entries_walked = 0; while let Some(node) = account_node_iter.try_next()? { @@ -283,11 +284,9 @@ where storage_root_calculator.root()? }; - let account = TrieAccount::from((account, storage_root)); - account_rlp.clear(); + let account = TrieAccount::from((account, storage_root)); account.encode(&mut account_rlp as &mut dyn BufMut); - hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp); // Decide if we need to return intermediate progress. @@ -319,13 +318,10 @@ where let root = hash_builder.root(); - let (_, walker_updates) = account_node_iter.walker.split(); - let (_, hash_builder_updates) = hash_builder.split(); - - trie_updates.extend(walker_updates); - trie_updates.extend_with_account_updates(hash_builder_updates); - trie_updates.extend_with_deletes( - self.prefix_sets.destroyed_accounts.into_iter().map(TrieKey::StorageTrie), + trie_updates.finalize_state_updates( + account_node_iter.walker, + hash_builder, + self.prefix_sets.destroyed_accounts, ); let stats = tracker.finish(); @@ -357,8 +353,8 @@ pub struct StorageRoot { pub hashed_address: B256, /// The set of storage slot prefixes that have changed. pub prefix_set: PrefixSet, - #[cfg(feature = "metrics")] /// Storage root metrics. + #[cfg(feature = "metrics")] metrics: TrieRootMetrics, } @@ -390,7 +386,7 @@ impl StorageRoot { trie_cursor_factory, hashed_cursor_factory, hashed_address, - prefix_set: PrefixSetMut::default().freeze(), + prefix_set: PrefixSet::default(), #[cfg(feature = "metrics")] metrics, } @@ -475,7 +471,13 @@ where Ok(root) } - fn calculate( + /// Walks the hashed storage table entries for a given address and calculates the storage root. + /// + /// # Returns + /// + /// The storage root, number of walked entries and trie updates + /// for a given address if requested. + pub fn calculate( self, retain_updates: bool, ) -> Result<(B256, usize, TrieUpdates), StorageRootError> { @@ -518,12 +520,12 @@ where let root = hash_builder.root(); - let (_, hash_builder_updates) = hash_builder.split(); - let (_, walker_updates) = storage_node_iter.walker.split(); - let mut trie_updates = TrieUpdates::default(); - trie_updates.extend(walker_updates); - trie_updates.extend_with_storage_updates(self.hashed_address, hash_builder_updates); + trie_updates.finalize_storage_updates( + self.hashed_address, + storage_node_iter.walker, + hash_builder, + ); let stats = tracker.finish(); @@ -548,8 +550,9 @@ where #[cfg(test)] mod tests { use super::*; - use crate::test_utils::{ - state_root, state_root_prehashed, storage_root, storage_root_prehashed, + use crate::{ + prefix_set::PrefixSetMut, + test_utils::{state_root, state_root_prehashed, storage_root, storage_root_prehashed}, }; use proptest::{prelude::ProptestConfig, proptest}; use reth_db::{ @@ -788,7 +791,7 @@ mod tests { tx.commit().unwrap(); let tx = factory.provider_rw().unwrap(); - let expected = state_root(state.into_iter()); + let expected = state_root(state); let threshold = 10; let mut got = None; diff --git a/crates/trie/src/updates.rs b/crates/trie/src/updates.rs index 7580f3ab40a..fc263aefb8c 100644 --- a/crates/trie/src/updates.rs +++ b/crates/trie/src/updates.rs @@ -6,12 +6,14 @@ use reth_db::{ }; use reth_primitives::{ trie::{ - BranchNodeCompact, Nibbles, StorageTrieEntry, StoredBranchNode, StoredNibbles, + BranchNodeCompact, HashBuilder, Nibbles, StorageTrieEntry, StoredBranchNode, StoredNibbles, StoredNibblesSubKey, }, B256, }; -use std::collections::{hash_map::IntoIter, HashMap}; +use std::collections::{hash_map::IntoIter, HashMap, HashSet}; + +use crate::walker::TrieWalker; /// The key of a trie node. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -88,22 +90,45 @@ impl TrieUpdates { ); } - /// Extend the updates with storage trie updates. - pub fn extend_with_storage_updates( + /// Finalize state trie updates. + pub fn finalize_state_updates( + &mut self, + walker: TrieWalker, + hash_builder: HashBuilder, + destroyed_accounts: HashSet, + ) { + // Add updates from trie walker. + let (_, walker_updates) = walker.split(); + self.extend(walker_updates); + + // Add account node updates from hash builder. + let (_, hash_builder_updates) = hash_builder.split(); + self.extend_with_account_updates(hash_builder_updates); + + // Add deleted storage tries for destroyed accounts. + self.extend( + destroyed_accounts.into_iter().map(|key| (TrieKey::StorageTrie(key), TrieOp::Delete)), + ); + } + + /// Finalize storage trie updates for a given address. + pub fn finalize_storage_updates( &mut self, hashed_address: B256, - updates: HashMap, + walker: TrieWalker, + hash_builder: HashBuilder, ) { - self.extend(updates.into_iter().map(|(nibbles, node)| { + // Add updates from trie walker. + let (_, walker_updates) = walker.split(); + self.extend(walker_updates); + + // Add storage node updates from hash builder. + let (_, hash_builder_updates) = hash_builder.split(); + self.extend(hash_builder_updates.into_iter().map(|(nibbles, node)| { (TrieKey::StorageNode(hashed_address, nibbles.into()), TrieOp::Update(node)) })); } - /// Extend the updates with deletes. - pub fn extend_with_deletes(&mut self, keys: impl IntoIterator) { - self.extend(keys.into_iter().map(|key| (key, TrieOp::Delete))); - } - /// Flush updates all aggregated updates to the database. pub fn flush(self, tx: &(impl DbTx + DbTxMut)) -> Result<(), reth_db::DatabaseError> { if self.trie_operations.is_empty() { diff --git a/crates/trie/src/walker.rs b/crates/trie/src/walker.rs index 3710648fd20..42a8dde277a 100644 --- a/crates/trie/src/walker.rs +++ b/crates/trie/src/walker.rs @@ -28,28 +28,7 @@ pub struct TrieWalker { trie_updates: Option, } -impl TrieWalker { - /// Constructs a new TrieWalker, setting up the initial state of the stack and cursor. - pub fn new(cursor: C, changes: PrefixSet) -> Self { - // Initialize the walker with a single empty stack element. - let mut this = Self { - cursor, - changes, - stack: vec![CursorSubNode::default()], - can_skip_current_node: false, - trie_updates: None, - }; - - // Set up the root node of the trie in the stack, if it exists. - if let Some((key, value)) = this.node(true).unwrap() { - this.stack[0] = CursorSubNode::new(key, Some(value)); - } - - // Update the skip state for the root node. - this.update_skip_node(); - this - } - +impl TrieWalker { /// Constructs a new TrieWalker from existing stack and a cursor. pub fn from_stack(cursor: C, stack: Vec, changes: PrefixSet) -> Self { let mut this = @@ -91,6 +70,68 @@ impl TrieWalker { self.trie_updates.as_ref().map(|u| u.len()).unwrap_or(0) } + /// Returns the current key in the trie. + pub fn key(&self) -> Option<&Nibbles> { + self.stack.last().map(|n| n.full_key()) + } + + /// Returns the current hash in the trie if any. + pub fn hash(&self) -> Option { + self.stack.last().and_then(|n| n.hash()) + } + + /// Indicates whether the children of the current node are present in the trie. + pub fn children_are_in_trie(&self) -> bool { + self.stack.last().map_or(false, |n| n.tree_flag()) + } + + /// Returns the next unprocessed key in the trie. + pub fn next_unprocessed_key(&self) -> Option { + self.key() + .and_then(|key| { + if self.can_skip_current_node { + key.increment().map(|inc| inc.pack()) + } else { + Some(key.pack()) + } + }) + .map(|mut key| { + key.resize(32, 0); + B256::from_slice(key.as_slice()) + }) + } + + /// Updates the skip node flag based on the walker's current state. + fn update_skip_node(&mut self) { + self.can_skip_current_node = self + .stack + .last() + .map_or(false, |node| !self.changes.contains(node.full_key()) && node.hash_flag()); + } +} + +impl TrieWalker { + /// Constructs a new TrieWalker, setting up the initial state of the stack and cursor. + pub fn new(cursor: C, changes: PrefixSet) -> Self { + // Initialize the walker with a single empty stack element. + let mut this = Self { + cursor, + changes, + stack: vec![CursorSubNode::default()], + can_skip_current_node: false, + trie_updates: None, + }; + + // Set up the root node of the trie in the stack, if it exists. + if let Some((key, value)) = this.node(true).unwrap() { + this.stack[0] = CursorSubNode::new(key, Some(value)); + } + + // Update the skip state for the root node. + this.update_skip_node(); + this + } + /// Advances the walker to the next trie node and updates the skip node flag. /// /// # Returns @@ -200,45 +241,6 @@ impl TrieWalker { Ok(()) } - - /// Returns the current key in the trie. - pub fn key(&self) -> Option<&Nibbles> { - self.stack.last().map(|n| n.full_key()) - } - - /// Returns the current hash in the trie if any. - pub fn hash(&self) -> Option { - self.stack.last().and_then(|n| n.hash()) - } - - /// Indicates whether the children of the current node are present in the trie. - pub fn children_are_in_trie(&self) -> bool { - self.stack.last().map_or(false, |n| n.tree_flag()) - } - - /// Returns the next unprocessed key in the trie. - pub fn next_unprocessed_key(&self) -> Option { - self.key() - .and_then(|key| { - if self.can_skip_current_node { - key.increment().map(|inc| inc.pack()) - } else { - Some(key.pack()) - } - }) - .map(|mut key| { - key.resize(32, 0); - B256::from_slice(key.as_slice()) - }) - } - - /// Updates the skip node flag based on the walker's current state. - fn update_skip_node(&mut self) { - self.can_skip_current_node = self - .stack - .last() - .map_or(false, |node| !self.changes.contains(node.full_key()) && node.hash_flag()); - } } #[cfg(test)]