Skip to content

Commit

Permalink
Implement LogupOps for SIMD backend
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed Jun 13, 2024
1 parent 7ae75ab commit 9bc0cd3
Show file tree
Hide file tree
Showing 4 changed files with 484 additions and 61 deletions.
39 changes: 11 additions & 28 deletions crates/prover/src/core/backend/cpu/lookups/gkr.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::ops::{Add, Index};
use std::ops::Index;

use num_traits::{One, Zero};

Expand All @@ -11,7 +11,7 @@ use crate::core::lookups::gkr_prover::{
};
use crate::core::lookups::mle::{Mle, MleOps};
use crate::core::lookups::sumcheck::MultivariatePolyOracle;
use crate::core::lookups::utils::{Fraction, UnivariatePoly};
use crate::core::lookups::utils::{Fraction, Reciprocal, UnivariatePoly};

impl GkrOps for CpuBackend {
fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Mle<Self, SecureField> {
Expand Down Expand Up @@ -184,23 +184,6 @@ fn process_logup_singles_sum(
n_terms: usize,
lambda: SecureField,
) {
/// Represents the fraction `1 / x`
struct Reciprocal {
x: SecureField,
}

impl Add for Reciprocal {
type Output = Fraction<SecureField>;

fn add(self, rhs: Self) -> Fraction<SecureField> {
// `1/a + 1/b = (a + b)/(a * b)`
Fraction {
numerator: self.x + rhs.x,
denominator: self.x * rhs.x,
}
}
}

#[allow(clippy::needless_range_loop)]
for i in 0..n_terms {
// Let `q` be the multilinear polynomial representing `denominators`.
Expand All @@ -221,8 +204,8 @@ fn process_logup_singles_sum(
let q2x0 /* = q(2, x, 0) */ = q1x0.double() - q0x0;
let q2x1 /* = q(2, x, 1) */ = q1x1.double() - q0x1;

let res0 = Reciprocal { x: q0x0 } + Reciprocal { x: q0x1 };
let res2 = Reciprocal { x: q2x0 } + Reciprocal { x: q2x1 };
let res0 = Reciprocal::new(q0x0) + Reciprocal::new(q0x1);
let res2 = Reciprocal::new(q2x0) + Reciprocal::new(q2x1);

let eq_eval = eq_evals[i];
*eval_at_0 += eq_eval * (res0.numerator + lambda * res0.denominator);
Expand Down Expand Up @@ -348,7 +331,7 @@ mod tests {
let GkrArtifact {
ood_point: r,
claims_to_verify_by_instance,
..
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?;

assert_eq!(proof.output_claims_by_instance, [vec![product]]);
Expand All @@ -364,7 +347,7 @@ mod tests {
let denominator_values = (0..N).map(|_| rng.gen()).collect::<Vec<SecureField>>();
let sum = zip(&numerator_values, &denominator_values)
.map(|(&n, &d)| Fraction::new(n, d))
.sum::<Fraction<SecureField>>();
.sum::<Fraction<SecureField, SecureField>>();
let numerators = Mle::<CpuBackend, SecureField>::new(numerator_values);
let denominators = Mle::<CpuBackend, SecureField>::new(denominator_values);
let top_layer = Layer::LogUpGeneric {
Expand Down Expand Up @@ -409,7 +392,7 @@ mod tests {
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
..
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;

assert_eq!(claims_to_verify_by_instance.len(), 1);
Expand All @@ -436,7 +419,7 @@ mod tests {
let sum = denominator_values
.iter()
.map(|&d| Fraction::new(SecureField::one(), d))
.sum::<Fraction<SecureField>>();
.sum::<Fraction<SecureField, SecureField>>();
let denominators = Mle::<CpuBackend, SecureField>::new(denominator_values);
let top_layer = Layer::LogUpSingles {
denominators: denominators.clone(),
Expand All @@ -446,7 +429,7 @@ mod tests {
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
..
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;

assert_eq!(claims_to_verify_by_instance.len(), 1);
Expand All @@ -470,7 +453,7 @@ mod tests {
let denominator_values = (0..N).map(|_| rng.gen()).collect::<Vec<SecureField>>();
let sum = zip(&numerator_values, &denominator_values)
.map(|(&n, &d)| Fraction::new(n.into(), d))
.sum::<Fraction<SecureField>>();
.sum::<Fraction<SecureField, SecureField>>();
let numerators = Mle::<CpuBackend, BaseField>::new(numerator_values);
let denominators = Mle::<CpuBackend, SecureField>::new(denominator_values);
let top_layer = Layer::LogUpMultiplicities {
Expand All @@ -482,7 +465,7 @@ mod tests {
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
..
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;

assert_eq!(claims_to_verify_by_instance.len(), 1);
Expand Down
Loading

0 comments on commit 9bc0cd3

Please sign in to comment.