Skip to content

Commit

Permalink
Merge pull request #52 from zkcrypto/ff-batch-invert
Browse files Browse the repository at this point in the history
Use batch inversion APIs from `ff` crate
  • Loading branch information
str4d committed Aug 11, 2021
2 parents 96ab416 + 3c7996d commit 53c4895
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 79 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ version = "0.5"
default-features = false

[dependencies.ff]
version = "0.10"
version = "0.10.1"
default-features = false

[dependencies.group]
Expand All @@ -46,7 +46,7 @@ default-features = false

[features]
default = ["alloc", "bits"]
alloc = ["group/alloc"]
alloc = ["ff/alloc", "group/alloc"]
bits = ["ff/bits"]

[[bench]]
Expand Down
136 changes: 59 additions & 77 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use core::borrow::Borrow;
use core::fmt;
use core::iter::Sum;
use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use ff::Field;
use ff::{BatchInverter, Field};
use group::{
cofactor::{CofactorCurve, CofactorCurveAffine, CofactorGroup},
prime::PrimeGroup,
Expand Down Expand Up @@ -534,6 +534,8 @@ impl AffinePoint {
/// according to ZIP 216.
#[cfg(feature = "alloc")]
pub fn batch_from_bytes(items: impl Iterator<Item = [u8; 32]>) -> Vec<CtOption<Self>> {
use ff::BatchInvert;

#[derive(Clone, Copy, Default)]
struct Item {
sign: u8,
Expand All @@ -553,54 +555,53 @@ impl AffinePoint {
}
}

let items = items.map(|mut b| {
// Grab the sign bit from the representation
let sign = b[31] >> 7;

// Mask away the sign bit
b[31] &= 0b0111_1111;

// Interpret what remains as the v-coordinate
Fq::from_bytes(&b).map(|v| {
// -u^2 + v^2 = 1 + d.u^2.v^2
// -u^2 = 1 + d.u^2.v^2 - v^2 (rearrange)
// -u^2 - d.u^2.v^2 = 1 - v^2 (rearrange)
// u^2 + d.u^2.v^2 = v^2 - 1 (flip signs)
// u^2 (1 + d.v^2) = v^2 - 1 (factor)
// u^2 = (v^2 - 1) / (1 + d.v^2) (isolate u^2)
// We know that (1 + d.v^2) is nonzero for all v:
// (1 + d.v^2) = 0
// d.v^2 = -1
// v^2 = -(1 / d) No solutions, as -(1 / d) is not a square

let v2 = v.square();

Item {
v,
sign,
numerator: (v2 - Fq::one()),
denominator: Fq::one() + EDWARDS_D * v2,
}
let items: Vec<_> = items
.map(|mut b| {
// Grab the sign bit from the representation
let sign = b[31] >> 7;

// Mask away the sign bit
b[31] &= 0b0111_1111;

// Interpret what remains as the v-coordinate
Fq::from_bytes(&b).map(|v| {
// -u^2 + v^2 = 1 + d.u^2.v^2
// -u^2 = 1 + d.u^2.v^2 - v^2 (rearrange)
// -u^2 - d.u^2.v^2 = 1 - v^2 (rearrange)
// u^2 + d.u^2.v^2 = v^2 - 1 (flip signs)
// u^2 (1 + d.v^2) = v^2 - 1 (factor)
// u^2 = (v^2 - 1) / (1 + d.v^2) (isolate u^2)
// We know that (1 + d.v^2) is nonzero for all v:
// (1 + d.v^2) = 0
// d.v^2 = -1
// v^2 = -(1 / d) No solutions, as -(1 / d) is not a square

let v2 = v.square();

Item {
v,
sign,
numerator: (v2 - Fq::one()),
denominator: Fq::one() + EDWARDS_D * v2,
}
})
})
});
.collect();

let mut acc = Fq::one();
let mut tmp = Vec::with_capacity(items.size_hint().0);
for item in items {
tmp.push((acc, item));
acc *= item.map(|item| item.denominator).unwrap_or(Fq::one());
}
acc = acc.invert().unwrap();
let mut denominators: Vec<_> = items
.iter()
.map(|item| item.map(|item| item.denominator).unwrap_or(Fq::zero()))
.collect();
denominators.iter_mut().batch_invert();

let mut ret: Vec<_> = tmp
items
.into_iter()
.rev()
.map(|(tmp, item)| {
let ret = item.and_then(
.zip(denominators.into_iter())
.map(|(item, inv_denominator)| {
item.and_then(
|Item {
v, sign, numerator, ..
}| {
let inv_denominator = tmp * acc;
(numerator * inv_denominator).sqrt().and_then(|u| {
// Fix the sign of `u` if necessary
let flip_sign = Choice::from((u.to_bytes()[0] ^ sign) & 1);
Expand All @@ -615,13 +616,9 @@ impl AffinePoint {
CtOption::new(AffinePoint { u: final_u, v }, !(u_is_zero & flip_sign))
})
},
);
acc *= item.map(|item| item.denominator).unwrap_or(Fq::one());
ret
)
})
.collect();
ret.reverse();
ret
.collect()
}

/// Returns the `u`-coordinate of this point.
Expand Down Expand Up @@ -838,23 +835,16 @@ impl ExtendedPoint {
fn batch_normalize(p: &[Self], q: &mut [AffinePoint]) {
assert_eq!(p.len(), q.len());

let mut acc = Fq::one();
for (p, q) in p.iter().zip(q.iter_mut()) {
// We use the `u` field of `AffinePoint` to store the product
// of previous z-coordinates seen.
q.u = acc;
acc *= &p.z;
// We use the `u` field of `AffinePoint` to store the z-coordinate being
// inverted, and the `v` field for scratch space.
q.u = p.z;
}

// This is the inverse, as all z-coordinates are nonzero.
acc = acc.invert().unwrap();
BatchInverter::invert_with_internal_scratch(q, |q| &mut q.u, |q| &mut q.v);

for (p, q) in p.iter().zip(q.iter_mut()).rev() {
// Compute tmp = 1/z
let tmp = q.u * acc;

// Cancel out z-coordinate in denominator of `acc`
acc *= &p.z;
let tmp = q.u;

// Set the coordinates to the correct value
q.u = p.u * &tmp; // Multiply by 1/z
Expand Down Expand Up @@ -1087,25 +1077,12 @@ impl Default for ExtendedPoint {
///
/// This costs 5 multiplications per element, and a field inversion.
pub fn batch_normalize<'a>(v: &'a mut [ExtendedPoint]) -> impl Iterator<Item = AffinePoint> + 'a {
let mut acc = Fq::one();
for p in v.iter_mut() {
// We use the `t1` field of `ExtendedPoint` to store the product
// of previous z-coordinates seen.
p.t1 = acc;
acc *= &p.z;
}
// We use the `t1` field of `ExtendedPoint` for scratch space.
BatchInverter::invert_with_internal_scratch(v, |p| &mut p.z, |p| &mut p.t1);

// This is the inverse, as all z-coordinates are nonzero.
acc = acc.invert().unwrap();

for p in v.iter_mut().rev() {
for p in v.iter_mut() {
let mut q = *p;

// Compute tmp = 1/z
let tmp = q.t1 * acc;

// Cancel out z-coordinate in denominator of `acc`
acc *= &q.z;
let tmp = q.z;

// Set the coordinates to the correct value
q.u *= &tmp; // Multiply by 1/z
Expand Down Expand Up @@ -1573,6 +1550,11 @@ fn test_batch_normalize() {
}

let expected: std::vec::Vec<_> = v.iter().map(|p| AffinePoint::from(*p)).collect();
let mut result0 = vec![AffinePoint::identity(); v.len()];
ExtendedPoint::batch_normalize(&v, &mut result0);
for i in 0..10 {
assert!(expected[i] == result0[i]);
}
let result1: std::vec::Vec<_> = batch_normalize(&mut v).collect();
for i in 0..10 {
assert!(expected[i] == result1[i]);
Expand Down

0 comments on commit 53c4895

Please sign in to comment.