Skip to content

Commit

Permalink
Implement GkrOps for SIMD backend
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed Jun 26, 2024
1 parent c51f096 commit 4968f10
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 4 deletions.
4 changes: 2 additions & 2 deletions crates/prover/src/core/backend/cpu/lookups/gkr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ fn eval_logup_singles_sum(
/// Returns evaluations `eq(x, y) * v` for all `x` in `{0, 1}^n`.
///
/// Evaluations are returned in bit-reversed order.
fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Vec<SecureField> {
pub fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Vec<SecureField> {
let mut evals = Vec::with_capacity(1 << y.len());
evals.push(v);

Expand Down Expand Up @@ -333,7 +333,7 @@ mod tests {
let eq_evals = CpuBackend::gen_eq_evals(&y, two);

assert_eq!(
**eq_evals,
*eq_evals,
[
eq(&[zero, zero], &y) * two,
eq(&[zero, one], &y) * two,
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/core/backend/cpu/lookups/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
mod gkr;
pub mod gkr;
mod mle;
2 changes: 1 addition & 1 deletion crates/prover/src/core/backend/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ mod accumulation;
mod blake2s;
mod circle;
mod fri;
mod lookups;
pub mod lookups;
pub mod quotients;

use std::fmt::Debug;
Expand Down
90 changes: 90 additions & 0 deletions crates/prover/src/core/backend/simd/lookups/gkr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
use std::iter::zip;

use crate::core::backend::cpu::lookups::gkr::gen_eq_evals as cpu_gen_eq_evals;
use crate::core::backend::simd::column::SecureFieldVec;
use crate::core::backend::simd::m31::{LOG_N_LANES, N_LANES};
use crate::core::backend::simd::qm31::PackedSecureField;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::Column;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::gkr_prover::{GkrMultivariatePolyOracle, GkrOps, Layer};
use crate::core::lookups::mle::Mle;
use crate::core::lookups::utils::UnivariatePoly;

impl GkrOps for SimdBackend {
#[allow(clippy::uninit_vec)]
fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Mle<Self, SecureField> {
if y.len() < LOG_N_LANES as usize {
return Mle::new(cpu_gen_eq_evals(y, v).into_iter().collect());
}

// Start DP with CPU backend to prevent dealing with instances smaller than a SIMD vector.
let (y_last_chunk, y_rem) = y.split_last_chunk::<{ LOG_N_LANES as usize }>().unwrap();
let initial = SecureFieldVec::from_iter(cpu_gen_eq_evals(y_last_chunk, v));
assert_eq!(initial.len(), N_LANES);

let packed_len = 1 << y_rem.len();
let mut data = initial.data;

data.reserve(packed_len - data.len());
unsafe { data.set_len(packed_len) };

for (i, &y_j) in y_rem.iter().rev().enumerate() {
let packed_y_j = PackedSecureField::broadcast(y_j);

let (lhs_evals, rhs_evals) = data.split_at_mut(1 << i);

for (lhs, rhs) in zip(lhs_evals, rhs_evals) {
// Equivalent to:
// `rhs = eq(1, y_j) * lhs`,
// `lhs = eq(0, y_j) * lhs`
*rhs = *lhs * packed_y_j;
*lhs -= *rhs;
}
}

let length = packed_len * N_LANES;
Mle::new(SecureFieldVec { data, length })
}

fn next_layer(_layer: &Layer<Self>) -> Layer<Self> {
todo!()
}

fn sum_as_poly_in_first_variable(
_h: &GkrMultivariatePolyOracle<'_, Self>,
_claim: SecureField,
) -> UnivariatePoly<SecureField> {
todo!()
}
}

#[cfg(test)]
mod tests {
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Column, CpuBackend};
use crate::core::fields::m31::BaseField;
use crate::core::lookups::gkr_prover::GkrOps;

#[test]
fn gen_eq_evals_matches_cpu() {
let two = BaseField::from(2).into();
let y = [7, 3, 5, 6, 1, 1, 9].map(|v| BaseField::from(v).into());
let eq_evals_cpu = CpuBackend::gen_eq_evals(&y, two);

let eq_evals_simd = SimdBackend::gen_eq_evals(&y, two);

assert_eq!(eq_evals_simd.to_cpu(), *eq_evals_cpu);
}

#[test]
fn gen_eq_evals_with_small_assignment_matches_cpu() {
let two = BaseField::from(2).into();
let y = [7, 3, 5].map(|v| BaseField::from(v).into());
let eq_evals_cpu = CpuBackend::gen_eq_evals(&y, two);

let eq_evals_simd = SimdBackend::gen_eq_evals(&y, two);

assert_eq!(eq_evals_simd.to_cpu(), *eq_evals_cpu);
}
}
1 change: 1 addition & 0 deletions crates/prover/src/core/backend/simd/lookups/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
mod gkr;
mod mle;
1 change: 1 addition & 0 deletions crates/prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
stdsimd,
get_many_mut,
int_roundings,
slice_first_last_chunk,
slice_flatten,
assert_matches,
portable_simd
Expand Down

0 comments on commit 4968f10

Please sign in to comment.