Skip to content

Commit

Permalink
feat: simplify inner-product generator folding (#52)
Browse files Browse the repository at this point in the history
Simplifies the handling of prover inner-product generator folding for
efficiency. Currently, this folding process is done using
element-by-element scalar-group multiplication followed by group
addition. This PR combines these operations into element-by-element
multiscalar multiplications. The result is an impressive 35% speedup for
single 64-bit range proving.

Also removes unnecessary helper functions.
  • Loading branch information
SWvheerden committed Aug 4, 2023
2 parents b4756b0 + 9d3e9d8 commit 271892c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 64 deletions.
21 changes: 13 additions & 8 deletions src/inner_product_round.rs
Expand Up @@ -240,14 +240,19 @@ where
)?;
let e_inverse = e.invert();

self.gi_base = P::add_point_vectors(
P::mul_point_vec_with_scalar(gi_base_lo, &e_inverse)?.as_slice(),
P::mul_point_vec_with_scalar(gi_base_hi, &(e * y_n_inverse))?.as_slice(),
)?;
self.hi_base = P::add_point_vectors(
P::mul_point_vec_with_scalar(hi_base_lo, &e)?.as_slice(),
P::mul_point_vec_with_scalar(hi_base_hi, &e_inverse)?.as_slice(),
)?;
// Fold the generator vectors
let e_y_n_inverse = e * y_n_inverse;
self.gi_base = gi_base_lo
.iter()
.zip(gi_base_hi.iter())
.map(|(lo, hi)| P::vartime_multiscalar_mul([&e_inverse, &e_y_n_inverse], [lo, hi]))
.collect();

self.hi_base = hi_base_lo
.iter()
.zip(hi_base_hi.iter())
.map(|(lo, hi)| P::vartime_multiscalar_mul([&e, &e_inverse], [lo, hi]))
.collect();

self.ai = Scalar::add_scalar_vectors(
Scalar::mul_scalar_vec_with_scalar(a1, &e)?.as_slice(),
Expand Down
59 changes: 3 additions & 56 deletions src/protocols/curve_point_protocol.rs
Expand Up @@ -6,20 +6,14 @@
use std::{
borrow::Borrow,
cmp::min,
ops::{Add, AddAssign, Mul},
ops::{Add, AddAssign},
};

use curve25519_dalek::{
scalar::Scalar,
traits::{Identity, VartimeMultiscalarMul},
};
use curve25519_dalek::traits::{Identity, VartimeMultiscalarMul};
use digest::Digest;
use sha3::Sha3_512;

use crate::{
errors::ProofError,
traits::{Compressable, FromUniformBytes},
};
use crate::traits::{Compressable, FromUniformBytes};

/// The `CurvePointProtocol` trait. Any implementation of this trait can be used with BP+.
pub trait CurvePointProtocol:
Expand Down Expand Up @@ -52,51 +46,4 @@ pub trait CurvePointProtocol:
(buffer[0..size]).copy_from_slice(&output.as_slice()[0..size]);
Self::from_uniform_bytes(&buffer)
}

/// Helper function to multiply a point vector with a scalar vector
fn mul_point_vec_with_scalar(point_vec: &[Self], scalar: &Scalar) -> Result<Vec<Self>, ProofError>
where for<'p> &'p Self: Mul<Scalar, Output = Self> {
if point_vec.is_empty() {
return Err(ProofError::InvalidLength(
"Cannot multiply empty point vector with scalar".to_string(),
));
}
let mut out = vec![Self::identity(); point_vec.len()];
for i in 0..point_vec.len() {
out[i] = &point_vec[i] * *scalar;
}
Ok(out)
}

/// Helper function to add two point vectors
fn add_point_vectors(a: &[Self], b: &[Self]) -> Result<Vec<Self>, ProofError>
where for<'p> &'p Self: Add<Output = Self> {
if a.len() != b.len() || a.is_empty() {
return Err(ProofError::InvalidLength("Cannot add empty point vectors".to_string()));
}
let mut out = vec![Self::identity(); a.len()];
for i in 0..a.len() {
out[i] = &a[i] + &b[i];
}
Ok(out)
}
}

#[cfg(test)]
mod test {
use curve25519_dalek::RistrettoPoint;

use super::*;

#[test]
fn test_errors() {
// Empty point vector
assert!(RistrettoPoint::mul_point_vec_with_scalar(&[], &Scalar::ONE).is_err());

// Mismatched vector lengths
assert!(RistrettoPoint::add_point_vectors(&[RistrettoPoint::default()], &[]).is_err());

// Empty point vector
assert!(RistrettoPoint::add_point_vectors(&[], &[]).is_err());
}
}

0 comments on commit 271892c

Please sign in to comment.