From 162fdcc8ff01a4b17769b6626789cd7deed2c0a8 Mon Sep 17 00:00:00 2001 From: Andrew Milson Date: Wed, 8 May 2024 10:57:56 -0400 Subject: [PATCH] Add GKR implementation of Grand Product lookups --- .../src/core/backend/cpu/lookups/gkr.rs | 87 +++++++++- crates/prover/src/core/lookups/gkr_prover.rs | 159 +++++++++++++++--- .../prover/src/core/lookups/gkr_verifier.rs | 94 ++++++++++- crates/prover/src/core/lookups/mle.rs | 2 +- crates/prover/src/core/lookups/sumcheck.rs | 9 +- crates/prover/src/core/test_utils.rs | 3 +- 6 files changed, 316 insertions(+), 38 deletions(-) diff --git a/crates/prover/src/core/backend/cpu/lookups/gkr.rs b/crates/prover/src/core/backend/cpu/lookups/gkr.rs index 8d0831efa..e5ef2e167 100644 --- a/crates/prover/src/core/backend/cpu/lookups/gkr.rs +++ b/crates/prover/src/core/backend/cpu/lookups/gkr.rs @@ -1,22 +1,69 @@ +use num_traits::Zero; + use crate::core::backend::CpuBackend; use crate::core::fields::qm31::SecureField; -use crate::core::lookups::gkr_prover::{GkrMultivariatePolyOracle, GkrOps, Layer}; +use crate::core::fields::Field; +use crate::core::lookups::gkr_prover::{ + correct_sum_as_poly_in_first_variable, GkrMultivariatePolyOracle, GkrOps, Layer, +}; use crate::core::lookups::mle::Mle; +use crate::core::lookups::sumcheck::MultivariatePolyOracle; +use crate::core::lookups::utils::UnivariatePoly; impl GkrOps for CpuBackend { fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Mle { Mle::new(gen_eq_evals(y, v)) } - fn next_layer(_layer: &Layer) -> Layer { - todo!() + fn next_layer(layer: &Layer) -> Layer { + match layer { + Layer::_LogUp(_) => todo!(), + Layer::GrandProduct(layer) => { + let res = layer.array_chunks().map(|&[a, b]| a * b).collect(); + Layer::GrandProduct(Mle::new(res)) + } + } } fn sum_as_poly_in_first_variable( - _h: &GkrMultivariatePolyOracle, - _claim: SecureField, - ) -> crate::core::lookups::utils::UnivariatePoly { - todo!() + h: &GkrMultivariatePolyOracle<'_, Self>, + claim: SecureField, + ) -> UnivariatePoly { + let k = h.n_variables(); + let n_terms = 1 << (k - 1); + let eq_evals = h.eq_evals; + let y = eq_evals.y(); + let input_layer = &h.input_layer; + + let mut eval_at_0 = SecureField::zero(); + let mut eval_at_2 = SecureField::zero(); + + match input_layer { + Layer::_LogUp(_) => todo!(), + Layer::GrandProduct(input_layer) => + { + #[allow(clippy::needless_range_loop)] + for i in 0..n_terms { + let lhs0 = input_layer[i * 2]; + let lhs1 = input_layer[i * 2 + 1]; + + let rhs0 = input_layer[(n_terms + i) * 2]; + let rhs1 = input_layer[(n_terms + i) * 2 + 1]; + + let product2 = (rhs0.double() - lhs0) * (rhs1.double() - lhs1); + let product0 = lhs0 * lhs1; + + let eq_eval = eq_evals[i]; + eval_at_0 += eq_eval * product0; + eval_at_2 += eq_eval * product2; + } + } + } + + eval_at_0 *= h.eq_fixed_var_correction; + eval_at_2 *= h.eq_fixed_var_correction; + + correct_sum_as_poly_in_first_variable(eval_at_0, eval_at_2, claim, y, k) } } @@ -45,10 +92,14 @@ mod tests { use num_traits::{One, Zero}; use crate::core::backend::CpuBackend; + use crate::core::channel::Channel; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; - use crate::core::lookups::gkr_prover::GkrOps; + use crate::core::lookups::gkr_prover::{prove_batch, GkrOps, Layer}; + use crate::core::lookups::gkr_verifier::{partially_verify_batch, Gate, GkrArtifact, GkrError}; + use crate::core::lookups::mle::Mle; use crate::core::lookups::utils::eq; + use crate::core::test_utils::test_channel; #[test] fn gen_eq_evals() { @@ -68,4 +119,24 @@ mod tests { ] ); } + + #[test] + fn grand_product_works() -> Result<(), GkrError> { + const N: usize = 1 << 5; + let values = test_channel().draw_felts(N); + let product = values.iter().product::(); + let col = Mle::::new(values); + let input_layer = Layer::GrandProduct(col.clone()); + let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]); + + let GkrArtifact { + ood_point: r, + claims_to_verify_by_instance, + .. + } = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?; + + assert_eq!(proof.output_claims_by_instance, [vec![product]]); + assert_eq!(claims_to_verify_by_instance, [vec![col.eval_at_point(&r)]]); + Ok(()) + } } diff --git a/crates/prover/src/core/lookups/gkr_prover.rs b/crates/prover/src/core/lookups/gkr_prover.rs index 573f3edb3..963e98c60 100644 --- a/crates/prover/src/core/lookups/gkr_prover.rs +++ b/crates/prover/src/core/lookups/gkr_prover.rs @@ -12,10 +12,12 @@ use super::sumcheck::MultivariatePolyOracle; use super::utils::{eq, random_linear_combination, UnivariatePoly}; use crate::core::backend::{Col, Column, ColumnOps}; use crate::core::channel::Channel; +use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; +use crate::core::fields::FieldExpOps; use crate::core::lookups::sumcheck; -pub trait GkrOps: MleOps { +pub trait GkrOps: MleOps + MleOps { /// Returns evaluations `eq(x, y) * v` for all `x` in `{0, 1}^n`. /// /// [`eq(x, y)`]: crate::core::lookups::utils::eq @@ -30,7 +32,7 @@ pub trait GkrOps: MleOps { /// /// For more context see docs of [`MultivariatePolyOracle::sum_as_poly_in_first_variable()`]. fn sum_as_poly_in_first_variable( - h: &GkrMultivariatePolyOracle, + h: &GkrMultivariatePolyOracle<'_, Self>, claim: SecureField, ) -> UnivariatePoly; } @@ -83,13 +85,16 @@ impl> Deref for EqEvals { /// [LogUp]: https://eprint.iacr.org/2023/1284.pdf pub enum Layer { _LogUp(B), - _GrandProduct(B), + GrandProduct(Mle), } impl Layer { /// Returns the number of variables used to interpolate the layer's gate values. fn n_variables(&self) -> usize { - todo!() + match self { + Self::_LogUp(_) => todo!(), + Self::GrandProduct(mle) => mle.n_variables(), + } } /// Produces the next layer from the current layer. @@ -110,7 +115,28 @@ impl Layer { /// Returns each column output if the layer is an output layer, otherwise returns an `Err`. fn try_into_output_layer_values(self) -> Result, NotOutputLayerError> { - todo!() + if !self.is_output_layer() { + return Err(NotOutputLayerError); + } + + Ok(match self { + Self::GrandProduct(col) => { + vec![col.at(0)] + } + Self::_LogUp(_) => todo!(), + }) + } + + /// Returns a transformed layer with the first variable of each column fixed to `assignment`. + fn fix_first_variable(self, x0: SecureField) -> Self { + if self.n_variables() == 0 { + return self; + } + + match self { + Self::_LogUp(_) => todo!(), + Self::GrandProduct(mle) => Self::GrandProduct(mle.fix_first_variable(x0)), + } } /// Represents the next GKR layer evaluation as a multivariate polynomial which uses this GKR @@ -143,36 +169,53 @@ impl Layer { fn into_multivariate_poly( self, _lambda: SecureField, - _eq_evals: &EqEvals, - ) -> GkrMultivariatePolyOracle { - todo!() + eq_evals: &EqEvals, + ) -> GkrMultivariatePolyOracle<'_, B> { + GkrMultivariatePolyOracle { + eq_evals, + input_layer: self, + eq_fixed_var_correction: SecureField::one(), + } } } #[derive(Debug)] struct NotOutputLayerError; -/// A multivariate polynomial expressed in the multilinear poly columns of a [`Layer`]. -pub enum GkrMultivariatePolyOracle { - LogUp, - GrandProduct, +/// A multivariate polynomial that expresses the relation between two consecutive GKR layers. +pub struct GkrMultivariatePolyOracle<'a, B: GkrOps> { + /// `eq_evals` passed by [`GkrBinaryLayer::into_multivariate_poly()`]. + pub eq_evals: &'a EqEvals, + pub input_layer: Layer, + pub eq_fixed_var_correction: SecureField, } -impl MultivariatePolyOracle for GkrMultivariatePolyOracle { +impl<'a, B: GkrOps> MultivariatePolyOracle for GkrMultivariatePolyOracle<'a, B> { fn n_variables(&self) -> usize { - todo!() + self.input_layer.n_variables() - 1 } - fn sum_as_poly_in_first_variable(&self, _claim: SecureField) -> UnivariatePoly { - todo!() + fn sum_as_poly_in_first_variable(&self, claim: SecureField) -> UnivariatePoly { + B::sum_as_poly_in_first_variable(self, claim) } - fn fix_first_variable(self, _challenge: SecureField) -> Self { - todo!() + fn fix_first_variable(self, challenge: SecureField) -> Self { + if self.n_variables() == 0 { + return self; + } + + let z0 = self.eq_evals.y()[self.eq_evals.y().len() - self.n_variables()]; + let eq_fixed_var_correction = self.eq_fixed_var_correction * eq(&[challenge], &[z0]); + + Self { + eq_evals: self.eq_evals, + eq_fixed_var_correction, + input_layer: self.input_layer.fix_first_variable(challenge), + } } } -impl GkrMultivariatePolyOracle { +impl<'a, B: GkrOps> GkrMultivariatePolyOracle<'a, B> { /// Returns all input layer columns restricted to a line. /// /// Let `l` be the line satisfying `l(0) = b*` and `l(1) = c*`. Oracles that represent constants @@ -184,7 +227,14 @@ impl GkrMultivariatePolyOracle { /// /// For more context see page 64. fn try_into_mask(self) -> Result { - todo!() + if self.n_variables() != 0 { + return Err(NotConstantPolyError); + } + + match self.input_layer { + Layer::_LogUp(_) => todo!(), + Layer::GrandProduct(mle) => Ok(GkrMask::new(vec![mle.to_cpu().try_into().unwrap()])), + } } } @@ -312,3 +362,72 @@ fn gen_layers(input_layer: Layer) -> Vec> { assert_eq!(layers.len(), n_variables + 1); layers } + +/// Corrects and interpolates GKR instance sumcheck round polynomials that are generated with the +/// precomputed `eq(x, y)` evaluations provided by [`GkrBinaryLayer::into_multivariate_poly()`]. +/// +/// Let `y` be a fixed vector of length `n` and let `z` be a subvector comprising of the last `k` +/// elements of `y`. Returns the univariate polynomial `f(t) = sum_x eq((t, x), z) * p(t, x)` for +/// `x` in the boolean hypercube `{0, 1}^(k-1)` when provided with: +/// +/// * `claim` equalling `f(0) + f(1)`. +/// * `eval_at_0/2` equalling `sum_x eq(({0}^(n-k+1), x), y) * p(t, x)` at `t=0,2` respectively. +/// +/// Note that `f` must have degree <= 3. +/// +/// For more context see [`GkrBinaryLayer::into_multivariate_poly()`] docs. +/// See also (section 3.2). +/// +/// # Panics +/// +/// Panics if: +/// * `k` is zero or greater than the length of `y`. +/// * `z_0` is zero. +pub fn correct_sum_as_poly_in_first_variable( + eval_at_0: SecureField, + eval_at_2: SecureField, + claim: SecureField, + y: &[SecureField], + k: usize, +) -> UnivariatePoly { + assert_ne!(k, 0); + let n = y.len(); + assert!(k <= n); + + let z = &y[n - k..]; + + // Corrects the difference between two sums: + // 1. `sum_x eq(({0}^(n-k+1), x), y) * p(t, x)` + // 2. `sum_x eq((0, x), z) * p(t, x)` + let eq_y_to_z_correction_factor = eq(&vec![SecureField::zero(); n - k], &y[0..n - k]).inverse(); + + // Corrects the difference between two sums: + // 1. `sum_x eq((0, x), z) * p(t, x)` + // 2. `sum_x eq((t, x), z) * p(t, x)` + let eq_correction_factor_at = |t| eq(&[t], &[z[0]]) / eq(&[SecureField::zero()], &[z[0]]); + + // Let `v(t) = sum_x eq((0, x), z) * p(t, x)`. Apply trick from + // (section 3.2) to obtain `f` from `v`. + let t0: SecureField = BaseField::zero().into(); + let t1: SecureField = BaseField::one().into(); + let t2: SecureField = BaseField::from(2).into(); + let t3: SecureField = BaseField::from(3).into(); + + // Obtain evals `v(0)`, `v(1)`, `v(2)`. + let mut y0 = eq_y_to_z_correction_factor * eval_at_0; + let mut y1 = (claim - y0) / eq_correction_factor_at(t1); + let mut y2 = eq_y_to_z_correction_factor * eval_at_2; + + // Interpolate `v` to find `v(3)`. Note `v` has degree <= 2. + let v = UnivariatePoly::interpolate_lagrange(&[t0, t1, t2], &[y0, y1, y2]); + let mut y3 = v.eval_at_point(t3); + + // Obtain evals of `f(0)`, `f(1)`, `f(2)`, `f(3)`. + y0 *= eq_correction_factor_at(t0); + y1 *= eq_correction_factor_at(t1); + y2 *= eq_correction_factor_at(t2); + y3 *= eq_correction_factor_at(t3); + + // Interpolate `f(t)`. Note `f(t)` has degree <= 3. + UnivariatePoly::interpolate_lagrange(&[t0, t1, t2, t3], &[y0, y1, y2, y3]) +} diff --git a/crates/prover/src/core/lookups/gkr_verifier.rs b/crates/prover/src/core/lookups/gkr_verifier.rs index a789e2372..f7a14a31a 100644 --- a/crates/prover/src/core/lookups/gkr_verifier.rs +++ b/crates/prover/src/core/lookups/gkr_verifier.rs @@ -167,15 +167,26 @@ pub struct GkrArtifact { /// circuit) GKR prover implementations. /// /// [Thaler13]: https://eprint.iacr.org/2013/351.pdf +#[derive(Debug, Clone, Copy)] pub enum Gate { _LogUp, - _GrandProduct, + GrandProduct, } impl Gate { /// Returns the output after applying the gate to the mask. - fn eval(&self, _mask: &GkrMask) -> Result, InvalidNumMaskColumnsError> { - todo!() + fn eval(&self, mask: &GkrMask) -> Result, InvalidNumMaskColumnsError> { + Ok(match self { + Self::_LogUp => todo!(), + Self::GrandProduct => { + if mask.columns().len() != 1 { + return Err(InvalidNumMaskColumnsError); + } + + let [a, b] = mask.columns()[0]; + vec![a * b] + } + }) } } @@ -240,3 +251,80 @@ pub enum GkrError { layer: usize, }, } + +#[cfg(test)] +mod tests { + use super::{partially_verify_batch, Gate, GkrArtifact, GkrError}; + use crate::core::backend::CpuBackend; + use crate::core::channel::Channel; + use crate::core::fields::qm31::SecureField; + use crate::core::lookups::gkr_prover::{prove_batch, Layer}; + use crate::core::lookups::mle::Mle; + use crate::core::test_utils::test_channel; + + #[test] + fn prove_batch_works() -> Result<(), GkrError> { + const LOG_N: usize = 5; + let mut channel = test_channel(); + let col0 = Mle::::new(channel.draw_felts(1 << LOG_N)); + let col1 = Mle::::new(channel.draw_felts(1 << LOG_N)); + let product0 = col0.iter().product::(); + let product1 = col1.iter().product::(); + let input_layers = vec![ + Layer::GrandProduct(col0.clone()), + Layer::GrandProduct(col1.clone()), + ]; + let (proof, _) = prove_batch(&mut test_channel(), input_layers); + + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance, + } = partially_verify_batch(vec![Gate::GrandProduct; 2], &proof, &mut test_channel())?; + + assert_eq!(n_variables_by_instance, [LOG_N, LOG_N]); + assert_eq!(proof.output_claims_by_instance.len(), 2); + assert_eq!(claims_to_verify_by_instance.len(), 2); + assert_eq!(proof.output_claims_by_instance[0], &[product0]); + assert_eq!(proof.output_claims_by_instance[1], &[product1]); + let claim0 = &claims_to_verify_by_instance[0]; + let claim1 = &claims_to_verify_by_instance[1]; + assert_eq!(claim0, &[col0.eval_at_point(&ood_point)]); + assert_eq!(claim1, &[col1.eval_at_point(&ood_point)]); + Ok(()) + } + + #[test] + fn prove_batch_with_different_sizes_works() -> Result<(), GkrError> { + const LOG_N0: usize = 5; + const LOG_N1: usize = 7; + let mut channel = test_channel(); + let col0 = Mle::::new(channel.draw_felts(1 << LOG_N0)); + let col1 = Mle::::new(channel.draw_felts(1 << LOG_N1)); + let product0 = col0.iter().product::(); + let product1 = col1.iter().product::(); + let input_layers = vec![ + Layer::GrandProduct(col0.clone()), + Layer::GrandProduct(col1.clone()), + ]; + let (proof, _) = prove_batch(&mut test_channel(), input_layers); + + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance, + } = partially_verify_batch(vec![Gate::GrandProduct; 2], &proof, &mut test_channel())?; + + assert_eq!(n_variables_by_instance, [LOG_N0, LOG_N1]); + assert_eq!(proof.output_claims_by_instance.len(), 2); + assert_eq!(claims_to_verify_by_instance.len(), 2); + assert_eq!(proof.output_claims_by_instance[0], &[product0]); + assert_eq!(proof.output_claims_by_instance[1], &[product1]); + let claim0 = &claims_to_verify_by_instance[0]; + let claim1 = &claims_to_verify_by_instance[1]; + let n_vars = ood_point.len(); + assert_eq!(claim0, &[col0.eval_at_point(&ood_point[n_vars - LOG_N0..])]); + assert_eq!(claim1, &[col1.eval_at_point(&ood_point[n_vars - LOG_N1..])]); + Ok(()) + } +} diff --git a/crates/prover/src/core/lookups/mle.rs b/crates/prover/src/core/lookups/mle.rs index cd6212686..7ac7f9eb3 100644 --- a/crates/prover/src/core/lookups/mle.rs +++ b/crates/prover/src/core/lookups/mle.rs @@ -72,7 +72,7 @@ mod test { B: MleOps, { /// Evaluates the multilinear polynomial at `point`. - pub(crate) fn eval_at_point(self, point: &[SecureField]) -> SecureField { + pub(crate) fn eval_at_point(&self, point: &[SecureField]) -> SecureField { pub fn eval(mle_evals: &[SecureField], p: &[SecureField]) -> SecureField { match p { [] => mle_evals[0], diff --git a/crates/prover/src/core/lookups/sumcheck.rs b/crates/prover/src/core/lookups/sumcheck.rs index 503c2d221..eb3ed28c6 100644 --- a/crates/prover/src/core/lookups/sumcheck.rs +++ b/crates/prover/src/core/lookups/sumcheck.rs @@ -186,20 +186,21 @@ pub struct SumcheckProof { pub const MAX_DEGREE: usize = 3; /// Sum-check protocol verification error. -/// -/// Round 0 corresponds to the first round. #[derive(Error, Debug)] pub enum SumcheckError { #[error("degree of the polynomial in round {round} is too high")] - DegreeInvalid { round: usize }, + DegreeInvalid { round: Round }, #[error("sum does not match the claim in round {round} (sum {sum}, claim {claim})")] SumInvalid { claim: SecureField, sum: SecureField, - round: usize, + round: Round, }, } +/// Round 0 corresponds to the first round. +pub type Round = usize; + #[cfg(test)] mod tests { diff --git a/crates/prover/src/core/test_utils.rs b/crates/prover/src/core/test_utils.rs index 083aa88ed..431c25778 100644 --- a/crates/prover/src/core/test_utils.rs +++ b/crates/prover/src/core/test_utils.rs @@ -3,6 +3,7 @@ use super::channel::Blake2sChannel; use super::fields::m31::BaseField; use super::fields::qm31::SecureField; use crate::core::channel::Channel; +use crate::core::vcs::blake2_hash::Blake2sHash; pub fn secure_eval_to_base_eval( eval: &CpuCircleEvaluation, @@ -14,8 +15,6 @@ pub fn secure_eval_to_base_eval( } pub fn test_channel() -> Blake2sChannel { - use crate::core::vcs::blake2_hash::Blake2sHash; - let seed = Blake2sHash::from(vec![0; 32]); Blake2sChannel::new(seed) }