Skip to content

Commit

Permalink
implement fixed columns in uni-stark, changed BaseAir
Browse files Browse the repository at this point in the history
  • Loading branch information
Schaeff committed Jul 8, 2024
1 parent c3d754e commit 7c5c621
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 42 deletions.
8 changes: 6 additions & 2 deletions air/src/air.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use core::ops::{Add, Mul, Sub};

use p3_field::{AbstractExtensionField, AbstractField, ExtensionField, Field};
use p3_matrix::dense::RowMajorMatrix;
use p3_matrix::Matrix;
use p3_matrix::{dense::RowMajorMatrix, Matrix};

/// An AIR (algebraic intermediate representation).
pub trait BaseAir<F>: Sync {
Expand All @@ -12,6 +11,11 @@ pub trait BaseAir<F>: Sync {
fn preprocessed_trace(&self) -> Option<RowMajorMatrix<F>> {
None
}

/// The number of preprocessed columns in this AIR
fn preprocessed_width(&self) -> usize {
0
}
}

/// An AIR that works with a particular `AirBuilder`.
Expand Down
25 changes: 22 additions & 3 deletions uni-stark/src/check_constraints.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
use alloc::vec::Vec;

use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues};
use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, PairBuilder};
use p3_field::Field;
use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView};
use p3_matrix::stack::VerticalPair;
use p3_matrix::Matrix;
use tracing::instrument;

#[instrument(name = "check constraints", skip_all)]
pub(crate) fn check_constraints<F, A>(air: &A, main: &RowMajorMatrix<F>, public_values: &Vec<F>)
where
pub(crate) fn check_constraints<F, A>(
air: &A,
preprocessed: &RowMajorMatrix<F>,
main: &RowMajorMatrix<F>,
public_values: &Vec<F>,
) where
F: Field,
A: for<'a> Air<DebugConstraintBuilder<'a, F>>,
{
Expand All @@ -18,6 +22,13 @@ where
(0..height).for_each(|i| {
let i_next = (i + 1) % height;

let local_preprocessed = preprocessed.row_slice(i);
let next_preprocessed = preprocessed.row_slice(i_next);
let preprocessed = VerticalPair::new(
RowMajorMatrixView::new_row(&*local_preprocessed),
RowMajorMatrixView::new_row(&*next_preprocessed),
);

let local = main.row_slice(i);
let next = main.row_slice(i_next);
let main = VerticalPair::new(
Expand All @@ -27,6 +38,7 @@ where

let mut builder = DebugConstraintBuilder {
row_index: i,
preprocessed,
main,
public_values,
is_first_row: F::from_bool(i == 0),
Expand All @@ -43,6 +55,7 @@ where
#[derive(Debug)]
pub struct DebugConstraintBuilder<'a, F: Field> {
row_index: usize,
preprocessed: VerticalPair<RowMajorMatrixView<'a, F>, RowMajorMatrixView<'a, F>>,
main: VerticalPair<RowMajorMatrixView<'a, F>, RowMajorMatrixView<'a, F>>,
public_values: &'a [F],
is_first_row: F,
Expand Down Expand Up @@ -106,3 +119,9 @@ impl<'a, F: Field> AirBuilderWithPublicValues for DebugConstraintBuilder<'a, F>
self.public_values
}
}

impl<'a, F: Field> PairBuilder for DebugConstraintBuilder<'a, F> {
fn preprocessed(&self) -> Self::M {
self.preprocessed
}
}
16 changes: 15 additions & 1 deletion uni-stark/src/folder.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use alloc::vec::Vec;

use p3_air::{AirBuilder, AirBuilderWithPublicValues};
use p3_air::{AirBuilder, AirBuilderWithPublicValues, PairBuilder};
use p3_field::AbstractField;
use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView};
use p3_matrix::stack::VerticalPair;
Expand All @@ -10,6 +10,7 @@ use crate::{PackedChallenge, PackedVal, StarkGenericConfig, Val};
#[derive(Debug)]
pub struct ProverConstraintFolder<'a, SC: StarkGenericConfig> {
pub main: RowMajorMatrix<PackedVal<SC>>,
pub preprocessed: RowMajorMatrix<PackedVal<SC>>,
pub public_values: &'a Vec<Val<SC>>,
pub is_first_row: PackedVal<SC>,
pub is_last_row: PackedVal<SC>,
Expand All @@ -23,6 +24,7 @@ type ViewPair<'a, T> = VerticalPair<RowMajorMatrixView<'a, T>, RowMajorMatrixVie
#[derive(Debug)]
pub struct VerifierConstraintFolder<'a, SC: StarkGenericConfig> {
pub main: ViewPair<'a, SC::Challenge>,
pub preprocessed: ViewPair<'a, SC::Challenge>,
pub public_values: &'a Vec<Val<SC>>,
pub is_first_row: SC::Challenge,
pub is_last_row: SC::Challenge,
Expand Down Expand Up @@ -72,6 +74,12 @@ impl<'a, SC: StarkGenericConfig> AirBuilderWithPublicValues for ProverConstraint
}
}

impl<'a, SC: StarkGenericConfig> PairBuilder for ProverConstraintFolder<'a, SC> {
fn preprocessed(&self) -> Self::M {
self.preprocessed.clone()
}
}

impl<'a, SC: StarkGenericConfig> AirBuilder for VerifierConstraintFolder<'a, SC> {
type F = Val<SC>;
type Expr = SC::Challenge;
Expand Down Expand Up @@ -112,3 +120,9 @@ impl<'a, SC: StarkGenericConfig> AirBuilderWithPublicValues for VerifierConstrai
self.public_values
}
}

impl<'a, SC: StarkGenericConfig> PairBuilder for VerifierConstraintFolder<'a, SC> {
fn preprocessed(&self) -> Self::M {
self.preprocessed
}
}
17 changes: 17 additions & 0 deletions uni-stark/src/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ type PcsProof<SC> = <<SC as StarkGenericConfig>::Pcs as Pcs<
<SC as StarkGenericConfig>::Challenge,
<SC as StarkGenericConfig>::Challenger,
>>::Proof;
pub type PcsProverData<SC> = <<SC as StarkGenericConfig>::Pcs as Pcs<
<SC as StarkGenericConfig>::Challenge,
<SC as StarkGenericConfig>::Challenger,
>>::ProverData;

#[derive(Serialize, Deserialize)]
#[serde(bound = "")]
Expand All @@ -31,7 +35,20 @@ pub struct Commitments<Com> {

#[derive(Debug, Serialize, Deserialize)]
pub struct OpenedValues<Challenge> {
pub(crate) preprocessed_local: Vec<Challenge>,
pub(crate) preprocessed_next: Vec<Challenge>,
pub(crate) trace_local: Vec<Challenge>,
pub(crate) trace_next: Vec<Challenge>,
pub(crate) quotient_chunks: Vec<Vec<Challenge>>,
}

pub struct StarkProvingKey<SC: StarkGenericConfig> {
pub preprocessed_commit: Com<SC>,
pub preprocessed_data: PcsProverData<SC>,
}

#[derive(Serialize, Deserialize)]
#[serde(bound = "")]
pub struct StarkVerifyingKey<SC: StarkGenericConfig> {
pub preprocessed_commit: Com<SC>,
}
116 changes: 102 additions & 14 deletions uni-stark/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use tracing::{info_span, instrument};
use crate::symbolic_builder::{get_log_quotient_degree, SymbolicAirBuilder};
use crate::{
Commitments, Domain, OpenedValues, PackedChallenge, PackedVal, Proof, ProverConstraintFolder,
StarkGenericConfig, Val,
StarkGenericConfig, StarkProvingKey, Val,
};

#[instrument(skip_all)]
Expand All @@ -32,17 +32,45 @@ pub fn prove<
trace: RowMajorMatrix<Val<SC>>,
public_values: &Vec<Val<SC>>,
) -> Proof<SC>
where
SC: StarkGenericConfig,
A: Air<SymbolicAirBuilder<Val<SC>>> + for<'a> Air<ProverConstraintFolder<'a, SC>>,
{
prove_with_key(config, None, air, challenger, trace, public_values)
}

#[instrument(skip_all)]
#[allow(clippy::multiple_bound_locations)] // cfg not supported in where clauses?
pub fn prove_with_key<
SC,
#[cfg(debug_assertions)] A: for<'a> Air<crate::check_constraints::DebugConstraintBuilder<'a, Val<SC>>>,
#[cfg(not(debug_assertions))] A,
>(
config: &SC,
proving_key: Option<&StarkProvingKey<SC>>,
air: &A,
challenger: &mut SC::Challenger,
trace: RowMajorMatrix<Val<SC>>,
public_values: &Vec<Val<SC>>,
) -> Proof<SC>
where
SC: StarkGenericConfig,
A: Air<SymbolicAirBuilder<Val<SC>>> + for<'a> Air<ProverConstraintFolder<'a, SC>>,
{
#[cfg(debug_assertions)]
crate::check_constraints::check_constraints(air, &trace, public_values);
crate::check_constraints::check_constraints(
air,
&air.preprocessed_trace()
.unwrap_or(RowMajorMatrix::new(vec![], 0)),
&trace,
public_values,
);

let degree = trace.height();
let log_degree = log2_strict_usize(degree);

let log_quotient_degree = get_log_quotient_degree::<Val<SC>, A>(air, 0, public_values.len());
let log_quotient_degree =
get_log_quotient_degree::<Val<SC>, A>(air, air.preprocessed_width(), public_values.len());
let quotient_degree = 1 << log_quotient_degree;

let pcs = config.pcs();
Expand All @@ -62,13 +90,18 @@ where
let quotient_domain =
trace_domain.create_disjoint_domain(1 << (log_degree + log_quotient_degree));

let preprocessed_on_quotient_domain = proving_key.map(|proving_key| {
pcs.get_evaluations_on_domain(&proving_key.preprocessed_data, 0, quotient_domain)
});

let trace_on_quotient_domain = pcs.get_evaluations_on_domain(&trace_data, 0, quotient_domain);

let quotient_values = quotient_values(
air,
public_values,
trace_domain,
quotient_domain,
preprocessed_on_quotient_domain,
trace_on_quotient_domain,
alpha,
);
Expand All @@ -89,22 +122,54 @@ where
let zeta_next = trace_domain.next_point(zeta).unwrap();

let (opened_values, opening_proof) = pcs.open(
vec![
(&trace_data, vec![vec![zeta, zeta_next]]),
(
&quotient_data,
// open every chunk at zeta
(0..quotient_degree).map(|_| vec![zeta]).collect_vec(),
),
],
iter::empty()
.chain(
proving_key
.map(|proving_key| {
(&proving_key.preprocessed_data, vec![vec![zeta, zeta_next]])
})
.into_iter(),
)
.chain([
(&trace_data, vec![vec![zeta, zeta_next]]),
(
&quotient_data,
// open every chunk at zeta
(0..quotient_degree).map(|_| vec![zeta]).collect_vec(),
),
])
.collect_vec(),
challenger,
);
let trace_local = opened_values[0][0][0].clone();
let trace_next = opened_values[0][0][1].clone();
let quotient_chunks = opened_values[1].iter().map(|v| v[0].clone()).collect_vec();
let mut opened_values = opened_values.iter();

// maybe get values for the preprocessed columns
let (preprocessed_local, preprocessed_next) = if proving_key.is_some() {
let value = opened_values.next().unwrap();
assert_eq!(value.len(), 1);
assert_eq!(value[0].len(), 2);
(value[0][0].clone(), value[0][1].clone())
} else {
(vec![], vec![])
};

// get values for the trace
let value = opened_values.next().unwrap();
assert_eq!(value.len(), 1);
assert_eq!(value[0].len(), 2);
let trace_local = value[0][0].clone();
let trace_next = value[0][1].clone();

// get values for the quotient
let value = opened_values.next().unwrap();
assert_eq!(value.len(), quotient_degree);
let quotient_chunks = value.iter().map(|v| v[0].clone()).collect_vec();

let opened_values = OpenedValues {
trace_local,
trace_next,
preprocessed_local,
preprocessed_next,
quotient_chunks,
};
Proof {
Expand All @@ -121,6 +186,7 @@ fn quotient_values<SC, A, Mat>(
public_values: &Vec<Val<SC>>,
trace_domain: Domain<SC>,
quotient_domain: Domain<SC>,
preprocessed_on_quotient_domain: Option<Mat>,
trace_on_quotient_domain: Mat,
alpha: SC::Challenge,
) -> Vec<SC::Challenge>
Expand All @@ -130,6 +196,10 @@ where
Mat: Matrix<Val<SC>> + Sync,
{
let quotient_size = quotient_domain.size();
let preprocessed_width = preprocessed_on_quotient_domain
.as_ref()
.map(Matrix::width)
.unwrap_or_default();
let width = trace_on_quotient_domain.width();
let mut sels = trace_domain.selectors_on_coset(quotient_domain);

Expand All @@ -156,6 +226,23 @@ where
let is_transition = *PackedVal::<SC>::from_slice(&sels.is_transition[i_range.clone()]);
let inv_zeroifier = *PackedVal::<SC>::from_slice(&sels.inv_zeroifier[i_range.clone()]);

let preprocessed = RowMajorMatrix::new(
iter::empty()
.chain(preprocessed_on_quotient_domain.iter().flat_map(
|preprocessed_on_quotient_domain| {
preprocessed_on_quotient_domain.vertically_packed_row(i_start)
},
))
.chain(preprocessed_on_quotient_domain.iter().flat_map(
|preprocessed_on_quotient_domain| {
preprocessed_on_quotient_domain
.vertically_packed_row(i_start + next_step)
},
))
.collect_vec(),
preprocessed_width,
);

let main = RowMajorMatrix::new(
iter::empty()
.chain(trace_on_quotient_domain.vertically_packed_row(i_start))
Expand All @@ -166,6 +253,7 @@ where

let accumulator = PackedChallenge::<SC>::zero();
let mut folder = ProverConstraintFolder {
preprocessed,
main,
public_values,
is_first_row,
Expand Down
Loading

0 comments on commit 7c5c621

Please sign in to comment.