Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CompositeBackend: Create proofs for each machine (working with Halo2) #1470

Merged
merged 19 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 71 additions & 17 deletions backend/src/composite/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,20 @@ use std::{collections::BTreeMap, io, marker::PhantomData, path::PathBuf, sync::A
use powdr_ast::analyzed::Analyzed;
use powdr_executor::witgen::WitgenCallback;
use powdr_number::{DegreeType, FieldElement};
use serde::{Deserialize, Serialize};
use split::select_machine_columns;

use crate::{Backend, BackendFactory, BackendOptions, Error, Proof};

mod split;

/// A composite proof that contains a proof for each machine separately.
#[derive(Serialize, Deserialize)]
struct CompositeProof {
/// Map from machine name to proof
proofs: BTreeMap<String, Vec<u8>>,
}

pub(crate) struct CompositeBackendFactory<F: FieldElement, B: BackendFactory<F>> {
factory: B,
_marker: PhantomData<F>,
Expand Down Expand Up @@ -35,38 +46,52 @@ impl<F: FieldElement, B: BackendFactory<F>> BackendFactory<F> for CompositeBacke
unimplemented!();
}

let backend_by_machine = ["main"]
.iter()
.map(|machine_name| {
let per_machine_data = split::split_pil((*pil).clone())
.into_iter()
.map(|(machine_name, pil)| {
let pil = Arc::new(pil);
let output_dir = output_dir
.clone()
.map(|output_dir| output_dir.join(machine_name));
.map(|output_dir| output_dir.join(&machine_name));
if let Some(ref output_dir) = output_dir {
std::fs::create_dir_all(output_dir)?;
}
let fixed = Arc::new(select_machine_columns(
&fixed,
pil.constant_polys_in_source_order()
.into_iter()
.map(|(symbol, _)| symbol),
));
let backend = self.factory.create(
pil.clone(),
fixed.clone(),
fixed,
output_dir,
// TODO: Handle setup, verification_key, verification_app_key
None,
None,
None,
backend_options.clone(),
);
backend.map(|backend| (machine_name.to_string(), backend))
backend.map(|backend| (machine_name.to_string(), MachineData { pil, backend }))
})
.collect::<Result<BTreeMap<_, _>, _>>()?;
Ok(Box::new(CompositeBackend { backend_by_machine }))
Ok(Box::new(CompositeBackend {
machine_data: per_machine_data,
}))
}

fn generate_setup(&self, _size: DegreeType, _output: &mut dyn io::Write) -> Result<(), Error> {
Err(Error::NoSetupAvailable)
}
}

struct MachineData<'a, F: FieldElement> {
pil: Arc<Analyzed<F>>,
backend: Box<dyn Backend<'a, F> + 'a>,
}

pub(crate) struct CompositeBackend<'a, F: FieldElement> {
georgwiese marked this conversation as resolved.
Show resolved Hide resolved
backend_by_machine: BTreeMap<String, Box<dyn Backend<'a, F> + 'a>>,
machine_data: BTreeMap<String, MachineData<'a, F>>,
}

// TODO: This just forwards to the backend for now. In the future this should:
Expand All @@ -81,17 +106,46 @@ impl<'a, F: FieldElement> Backend<'a, F> for CompositeBackend<'a, F> {
prev_proof: Option<Proof>,
witgen_callback: WitgenCallback<F>,
) -> Result<Proof, Error> {
self.backend_by_machine
.get("main")
.unwrap()
.prove(witness, prev_proof, witgen_callback)
if prev_proof.is_some() {
unimplemented!();
}

let proof = CompositeProof {
proofs: self
.machine_data
.iter()
Schaeff marked this conversation as resolved.
Show resolved Hide resolved
.map(|(machine, MachineData { pil, backend })| {
let witgen_callback = witgen_callback.clone().with_pil(pil.clone());

log::info!("== Proving machine: {}", machine);
log::debug!("PIL:\n{}", pil);

let witness = select_machine_columns(
witness,
pil.committed_polys_in_source_order()
.into_iter()
.map(|(symbol, _)| symbol),
);

backend
.prove(&witness, None, witgen_callback)
.map(|proof| (machine.clone(), proof))
})
.collect::<Result<_, _>>()?,
};
Ok(serde_json::to_vec(&proof).unwrap())
}

fn verify(&self, _proof: &[u8], instances: &[Vec<F>]) -> Result<(), Error> {
self.backend_by_machine
.get("main")
.unwrap()
.verify(_proof, instances)
fn verify(&self, proof: &[u8], instances: &[Vec<F>]) -> Result<(), Error> {
let proof: CompositeProof = serde_json::from_slice(proof).unwrap();
for (machine, machine_proof) in proof.proofs {
let machine_data = self
.machine_data
.get(&machine)
.ok_or_else(|| Error::BackendError(format!("Unknown machine: {machine}")))?;
machine_data.backend.verify(&machine_proof, instances)?;
}
Ok(())
}

fn export_setup(&self, _output: &mut dyn io::Write) -> Result<(), Error> {
Expand Down
194 changes: 194 additions & 0 deletions backend/src/composite/split.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
use std::{
collections::{BTreeMap, BTreeSet},
ops::ControlFlow,
str::FromStr,
};

use powdr_ast::analyzed::SelectedExpressions;
use powdr_ast::{
analyzed::{
AlgebraicExpression, Analyzed, Identity, IdentityKind, StatementIdentifier, Symbol,
},
parsed::{
asm::{AbsoluteSymbolPath, SymbolPath},
visitor::{ExpressionVisitable, VisitOrder},
},
};
use powdr_number::FieldElement;

/// Splits a PIL into multiple PILs, one for each "machine".
/// The rough algorithm is as follows:
/// 1. The PIL is split into namespaces
/// 2. Any lookups or permutations that reference multiple namespaces are removed.
pub(crate) fn split_pil<F: FieldElement>(pil: Analyzed<F>) -> BTreeMap<String, Analyzed<F>> {
let statements_by_machine = split_by_namespace(&pil);

statements_by_machine
.into_iter()
.filter_map(|(machine_name, statements)| {
build_machine_pil(pil.clone(), statements).map(|pil| (machine_name, pil))
})
.collect()
}

/// Given a set of columns and a set of symbols, returns the columns that correspond to the symbols.
pub(crate) fn select_machine_columns<'a, F: FieldElement>(
columns: &[(String, Vec<F>)],
symbols: impl Iterator<Item = &'a Symbol>,
) -> Vec<(String, Vec<F>)> {
let names = symbols
.flat_map(|symbol| symbol.array_elements().map(|(name, _)| name))
Schaeff marked this conversation as resolved.
Show resolved Hide resolved
.collect::<BTreeSet<_>>();
columns
.iter()
.filter(|(name, _)| names.contains(name))
.cloned()
.collect::<Vec<_>>()
}

/// From a symbol name, get the namespace of the symbol.
fn extract_namespace(name: &str) -> String {
let mut namespace = AbsoluteSymbolPath::default().join(SymbolPath::from_str(name).unwrap());
namespace.pop().unwrap();
namespace.relative_to(&Default::default()).to_string()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this always correct? I don't have a counterexample, but wondering if there are cases that would break this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so! The code is inspired by our implementation of Display:

let name = namespace.pop().unwrap();
if namespace != current_namespace {
current_namespace = namespace;
writeln!(
f,
"namespace {}({degree});",
current_namespace.relative_to(&Default::default())
)?;
};

}

/// From an identity, get the namespaces of the symbols it references.
fn referenced_namespaces<F: FieldElement>(
expression_visitable: &impl ExpressionVisitable<AlgebraicExpression<F>>,
) -> BTreeSet<String> {
let mut namespaces = BTreeSet::new();
expression_visitable.visit_expressions(
&mut (|expr| {
match expr {
AlgebraicExpression::Reference(reference) => {
namespaces.insert(extract_namespace(&reference.name));
}
AlgebraicExpression::PublicReference(_) => unimplemented!(),
AlgebraicExpression::Challenge(_) => {}
AlgebraicExpression::Number(_) => {}
AlgebraicExpression::BinaryOperation(_) => {}
AlgebraicExpression::UnaryOperation(_) => {}
}
ControlFlow::Continue::<()>(())
}),
VisitOrder::Pre,
);

namespaces
}

/// Organizes the PIL statements by namespace:
/// - Any definition or public declaration belongs to the namespace of the symbol.
/// - Lookups and permutations that reference multiple namespaces removed.
///
/// Returns:
/// - statements_by_namespace: A map from namespace to the statements in that namespace.
fn split_by_namespace<F: FieldElement>(
pil: &Analyzed<F>,
) -> BTreeMap<String, Vec<StatementIdentifier>> {
let mut current_namespace = String::new();

let mut statements_by_namespace: BTreeMap<String, Vec<StatementIdentifier>> = BTreeMap::new();
Schaeff marked this conversation as resolved.
Show resolved Hide resolved
for statement in pil.source_order.clone() {
let statement = match &statement {
StatementIdentifier::Definition(name)
| StatementIdentifier::PublicDeclaration(name) => {
let new_namespace = extract_namespace(name);
current_namespace = new_namespace;
Some(statement)
}
StatementIdentifier::Identity(i) => {
let identity = &pil.identities[*i];
let namespaces = referenced_namespaces(identity);

match namespaces.len() {
0 => panic!("Identity references no namespace: {identity}"),
1 => {
assert!(namespaces.iter().next().unwrap() == &current_namespace);
Schaeff marked this conversation as resolved.
Show resolved Hide resolved
Some(statement)
}
_ => match identity.kind {
IdentityKind::Plookup | IdentityKind::Permutation => {
assert_eq!(
referenced_namespaces(&identity.left).len(),
1,
"LHS of identity references multiple namespaces: {identity}"
);
assert_eq!(
referenced_namespaces(&identity.right).len(),
1,
"RHS of identity references multiple namespaces: {identity}"
);
log::debug!("Skipping connecting identity: {identity}");
None
}
_ => {
panic!("Identity references multiple namespaces: {identity}");
}
},
}
}
};

if let Some(statement) = statement {
statements_by_namespace
.entry(current_namespace.clone())
.or_default()
.push(statement);
}
}
statements_by_namespace
}

/// Given a PIL and a list of statements, returns a new PIL that only contains the
/// given subset of statements.
/// Returns None if there are no identities in the subset of statements.
fn build_machine_pil<F: FieldElement>(
pil: Analyzed<F>,
statements: Vec<StatementIdentifier>,
) -> Option<Analyzed<F>> {
// TODO: After #1488 is fixed, we can implement this like so:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this PR going to wait for that, or should it be merged before?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be merged before!

// let pil = Analyzed {
// source_order: statements,
// ..pil.clone()
// };
// let parsed_string = powdr_parser::parse(None, &pil.to_string()).unwrap();
// let pil = powdr_pil_analyzer::analyze_ast(parsed_string);

// HACK: Replace unreferenced identities with 0 = 0, to avoid having to re-assign IDs.
leonardoalt marked this conversation as resolved.
Show resolved Hide resolved
let identities = statements
.iter()
.filter_map(|statement| match statement {
StatementIdentifier::Identity(i) => Some(*i as u64),
_ => None,
})
.collect::<BTreeSet<_>>();
if identities.is_empty() {
// This can happen if a hint references some std module,
// but the module is empty.
return None;
}
let identities = pil
.identities
.iter()
.enumerate()
.map(|(identity_index, identity)| {
if identities.contains(&(identity_index as u64)) {
identity.clone()
} else {
Identity::<SelectedExpressions<AlgebraicExpression<F>>>::from_polynomial_identity(
identity.id,
identity.source.clone(),
AlgebraicExpression::Number(F::zero()),
)
}
})
.collect();

Some(Analyzed {
source_order: statements,
identities,
..pil
})
}
7 changes: 7 additions & 0 deletions executor/src/witgen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ impl<T: FieldElement> WitgenCallback<T> {
}
}

pub fn with_pil(&self, analyzed: Arc<Analyzed<T>>) -> Self {
Self {
analyzed,
..self.clone()
Schaeff marked this conversation as resolved.
Show resolved Hide resolved
}
}

/// Computes the next-stage witness, given the current witness and challenges.
pub fn next_stage_witness(
&self,
Expand Down
7 changes: 0 additions & 7 deletions pipeline/src/test_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,6 @@ pub fn gen_estark_proof(file_name: &str, inputs: Vec<GoldilocksField>) {

pipeline.clone().compute_proof().unwrap();

// Also test composite backend:
pipeline
.clone()
.with_backend(powdr_backend::BackendType::EStarkStarkyComposite, None)
.compute_proof()
.unwrap();

Comment on lines -77 to -83
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed for now, until the re-parsing is happening.

// Repeat the proof generation, but with an externally generated verification key

// Verification Key
Expand Down
Loading