From c2810972c15858f66a7ba9639cfed2a2b025de58 Mon Sep 17 00:00:00 2001 From: schaeff Date: Sun, 7 Jul 2024 23:41:36 +0200 Subject: [PATCH] implement fixed columns in uni-stark, changed BaseAir --- air/src/air.rs | 5 ++ uni-stark/src/check_constraints.rs | 25 ++++++- uni-stark/src/folder.rs | 16 +++- uni-stark/src/proof.rs | 17 +++++ uni-stark/src/prover.rs | 115 +++++++++++++++++++++++++---- uni-stark/src/symbolic_builder.rs | 21 ++---- uni-stark/src/verifier.rs | 87 ++++++++++++++++------ 7 files changed, 231 insertions(+), 55 deletions(-) diff --git a/air/src/air.rs b/air/src/air.rs index e8d773579..873ea7da6 100644 --- a/air/src/air.rs +++ b/air/src/air.rs @@ -12,6 +12,11 @@ pub trait BaseAir: Sync { fn preprocessed_trace(&self) -> Option> { None } + + /// The number of preprocessed columns in this AIR + fn preprocessed_width(&self) -> usize { + 0 + } } /// An AIR that works with a particular `AirBuilder`. diff --git a/uni-stark/src/check_constraints.rs b/uni-stark/src/check_constraints.rs index 1028ef904..7584f8a85 100644 --- a/uni-stark/src/check_constraints.rs +++ b/uni-stark/src/check_constraints.rs @@ -1,6 +1,6 @@ use alloc::vec::Vec; -use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues}; +use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, PairBuilder}; use p3_field::Field; use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView}; use p3_matrix::stack::VerticalPair; @@ -8,8 +8,12 @@ use p3_matrix::Matrix; use tracing::instrument; #[instrument(name = "check constraints", skip_all)] -pub(crate) fn check_constraints(air: &A, main: &RowMajorMatrix, public_values: &Vec) -where +pub(crate) fn check_constraints( + air: &A, + preprocessed: &RowMajorMatrix, + main: &RowMajorMatrix, + public_values: &Vec, +) where F: Field, A: for<'a> Air>, { @@ -18,6 +22,13 @@ where (0..height).for_each(|i| { let i_next = (i + 1) % height; + let local_preprocessed = preprocessed.row_slice(i); + let next_preprocessed = preprocessed.row_slice(i_next); + let preprocessed = VerticalPair::new( + RowMajorMatrixView::new_row(&*local_preprocessed), + RowMajorMatrixView::new_row(&*next_preprocessed), + ); + let local = main.row_slice(i); let next = main.row_slice(i_next); let main = VerticalPair::new( @@ -27,6 +38,7 @@ where let mut builder = DebugConstraintBuilder { row_index: i, + preprocessed, main, public_values, is_first_row: F::from_bool(i == 0), @@ -43,6 +55,7 @@ where #[derive(Debug)] pub struct DebugConstraintBuilder<'a, F: Field> { row_index: usize, + preprocessed: VerticalPair, RowMajorMatrixView<'a, F>>, main: VerticalPair, RowMajorMatrixView<'a, F>>, public_values: &'a [F], is_first_row: F, @@ -106,3 +119,9 @@ impl<'a, F: Field> AirBuilderWithPublicValues for DebugConstraintBuilder<'a, F> self.public_values } } + +impl<'a, F: Field> PairBuilder for DebugConstraintBuilder<'a, F> { + fn preprocessed(&self) -> Self::M { + self.preprocessed + } +} diff --git a/uni-stark/src/folder.rs b/uni-stark/src/folder.rs index d7536b046..ea1563254 100644 --- a/uni-stark/src/folder.rs +++ b/uni-stark/src/folder.rs @@ -1,6 +1,6 @@ use alloc::vec::Vec; -use p3_air::{AirBuilder, AirBuilderWithPublicValues}; +use p3_air::{AirBuilder, AirBuilderWithPublicValues, PairBuilder}; use p3_field::AbstractField; use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView}; use p3_matrix::stack::VerticalPair; @@ -10,6 +10,7 @@ use crate::{PackedChallenge, PackedVal, StarkGenericConfig, Val}; #[derive(Debug)] pub struct ProverConstraintFolder<'a, SC: StarkGenericConfig> { pub main: RowMajorMatrix>, + pub preprocessed: RowMajorMatrix>, pub public_values: &'a Vec>, pub is_first_row: PackedVal, pub is_last_row: PackedVal, @@ -23,6 +24,7 @@ type ViewPair<'a, T> = VerticalPair, RowMajorMatrixVie #[derive(Debug)] pub struct VerifierConstraintFolder<'a, SC: StarkGenericConfig> { pub main: ViewPair<'a, SC::Challenge>, + pub preprocessed: ViewPair<'a, SC::Challenge>, pub public_values: &'a Vec>, pub is_first_row: SC::Challenge, pub is_last_row: SC::Challenge, @@ -72,6 +74,12 @@ impl<'a, SC: StarkGenericConfig> AirBuilderWithPublicValues for ProverConstraint } } +impl<'a, SC: StarkGenericConfig> PairBuilder for ProverConstraintFolder<'a, SC> { + fn preprocessed(&self) -> Self::M { + self.preprocessed.clone() + } +} + impl<'a, SC: StarkGenericConfig> AirBuilder for VerifierConstraintFolder<'a, SC> { type F = Val; type Expr = SC::Challenge; @@ -112,3 +120,9 @@ impl<'a, SC: StarkGenericConfig> AirBuilderWithPublicValues for VerifierConstrai self.public_values } } + +impl<'a, SC: StarkGenericConfig> PairBuilder for VerifierConstraintFolder<'a, SC> { + fn preprocessed(&self) -> Self::M { + self.preprocessed + } +} diff --git a/uni-stark/src/proof.rs b/uni-stark/src/proof.rs index ef969cf72..79b8b96a5 100644 --- a/uni-stark/src/proof.rs +++ b/uni-stark/src/proof.rs @@ -13,6 +13,10 @@ type PcsProof = <::Pcs as Pcs< ::Challenge, ::Challenger, >>::Proof; +pub type PcsProverData = <::Pcs as Pcs< + ::Challenge, + ::Challenger, +>>::ProverData; #[derive(Serialize, Deserialize)] #[serde(bound = "")] @@ -31,7 +35,20 @@ pub struct Commitments { #[derive(Debug, Serialize, Deserialize)] pub struct OpenedValues { + pub(crate) preprocessed_local: Vec, + pub(crate) preprocessed_next: Vec, pub(crate) trace_local: Vec, pub(crate) trace_next: Vec, pub(crate) quotient_chunks: Vec>, } + +pub struct StarkProvingKey { + pub preprocessed_commit: Com, + pub preprocessed_data: PcsProverData, +} + +#[derive(Serialize, Deserialize)] +#[serde(bound = "")] +pub struct StarkVerifyingKey { + pub preprocessed_commit: Com, +} diff --git a/uni-stark/src/prover.rs b/uni-stark/src/prover.rs index 5a3c4c644..5314f41ef 100644 --- a/uni-stark/src/prover.rs +++ b/uni-stark/src/prover.rs @@ -16,7 +16,7 @@ use tracing::{info_span, instrument}; use crate::symbolic_builder::{get_log_quotient_degree, SymbolicAirBuilder}; use crate::{ Commitments, Domain, OpenedValues, PackedChallenge, PackedVal, Proof, ProverConstraintFolder, - StarkGenericConfig, Val, + StarkGenericConfig, StarkProvingKey, Val, }; #[instrument(skip_all)] @@ -32,17 +32,44 @@ pub fn prove< trace: RowMajorMatrix>, public_values: &Vec>, ) -> Proof +where + SC: StarkGenericConfig, + A: Air>> + for<'a> Air>, +{ + prove_with_key(config, None, air, challenger, trace, public_values) +} + +#[instrument(skip_all)] +#[allow(clippy::multiple_bound_locations)] // cfg not supported in where clauses? +pub fn prove_with_key< + SC, + #[cfg(debug_assertions)] A: for<'a> Air>>, + #[cfg(not(debug_assertions))] A, +>( + config: &SC, + proving_key: Option<&StarkProvingKey>, + air: &A, + challenger: &mut SC::Challenger, + trace: RowMajorMatrix>, + public_values: &Vec>, +) -> Proof where SC: StarkGenericConfig, A: Air>> + for<'a> Air>, { #[cfg(debug_assertions)] - crate::check_constraints::check_constraints(air, &trace, public_values); + crate::check_constraints::check_constraints( + air, + &air.preprocessed_trace() + .unwrap_or(RowMajorMatrix::new(vec![], 0)), + &trace, + public_values, + ); let degree = trace.height(); let log_degree = log2_strict_usize(degree); - let log_quotient_degree = get_log_quotient_degree::, A>(air, 0, public_values.len()); + let log_quotient_degree = get_log_quotient_degree::, A>(air, public_values.len()); let quotient_degree = 1 << log_quotient_degree; let pcs = config.pcs(); @@ -62,6 +89,10 @@ where let quotient_domain = trace_domain.create_disjoint_domain(1 << (log_degree + log_quotient_degree)); + let preprocessed_on_quotient_domain = proving_key.map(|proving_key| { + pcs.get_evaluations_on_domain(&proving_key.preprocessed_data, 0, quotient_domain) + }); + let trace_on_quotient_domain = pcs.get_evaluations_on_domain(&trace_data, 0, quotient_domain); let quotient_values = quotient_values( @@ -69,6 +100,7 @@ where public_values, trace_domain, quotient_domain, + preprocessed_on_quotient_domain, trace_on_quotient_domain, alpha, ); @@ -89,22 +121,54 @@ where let zeta_next = trace_domain.next_point(zeta).unwrap(); let (opened_values, opening_proof) = pcs.open( - vec![ - (&trace_data, vec![vec![zeta, zeta_next]]), - ( - "ient_data, - // open every chunk at zeta - (0..quotient_degree).map(|_| vec![zeta]).collect_vec(), - ), - ], + iter::empty() + .chain( + proving_key + .map(|proving_key| { + (&proving_key.preprocessed_data, vec![vec![zeta, zeta_next]]) + }) + .into_iter(), + ) + .chain([ + (&trace_data, vec![vec![zeta, zeta_next]]), + ( + "ient_data, + // open every chunk at zeta + (0..quotient_degree).map(|_| vec![zeta]).collect_vec(), + ), + ]) + .collect_vec(), challenger, ); - let trace_local = opened_values[0][0][0].clone(); - let trace_next = opened_values[0][0][1].clone(); - let quotient_chunks = opened_values[1].iter().map(|v| v[0].clone()).collect_vec(); + let mut opened_values = opened_values.iter(); + + // maybe get values for the preprocessed columns + let (preprocessed_local, preprocessed_next) = if proving_key.is_some() { + let value = opened_values.next().unwrap(); + assert_eq!(value.len(), 1); + assert_eq!(value[0].len(), 2); + (value[0][0].clone(), value[0][1].clone()) + } else { + (vec![], vec![]) + }; + + // get values for the trace + let value = opened_values.next().unwrap(); + assert_eq!(value.len(), 1); + assert_eq!(value[0].len(), 2); + let trace_local = value[0][0].clone(); + let trace_next = value[0][1].clone(); + + // get values for the quotient + let value = opened_values.next().unwrap(); + assert_eq!(value.len(), quotient_degree); + let quotient_chunks = value.iter().map(|v| v[0].clone()).collect_vec(); + let opened_values = OpenedValues { trace_local, trace_next, + preprocessed_local, + preprocessed_next, quotient_chunks, }; Proof { @@ -121,6 +185,7 @@ fn quotient_values( public_values: &Vec>, trace_domain: Domain, quotient_domain: Domain, + preprocessed_on_quotient_domain: Option, trace_on_quotient_domain: Mat, alpha: SC::Challenge, ) -> Vec @@ -130,6 +195,10 @@ where Mat: Matrix> + Sync, { let quotient_size = quotient_domain.size(); + let preprocessed_width = preprocessed_on_quotient_domain + .as_ref() + .map(Matrix::width) + .unwrap_or_default(); let width = trace_on_quotient_domain.width(); let mut sels = trace_domain.selectors_on_coset(quotient_domain); @@ -156,6 +225,23 @@ where let is_transition = *PackedVal::::from_slice(&sels.is_transition[i_range.clone()]); let inv_zeroifier = *PackedVal::::from_slice(&sels.inv_zeroifier[i_range.clone()]); + let preprocessed = RowMajorMatrix::new( + iter::empty() + .chain(preprocessed_on_quotient_domain.iter().flat_map( + |preprocessed_on_quotient_domain| { + preprocessed_on_quotient_domain.vertically_packed_row(i_start) + }, + )) + .chain(preprocessed_on_quotient_domain.iter().flat_map( + |preprocessed_on_quotient_domain| { + preprocessed_on_quotient_domain + .vertically_packed_row(i_start + next_step) + }, + )) + .collect_vec(), + preprocessed_width, + ); + let main = RowMajorMatrix::new( iter::empty() .chain(trace_on_quotient_domain.vertically_packed_row(i_start)) @@ -166,6 +252,7 @@ where let accumulator = PackedChallenge::::zero(); let mut folder = ProverConstraintFolder { + preprocessed, main, public_values, is_first_row, diff --git a/uni-stark/src/symbolic_builder.rs b/uni-stark/src/symbolic_builder.rs index 1b106c60a..813f60032 100644 --- a/uni-stark/src/symbolic_builder.rs +++ b/uni-stark/src/symbolic_builder.rs @@ -12,18 +12,13 @@ use crate::symbolic_variable::SymbolicVariable; use crate::Entry; #[instrument(name = "infer log of constraint degree", skip_all)] -pub fn get_log_quotient_degree( - air: &A, - preprocessed_width: usize, - num_public_values: usize, -) -> usize +pub fn get_log_quotient_degree(air: &A, num_public_values: usize) -> usize where F: Field, A: Air>, { // We pad to at least degree 2, since a quotient argument doesn't make sense with smaller degrees. - let constraint_degree = - get_max_constraint_degree(air, preprocessed_width, num_public_values).max(2); + let constraint_degree = get_max_constraint_degree(air, num_public_values).max(2); // The quotient's actual degree is approximately (max_constraint_degree - 1) n, // where subtracting 1 comes from division by the zerofier. @@ -32,16 +27,12 @@ where } #[instrument(name = "infer constraint degree", skip_all, level = "debug")] -pub fn get_max_constraint_degree( - air: &A, - preprocessed_width: usize, - num_public_values: usize, -) -> usize +pub fn get_max_constraint_degree(air: &A, num_public_values: usize) -> usize where F: Field, A: Air>, { - get_symbolic_constraints(air, preprocessed_width, num_public_values) + get_symbolic_constraints(air, num_public_values) .iter() .map(|c| c.degree_multiple()) .max() @@ -51,14 +42,14 @@ where #[instrument(name = "evaluate constraints symbolically", skip_all, level = "debug")] pub fn get_symbolic_constraints( air: &A, - preprocessed_width: usize, num_public_values: usize, ) -> Vec> where F: Field, A: Air>, { - let mut builder = SymbolicAirBuilder::new(preprocessed_width, air.width(), num_public_values); + let mut builder = + SymbolicAirBuilder::new(air.preprocessed_width(), air.width(), num_public_values); air.eval(&mut builder); builder.constraints() } diff --git a/uni-stark/src/verifier.rs b/uni-stark/src/verifier.rs index 50cd68238..c0e64cdac 100644 --- a/uni-stark/src/verifier.rs +++ b/uni-stark/src/verifier.rs @@ -1,5 +1,6 @@ use alloc::vec; use alloc::vec::Vec; +use core::iter; use itertools::Itertools; use p3_air::{Air, BaseAir}; @@ -11,7 +12,9 @@ use p3_matrix::stack::VerticalPair; use tracing::instrument; use crate::symbolic_builder::{get_log_quotient_degree, SymbolicAirBuilder}; -use crate::{PcsError, Proof, StarkGenericConfig, Val, VerifierConstraintFolder}; +use crate::{ + PcsError, Proof, StarkGenericConfig, StarkVerifyingKey, Val, VerifierConstraintFolder, +}; #[instrument(skip_all)] pub fn verify( @@ -21,6 +24,22 @@ pub fn verify( proof: &Proof, public_values: &Vec>, ) -> Result<(), VerificationError>> +where + SC: StarkGenericConfig, + A: Air>> + for<'a> Air>, +{ + verify_with_key(config, None, air, challenger, proof, public_values) +} + +#[instrument(skip_all)] +pub fn verify_with_key( + config: &SC, + verifying_key: Option<&StarkVerifyingKey>, + air: &A, + challenger: &mut SC::Challenger, + proof: &Proof, + public_values: &Vec>, +) -> Result<(), VerificationError>> where SC: StarkGenericConfig, A: Air>> + for<'a> Air>, @@ -33,7 +52,7 @@ where } = proof; let degree = 1 << degree_bits; - let log_quotient_degree = get_log_quotient_degree::, A>(air, 0, public_values.len()); + let log_quotient_degree = get_log_quotient_degree::, A>(air, public_values.len()); let quotient_degree = 1 << log_quotient_degree; let pcs = config.pcs(); @@ -71,26 +90,44 @@ where let zeta_next = trace_domain.next_point(zeta).unwrap(); pcs.verify( - vec![ - ( - commitments.trace.clone(), - vec![( - trace_domain, - vec![ - (zeta, opened_values.trace_local.clone()), - (zeta_next, opened_values.trace_next.clone()), - ], - )], - ), - ( - commitments.quotient_chunks.clone(), - quotient_chunks_domains - .iter() - .zip(&opened_values.quotient_chunks) - .map(|(domain, values)| (*domain, vec![(zeta, values.clone())])) - .collect_vec(), - ), - ], + iter::empty() + .chain( + verifying_key + .map(|verifying_key| { + ( + verifying_key.preprocessed_commit.clone(), + (vec![( + trace_domain, + vec![ + (zeta, opened_values.preprocessed_local.clone()), + (zeta_next, opened_values.preprocessed_next.clone()), + ], + )]), + ) + }) + .into_iter(), + ) + .chain([ + ( + commitments.trace.clone(), + vec![( + trace_domain, + vec![ + (zeta, opened_values.trace_local.clone()), + (zeta_next, opened_values.trace_next.clone()), + ], + )], + ), + ( + commitments.quotient_chunks.clone(), + quotient_chunks_domains + .iter() + .zip(&opened_values.quotient_chunks) + .map(|(domain, values)| (*domain, vec![(zeta, values.clone())])) + .collect_vec(), + ), + ]) + .collect_vec(), opening_proof, challenger, ) @@ -126,12 +163,18 @@ where let sels = trace_domain.selectors_at_point(zeta); + let preprocessed = VerticalPair::new( + RowMajorMatrixView::new_row(&opened_values.preprocessed_local), + RowMajorMatrixView::new_row(&opened_values.preprocessed_next), + ); + let main = VerticalPair::new( RowMajorMatrixView::new_row(&opened_values.trace_local), RowMajorMatrixView::new_row(&opened_values.trace_next), ); let mut folder = VerifierConstraintFolder { + preprocessed, main, public_values, is_first_row: sels.is_first_row,