Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fixed columns to uni-stark #1

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions air/src/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ pub trait BaseAir<F>: Sync {
fn preprocessed_trace(&self) -> Option<RowMajorMatrix<F>> {
None
}

/// The number of preprocessed columns in this AIR
fn preprocessed_width(&self) -> usize {
0
}
}

/// An AIR that works with a particular `AirBuilder`.
Expand Down
25 changes: 22 additions & 3 deletions uni-stark/src/check_constraints.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
use alloc::vec::Vec;

use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues};
use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, PairBuilder};
use p3_field::Field;
use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView};
use p3_matrix::stack::VerticalPair;
use p3_matrix::Matrix;
use tracing::instrument;

#[instrument(name = "check constraints", skip_all)]
pub(crate) fn check_constraints<F, A>(air: &A, main: &RowMajorMatrix<F>, public_values: &Vec<F>)
where
pub(crate) fn check_constraints<F, A>(
air: &A,
preprocessed: &RowMajorMatrix<F>,
main: &RowMajorMatrix<F>,
public_values: &Vec<F>,
) where
F: Field,
A: for<'a> Air<DebugConstraintBuilder<'a, F>>,
{
Expand All @@ -18,6 +22,13 @@ where
(0..height).for_each(|i| {
let i_next = (i + 1) % height;

let local_preprocessed = preprocessed.row_slice(i);
let next_preprocessed = preprocessed.row_slice(i_next);
let preprocessed = VerticalPair::new(
RowMajorMatrixView::new_row(&*local_preprocessed),
RowMajorMatrixView::new_row(&*next_preprocessed),
);

let local = main.row_slice(i);
let next = main.row_slice(i_next);
let main = VerticalPair::new(
Expand All @@ -27,6 +38,7 @@ where

let mut builder = DebugConstraintBuilder {
row_index: i,
preprocessed,
main,
public_values,
is_first_row: F::from_bool(i == 0),
Expand All @@ -43,6 +55,7 @@ where
#[derive(Debug)]
pub struct DebugConstraintBuilder<'a, F: Field> {
row_index: usize,
preprocessed: VerticalPair<RowMajorMatrixView<'a, F>, RowMajorMatrixView<'a, F>>,
main: VerticalPair<RowMajorMatrixView<'a, F>, RowMajorMatrixView<'a, F>>,
public_values: &'a [F],
is_first_row: F,
Expand Down Expand Up @@ -106,3 +119,9 @@ impl<'a, F: Field> AirBuilderWithPublicValues for DebugConstraintBuilder<'a, F>
self.public_values
}
}

impl<'a, F: Field> PairBuilder for DebugConstraintBuilder<'a, F> {
fn preprocessed(&self) -> Self::M {
self.preprocessed
}
}
16 changes: 15 additions & 1 deletion uni-stark/src/folder.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use alloc::vec::Vec;

use p3_air::{AirBuilder, AirBuilderWithPublicValues};
use p3_air::{AirBuilder, AirBuilderWithPublicValues, PairBuilder};
use p3_field::AbstractField;
use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView};
use p3_matrix::stack::VerticalPair;
Expand All @@ -10,6 +10,7 @@ use crate::{PackedChallenge, PackedVal, StarkGenericConfig, Val};
#[derive(Debug)]
pub struct ProverConstraintFolder<'a, SC: StarkGenericConfig> {
pub main: RowMajorMatrix<PackedVal<SC>>,
pub preprocessed: RowMajorMatrix<PackedVal<SC>>,
pub public_values: &'a Vec<Val<SC>>,
pub is_first_row: PackedVal<SC>,
pub is_last_row: PackedVal<SC>,
Expand All @@ -23,6 +24,7 @@ type ViewPair<'a, T> = VerticalPair<RowMajorMatrixView<'a, T>, RowMajorMatrixVie
#[derive(Debug)]
pub struct VerifierConstraintFolder<'a, SC: StarkGenericConfig> {
pub main: ViewPair<'a, SC::Challenge>,
pub preprocessed: ViewPair<'a, SC::Challenge>,
pub public_values: &'a Vec<Val<SC>>,
pub is_first_row: SC::Challenge,
pub is_last_row: SC::Challenge,
Expand Down Expand Up @@ -72,6 +74,12 @@ impl<'a, SC: StarkGenericConfig> AirBuilderWithPublicValues for ProverConstraint
}
}

impl<'a, SC: StarkGenericConfig> PairBuilder for ProverConstraintFolder<'a, SC> {
fn preprocessed(&self) -> Self::M {
self.preprocessed.clone()
}
}

impl<'a, SC: StarkGenericConfig> AirBuilder for VerifierConstraintFolder<'a, SC> {
type F = Val<SC>;
type Expr = SC::Challenge;
Expand Down Expand Up @@ -112,3 +120,9 @@ impl<'a, SC: StarkGenericConfig> AirBuilderWithPublicValues for VerifierConstrai
self.public_values
}
}

impl<'a, SC: StarkGenericConfig> PairBuilder for VerifierConstraintFolder<'a, SC> {
fn preprocessed(&self) -> Self::M {
self.preprocessed
}
}
17 changes: 17 additions & 0 deletions uni-stark/src/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ type PcsProof<SC> = <<SC as StarkGenericConfig>::Pcs as Pcs<
<SC as StarkGenericConfig>::Challenge,
<SC as StarkGenericConfig>::Challenger,
>>::Proof;
pub type PcsProverData<SC> = <<SC as StarkGenericConfig>::Pcs as Pcs<
<SC as StarkGenericConfig>::Challenge,
<SC as StarkGenericConfig>::Challenger,
>>::ProverData;

#[derive(Serialize, Deserialize)]
#[serde(bound = "")]
Expand All @@ -31,7 +35,20 @@ pub struct Commitments<Com> {

#[derive(Debug, Serialize, Deserialize)]
pub struct OpenedValues<Challenge> {
pub(crate) preprocessed_local: Vec<Challenge>,
pub(crate) preprocessed_next: Vec<Challenge>,
pub(crate) trace_local: Vec<Challenge>,
pub(crate) trace_next: Vec<Challenge>,
pub(crate) quotient_chunks: Vec<Vec<Challenge>>,
}

pub struct StarkProvingKey<SC: StarkGenericConfig> {
pub preprocessed_commit: Com<SC>,
pub preprocessed_data: PcsProverData<SC>,
}

#[derive(Serialize, Deserialize)]
#[serde(bound = "")]
pub struct StarkVerifyingKey<SC: StarkGenericConfig> {
pub preprocessed_commit: Com<SC>,
}
114 changes: 100 additions & 14 deletions uni-stark/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use tracing::{info_span, instrument};
use crate::symbolic_builder::{get_log_quotient_degree, SymbolicAirBuilder};
use crate::{
Commitments, Domain, OpenedValues, PackedChallenge, PackedVal, Proof, ProverConstraintFolder,
StarkGenericConfig, Val,
StarkGenericConfig, StarkProvingKey, Val,
};

#[instrument(skip_all)]
Expand All @@ -32,17 +32,44 @@ pub fn prove<
trace: RowMajorMatrix<Val<SC>>,
public_values: &Vec<Val<SC>>,
) -> Proof<SC>
where
SC: StarkGenericConfig,
A: Air<SymbolicAirBuilder<Val<SC>>> + for<'a> Air<ProverConstraintFolder<'a, SC>>,
{
prove_with_key(config, None, air, challenger, trace, public_values)
}

#[instrument(skip_all)]
#[allow(clippy::multiple_bound_locations)] // cfg not supported in where clauses?
pub fn prove_with_key<
SC,
#[cfg(debug_assertions)] A: for<'a> Air<crate::check_constraints::DebugConstraintBuilder<'a, Val<SC>>>,
#[cfg(not(debug_assertions))] A,
>(
config: &SC,
proving_key: Option<&StarkProvingKey<SC>>,
air: &A,
challenger: &mut SC::Challenger,
trace: RowMajorMatrix<Val<SC>>,
public_values: &Vec<Val<SC>>,
) -> Proof<SC>
where
SC: StarkGenericConfig,
A: Air<SymbolicAirBuilder<Val<SC>>> + for<'a> Air<ProverConstraintFolder<'a, SC>>,
{
#[cfg(debug_assertions)]
crate::check_constraints::check_constraints(air, &trace, public_values);
crate::check_constraints::check_constraints(
air,
&air.preprocessed_trace()
.unwrap_or(RowMajorMatrix::new(vec![], 0)),
&trace,
public_values,
);

let degree = trace.height();
let log_degree = log2_strict_usize(degree);

let log_quotient_degree = get_log_quotient_degree::<Val<SC>, A>(air, 0, public_values.len());
let log_quotient_degree = get_log_quotient_degree::<Val<SC>, A>(air, public_values.len());
let quotient_degree = 1 << log_quotient_degree;

let pcs = config.pcs();
Expand All @@ -55,20 +82,28 @@ where
challenger.observe(Val::<SC>::from_canonical_usize(log_degree));
// TODO: Might be best practice to include other instance data here; see verifier comment.

if let Some(proving_key) = proving_key {
challenger.observe(proving_key.preprocessed_commit.clone())
};
challenger.observe(trace_commit.clone());
challenger.observe_slice(public_values);
let alpha: SC::Challenge = challenger.sample_ext_element();

let quotient_domain =
trace_domain.create_disjoint_domain(1 << (log_degree + log_quotient_degree));

let preprocessed_on_quotient_domain = proving_key.map(|proving_key| {
pcs.get_evaluations_on_domain(&proving_key.preprocessed_data, 0, quotient_domain)
});

let trace_on_quotient_domain = pcs.get_evaluations_on_domain(&trace_data, 0, quotient_domain);

let quotient_values = quotient_values(
air,
public_values,
trace_domain,
quotient_domain,
preprocessed_on_quotient_domain,
trace_on_quotient_domain,
alpha,
);
Expand All @@ -89,22 +124,54 @@ where
let zeta_next = trace_domain.next_point(zeta).unwrap();

let (opened_values, opening_proof) = pcs.open(
vec![
(&trace_data, vec![vec![zeta, zeta_next]]),
(
&quotient_data,
// open every chunk at zeta
(0..quotient_degree).map(|_| vec![zeta]).collect_vec(),
),
],
iter::empty()
.chain(
proving_key
.map(|proving_key| {
(&proving_key.preprocessed_data, vec![vec![zeta, zeta_next]])
})
.into_iter(),
)
.chain([
(&trace_data, vec![vec![zeta, zeta_next]]),
(
&quotient_data,
// open every chunk at zeta
(0..quotient_degree).map(|_| vec![zeta]).collect_vec(),
),
])
.collect_vec(),
challenger,
);
let trace_local = opened_values[0][0][0].clone();
let trace_next = opened_values[0][0][1].clone();
let quotient_chunks = opened_values[1].iter().map(|v| v[0].clone()).collect_vec();
let mut opened_values = opened_values.iter();

// maybe get values for the preprocessed columns
let (preprocessed_local, preprocessed_next) = if proving_key.is_some() {
let value = opened_values.next().unwrap();
assert_eq!(value.len(), 1);
assert_eq!(value[0].len(), 2);
(value[0][0].clone(), value[0][1].clone())
} else {
(vec![], vec![])
};

// get values for the trace
let value = opened_values.next().unwrap();
assert_eq!(value.len(), 1);
assert_eq!(value[0].len(), 2);
let trace_local = value[0][0].clone();
let trace_next = value[0][1].clone();

// get values for the quotient
let value = opened_values.next().unwrap();
assert_eq!(value.len(), quotient_degree);
let quotient_chunks = value.iter().map(|v| v[0].clone()).collect_vec();

let opened_values = OpenedValues {
trace_local,
trace_next,
preprocessed_local,
preprocessed_next,
quotient_chunks,
};
Proof {
Expand All @@ -121,6 +188,7 @@ fn quotient_values<SC, A, Mat>(
public_values: &Vec<Val<SC>>,
trace_domain: Domain<SC>,
quotient_domain: Domain<SC>,
preprocessed_on_quotient_domain: Option<Mat>,
trace_on_quotient_domain: Mat,
alpha: SC::Challenge,
) -> Vec<SC::Challenge>
Expand All @@ -130,6 +198,10 @@ where
Mat: Matrix<Val<SC>> + Sync,
{
let quotient_size = quotient_domain.size();
let preprocessed_width = preprocessed_on_quotient_domain
.as_ref()
.map(Matrix::width)
.unwrap_or_default();
let width = trace_on_quotient_domain.width();
let mut sels = trace_domain.selectors_on_coset(quotient_domain);

Expand All @@ -156,6 +228,19 @@ where
let is_transition = *PackedVal::<SC>::from_slice(&sels.is_transition[i_range.clone()]);
let inv_zeroifier = *PackedVal::<SC>::from_slice(&sels.inv_zeroifier[i_range.clone()]);

let preprocessed = RowMajorMatrix::new(
preprocessed_on_quotient_domain
.as_ref()
.map(|on_quotient_domain| {
iter::empty()
.chain(on_quotient_domain.vertically_packed_row(i_start))
.chain(on_quotient_domain.vertically_packed_row(i_start + next_step))
.collect_vec()
})
.unwrap_or_default(),
preprocessed_width,
);

let main = RowMajorMatrix::new(
iter::empty()
.chain(trace_on_quotient_domain.vertically_packed_row(i_start))
Expand All @@ -166,6 +251,7 @@ where

let accumulator = PackedChallenge::<SC>::zero();
let mut folder = ProverConstraintFolder {
preprocessed,
main,
public_values,
is_first_row,
Expand Down
Loading
Loading