diff --git a/src/vm/analysis/type_checker/mod.rs b/src/vm/analysis/type_checker/mod.rs index f3b7d5ced08..fec8fffd0f6 100644 --- a/src/vm/analysis/type_checker/mod.rs +++ b/src/vm/analysis/type_checker/mod.rs @@ -81,7 +81,7 @@ impl CostTracker for TypeChecker<'_, '_> { fn compute_cost( &mut self, cost_function: ClarityCostFunction, - input: u64, + input: &[u64], ) -> Result { self.cost_track.compute_cost(cost_function, input) } diff --git a/src/vm/callables.rs b/src/vm/callables.rs index 140532777e8..77308a44919 100644 --- a/src/vm/callables.rs +++ b/src/vm/callables.rs @@ -21,7 +21,7 @@ use std::iter::FromIterator; use chainstate::stacks::events::StacksTransactionEvent; -use vm::costs::{cost_functions, runtime_cost, SimpleCostSpecification}; +use vm::costs::{cost_functions, runtime_cost}; use vm::analysis::errors::CheckErrors; use vm::contexts::ContractContext; diff --git a/src/vm/contexts.rs b/src/vm/contexts.rs index 40224a77ca6..188eddd1fbe 100644 --- a/src/vm/contexts.rs +++ b/src/vm/contexts.rs @@ -638,7 +638,7 @@ impl CostTracker for Environment<'_, '_> { fn compute_cost( &mut self, cost_function: ClarityCostFunction, - input: u64, + input: &[u64], ) -> std::result::Result { self.global_context .cost_track @@ -672,7 +672,7 @@ impl CostTracker for GlobalContext<'_> { fn compute_cost( &mut self, cost_function: ClarityCostFunction, - input: u64, + input: &[u64], ) -> std::result::Result { self.cost_track.compute_cost(cost_function, input) } diff --git a/src/vm/costs/mod.rs b/src/vm/costs/mod.rs index bcc090d3fe3..2ec1cf84947 100644 --- a/src/vm/costs/mod.rs +++ b/src/vm/costs/mod.rs @@ -1,4 +1,4 @@ -// Copyright (C) 2013-2020 Blocstack PBC, a public benefit corporation +// Copyright (C) 2013-2020 Blockstack PBC, a public benefit corporation // Copyright (C) 2020 Stacks Open Internet Foundation // // This program is free software: you can redistribute it and/or modify @@ -32,12 +32,15 @@ use vm::contexts::{ContractContext, Environment, GlobalContext, OwnedEnvironment use vm::costs::cost_functions::ClarityCostFunction; use vm::database::{marf::NullBackingStore, ClarityDatabase, MemoryBackingStore}; use vm::errors::{Error, InterpreterResult}; +use vm::types::signatures::FunctionType::Fixed; use vm::types::Value::UInt; -use vm::types::{PrincipalData, QualifiedContractIdentifier, TupleData, TypeSignature, NONE}; +use vm::types::{ + FunctionArg, FunctionType, PrincipalData, QualifiedContractIdentifier, TupleData, + TypeSignature, NONE, +}; use vm::{ast, eval_all, ClarityName, SymbolicExpression, Value}; use vm::types::signatures::{FunctionSignature, TupleTypeSignature}; -use vm::types::FunctionType; type Result = std::result::Result; @@ -62,7 +65,7 @@ pub fn runtime_cost, C: CostTracker>( input: T, ) -> Result<()> { let size: u64 = input.try_into().map_err(|_| CostErrors::CostOverflow)?; - let cost = tracker.compute_cost(cost_function, size)?; + let cost = tracker.compute_cost(cost_function, &[size])?; tracker.add_cost(cost) } @@ -84,7 +87,7 @@ pub fn analysis_typecheck_cost( let t2_size = t2.type_size().map_err(|_| CostErrors::CostOverflow)?; let cost = track.compute_cost( ClarityCostFunction::AnalysisTypeCheck, - cmp::max(t1_size, t2_size) as u64, + &[cmp::max(t1_size, t2_size) as u64], )?; track.add_cost(cost) } @@ -103,7 +106,7 @@ pub trait CostTracker { fn compute_cost( &mut self, cost_function: ClarityCostFunction, - input: u64, + input: &[u64], ) -> Result; fn add_cost(&mut self, cost: ExecutionCost) -> Result<()>; fn add_memory(&mut self, memory: u64) -> Result<()>; @@ -125,7 +128,7 @@ impl CostTracker for () { fn compute_cost( &mut self, _cost_function: ClarityCostFunction, - _input: u64, + _input: &[u64], ) -> std::result::Result { Ok(ExecutionCost::zero()) } @@ -420,7 +423,9 @@ fn load_cost_functions(clarity_db: &mut ClarityDatabase) -> Result { if !c.is_cost_contract_eligible { warn!("Confirmed cost proposal invalid: cost-function-contract uses non-arithmetic or otherwise illegal operations"; @@ -454,6 +459,13 @@ fn load_cost_functions(clarity_db: &mut ClarityDatabase) -> Result confirmed_proposal, @@ -462,10 +474,6 @@ fn load_cost_functions(clarity_db: &mut ClarityDatabase) -> Result { warn!("Confirmed cost proposal invalid: cost-function-contract is not a published contract"; @@ -491,7 +499,7 @@ fn load_cost_functions(clarity_db: &mut ClarityDatabase) -> Result { if !c.read_only_function_types.contains_key(&target_function) { @@ -502,6 +510,35 @@ fn load_cost_functions(clarity_db: &mut ClarityDatabase) -> Result { + if cf.args.len() != tf.args.len() { + warn!("Confirmed cost proposal invalid: cost-function contains the wrong number of arguments"; + "confirmed_proposal_id" => confirmed_proposal, + "target_contract_name" => %target_contract, + "target_function_name" => %target_function, + ); + continue; + } + for arg in &cf.args { + match &arg.signature { + TypeSignature::UIntType => {} + _ => { + warn!("Confirmed cost proposal invalid: contains non UInt argument"; + "confirmed_proposal_id" => confirmed_proposal, + ); + continue; + } + } + } + } + _ => { + panic!("Cost and target functions should be Fixed"); + } + } } None => { warn!("Confirmed cost proposal invalid: contract-name not a published contract"; @@ -702,7 +739,7 @@ fn parse_cost( fn compute_cost( cost_tracker: &mut LimitedCostTracker, cost_function_reference: ClarityCostFunctionReference, - input_size: u64, + input_sizes: &[u64], ) -> Result { let mut null_store = NullBackingStore::new(); let conn = null_store.as_clarity_db(); @@ -716,10 +753,15 @@ fn compute_cost( &cost_function_reference )))?; - let program = vec![ - SymbolicExpression::atom(cost_function_reference.function_name[..].into()), - SymbolicExpression::atom_value(Value::UInt(input_size.into())), - ]; + let mut program = vec![SymbolicExpression::atom( + cost_function_reference.function_name[..].into(), + )]; + + for input_size in input_sizes.iter() { + program.push(SymbolicExpression::atom_value(Value::UInt( + *input_size as u128, + ))); + } let function_invocation = [SymbolicExpression::list(program.into_boxed_slice())]; @@ -763,7 +805,7 @@ impl CostTracker for LimitedCostTracker { fn compute_cost( &mut self, cost_function: ClarityCostFunction, - input: u64, + input: &[u64], ) -> std::result::Result { if self.free { return Ok(ExecutionCost::zero()); @@ -814,8 +856,7 @@ impl CostTracker for LimitedCostTracker { // grr, if HashMap::get didn't require Borrow, we wouldn't need this cloning. let lookup_key = (contract.clone(), function.clone()); if let Some(cost_function) = self.contract_call_circuits.get(&lookup_key).cloned() { - let input_size = input.iter().fold(0, |agg, cur| agg + cur); - compute_cost(self, cost_function, input_size)?; + compute_cost(self, cost_function, input)?; Ok(true) } else { Ok(false) @@ -827,7 +868,7 @@ impl CostTracker for &mut LimitedCostTracker { fn compute_cost( &mut self, cost_function: ClarityCostFunction, - input: u64, + input: &[u64], ) -> std::result::Result { LimitedCostTracker::compute_cost(self, cost_function, input) } @@ -853,23 +894,6 @@ impl CostTracker for &mut LimitedCostTracker { } } -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] -pub enum CostFunctions { - Constant(u64), - Linear(u64, u64), - NLogN(u64, u64), - LogN(u64, u64), -} - -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] -pub struct SimpleCostSpecification { - pub write_count: CostFunctions, - pub write_length: CostFunctions, - pub read_count: CostFunctions, - pub read_length: CostFunctions, - pub runtime: CostFunctions, -} - #[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] pub struct ExecutionCost { pub write_length: u64, @@ -1018,63 +1042,6 @@ fn int_log2(input: u64) -> Option { }) } -impl CostFunctions { - pub fn compute_cost(&self, input: u64) -> Result { - match self { - CostFunctions::Constant(val) => Ok(*val), - CostFunctions::Linear(a, b) => a.cost_overflow_mul(input)?.cost_overflow_add(*b), - CostFunctions::LogN(a, b) => { - // a*log(input)) + b - // and don't do log(0). - int_log2(cmp::max(input, 1)) - .ok_or_else(|| CostErrors::CostOverflow)? - .cost_overflow_mul(*a)? - .cost_overflow_add(*b) - } - CostFunctions::NLogN(a, b) => { - // a*input*log(input)) + b - // and don't do log(0). - int_log2(cmp::max(input, 1)) - .ok_or_else(|| CostErrors::CostOverflow)? - .cost_overflow_mul(input)? - .cost_overflow_mul(*a)? - .cost_overflow_add(*b) - } - } - } -} - -impl SimpleCostSpecification { - pub fn compute_cost(&self, input: u64) -> Result { - Ok(ExecutionCost { - write_length: self.write_length.compute_cost(input)?, - write_count: self.write_count.compute_cost(input)?, - read_count: self.read_count.compute_cost(input)?, - read_length: self.read_length.compute_cost(input)?, - runtime: self.runtime.compute_cost(input)?, - }) - } -} - -impl From for SimpleCostSpecification { - fn from(value: ExecutionCost) -> SimpleCostSpecification { - let ExecutionCost { - write_length, - write_count, - read_count, - read_length, - runtime, - } = value; - SimpleCostSpecification { - write_length: CostFunctions::Constant(write_length), - write_count: CostFunctions::Constant(write_count), - read_length: CostFunctions::Constant(read_length), - read_count: CostFunctions::Constant(read_count), - runtime: CostFunctions::Constant(runtime), - } - } -} - #[cfg(test)] mod unit_tests { use super::*; @@ -1089,10 +1056,6 @@ mod unit_tests { u64::max_value().cost_overflow_mul(2), Err(CostErrors::CostOverflow) ); - assert_eq!( - CostFunctions::NLogN(1, 1).compute_cost(u64::max_value()), - Err(CostErrors::CostOverflow) - ); } #[test] diff --git a/src/vm/tests/costs.rs b/src/vm/tests/costs.rs index 8b806c5befe..e348409f9f9 100644 --- a/src/vm/tests/costs.rs +++ b/src/vm/tests/costs.rs @@ -467,6 +467,8 @@ fn test_cost_voting_integration() { QualifiedContractIdentifier::new(p1_principal.clone(), "cost-definer".into()); let bad_cost_definer = QualifiedContractIdentifier::new(p1_principal.clone(), "bad-cost-definer".into()); + let bad_cost_args_definer = + QualifiedContractIdentifier::new(p1_principal.clone(), "bad-cost-args-definer".into()); let intercepted = QualifiedContractIdentifier::new(p1_principal.clone(), "intercepted".into()); let caller = QualifiedContractIdentifier::new(p1_principal.clone(), "caller".into()); @@ -491,6 +493,10 @@ fn test_cost_voting_integration() { { runtime: u2, write_length: u0, write_count: u0, read_count: u0, read_length: u0 }) + (define-read-only (cost-definition-multi-arg (a uint) (b uint) (c uint)) + { + runtime: u1, write_length: u0, write_count: u0, read_count: u0, read_length: u0 + }) "; @@ -502,12 +508,23 @@ fn test_cost_voting_integration() { }) "; + let bad_cost_args_definer_src = " + (define-read-only (cost-definition (a uint) (b uint)) + { + runtime: u1, write_length: u1, write_count: u1, read_count: u1, read_length: u1 + }) + "; + let intercepted_src = " (define-read-only (intercepted-function (a uint)) (if (>= a u10) (+ (+ a a) (+ a a) (+ a a) (+ a a)) u0)) + + (define-read-only (intercepted-function2 (a uint) (b uint) (c uint)) + (- (+ a b) c)) + (define-public (non-read-only) (ok (+ 1 2 3))) "; @@ -521,6 +538,7 @@ fn test_cost_voting_integration() { (&intercepted, intercepted_src), (&caller, caller_src), (&bad_cost_definer, bad_cost_definer_src), + (&bad_cost_args_definer, bad_cost_args_definer_src), ] .iter() { @@ -606,6 +624,13 @@ fn test_cost_voting_integration() { bad_cost_definer.clone().into(), "cost-definition", ), + // cost defining contract has incorrect number of arguments + ( + intercepted.clone().into(), + "intercepted-function", + bad_cost_args_definer.clone().into(), + "cost-definition", + ), ]; let bad_proposals = bad_cases.len(); @@ -692,6 +717,12 @@ fn test_cost_voting_integration() { cost_definer.clone(), "cost-definition-le", ), + ( + intercepted.clone(), + "intercepted-function2", + cost_definer.clone(), + "cost-definition-multi-arg", + ), ]; marf_kv.test_commit(); @@ -751,13 +782,19 @@ fn test_cost_voting_integration() { let (_db, tracker) = owned_env.destruct().unwrap(); let circuits = tracker.contract_call_circuits(); - assert_eq!(circuits.len(), 1); - for (target, referenced_function) in circuits.into_iter() { - assert_eq!(&target.0, &intercepted); - assert_eq!(&target.1.to_string(), "intercepted-function"); - assert_eq!(&referenced_function.contract_id, &cost_definer); - assert_eq!(&referenced_function.function_name, "cost-definition"); - } + assert_eq!(circuits.len(), 2); + + let circuit1 = circuits.get(&(intercepted.clone(), "intercepted-function".into())); + let circuit2 = circuits.get(&(intercepted.clone(), "intercepted-function2".into())); + + assert!(circuit1.is_some()); + assert!(circuit2.is_some()); + + assert_eq!(circuit1.unwrap().contract_id, cost_definer); + assert_eq!(circuit1.unwrap().function_name, "cost-definition"); + + assert_eq!(circuit2.unwrap().contract_id, cost_definer); + assert_eq!(circuit2.unwrap().function_name, "cost-definition-multi-arg"); for (target, referenced_function) in tracker.cost_function_references().into_iter() { if target == &ClarityCostFunction::Le {