From 6dcc038c522606cca08baf9abfe83da001b3379e Mon Sep 17 00:00:00 2001 From: Andrew Milson Date: Fri, 24 May 2024 23:17:51 -0400 Subject: [PATCH] Implement LogupOps for SIMD backend --- .../src/core/backend/cpu/lookups/gkr.rs | 39 +- .../src/core/backend/simd/lookups/gkr.rs | 443 +++++++++++++++++- .../src/core/backend/simd/lookups/mod.rs | 1 - crates/prover/src/core/lookups/utils.rs | 57 ++- 4 files changed, 476 insertions(+), 64 deletions(-) diff --git a/crates/prover/src/core/backend/cpu/lookups/gkr.rs b/crates/prover/src/core/backend/cpu/lookups/gkr.rs index 3e35c4017..6c227ec17 100644 --- a/crates/prover/src/core/backend/cpu/lookups/gkr.rs +++ b/crates/prover/src/core/backend/cpu/lookups/gkr.rs @@ -1,4 +1,4 @@ -use std::ops::{Add, Index}; +use std::ops::Index; use num_traits::{One, Zero}; @@ -11,7 +11,7 @@ use crate::core::lookups::gkr_prover::{ }; use crate::core::lookups::mle::{Mle, MleOps}; use crate::core::lookups::sumcheck::MultivariatePolyOracle; -use crate::core::lookups::utils::{Fraction, UnivariatePoly}; +use crate::core::lookups::utils::{Fraction, Reciprocal, UnivariatePoly}; impl GkrOps for CpuBackend { fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Mle { @@ -185,23 +185,6 @@ fn process_logup_singles_sum( n_terms: usize, lambda: SecureField, ) { - /// Represents the fraction `1 / x` - struct Reciprocal { - x: SecureField, - } - - impl Add for Reciprocal { - type Output = Fraction; - - fn add(self, rhs: Self) -> Fraction { - // `1/a + 1/b = (a + b)/(a * b)` - Fraction { - numerator: self.x + rhs.x, - denominator: self.x * rhs.x, - } - } - } - for i in 0..n_terms { // Input polynomial at points `(r, {0, 1, 2}, bits(i), {0, 1})`. let inp_denom_at_r0i0 = denominators[i * 2]; @@ -220,19 +203,11 @@ fn process_logup_singles_sum( let Fraction { numerator: numer_at_r0i, denominator: denom_at_r0i, - } = Reciprocal { - x: inp_denom_at_r0i0, - } + Reciprocal { - x: inp_denom_at_r0i1, - }; + } = Reciprocal::new(inp_denom_at_r0i0) + Reciprocal::new(inp_denom_at_r0i1); let Fraction { numerator: numer_at_r2i, denominator: denom_at_r2i, - } = Reciprocal { - x: inp_denom_at_r2i0, - } + Reciprocal { - x: inp_denom_at_r2i1, - }; + } = Reciprocal::new(inp_denom_at_r2i0) + Reciprocal::new(inp_denom_at_r2i1); let eq_eval_at_0i = eq_evals[i]; *eval_at_0 += eq_eval_at_0i * (numer_at_r0i + lambda * denom_at_r0i); @@ -375,7 +350,7 @@ mod tests { let denominator_values = (0..N).map(|_| rng.gen()).collect::>(); let sum = zip(&numerator_values, &denominator_values) .map(|(&n, &d)| Fraction::new(n, d)) - .sum::>(); + .sum::>(); let numerators = Mle::::new(numerator_values); let denominators = Mle::::new(denominator_values); let top_layer = Layer::LogUpGeneric { @@ -417,7 +392,7 @@ mod tests { let sum = denominator_values .iter() .map(|&d| Fraction::new(SecureField::one(), d)) - .sum::>(); + .sum::>(); let denominators = Mle::::new(denominator_values); let top_layer = Layer::LogUpSingles { denominators: denominators.clone(), @@ -451,7 +426,7 @@ mod tests { let denominator_values = (0..N).map(|_| rng.gen()).collect::>(); let sum = zip(&numerator_values, &denominator_values) .map(|(&n, &d)| Fraction::new(n.into(), d)) - .sum::>(); + .sum::>(); let numerators = Mle::::new(numerator_values); let denominators = Mle::::new(denominator_values); let top_layer = Layer::LogUpMultiplicities { diff --git a/crates/prover/src/core/backend/simd/lookups/gkr.rs b/crates/prover/src/core/backend/simd/lookups/gkr.rs index ce2cc5a37..427f5226d 100644 --- a/crates/prover/src/core/backend/simd/lookups/gkr.rs +++ b/crates/prover/src/core/backend/simd/lookups/gkr.rs @@ -8,13 +8,14 @@ use crate::core::backend::simd::m31::{LOG_N_LANES, N_LANES}; use crate::core::backend::simd::qm31::PackedSecureField; use crate::core::backend::simd::SimdBackend; use crate::core::backend::{Column, CpuBackend}; +use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::lookups::gkr_prover::{ correct_sum_as_poly_in_first_variable, EqEvals, GkrMultivariatePolyOracle, GkrOps, Layer, }; use crate::core::lookups::mle::Mle; use crate::core::lookups::sumcheck::MultivariatePolyOracle; -use crate::core::lookups::utils::UnivariatePoly; +use crate::core::lookups::utils::{Fraction, Reciprocal, UnivariatePoly}; impl GkrOps for SimdBackend { #[allow(clippy::uninit_vec)] @@ -61,14 +62,14 @@ impl GkrOps for SimdBackend { match layer { Layer::GrandProduct(col) => next_grand_product_layer(col), Layer::LogUpGeneric { - numerators: _, - denominators: _, - } => todo!(), + numerators, + denominators, + } => next_logup_generic_layer(numerators, denominators), Layer::LogUpMultiplicities { - numerators: _, - denominators: _, - } => todo!(), - Layer::LogUpSingles { denominators: _ } => todo!(), + numerators, + denominators, + } => next_logup_multiplicities_layer(numerators, denominators), + Layer::LogUpSingles { denominators } => next_logup_singles_layer(denominators), } } @@ -88,6 +89,7 @@ impl GkrOps for SimdBackend { } let n_packed_terms = n_terms / N_LANES; + let packed_lambda = PackedSecureField::broadcast(h.lambda); let mut packed_eval_at_0 = PackedSecureField::zero(); let mut packed_eval_at_2 = PackedSecureField::zero(); @@ -100,14 +102,37 @@ impl GkrOps for SimdBackend { n_packed_terms, ), Layer::LogUpGeneric { - numerators: _, - denominators: _, - } => todo!(), + numerators, + denominators, + } => process_logup_generic_sum( + &mut packed_eval_at_0, + &mut packed_eval_at_2, + eq_evals, + numerators, + denominators, + n_packed_terms, + packed_lambda, + ), Layer::LogUpMultiplicities { - numerators: _, - denominators: _, - } => todo!(), - Layer::LogUpSingles { denominators: _ } => todo!(), + numerators, + denominators, + } => process_logup_multiplicities_sum( + &mut packed_eval_at_0, + &mut packed_eval_at_2, + eq_evals, + numerators, + denominators, + n_packed_terms, + packed_lambda, + ), + Layer::LogUpSingles { denominators } => process_logup_singles_sum( + &mut packed_eval_at_0, + &mut packed_eval_at_2, + eq_evals, + denominators, + n_packed_terms, + packed_lambda, + ), } // Corrects the difference between two univariate sums in `t`: @@ -121,7 +146,7 @@ impl GkrOps for SimdBackend { } } -// Can assume `len(layer) > N_LANES * 2` +// Can assume `len(layer) > N_LANES`. fn next_grand_product_layer(layer: &Mle) -> Layer { assert!(layer.len() > N_LANES); let next_layer_len = layer.len() / 2; @@ -141,6 +166,128 @@ fn next_grand_product_layer(layer: &Mle) -> Layer, + denominators: &Mle, +) -> Layer { + assert!(denominators.len() > N_LANES); + assert_eq!(numerators.len(), denominators.len()); + + let next_layer_len = denominators.len() / 2; + let next_layer_packed_len = next_layer_len / N_LANES; + + let mut next_numerators = Vec::with_capacity(next_layer_packed_len); + let mut next_denominators = Vec::with_capacity(next_layer_packed_len); + + for i in 0..next_layer_packed_len { + let (n_even, n_odd) = numerators.data[i * 2].deinterleave(numerators.data[i * 2 + 1]); + let (d_even, d_odd) = denominators.data[i * 2].deinterleave(denominators.data[i * 2 + 1]); + + let Fraction { + numerator, + denominator, + } = Fraction::new(n_even, d_even) + Fraction::new(n_odd, d_odd); + + next_numerators.push(numerator); + next_denominators.push(denominator); + } + + let next_numerators = SecureFieldVec { + data: next_numerators, + length: next_layer_len, + }; + + let next_denominators = SecureFieldVec { + data: next_denominators, + length: next_layer_len, + }; + + Layer::LogUpGeneric { + numerators: Mle::new(next_numerators), + denominators: Mle::new(next_denominators), + } +} + +// TODO: Code duplication of `next_logup_generic_layer`. Consider unifying these. +fn next_logup_multiplicities_layer( + numerators: &Mle, + denominators: &Mle, +) -> Layer { + assert!(denominators.len() > N_LANES); + assert_eq!(numerators.len(), denominators.len()); + + let next_layer_len = denominators.len() / 2; + let next_layer_packed_len = next_layer_len / N_LANES; + + let mut next_numerators = Vec::with_capacity(next_layer_packed_len); + let mut next_denominators = Vec::with_capacity(next_layer_packed_len); + + for i in 0..next_layer_packed_len { + let (n_even, n_odd) = numerators.data[i * 2].deinterleave(numerators.data[i * 2 + 1]); + let (d_even, d_odd) = denominators.data[i * 2].deinterleave(denominators.data[i * 2 + 1]); + + let Fraction { + numerator, + denominator, + } = Fraction::new(n_even, d_even) + Fraction::new(n_odd, d_odd); + + next_numerators.push(numerator); + next_denominators.push(denominator); + } + + let next_numerators = SecureFieldVec { + data: next_numerators, + length: next_layer_len, + }; + + let next_denominators = SecureFieldVec { + data: next_denominators, + length: next_layer_len, + }; + + Layer::LogUpGeneric { + numerators: Mle::new(next_numerators), + denominators: Mle::new(next_denominators), + } +} + +fn next_logup_singles_layer(denominators: &Mle) -> Layer { + assert!(denominators.len() > N_LANES); + + let next_layer_len = denominators.len() / 2; + let next_layer_packed_len = next_layer_len / N_LANES; + + let mut next_numerators = Vec::with_capacity(next_layer_packed_len); + let mut next_denominators = Vec::with_capacity(next_layer_packed_len); + + for i in 0..next_layer_packed_len { + let (d_even, d_odd) = denominators.data[i * 2].deinterleave(denominators.data[i * 2 + 1]); + + let Fraction { + numerator, + denominator, + } = Reciprocal::new(d_even) + Reciprocal::new(d_odd); + + next_numerators.push(numerator); + next_denominators.push(denominator); + } + + let next_numerators = SecureFieldVec { + data: next_numerators, + length: next_layer_len, + }; + + let next_denominators = SecureFieldVec { + data: next_denominators, + length: next_layer_len, + }; + + Layer::LogUpGeneric { + numerators: Mle::new(next_numerators), + denominators: Mle::new(next_denominators), + } +} + fn process_grand_product_sum( packed_eval_at_0: &mut PackedSecureField, packed_eval_at_2: &mut PackedSecureField, @@ -170,6 +317,151 @@ fn process_grand_product_sum( } } +fn process_logup_generic_sum( + packed_eval_at_0: &mut PackedSecureField, + packed_eval_at_2: &mut PackedSecureField, + eq_evals: &EqEvals, + numerators: &Mle, + denominators: &Mle, + n_packed_terms: usize, + packed_lambda: PackedSecureField, +) { + let inp_numer = &numerators.data; + let inp_denom = &denominators.data; + + for i in 0..n_packed_terms { + // Input polynomials at points `(r, {0, 1, 2}, bits(i), v, {0, 1})` + // for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. + let (inp_numer_at_r0iv0, inp_numer_at_r0iv1) = + inp_numer[i * 2].deinterleave(inp_numer[i * 2 + 1]); + let (inp_denom_at_r0iv0, inp_denom_at_r0iv1) = + inp_denom[i * 2].deinterleave(inp_denom[i * 2 + 1]); + let (inp_numer_at_r1iv0, inp_numer_at_r1iv1) = inp_numer[(n_packed_terms + i) * 2] + .deinterleave(inp_numer[(n_packed_terms + i) * 2 + 1]); + let (inp_denom_at_r1iv0, inp_denom_at_r1iv1) = inp_denom[(n_packed_terms + i) * 2] + .deinterleave(inp_denom[(n_packed_terms + i) * 2 + 1]); + // Note `inp_denom(r, t, x) = eq(t, 0) * inp_denom(r, 0, x) + eq(t, 1) * inp_denom(r, 1, x)` + // => `inp_denom(r, 2, x) = 2 * inp_denom(r, 1, x) - inp_denom(r, 0, x)` + let inp_numer_at_r2iv0 = inp_numer_at_r1iv0.double() - inp_numer_at_r0iv0; + let inp_numer_at_r2iv1 = inp_numer_at_r1iv1.double() - inp_numer_at_r0iv1; + let inp_denom_at_r2iv0 = inp_denom_at_r1iv0.double() - inp_denom_at_r0iv0; + let inp_denom_at_r2iv1 = inp_denom_at_r1iv1.double() - inp_denom_at_r0iv1; + + // Fraction addition polynomials: + // - `numer(x) = inp_numer(x, 0) * inp_denom(x, 1) + inp_numer(x, 1) * inp_denom(x, 0)` + // - `denom(x) = inp_denom(x, 0) * inp_denom(x, 1)`. + // at points `(r, {0, 2}, bits(i), v)` for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. + let Fraction { + numerator: numer_at_r0iv, + denominator: denom_at_r0iv, + } = Fraction::new(inp_numer_at_r0iv0, inp_denom_at_r0iv0) + + Fraction::new(inp_numer_at_r0iv1, inp_denom_at_r0iv1); + let Fraction { + numerator: numer_at_r2iv, + denominator: denom_at_r2iv, + } = Fraction::new(inp_numer_at_r2iv0, inp_denom_at_r2iv0) + + Fraction::new(inp_numer_at_r2iv1, inp_denom_at_r2iv1); + + let eq_eval_at_0iv = eq_evals.data[i]; + *packed_eval_at_0 += eq_eval_at_0iv * (numer_at_r0iv + packed_lambda * denom_at_r0iv); + *packed_eval_at_2 += eq_eval_at_0iv * (numer_at_r2iv + packed_lambda * denom_at_r2iv); + } +} + +// Can assume `n_terms > N_LANES`. +// TODO: Code duplication of `process_logup_generic_sum`. Consider unifying these. +fn process_logup_multiplicities_sum( + packed_eval_at_0: &mut PackedSecureField, + packed_eval_at_2: &mut PackedSecureField, + eq_evals: &EqEvals, + numerators: &Mle, + denominators: &Mle, + n_packed_terms: usize, + packed_lambda: PackedSecureField, +) { + let inp_numer = &numerators.data; + let inp_denom = &denominators.data; + + for i in 0..n_packed_terms { + // Input polynomials at points `(r, {0, 1, 2}, bits(i), v, {0, 1})` + // for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. + let (inp_numer_at_r0iv0, inp_numer_at_r0iv1) = + inp_numer[i * 2].deinterleave(inp_numer[i * 2 + 1]); + let (inp_denom_at_r0iv0, inp_denom_at_r0iv1) = + inp_denom[i * 2].deinterleave(inp_denom[i * 2 + 1]); + let (inp_numer_at_r1iv0, inp_numer_at_r1iv1) = inp_numer[(n_packed_terms + i) * 2] + .deinterleave(inp_numer[(n_packed_terms + i) * 2 + 1]); + let (inp_denom_at_r1iv0, inp_denom_at_r1iv1) = inp_denom[(n_packed_terms + i) * 2] + .deinterleave(inp_denom[(n_packed_terms + i) * 2 + 1]); + // Note `inp_denom(r, t, x) = eq(t, 0) * inp_denom(r, 0, x) + eq(t, 1) * inp_denom(r, 1, x)` + // => `inp_denom(r, 2, x) = 2 * inp_denom(r, 1, x) - inp_denom(r, 0, x)` + let inp_numer_at_r2iv0 = inp_numer_at_r1iv0.double() - inp_numer_at_r0iv0; + let inp_numer_at_r2iv1 = inp_numer_at_r1iv1.double() - inp_numer_at_r0iv1; + let inp_denom_at_r2iv0 = inp_denom_at_r1iv0.double() - inp_denom_at_r0iv0; + let inp_denom_at_r2iv1 = inp_denom_at_r1iv1.double() - inp_denom_at_r0iv1; + + // Fraction addition polynomials: + // - `numer(x) = inp_numer(x, 0) * inp_denom(x, 1) + inp_numer(x, 1) * inp_denom(x, 0)` + // - `denom(x) = inp_denom(x, 0) * inp_denom(x, 1)`. + // at points `(r, {0, 2}, bits(i), v)` for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. + let Fraction { + numerator: numer_at_r0iv, + denominator: denom_at_r0iv, + } = Fraction::new(inp_numer_at_r0iv0, inp_denom_at_r0iv0) + + Fraction::new(inp_numer_at_r0iv1, inp_denom_at_r0iv1); + let Fraction { + numerator: numer_at_r2iv, + denominator: denom_at_r2iv, + } = Fraction::new(inp_numer_at_r2iv0, inp_denom_at_r2iv0) + + Fraction::new(inp_numer_at_r2iv1, inp_denom_at_r2iv1); + + let eq_eval_at_0iv = eq_evals.data[i]; + *packed_eval_at_0 += eq_eval_at_0iv * (numer_at_r0iv + packed_lambda * denom_at_r0iv); + *packed_eval_at_2 += eq_eval_at_0iv * (numer_at_r2iv + packed_lambda * denom_at_r2iv); + } +} + +fn process_logup_singles_sum( + packed_eval_at_0: &mut PackedSecureField, + packed_eval_at_2: &mut PackedSecureField, + eq_evals: &EqEvals, + denominators: &Mle, + n_packed_terms: usize, + packed_lambda: PackedSecureField, +) { + let inp_denom = &denominators.data; + + for i in 0..n_packed_terms { + // Input polynomial at points `(r, {0, 1, 2}, bits(i), v, {0, 1})` + // for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. + let (inp_denom_at_r0iv0, inp_denom_at_r0iv1) = + inp_denom[i * 2].deinterleave(inp_denom[i * 2 + 1]); + let (inp_denom_at_r1iv0, inp_denom_at_r1iv1) = inp_denom[(n_packed_terms + i) * 2] + .deinterleave(inp_denom[(n_packed_terms + i) * 2 + 1]); + // Note `inp_denom(r, t, x) = eq(t, 0) * inp_denom(r, 0, x) + eq(t, 1) * inp_denom(r, 1, x)` + // => `inp_denom(r, 2, x) = 2 * inp_denom(r, 1, x) - inp_denom(r, 0, x)` + let inp_denom_at_r2iv0 = inp_denom_at_r1iv0.double() - inp_denom_at_r0iv0; + let inp_denom_at_r2iv1 = inp_denom_at_r1iv1.double() - inp_denom_at_r0iv1; + + // Fraction addition polynomials: + // - `numer(x) = inp_denom(x, 1) + inp_denom(x, 0)` + // - `denom(x) = inp_denom(x, 0) * inp_denom(x, 1)`. + // at points `(r, {0, 2}, bits(i), v)` for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. + let Fraction { + numerator: numer_at_r0iv, + denominator: denom_at_r0iv, + } = Reciprocal::new(inp_denom_at_r0iv0) + Reciprocal::new(inp_denom_at_r0iv1); + let Fraction { + numerator: numer_at_r2iv, + denominator: denom_at_r2iv, + } = Reciprocal::new(inp_denom_at_r2iv0) + Reciprocal::new(inp_denom_at_r2iv1); + + let eq_eval_at_0iv = eq_evals.data[i]; + *packed_eval_at_0 += eq_eval_at_0iv * (numer_at_r0iv + packed_lambda * denom_at_r0iv); + *packed_eval_at_2 += eq_eval_at_0iv * (numer_at_r2iv + packed_lambda * denom_at_r2iv); + } +} + fn into_simd_layer(cpu_layer: Layer) -> Layer { match cpu_layer { Layer::GrandProduct(mle) => { @@ -197,6 +489,12 @@ fn into_simd_layer(cpu_layer: Layer) -> Layer { #[cfg(test)] mod tests { + use std::iter::zip; + + use num_traits::One; + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + use crate::core::backend::simd::SimdBackend; use crate::core::backend::{Column, CpuBackend}; use crate::core::channel::Channel; @@ -205,6 +503,7 @@ mod tests { use crate::core::lookups::gkr_prover::{prove_batch, GkrOps, Layer}; use crate::core::lookups::gkr_verifier::{partially_verify_batch, Gate, GkrArtifact, GkrError}; use crate::core::lookups::mle::Mle; + use crate::core::lookups::utils::Fraction; use crate::core::test_utils::test_channel; #[test] @@ -251,4 +550,116 @@ mod tests { ); Ok(()) } + + #[test] + fn logup_with_generic_trace_works() -> Result<(), GkrError> { + const N: usize = 1 << 8; + let mut rng = SmallRng::seed_from_u64(0); + let numerators = (0..N).map(|_| rng.gen()).collect::>(); + let denominators = (0..N).map(|_| rng.gen()).collect::>(); + let sum = zip(&numerators, &denominators) + .map(|(&n, &d)| Fraction::new(n, d)) + .sum::>(); + let numerators = Mle::::new(numerators.into_iter().collect()); + let denominators = Mle::::new(denominators.into_iter().collect()); + let input_layer = Layer::LogUpGeneric { + numerators: numerators.clone(), + denominators: denominators.clone(), + }; + let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]); + + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance: _, + } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; + + assert_eq!(claims_to_verify_by_instance.len(), 1); + assert_eq!(proof.output_claims_by_instance.len(), 1); + assert_eq!( + claims_to_verify_by_instance[0], + [ + numerators.eval_at_point(&ood_point), + denominators.eval_at_point(&ood_point) + ] + ); + assert_eq!( + proof.output_claims_by_instance[0], + [sum.numerator, sum.denominator] + ); + Ok(()) + } + + #[test] + fn logup_with_multiplicities_trace_works() -> Result<(), GkrError> { + const N: usize = 1 << 8; + let mut rng = SmallRng::seed_from_u64(0); + let numerators = (0..N).map(|_| rng.gen()).collect::>(); + let denominators = (0..N).map(|_| rng.gen()).collect::>(); + let sum = zip(&numerators, &denominators) + .map(|(&n, &d)| Fraction::new(n.into(), d)) + .sum::>(); + let numerators = Mle::::new(numerators.into_iter().collect()); + let denominators = Mle::::new(denominators.into_iter().collect()); + let input_layer = Layer::LogUpMultiplicities { + numerators: numerators.clone(), + denominators: denominators.clone(), + }; + let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]); + + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance: _, + } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; + + assert_eq!(claims_to_verify_by_instance.len(), 1); + assert_eq!(proof.output_claims_by_instance.len(), 1); + assert_eq!( + claims_to_verify_by_instance[0], + [ + numerators.eval_at_point(&ood_point), + denominators.eval_at_point(&ood_point) + ] + ); + assert_eq!( + proof.output_claims_by_instance[0], + [sum.numerator, sum.denominator] + ); + Ok(()) + } + + #[test] + fn logup_with_singles_trace_works() -> Result<(), GkrError> { + const N: usize = 1 << 8; + let mut rng = SmallRng::seed_from_u64(0); + let denominators = (0..N).map(|_| rng.gen()).collect::>(); + let sum = denominators + .iter() + .map(|&d| Fraction::new(SecureField::one(), d)) + .sum::>(); + let denominators = Mle::::new(denominators.into_iter().collect()); + let input_layer = Layer::LogUpSingles { + denominators: denominators.clone(), + }; + let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]); + + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance: _, + } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; + + assert_eq!(claims_to_verify_by_instance.len(), 1); + assert_eq!(proof.output_claims_by_instance.len(), 1); + assert_eq!( + claims_to_verify_by_instance[0], + [SecureField::one(), denominators.eval_at_point(&ood_point)] + ); + assert_eq!( + proof.output_claims_by_instance[0], + [sum.numerator, sum.denominator] + ); + Ok(()) + } } diff --git a/crates/prover/src/core/backend/simd/lookups/mod.rs b/crates/prover/src/core/backend/simd/lookups/mod.rs index e2ee801c4..34395e985 100644 --- a/crates/prover/src/core/backend/simd/lookups/mod.rs +++ b/crates/prover/src/core/backend/simd/lookups/mod.rs @@ -1,3 +1,2 @@ mod gkr; -// mod grandproduct; mod mle; diff --git a/crates/prover/src/core/lookups/utils.rs b/crates/prover/src/core/lookups/utils.rs index 70adb4c64..da6875e52 100644 --- a/crates/prover/src/core/lookups/utils.rs +++ b/crates/prover/src/core/lookups/utils.rs @@ -196,13 +196,13 @@ where /// Projective fraction. #[derive(Debug, Clone, Copy)] -pub struct Fraction { - pub numerator: F, - pub denominator: SecureField, +pub struct Fraction { + pub numerator: N, + pub denominator: D, } -impl Fraction { - pub fn new(numerator: F, denominator: SecureField) -> Self { +impl Fraction { + pub fn new(numerator: N, denominator: D) -> Self { Self { numerator, denominator, @@ -210,14 +210,12 @@ impl Fraction { } } -impl Add for Fraction -where - F: Field, - SecureField: ExtensionOf + Field, +impl + Add + Mul + Mul + Copy> Add + for Fraction { - type Output = Fraction; + type Output = Fraction; - fn add(self, rhs: Self) -> Fraction { + fn add(self, rhs: Self) -> Fraction { Fraction { numerator: rhs.denominator * self.numerator + self.denominator * rhs.numerator, denominator: self.denominator * rhs.denominator, @@ -225,11 +223,14 @@ where } } -impl Zero for Fraction { +impl Zero for Fraction +where + Self: Add, +{ fn zero() -> Self { Self { - numerator: SecureField::zero(), - denominator: SecureField::one(), + numerator: N::zero(), + denominator: D::one(), } } @@ -238,13 +239,39 @@ impl Zero for Fraction { } } -impl Sum for Fraction { +impl Sum for Fraction +where + Self: Zero, +{ fn sum>(mut iter: I) -> Self { let first = iter.next().unwrap_or_else(Self::zero); iter.fold(first, |a, b| a + b) } } +/// Represents the fraction `1 / x` +pub struct Reciprocal { + x: T, +} + +impl Reciprocal { + pub fn new(x: T) -> Self { + Self { x } + } +} + +impl + Mul + Copy> Add for Reciprocal { + type Output = Fraction; + + fn add(self, rhs: Self) -> Fraction { + // `1/a + 1/b = (a + b)/(a * b)` + Fraction { + numerator: self.x + rhs.x, + denominator: self.x * rhs.x, + } + } +} + #[cfg(test)] mod tests { use std::iter::zip;