Skip to content

Commit

Permalink
Add GKR implementation of Grand Product lookups
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed May 19, 2024
1 parent a014862 commit 164390d
Show file tree
Hide file tree
Showing 8 changed files with 367 additions and 18 deletions.
85 changes: 85 additions & 0 deletions crates/prover/src/core/backend/cpu/lookups/grandproduct.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
use num_traits::Zero;

use crate::core::backend::CpuBackend;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::Field;
use crate::core::lookups::gkr::correct_sum_as_poly_in_first_variable;
use crate::core::lookups::grandproduct::{GrandProductOps, GrandProductOracle, GrandProductTrace};
use crate::core::lookups::mle::Mle;
use crate::core::lookups::sumcheck::MultivariatePolyOracle;
use crate::core::lookups::utils::UnivariatePoly;

impl GrandProductOps for CpuBackend {
fn next_layer(layer: &GrandProductTrace<Self>) -> GrandProductTrace<Self> {
let res = layer.array_chunks().map(|&[a, b]| a * b).collect();
GrandProductTrace::new(Mle::new(res))
}

fn sum_as_poly_in_first_variable(
h: &GrandProductOracle<'_, Self>,
claim: SecureField,
) -> UnivariatePoly<SecureField> {
let k = h.n_variables();
let n_terms = 1 << (k - 1);
let eq_evals = h.eq_evals();
let y = eq_evals.y();
let trace = h.trace();

let mut eval_at_0 = SecureField::zero();
let mut eval_at_2 = SecureField::zero();

#[allow(clippy::needless_range_loop)]
for i in 0..n_terms {
let lhs0 = trace[i * 2];
let lhs1 = trace[i * 2 + 1];

let rhs0 = trace[(n_terms + i) * 2];
let rhs1 = trace[(n_terms + i) * 2 + 1];

let product2 = (rhs0.double() - lhs0) * (rhs1.double() - lhs1);
let product0 = lhs0 * lhs1;

let eq_eval = eq_evals[i];
eval_at_0 += eq_eval * product0;
eval_at_2 += eq_eval * product2;
}

eval_at_0 *= h.eq_fixed_var_correction();
eval_at_2 *= h.eq_fixed_var_correction();

correct_sum_as_poly_in_first_variable(eval_at_0, eval_at_2, claim, y, k)
}
}

#[cfg(test)]
mod tests {
use crate::core::backend::CpuBackend;
use crate::core::channel::Channel;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::gkr::{partially_verify_batch, prove_batch, GkrArtifact, GkrError};
use crate::core::lookups::grandproduct::{GrandProductGate, GrandProductTrace};
use crate::core::lookups::mle::Mle;
use crate::core::test_utils::test_channel;

#[test]
fn grand_product_works() -> Result<(), GkrError> {
const N: usize = 1 << 5;
let values = test_channel().draw_felts(N);
let product = values.iter().product::<SecureField>();
let top_layer = GrandProductTrace::<CpuBackend>::new(Mle::new(values));
let (proof, _) = prove_batch(&mut test_channel(), vec![top_layer.clone()]);

let GkrArtifact {
ood_point,
claims_to_verify_by_component,
..
} = partially_verify_batch(vec![&GrandProductGate], &proof, &mut test_channel())?;

assert_eq!(proof.output_claims_by_component, [vec![product]]);
assert_eq!(
claims_to_verify_by_component,
[vec![top_layer.eval_at_point(&ood_point)]]
);
Ok(())
}
}
1 change: 1 addition & 0 deletions crates/prover/src/core/backend/cpu/lookups/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
mod gkr;
mod grandproduct;
mod mle;
132 changes: 121 additions & 11 deletions crates/prover/src/core/lookups/gkr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,21 @@ use crate::core::backend::{Col, Column, ColumnOps};
use crate::core::channel::Channel;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
use crate::core::lookups::sumcheck;

pub trait GkrOps: MleOps<SecureField> {
/// Returns evaluations `eq(x, y) * v` for all `x` in `{0, 1}^n`.
///
/// [`eq(x, y)`]: crate::core::lookups::utils::eq
/// See [`eq(x, y)`](eq).
fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Mle<Self, SecureField>;
}

/// Stores evaluations of [`eq(x, y)`] on all boolean hypercube points of the form
/// Stores evaluations of [`eq(x, y)`](eq) on all boolean hypercube points of the form
/// `x = (0, x_1, ..., x_{n-1})`.
///
/// Evaluations are stored in bit-reversed order i.e. `evals[0] = eq((0, ..., 0, 0), y)`,
/// `evals[1] = eq((0, ..., 0, 1), y)`, etc.
///
/// [`eq(x, y)`]: crate::core::lookups::utils::eq
pub struct EqEvals<B: ColumnOps<SecureField>> {
y: Vec<SecureField>,
evals: Mle<B, SecureField>,
Expand Down Expand Up @@ -105,7 +104,7 @@ pub trait GkrBinaryLayer: Sized {
/// implementation because the prover only has to generate `eq_evals` once for an entire batch
/// of multiple GKR layer instances.
///
/// [`eq(x, y)`]: crate::core::lookups::utils::eq
/// [`eq(x, y)`]: eq
/// [^note]: By "representing" we mean `g_i` agrees with the next layer's `c_i` on the boolean
/// hypercube that interpolates `c_i`.
fn into_multivariate_poly(
Expand Down Expand Up @@ -518,11 +517,122 @@ pub enum GkrError {
/// * `k` is zero or greater than the length of `y`.
/// * `z_0` is zero.
pub fn correct_sum_as_poly_in_first_variable(
_eval_at_0: SecureField,
_eval_at_2: SecureField,
_claim: SecureField,
_y: &[SecureField],
_k: usize,
eval_at_0: SecureField,
eval_at_2: SecureField,
claim: SecureField,
y: &[SecureField],
k: usize,
) -> UnivariatePoly<SecureField> {
todo!()
assert_ne!(k, 0);
let n = y.len();
assert!(k <= n);

let z = &y[n - k..];

// Corrects the difference between two sums:
// 1. `sum_x eq(({0}^(n-k+1), x), y) * p(t, x)`
// 2. `sum_x eq((0, x), z) * p(t, x)`
let eq_y_to_z_correction_factor = eq(&vec![SecureField::zero(); n - k], &y[0..n - k]).inverse();

// Corrects the difference between two sums:
// 1. `sum_x eq((0, x), z) * p(t, x)`
// 2. `sum_x eq((t, x), z) * p(t, x)`
let eq_correction_factor_at = |t| eq(&[t], &[z[0]]) / eq(&[SecureField::zero()], &[z[0]]);

// Let `v(t) = sum_x eq((0, x), z) * p(t, x)`. Apply trick from
// <https://ia.cr/2024/108> (section 3.2) to obtain `f` from `v`.
let t0: SecureField = BaseField::zero().into();
let t1: SecureField = BaseField::one().into();
let t2: SecureField = BaseField::from(2).into();
let t3: SecureField = BaseField::from(3).into();

// Obtain evals `v(0)`, `v(1)`, `v(2)`.
let mut y0 = eq_y_to_z_correction_factor * eval_at_0;
let mut y1 = (claim - y0) / eq_correction_factor_at(t1);
let mut y2 = eq_y_to_z_correction_factor * eval_at_2;

// Interpolate `v` to find `v(3)`. Note `v` has degree <= 2.
let v = UnivariatePoly::interpolate_lagrange(&[t0, t1, t2], &[y0, y1, y2]);
let mut y3 = v.eval_at_point(t3);

// Obtain evals of `f(0)`, `f(1)`, `f(2)`, `f(3)`.
y0 *= eq_correction_factor_at(t0);
y1 *= eq_correction_factor_at(t1);
y2 *= eq_correction_factor_at(t2);
y3 *= eq_correction_factor_at(t3);

// Interpolate `f(t)`. Note `f(t)` has degree <= 3.
UnivariatePoly::interpolate_lagrange(&[t0, t1, t2, t3], &[y0, y1, y2, y3])
}

#[cfg(test)]
mod tests {
use super::GkrError;
use crate::core::backend::CpuBackend;
use crate::core::channel::Channel;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::gkr::{partially_verify_batch, prove_batch, GkrArtifact};
use crate::core::lookups::grandproduct::{GrandProductGate, GrandProductTrace};
use crate::core::lookups::mle::Mle;
use crate::core::test_utils::test_channel;

#[test]
fn prove_batch_works() -> Result<(), GkrError> {
const LOG_N: usize = 5;
let mut channel = test_channel();
let col0 = GrandProductTrace::<CpuBackend>::new(Mle::new(channel.draw_felts(1 << LOG_N)));
let col1 = GrandProductTrace::<CpuBackend>::new(Mle::new(channel.draw_felts(1 << LOG_N)));
let product0 = col0.iter().product::<SecureField>();
let product1 = col1.iter().product::<SecureField>();
let top_layers = vec![col0.clone(), col1.clone()];
let (proof, _) = prove_batch(&mut test_channel(), top_layers);

let GkrArtifact {
ood_point,
claims_to_verify_by_component,
n_variables_by_component,
} = partially_verify_batch(vec![&GrandProductGate; 2], &proof, &mut test_channel())?;

assert_eq!(n_variables_by_component, [LOG_N, LOG_N]);
assert_eq!(proof.output_claims_by_component.len(), 2);
assert_eq!(claims_to_verify_by_component.len(), 2);
assert_eq!(proof.output_claims_by_component[0], &[product0]);
assert_eq!(proof.output_claims_by_component[1], &[product1]);
let claim0 = &claims_to_verify_by_component[0];
let claim1 = &claims_to_verify_by_component[1];
assert_eq!(claim0, &[col0.eval_at_point(&ood_point)]);
assert_eq!(claim1, &[col1.eval_at_point(&ood_point)]);
Ok(())
}

#[test]
fn prove_batch_with_different_sizes_works() -> Result<(), GkrError> {
const LOG_N0: usize = 5;
const LOG_N1: usize = 7;
let mut channel = test_channel();
let col0 = GrandProductTrace::<CpuBackend>::new(Mle::new(channel.draw_felts(1 << LOG_N0)));
let col1 = GrandProductTrace::<CpuBackend>::new(Mle::new(channel.draw_felts(1 << LOG_N1)));
let product0 = col0.iter().product::<SecureField>();
let product1 = col1.iter().product::<SecureField>();
let top_layers = vec![col0.clone(), col1.clone()];
let (proof, _) = prove_batch(&mut test_channel(), top_layers);

let GkrArtifact {
ood_point,
claims_to_verify_by_component,
n_variables_by_component,
} = partially_verify_batch(vec![&GrandProductGate; 2], &proof, &mut test_channel())?;

assert_eq!(n_variables_by_component, [LOG_N0, LOG_N1]);
assert_eq!(proof.output_claims_by_component.len(), 2);
assert_eq!(claims_to_verify_by_component.len(), 2);
assert_eq!(proof.output_claims_by_component[0], &[product0]);
assert_eq!(proof.output_claims_by_component[1], &[product1]);
let claim0 = &claims_to_verify_by_component[0];
let claim1 = &claims_to_verify_by_component[1];
let n_vars = ood_point.len();
assert_eq!(claim0, &[col0.eval_at_point(&ood_point[n_vars - LOG_N0..])]);
assert_eq!(claim1, &[col1.eval_at_point(&ood_point[n_vars - LOG_N1..])]);
Ok(())
}
}
Loading

0 comments on commit 164390d

Please sign in to comment.