Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stake-pool: Truncate on withdrawal calculation #3804

Merged
merged 1 commit into from Nov 19, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
82 changes: 77 additions & 5 deletions stake-pool/program/src/state.rs
Expand Up @@ -18,7 +18,6 @@ use {
pubkey::{Pubkey, PUBKEY_BYTES},
stake::state::Lockup,
},
spl_math::checked_ceil_div::CheckedCeilDiv,
spl_token::state::{Account, AccountState},
std::{borrow::Borrow, convert::TryFrom, fmt, matches},
};
Expand Down Expand Up @@ -172,16 +171,15 @@ impl StakePool {
/// calculate lamports amount on withdrawal
#[inline]
pub fn calc_lamports_withdraw_amount(&self, pool_tokens: u64) -> Option<u64> {
// `checked_ceil_div` returns `None` for a 0 quotient result, but in this
// `checked_div` returns `None` for a 0 quotient result, but in this
// case, a return of 0 is valid for small amounts of pool tokens. So
// we check for that separately
let numerator = (pool_tokens as u128).checked_mul(self.total_lamports as u128)?;
let denominator = self.pool_token_supply as u128;
if numerator < denominator || denominator == 0 {
Some(0)
} else {
let (quotient, _) = numerator.checked_ceil_div(denominator)?;
u64::try_from(quotient).ok()
u64::try_from(numerator.checked_div(denominator)?).ok()
}
}

Expand Down Expand Up @@ -1033,7 +1031,7 @@ mod test {
let fee_lamports = stake_pool
.calc_lamports_withdraw_amount(pool_token_fee)
.unwrap();
assert_eq!(fee_lamports, LAMPORTS_PER_SOL);
assert_eq!(fee_lamports, LAMPORTS_PER_SOL - 1); // off-by-one due to truncation
}

#[test]
Expand Down Expand Up @@ -1148,6 +1146,80 @@ mod test {
stake_pool.pool_token_supply += deposit_result;
let withdraw_result = stake_pool.calc_lamports_withdraw_amount(deposit_result).unwrap();
assert!(withdraw_result <= deposit_stake);

// also test splitting the withdrawal in two operations
if deposit_result >= 2 {
let first_half_deposit = deposit_result / 2;
let first_withdraw_result = stake_pool.calc_lamports_withdraw_amount(first_half_deposit).unwrap();
stake_pool.total_lamports -= first_withdraw_result;
stake_pool.pool_token_supply -= first_half_deposit;
let second_half_deposit = deposit_result - first_half_deposit; // do the whole thing
let second_withdraw_result = stake_pool.calc_lamports_withdraw_amount(second_half_deposit).unwrap();
assert!(first_withdraw_result + second_withdraw_result <= deposit_stake);
}
}
}

#[test]
fn specific_split_withdrawal() {
let total_lamports = 1_100_000_000_000;
let pool_token_supply = 1_000_000_000_000;
let deposit_stake = 3;
let mut stake_pool = StakePool {
total_lamports,
pool_token_supply,
..StakePool::default()
};
let deposit_result = stake_pool
.calc_pool_tokens_for_deposit(deposit_stake)
.unwrap();
assert!(deposit_result > 0);
stake_pool.total_lamports += deposit_stake;
stake_pool.pool_token_supply += deposit_result;
let withdraw_result = stake_pool
.calc_lamports_withdraw_amount(deposit_result / 2)
.unwrap();
assert!(withdraw_result * 2 <= deposit_stake);
}

#[test]
fn withdraw_all() {
let total_lamports = 1_100_000_000_000;
let pool_token_supply = 1_000_000_000_000;
let mut stake_pool = StakePool {
total_lamports,
pool_token_supply,
..StakePool::default()
};
// take everything out at once
let withdraw_result = stake_pool
.calc_lamports_withdraw_amount(pool_token_supply)
.unwrap();
assert_eq!(stake_pool.total_lamports, withdraw_result);

// take out 1, then the rest
let withdraw_result = stake_pool.calc_lamports_withdraw_amount(1).unwrap();
stake_pool.total_lamports -= withdraw_result;
stake_pool.pool_token_supply -= 1;
let withdraw_result = stake_pool
.calc_lamports_withdraw_amount(stake_pool.pool_token_supply)
.unwrap();
assert_eq!(stake_pool.total_lamports, withdraw_result);

// take out all except 1, then the rest
let mut stake_pool = StakePool {
total_lamports,
pool_token_supply,
..StakePool::default()
};
let withdraw_result = stake_pool
.calc_lamports_withdraw_amount(pool_token_supply - 1)
.unwrap();
stake_pool.total_lamports -= withdraw_result;
stake_pool.pool_token_supply = 1;
assert_ne!(stake_pool.total_lamports, 0);

let withdraw_result = stake_pool.calc_lamports_withdraw_amount(1).unwrap();
assert_eq!(stake_pool.total_lamports, withdraw_result);
}
}