Skip to content

Commit

Permalink
Implement LogupOps for SIMD backend
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed Jun 13, 2024
1 parent 1910db0 commit 307f62d
Show file tree
Hide file tree
Showing 4 changed files with 480 additions and 57 deletions.
31 changes: 7 additions & 24 deletions crates/prover/src/core/backend/cpu/lookups/gkr.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::ops::{Add, Index};
use std::ops::Index;

use num_traits::{One, Zero};

Expand All @@ -11,7 +11,7 @@ use crate::core::lookups::gkr_prover::{
};
use crate::core::lookups::mle::{Mle, MleOps};
use crate::core::lookups::sumcheck::MultivariatePolyOracle;
use crate::core::lookups::utils::{Fraction, UnivariatePoly};
use crate::core::lookups::utils::{Fraction, Reciprocal, UnivariatePoly};

impl GkrOps for CpuBackend {
fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Mle<Self, SecureField> {
Expand Down Expand Up @@ -184,23 +184,6 @@ fn process_logup_singles_sum(
n_terms: usize,
lambda: SecureField,
) {
/// Represents the fraction `1 / x`
struct Reciprocal {
x: SecureField,
}

impl Add for Reciprocal {
type Output = Fraction<SecureField>;

fn add(self, rhs: Self) -> Fraction<SecureField> {
// `1/a + 1/b = (a + b)/(a * b)`
Fraction {
numerator: self.x + rhs.x,
denominator: self.x * rhs.x,
}
}
}

#[allow(clippy::needless_range_loop)]
for i in 0..n_terms {
// Let `q` be the multilinear polynomial representing `denominators`.
Expand All @@ -221,8 +204,8 @@ fn process_logup_singles_sum(
let q2x0 /* = q(2, x, 0) */ = q1x0.double() - q0x0;
let q2x1 /* = q(2, x, 1) */ = q1x1.double() - q0x1;

let res0 = Reciprocal { x: q0x0 } + Reciprocal { x: q0x1 };
let res2 = Reciprocal { x: q2x0 } + Reciprocal { x: q2x1 };
let res0 = Reciprocal::new(q0x0) + Reciprocal::new(q0x1);
let res2 = Reciprocal::new(q2x0) + Reciprocal::new(q2x1);

let eq_eval = eq_evals[i];
*eval_at_0 += eq_eval * (res0.numerator + lambda * res0.denominator);
Expand Down Expand Up @@ -364,7 +347,7 @@ mod tests {
let denominator_values = (0..N).map(|_| rng.gen()).collect::<Vec<SecureField>>();
let sum = zip(&numerator_values, &denominator_values)
.map(|(&n, &d)| Fraction::new(n, d))
.sum::<Fraction<SecureField>>();
.sum::<Fraction<SecureField, SecureField>>();
let numerators = Mle::<CpuBackend, SecureField>::new(numerator_values);
let denominators = Mle::<CpuBackend, SecureField>::new(denominator_values);
let top_layer = Layer::LogUpGeneric {
Expand Down Expand Up @@ -403,7 +386,7 @@ mod tests {
let sum = denominator_values
.iter()
.map(|&d| Fraction::new(SecureField::one(), d))
.sum::<Fraction<SecureField>>();
.sum::<Fraction<SecureField, SecureField>>();
let denominators = Mle::<CpuBackend, SecureField>::new(denominator_values);
let top_layer = Layer::LogUpSingles {
denominators: denominators.clone(),
Expand Down Expand Up @@ -437,7 +420,7 @@ mod tests {
let denominator_values = (0..N).map(|_| rng.gen()).collect::<Vec<SecureField>>();
let sum = zip(&numerator_values, &denominator_values)
.map(|(&n, &d)| Fraction::new(n.into(), d))
.sum::<Fraction<SecureField>>();
.sum::<Fraction<SecureField, SecureField>>();
let numerators = Mle::<CpuBackend, BaseField>::new(numerator_values);
let denominators = Mle::<CpuBackend, SecureField>::new(denominator_values);
let top_layer = Layer::LogUpMultiplicities {
Expand Down
Loading

0 comments on commit 307f62d

Please sign in to comment.