Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/vm/analysis/type_checker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ impl CostTracker for TypeChecker<'_, '_> {
fn compute_cost(
&mut self,
cost_function: ClarityCostFunction,
input: u64,
input: &[u64],
) -> Result<ExecutionCost, CostErrors> {
self.cost_track.compute_cost(cost_function, input)
}
Expand Down
2 changes: 1 addition & 1 deletion src/vm/callables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::iter::FromIterator;

use chainstate::stacks::events::StacksTransactionEvent;

use vm::costs::{cost_functions, runtime_cost, SimpleCostSpecification};
use vm::costs::{cost_functions, runtime_cost};

use vm::analysis::errors::CheckErrors;
use vm::contexts::ContractContext;
Expand Down
4 changes: 2 additions & 2 deletions src/vm/contexts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ impl CostTracker for Environment<'_, '_> {
fn compute_cost(
&mut self,
cost_function: ClarityCostFunction,
input: u64,
input: &[u64],
) -> std::result::Result<ExecutionCost, CostErrors> {
self.global_context
.cost_track
Expand Down Expand Up @@ -672,7 +672,7 @@ impl CostTracker for GlobalContext<'_> {
fn compute_cost(
&mut self,
cost_function: ClarityCostFunction,
input: u64,
input: &[u64],
) -> std::result::Result<ExecutionCost, CostErrors> {
self.cost_track.compute_cost(cost_function, input)
}
Expand Down
163 changes: 63 additions & 100 deletions src/vm/costs/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2013-2020 Blocstack PBC, a public benefit corporation
// Copyright (C) 2013-2020 Blockstack PBC, a public benefit corporation
// Copyright (C) 2020 Stacks Open Internet Foundation
//
// This program is free software: you can redistribute it and/or modify
Expand Down Expand Up @@ -32,12 +32,15 @@ use vm::contexts::{ContractContext, Environment, GlobalContext, OwnedEnvironment
use vm::costs::cost_functions::ClarityCostFunction;
use vm::database::{marf::NullBackingStore, ClarityDatabase, MemoryBackingStore};
use vm::errors::{Error, InterpreterResult};
use vm::types::signatures::FunctionType::Fixed;
use vm::types::Value::UInt;
use vm::types::{PrincipalData, QualifiedContractIdentifier, TupleData, TypeSignature, NONE};
use vm::types::{
FunctionArg, FunctionType, PrincipalData, QualifiedContractIdentifier, TupleData,
TypeSignature, NONE,
};
use vm::{ast, eval_all, ClarityName, SymbolicExpression, Value};

use vm::types::signatures::{FunctionSignature, TupleTypeSignature};
use vm::types::FunctionType;

type Result<T> = std::result::Result<T, CostErrors>;

Expand All @@ -62,7 +65,7 @@ pub fn runtime_cost<T: TryInto<u64>, C: CostTracker>(
input: T,
) -> Result<()> {
let size: u64 = input.try_into().map_err(|_| CostErrors::CostOverflow)?;
let cost = tracker.compute_cost(cost_function, size)?;
let cost = tracker.compute_cost(cost_function, &[size])?;

tracker.add_cost(cost)
}
Expand All @@ -84,7 +87,7 @@ pub fn analysis_typecheck_cost<T: CostTracker>(
let t2_size = t2.type_size().map_err(|_| CostErrors::CostOverflow)?;
let cost = track.compute_cost(
ClarityCostFunction::AnalysisTypeCheck,
cmp::max(t1_size, t2_size) as u64,
&[cmp::max(t1_size, t2_size) as u64],
)?;
track.add_cost(cost)
}
Expand All @@ -103,7 +106,7 @@ pub trait CostTracker {
fn compute_cost(
&mut self,
cost_function: ClarityCostFunction,
input: u64,
input: &[u64],
) -> Result<ExecutionCost>;
fn add_cost(&mut self, cost: ExecutionCost) -> Result<()>;
fn add_memory(&mut self, memory: u64) -> Result<()>;
Expand All @@ -125,7 +128,7 @@ impl CostTracker for () {
fn compute_cost(
&mut self,
_cost_function: ClarityCostFunction,
_input: u64,
_input: &[u64],
) -> std::result::Result<ExecutionCost, CostErrors> {
Ok(ExecutionCost::zero())
}
Expand Down Expand Up @@ -420,7 +423,9 @@ fn load_cost_functions(clarity_db: &mut ClarityDatabase) -> Result<CostStateSumm

// make sure the contract is "cost contract eligible" via the
// arithmetic-checking analysis pass
let cost_func_ref = match clarity_db.load_contract_analysis(&cost_contract) {
let (cost_func_ref, cost_func_type) = match clarity_db
.load_contract_analysis(&cost_contract)
{
Some(c) => {
if !c.is_cost_contract_eligible {
warn!("Confirmed cost proposal invalid: cost-function-contract uses non-arithmetic or otherwise illegal operations";
Expand Down Expand Up @@ -454,6 +459,13 @@ fn load_cost_functions(clarity_db: &mut ClarityDatabase) -> Result<CostStateSumm
);
continue;
}
(
ClarityCostFunctionReference {
contract_id: cost_contract,
function_name: cost_function.to_string(),
},
cost_function_type.clone(),
)
} else {
warn!("Confirmed cost proposal invalid: cost-function-name not defined";
"confirmed_proposal_id" => confirmed_proposal,
Expand All @@ -462,10 +474,6 @@ fn load_cost_functions(clarity_db: &mut ClarityDatabase) -> Result<CostStateSumm
);
continue;
}
ClarityCostFunctionReference {
contract_id: cost_contract,
function_name: cost_function.to_string(),
}
}
None => {
warn!("Confirmed cost proposal invalid: cost-function-contract is not a published contract";
Expand All @@ -491,7 +499,7 @@ fn load_cost_functions(clarity_db: &mut ClarityDatabase) -> Result<CostStateSumm
.cost_function_references
.insert(target, cost_func_ref);
} else {
// refering to a user-defined function
// referring to a user-defined function
match clarity_db.load_contract_analysis(&target_contract) {
Some(c) => {
if !c.read_only_function_types.contains_key(&target_function) {
Expand All @@ -502,6 +510,35 @@ fn load_cost_functions(clarity_db: &mut ClarityDatabase) -> Result<CostStateSumm
);
continue;
}
match (
c.read_only_function_types.get(&target_function).unwrap(),
&cost_func_type,
) {
(Fixed(tf), cf) => {
if cf.args.len() != tf.args.len() {
warn!("Confirmed cost proposal invalid: cost-function contains the wrong number of arguments";
"confirmed_proposal_id" => confirmed_proposal,
"target_contract_name" => %target_contract,
"target_function_name" => %target_function,
);
continue;
}
for arg in &cf.args {
match &arg.signature {
TypeSignature::UIntType => {}
_ => {
warn!("Confirmed cost proposal invalid: contains non UInt argument";
"confirmed_proposal_id" => confirmed_proposal,
);
continue;
}
}
}
}
_ => {
panic!("Cost and target functions should be Fixed");
}
}
}
None => {
warn!("Confirmed cost proposal invalid: contract-name not a published contract";
Expand Down Expand Up @@ -702,7 +739,7 @@ fn parse_cost(
fn compute_cost(
cost_tracker: &mut LimitedCostTracker,
cost_function_reference: ClarityCostFunctionReference,
input_size: u64,
input_sizes: &[u64],
) -> Result<ExecutionCost> {
let mut null_store = NullBackingStore::new();
let conn = null_store.as_clarity_db();
Expand All @@ -716,10 +753,15 @@ fn compute_cost(
&cost_function_reference
)))?;

let program = vec![
SymbolicExpression::atom(cost_function_reference.function_name[..].into()),
SymbolicExpression::atom_value(Value::UInt(input_size.into())),
];
let mut program = vec![SymbolicExpression::atom(
cost_function_reference.function_name[..].into(),
)];

for input_size in input_sizes.iter() {
program.push(SymbolicExpression::atom_value(Value::UInt(
*input_size as u128,
)));
}

let function_invocation = [SymbolicExpression::list(program.into_boxed_slice())];

Expand Down Expand Up @@ -763,7 +805,7 @@ impl CostTracker for LimitedCostTracker {
fn compute_cost(
&mut self,
cost_function: ClarityCostFunction,
input: u64,
input: &[u64],
) -> std::result::Result<ExecutionCost, CostErrors> {
if self.free {
return Ok(ExecutionCost::zero());
Expand Down Expand Up @@ -814,8 +856,7 @@ impl CostTracker for LimitedCostTracker {
// grr, if HashMap::get didn't require Borrow, we wouldn't need this cloning.
let lookup_key = (contract.clone(), function.clone());
if let Some(cost_function) = self.contract_call_circuits.get(&lookup_key).cloned() {
let input_size = input.iter().fold(0, |agg, cur| agg + cur);
compute_cost(self, cost_function, input_size)?;
compute_cost(self, cost_function, input)?;
Ok(true)
} else {
Ok(false)
Expand All @@ -827,7 +868,7 @@ impl CostTracker for &mut LimitedCostTracker {
fn compute_cost(
&mut self,
cost_function: ClarityCostFunction,
input: u64,
input: &[u64],
) -> std::result::Result<ExecutionCost, CostErrors> {
LimitedCostTracker::compute_cost(self, cost_function, input)
}
Expand All @@ -853,23 +894,6 @@ impl CostTracker for &mut LimitedCostTracker {
}
}

#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
pub enum CostFunctions {
Constant(u64),
Linear(u64, u64),
NLogN(u64, u64),
LogN(u64, u64),
}

#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
pub struct SimpleCostSpecification {
pub write_count: CostFunctions,
pub write_length: CostFunctions,
pub read_count: CostFunctions,
pub read_length: CostFunctions,
pub runtime: CostFunctions,
}

#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
pub struct ExecutionCost {
pub write_length: u64,
Expand Down Expand Up @@ -1018,63 +1042,6 @@ fn int_log2(input: u64) -> Option<u64> {
})
}

impl CostFunctions {
pub fn compute_cost(&self, input: u64) -> Result<u64> {
match self {
CostFunctions::Constant(val) => Ok(*val),
CostFunctions::Linear(a, b) => a.cost_overflow_mul(input)?.cost_overflow_add(*b),
CostFunctions::LogN(a, b) => {
// a*log(input)) + b
// and don't do log(0).
int_log2(cmp::max(input, 1))
.ok_or_else(|| CostErrors::CostOverflow)?
.cost_overflow_mul(*a)?
.cost_overflow_add(*b)
}
CostFunctions::NLogN(a, b) => {
// a*input*log(input)) + b
// and don't do log(0).
int_log2(cmp::max(input, 1))
.ok_or_else(|| CostErrors::CostOverflow)?
.cost_overflow_mul(input)?
.cost_overflow_mul(*a)?
.cost_overflow_add(*b)
}
}
}
}

impl SimpleCostSpecification {
pub fn compute_cost(&self, input: u64) -> Result<ExecutionCost> {
Ok(ExecutionCost {
write_length: self.write_length.compute_cost(input)?,
write_count: self.write_count.compute_cost(input)?,
read_count: self.read_count.compute_cost(input)?,
read_length: self.read_length.compute_cost(input)?,
runtime: self.runtime.compute_cost(input)?,
})
}
}

impl From<ExecutionCost> for SimpleCostSpecification {
fn from(value: ExecutionCost) -> SimpleCostSpecification {
let ExecutionCost {
write_length,
write_count,
read_count,
read_length,
runtime,
} = value;
SimpleCostSpecification {
write_length: CostFunctions::Constant(write_length),
write_count: CostFunctions::Constant(write_count),
read_length: CostFunctions::Constant(read_length),
read_count: CostFunctions::Constant(read_count),
runtime: CostFunctions::Constant(runtime),
}
}
}

#[cfg(test)]
mod unit_tests {
use super::*;
Expand All @@ -1089,10 +1056,6 @@ mod unit_tests {
u64::max_value().cost_overflow_mul(2),
Err(CostErrors::CostOverflow)
);
assert_eq!(
CostFunctions::NLogN(1, 1).compute_cost(u64::max_value()),
Err(CostErrors::CostOverflow)
);
}

#[test]
Expand Down
Loading