Skip to content

Commit

Permalink
Implement dynamic proving, witgen always picks largest size
Browse files Browse the repository at this point in the history
  • Loading branch information
georgwiese committed Jul 11, 2024
1 parent c3ba01f commit 8906795
Show file tree
Hide file tree
Showing 8 changed files with 225 additions and 99 deletions.
168 changes: 106 additions & 62 deletions backend/src/composite/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@ use std::{

use itertools::Itertools;
use powdr_ast::analyzed::Analyzed;
use powdr_executor::{
constant_evaluator::{get_uniquely_sized_cloned, VariablySizedColumn},
witgen::WitgenCallback,
};
use powdr_executor::{constant_evaluator::VariablySizedColumn, witgen::WitgenCallback};
use powdr_number::{DegreeType, FieldElement};
use serde::{Deserialize, Serialize};
use split::{machine_fixed_columns, machine_witness_columns};
Expand All @@ -24,14 +21,15 @@ mod split;
#[derive(Serialize, Deserialize)]
struct CompositeVerificationKey {
/// Verification key for each machine (if available, otherwise None), sorted by machine name.
verification_keys: Vec<Option<Vec<u8>>>,
verification_keys: Vec<Option<BTreeMap<usize, Vec<u8>>>>,
}

/// A composite proof that contains a proof for each machine separately, sorted by machine name.
#[derive(Serialize, Deserialize)]
struct CompositeProof {
/// Map from machine name to proof
proofs: Vec<Vec<u8>>,
sizes: Vec<usize>,
}

pub(crate) struct CompositeBackendFactory<F: FieldElement, B: BackendFactory<F>> {
Expand Down Expand Up @@ -63,11 +61,6 @@ impl<F: FieldElement, B: BackendFactory<F>> BackendFactory<F> for CompositeBacke
unimplemented!();
}

// TODO: Handle multiple sizes.
let fixed = Arc::new(
get_uniquely_sized_cloned(&fixed).map_err(|_| Error::NoVariableDegreeAvailable)?,
);

let pils = split::split_pil((*pil).clone());

// Read the setup once (if any) to pass to all backends.
Expand All @@ -89,41 +82,46 @@ impl<F: FieldElement, B: BackendFactory<F>> BackendFactory<F> for CompositeBacke
.into_iter()
.zip_eq(verification_keys.into_iter())
.map(|((machine_name, pil), verification_key)| {
// Set up readers for the setup and verification key
let mut setup_cursor = setup_bytes.as_ref().map(Cursor::new);
let setup = setup_cursor.as_mut().map(|cursor| cursor as &mut dyn Read);
let pil = Arc::new(pil);
machine_fixed_columns(&fixed, &pil)
.into_iter()
.map(|(size, fixed)| {
let pil = Arc::new(set_size(&pil, size as DegreeType));
// Set up readers for the setup and verification key
let mut setup_cursor = setup_bytes.as_ref().map(Cursor::new);
let setup = setup_cursor.as_mut().map(|cursor| cursor as &mut dyn Read);

let mut verification_key_cursor = verification_key.as_ref().map(Cursor::new);
let verification_key = verification_key_cursor
.as_mut()
.map(|cursor| cursor as &mut dyn Read);
let mut verification_key_cursor = verification_key
.as_ref()
.map(|keys| Cursor::new(keys.get(&size).unwrap()));
let verification_key = verification_key_cursor
.as_mut()
.map(|cursor| cursor as &mut dyn Read);

let pil = Arc::new(pil);
let output_dir = output_dir
.clone()
.map(|output_dir| output_dir.join(&machine_name));
if let Some(ref output_dir) = output_dir {
std::fs::create_dir_all(output_dir)?;
}
let fixed = Arc::new(
machine_fixed_columns(&fixed, &pil)
.into_iter()
.map(|(column_name, values)| (column_name, values.into()))
.collect(),
);
let backend = self.factory.create(
pil.clone(),
fixed,
output_dir,
setup,
verification_key,
// TODO: Handle verification_app_key
None,
backend_options.clone(),
);
backend.map(|backend| (machine_name.to_string(), MachineData { pil, backend }))
let output_dir = output_dir
.clone()
.map(|output_dir| output_dir.join(&machine_name));
if let Some(ref output_dir) = output_dir {
std::fs::create_dir_all(output_dir)?;
}
let fixed = Arc::new(fixed);
let backend = self.factory.create(
pil.clone(),
fixed,
output_dir,
setup,
verification_key,
// TODO: Handle verification_app_key
None,
backend_options.clone(),
);
backend.map(|backend| (size, MachineData { pil, backend }))
})
.collect::<Result<BTreeMap<_, _>, _>>()
.map(|backends| (machine_name, backends))
})
.collect::<Result<_, _>>()?;
.collect::<Result<BTreeMap<_, _>, _>>()?;

Ok(Box::new(CompositeBackend { machine_data }))
}

Expand All @@ -141,7 +139,20 @@ pub(crate) struct CompositeBackend<'a, F> {
/// Maps each machine name to the corresponding machine data
/// Note that it is essential that we use BTreeMap here to ensure that the machines are
/// deterministically ordered.
machine_data: BTreeMap<String, MachineData<'a, F>>,
machine_data: BTreeMap<String, BTreeMap<usize, MachineData<'a, F>>>,
}

fn set_size<F: Clone>(pil: &Analyzed<F>, degree: DegreeType) -> Analyzed<F> {
let pil = pil.clone();
let definitions = pil
.definitions
.into_iter()
.map(|(name, (mut symbol, def))| {
symbol.degree = Some(degree);
(name, (symbol, def))
})
.collect();
Analyzed { definitions, ..pil }
}

// TODO: This just forwards to the backend for now. In the future this should:
Expand All @@ -160,36 +171,61 @@ impl<'a, F: FieldElement> Backend<'a, F> for CompositeBackend<'a, F> {
unimplemented!();
}

let proof = CompositeProof {
proofs: self
.machine_data
.iter()
.map(|(machine, MachineData { pil, backend })| {
let witgen_callback = witgen_callback.clone().with_pil(pil.clone());
let (sizes, proofs) = self
.machine_data
.iter()
.map(|(machine, machine_data)| {
let any_pil = &machine_data.values().next().unwrap().pil;
let witness = machine_witness_columns(witness, any_pil, machine);
let size = witness
.iter()
.map(|(_, witness)| witness.len())
.unique()
.exactly_one()
.expect("All witness columns of a machine must have the same size");
let machine_data = machine_data
.get(&size)
.expect("Machine does not support the given size");
let witgen_callback = witgen_callback.clone().with_pil(machine_data.pil.clone());

log::info!("== Proving machine: {} (size {})", machine, pil.degree());
log::debug!("PIL:\n{}", pil);
log::info!("== Proving machine: {} (size {})", machine, size);
log::debug!("PIL:\n{}", machine_data.pil);

let witness = machine_witness_columns(witness, pil, machine);
machine_data
.backend
.prove(&witness, None, witgen_callback)
.map(|proof| (size, proof))
})
.collect::<Result<Vec<(usize, Vec<u8>)>, _>>()?
.into_iter()
.unzip();

backend.prove(&witness, None, witgen_callback)
})
.collect::<Result<_, _>>()?,
};
let proof = CompositeProof { sizes, proofs };
Ok(bincode::serialize(&proof).unwrap())
}

fn verify(&self, proof: &[u8], instances: &[Vec<F>]) -> Result<(), Error> {
let proof: CompositeProof = bincode::deserialize(proof).unwrap();
for (machine_data, machine_proof) in self.machine_data.values().zip_eq(proof.proofs) {
machine_data.backend.verify(&machine_proof, instances)?;
for (machine_data, (machine_proof, size)) in self
.machine_data
.values()
.zip_eq(proof.proofs.into_iter().zip_eq(proof.sizes))
{
machine_data
.get(&size)
.unwrap()
.backend
.verify(&machine_proof, instances)?;
}
Ok(())
}

fn export_setup(&self, output: &mut dyn io::Write) -> Result<(), Error> {
// All backend are the same, just pick the first
self.machine_data
.values()
.next()
.unwrap()
.values()
.next()
.unwrap()
Expand All @@ -203,10 +239,18 @@ impl<'a, F: FieldElement> Backend<'a, F> for CompositeBackend<'a, F> {
.machine_data
.values()
.map(|machine_data| {
let backend = machine_data.backend.as_ref();
let vk_bytes = backend.verification_key_bytes();
match vk_bytes {
Ok(vk_bytes) => Ok(Some(vk_bytes)),
let verification_keys = machine_data
.iter()
.map(|(size, machine_data)| {
machine_data
.backend
.verification_key_bytes()
.map(|vk_bytes| (*size, vk_bytes))
})
.collect::<Result<_, _>>();

match verification_keys {
Ok(verification_keys) => Ok(Some(verification_keys)),
Err(Error::NoVerificationAvailable) => Ok(None),
Err(e) => Err(e),
}
Expand Down
75 changes: 62 additions & 13 deletions backend/src/composite/split.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use powdr_ast::{
visitor::{ExpressionVisitable, VisitOrder},
},
};
use powdr_executor::constant_evaluator::VariablySizedColumn;
use powdr_number::FieldElement;

const DUMMY_COLUMN_NAME: &str = "__dummy";
Expand Down Expand Up @@ -43,40 +44,88 @@ pub(crate) fn machine_witness_columns<F: FieldElement>(
machine_pil: &Analyzed<F>,
machine_name: &str,
) -> Vec<(String, Vec<F>)> {
let machine_columns = select_machine_columns(
all_witness_columns,
machine_pil.committed_polys_in_source_order(),
)
.into_iter()
.cloned()
.collect::<Vec<_>>();
let size = machine_columns
.iter()
.map(|(_, column)| column.len())
.unique()
.exactly_one()
.unwrap_or_else(|err| {
if err.try_len().unwrap() == 0 {
// No witness column, use degree of provided PIL
machine_pil.degree() as usize
} else {
panic!("Machine {machine_name} has witness columns of different sizes")
}
});
let dummy_column_name = format!("{machine_name}.{DUMMY_COLUMN_NAME}");
let dummy_column = vec![F::zero(); machine_pil.degree() as usize];
let dummy_column = vec![F::zero(); size];
iter::once((dummy_column_name, dummy_column))
.chain(select_machine_columns(
all_witness_columns,
machine_pil.committed_polys_in_source_order(),
))
.chain(machine_columns)
.collect::<Vec<_>>()
}

/// Given a set of columns and a PIL describing the machine, returns the fixed column that belong to the machine.
pub(crate) fn machine_fixed_columns<F: FieldElement>(
all_fixed_columns: &[(String, Vec<F>)],
all_fixed_columns: &[(String, VariablySizedColumn<F>)],
machine_pil: &Analyzed<F>,
) -> Vec<(String, Vec<F>)> {
select_machine_columns(
) -> BTreeMap<usize, Vec<(String, VariablySizedColumn<F>)>> {
let machine_columns = select_machine_columns(
all_fixed_columns,
machine_pil.constant_polys_in_source_order(),
)
);
let sizes = machine_columns
.iter()
.map(|(_, column)| {
column
.column_by_size
.keys()
.cloned()
.collect::<BTreeSet<_>>()
})
.collect::<BTreeSet<_>>();

assert_eq!(
sizes.len(),
1,
"All fixed columns of a machine must have the same sizes"
);
let sizes = sizes.into_iter().next().unwrap();

sizes
.into_iter()
.map(|size| {
(
size,
machine_columns
.iter()
.map(|(name, column)| {
(name.clone(), column.column_by_size[&size].clone().into())
})
.collect::<Vec<_>>(),
)
})
.collect()
}

/// Filter the given columns to only include those that are referenced by the given symbols.
fn select_machine_columns<F: FieldElement, T>(
columns: &[(String, Vec<F>)],
fn select_machine_columns<'a, T, C>(
columns: &'a [(String, C)],
symbols: Vec<&(Symbol, T)>,
) -> Vec<(String, Vec<F>)> {
) -> Vec<&'a (String, C)> {
let names = symbols
.into_iter()
.flat_map(|(symbol, _)| symbol.array_elements().map(|(name, _)| name))
.collect::<BTreeSet<_>>();
columns
.iter()
.filter(|(name, _)| names.contains(name))
.cloned()
.collect::<Vec<_>>()
}

Expand Down
23 changes: 22 additions & 1 deletion executor/src/constant_evaluator/data_structures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::collections::BTreeMap;

#[derive(Serialize, Deserialize)]
pub struct VariablySizedColumn<F> {
column_by_size: BTreeMap<usize, Vec<F>>,
pub column_by_size: BTreeMap<usize, Vec<F>>,
}

#[derive(Debug)]
Expand All @@ -28,6 +28,16 @@ pub fn get_uniquely_sized<F>(
.collect()
}

pub fn get_max_sized<F>(column: &[(String, VariablySizedColumn<F>)]) -> Vec<(String, &Vec<F>)> {
column
.iter()
.map(|(name, column)| {
let max_size = column.column_by_size.keys().max().unwrap();
(name.clone(), &column.column_by_size[max_size])
})
.collect()
}

pub fn get_uniquely_sized_cloned<F: Clone>(
column: &[(String, VariablySizedColumn<F>)],
) -> Result<Vec<(String, Vec<F>)>, HasMultipleSizesError> {
Expand All @@ -46,3 +56,14 @@ impl<F> From<Vec<F>> for VariablySizedColumn<F> {
}
}
}

impl<F> From<Vec<Vec<F>>> for VariablySizedColumn<F> {
fn from(columns: Vec<Vec<F>>) -> Self {
VariablySizedColumn {
column_by_size: columns
.into_iter()
.map(|column| (column.len(), column))
.collect(),
}
}
}
Loading

0 comments on commit 8906795

Please sign in to comment.