diff --git a/stake-pool/program/src/state.rs b/stake-pool/program/src/state.rs index 31c2ef94665..b453b1b5a30 100644 --- a/stake-pool/program/src/state.rs +++ b/stake-pool/program/src/state.rs @@ -942,9 +942,13 @@ impl Fee { if self.denominator == 0 { return Some(0); } - (amt as u128) - .checked_mul(self.numerator as u128)? - .checked_div(self.denominator as u128) + let numerator = (amt as u128).checked_mul(self.numerator as u128)?; + // ceiling the calculation by adding (denominator - 1) to the numerator + let denominator = self.denominator as u128; + numerator + .checked_add(denominator)? + .checked_sub(1)? + .checked_div(denominator) } /// Withdrawal fees have some additional restrictions, diff --git a/stake-pool/program/tests/helpers/mod.rs b/stake-pool/program/tests/helpers/mod.rs index d82dbe5ad7d..23f146578ea 100644 --- a/stake-pool/program/tests/helpers/mod.rs +++ b/stake-pool/program/tests/helpers/mod.rs @@ -897,15 +897,17 @@ impl StakePoolAccounts { } pub fn calculate_fee(&self, amount: u64) -> u64 { - amount * self.epoch_fee.numerator / self.epoch_fee.denominator + (amount * self.epoch_fee.numerator + self.epoch_fee.denominator - 1) + / self.epoch_fee.denominator } pub fn calculate_withdrawal_fee(&self, pool_tokens: u64) -> u64 { - pool_tokens * self.withdrawal_fee.numerator / self.withdrawal_fee.denominator + (pool_tokens * self.withdrawal_fee.numerator + self.withdrawal_fee.denominator - 1) + / self.withdrawal_fee.denominator } pub fn calculate_inverse_withdrawal_fee(&self, pool_tokens: u64) -> u64 { - pool_tokens * self.withdrawal_fee.denominator + (pool_tokens * self.withdrawal_fee.denominator + self.withdrawal_fee.denominator - 1) / (self.withdrawal_fee.denominator - self.withdrawal_fee.numerator) } @@ -914,7 +916,8 @@ impl StakePoolAccounts { } pub fn calculate_sol_deposit_fee(&self, pool_tokens: u64) -> u64 { - pool_tokens * self.sol_deposit_fee.numerator / self.sol_deposit_fee.denominator + (pool_tokens * self.sol_deposit_fee.numerator + self.sol_deposit_fee.denominator - 1) + / self.sol_deposit_fee.denominator } pub fn calculate_sol_referral_fee(&self, deposit_fee_collected: u64) -> u64 { diff --git a/stake-pool/program/tests/withdraw_sol.rs b/stake-pool/program/tests/withdraw_sol.rs index 67b60a7a221..e8eb2423a68 100644 --- a/stake-pool/program/tests/withdraw_sol.rs +++ b/stake-pool/program/tests/withdraw_sol.rs @@ -214,7 +214,7 @@ async fn fail_overdraw_reserve() { .await; assert!(error.is_none(), "{:?}", error); - // try to withdraw one lamport, will overdraw + // try to withdraw one lamport after fees, will overdraw let error = stake_pool_accounts .withdraw_sol( &mut context.banks_client, @@ -222,7 +222,7 @@ async fn fail_overdraw_reserve() { &context.last_blockhash, &user, &pool_token_account, - 1, + 2, None, ) .await diff --git a/stake-pool/program/tests/withdraw_with_fee.rs b/stake-pool/program/tests/withdraw_with_fee.rs index b55283fbd1a..7a8e000c1cb 100644 --- a/stake-pool/program/tests/withdraw_with_fee.rs +++ b/stake-pool/program/tests/withdraw_with_fee.rs @@ -6,10 +6,10 @@ mod helpers; use { bincode::deserialize, helpers::*, - solana_program::{borsh0_10::try_from_slice_unchecked, pubkey::Pubkey, stake}, + solana_program::{pubkey::Pubkey, stake}, solana_program_test::*, solana_sdk::signature::{Keypair, Signer}, - spl_stake_pool::{minimum_stake_lamports, state}, + spl_stake_pool::minimum_stake_lamports, }; #[tokio::test] @@ -183,20 +183,8 @@ async fn success_empty_out_stake_with_fee() { .await; let lamports_to_withdraw = validator_stake_account.lamports - minimum_stake_lamports(&meta, stake_minimum_delegation); - let stake_pool_account = get_account( - &mut context.banks_client, - &stake_pool_accounts.stake_pool.pubkey(), - ) - .await; - let stake_pool = - try_from_slice_unchecked::(stake_pool_account.data.as_slice()).unwrap(); - let fee = stake_pool.stake_withdrawal_fee; - let inverse_fee = state::Fee { - numerator: fee.denominator - fee.numerator, - denominator: fee.denominator, - }; let pool_tokens_to_withdraw = - lamports_to_withdraw * inverse_fee.denominator / inverse_fee.numerator; + stake_pool_accounts.calculate_inverse_withdrawal_fee(lamports_to_withdraw); let last_blockhash = context .banks_client