From cd7588ee8eaebe862fe9cf5d7c3fd92981703e87 Mon Sep 17 00:00:00 2001 From: Aaron Feickert <66188213+AaronFeickert@users.noreply.github.com> Date: Mon, 19 Jun 2023 09:30:33 -0500 Subject: [PATCH] feat: use precomputation on (most) fixed generators (#19) Uses precomputation of fixed generator vectors to speed up verification, at the cost of storing precomputation tables between verification operations. This doesn't use precomputation on the Pedersen generators, since those can be set independently of the others, and we can't mix-and-match precomputation tables due to upstream limitations. Note that this requires and uses a custom curve library fork. The fork supports partial precomputation by removing an existing restriction about matching the number of static points and scalars used for precomputation evaluation. It also implements `Clone` on the underlying types used for precomputation. This is unfortunate, since due to their size (several megabytes in total) such tables should almost certainly never be cloned. However, it's done for the reason explained below. The generator tables are wrapped in an `Arc` for shared ownership. This is done because precomputation evaluation is an instance method on a precomputation type, not a static method that takes a reference to the underlying tables. I have no idea why this design was chosen (static methods are used for other types of multiscalar multiplication), especially because there's no mutation involved. But because of this, the verifier needs to own the precomputation structure containing the tables, even though those tables are expected to be reused (that's the entire point of precomputation). Using an `Arc` takes care of this nicely, and avoids cloning. However, apparently `#[derive(Clone)]` only plays nicely with structs if all included generic types implement `Clone`, which means even though cloning the table `Arc` isn't any kind of deep copy, we can't use that attribute unless the precomputation tables implement `Clone`. Manually implementing `Clone` on the containing struct is a headache, so it seemed easier just to add `#[derive(Clone)]` at the curve library level. This means it's probably _very important_ to ensure that precomputation tables are used very carefully to avoid unintended cloning. I did some testing and confirmed that the current implementation handles this as expected, and won't clone any of the tables, despite the compiler requiring they implement `Clone`. Closes [issue #18](https://github.com/tari-project/bulletproofs-plus/issues/18). --- Cargo.toml | 1 + benches/range_proof.rs | 91 ++++++++++++-------------- src/generators/bulletproof_gens.rs | 90 +++++++++++++++----------- src/generators/generators_chain.rs | 9 --- src/generators/mod.rs | 23 ------- src/range_parameters.rs | 26 ++++---- src/range_proof.rs | 100 ++++++++++++++--------------- src/range_statement.rs | 8 +-- src/ristretto.rs | 8 ++- src/traits.rs | 8 +++ src/transcripts.rs | 4 +- 11 files changed, 179 insertions(+), 189 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d8feb4d..3e2859a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ curve25519-dalek = { package="tari-curve25519-dalek", version = "4.0.2", default derive_more = "0.99.17" derivative = "2.2.0" digest = { version = "0.9.0", default-features = false } +itertools = "0.6.0" lazy_static = "1.4.0" merlin = { version = "2", default-features = false } rand = "0.7" diff --git a/benches/range_proof.rs b/benches/range_proof.rs index 1f2565d..b17ac1a 100644 --- a/benches/range_proof.rs +++ b/benches/range_proof.rs @@ -9,8 +9,6 @@ #[macro_use] extern crate criterion; -use std::convert::TryInto; - use criterion::{Criterion, SamplingMode}; use curve25519_dalek::scalar::Scalar; use rand::{self, Rng}; @@ -197,65 +195,60 @@ fn verify_batched_rangeproofs_helper(bit_length: usize, extension_degree: Extens #[allow(clippy::cast_possible_truncation)] let (value_min, value_max) = (0u64, (1u128 << (bit_length - 1)) as u64); - let max_range_proofs = BATCHED_SIZES - .to_vec() - .iter() - .fold(u32::MIN, |a, &b| a.max(b.try_into().unwrap())); - // 0. Batch data - let mut statements = vec![]; - let mut proofs = vec![]; - let pc_gens = ristretto::create_pedersen_gens_with_extension_degree(extension_degree); - - // 1. Generators - let generators = RangeParameters::init(bit_length, 1, pc_gens).unwrap(); - - let mut rng = rand::thread_rng(); - for _ in 0..max_range_proofs { - // 2. Create witness data - let mut openings = vec![]; - let value = rng.gen_range(value_min, value_max); - let blindings = vec![Scalar::random_not_zero(&mut rng); extension_degree as usize]; - openings.push(CommitmentOpening::new(value, blindings.clone())); - let witness = RangeWitness::init(openings).unwrap(); - - // 3. Generate the statement - let seed_nonce = Some(Scalar::random_not_zero(&mut rng)); - let statement = RangeStatement::init( - generators.clone(), - vec![generators - .pc_gens() - .commit(&Scalar::from(value), blindings.as_slice()) - .unwrap()], - vec![Some(value / 3)], - seed_nonce, - ) - .unwrap(); - statements.push(statement.clone()); - - // 4. Create the proof - let proof = RistrettoRangeProof::prove(transcript_label, &statement, &witness).unwrap(); - proofs.push(proof); - } - for extract_masks in EXTRACT_MASKS { for number_of_range_proofs in BATCHED_SIZES { let label = format!( "Batched {}-bit BP+ verify {} deg {:?} masks {:?}", bit_length, number_of_range_proofs, extension_degree, extract_masks ); - let statements = &statements[0..number_of_range_proofs]; - let proofs = &proofs[0..number_of_range_proofs]; + + // Generators + let pc_gens = ristretto::create_pedersen_gens_with_extension_degree(extension_degree); + let generators = RangeParameters::init(bit_length, 1, pc_gens).unwrap(); + + let mut rng = rand::thread_rng(); group.bench_function(&label, move |b| { + // Batch data + let mut statements = vec![]; + let mut proofs = vec![]; + + for _ in 0..number_of_range_proofs { + // Witness data + let mut openings = vec![]; + let value = rng.gen_range(value_min, value_max); + let blindings = vec![Scalar::random_not_zero(&mut rng); extension_degree as usize]; + openings.push(CommitmentOpening::new(value, blindings.clone())); + let witness = RangeWitness::init(openings).unwrap(); + + // Statement data + let seed_nonce = Some(Scalar::random_not_zero(&mut rng)); + let statement = RangeStatement::init( + generators.clone(), + vec![generators + .pc_gens() + .commit(&Scalar::from(value), blindings.as_slice()) + .unwrap()], + vec![Some(value / 3)], + seed_nonce, + ) + .unwrap(); + statements.push(statement.clone()); + + // Proof + let proof = RistrettoRangeProof::prove(transcript_label, &statement, &witness).unwrap(); + proofs.push(proof); + } + // Benchmark this code b.iter(|| { - // 5. Verify the entire batch of single proofs + // Verify the entire batch of proofs match extract_masks { VerifyAction::VerifyOnly => { let _masks = RangeProof::verify_batch( transcript_label, - statements, - proofs, + &statements, + &proofs, VerifyAction::VerifyOnly, ) .unwrap(); @@ -263,8 +256,8 @@ fn verify_batched_rangeproofs_helper(bit_length: usize, extension_degree: Extens VerifyAction::RecoverOnly => { let _masks = RangeProof::verify_batch( transcript_label, - statements, - proofs, + &statements, + &proofs, VerifyAction::RecoverOnly, ) .unwrap(); diff --git a/src/generators/bulletproof_gens.rs b/src/generators/bulletproof_gens.rs index 178ea87..de47028 100644 --- a/src/generators/bulletproof_gens.rs +++ b/src/generators/bulletproof_gens.rs @@ -4,7 +4,19 @@ // Copyright (c) 2018 Chain, Inc. // SPDX-License-Identifier: MIT -use crate::{generators::aggregated_gens_iter::AggregatedGensIter, traits::FromUniformBytes}; +use std::{ + fmt::{Debug, Formatter}, + sync::Arc, +}; + +use byteorder::{ByteOrder, LittleEndian}; +use curve25519_dalek::traits::VartimePrecomputedMultiscalarMul; +use itertools::Itertools; + +use crate::{ + generators::{aggregated_gens_iter::AggregatedGensIter, generators_chain::GeneratorsChain}, + traits::{Compressable, FromUniformBytes, Precomputable}, +}; /// The `BulletproofGens` struct contains all the generators needed for aggregating up to `m` range proofs of up to `n` /// bits each. @@ -25,8 +37,8 @@ use crate::{generators::aggregated_gens_iter::AggregatedGensIter, traits::FromUn /// This construction is also forward-compatible with constraint system proofs, which use a much larger slice of the /// generator chain, and even forward-compatible to multiparty aggregation of constraint system proofs, since the /// generators are namespaced by their party index. -#[derive(Clone, Debug)] -pub struct BulletproofGens

{ +#[derive(Clone)] +pub struct BulletproofGens { /// The maximum number of usable generators for each party. pub gens_capacity: usize, /// Number of values or parties @@ -35,9 +47,11 @@ pub struct BulletproofGens

{ pub(crate) g_vec: Vec>, /// Precomputed \\(\mathbf H\\) generators for each party. pub(crate) h_vec: Vec>, + /// Interleaved precomputed generators + pub(crate) precomp: Arc, } -impl BulletproofGens

{ +impl BulletproofGens

{ /// Create a new `BulletproofGens` object. /// /// # Inputs @@ -48,46 +62,35 @@ impl BulletproofGens

{ /// /// * `party_capacity` is the maximum number of parties that can produce an aggregated proof. pub fn new(gens_capacity: usize, party_capacity: usize) -> Self { - let mut gens = BulletproofGens { - gens_capacity: 0, - party_capacity, - g_vec: (0..party_capacity).map(|_| Vec::new()).collect(), - h_vec: (0..party_capacity).map(|_| Vec::new()).collect(), - }; - gens.increase_capacity(gens_capacity); - gens - } - - /// Increases the generators' capacity to the amount specified. If less than or equal to the current capacity, - /// does nothing. - pub fn increase_capacity(&mut self, new_capacity: usize) { - use byteorder::{ByteOrder, LittleEndian}; - - use crate::generators::generators_chain::GeneratorsChain; - - if self.gens_capacity >= new_capacity { - return; - } + let mut g_vec: Vec> = (0..party_capacity).map(|_| Vec::new()).collect(); + let mut h_vec: Vec> = (0..party_capacity).map(|_| Vec::new()).collect(); - for i in 0..self.party_capacity { + // Generate the points + for i in 0..party_capacity { #[allow(clippy::cast_possible_truncation)] let party_index = i as u32; + let mut label = [b'G', 0, 0, 0, 0]; LittleEndian::write_u32(&mut label[1..5], party_index); - self.g_vec[i].extend( - &mut GeneratorsChain::new(&label) - .fast_forward(self.gens_capacity) - .take(new_capacity - self.gens_capacity), - ); + g_vec[i].extend(&mut GeneratorsChain::

::new(&label).take(gens_capacity)); label[0] = b'H'; - self.h_vec[i].extend( - &mut GeneratorsChain::new(&label) - .fast_forward(self.gens_capacity) - .take(new_capacity - self.gens_capacity), - ); + h_vec[i].extend(&mut GeneratorsChain::

::new(&label).take(gens_capacity)); + } + + // Generate a flattened interleaved iterator for the precomputation tables + let iter_g_vec = g_vec.iter().flat_map(move |g_j| g_j.iter()); + let iter_h_vec = h_vec.iter().flat_map(move |h_j| h_j.iter()); + let iter_interleaved = iter_g_vec.interleave(iter_h_vec); + let precomp = Arc::new(P::Precomputation::new(iter_interleaved)); + + BulletproofGens { + gens_capacity, + party_capacity, + g_vec, + h_vec, + precomp, } - self.gens_capacity = new_capacity; } /// Return an iterator over the aggregation of the parties' G generators with given size `n`. @@ -112,3 +115,18 @@ impl BulletproofGens

{ } } } + +impl

Debug for BulletproofGens

+where + P: Compressable + Debug + Precomputable, + P::Compressed: Debug, +{ + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RangeParameters") + .field("gens_capacity", &self.gens_capacity) + .field("party_capacity", &self.party_capacity) + .field("g_vec", &self.g_vec) + .field("h_vec", &self.h_vec) + .finish() + } +} diff --git a/src/generators/generators_chain.rs b/src/generators/generators_chain.rs index f3e7979..a85cc33 100644 --- a/src/generators/generators_chain.rs +++ b/src/generators/generators_chain.rs @@ -30,15 +30,6 @@ impl

GeneratorsChain

{ _phantom: PhantomData, } } - - /// Advances the reader n times, squeezing and discarding the result - pub(crate) fn fast_forward(mut self, n: usize) -> Self { - let mut buf = [0u8; 64]; - for _ in 0..n { - self.reader.read(&mut buf); - } - self - } } impl

Default for GeneratorsChain

{ diff --git a/src/generators/mod.rs b/src/generators/mod.rs index a1e51b6..e84f0d1 100644 --- a/src/generators/mod.rs +++ b/src/generators/mod.rs @@ -61,27 +61,4 @@ mod tests { helper(16, 2); helper(16, 1); } - - #[test] - fn resizing_small_gens_matches_creating_bigger_gens() { - let gens = BulletproofGens::new(64, 8); - - let mut gen_resized = BulletproofGens::new(32, 8); - gen_resized.increase_capacity(64); - - let helper = |n: usize, m: usize| { - let gens_g: Vec = gens.g_iter(n, m).copied().collect(); - let gens_h: Vec = gens.h_iter(n, m).copied().collect(); - - let resized_g: Vec = gen_resized.g_iter(n, m).copied().collect(); - let resized_h: Vec = gen_resized.h_iter(n, m).copied().collect(); - - assert_eq!(gens_g, resized_g); - assert_eq!(gens_h, resized_h); - }; - - helper(64, 8); - helper(32, 8); - helper(16, 8); - } } diff --git a/src/range_parameters.rs b/src/range_parameters.rs index 6bd8ec6..55020a5 100644 --- a/src/range_parameters.rs +++ b/src/range_parameters.rs @@ -3,7 +3,10 @@ //! Bulletproofs+ range parameters (generators and base points) needed for a batch of range proofs -use std::fmt::{Debug, Formatter}; +use std::{ + fmt::{Debug, Formatter}, + sync::Arc, +}; use crate::{ errors::ProofError, @@ -12,12 +15,12 @@ use crate::{ pedersen_gens::{ExtensionDegree, PedersenGens}, }, range_proof::MAX_RANGE_PROOF_BIT_LENGTH, - traits::{Compressable, FromUniformBytes}, + traits::{Compressable, FromUniformBytes, Precomputable}, }; /// Contains all the generators and base points needed for a batch of range proofs #[derive(Clone)] -pub struct RangeParameters { +pub struct RangeParameters { /// Generators needed for aggregating up to `m` range proofs of up to `n` bits each. bp_gens: BulletproofGens

, /// The pair of base points for Pedersen commitments. @@ -25,7 +28,7 @@ pub struct RangeParameters { } impl

RangeParameters

-where P: FromUniformBytes + Compressable + Clone +where P: FromUniformBytes + Compressable + Clone + Precomputable { /// Initialize a new 'RangeParameters' with sanity checks pub fn init(bit_length: usize, aggregation_factor: usize, pc_gens: PedersenGens

) -> Result { @@ -107,11 +110,6 @@ where P: FromUniformBytes + Compressable + Clone self.hi_base_iter().collect() } - /// Return the non-public value bulletproof generator references - pub fn hi_base_copied(&self) -> Vec

{ - self.hi_base_iter().cloned().collect() - } - /// Return the non-public mask iterator to the bulletproof generators pub fn gi_base_iter(&self) -> impl Iterator { self.bp_gens.g_iter(self.bit_length(), self.aggregation_factor()) @@ -122,15 +120,17 @@ where P: FromUniformBytes + Compressable + Clone self.gi_base_iter().collect() } - /// Return the non-public mask bulletproof generators - pub fn gi_base_copied(&self) -> Vec

{ - self.gi_base_iter().cloned().collect() + /// Return the interleaved precomputation tables + pub fn precomp(&self) -> Arc { + // We use shared ownership since precomputation evaluation is an instance method and we don't want to actually + // clone + Arc::clone(&self.bp_gens.precomp) } } impl

Debug for RangeParameters

where - P: Compressable + Debug, + P: Compressable + Debug + Precomputable, P::Compressed: Debug, { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { diff --git a/src/range_proof.rs b/src/range_proof.rs index f20b228..ba2a826 100644 --- a/src/range_proof.rs +++ b/src/range_proof.rs @@ -13,8 +13,9 @@ use std::{ use curve25519_dalek::{ scalar::Scalar, - traits::{Identity, IsIdentity}, + traits::{Identity, IsIdentity, VartimePrecomputedMultiscalarMul}, }; +use itertools::Itertools; use merlin::Transcript; use rand::thread_rng; use serde::{de::Visitor, Deserialize, Deserializer, Serialize, Serializer}; @@ -31,7 +32,7 @@ use crate::{ }, range_statement::RangeStatement, range_witness::RangeWitness, - traits::{Compressable, Decompressable, FixedBytesRepr}, + traits::{Compressable, Decompressable, FixedBytesRepr, Precomputable}, transcripts, utils::generic::{bit_vector_of_scalars, nonce, read_1_byte, read_32_bytes}, }; @@ -186,8 +187,8 @@ impl

RangeProof

where for<'p> &'p P: Mul, for<'p> &'p P: Add, - P: CurvePointProtocol, - P::Compressed: FixedBytesRepr + IsIdentity + Identity + Copy, + P: CurvePointProtocol + Precomputable, + P::Compressed: FixedBytesRepr + IsIdentity + Identity, { /// Helper function to return the proof's extension degree pub fn extension_degree(&self) -> ExtensionDegree { @@ -314,8 +315,8 @@ where // Calculate the inner product transcript.domain_separator(b"Bulletproofs+", b"Inner Product Proof"); let mut ip_data = InnerProductRound::init( - statement.generators.gi_base_copied(), - statement.generators.hi_base_copied(), + statement.generators.gi_base_iter().cloned().collect(), + statement.generators.hi_base_iter().cloned().collect(), statement.generators.g_bases().to_vec(), statement.generators.h_base().clone(), a_li, @@ -399,7 +400,7 @@ where statements[max_index].generators.hi_base_ref(), ); for (i, statement) in statements.iter().enumerate() { - for value in statement.minimum_value_promises.iter().flatten() { + for value in Iterator::flatten(statement.minimum_value_promises.iter()) { if value >> (bit_length - 1) > 1 { return Err(ProofError::InvalidLength( "Minimum value promise exceeds bit vector capacity".to_string(), @@ -480,13 +481,10 @@ where let (max_mn, max_index) = RangeProof::verify_statements_and_generators_consistency(statements, range_proofs)?; let (g_base_vec, h_base) = (statements[0].generators.g_bases(), statements[0].generators.h_base()); let bit_length = statements[0].generators.bit_length(); - let (gi_base_ref, hi_base_ref) = ( - statements[max_index].generators.gi_base_ref(), - statements[max_index].generators.hi_base_ref(), - ); let extension_degree = statements[0].generators.extension_degree() as usize; let g_bases_compressed = statements[0].generators.g_bases_compressed(); let h_base_compressed = statements[0].generators.h_base_compressed(); + let precomp = statements[max_index].generators.precomp(); // Compute log2(N) let mut log_n = 0u32; @@ -510,13 +508,15 @@ where let mut hi_base_scalars = vec![Scalar::zero(); max_mn]; // Final multiscalar multiplication data - let mut msm_len = 0; + // Because we use precomputation on the generator vectors, we need to separate the static data from the dynamic + // data. However, we can't combine precomputation data, so the Pedersen generators go with the dynamic + // data :( + let mut msm_dynamic_len = extension_degree + 1; for (index, item) in statements.iter().enumerate() { - msm_len += item.generators.aggregation_factor() + 3 + range_proofs[index].li.len() * 2; + msm_dynamic_len += item.generators.aggregation_factor() + 3 + range_proofs[index].li.len() * 2; } - msm_len += 2 + max_mn * 2 + (extension_degree - 1); - let mut scalars: Vec = Vec::with_capacity(msm_len); - let mut points: Vec

= Vec::with_capacity(msm_len); + let mut dynamic_scalars: Vec = Vec::with_capacity(msm_dynamic_len); + let mut dynamic_points: Vec

= Vec::with_capacity(msm_dynamic_len); // Recovered masks let mut masks = match extract_masks { @@ -530,9 +530,9 @@ where // Process each proof and add it to the batch let rng = &mut thread_rng(); - for (index, proof) in range_proofs.iter().enumerate() { - let commitments = statements[index].commitments.clone(); - let minimum_value_promises = statements[index].minimum_value_promises.clone(); + for (proof, statement) in range_proofs.iter().zip(statements) { + let commitments = statement.commitments.clone(); + let minimum_value_promises = statement.minimum_value_promises.clone(); let a = proof.a_decompressed()?; let a1 = proof.a1_decompressed()?; let b = proof.b_decompressed()?; @@ -569,7 +569,7 @@ where bit_length, extension_degree, aggregation_factor, - &statements[index], + statement, )?; // Reconstruct challenges @@ -632,7 +632,7 @@ where match extract_masks { VerifyAction::VerifyOnly => masks.push(None), _ => { - if let Some(seed_nonce) = statements[index].seed_nonce { + if let Some(seed_nonce) = statement.seed_nonce { let mut temp_masks = Vec::with_capacity(extension_degree); for (k, d1_val) in d1.iter().enumerate().take(extension_degree) { let mut this_mask = (*d1_val - @@ -681,54 +681,52 @@ where // Remaining terms let mut z_even_powers = Scalar::one(); - for k in 0..aggregation_factor { + for minimum_value_promise in minimum_value_promises { z_even_powers *= z_square; let weighted = weight * (-e_square * z_even_powers * y_nm_1); - scalars.push(weighted); - points.push(commitments[k].clone()); - if let Some(minimum_value) = minimum_value_promises[k] { + dynamic_scalars.push(weighted); + if let Some(minimum_value) = minimum_value_promise { h_base_scalar -= weighted * Scalar::from(minimum_value); } } + dynamic_points.extend(commitments); h_base_scalar += weight * (r1 * y * s1 + e_square * (y_nm_1 * z * d_sum + (z_square - z) * y_sum)); for k in 0..extension_degree { g_base_scalars[k] += weight * d1[k]; } - scalars.push(weight * (-e)); - points.push(a1); - scalars.push(-weight); - points.push(b); - scalars.push(weight * (-e_square)); - points.push(a); - - for j in 0..rounds { - scalars.push(weight * (-e_square * challenges_sq[j])); - points.push(li[j].clone()); - scalars.push(weight * (-e_square * challenges_sq_inv[j])); - points.push(ri[j].clone()); - } + dynamic_scalars.push(weight * (-e)); + dynamic_points.push(a1); + dynamic_scalars.push(-weight); + dynamic_points.push(b); + dynamic_scalars.push(weight * (-e_square)); + dynamic_points.push(a); + + dynamic_scalars.extend(challenges_sq.into_iter().map(|c| weight * -e_square * c)); + dynamic_points.extend(li.into_iter()); + dynamic_scalars.extend(challenges_sq_inv.into_iter().map(|c| weight * -e_square * c)); + dynamic_points.extend(ri.into_iter()); } if extract_masks == VerifyAction::RecoverOnly { return Ok(masks); } - // Common generators + // Pedersen generators for k in 0..extension_degree { - scalars.push(g_base_scalars[k]); - points.push(g_base_vec[k].clone()); - } - scalars.push(h_base_scalar); - points.push(h_base.clone()); - for i in 0..max_mn { - scalars.push(gi_base_scalars[i]); - points.push(gi_base_ref[i].clone()); - scalars.push(hi_base_scalars[i]); - points.push(hi_base_ref[i].clone()); + dynamic_scalars.push(g_base_scalars[k]); + dynamic_points.push(g_base_vec[k].clone()); } - - if P::vartime_multiscalar_mul(scalars, points) != P::identity() { + dynamic_scalars.push(h_base_scalar); + dynamic_points.push(h_base.clone()); + + // Perform the final check using precomputation + if precomp.vartime_mixed_multiscalar_mul( + gi_base_scalars.iter().interleave(hi_base_scalars.iter()), + dynamic_scalars.iter(), + dynamic_points.iter(), + ) != P::identity() + { return Err(ProofError::VerificationFailed( "Range proof batch not valid".to_string(), )); diff --git a/src/range_statement.rs b/src/range_statement.rs index 78c47aa..22d71f4 100644 --- a/src/range_statement.rs +++ b/src/range_statement.rs @@ -10,13 +10,13 @@ use zeroize::Zeroize; use crate::{ errors::ProofError, range_parameters::RangeParameters, - traits::{Compressable, FromUniformBytes}, + traits::{Compressable, FromUniformBytes, Precomputable}, }; /// The range proof statement contains the generators, vector of commitments, vector of optional minimum promised /// values and a vector of optional seed nonces for mask recovery #[derive(Clone)] -pub struct RangeStatement { +pub struct RangeStatement { /// The generators and base points needed for aggregating range proofs pub generators: RangeParameters

, /// The aggregated commitments @@ -29,7 +29,7 @@ pub struct RangeStatement { pub seed_nonce: Option, } -impl RangeStatement

{ +impl RangeStatement

{ /// Initialize a new 'RangeStatement' with sanity checks pub fn init( generators: RangeParameters

, @@ -72,7 +72,7 @@ impl RangeStatement

{ } /// Overwrite secrets with null bytes when they go out of scope. -impl Drop for RangeStatement

{ +impl Drop for RangeStatement

{ fn drop(&mut self) { self.seed_nonce.zeroize(); } diff --git a/src/ristretto.rs b/src/ristretto.rs index 0075fca..de132e8 100644 --- a/src/ristretto.rs +++ b/src/ristretto.rs @@ -7,14 +7,14 @@ use curve25519_dalek::{ constants::{RISTRETTO_BASEPOINT_COMPRESSED, RISTRETTO_BASEPOINT_POINT}, - ristretto::{CompressedRistretto, RistrettoPoint}, + ristretto::{CompressedRistretto, RistrettoPoint, VartimeRistrettoPrecomputation}, }; use crate::{ generators::pedersen_gens::ExtensionDegree, protocols::curve_point_protocol::CurvePointProtocol, range_proof::RangeProof, - traits::{Compressable, Decompressable, FixedBytesRepr, FromUniformBytes}, + traits::{Compressable, Decompressable, FixedBytesRepr, FromUniformBytes, Precomputable}, PedersenGens, }; @@ -55,6 +55,10 @@ impl Compressable for RistrettoPoint { } } +impl Precomputable for RistrettoPoint { + type Precomputation = VartimeRistrettoPrecomputation; +} + /// Create extended Pedersen generators for the required extension degree using pre-calculated compressed constants pub fn create_pedersen_gens_with_extension_degree(extension_degree: ExtensionDegree) -> PedersenGens { let (g_base_vec, g_base_compressed_vec) = get_g_base(extension_degree); diff --git a/src/traits.rs b/src/traits.rs index 89c49d4..142c3a4 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -1,6 +1,8 @@ // Copyright 2022 The Tari Project // SPDX-License-Identifier: BSD-3-Clause +use curve25519_dalek::traits::VartimePrecomputedMultiscalarMul; + /// Abstrations for any type that can be represented as 32 bytes pub trait FixedBytesRepr { /// Returns a reference to the 32-byte representation @@ -33,3 +35,9 @@ pub trait Decompressable { /// Try decompress this instance. None is returned if this fails. fn decompress(&self) -> Option; } + +/// Abstraction for any type supporting multiscalar multiplication precomputation +pub trait Precomputable { + /// The type representing the precomputation instantiation + type Precomputation: Clone + VartimePrecomputedMultiscalarMul; +} diff --git a/src/transcripts.rs b/src/transcripts.rs index cef775e..72a1f91 100644 --- a/src/transcripts.rs +++ b/src/transcripts.rs @@ -8,7 +8,7 @@ use crate::{ errors::ProofError, protocols::transcript_protocol::TranscriptProtocol, range_statement::RangeStatement, - traits::{Compressable, FixedBytesRepr}, + traits::{Compressable, FixedBytesRepr, Precomputable}, }; // Helper function to construct the initial transcript @@ -22,7 +22,7 @@ pub(crate) fn transcript_initialize

( statement: &RangeStatement

, ) -> Result<(), ProofError> where - P: Compressable, + P: Compressable + Precomputable, P::Compressed: FixedBytesRepr + IsIdentity, { transcript.validate_and_append_point(b"H", h_base_compressed)?;