Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CompositeBackend: Create proofs for each machine (working with Halo2) #1470

Merged
merged 19 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions backend/src/composite/merged_machines.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
use std::collections::{BTreeMap, BTreeSet};

pub(crate) type MergedMachines = MergedMachinesImpl<String>;

/// A simple union-find data structure to keep track of which machines to merge.
pub(crate) struct MergedMachinesImpl<T: Clone + Ord> {
georgwiese marked this conversation as resolved.
Show resolved Hide resolved
// Maps a machine ID to its "parent", i.e., the next hop towards
// the representative of the equivalence class.
// If a machine ID is not included in the map, it is a representative.
parent: BTreeMap<T, T>,
}

impl<T: Clone + Ord> MergedMachinesImpl<T> {
pub(crate) fn new() -> Self {
MergedMachinesImpl {
parent: BTreeMap::new(),
}
}

fn find(&self, machine: &T) -> T {
let mut current = machine;
while let Some(parent) = self.parent.get(current) {
current = parent;
}
current.clone()
}

pub(crate) fn merge(&mut self, machine1: T, machine2: T) {
let root1 = self.find(&machine1);
let root2 = self.find(&machine2);

if root1 != root2 {
self.parent.insert(root2, root1);
}
}

pub(crate) fn merged_machines(&self) -> BTreeSet<BTreeSet<T>> {
let mut groups: BTreeMap<T, BTreeSet<T>> = BTreeMap::new();
for machine in self.parent.keys() {
let root = self.find(machine);
groups
.entry(root.clone())
.or_default()
.insert(machine.clone());
// The root itself does not appear in the map, so we need to insert it manually.
groups.get_mut(&root).unwrap().insert(root);
}
groups.into_values().collect()
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_merged_machines() {
let mut merged_machines = MergedMachinesImpl::new();
// Equivalence classes: {{1, 2, 3, 4}, {5, 6}}
merged_machines.merge(1, 2);
merged_machines.merge(3, 4);
merged_machines.merge(2, 3);
merged_machines.merge(1, 4);
merged_machines.merge(5, 6);

let merged_machines = merged_machines.merged_machines();
assert_eq!(merged_machines.len(), 2);
assert!(merged_machines.contains(&[1, 2, 3, 4].iter().cloned().collect()));
assert!(merged_machines.contains(&[5, 6].iter().cloned().collect()));
}
}
87 changes: 70 additions & 17 deletions backend/src/composite/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,21 @@ use std::{collections::BTreeMap, io, marker::PhantomData, path::PathBuf, sync::A
use powdr_ast::analyzed::Analyzed;
use powdr_executor::witgen::WitgenCallback;
use powdr_number::{DegreeType, FieldElement};
use serde::{Deserialize, Serialize};
use split::select_machine_columns;

use crate::{Backend, BackendFactory, BackendOptions, Error, Proof};

mod merged_machines;
mod split;

/// A composite proof that contains a proof for each machine separately.
#[derive(Serialize, Deserialize)]
struct CompositeProof {
/// Map from machine name to proof
proofs: BTreeMap<String, Vec<u8>>,
}

pub(crate) struct CompositeBackendFactory<F: FieldElement, B: BackendFactory<F>> {
factory: B,
_marker: PhantomData<F>,
Expand Down Expand Up @@ -35,38 +47,50 @@ impl<F: FieldElement, B: BackendFactory<F>> BackendFactory<F> for CompositeBacke
unimplemented!();
}

let backend_by_machine = ["main"]
.iter()
.map(|machine_name| {
let per_machine_data = split::split_pil((*pil).clone())
.into_iter()
.map(|(machine_name, pil)| {
let pil = Arc::new(pil);
let output_dir = output_dir
.clone()
.map(|output_dir| output_dir.join(machine_name));
.map(|output_dir| output_dir.join(&machine_name));
if let Some(ref output_dir) = output_dir {
std::fs::create_dir_all(output_dir)?;
}
let fixed = Arc::new(select_machine_columns(
&fixed,
pil.constant_polys_in_source_order()
.into_iter()
.map(|(symbol, _)| symbol),
));
let backend = self.factory.create(
pil.clone(),
fixed.clone(),
fixed,
output_dir,
// TODO: Handle setup, verification_key, verification_app_key
None,
None,
None,
backend_options.clone(),
);
backend.map(|backend| (machine_name.to_string(), backend))
backend.map(|backend| (machine_name.to_string(), PerMachineData { pil, backend }))
})
.collect::<Result<BTreeMap<_, _>, _>>()?;
Ok(Box::new(CompositeBackend { backend_by_machine }))
Ok(Box::new(CompositeBackend { per_machine_data }))
}

fn generate_setup(&self, _size: DegreeType, _output: &mut dyn io::Write) -> Result<(), Error> {
Err(Error::NoSetupAvailable)
}
}

struct PerMachineData<'a, F: FieldElement> {
georgwiese marked this conversation as resolved.
Show resolved Hide resolved
georgwiese marked this conversation as resolved.
Show resolved Hide resolved
pil: Arc<Analyzed<F>>,
backend: Box<dyn Backend<'a, F> + 'a>,
}

pub(crate) struct CompositeBackend<'a, F: FieldElement> {
georgwiese marked this conversation as resolved.
Show resolved Hide resolved
backend_by_machine: BTreeMap<String, Box<dyn Backend<'a, F> + 'a>>,
per_machine_data: BTreeMap<String, PerMachineData<'a, F>>,
georgwiese marked this conversation as resolved.
Show resolved Hide resolved
}

// TODO: This just forwards to the backend for now. In the future this should:
Expand All @@ -81,17 +105,46 @@ impl<'a, F: FieldElement> Backend<'a, F> for CompositeBackend<'a, F> {
prev_proof: Option<Proof>,
witgen_callback: WitgenCallback<F>,
) -> Result<Proof, Error> {
self.backend_by_machine
.get("main")
.unwrap()
.prove(witness, prev_proof, witgen_callback)
if prev_proof.is_some() {
unimplemented!();
}

let proof = CompositeProof {
proofs: self
.per_machine_data
.iter()
Schaeff marked this conversation as resolved.
Show resolved Hide resolved
.map(|(machine, PerMachineData { pil, backend })| {
let witgen_callback = witgen_callback.clone().with_pil(pil.clone());

log::info!("== Proving machine: {}", machine);
log::debug!("PIL:\n{}", pil);

let witness = select_machine_columns(
witness,
pil.committed_polys_in_source_order()
.into_iter()
.map(|(symbol, _)| symbol),
);

backend
.prove(&witness, None, witgen_callback)
.map(|proof| (machine.clone(), proof))
})
.collect::<Result<_, _>>()?,
};
Ok(serde_json::to_vec(&proof).unwrap())
}

fn verify(&self, _proof: &[u8], instances: &[Vec<F>]) -> Result<(), Error> {
self.backend_by_machine
.get("main")
.unwrap()
.verify(_proof, instances)
fn verify(&self, proof: &[u8], instances: &[Vec<F>]) -> Result<(), Error> {
let proof: CompositeProof = serde_json::from_slice(proof).unwrap();
for (machine, machine_proof) in proof.proofs {
let machine_data = self
.per_machine_data
.get(&machine)
.ok_or_else(|| Error::BackendError(format!("Unknown machine: {machine}")))?;
machine_data.backend.verify(&machine_proof, instances)?;
}
Ok(())
}

fn export_setup(&self, _output: &mut dyn io::Write) -> Result<(), Error> {
Expand Down
Loading
Loading