Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(trimming): implement pre-trimming #687

Merged
merged 15 commits into from
Nov 13, 2023
1,102 changes: 840 additions & 262 deletions Cargo.lock

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ pin-project-lite = "0.2"
subxt = "0.29"
scale-value = "0.10.0"

# substrate
frame-election-provider-support = { git = "https://github.com/paritytech/substrate" }
pallet-election-provider-multi-phase = { git = "https://github.com/paritytech/substrate" }
sp-npos-elections = { git = "https://github.com/paritytech/substrate" }
frame-support = { git = "https://github.com/paritytech/substrate" }
sp-runtime = { git = "https://github.com/paritytech/substrate" }
# polkadot-sdk
frame-election-provider-support = { git = "https://github.com/paritytech/polkadot-sdk" }
pallet-election-provider-multi-phase = { git = "https://github.com/paritytech/polkadot-sdk" }
sp-npos-elections = { git = "https://github.com/paritytech/polkadot-sdk/" }
frame-support = { git = "https://github.com/paritytech/polkadot-sdk" }
sp-runtime = { git = "https://github.com/paritytech/polkadot-sdk" }

# prometheus
prometheus = "0.13"
Expand All @@ -37,7 +37,7 @@ once_cell = "1.18"
[dev-dependencies]
anyhow = "1"
assert_cmd = "2.0"
sp-storage = { git = "https://github.com/paritytech/substrate" }
sp-storage = { git = "https://github.com/paritytech/polkadot-sdk" }
regex = "1"

[features]
Expand Down
5 changes: 2 additions & 3 deletions src/commands/monitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ where

(solution, score)
},
(Err(e), _) => return Err(e),
(Err(e), _) => return Err(Error::Other(e.to_string())),
};

let best_head = get_latest_head(&api, config.listen).await?;
Expand Down Expand Up @@ -421,8 +421,7 @@ where
(Err(e), _) => {
log::warn!(
target: LOG_TARGET,
"submit_and_watch_solution failed: {:?}; skipping block: {}",
e,
"submit_and_watch_solution failed: {e}; skipping block: {}",
at.number
);
},
Expand Down
259 changes: 222 additions & 37 deletions src/epm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,34 @@ use crate::{
helpers::{storage_at, RuntimeDispatchInfo},
opt::{BalanceIterations, Balancing, Solver},
prelude::*,
static_types,
prometheus,
static_types::{self},
};

use std::{
collections::{BTreeMap, BTreeSet},
marker::PhantomData,
};

use codec::{Decode, Encode};
use frame_election_provider_support::{NposSolution, PhragMMS, SequentialPhragmen};
use frame_support::weights::Weight;
use pallet_election_provider_multi_phase::{RawSolution, ReadySolution, SolutionOrSnapshotSize};
use frame_election_provider_support::{Get, NposSolution, PhragMMS, SequentialPhragmen};
use frame_support::{weights::Weight, BoundedVec};
use pallet_election_provider_multi_phase::{
unsigned::TrimmingStatus, RawSolution, ReadySolution, SolutionOf, SolutionOrSnapshotSize,
};
use scale_info::{PortableRegistry, TypeInfo};
use scale_value::scale::{decode_as_type, TypeId};
use sp_core::Bytes;
use sp_npos_elections::ElectionScore;
use sp_npos_elections::{ElectionScore, VoteWeight};
use subxt::{dynamic::Value, rpc::rpc_params, tx::DynamicPayload};

const EPM_PALLET_NAME: &str = "ElectionProviderMultiPhase";

type MinerVoterOf =
frame_election_provider_support::Voter<AccountId, crate::static_types::MaxVotesPerVoter>;

type RoundSnapshot = pallet_election_provider_multi_phase::RoundSnapshot<AccountId, MinerVoterOf>;
type Voters =
Vec<(AccountId, VoteWeight, BoundedVec<AccountId, crate::static_types::MaxVotesPerVoter>)>;

#[derive(Copy, Clone, Debug)]
struct EpmConstant {
Expand All @@ -62,6 +72,113 @@ impl std::fmt::Display for EpmConstant {
}
}

#[derive(Debug)]
pub struct State {
voters: Voters,
voters_by_stake: BTreeMap<VoteWeight, usize>,
}

impl State {
fn len(&self) -> usize {
self.voters_by_stake.len()
}

fn to_voters(&self) -> Voters {
self.voters.clone()
}
}

/// Represent voters that may be trimmed
///
/// The trimming works by removing the voter with the least amount of stake.
///
/// It's using an internal `BTreeMap` to determine which voter to remove next
/// and the voters Vec can't be sorted because the EPM pallet will index into it
/// when checking the solution.
#[derive(Debug)]
pub struct TrimmedVoters<T> {
state: State,
_marker: PhantomData<T>,
}

impl<T> TrimmedVoters<T>
where
T: MinerConfig<AccountId = AccountId, MaxVotesPerVoter = static_types::MaxVotesPerVoter>
+ Send
+ Sync
+ 'static,
T::Solution: Send,
{
/// Create a new `TrimmedVotes`.
pub async fn new(mut voters: Voters, desired_targets: u32) -> Result<Self, Error> {
let mut voters_by_stake = BTreeMap::new();
let mut targets = BTreeSet::new();

for (idx, (_voter, stake, supports)) in voters.iter().enumerate() {
voters_by_stake.insert(*stake, idx);
targets.extend(supports.iter().cloned());
}

loop {
let targets_len = targets.len() as u32;
let active_voters = voters_by_stake.len() as u32;

let est_weight: Weight = tokio::task::spawn_blocking(move || {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decided to add this in the constructor/new because it should only be needed to do once

T::solution_weight(active_voters, targets_len, active_voters, desired_targets)
})
.await?;

let max_weight: Weight = T::MaxWeight::get();
log::trace!(target: "staking-miner", "trimming weight: est_weight={est_weight} / max_weight={max_weight}");

if est_weight.all_lt(max_weight) {
return Ok(Self { state: State { voters, voters_by_stake }, _marker: PhantomData })
}

let Some((_, idx)) = voters_by_stake.pop_first() else { break };

let rm = voters[idx].0.clone();

// Remove votes for an account.
for (_voter, _stake, supports) in &mut voters {
supports.retain(|a| a != &rm);
}

targets.remove(&rm);
}

return Err(Error::Feasibility("Failed to pre-trim weight < T::MaxLength".to_string()))
}

/// Clone the state and trim it, so it get can be reverted.
pub fn trim(&mut self, n: usize) -> Result<State, Error> {
let mut voters = self.state.voters.clone();
let mut voters_by_stake = self.state.voters_by_stake.clone();

for _ in 0..n {
let Some((_, idx)) = voters_by_stake.pop_first() else {
return Err(Error::Feasibility("Failed to pre-trim len".to_string()))
};
let rm = voters[idx].0.clone();

// Remove votes for an account.
for (_voter, _stake, supports) in &mut voters {
supports.retain(|a| a != &rm);
}
}

Ok(State { voters, voters_by_stake })
}

pub fn to_voters(&self) -> Voters {
self.state.voters.clone()
}

pub fn len(&self) -> usize {
self.state.len()
}
}

/// Read the constants from the metadata and updates the static types.
pub(crate) async fn update_metadata_constants(api: &SubxtClient) -> Result<(), Error> {
const SIGNED_MAX_WEIGHT: EpmConstant = EpmConstant::new("SignedMaxWeight");
Expand Down Expand Up @@ -186,8 +303,44 @@ pub async fn snapshot_at(
}
}

/// Helper to fetch snapshot data via RPC
pub async fn mine_solution<T>(
solver: Solver,
targets: Vec<AccountId>,
voters: Voters,
desired_targets: u32,
) -> Result<(SolutionOf<T>, ElectionScore, SolutionOrSnapshotSize, TrimmingStatus), Error>
where
T: MinerConfig<AccountId = AccountId, MaxVotesPerVoter = static_types::MaxVotesPerVoter>
+ Send
+ Sync
+ 'static,
T::Solution: Send,
{
match tokio::task::spawn_blocking(move || match solver {
Solver::SeqPhragmen { iterations } => {
BalanceIterations::set(iterations);
Miner::<T>::mine_solution_with_snapshot::<
SequentialPhragmen<AccountId, Accuracy, Balancing>,
>(voters, targets, desired_targets)
},
Solver::PhragMMS { iterations } => {
BalanceIterations::set(iterations);
Miner::<T>::mine_solution_with_snapshot::<PhragMMS<AccountId, Accuracy, Balancing>>(
voters,
targets,
desired_targets,
)
},
})
.await
{
Ok(Ok(s)) => Ok(s),
Err(e) => Err(e.into()),
Ok(Err(e)) => Err(Error::Other(format!("{:?}", e))),
}
}

/// Helper to fetch snapshot data via RPC
/// and compute an NPos solution via [`pallet_election_provider_multi_phase`].
pub async fn fetch_snapshot_and_mine_solution<T>(
api: &SubxtClient,
Expand Down Expand Up @@ -219,47 +372,69 @@ where
.await?
.map(|score| score.0);

let voters = snapshot.voters.clone();
let targets = snapshot.targets.clone();
let mut voters = TrimmedVoters::<T>::new(snapshot.voters.clone(), desired_targets).await?;

log::trace!(
target: LOG_TARGET,
"mine solution: desired_targets={}, voters={}, targets={}",
let (solution, score, solution_or_snapshot_size, trim_status) = mine_solution::<T>(
solver.clone(),
snapshot.targets.clone(),
voters.to_voters(),
desired_targets,
voters.len(),
targets.len()
);
)
.await?;

let blocking_task = tokio::task::spawn_blocking(move || match solver {
Solver::SeqPhragmen { iterations } => {
BalanceIterations::set(iterations);
Miner::<T>::mine_solution_with_snapshot::<
SequentialPhragmen<AccountId, Accuracy, Balancing>,
>(voters, targets, desired_targets)
},
Solver::PhragMMS { iterations } => {
BalanceIterations::set(iterations);
Miner::<T>::mine_solution_with_snapshot::<PhragMMS<AccountId, Accuracy, Balancing>>(
voters,
targets,
desired_targets,
)
},
})
.await;
if !trim_status.is_trimmed() {
return Ok(MinedSolution {
round,
desired_targets,
snapshot,
minimum_untrusted_score,
solution,
score,
solution_or_snapshot_size,
})
}

prometheus::on_trim_attempt();

let mut l = 1;
let mut h = voters.len();
let mut best_solution = None;

while l <= h {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

binary search to find the best "pre-trim"

let mid = ((h - l) / 2) + l;

let next_state = voters.trim(mid)?;

match blocking_task {
Ok(Ok((solution, score, solution_or_snapshot_size))) => Ok(MinedSolution {
let (solution, score, solution_or_snapshot_size, trim_status) = mine_solution::<T>(
solver.clone(),
snapshot.targets.clone(),
next_state.to_voters(),
desired_targets,
)
.await?;

if !trim_status.is_trimmed() {
best_solution = Some((solution, score, solution_or_snapshot_size));
h = mid - 1;
} else {
l = mid + 1;
}
}

if let Some((solution, score, solution_or_snapshot_size)) = best_solution {
prometheus::on_trim_success();

Ok(MinedSolution {
round,
desired_targets,
snapshot,
minimum_untrusted_score,
solution,
score,
solution_or_snapshot_size,
}),
Ok(Err(err)) => Err(Error::Other(format!("{:?}", err))),
Err(err) => Err(err.into()),
})
} else {
Err(Error::Feasibility("Failed pre-trim length".to_string()))
}
}

Expand Down Expand Up @@ -315,6 +490,16 @@ where
}
}

impl<T: MinerConfig> std::fmt::Debug for MinedSolution<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MinedSolution")
.field("round", &self.round)
.field("desired_targets", &self.desired_targets)
.field("score", &self.score)
.finish()
}
}

fn make_type<T: scale_info::TypeInfo + 'static>() -> (TypeId, PortableRegistry) {
let m = scale_info::MetaType::new::<T>();
let mut types = scale_info::Registry::new();
Expand Down
Loading