diff --git a/backend/src/composite/mod.rs b/backend/src/composite/mod.rs index e8eb63b25..ad0be7e86 100644 --- a/backend/src/composite/mod.rs +++ b/backend/src/composite/mod.rs @@ -3,9 +3,20 @@ use std::{collections::BTreeMap, io, marker::PhantomData, path::PathBuf, sync::A use powdr_ast::analyzed::Analyzed; use powdr_executor::witgen::WitgenCallback; use powdr_number::{DegreeType, FieldElement}; +use serde::{Deserialize, Serialize}; +use split::select_machine_columns; use crate::{Backend, BackendFactory, BackendOptions, Error, Proof}; +mod split; + +/// A composite proof that contains a proof for each machine separately. +#[derive(Serialize, Deserialize)] +struct CompositeProof { + /// Map from machine name to proof + proofs: BTreeMap>, +} + pub(crate) struct CompositeBackendFactory> { factory: B, _marker: PhantomData, @@ -35,18 +46,25 @@ impl> BackendFactory for CompositeBacke unimplemented!(); } - let backend_by_machine = ["main"] - .iter() - .map(|machine_name| { + let per_machine_data = split::split_pil((*pil).clone()) + .into_iter() + .map(|(machine_name, pil)| { + let pil = Arc::new(pil); let output_dir = output_dir .clone() - .map(|output_dir| output_dir.join(machine_name)); + .map(|output_dir| output_dir.join(&machine_name)); if let Some(ref output_dir) = output_dir { std::fs::create_dir_all(output_dir)?; } + let fixed = Arc::new(select_machine_columns( + &fixed, + pil.constant_polys_in_source_order() + .into_iter() + .map(|(symbol, _)| symbol), + )); let backend = self.factory.create( pil.clone(), - fixed.clone(), + fixed, output_dir, // TODO: Handle setup, verification_key, verification_app_key None, @@ -54,10 +72,12 @@ impl> BackendFactory for CompositeBacke None, backend_options.clone(), ); - backend.map(|backend| (machine_name.to_string(), backend)) + backend.map(|backend| (machine_name.to_string(), MachineData { pil, backend })) }) .collect::, _>>()?; - Ok(Box::new(CompositeBackend { backend_by_machine })) + Ok(Box::new(CompositeBackend { + machine_data: per_machine_data, + })) } fn generate_setup(&self, _size: DegreeType, _output: &mut dyn io::Write) -> Result<(), Error> { @@ -65,8 +85,13 @@ impl> BackendFactory for CompositeBacke } } -pub(crate) struct CompositeBackend<'a, F: FieldElement> { - backend_by_machine: BTreeMap + 'a>>, +struct MachineData<'a, F> { + pil: Arc>, + backend: Box + 'a>, +} + +pub(crate) struct CompositeBackend<'a, F> { + machine_data: BTreeMap>, } // TODO: This just forwards to the backend for now. In the future this should: @@ -81,17 +106,46 @@ impl<'a, F: FieldElement> Backend<'a, F> for CompositeBackend<'a, F> { prev_proof: Option, witgen_callback: WitgenCallback, ) -> Result { - self.backend_by_machine - .get("main") - .unwrap() - .prove(witness, prev_proof, witgen_callback) + if prev_proof.is_some() { + unimplemented!(); + } + + let proof = CompositeProof { + proofs: self + .machine_data + .iter() + .map(|(machine, MachineData { pil, backend })| { + let witgen_callback = witgen_callback.clone().with_pil(pil.clone()); + + log::info!("== Proving machine: {}", machine); + log::debug!("PIL:\n{}", pil); + + let witness = select_machine_columns( + witness, + pil.committed_polys_in_source_order() + .into_iter() + .map(|(symbol, _)| symbol), + ); + + backend + .prove(&witness, None, witgen_callback) + .map(|proof| (machine.clone(), proof)) + }) + .collect::>()?, + }; + Ok(serde_json::to_vec(&proof).unwrap()) } - fn verify(&self, _proof: &[u8], instances: &[Vec]) -> Result<(), Error> { - self.backend_by_machine - .get("main") - .unwrap() - .verify(_proof, instances) + fn verify(&self, proof: &[u8], instances: &[Vec]) -> Result<(), Error> { + let proof: CompositeProof = serde_json::from_slice(proof).unwrap(); + for (machine, machine_proof) in proof.proofs { + let machine_data = self + .machine_data + .get(&machine) + .ok_or_else(|| Error::BackendError(format!("Unknown machine: {machine}")))?; + machine_data.backend.verify(&machine_proof, instances)?; + } + Ok(()) } fn export_setup(&self, _output: &mut dyn io::Write) -> Result<(), Error> { diff --git a/backend/src/composite/split.rs b/backend/src/composite/split.rs new file mode 100644 index 000000000..e0cf7cdbf --- /dev/null +++ b/backend/src/composite/split.rs @@ -0,0 +1,187 @@ +use std::{ + collections::{BTreeMap, BTreeSet}, + ops::ControlFlow, + str::FromStr, +}; + +use powdr_ast::analyzed::SelectedExpressions; +use powdr_ast::{ + analyzed::{ + AlgebraicExpression, Analyzed, Identity, IdentityKind, StatementIdentifier, Symbol, + }, + parsed::{ + asm::{AbsoluteSymbolPath, SymbolPath}, + visitor::{ExpressionVisitable, VisitOrder}, + }, +}; +use powdr_number::FieldElement; + +/// Splits a PIL into multiple PILs, one for each "machine". +/// The rough algorithm is as follows: +/// 1. The PIL is split into namespaces +/// 2. Any lookups or permutations that reference multiple namespaces are removed. +pub(crate) fn split_pil(pil: Analyzed) -> BTreeMap> { + let statements_by_machine = split_by_namespace(&pil); + + statements_by_machine + .into_iter() + .filter_map(|(machine_name, statements)| { + build_machine_pil(pil.clone(), statements).map(|pil| (machine_name, pil)) + }) + .collect() +} + +/// Given a set of columns and a set of polynomial symbols, returns the columns that correspond to the symbols. +pub(crate) fn select_machine_columns<'a, F: FieldElement>( + columns: &[(String, Vec)], + symbols: impl Iterator, +) -> Vec<(String, Vec)> { + let names = symbols + .flat_map(|symbol| symbol.array_elements().map(|(name, _)| name)) + .collect::>(); + columns + .iter() + .filter(|(name, _)| names.contains(name)) + .cloned() + .collect::>() +} + +/// From a symbol name, get the namespace of the symbol. +fn extract_namespace(name: &str) -> String { + let mut namespace = AbsoluteSymbolPath::default().join(SymbolPath::from_str(name).unwrap()); + namespace.pop().unwrap(); + namespace.relative_to(&Default::default()).to_string() +} + +/// From an identity, get the namespaces of the symbols it references. +fn referenced_namespaces( + expression_visitable: &impl ExpressionVisitable>, +) -> BTreeSet { + let mut namespaces = BTreeSet::new(); + expression_visitable.visit_expressions( + &mut (|expr| { + match expr { + AlgebraicExpression::Reference(reference) => { + namespaces.insert(extract_namespace(&reference.name)); + } + AlgebraicExpression::PublicReference(_) => unimplemented!(), + AlgebraicExpression::Challenge(_) => {} + AlgebraicExpression::Number(_) => {} + AlgebraicExpression::BinaryOperation(_) => {} + AlgebraicExpression::UnaryOperation(_) => {} + } + ControlFlow::Continue::<()>(()) + }), + VisitOrder::Pre, + ); + + namespaces +} + +/// Organizes the PIL statements by namespace: +/// - Any definition or public declaration belongs to the namespace of the symbol. +/// - Lookups and permutations that reference multiple namespaces removed. +/// +/// Returns: +/// - statements_by_namespace: A map from namespace to the statements in that namespace. +fn split_by_namespace( + pil: &Analyzed, +) -> BTreeMap> { + pil.source_order + .iter() + // split, filtering out some statements + .filter_map(|statement| match &statement { + StatementIdentifier::Definition(name) + | StatementIdentifier::PublicDeclaration(name) => { + let namespace = extract_namespace(name); + // add `statement` to `namespace` + Some((namespace, statement)) + } + StatementIdentifier::Identity(i) => { + let identity = &pil.identities[*i]; + let namespaces = referenced_namespaces(identity); + + match namespaces.len() { + 0 => panic!("Identity references no namespace: {identity}"), + // add this identity to the only referenced namespace + 1 => Some((namespaces.into_iter().next().unwrap(), statement)), + _ => match identity.kind { + IdentityKind::Plookup | IdentityKind::Permutation => { + assert_eq!( + referenced_namespaces(&identity.left).len(), + 1, + "LHS of identity references multiple namespaces: {identity}" + ); + assert_eq!( + referenced_namespaces(&identity.right).len(), + 1, + "RHS of identity references multiple namespaces: {identity}" + ); + log::debug!("Skipping connecting identity: {identity}"); + None + } + _ => { + panic!("Identity references multiple namespaces: {identity}"); + } + }, + } + } + }) + // collect into a map + .fold(Default::default(), |mut acc, (namespace, statement)| { + acc.entry(namespace).or_default().push(statement.clone()); + acc + }) +} + +/// Given a PIL and a list of statements, returns a new PIL that only contains the +/// given subset of statements. +/// Returns None if there are no identities in the subset of statements. +fn build_machine_pil( + pil: Analyzed, + statements: Vec, +) -> Option> { + // TODO: After #1488 is fixed, we can implement this like so: + // let pil = Analyzed { + // source_order: statements, + // ..pil.clone() + // }; + // let parsed_string = powdr_parser::parse(None, &pil.to_string()).unwrap(); + // let pil = powdr_pil_analyzer::analyze_ast(parsed_string); + + // HACK: Replace unreferenced identities with 0 = 0, to avoid having to re-assign IDs. + let identities = statements + .iter() + .filter_map(|statement| match statement { + StatementIdentifier::Identity(i) => Some(*i as u64), + _ => None, + }) + .collect::>(); + if identities.is_empty() { + // This can happen if a hint references some std module, + // but the module is empty. + return None; + } + let identities = pil + .identities + .iter() + .enumerate() + .map(|(identity_index, identity)| { + if identities.contains(&(identity_index as u64)) { + identity.clone() + } else { + Identity::>>::from_polynomial_identity( + identity.id, + identity.source.clone(), + AlgebraicExpression::Number(F::zero()), + ) + } + }) + .collect(); + + Some(Analyzed { + source_order: statements, + identities, + ..pil + }) +} diff --git a/executor/src/witgen/mod.rs b/executor/src/witgen/mod.rs index 36dd0f381..6d8d72fb5 100644 --- a/executor/src/witgen/mod.rs +++ b/executor/src/witgen/mod.rs @@ -67,6 +67,10 @@ impl WitgenCallback { } } + pub fn with_pil(self, analyzed: Arc>) -> Self { + Self { analyzed, ..self } + } + /// Computes the next-stage witness, given the current witness and challenges. pub fn next_stage_witness( &self, diff --git a/pipeline/src/test_util.rs b/pipeline/src/test_util.rs index 193ce19ba..4d6b8c121 100644 --- a/pipeline/src/test_util.rs +++ b/pipeline/src/test_util.rs @@ -74,13 +74,6 @@ pub fn gen_estark_proof(file_name: &str, inputs: Vec) { pipeline.clone().compute_proof().unwrap(); - // Also test composite backend: - pipeline - .clone() - .with_backend(powdr_backend::BackendType::EStarkStarkyComposite, None) - .compute_proof() - .unwrap(); - // Repeat the proof generation, but with an externally generated verification key // Verification Key