diff --git a/src/policy/compiler.rs b/src/policy/compiler.rs index 97aa60e65..e0204a2c0 100644 --- a/src/policy/compiler.rs +++ b/src/policy/compiler.rs @@ -36,7 +36,7 @@ type PolicyCache = ///Ordered f64 for comparison #[derive(Copy, Clone, PartialEq, PartialOrd, Debug)] -struct OrdF64(f64); +pub(crate) struct OrdF64(pub f64); impl Eq for OrdF64 {} impl Ord for OrdF64 { @@ -1007,18 +1007,23 @@ where }) .collect(); - if key_vec.len() == subs.len() && subs.len() <= MAX_PUBKEYS_PER_MULTISIG { - insert_wrap!(AstElemExt::terminal(Terminal::Multi(k, key_vec))); - } - // Not a threshold, it's always more optimal to translate it to and()s as we save the - // resulting threshold check (N EQUAL) in any case. - else if k == subs.len() { - let mut policy = subs.first().expect("No sub policy in thresh() ?").clone(); - for sub in &subs[1..] { - policy = Concrete::And(vec![sub.clone(), policy]); + match Ctx::sig_type() { + SigType::Schnorr if key_vec.len() == subs.len() => { + insert_wrap!(AstElemExt::terminal(Terminal::MultiA(k, key_vec))) + } + SigType::Ecdsa + if key_vec.len() == subs.len() && subs.len() <= MAX_PUBKEYS_PER_MULTISIG => + { + insert_wrap!(AstElemExt::terminal(Terminal::Multi(k, key_vec))) } + _ if k == subs.len() => { + let mut it = subs.iter(); + let mut policy = it.next().expect("No sub policy in thresh() ?").clone(); + policy = it.fold(policy, |acc, pol| Concrete::And(vec![acc, pol.clone()])); - ret = best_compilations(policy_cache, &policy, sat_prob, dissat_prob)?; + ret = best_compilations(policy_cache, &policy, sat_prob, dissat_prob)?; + } + _ => {} } // FIXME: Should we also optimize thresh(1, subs) ? @@ -1569,6 +1574,17 @@ mod tests { )) ); } + + #[test] + fn compile_tr_thresh() { + for k in 1..4 { + let small_thresh: Concrete = + policy_str!("{}", &format!("thresh({},pk(B),pk(C),pk(D))", k)); + let small_thresh_ms: Miniscript = small_thresh.compile().unwrap(); + let small_thresh_ms_expected: Miniscript = ms_str!("multi_a({},B,C,D)", k); + assert_eq!(small_thresh_ms, small_thresh_ms_expected); + } + } } #[cfg(all(test, feature = "unstable"))] diff --git a/src/policy/concrete.rs b/src/policy/concrete.rs index 7fb48a1f0..d3cda6b82 100644 --- a/src/policy/concrete.rs +++ b/src/policy/concrete.rs @@ -20,20 +20,27 @@ use std::{error, fmt, str}; use bitcoin::hashes::hex::FromHex; use bitcoin::hashes::{hash160, ripemd160, sha256, sha256d}; +#[cfg(feature = "compiler")] +use { + crate::descriptor::TapTree, + crate::miniscript::ScriptContext, + crate::policy::compiler::CompilerError, + crate::policy::compiler::OrdF64, + crate::policy::{compiler, Concrete, Liftable, Semantic}, + crate::Descriptor, + crate::Miniscript, + crate::Tap, + std::cmp::Reverse, + std::collections::{BinaryHeap, HashMap}, + std::sync::Arc, +}; use super::ENTAILMENT_MAX_TERMINALS; use crate::expression::{self, FromTree}; use crate::miniscript::limits::{HEIGHT_TIME_THRESHOLD, SEQUENCE_LOCKTIME_TYPE_FLAG}; use crate::miniscript::types::extra_props::TimeLockInfo; -#[cfg(feature = "compiler")] -use crate::miniscript::ScriptContext; -#[cfg(feature = "compiler")] -use crate::policy::compiler; -#[cfg(feature = "compiler")] -use crate::policy::compiler::CompilerError; -#[cfg(feature = "compiler")] -use crate::Miniscript; use crate::{errstr, Error, ForEach, ForEachKey, MiniscriptKey}; + /// Concrete policy which corresponds directly to a Miniscript structure, /// and whose disjunctions are annotated with satisfaction probabilities /// to assist the compiler @@ -128,6 +135,136 @@ impl fmt::Display for PolicyError { } impl Policy { + /// Flatten the [`Policy`] tree structure into a Vector of tuple `(leaf script, leaf probability)` + /// with leaf probabilities corresponding to odds for sub-branch in the policy. + /// We calculate the probability of selecting the sub-branch at every level and calculate the + /// leaf probabilities as the probability of traversing through required branches to reach the + /// leaf node, i.e. multiplication of the respective probabilities. + /// + /// For example, the policy tree: OR + /// / \ + /// 2 1 odds + /// / \ + /// A OR + /// / \ + /// 3 1 odds + /// / \ + /// B C + /// + /// gives the vector [(2/3, A), (1/3 * 3/4, B), (1/3 * 1/4, C)]. + #[cfg(feature = "compiler")] + fn to_tapleaf_prob_vec(&self, prob: f64) -> Vec<(f64, Policy)> { + match *self { + Policy::Or(ref subs) => { + let total_odds: usize = subs.iter().map(|(ref k, _)| k).sum(); + subs.iter() + .map(|(k, ref policy)| { + policy.to_tapleaf_prob_vec(prob * *k as f64 / total_odds as f64) + }) + .flatten() + .collect::>() + } + Policy::Threshold(k, ref subs) if k == 1 => { + let total_odds = subs.len(); + subs.iter() + .map(|policy| policy.to_tapleaf_prob_vec(prob / total_odds as f64)) + .flatten() + .collect::>() + } + ref x => vec![(prob, x.clone())], + } + } + + /// Compile [`Policy::Or`] and [`Policy::Threshold`] according to odds + #[cfg(feature = "compiler")] + fn compile_tr_policy(&self) -> Result, Error> { + let leaf_compilations: Vec<_> = self + .to_tapleaf_prob_vec(1.0) + .into_iter() + .filter(|x| x.1 != Policy::Unsatisfiable) + .map(|(prob, ref policy)| (OrdF64(prob), compiler::best_compilation(policy).unwrap())) + .collect(); + let taptree = with_huffman_tree::(leaf_compilations).unwrap(); + Ok(taptree) + } + + /// Extract the internal_key from policy tree. + #[cfg(feature = "compiler")] + fn extract_key(self, unspendable_key: Option) -> Result<(Pk, Policy), Error> { + let mut internal_key: Option = None; + { + let mut prob = 0.; + let semantic_policy = self.lift()?; + let concrete_keys = self.keys(); + let key_prob_map: HashMap<_, _> = self + .to_tapleaf_prob_vec(1.0) + .into_iter() + .filter(|(_, ref pol)| match *pol { + Concrete::Key(..) => true, + _ => false, + }) + .map(|(prob, key)| (key, prob)) + .collect(); + + for key in concrete_keys.into_iter() { + if semantic_policy + .clone() + .satisfy_constraint(&Semantic::KeyHash(key.to_pubkeyhash()), true) + == Semantic::Trivial + { + match key_prob_map.get(&Concrete::Key(key.clone())) { + Some(val) => { + if *val > prob { + prob = *val; + internal_key = Some(key.clone()); + } + } + None => return Err(errstr("Key should have existed in the HashMap!")), + } + } + } + } + match (internal_key, unspendable_key) { + (Some(ref key), _) => Ok((key.clone(), self.translate_unsatisfiable_pk(&key))), + (_, Some(key)) => Ok((key, self)), + _ => Err(errstr("No viable internal key found.")), + } + } + + /// Compile the [`Policy`] into a [`Tr`][`Descriptor::Tr`] Descriptor. + /// + /// ### TapTree compilation + /// + /// The policy tree constructed by root-level disjunctions over [`Or`][`Policy::Or`] and + /// [`Thresh`][`Policy::Threshold`](1, ..) which is flattened into a vector (with respective + /// probabilities derived from odds) of policies. + /// For example, the policy `thresh(1,or(pk(A),pk(B)),and(or(pk(C),pk(D)),pk(E)))` gives the vector + /// `[pk(A),pk(B),and(or(pk(C),pk(D)),pk(E)))]`. Each policy in the vector is compiled into + /// the respective miniscripts. A Huffman Tree is created from this vector which optimizes over + /// the probabilitity of satisfaction for the respective branch in the TapTree. + // TODO: We might require other compile errors for Taproot. + #[cfg(feature = "compiler")] + pub fn compile_tr(&self, unspendable_key: Option) -> Result, Error> { + self.is_valid()?; // Check for validity + match self.is_safe_nonmalleable() { + (false, _) => Err(Error::from(CompilerError::TopLevelNonSafe)), + (_, false) => Err(Error::from( + CompilerError::ImpossibleNonMalleableCompilation, + )), + _ => { + let (internal_key, policy) = self.clone().extract_key(unspendable_key)?; + let tree = Descriptor::new_tr( + internal_key, + match policy { + Policy::Trivial => None, + policy => Some(policy.compile_tr_policy()?), + }, + )?; + Ok(tree) + } + } + } + /// Compile the descriptor into an optimized `Miniscript` representation #[cfg(feature = "compiler")] pub fn compile(&self) -> Result, CompilerError> { @@ -226,6 +363,30 @@ impl Policy { } } + /// Translate `Concrete::Key(key)` to `Concrete::Unsatisfiable` when extracting TapKey + pub fn translate_unsatisfiable_pk(self, key: &Pk) -> Policy { + match self { + Policy::Key(ref k) if k.clone() == *key => Policy::Unsatisfiable, + Policy::And(subs) => Policy::And( + subs.into_iter() + .map(|sub| sub.translate_unsatisfiable_pk(key)) + .collect::>(), + ), + Policy::Or(subs) => Policy::Or( + subs.into_iter() + .map(|(k, sub)| (k, sub.translate_unsatisfiable_pk(key))) + .collect::>(), + ), + Policy::Threshold(k, subs) => Policy::Threshold( + k, + subs.into_iter() + .map(|sub| sub.translate_unsatisfiable_pk(key)) + .collect::>(), + ), + x => x, + } + } + /// Get all keys in the policy pub fn keys(&self) -> Vec<&Pk> { match *self { @@ -645,3 +806,34 @@ where Policy::from_tree_prob(top, false).map(|(_, result)| result) } } + +/// Create a Huffman Tree from compiled [Miniscript] nodes +#[cfg(feature = "compiler")] +fn with_huffman_tree( + ms: Vec<(OrdF64, Miniscript)>, +) -> Result, Error> { + let mut node_weights = BinaryHeap::<(Reverse, TapTree)>::new(); + for (prob, script) in ms { + node_weights.push((Reverse(prob), TapTree::Leaf(Arc::new(script)))); + } + if node_weights.is_empty() { + return Err(errstr("Empty Miniscript compilation")); + } + while node_weights.len() > 1 { + let (p1, s1) = node_weights.pop().expect("len must atleast be two"); + let (p2, s2) = node_weights.pop().expect("len must atleast be two"); + + let p = (p1.0).0 + (p2.0).0; + node_weights.push(( + Reverse(OrdF64(p)), + TapTree::Tree(Arc::from(s1), Arc::from(s2)), + )); + } + + debug_assert!(node_weights.len() == 1); + let node = node_weights + .pop() + .expect("huffman tree algorithm is broken") + .1; + Ok(node) +} diff --git a/src/policy/mod.rs b/src/policy/mod.rs index 220caa223..2f0f17f1b 100644 --- a/src/policy/mod.rs +++ b/src/policy/mod.rs @@ -225,6 +225,8 @@ impl Liftable for Concrete { #[cfg(test)] mod tests { use std::str::FromStr; + #[cfg(feature = "compiler")] + use std::sync::Arc; use bitcoin; @@ -232,6 +234,8 @@ mod tests { use super::super::miniscript::Miniscript; use super::{Concrete, Liftable, Semantic}; use crate::DummyKey; + #[cfg(feature = "compiler")] + use crate::{descriptor::TapTree, Descriptor, Tap}; type ConcretePol = Concrete; type SemanticPol = Semantic; @@ -361,4 +365,121 @@ mod tests { ms_str.lift().unwrap() ); } + + #[test] + #[cfg(feature = "compiler")] + fn taproot_compile() { + // Trivial single-node compilation + let unspendable_key: String = "UNSPENDABLE".to_string(); + { + let policy: Concrete = policy_str!("thresh(2,pk(A),pk(B),pk(C),pk(D))"); + let descriptor = policy.compile_tr(Some(unspendable_key.clone())).unwrap(); + + let ms_compilation: Miniscript = ms_str!("multi_a(2,A,B,C,D)"); + let tree: TapTree = TapTree::Leaf(Arc::new(ms_compilation)); + let expected_descriptor = + Descriptor::new_tr(unspendable_key.clone(), Some(tree)).unwrap(); + assert_eq!(descriptor, expected_descriptor); + } + + // Trivial multi-node compilation + { + let policy: Concrete = policy_str!("or(and(pk(A),pk(B)),and(pk(C),pk(D)))"); + let descriptor = policy.compile_tr(Some(unspendable_key.clone())).unwrap(); + + let left_ms_compilation: Arc> = + Arc::new(ms_str!("and_v(v:pk(C),pk(D))")); + let right_ms_compilation: Arc> = + Arc::new(ms_str!("and_v(v:pk(A),pk(B))")); + let left_node: Arc> = Arc::from(TapTree::Leaf(left_ms_compilation)); + let right_node: Arc> = Arc::from(TapTree::Leaf(right_ms_compilation)); + let tree: TapTree = TapTree::Tree(left_node, right_node); + let expected_descriptor = + Descriptor::new_tr(unspendable_key.clone(), Some(tree)).unwrap(); + assert_eq!(descriptor, expected_descriptor); + } + + { + // Invalid policy compilation (Duplicate PubKeys) + let policy: Concrete = policy_str!("or(and(pk(A),pk(B)),and(pk(A),pk(D)))"); + let descriptor = policy.compile_tr(Some(unspendable_key.clone())); + + assert_eq!( + descriptor.unwrap_err().to_string(), + "Policy contains duplicate keys" + ); + } + + // Non-trivial multi-node compilation + { + let node_policies = [ + "and(pk(A),pk(B))", + "and(pk(C),older(12960))", + "pk(D)", + "pk(E)", + "thresh(3,pk(F),pk(G),pk(H))", + "and(and(or(2@pk(I),1@pk(J)),or(1@pk(K),20@pk(L))),pk(M))", + "pk(N)", + ]; + + // Floating-point precision errors cause the minor errors + let node_probabilities: [f64; 7] = + [0.12000002, 0.28, 0.08, 0.12, 0.19, 0.18999998, 0.02]; + + let policy: Concrete = policy_str!( + "{}", + &format!( + "or(4@or(3@{},7@{}),6@thresh(1,or(4@{},6@{}),{},or(9@{},1@{})))", + node_policies[0], + node_policies[1], + node_policies[2], + node_policies[3], + node_policies[4], + node_policies[5], + node_policies[6] + ) + ); + let descriptor = policy.compile_tr(Some(unspendable_key.clone())).unwrap(); + + let mut sorted_policy_prob = node_policies + .into_iter() + .zip(node_probabilities.into_iter()) + .collect::>(); + sorted_policy_prob.sort_by(|a, b| (a.1).partial_cmp(&b.1).unwrap()); + let sorted_policies = sorted_policy_prob + .into_iter() + .map(|(x, _prob)| x) + .collect::>(); + + // Generate TapTree leaves compilations from the given sub-policies + let node_compilations = sorted_policies + .into_iter() + .map(|x| { + let leaf_policy: Concrete = policy_str!("{}", x); + TapTree::Leaf(Arc::from(leaf_policy.compile::().unwrap())) + }) + .collect::>(); + + // Arrange leaf compilations (acc. to probabilities) using huffman encoding into a TapTree + let tree = TapTree::Tree( + Arc::from(TapTree::Tree( + Arc::from(node_compilations[4].clone()), + Arc::from(node_compilations[5].clone()), + )), + Arc::from(TapTree::Tree( + Arc::from(TapTree::Tree( + Arc::from(TapTree::Tree( + Arc::from(node_compilations[0].clone()), + Arc::from(node_compilations[1].clone()), + )), + Arc::from(node_compilations[3].clone()), + )), + Arc::from(node_compilations[6].clone()), + )), + ); + + let expected_descriptor = Descriptor::new_tr("E".to_string(), Some(tree)).unwrap(); + assert_eq!(descriptor, expected_descriptor); + } + } } diff --git a/src/policy/semantic.rs b/src/policy/semantic.rs index de60a3bf4..4411aa899 100644 --- a/src/policy/semantic.rs +++ b/src/policy/semantic.rs @@ -184,7 +184,7 @@ impl Policy { // policy. // Witness is currently encoded as policy. Only accepts leaf fragment and // a normalized policy - fn satisfy_constraint(self, witness: &Policy, available: bool) -> Policy { + pub(crate) fn satisfy_constraint(self, witness: &Policy, available: bool) -> Policy { debug_assert!(self.clone().normalized() == self); match *witness { // only for internal purposes, safe to use unreachable!