Skip to content

Commit

Permalink
Add overflow checks to STP
Browse files Browse the repository at this point in the history
- Added overflow checks to sender transaction protocol. It is possible to crash the system in the
change calculation before the transaction is validated; this PR prevents it.
- Added an additinal unit test to verify three overflow errors are handled.
  • Loading branch information
hansieodendaal committed Oct 12, 2023
1 parent 1d1332d commit 0d77212
Show file tree
Hide file tree
Showing 8 changed files with 399 additions and 36 deletions.
3 changes: 2 additions & 1 deletion base_layer/core/src/blocks/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ impl BlockBuilder {
/// This function adds the provided transaction kernels to the block WITHOUT updating kernel_mmr_size in the header
pub fn add_kernels(mut self, mut kernels: Vec<TransactionKernel>) -> Self {
for kernel in &kernels {
self.total_fee += kernel.fee;
// Saturating add is used here to prevent overflow; invalid fees will be caught by block validation
self.total_fee = self.total_fee.saturating_add(kernel.fee);
}
self.kernels.append(&mut kernels);
self
Expand Down
4 changes: 3 additions & 1 deletion base_layer/core/src/transactions/fee.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ impl Fee {
num_outputs,
rounded_features_and_scripts_byte_size,
);
MicroMinotari::from(weight) * fee_per_gram
// Saturating multiplication is used here to prevent overflow only; invalid values will be caught with
// validation
MicroMinotari::from(weight.saturating_mul(fee_per_gram.0))
}

pub fn calculate_body(&self, fee_per_gram: MicroMinotari, body: &AggregateBody) -> std::io::Result<MicroMinotari> {
Expand Down
40 changes: 28 additions & 12 deletions base_layer/core/src/transactions/tari_amount.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,32 +113,42 @@ impl MicroMinotari {
Self(0)
}

pub fn checked_add(self, v: MicroMinotari) -> Option<MicroMinotari> {
self.as_u64().checked_add(v.as_u64()).map(Into::into)
pub fn checked_add<T>(&self, v: T) -> Option<MicroMinotari>
where T: AsRef<MicroMinotari> {
self.as_u64().checked_add(v.as_ref().as_u64()).map(Into::into)
}

pub fn checked_sub(self, v: MicroMinotari) -> Option<MicroMinotari> {
if self >= v {
return Some(self - v);
pub fn checked_sub<T>(&self, v: T) -> Option<MicroMinotari>
where T: AsRef<MicroMinotari> {
if self >= v.as_ref() {
return Some(self - v.as_ref());
}
None
}

pub fn checked_mul(self, v: MicroMinotari) -> Option<MicroMinotari> {
self.as_u64().checked_mul(v.as_u64()).map(Into::into)
pub fn checked_mul<T>(&self, v: T) -> Option<MicroMinotari>
where T: AsRef<MicroMinotari> {
self.as_u64().checked_mul(v.as_ref().as_u64()).map(Into::into)
}

pub fn checked_div(self, v: MicroMinotari) -> Option<MicroMinotari> {
self.as_u64().checked_div(v.as_u64()).map(Into::into)
pub fn checked_div<T>(&self, v: T) -> Option<MicroMinotari>
where T: AsRef<MicroMinotari> {
self.as_u64().checked_div(v.as_ref().as_u64()).map(Into::into)
}

pub fn saturating_sub(self, v: MicroMinotari) -> MicroMinotari {
if self >= v {
return self - v;
pub fn saturating_sub<T>(&self, v: T) -> MicroMinotari
where T: AsRef<MicroMinotari> {
if self >= v.as_ref() {
return self - v.as_ref();
}
Self(0)
}

pub fn saturating_add<T>(&self, v: T) -> MicroMinotari
where T: AsRef<MicroMinotari> {
self.0.saturating_add(v.as_ref().0).into()
}

#[inline]
pub fn as_u64(&self) -> u64 {
self.0
Expand All @@ -149,6 +159,12 @@ impl MicroMinotari {
}
}

impl AsRef<MicroMinotari> for MicroMinotari {
fn as_ref(&self) -> &MicroMinotari {
self
}
}

#[allow(clippy::identity_op)]
impl Display for MicroMinotari {
fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
Expand Down
25 changes: 18 additions & 7 deletions base_layer/core/src/transactions/test_helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ use crate::{
WalletOutput,
WalletOutputBuilder,
},
transaction_protocol::TransactionMetadata,
transaction_protocol::{transaction_initializer::SenderTransactionInitializer, TransactionMetadata},
weight::TransactionWeight,
SenderTransactionProtocol,
},
Expand Down Expand Up @@ -651,6 +651,22 @@ pub async fn create_stx_protocol(
schema: TransactionSchema,
key_manager: &TestKeyManager,
) -> (SenderTransactionProtocol, Vec<WalletOutput>) {
let mut outputs = Vec::with_capacity(schema.to.len());
let stx_builder = create_stx_protocol_internal(schema, key_manager, &mut outputs).await;

let stx_protocol = stx_builder.build().await.unwrap();
let change_output = stx_protocol.get_change_output().unwrap().unwrap();

outputs.push(change_output);
(stx_protocol, outputs)
}

#[allow(clippy::too_many_lines)]
pub async fn create_stx_protocol_internal(
schema: TransactionSchema,
key_manager: &TestKeyManager,
outputs: &mut Vec<WalletOutput>,
) -> SenderTransactionInitializer<TestKeyManager> {
let constants = ConsensusManager::builder(Network::LocalNet)
.build()
.unwrap()
Expand All @@ -676,7 +692,6 @@ pub async fn create_stx_protocol(
for tx_input in &schema.from {
stx_builder.with_input(tx_input.clone()).await.unwrap();
}
let mut outputs = Vec::with_capacity(schema.to.len());
for val in schema.to {
let (spending_key, _) = key_manager
.get_next_key(TransactionKeyManagerBranch::CommitmentMask.get_branch_key())
Expand Down Expand Up @@ -741,11 +756,7 @@ pub async fn create_stx_protocol(
stx_builder.with_output(utxo, sender_offset_key_id).await.unwrap();
}

let stx_protocol = stx_builder.build().await.unwrap();
let change_output = stx_protocol.get_change_output().unwrap().unwrap();

outputs.push(change_output);
(stx_protocol, outputs)
stx_builder
}

pub async fn create_coinbase_kernel(spending_key_id: &TariKeyId, key_manager: &TestKeyManager) -> TransactionKernel {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,12 +294,20 @@ where KM: TransactionKeyManagerInterface
// The number of outputs excluding a possible residual change output
let num_outputs = self.sender_custom_outputs.len() + usize::from(self.recipient.is_some());
let num_inputs = self.inputs.len();
let total_being_spent = self.inputs.iter().map(|i| i.output.value).sum::<MicroMinotari>();
let total_being_spent = self
.inputs
.iter()
.map(|i| i.output.value)
.fold(Ok(MicroMinotari::zero()), |acc, x| {
acc?.checked_add(x).ok_or("Total inputs being spent amount overflow")
})?;
let total_to_self = self
.sender_custom_outputs
.iter()
.map(|o| o.output.value)
.sum::<MicroMinotari>();
.fold(Ok(MicroMinotari::zero()), |acc, x| {
acc?.checked_add(x).ok_or("Total outputs to self amount overflow")
})?;
let total_amount = match &self.recipient {
Some(data) => data.amount,
None => 0.into(),
Expand Down Expand Up @@ -332,19 +340,23 @@ where KM: TransactionKeyManagerInterface
.weighting()
.round_up_features_and_scripts_size(change_features_and_scripts_size);

let change_fee = self
.fee()
.calculate(fee_per_gram, 0, 0, 1, change_features_and_scripts_size);
// Subtract with a check on going negative
let total_input_value = total_to_self + total_amount + fee_without_change;
let total_input_value = [total_to_self, total_amount, fee_without_change]
.iter()
.fold(Ok(MicroMinotari::zero()), |acc, x| {
acc?.checked_add(x).ok_or("Total input value overflow")
})?;
let change_amount = total_being_spent.checked_sub(total_input_value);
match change_amount {
None => Err(format!(
"You are spending ({}) more than you're providing ({}).",
total_input_value, total_being_spent
"You are spending more than you're providing: provided {}, required {}.",
total_being_spent, total_input_value
)),
Some(MicroMinotari(0)) => Ok((fee_without_change, MicroMinotari(0), None)),
Some(v) => {
let change_fee = self
.fee()
.calculate(fee_per_gram, 0, 0, 1, change_features_and_scripts_size);
let change_amount = v.checked_sub(change_fee);
match change_amount {
// You can't win. Just add the change to the fee (which is less than the cost of adding another
Expand Down Expand Up @@ -909,7 +921,7 @@ mod test {
let err = builder.build().await.unwrap_err();
assert_eq!(
err.message,
"You are spending (528 µT) more than you're providing (400 µT)."
"You are spending more than you're providing: provided 400 µT, required 528 µT."
);
}

Expand Down
146 changes: 144 additions & 2 deletions base_layer/core/tests/tests/block_validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@ use tari_common_types::types::FixedHash;
use tari_core::{
blocks::{Block, BlockHeaderAccumulatedData, BlockHeaderValidationError, BlockValidationError, ChainBlock},
chain_storage::{BlockchainDatabase, BlockchainDatabaseConfig, ChainStorageError, Validators},
consensus::{consensus_constants::PowAlgorithmConstants, ConsensusConstantsBuilder, ConsensusManager},
consensus::{
consensus_constants::PowAlgorithmConstants,
emission::Emission,
ConsensusConstantsBuilder,
ConsensusManager,
},
proof_of_work::{
monero_rx,
monero_rx::{verify_header, FixedByteArray, MoneroPowData},
Expand All @@ -42,7 +47,7 @@ use tari_core::{
transactions::{
aggregated_body::AggregateBody,
key_manager::TransactionKeyManagerInterface,
tari_amount::{uT, T},
tari_amount::{uT, MicroMinotari, T},
test_helpers::{
create_test_core_key_manager_with_memory_db,
create_wallet_output_with_data,
Expand Down Expand Up @@ -1166,3 +1171,140 @@ async fn add_block_with_large_many_output_block() {
// of the block
println!("finished validating in: {}", finished.as_millis());
}

use tari_core::{
blocks::{BlockHeader, NewBlockTemplate},
transactions::{
test_helpers::create_stx_protocol_internal,
transaction_components::{Transaction, TransactionKernel},
},
};

use crate::helpers::{block_builders::generate_new_block, sample_blockchains::create_new_blockchain};

#[tokio::test]
#[allow(clippy::too_many_lines)]
async fn test_fee_overflow() {
let network = Network::LocalNet;
let (mut store, mut blocks, mut outputs, consensus_manager, key_manager) = create_new_blockchain(network).await;
let schemas = vec![txn_schema!(
from: vec![outputs[0][0].clone()],
to: vec![10 * T, 10 * T, 10 * T, 10 * T]
)];
generate_new_block(
&mut store,
&mut blocks,
&mut outputs,
schemas,
&consensus_manager,
&key_manager,
)
.await
.unwrap();

let schemas = vec![
txn_schema!(
from: vec![outputs[1][0].clone()],
to: vec![1 * T, 1 * T, 1 * T, 1 * T]
),
txn_schema!(
from: vec![outputs[1][1].clone()],
to: vec![1 * T, 1 * T, 1 * T, 1 * T]
),
txn_schema!(
from: vec![outputs[1][2].clone()],
to: vec![1 * T, 1 * T, 1 * T, 1 * T]
),
];

let coinbase_value = consensus_manager
.emission_schedule()
.block_reward(store.get_height().unwrap() + 1);

let mut transactions = Vec::new();
let mut block_utxos = Vec::new();
let mut fees = MicroMinotari(0);
for schema in schemas {
let (tx, mut utxos) = spend_utxos(schema, &key_manager).await;
fees += tx.body.get_total_fee().unwrap();
transactions.push(tx);
block_utxos.append(&mut utxos);
}

let (coinbase_utxo, coinbase_kernel, coinbase_output) =
create_coinbase(coinbase_value + fees, 100, None, &key_manager).await;
block_utxos.push(coinbase_output);

outputs.push(block_utxos);

let mut header = BlockHeader::from_previous(blocks.last().unwrap().header());
header.version = consensus_manager
.consensus_constants(header.height)
.blockchain_version();
let height = header.height;

let mut transactions_new = Vec::with_capacity(transactions.len());
for txn in transactions {
transactions_new.push(Transaction {
offset: txn.offset,
body: {
let mut inputs = Vec::with_capacity(txn.body.inputs().len());
for input in txn.body.inputs().iter() {
inputs.push(input.clone());
}
let mut outputs = Vec::with_capacity(txn.body.outputs().len());
for output in txn.body.outputs().iter() {
outputs.push(output.clone());
}
let mut kernels = Vec::with_capacity(txn.body.kernels().len());
for kernel in txn.body.kernels().iter() {
kernels.push(TransactionKernel {
version: kernel.version,
features: kernel.features,
fee: (u64::MAX / 2).into(), // This is the adversary's attack!
lock_height: kernel.lock_height,
excess_sig: kernel.excess_sig.clone(),
excess: kernel.excess.clone(),
burn_commitment: kernel.burn_commitment.clone(),
});
}
AggregateBody::new(inputs, outputs, kernels)
},
script_offset: txn.script_offset,
});
}

// This will call `BlockBuilder::add_kernels(...)` and `AggregateBody::get_total_fee(...)`, which will overflow if
// regressed
let template_result = NewBlockTemplate::from_block(
header
.into_builder()
.with_transactions(transactions_new)
.with_coinbase_utxo(coinbase_utxo, coinbase_kernel)
.build(),
Difficulty::min(),
consensus_manager.get_block_reward_at(height),
);
assert!(template_result.is_err());
assert_eq!(
template_result.unwrap_err().to_string(),
"Invalid kernel in body: Aggregated body has greater fee than u64::MAX".to_string()
);

let schema = txn_schema!(
from: vec![outputs[1][3].clone()],
to: vec![],
fee: MicroMinotari(u64::MAX / 2), // This is the adversary's attack!
lock: 0,
features: OutputFeatures::default()
);
let stx_builder = create_stx_protocol_internal(schema, &key_manager, &mut Vec::new()).await;

// This will call `Fee::calculate(...)`, which will overflow if regressed
let build_result = stx_builder.build().await;
assert!(build_result.is_err());
assert!(format!("{:?}", build_result.unwrap_err())
// Using less decimal points in the comparison to make provision for the rounding error - the actual value
// should be `18446744073691.551615 T`, but the formatting results in `18446744073709.550781 T`
.contains("You are spending more than you're providing: provided 10.000000 T, required 18446744073709.55"));
}
2 changes: 1 addition & 1 deletion base_layer/core/tests/tests/mempool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

use std::{convert::TryFrom, ops::Deref, panic, sync::Arc, time::Duration};
use std::{convert::TryFrom, ops::Deref, sync::Arc, time::Duration};

use randomx_rs::RandomXFlag;
use tari_common::configuration::Network;
Expand Down
Loading

0 comments on commit 0d77212

Please sign in to comment.