Skip to content

Commit

Permalink
stake-pool: Truncate on withdrawal calculation (#3804)
Browse files Browse the repository at this point in the history
  • Loading branch information
joncinque committed Nov 19, 2022
1 parent c5fcbef commit 84182ce
Showing 1 changed file with 77 additions and 5 deletions.
82 changes: 77 additions & 5 deletions stake-pool/program/src/state.rs
Expand Up @@ -18,7 +18,6 @@ use {
pubkey::{Pubkey, PUBKEY_BYTES},
stake::state::Lockup,
},
spl_math::checked_ceil_div::CheckedCeilDiv,
spl_token::state::{Account, AccountState},
std::{borrow::Borrow, convert::TryFrom, fmt, matches},
};
Expand Down Expand Up @@ -172,16 +171,15 @@ impl StakePool {
/// calculate lamports amount on withdrawal
#[inline]
pub fn calc_lamports_withdraw_amount(&self, pool_tokens: u64) -> Option<u64> {
// `checked_ceil_div` returns `None` for a 0 quotient result, but in this
// `checked_div` returns `None` for a 0 quotient result, but in this
// case, a return of 0 is valid for small amounts of pool tokens. So
// we check for that separately
let numerator = (pool_tokens as u128).checked_mul(self.total_lamports as u128)?;
let denominator = self.pool_token_supply as u128;
if numerator < denominator || denominator == 0 {
Some(0)
} else {
let (quotient, _) = numerator.checked_ceil_div(denominator)?;
u64::try_from(quotient).ok()
u64::try_from(numerator.checked_div(denominator)?).ok()
}
}

Expand Down Expand Up @@ -1033,7 +1031,7 @@ mod test {
let fee_lamports = stake_pool
.calc_lamports_withdraw_amount(pool_token_fee)
.unwrap();
assert_eq!(fee_lamports, LAMPORTS_PER_SOL);
assert_eq!(fee_lamports, LAMPORTS_PER_SOL - 1); // off-by-one due to truncation
}

#[test]
Expand Down Expand Up @@ -1148,6 +1146,80 @@ mod test {
stake_pool.pool_token_supply += deposit_result;
let withdraw_result = stake_pool.calc_lamports_withdraw_amount(deposit_result).unwrap();
assert!(withdraw_result <= deposit_stake);

// also test splitting the withdrawal in two operations
if deposit_result >= 2 {
let first_half_deposit = deposit_result / 2;
let first_withdraw_result = stake_pool.calc_lamports_withdraw_amount(first_half_deposit).unwrap();
stake_pool.total_lamports -= first_withdraw_result;
stake_pool.pool_token_supply -= first_half_deposit;
let second_half_deposit = deposit_result - first_half_deposit; // do the whole thing
let second_withdraw_result = stake_pool.calc_lamports_withdraw_amount(second_half_deposit).unwrap();
assert!(first_withdraw_result + second_withdraw_result <= deposit_stake);
}
}
}

#[test]
fn specific_split_withdrawal() {
let total_lamports = 1_100_000_000_000;
let pool_token_supply = 1_000_000_000_000;
let deposit_stake = 3;
let mut stake_pool = StakePool {
total_lamports,
pool_token_supply,
..StakePool::default()
};
let deposit_result = stake_pool
.calc_pool_tokens_for_deposit(deposit_stake)
.unwrap();
assert!(deposit_result > 0);
stake_pool.total_lamports += deposit_stake;
stake_pool.pool_token_supply += deposit_result;
let withdraw_result = stake_pool
.calc_lamports_withdraw_amount(deposit_result / 2)
.unwrap();
assert!(withdraw_result * 2 <= deposit_stake);
}

#[test]
fn withdraw_all() {
let total_lamports = 1_100_000_000_000;
let pool_token_supply = 1_000_000_000_000;
let mut stake_pool = StakePool {
total_lamports,
pool_token_supply,
..StakePool::default()
};
// take everything out at once
let withdraw_result = stake_pool
.calc_lamports_withdraw_amount(pool_token_supply)
.unwrap();
assert_eq!(stake_pool.total_lamports, withdraw_result);

// take out 1, then the rest
let withdraw_result = stake_pool.calc_lamports_withdraw_amount(1).unwrap();
stake_pool.total_lamports -= withdraw_result;
stake_pool.pool_token_supply -= 1;
let withdraw_result = stake_pool
.calc_lamports_withdraw_amount(stake_pool.pool_token_supply)
.unwrap();
assert_eq!(stake_pool.total_lamports, withdraw_result);

// take out all except 1, then the rest
let mut stake_pool = StakePool {
total_lamports,
pool_token_supply,
..StakePool::default()
};
let withdraw_result = stake_pool
.calc_lamports_withdraw_amount(pool_token_supply - 1)
.unwrap();
stake_pool.total_lamports -= withdraw_result;
stake_pool.pool_token_supply = 1;
assert_ne!(stake_pool.total_lamports, 0);

let withdraw_result = stake_pool.calc_lamports_withdraw_amount(1).unwrap();
assert_eq!(stake_pool.total_lamports, withdraw_result);
}
}

0 comments on commit 84182ce

Please sign in to comment.