Skip to content

Commit

Permalink
CommitmentSchemeProver generic in backend
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Apr 3, 2024
1 parent 4b23127 commit 1a09f8c
Show file tree
Hide file tree
Showing 13 changed files with 49 additions and 44 deletions.
2 changes: 1 addition & 1 deletion src/core/backend/avx512/bit_reverse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,6 @@ mod tests {
let mut data: BaseFieldVec = data.into_iter().collect();

bit_reverse_m31(&mut data.data[..]);
assert_eq!(data.to_vec(), expected);
assert_eq!(data.to_cpu(), expected);
}
}
8 changes: 4 additions & 4 deletions src/core/backend/avx512/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ mod tests {
);
let poly = evaluation.clone().interpolate();
let evaluation2 = poly.evaluate(domain);
assert_eq!(evaluation.values.to_vec(), evaluation2.values.to_vec());
assert_eq!(evaluation.values.to_cpu(), evaluation2.values.to_cpu());
}
}

Expand All @@ -377,8 +377,8 @@ mod tests {
let evaluation2 = poly.evaluate(domain_ext);
let poly2 = evaluation2.interpolate();
assert_eq!(
poly.extend(log_size + 3).coeffs.to_vec(),
poly2.coeffs.to_vec()
poly.extend(log_size + 3).coeffs.to_cpu(),
poly2.coeffs.to_cpu()
);
}
}
Expand Down Expand Up @@ -419,7 +419,7 @@ mod tests {
.extend(log_size + 2)
.evaluate(CanonicCoset::new(log_size + 2).circle_domain());

assert_eq!(eval0.values.to_vec(), eval1.values.to_vec());
assert_eq!(eval0.values.to_cpu(), eval1.values.to_cpu());
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/core/backend/avx512/fft/ifft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ mod tests {
);

// Compare.
assert_eq!(values.to_vec(), expected_coeffs);
assert_eq!(values.to_cpu(), expected_coeffs);
}
}
}
Expand Down Expand Up @@ -718,7 +718,7 @@ mod tests {
);

// Compare.
assert_eq!(values.to_vec(), expected_coeffs);
assert_eq!(values.to_cpu(), expected_coeffs);
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/core/backend/avx512/fft/rfft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ mod tests {
);

// Compare.
assert_eq!(values.to_vec(), expected_coeffs);
assert_eq!(values.to_cpu(), expected_coeffs);
}
}
}
Expand Down Expand Up @@ -744,7 +744,7 @@ mod tests {
);

// Compare.
assert_eq!(values.to_vec(), expected_coeffs);
assert_eq!(values.to_cpu(), expected_coeffs);
}
}

Expand Down
14 changes: 7 additions & 7 deletions src/core/backend/avx512/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ impl Column<BaseField> for BaseFieldVec {
length: len,
}
}
fn to_vec(&self) -> Vec<BaseField> {
fn to_cpu(&self) -> Vec<BaseField> {
self.data
.iter()
.flat_map(|x| x.to_array())
Expand Down Expand Up @@ -242,10 +242,10 @@ impl SecureColumn<AVX512Backend> {

pub fn to_vec(&self) -> Vec<SecureField> {
izip!(
self.columns[0].to_vec(),
self.columns[1].to_vec(),
self.columns[2].to_vec(),
self.columns[3].to_vec(),
self.columns[0].to_cpu(),
self.columns[1].to_cpu(),
self.columns[2].to_cpu(),
self.columns[3].to_cpu(),
)
.map(|(a, b, c, d)| SecureField::from_m31_array([a, b, c, d]))
.collect()
Expand All @@ -269,7 +269,7 @@ mod tests {
for i in 0..100 {
let col = Col::<B, BaseField>::from_iter((0..i).map(BaseField::from));
assert_eq!(
col.to_vec(),
col.to_cpu(),
(0..i).map(BaseField::from).collect::<Vec<_>>()
);
for j in 0..i {
Expand All @@ -285,7 +285,7 @@ mod tests {
let mut col = Col::<B, BaseField>::from_iter((0..len).map(BaseField::from));
<B as ColumnOps<BaseField>>::bit_reverse_column(&mut col);
assert_eq!(
col.to_vec(),
col.to_cpu(),
(0..len)
.map(|x| BaseField::from(utils::bit_reverse_index(x, i as u32)))
.collect::<Vec<_>>()
Expand Down
2 changes: 1 addition & 1 deletion src/core/backend/avx512/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ mod tests {
.map(|c| {
CircleEvaluation::<CPUBackend, _, BitReversedOrder>::new(
c.domain,
c.values.to_vec(),
c.values.to_cpu(),
)
})
.collect::<Vec<_>>();
Expand Down
2 changes: 1 addition & 1 deletion src/core/backend/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl<T: Debug + Clone + Default> Column<T> for Vec<T> {
fn zeros(len: usize) -> Self {
vec![T::default(); len]
}
fn to_vec(&self) -> Vec<T> {
fn to_cpu(&self) -> Vec<T> {
self.clone()
}
fn len(&self) -> usize {
Expand Down
6 changes: 4 additions & 2 deletions src/core/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@ use std::fmt::Debug;

pub use cpu::CPUBackend;

use super::commitment_scheme::quotients::QuotientOps;
use super::fields::m31::BaseField;
use super::fields::qm31::SecureField;
use super::fields::FieldOps;
use super::fri::FriOps;
use super::poly::circle::PolyOps;

#[cfg(target_arch = "x86_64")]
pub mod avx512;
pub mod cpu;

pub trait Backend:
Copy + Clone + Debug + FieldOps<BaseField> + FieldOps<SecureField> + PolyOps
Copy + Clone + Debug + FieldOps<BaseField> + FieldOps<SecureField> + PolyOps + QuotientOps + FriOps
{
}

Expand All @@ -28,7 +30,7 @@ pub trait Column<T>: Clone + Debug + FromIterator<T> {
/// Creates a new column of zeros with the given length.
fn zeros(len: usize) -> Self;
/// Returns a cpu vector of the column.
fn to_vec(&self) -> Vec<T>;
fn to_cpu(&self) -> Vec<T>;
/// Returns the length of the column.
fn len(&self) -> usize;
/// Returns true if the column is empty.
Expand Down
34 changes: 17 additions & 17 deletions src/core/commitment_scheme/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ use std::collections::BTreeMap;

use itertools::Itertools;

use super::super::backend::cpu::{CPUCircleEvaluation, CPUCirclePoly};
use super::super::backend::CPUBackend;
use super::super::channel::Blake2sChannel;
use super::super::circle::CirclePoint;
use super::super::fields::m31::BaseField;
Expand All @@ -20,27 +18,30 @@ use super::quotients::{compute_fri_quotients, PointSample};
use super::utils::TreeVec;
use crate::commitment_scheme::blake2_hash::Blake2sHash;
use crate::commitment_scheme::blake2_merkle::Blake2sMerkleHasher;
use crate::commitment_scheme::ops::MerkleOps;
use crate::commitment_scheme::prover::{MerkleDecommitment, MerkleProver};
use crate::core::backend::Backend;
use crate::core::channel::Channel;
use crate::core::poly::circle::{CircleEvaluation, CirclePoly};

type MerkleHasher = Blake2sMerkleHasher;
type ProofChannel = Blake2sChannel;

/// The prover side of a FRI polynomial commitment scheme. See [super].
pub struct CommitmentSchemeProver {
pub trees: TreeVec<CommitmentTreeProver>,
pub struct CommitmentSchemeProver<B: Backend + MerkleOps<MerkleHasher>> {
pub trees: TreeVec<CommitmentTreeProver<B>>,
pub log_blowup_factor: u32,
}

impl CommitmentSchemeProver {
impl<B: Backend + MerkleOps<MerkleHasher>> CommitmentSchemeProver<B> {
pub fn new(log_blowup_factor: u32) -> Self {
CommitmentSchemeProver {
trees: TreeVec::<CommitmentTreeProver>::default(),
trees: TreeVec::default(),
log_blowup_factor,
}
}

pub fn commit(&mut self, polynomials: ColumnVec<CPUCirclePoly>, channel: &mut ProofChannel) {
pub fn commit(&mut self, polynomials: ColumnVec<CirclePoly<B>>, channel: &mut ProofChannel) {
let tree = CommitmentTreeProver::new(polynomials, self.log_blowup_factor, channel);
self.trees.push(tree);
}
Expand All @@ -49,13 +50,13 @@ impl CommitmentSchemeProver {
self.trees.as_ref().map(|tree| tree.commitment.root())
}

pub fn polynomials(&self) -> TreeVec<ColumnVec<&CPUCirclePoly>> {
pub fn polynomials(&self) -> TreeVec<ColumnVec<&CirclePoly<B>>> {
self.trees
.as_ref()
.map(|tree| tree.polynomials.iter().collect())
}

fn evaluations(&self) -> TreeVec<ColumnVec<&CPUCircleEvaluation<BaseField, BitReversedOrder>>> {
fn evaluations(&self) -> TreeVec<ColumnVec<&CircleEvaluation<B, BaseField, BitReversedOrder>>> {
self.trees
.as_ref()
.map(|tree| tree.evaluations.iter().collect())
Expand Down Expand Up @@ -90,8 +91,7 @@ impl CommitmentSchemeProver {

// Run FRI commitment phase on the oods quotients.
let fri_config = FriConfig::new(LOG_LAST_LAYER_DEGREE_BOUND, LOG_BLOWUP_FACTOR, N_QUERIES);
let fri_prover =
FriProver::<CPUBackend, MerkleHasher>::commit(channel, fri_config, &quotients);
let fri_prover = FriProver::<B, MerkleHasher>::commit(channel, fri_config, &quotients);

// Proof of work.
let proof_of_work = ProofOfWork::new(PROOF_OF_WORK_BITS).prove(channel);
Expand Down Expand Up @@ -132,15 +132,15 @@ pub struct CommitmentSchemeProof {

/// Prover data for a single commitment tree in a commitment scheme. The commitment scheme allows to
/// commit on a set of polynomials at a time. This corresponds to such a set.
pub struct CommitmentTreeProver {
pub polynomials: ColumnVec<CPUCirclePoly>,
pub evaluations: ColumnVec<CPUCircleEvaluation<BaseField, BitReversedOrder>>,
pub commitment: MerkleProver<CPUBackend, MerkleHasher>,
pub struct CommitmentTreeProver<B: Backend + MerkleOps<MerkleHasher>> {
pub polynomials: ColumnVec<CirclePoly<B>>,
pub evaluations: ColumnVec<CircleEvaluation<B, BaseField, BitReversedOrder>>,
pub commitment: MerkleProver<B, MerkleHasher>,
}

impl CommitmentTreeProver {
impl<B: Backend + MerkleOps<MerkleHasher>> CommitmentTreeProver<B> {
fn new(
polynomials: ColumnVec<CPUCirclePoly>,
polynomials: ColumnVec<CirclePoly<B>>,
log_blowup_factor: u32,
channel: &mut ProofChannel,
) -> Self {
Expand Down
2 changes: 1 addition & 1 deletion src/core/fields/secure_column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ impl<B: FieldOps<BaseField>> SecureColumn<B> {

pub fn to_cpu(&self) -> SecureColumn<CPUBackend> {
SecureColumn {
columns: self.columns.clone().map(|c| c.to_vec()),
columns: self.columns.clone().map(|c| c.to_cpu()),
}
}
}
Expand Down
6 changes: 4 additions & 2 deletions src/core/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ use itertools::Itertools;
use num_traits::Zero;
use thiserror::Error;

use super::backend::{Backend, CPUBackend};
use super::backend::CPUBackend;
use super::channel::Channel;
use super::fields::m31::BaseField;
use super::fields::qm31::SecureField;
use super::fields::secure_column::{SecureColumn, SECURE_EXTENSION_DEGREE};
use super::fields::FieldOps;
use super::poly::circle::{CircleEvaluation, SecureEvaluation};
use super::poly::line::{LineEvaluation, LinePoly};
use super::poly::BitReversedOrder;
Expand Down Expand Up @@ -67,7 +69,7 @@ impl FriConfig {
}
}

pub trait FriOps: Backend + Sized {
pub trait FriOps: FieldOps<BaseField> + Sized {
/// Folds a degree `d` polynomial into a degree `d/2` polynomial.
///
/// Let `eval` be a polynomial evaluated on a [LineDomain] `E`, `alpha` be a random field
Expand Down
1 change: 1 addition & 0 deletions src/core/poly/circle/evaluation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ impl<B: FieldOps<F>, F: ExtensionOf<BaseField>, EvalOrder> CircleEvaluation<B, F
// Note: The concrete implementation of the poly operations is in the specific backend used.
// For example, the CPU backend implementation is in `src/core/backend/cpu/poly.rs`.
impl<F: ExtensionOf<BaseField>, B: FieldOps<F>> CircleEvaluation<B, F, NaturalOrder> {
// TODO(spapini): Remove. Is this even used.
pub fn get_at(&self, point_index: CirclePointIndex) -> F {
self.values
.at(self.domain.find(point_index).expect("Not in domain"))
Expand Down
8 changes: 4 additions & 4 deletions src/core/poly/line.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ use itertools::Itertools;
use num_traits::Zero;

use super::utils::fold;
use crate::core::backend::{Backend, CPUBackend, ColumnOps};
use crate::core::backend::{CPUBackend, ColumnOps};
use crate::core::circle::{CirclePoint, Coset, CosetIterator};
use crate::core::fft::ibutterfly;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumn;
use crate::core::fields::{ExtensionOf, FieldExpOps};
use crate::core::fields::{ExtensionOf, FieldExpOps, FieldOps};
use crate::core::utils::bit_reverse;

/// Domain comprising of the x-coordinates of points in a [Coset].
Expand Down Expand Up @@ -176,13 +176,13 @@ impl DerefMut for LinePoly {
// only used by FRI where evaluations are in bit-reversed order.
// TODO(spapini): Remove pub.
#[derive(Clone, Debug)]
pub struct LineEvaluation<B: Backend> {
pub struct LineEvaluation<B: FieldOps<BaseField>> {
/// Evaluations of a univariate polynomial on `domain`.
pub values: SecureColumn<B>,
domain: LineDomain,
}

impl<B: Backend> LineEvaluation<B> {
impl<B: FieldOps<BaseField>> LineEvaluation<B> {
/// Creates new [LineEvaluation] from a set of polynomial evaluations over a [LineDomain].
///
/// # Panics
Expand Down

0 comments on commit 1a09f8c

Please sign in to comment.