diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs index bb400ddae..2f230e5b2 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs @@ -85,41 +85,6 @@ where ZeroKnowledgeProof::new(proof) } - pub fn compute_final_proof( - uv: &(GenericArray, GenericArray), - p_0: F, - q_0: F, - lagrange_table: &LagrangeTable, λ>, - ) -> ZeroKnowledgeProof> - where - λ: Add + Add, - <λ as Add>::Output: Add, - <<λ as Add>::Output as Add>::Output: ArrayLength, - <λ as Add>::Output: ArrayLength, - { - let mut p = GenericArray::>::generate(|_| F::ZERO); - let mut q = GenericArray::>::generate(|_| F::ZERO); - let mut proof: GenericArray> = GenericArray::generate(|_| F::ZERO); - p[0] = p_0; - q[0] = q_0; - proof[0] = p_0 * q_0; - - for i in 0..λ::USIZE { - p[i + 1] = uv.0[i]; - q[i + 1] = uv.1[i]; - proof[i + 1] += uv.0[i] * uv.1[i]; - } - // We need a table of size `λ + 1` since we add a random point at x=0 - let p_extrapolated = lagrange_table.eval(&p); - let q_extrapolated = lagrange_table.eval(&q); - - for (i, (x, y)) in zip(p_extrapolated.into_iter(), q_extrapolated.into_iter()).enumerate() { - proof[λ::USIZE + 1 + i] += x * y; - } - - ZeroKnowledgeProof::new(proof) - } - pub fn gen_challenge_and_recurse( proof_left: &GenericArray>, proof_right: &GenericArray>, @@ -194,17 +159,33 @@ where #[cfg(all(test, unit_test))] mod test { - use generic_array::{sequence::GenericSequence, GenericArray}; - use typenum::{U2, U3, U4, U7}; + use std::iter::zip; + + use generic_array::{sequence::GenericSequence, ArrayLength, GenericArray}; + use typenum::{U3, U4, U7}; use super::ProofGenerator; use crate::{ - ff::{Fp31, U128Conversions}, + ff::{Fp31, PrimeField, U128Conversions}, protocol::ipa_prf::malicious_security::lagrange::{ CanonicalLagrangeDenominator, LagrangeTable, }, }; + fn zip_chunks( + a: &[u128], + b: &[u128], + ) -> Vec<(GenericArray, GenericArray)> { + zip(a.chunks(U::USIZE), b.chunks(U::USIZE)) + .map(|(u_chunk, v_chunk)| { + ( + GenericArray::generate(|i| F::try_from(u_chunk[i]).unwrap()), + GenericArray::generate(|i| F::try_from(v_chunk[i]).unwrap()), + ) + }) + .collect::>() + } + #[test] fn sample_proof() { const U_1: [u128; 32] = [ @@ -222,51 +203,27 @@ mod test { const PROOF_2: [u128; 7] = [12, 6, 15, 8, 29, 30, 6]; const PROOF_LEFT_2: [u128; 7] = [5, 26, 14, 9, 0, 25, 2]; - const U_3: [u128; 2] = [3, 3]; - const V_3: [u128; 2] = [5, 24]; + const U_3: [u128; 4] = [3, 3, 0, 0]; // padded with zeroes + const V_3: [u128; 4] = [5, 24, 0, 0]; // padded with zeroes - const PROOF_3: [u128; 5] = [12, 15, 10, 14, 17]; + const PROOF_3: [u128; 7] = [12, 15, 10, 0, 18, 6, 5]; const P_RANDOM_WEIGHT: u128 = 12; const Q_RANDOM_WEIGHT: u128 = 1; let denominator = CanonicalLagrangeDenominator::::new(); let lagrange_table = LagrangeTable::::from(denominator); - // convert to field - let vec_u_1 = U_1 - .into_iter() - .map(|x| Fp31::try_from(x).unwrap()) - .collect::>(); - let vec_v_1 = V_1 - .into_iter() - .map(|x| Fp31::try_from(x).unwrap()) - .collect::>(); - let vec_u_2 = U_2 - .into_iter() - .map(|x| Fp31::try_from(x).unwrap()) - .collect::>(); - let vec_v_2 = V_2 - .into_iter() - .map(|x| Fp31::try_from(x).unwrap()) - .collect::>(); - - // uv values in input format - let uv_1 = (0usize..8) - .map(|i| { - ( - *GenericArray::::from_slice(&vec_u_1[4 * i..4 * i + 4]), - *GenericArray::::from_slice(&vec_v_1[4 * i..4 * i + 4]), - ) - }) - .collect::>(); - let uv_2 = (0usize..2) - .map(|i| { - ( - *GenericArray::::from_slice(&vec_u_2[4 * i..4 * i + 4]), - *GenericArray::::from_slice(&vec_v_2[4 * i..4 * i + 4]), - ) - }) - .collect::>(); + // uv values in input format (iterator of tuples of GenericArrays of length 4) + let uv_1 = zip_chunks(&U_1, &V_1); + let uv_2 = zip_chunks(&U_2, &V_2); + let uv_3 = { + let u_chunk = [P_RANDOM_WEIGHT, U_3[0], U_3[1], U_3[2]]; + let v_chunk = [Q_RANDOM_WEIGHT, V_3[0], V_3[1], V_3[2]]; + vec![( + GenericArray::generate(|i| Fp31::try_from(u_chunk[i]).unwrap()), + GenericArray::generate(|i| Fp31::try_from(v_chunk[i]).unwrap()), + )] + }; // first iteration let proof_1 = ProofGenerator::::compute_proof(uv_1.iter(), &lagrange_table); @@ -308,24 +265,10 @@ mod test { &proof_right_2, pg_2.uv.iter(), ); - - // final proof trim pg_3 from U4 to U2 - let uv = ( - *GenericArray::::from_slice(&pg_3.uv[0].0.as_slice()[0..2]), - *GenericArray::::from_slice(&pg_3.uv[0].1.as_slice()[0..2]), - ); - - assert_eq!(ProofGenerator { uv: vec![uv; 1] }, (&U_3[..], &V_3[..])); + assert_eq!(pg_3, (&U_3[..], &V_3[..])); // final iteration - let denominator = CanonicalLagrangeDenominator::::new(); - let lagrange_table = LagrangeTable::::from(denominator); - let proof_3 = ProofGenerator::::compute_final_proof( - &uv, - Fp31::try_from(P_RANDOM_WEIGHT).unwrap(), - Fp31::try_from(Q_RANDOM_WEIGHT).unwrap(), - &lagrange_table, - ); + let proof_3 = ProofGenerator::::compute_proof(uv_3.iter(), &lagrange_table); assert_eq!( proof_3.g.iter().map(Fp31::as_u128).collect::>(), PROOF_3,