Skip to content

Commit

Permalink
refactor(trimming): implement pre-trimming (#687)
Browse files Browse the repository at this point in the history
* refactor(trimming): implement pre-trimming

This commit refactors trimming such that the staking-miner is
responsible to ensure that trimming doesn't have to be done be on-chain.

Instead the staking-miner iteratively removes the votes by one voter at
the time and re-computes the solution. Once a solution with no trimming
is found that solution is submitted to the chain

* cargo fmt

* cargo fmt

* update to polkadot-sdk

* cleanup the pre-trimming code

* remove checking solution score

* update polkadot-sdk

* cleaup

* Update src/epm.rs

Co-authored-by: Gonçalo Pestana <g6pestana@gmail.com>

* add promethous counter for trimming

* cargo fmt

* fix nits

* pre-trimming: impl bsearch to find pre-trim

* fix nit in weight trimmming

---------

Co-authored-by: Gonçalo Pestana <g6pestana@gmail.com>
  • Loading branch information
niklasad1 and gpestana committed Nov 13, 2023
1 parent 511cbb0 commit 2ea305e
Show file tree
Hide file tree
Showing 8 changed files with 1,129 additions and 393 deletions.
1,102 changes: 840 additions & 262 deletions Cargo.lock

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ pin-project-lite = "0.2"
subxt = "0.29"
scale-value = "0.10.0"

# substrate
frame-election-provider-support = { git = "https://github.com/paritytech/substrate" }
pallet-election-provider-multi-phase = { git = "https://github.com/paritytech/substrate" }
sp-npos-elections = { git = "https://github.com/paritytech/substrate" }
frame-support = { git = "https://github.com/paritytech/substrate" }
sp-runtime = { git = "https://github.com/paritytech/substrate" }
# polkadot-sdk
frame-election-provider-support = { git = "https://github.com/paritytech/polkadot-sdk" }
pallet-election-provider-multi-phase = { git = "https://github.com/paritytech/polkadot-sdk" }
sp-npos-elections = { git = "https://github.com/paritytech/polkadot-sdk/" }
frame-support = { git = "https://github.com/paritytech/polkadot-sdk" }
sp-runtime = { git = "https://github.com/paritytech/polkadot-sdk" }

# prometheus
prometheus = "0.13"
Expand All @@ -37,7 +37,7 @@ once_cell = "1.18"
[dev-dependencies]
anyhow = "1"
assert_cmd = "2.0"
sp-storage = { git = "https://github.com/paritytech/substrate" }
sp-storage = { git = "https://github.com/paritytech/polkadot-sdk" }
regex = "1"

[features]
Expand Down
5 changes: 2 additions & 3 deletions src/commands/monitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ where

(solution, score)
},
(Err(e), _) => return Err(e),
(Err(e), _) => return Err(Error::Other(e.to_string())),
};

let best_head = get_latest_head(&api, config.listen).await?;
Expand Down Expand Up @@ -421,8 +421,7 @@ where
(Err(e), _) => {
log::warn!(
target: LOG_TARGET,
"submit_and_watch_solution failed: {:?}; skipping block: {}",
e,
"submit_and_watch_solution failed: {e}; skipping block: {}",
at.number
);
},
Expand Down
259 changes: 222 additions & 37 deletions src/epm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,34 @@ use crate::{
helpers::{storage_at, RuntimeDispatchInfo},
opt::{BalanceIterations, Balancing, Solver},
prelude::*,
static_types,
prometheus,
static_types::{self},
};

use std::{
collections::{BTreeMap, BTreeSet},
marker::PhantomData,
};

use codec::{Decode, Encode};
use frame_election_provider_support::{NposSolution, PhragMMS, SequentialPhragmen};
use frame_support::weights::Weight;
use pallet_election_provider_multi_phase::{RawSolution, ReadySolution, SolutionOrSnapshotSize};
use frame_election_provider_support::{Get, NposSolution, PhragMMS, SequentialPhragmen};
use frame_support::{weights::Weight, BoundedVec};
use pallet_election_provider_multi_phase::{
unsigned::TrimmingStatus, RawSolution, ReadySolution, SolutionOf, SolutionOrSnapshotSize,
};
use scale_info::{PortableRegistry, TypeInfo};
use scale_value::scale::{decode_as_type, TypeId};
use sp_core::Bytes;
use sp_npos_elections::ElectionScore;
use sp_npos_elections::{ElectionScore, VoteWeight};
use subxt::{dynamic::Value, rpc::rpc_params, tx::DynamicPayload};

const EPM_PALLET_NAME: &str = "ElectionProviderMultiPhase";

type MinerVoterOf =
frame_election_provider_support::Voter<AccountId, crate::static_types::MaxVotesPerVoter>;

type RoundSnapshot = pallet_election_provider_multi_phase::RoundSnapshot<AccountId, MinerVoterOf>;
type Voters =
Vec<(AccountId, VoteWeight, BoundedVec<AccountId, crate::static_types::MaxVotesPerVoter>)>;

#[derive(Copy, Clone, Debug)]
struct EpmConstant {
Expand All @@ -62,6 +72,113 @@ impl std::fmt::Display for EpmConstant {
}
}

#[derive(Debug)]
pub struct State {
voters: Voters,
voters_by_stake: BTreeMap<VoteWeight, usize>,
}

impl State {
fn len(&self) -> usize {
self.voters_by_stake.len()
}

fn to_voters(&self) -> Voters {
self.voters.clone()
}
}

/// Represent voters that may be trimmed
///
/// The trimming works by removing the voter with the least amount of stake.
///
/// It's using an internal `BTreeMap` to determine which voter to remove next
/// and the voters Vec can't be sorted because the EPM pallet will index into it
/// when checking the solution.
#[derive(Debug)]
pub struct TrimmedVoters<T> {
state: State,
_marker: PhantomData<T>,
}

impl<T> TrimmedVoters<T>
where
T: MinerConfig<AccountId = AccountId, MaxVotesPerVoter = static_types::MaxVotesPerVoter>
+ Send
+ Sync
+ 'static,
T::Solution: Send,
{
/// Create a new `TrimmedVotes`.
pub async fn new(mut voters: Voters, desired_targets: u32) -> Result<Self, Error> {
let mut voters_by_stake = BTreeMap::new();
let mut targets = BTreeSet::new();

for (idx, (_voter, stake, supports)) in voters.iter().enumerate() {
voters_by_stake.insert(*stake, idx);
targets.extend(supports.iter().cloned());
}

loop {
let targets_len = targets.len() as u32;
let active_voters = voters_by_stake.len() as u32;

let est_weight: Weight = tokio::task::spawn_blocking(move || {
T::solution_weight(active_voters, targets_len, active_voters, desired_targets)
})
.await?;

let max_weight: Weight = T::MaxWeight::get();
log::trace!(target: "staking-miner", "trimming weight: est_weight={est_weight} / max_weight={max_weight}");

if est_weight.all_lt(max_weight) {
return Ok(Self { state: State { voters, voters_by_stake }, _marker: PhantomData })
}

let Some((_, idx)) = voters_by_stake.pop_first() else { break };

let rm = voters[idx].0.clone();

// Remove votes for an account.
for (_voter, _stake, supports) in &mut voters {
supports.retain(|a| a != &rm);
}

targets.remove(&rm);
}

return Err(Error::Feasibility("Failed to pre-trim weight < T::MaxLength".to_string()))
}

/// Clone the state and trim it, so it get can be reverted.
pub fn trim(&mut self, n: usize) -> Result<State, Error> {
let mut voters = self.state.voters.clone();
let mut voters_by_stake = self.state.voters_by_stake.clone();

for _ in 0..n {
let Some((_, idx)) = voters_by_stake.pop_first() else {
return Err(Error::Feasibility("Failed to pre-trim len".to_string()))
};
let rm = voters[idx].0.clone();

// Remove votes for an account.
for (_voter, _stake, supports) in &mut voters {
supports.retain(|a| a != &rm);
}
}

Ok(State { voters, voters_by_stake })
}

pub fn to_voters(&self) -> Voters {
self.state.voters.clone()
}

pub fn len(&self) -> usize {
self.state.len()
}
}

/// Read the constants from the metadata and updates the static types.
pub(crate) async fn update_metadata_constants(api: &SubxtClient) -> Result<(), Error> {
const SIGNED_MAX_WEIGHT: EpmConstant = EpmConstant::new("SignedMaxWeight");
Expand Down Expand Up @@ -186,8 +303,44 @@ pub async fn snapshot_at(
}
}

/// Helper to fetch snapshot data via RPC
pub async fn mine_solution<T>(
solver: Solver,
targets: Vec<AccountId>,
voters: Voters,
desired_targets: u32,
) -> Result<(SolutionOf<T>, ElectionScore, SolutionOrSnapshotSize, TrimmingStatus), Error>
where
T: MinerConfig<AccountId = AccountId, MaxVotesPerVoter = static_types::MaxVotesPerVoter>
+ Send
+ Sync
+ 'static,
T::Solution: Send,
{
match tokio::task::spawn_blocking(move || match solver {
Solver::SeqPhragmen { iterations } => {
BalanceIterations::set(iterations);
Miner::<T>::mine_solution_with_snapshot::<
SequentialPhragmen<AccountId, Accuracy, Balancing>,
>(voters, targets, desired_targets)
},
Solver::PhragMMS { iterations } => {
BalanceIterations::set(iterations);
Miner::<T>::mine_solution_with_snapshot::<PhragMMS<AccountId, Accuracy, Balancing>>(
voters,
targets,
desired_targets,
)
},
})
.await
{
Ok(Ok(s)) => Ok(s),
Err(e) => Err(e.into()),
Ok(Err(e)) => Err(Error::Other(format!("{:?}", e))),
}
}

/// Helper to fetch snapshot data via RPC
/// and compute an NPos solution via [`pallet_election_provider_multi_phase`].
pub async fn fetch_snapshot_and_mine_solution<T>(
api: &SubxtClient,
Expand Down Expand Up @@ -219,47 +372,69 @@ where
.await?
.map(|score| score.0);

let voters = snapshot.voters.clone();
let targets = snapshot.targets.clone();
let mut voters = TrimmedVoters::<T>::new(snapshot.voters.clone(), desired_targets).await?;

log::trace!(
target: LOG_TARGET,
"mine solution: desired_targets={}, voters={}, targets={}",
let (solution, score, solution_or_snapshot_size, trim_status) = mine_solution::<T>(
solver.clone(),
snapshot.targets.clone(),
voters.to_voters(),
desired_targets,
voters.len(),
targets.len()
);
)
.await?;

let blocking_task = tokio::task::spawn_blocking(move || match solver {
Solver::SeqPhragmen { iterations } => {
BalanceIterations::set(iterations);
Miner::<T>::mine_solution_with_snapshot::<
SequentialPhragmen<AccountId, Accuracy, Balancing>,
>(voters, targets, desired_targets)
},
Solver::PhragMMS { iterations } => {
BalanceIterations::set(iterations);
Miner::<T>::mine_solution_with_snapshot::<PhragMMS<AccountId, Accuracy, Balancing>>(
voters,
targets,
desired_targets,
)
},
})
.await;
if !trim_status.is_trimmed() {
return Ok(MinedSolution {
round,
desired_targets,
snapshot,
minimum_untrusted_score,
solution,
score,
solution_or_snapshot_size,
})
}

prometheus::on_trim_attempt();

let mut l = 1;
let mut h = voters.len();
let mut best_solution = None;

while l <= h {
let mid = ((h - l) / 2) + l;

let next_state = voters.trim(mid)?;

match blocking_task {
Ok(Ok((solution, score, solution_or_snapshot_size))) => Ok(MinedSolution {
let (solution, score, solution_or_snapshot_size, trim_status) = mine_solution::<T>(
solver.clone(),
snapshot.targets.clone(),
next_state.to_voters(),
desired_targets,
)
.await?;

if !trim_status.is_trimmed() {
best_solution = Some((solution, score, solution_or_snapshot_size));
h = mid - 1;
} else {
l = mid + 1;
}
}

if let Some((solution, score, solution_or_snapshot_size)) = best_solution {
prometheus::on_trim_success();

Ok(MinedSolution {
round,
desired_targets,
snapshot,
minimum_untrusted_score,
solution,
score,
solution_or_snapshot_size,
}),
Ok(Err(err)) => Err(Error::Other(format!("{:?}", err))),
Err(err) => Err(err.into()),
})
} else {
Err(Error::Feasibility("Failed pre-trim length".to_string()))
}
}

Expand Down Expand Up @@ -315,6 +490,16 @@ where
}
}

impl<T: MinerConfig> std::fmt::Debug for MinedSolution<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MinedSolution")
.field("round", &self.round)
.field("desired_targets", &self.desired_targets)
.field("score", &self.score)
.finish()
}
}

fn make_type<T: scale_info::TypeInfo + 'static>() -> (TypeId, PortableRegistry) {
let m = scale_info::MetaType::new::<T>();
let mut types = scale_info::Registry::new();
Expand Down
Loading

0 comments on commit 2ea305e

Please sign in to comment.