Skip to content

Commit

Permalink
feat(nifs): impl protogalaxy::compute_G
Browse files Browse the repository at this point in the history
refactor(nifs): mv `folded_trace` to mod

fix(util): err in doc-test

docs(nifs): `protogalaxy::poly`
  • Loading branch information
cyphersnake committed Jun 11, 2024
1 parent ec4e1c3 commit 6b0ae01
Show file tree
Hide file tree
Showing 3 changed files with 279 additions and 156 deletions.
169 changes: 169 additions & 0 deletions src/nifs/protogalaxy/poly/folded_trace.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
use std::iter;

use ff::PrimeField;
use rayon::prelude::*;

use crate::{
plonk::{GetChallenges, GetWitness, PlonkWitness},
polynomial::lagrange,
util::MultiCartesianProduct,
};

pub(crate) struct FoldedTrace<F: PrimeField> {
witness: PlonkWitness<F>,
challenges: Vec<F>,
}

impl<F: PrimeField> FoldedTrace<F> {
pub(crate) fn new(
points_for_fft: &[F],
accumulator: impl Sync + GetChallenges<F> + GetWitness<F>,
traces: &[(impl Sync + GetChallenges<F> + GetWitness<F>)],
) -> Box<[Self]> {
let folded_witnesses_collection = fold_witnesses(points_for_fft, &accumulator, traces);
let folded_challenges_collection =
fold_plonk_challenges(points_for_fft, &accumulator, traces);

folded_witnesses_collection
.into_iter()
.zip(folded_challenges_collection)
.map(|(witness, challenges)| Self {
witness,
challenges,
})
.collect()
}
}
impl<F: PrimeField> GetChallenges<F> for FoldedTrace<F> {
fn get_challenges(&self) -> &[F] {
&self.challenges
}
}
impl<F: PrimeField> GetWitness<F> for FoldedTrace<F> {
fn get_witness(&self) -> &[Vec<F>] {
&self.witness.W
}
}

/// For each `X` we must perform the operation of sum all all matrices [`PlonkWitness`] with
/// coefficients taken from [`lagrange::iter_eval_lagrange_polynomials_for_cyclic_group`]
///
/// Since the number of rows is large, we do this in one pass, counting the points for each
/// challenge at each iteration, and laying them out in separate [`PlonkWitness`] at the end.
fn fold_witnesses<F: PrimeField>(
X_challenges: &[F],
accumulator: &(impl GetWitness<F> + Sync),
witnesses: &[impl Sync + GetWitness<F>],
) -> Vec<PlonkWitness<F>> {
let log_n = (witnesses.len() + 1).next_power_of_two().ilog2();

let lagrange_poly_by_challenge = X_challenges
.iter()
.map(|X| {
lagrange::iter_eval_lagrange_polynomials_for_cyclic_group(*X, log_n)
.collect::<Box<[_]>>()
})
.collect::<Box<[_]>>();

let columns_count = accumulator.get_witness().len();
let rows_count = accumulator.get_witness()[0].len();

let mut result_matrix_by_challenge = vec![
PlonkWitness {
W: vec![vec![F::ZERO; rows_count]; columns_count],
};
X_challenges.len()
];

itertools::iproduct!(0..columns_count, 0..rows_count)
.map(|(col, row)| {
iter::once(accumulator.get_witness())
.chain(witnesses.iter().map(GetWitness::get_witness))
.zip(
lagrange_poly_by_challenge
.iter()
.map(|m| m.iter().copied())
.multi_product(),
)
.fold(
vec![F::ZERO; X_challenges.len()].into_boxed_slice(),
// every element of this collection - one cell for each `X_challenge`
|mut cells_by_challenge, (witness, multiplier)| {
cells_by_challenge
.iter_mut()
.zip(multiplier.iter())
.for_each(|(res, cell)| {
*res += *cell * witness[col][row];
});

cells_by_challenge
},
)
})
.zip(
// Here we take an iterator that on each iteration returns [column][row] elements for
// each witness for its challenge
//
// next -> vec![result[0][col][row], result[1][col][row], ... result[challenges_len][col][row]]
result_matrix_by_challenge
.iter_mut()
.map(|matrix| matrix.W.iter_mut().flat_map(|col| col.iter_mut()))
.multi_product(),
)
.par_bridge()
.for_each(|(elements, mut results)| {
results
.iter_mut()
.zip(elements.iter())
.for_each(|(result, cell)| **result = *cell);
});

result_matrix_by_challenge
}

fn fold_plonk_challenges<F: PrimeField>(
X_challenges: &[F],
accumulator: &(impl GetChallenges<F> + Sync),
plonk_challenges: &[impl Sync + GetChallenges<F>],
) -> Vec<Vec<F>> {
let log_n = (plonk_challenges.len() + 1).next_power_of_two().ilog2();

let lagrange_poly_by_challenge = X_challenges
.iter()
.map(|X| {
lagrange::iter_eval_lagrange_polynomials_for_cyclic_group(*X, log_n)
.collect::<Box<[_]>>()
})
.collect::<Box<[_]>>();

let plonk_challenges_len = accumulator.get_challenges().len();

iter::once(accumulator.get_challenges())
.chain(plonk_challenges.iter().map(GetChallenges::get_challenges))
.zip(
lagrange_poly_by_challenge
.iter()
.map(|m| m.iter().copied())
.multi_product(),
)
.fold(
vec![vec![F::ZERO; plonk_challenges_len]; X_challenges.len()],
|mut folded, (plonk_challenge, lagrange_by_X_challenges)| {
folded
.iter_mut()
.zip(lagrange_by_X_challenges.iter())
.for_each(
|(folded_plonk_challenge, lagrange_multiplier_by_plonk_challenge)| {
folded_plonk_challenge
.iter_mut()
.zip(plonk_challenge)
.for_each(|(folded, element)| {
*folded += *element * lagrange_multiplier_by_plonk_challenge;
});
},
);

folded
},
)
}
Loading

0 comments on commit 6b0ae01

Please sign in to comment.