Skip to content

Commit

Permalink
Add GKR implementation of Grand Product lookups
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed Jun 13, 2024
1 parent 38cef08 commit 162fdcc
Show file tree
Hide file tree
Showing 6 changed files with 316 additions and 38 deletions.
87 changes: 79 additions & 8 deletions crates/prover/src/core/backend/cpu/lookups/gkr.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,69 @@
use num_traits::Zero;

use crate::core::backend::CpuBackend;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::gkr_prover::{GkrMultivariatePolyOracle, GkrOps, Layer};
use crate::core::fields::Field;
use crate::core::lookups::gkr_prover::{
correct_sum_as_poly_in_first_variable, GkrMultivariatePolyOracle, GkrOps, Layer,
};
use crate::core::lookups::mle::Mle;
use crate::core::lookups::sumcheck::MultivariatePolyOracle;
use crate::core::lookups::utils::UnivariatePoly;

impl GkrOps for CpuBackend {
fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Mle<Self, SecureField> {
Mle::new(gen_eq_evals(y, v))
}

fn next_layer(_layer: &Layer<Self>) -> Layer<Self> {
todo!()
fn next_layer(layer: &Layer<Self>) -> Layer<Self> {
match layer {
Layer::_LogUp(_) => todo!(),
Layer::GrandProduct(layer) => {
let res = layer.array_chunks().map(|&[a, b]| a * b).collect();
Layer::GrandProduct(Mle::new(res))
}
}
}

fn sum_as_poly_in_first_variable(
_h: &GkrMultivariatePolyOracle,
_claim: SecureField,
) -> crate::core::lookups::utils::UnivariatePoly<SecureField> {
todo!()
h: &GkrMultivariatePolyOracle<'_, Self>,
claim: SecureField,
) -> UnivariatePoly<SecureField> {
let k = h.n_variables();
let n_terms = 1 << (k - 1);
let eq_evals = h.eq_evals;
let y = eq_evals.y();
let input_layer = &h.input_layer;

let mut eval_at_0 = SecureField::zero();
let mut eval_at_2 = SecureField::zero();

match input_layer {
Layer::_LogUp(_) => todo!(),
Layer::GrandProduct(input_layer) =>
{
#[allow(clippy::needless_range_loop)]
for i in 0..n_terms {
let lhs0 = input_layer[i * 2];
let lhs1 = input_layer[i * 2 + 1];

let rhs0 = input_layer[(n_terms + i) * 2];
let rhs1 = input_layer[(n_terms + i) * 2 + 1];

let product2 = (rhs0.double() - lhs0) * (rhs1.double() - lhs1);
let product0 = lhs0 * lhs1;

let eq_eval = eq_evals[i];
eval_at_0 += eq_eval * product0;
eval_at_2 += eq_eval * product2;
}
}
}

eval_at_0 *= h.eq_fixed_var_correction;
eval_at_2 *= h.eq_fixed_var_correction;

correct_sum_as_poly_in_first_variable(eval_at_0, eval_at_2, claim, y, k)
}
}

Expand Down Expand Up @@ -45,10 +92,14 @@ mod tests {
use num_traits::{One, Zero};

use crate::core::backend::CpuBackend;
use crate::core::channel::Channel;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::gkr_prover::GkrOps;
use crate::core::lookups::gkr_prover::{prove_batch, GkrOps, Layer};
use crate::core::lookups::gkr_verifier::{partially_verify_batch, Gate, GkrArtifact, GkrError};
use crate::core::lookups::mle::Mle;
use crate::core::lookups::utils::eq;
use crate::core::test_utils::test_channel;

#[test]
fn gen_eq_evals() {
Expand All @@ -68,4 +119,24 @@ mod tests {
]
);
}

#[test]
fn grand_product_works() -> Result<(), GkrError> {
const N: usize = 1 << 5;
let values = test_channel().draw_felts(N);
let product = values.iter().product::<SecureField>();
let col = Mle::<CpuBackend, SecureField>::new(values);
let input_layer = Layer::GrandProduct(col.clone());
let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]);

let GkrArtifact {
ood_point: r,
claims_to_verify_by_instance,
..
} = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?;

assert_eq!(proof.output_claims_by_instance, [vec![product]]);
assert_eq!(claims_to_verify_by_instance, [vec![col.eval_at_point(&r)]]);
Ok(())
}
}
159 changes: 139 additions & 20 deletions crates/prover/src/core/lookups/gkr_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ use super::sumcheck::MultivariatePolyOracle;
use super::utils::{eq, random_linear_combination, UnivariatePoly};
use crate::core::backend::{Col, Column, ColumnOps};
use crate::core::channel::Channel;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
use crate::core::lookups::sumcheck;

pub trait GkrOps: MleOps<SecureField> {
pub trait GkrOps: MleOps<BaseField> + MleOps<SecureField> {
/// Returns evaluations `eq(x, y) * v` for all `x` in `{0, 1}^n`.
///
/// [`eq(x, y)`]: crate::core::lookups::utils::eq
Expand All @@ -30,7 +32,7 @@ pub trait GkrOps: MleOps<SecureField> {
///
/// For more context see docs of [`MultivariatePolyOracle::sum_as_poly_in_first_variable()`].
fn sum_as_poly_in_first_variable(
h: &GkrMultivariatePolyOracle,
h: &GkrMultivariatePolyOracle<'_, Self>,
claim: SecureField,
) -> UnivariatePoly<SecureField>;
}
Expand Down Expand Up @@ -83,13 +85,16 @@ impl<B: ColumnOps<SecureField>> Deref for EqEvals<B> {
/// [LogUp]: https://eprint.iacr.org/2023/1284.pdf
pub enum Layer<B: GkrOps> {
_LogUp(B),
_GrandProduct(B),
GrandProduct(Mle<B, SecureField>),
}

impl<B: GkrOps> Layer<B> {
/// Returns the number of variables used to interpolate the layer's gate values.
fn n_variables(&self) -> usize {
todo!()
match self {
Self::_LogUp(_) => todo!(),
Self::GrandProduct(mle) => mle.n_variables(),
}
}

/// Produces the next layer from the current layer.
Expand All @@ -110,7 +115,28 @@ impl<B: GkrOps> Layer<B> {

/// Returns each column output if the layer is an output layer, otherwise returns an `Err`.
fn try_into_output_layer_values(self) -> Result<Vec<SecureField>, NotOutputLayerError> {
todo!()
if !self.is_output_layer() {
return Err(NotOutputLayerError);
}

Ok(match self {
Self::GrandProduct(col) => {
vec![col.at(0)]
}
Self::_LogUp(_) => todo!(),
})
}

/// Returns a transformed layer with the first variable of each column fixed to `assignment`.
fn fix_first_variable(self, x0: SecureField) -> Self {
if self.n_variables() == 0 {
return self;
}

match self {
Self::_LogUp(_) => todo!(),
Self::GrandProduct(mle) => Self::GrandProduct(mle.fix_first_variable(x0)),
}
}

/// Represents the next GKR layer evaluation as a multivariate polynomial which uses this GKR
Expand Down Expand Up @@ -143,36 +169,53 @@ impl<B: GkrOps> Layer<B> {
fn into_multivariate_poly(
self,
_lambda: SecureField,
_eq_evals: &EqEvals<B>,
) -> GkrMultivariatePolyOracle {
todo!()
eq_evals: &EqEvals<B>,
) -> GkrMultivariatePolyOracle<'_, B> {
GkrMultivariatePolyOracle {
eq_evals,
input_layer: self,
eq_fixed_var_correction: SecureField::one(),
}
}
}

#[derive(Debug)]
struct NotOutputLayerError;

/// A multivariate polynomial expressed in the multilinear poly columns of a [`Layer`].
pub enum GkrMultivariatePolyOracle {
LogUp,
GrandProduct,
/// A multivariate polynomial that expresses the relation between two consecutive GKR layers.
pub struct GkrMultivariatePolyOracle<'a, B: GkrOps> {
/// `eq_evals` passed by [`GkrBinaryLayer::into_multivariate_poly()`].
pub eq_evals: &'a EqEvals<B>,
pub input_layer: Layer<B>,
pub eq_fixed_var_correction: SecureField,
}

impl MultivariatePolyOracle for GkrMultivariatePolyOracle {
impl<'a, B: GkrOps> MultivariatePolyOracle for GkrMultivariatePolyOracle<'a, B> {
fn n_variables(&self) -> usize {
todo!()
self.input_layer.n_variables() - 1
}

fn sum_as_poly_in_first_variable(&self, _claim: SecureField) -> UnivariatePoly<SecureField> {
todo!()
fn sum_as_poly_in_first_variable(&self, claim: SecureField) -> UnivariatePoly<SecureField> {
B::sum_as_poly_in_first_variable(self, claim)
}

fn fix_first_variable(self, _challenge: SecureField) -> Self {
todo!()
fn fix_first_variable(self, challenge: SecureField) -> Self {
if self.n_variables() == 0 {
return self;
}

let z0 = self.eq_evals.y()[self.eq_evals.y().len() - self.n_variables()];
let eq_fixed_var_correction = self.eq_fixed_var_correction * eq(&[challenge], &[z0]);

Self {
eq_evals: self.eq_evals,
eq_fixed_var_correction,
input_layer: self.input_layer.fix_first_variable(challenge),
}
}
}

impl GkrMultivariatePolyOracle {
impl<'a, B: GkrOps> GkrMultivariatePolyOracle<'a, B> {
/// Returns all input layer columns restricted to a line.
///
/// Let `l` be the line satisfying `l(0) = b*` and `l(1) = c*`. Oracles that represent constants
Expand All @@ -184,7 +227,14 @@ impl GkrMultivariatePolyOracle {
///
/// For more context see <https://people.cs.georgetown.edu/jthaler/ProofsArgsAndZK.pdf> page 64.
fn try_into_mask(self) -> Result<GkrMask, NotConstantPolyError> {
todo!()
if self.n_variables() != 0 {
return Err(NotConstantPolyError);
}

match self.input_layer {
Layer::_LogUp(_) => todo!(),
Layer::GrandProduct(mle) => Ok(GkrMask::new(vec![mle.to_cpu().try_into().unwrap()])),
}
}
}

Expand Down Expand Up @@ -312,3 +362,72 @@ fn gen_layers<B: GkrOps>(input_layer: Layer<B>) -> Vec<Layer<B>> {
assert_eq!(layers.len(), n_variables + 1);
layers
}

/// Corrects and interpolates GKR instance sumcheck round polynomials that are generated with the
/// precomputed `eq(x, y)` evaluations provided by [`GkrBinaryLayer::into_multivariate_poly()`].
///
/// Let `y` be a fixed vector of length `n` and let `z` be a subvector comprising of the last `k`
/// elements of `y`. Returns the univariate polynomial `f(t) = sum_x eq((t, x), z) * p(t, x)` for
/// `x` in the boolean hypercube `{0, 1}^(k-1)` when provided with:
///
/// * `claim` equalling `f(0) + f(1)`.
/// * `eval_at_0/2` equalling `sum_x eq(({0}^(n-k+1), x), y) * p(t, x)` at `t=0,2` respectively.
///
/// Note that `f` must have degree <= 3.
///
/// For more context see [`GkrBinaryLayer::into_multivariate_poly()`] docs.
/// See also <https://ia.cr/2024/108> (section 3.2).
///
/// # Panics
///
/// Panics if:
/// * `k` is zero or greater than the length of `y`.
/// * `z_0` is zero.
pub fn correct_sum_as_poly_in_first_variable(
eval_at_0: SecureField,
eval_at_2: SecureField,
claim: SecureField,
y: &[SecureField],
k: usize,
) -> UnivariatePoly<SecureField> {
assert_ne!(k, 0);
let n = y.len();
assert!(k <= n);

let z = &y[n - k..];

// Corrects the difference between two sums:
// 1. `sum_x eq(({0}^(n-k+1), x), y) * p(t, x)`
// 2. `sum_x eq((0, x), z) * p(t, x)`
let eq_y_to_z_correction_factor = eq(&vec![SecureField::zero(); n - k], &y[0..n - k]).inverse();

// Corrects the difference between two sums:
// 1. `sum_x eq((0, x), z) * p(t, x)`
// 2. `sum_x eq((t, x), z) * p(t, x)`
let eq_correction_factor_at = |t| eq(&[t], &[z[0]]) / eq(&[SecureField::zero()], &[z[0]]);

// Let `v(t) = sum_x eq((0, x), z) * p(t, x)`. Apply trick from
// <https://ia.cr/2024/108> (section 3.2) to obtain `f` from `v`.
let t0: SecureField = BaseField::zero().into();
let t1: SecureField = BaseField::one().into();
let t2: SecureField = BaseField::from(2).into();
let t3: SecureField = BaseField::from(3).into();

// Obtain evals `v(0)`, `v(1)`, `v(2)`.
let mut y0 = eq_y_to_z_correction_factor * eval_at_0;
let mut y1 = (claim - y0) / eq_correction_factor_at(t1);
let mut y2 = eq_y_to_z_correction_factor * eval_at_2;

// Interpolate `v` to find `v(3)`. Note `v` has degree <= 2.
let v = UnivariatePoly::interpolate_lagrange(&[t0, t1, t2], &[y0, y1, y2]);
let mut y3 = v.eval_at_point(t3);

// Obtain evals of `f(0)`, `f(1)`, `f(2)`, `f(3)`.
y0 *= eq_correction_factor_at(t0);
y1 *= eq_correction_factor_at(t1);
y2 *= eq_correction_factor_at(t2);
y3 *= eq_correction_factor_at(t3);

// Interpolate `f(t)`. Note `f(t)` has degree <= 3.
UnivariatePoly::interpolate_lagrange(&[t0, t1, t2, t3], &[y0, y1, y2, y3])
}
Loading

0 comments on commit 162fdcc

Please sign in to comment.