From 75db53a0dcfa9e137dc9cb8a6145af6f96eed3d4 Mon Sep 17 00:00:00 2001 From: Georg Wiese Date: Fri, 28 Jun 2024 09:25:01 +0200 Subject: [PATCH 1/2] BackendFactory: Use Arc instead of references --- backend/src/composite/mod.rs | 10 +++++----- backend/src/estark/mod.rs | 9 +++++---- backend/src/estark/starky_wrapper.rs | 19 ++++++++++--------- backend/src/halo2/circuit_builder.rs | 16 +++++++++------- backend/src/halo2/mock_prover.rs | 4 +++- backend/src/halo2/mod.rs | 21 +++++++++++---------- backend/src/halo2/prover.rs | 19 ++++++++++--------- backend/src/lib.rs | 6 +++--- pipeline/src/pipeline.rs | 16 ++++++++-------- 9 files changed, 64 insertions(+), 56 deletions(-) diff --git a/backend/src/composite/mod.rs b/backend/src/composite/mod.rs index aca2b98a8..e8eb63b25 100644 --- a/backend/src/composite/mod.rs +++ b/backend/src/composite/mod.rs @@ -1,4 +1,4 @@ -use std::{collections::BTreeMap, io, marker::PhantomData, path::PathBuf}; +use std::{collections::BTreeMap, io, marker::PhantomData, path::PathBuf, sync::Arc}; use powdr_ast::analyzed::Analyzed; use powdr_executor::witgen::WitgenCallback; @@ -23,8 +23,8 @@ impl> CompositeBackendFactory { impl> BackendFactory for CompositeBackendFactory { fn create<'a>( &self, - pil: &'a Analyzed, - fixed: &'a [(String, Vec)], + pil: Arc>, + fixed: Arc)>>, output_dir: Option, setup: Option<&mut dyn std::io::Read>, verification_key: Option<&mut dyn std::io::Read>, @@ -45,8 +45,8 @@ impl> BackendFactory for CompositeBacke std::fs::create_dir_all(output_dir)?; } let backend = self.factory.create( - pil, - fixed, + pil.clone(), + fixed.clone(), output_dir, // TODO: Handle setup, verification_key, verification_app_key None, diff --git a/backend/src/estark/mod.rs b/backend/src/estark/mod.rs index 420373b67..a638866dc 100644 --- a/backend/src/estark/mod.rs +++ b/backend/src/estark/mod.rs @@ -7,6 +7,7 @@ use std::{ fs::{hard_link, remove_file}, iter::{once, repeat}, path::{Path, PathBuf}, + sync::Arc, }; use crate::{Backend, BackendFactory, BackendOptions, Error, Proof}; @@ -214,8 +215,8 @@ pub struct DumpFactory; impl BackendFactory for DumpFactory { fn create<'a>( &self, - analyzed: &'a Analyzed, - fixed: &'a [(String, Vec)], + analyzed: Arc>, + fixed: Arc)>>, output_dir: Option, setup: Option<&mut dyn std::io::Read>, verification_key: Option<&mut dyn std::io::Read>, @@ -223,8 +224,8 @@ impl BackendFactory for DumpFactory { options: BackendOptions, ) -> Result + 'a>, Error> { Ok(Box::new(DumpBackend(EStarkFilesCommon::create( - analyzed, - fixed, + &analyzed, + &fixed, output_dir, setup, verification_key, diff --git a/backend/src/estark/starky_wrapper.rs b/backend/src/estark/starky_wrapper.rs index 4ca379229..c94c9908c 100644 --- a/backend/src/estark/starky_wrapper.rs +++ b/backend/src/estark/starky_wrapper.rs @@ -1,6 +1,7 @@ +use std::io; use std::path::PathBuf; +use std::sync::Arc; use std::time::Instant; -use std::{borrow::Cow, io}; use crate::{Backend, BackendFactory, BackendOptions, Error}; use powdr_ast::analyzed::Analyzed; @@ -25,8 +26,8 @@ pub struct Factory; impl BackendFactory for Factory { fn create<'a>( &self, - pil: &'a Analyzed, - fixed: &'a [(String, Vec)], + pil: Arc>, + fixed: Arc)>>, _output_dir: Option, setup: Option<&mut dyn std::io::Read>, verification_key: Option<&mut dyn std::io::Read>, @@ -50,9 +51,9 @@ impl BackendFactory for Factory { let params = create_stark_struct(pil.degree(), proof_type.hash_type()); - let (pil_json, patched_fixed) = first_step_fixup(pil, fixed); + let (pil_json, patched_fixed) = first_step_fixup(&pil, &fixed); - let fixed = patched_fixed.map_or_else(|| Cow::Borrowed(fixed), Cow::Owned); + let fixed = patched_fixed.map_or_else(|| fixed.clone(), Arc::new); let const_pols = to_starky_pols_array(&fixed, &pil_json, PolKind::Constant); @@ -86,8 +87,8 @@ fn create_stark_setup( .unwrap() } -pub struct EStark<'a, F: FieldElement> { - fixed: Cow<'a, [(String, Vec)]>, +pub struct EStark { + fixed: Arc)>>, pil_json: PIL, params: StarkStruct, // eSTARK calls it setup, but it works similarly to a verification key and depends only on the @@ -96,7 +97,7 @@ pub struct EStark<'a, F: FieldElement> { proof_type: ProofType, } -impl<'a, F: FieldElement> EStark<'a, F> { +impl EStark { fn verify_stark_gl_with_publics( &self, proof: &StarkProof, @@ -183,7 +184,7 @@ impl<'a, F: FieldElement> EStark<'a, F> { } } -impl<'a, F: FieldElement> Backend<'a, F> for EStark<'a, F> { +impl<'a, F: FieldElement> Backend<'a, F> for EStark { fn verify(&self, proof: &[u8], instances: &[Vec]) -> Result<(), Error> { match self.proof_type { ProofType::StarkGL => { diff --git a/backend/src/halo2/circuit_builder.rs b/backend/src/halo2/circuit_builder.rs index f5023feed..bad2785ef 100644 --- a/backend/src/halo2/circuit_builder.rs +++ b/backend/src/halo2/circuit_builder.rs @@ -1,4 +1,4 @@ -use std::{cmp::max, collections::BTreeMap, iter}; +use std::{cmp::max, collections::BTreeMap, iter, sync::Arc}; use halo2_curves::ff::PrimeField; use halo2_proofs::{ @@ -52,7 +52,7 @@ impl PowdrCircuitConfig { #[derive(Clone)] /// Wraps an Analyzed. This is used as the PowdrCircuit::Params type, which required /// a type that implements Default. -pub(crate) struct AnalyzedWrapper(Analyzed); +pub(crate) struct AnalyzedWrapper(Arc>); impl Default for AnalyzedWrapper where @@ -65,8 +65,8 @@ where } } -impl From> for AnalyzedWrapper { - fn from(analyzed: Analyzed) -> Self { +impl From>> for AnalyzedWrapper { + fn from(analyzed: Arc>) -> Self { Self(analyzed) } } @@ -74,7 +74,7 @@ impl From> for AnalyzedWrapper { #[derive(Clone)] pub(crate) struct PowdrCircuit<'a, T> { /// The analyzed PIL - analyzed: &'a Analyzed, + analyzed: Arc>, /// The value of the fixed columns fixed: &'a [(String, Vec)], /// The value of the witness columns, if set @@ -102,17 +102,19 @@ fn get_publics(analyzed: &Analyzed) -> Vec<(String, usize)> } impl<'a, T: FieldElement> PowdrCircuit<'a, T> { - pub(crate) fn new(analyzed: &'a Analyzed, fixed: &'a [(String, Vec)]) -> Self { + pub(crate) fn new(analyzed: Arc>, fixed: &'a [(String, Vec)]) -> Self { for (fixed_name, _) in fixed { assert!(fixed_name != FIRST_STEP_NAME); assert!(fixed_name != ENABLE_NAME); } + let publics = get_publics(&analyzed); + Self { analyzed, fixed, witness: None, - publics: get_publics(analyzed), + publics, witgen_callback: None, } } diff --git a/backend/src/halo2/mock_prover.rs b/backend/src/halo2/mock_prover.rs index 65e2e206d..3e6f41ffb 100644 --- a/backend/src/halo2/mock_prover.rs +++ b/backend/src/halo2/mock_prover.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use powdr_ast::analyzed::Analyzed; use powdr_executor::witgen::WitgenCallback; @@ -8,7 +10,7 @@ use powdr_number::{FieldElement, KnownField}; // Can't depend on compiler::pipeline::GeneratedWitness because of circular dependencies... pub fn mock_prove( - pil: &Analyzed, + pil: Arc>, constants: &[(String, Vec)], witness: &[(String, Vec)], witgen_callback: WitgenCallback, diff --git a/backend/src/halo2/mod.rs b/backend/src/halo2/mod.rs index bc30c2e6c..b548112c9 100644 --- a/backend/src/halo2/mod.rs +++ b/backend/src/halo2/mod.rs @@ -2,6 +2,7 @@ use std::io; use std::path::PathBuf; +use std::sync::Arc; use crate::{Backend, BackendFactory, BackendOptions, Error, Proof}; use powdr_ast::analyzed::Analyzed; @@ -74,8 +75,8 @@ where impl BackendFactory for Halo2ProverFactory { fn create<'a>( &self, - pil: &'a Analyzed, - fixed: &'a [(String, Vec)], + pil: Arc>, + fixed: Arc)>>, _output_dir: Option, setup: Option<&mut dyn io::Read>, verification_key: Option<&mut dyn io::Read>, @@ -108,7 +109,7 @@ fn fe_slice_to_string(fe: &[F]) -> Vec { fe.iter().map(|x| x.to_string()).collect() } -impl<'a, T: FieldElement> Backend<'a, T> for Halo2Prover<'a, T> { +impl<'a, T: FieldElement> Backend<'a, T> for Halo2Prover { fn verify(&self, proof: &[u8], instances: &[Vec]) -> Result<(), Error> { let proof: Halo2Proof = serde_json::from_slice(proof).unwrap(); // TODO should do a verification refactoring making it a 1d vec @@ -180,8 +181,8 @@ pub(crate) struct Halo2MockFactory; impl BackendFactory for Halo2MockFactory { fn create<'a>( &self, - pil: &'a Analyzed, - fixed: &'a [(String, Vec)], + pil: Arc>, + fixed: Arc)>>, _output_dir: Option, setup: Option<&mut dyn io::Read>, verification_key: Option<&mut dyn io::Read>, @@ -202,12 +203,12 @@ impl BackendFactory for Halo2MockFactory { } } -pub struct Halo2Mock<'a, F: FieldElement> { - pil: &'a Analyzed, - fixed: &'a [(String, Vec)], +pub struct Halo2Mock { + pil: Arc>, + fixed: Arc)>>, } -impl<'a, T: FieldElement> Backend<'a, T> for Halo2Mock<'a, T> { +impl<'a, T: FieldElement> Backend<'a, T> for Halo2Mock { fn prove( &self, witness: &[(String, Vec)], @@ -218,7 +219,7 @@ impl<'a, T: FieldElement> Backend<'a, T> for Halo2Mock<'a, T> { return Err(Error::NoAggregationAvailable); } - mock_prover::mock_prove(self.pil, self.fixed, witness, witgen_callback) + mock_prover::mock_prove(self.pil.clone(), &self.fixed, witness, witgen_callback) .map_err(Error::BackendError)?; Ok(vec![]) diff --git a/backend/src/halo2/prover.rs b/backend/src/halo2/prover.rs index 2830c3f62..153888040 100644 --- a/backend/src/halo2/prover.rs +++ b/backend/src/halo2/prover.rs @@ -48,6 +48,7 @@ use itertools::Itertools; use rand::rngs::OsRng; use std::{ io::{self, Cursor}, + sync::Arc, time::Instant, }; @@ -58,9 +59,9 @@ use std::{ /// This only works with Bn254, so it really shouldn't be generic over the field /// element, but without RFC #1210, the only alternative I found is a very ugly /// "unsafe" code, and unsafe code is harder to explain and maintain. -pub struct Halo2Prover<'a, F> { - analyzed: &'a Analyzed, - fixed: &'a [(String, Vec)], +pub struct Halo2Prover { + analyzed: Arc>, + fixed: Arc)>>, params: ParamsKZG, // Verification key of the proof type we're generating vkey: Option>, @@ -82,10 +83,10 @@ pub fn generate_setup(size: DegreeType) -> ParamsKZG { ParamsKZG::::new(std::cmp::max(4, degree_bits(size))) } -impl<'a, F: FieldElement> Halo2Prover<'a, F> { +impl Halo2Prover { pub fn new( - analyzed: &'a Analyzed, - fixed: &'a [(String, Vec)], + analyzed: Arc>, + fixed: Arc)>>, setup: Option<&mut dyn io::Read>, proof_type: ProofType, ) -> Result { @@ -129,7 +130,7 @@ impl<'a, F: FieldElement> Halo2Prover<'a, F> { ) -> Result<(Vec, Vec>), String> { log::info!("Starting proof generation..."); - let circuit = PowdrCircuit::new(self.analyzed, self.fixed) + let circuit = PowdrCircuit::new(self.analyzed.clone(), &self.fixed) .with_witgen_callback(witgen_callback) .with_witness(witness); let publics = vec![circuit.instance_column()]; @@ -238,7 +239,7 @@ impl<'a, F: FieldElement> Halo2Prover<'a, F> { log::info!("Generating circuit for app snark..."); - let circuit_app = PowdrCircuit::new(self.analyzed, self.fixed) + let circuit_app = PowdrCircuit::new(self.analyzed.clone(), &self.fixed) .with_witgen_callback(witgen_callback) .with_witness(witness); @@ -362,7 +363,7 @@ impl<'a, F: FieldElement> Halo2Prover<'a, F> { } fn generate_verification_key_single(&self) -> Result, String> { - let circuit = PowdrCircuit::new(self.analyzed, self.fixed); + let circuit = PowdrCircuit::new(self.analyzed.clone(), &self.fixed); keygen_vk(&self.params, &circuit).map_err(|e| e.to_string()) } diff --git a/backend/src/lib.rs b/backend/src/lib.rs index 09d1f7dc8..9901fd438 100644 --- a/backend/src/lib.rs +++ b/backend/src/lib.rs @@ -11,7 +11,7 @@ mod composite; use powdr_ast::analyzed::Analyzed; use powdr_executor::witgen::WitgenCallback; use powdr_number::{DegreeType, FieldElement}; -use std::{io, path::PathBuf}; +use std::{io, path::PathBuf, sync::Arc}; use strum::{Display, EnumString, EnumVariantNames}; #[derive(Clone, EnumString, EnumVariantNames, Display, Copy)] @@ -131,8 +131,8 @@ pub trait BackendFactory { #[allow(clippy::too_many_arguments)] fn create<'a>( &self, - pil: &'a Analyzed, - fixed: &'a [(String, Vec)], + pil: Arc>, + fixed: Arc)>>, output_dir: Option, setup: Option<&mut dyn io::Read>, verification_key: Option<&mut dyn io::Read>, diff --git a/pipeline/src/pipeline.rs b/pipeline/src/pipeline.rs index a079f6a98..509ac6cb6 100644 --- a/pipeline/src/pipeline.rs +++ b/pipeline/src/pipeline.rs @@ -911,8 +911,8 @@ impl Pipeline { /* Create the backend */ let backend = factory .create( - pil.borrow(), - &fixed_cols[..], + pil.clone(), + fixed_cols.clone(), self.output_dir.clone(), setup.as_io_read(), vkey.as_io_read(), @@ -993,8 +993,8 @@ impl Pipeline { let backend = factory .create( - pil.borrow(), - &fixed_cols[..], + pil.clone(), + fixed_cols.clone(), self.output_dir.clone(), setup_file .as_mut() @@ -1047,8 +1047,8 @@ impl Pipeline { let backend = factory .create( - pil.borrow(), - &fixed_cols[..], + pil.clone(), + fixed_cols.clone(), self.output_dir.clone(), setup_file .as_mut() @@ -1095,8 +1095,8 @@ impl Pipeline { let backend = factory .create( - pil.borrow(), - &fixed_cols[..], + pil.clone(), + fixed_cols.clone(), self.output_dir.clone(), setup_file .as_mut() From 093f808403bd157c54e8f469bb91b358bdbce16a Mon Sep 17 00:00:00 2001 From: Georg Wiese Date: Fri, 28 Jun 2024 09:39:39 +0200 Subject: [PATCH 2/2] Fix more backends --- backend/src/estark/polygon_wrapper.rs | 10 +++++----- backend/src/plonky3/mod.rs | 8 ++++---- plonky3/src/prover/mod.rs | 18 ++++++++++-------- 3 files changed, 19 insertions(+), 17 deletions(-) diff --git a/backend/src/estark/polygon_wrapper.rs b/backend/src/estark/polygon_wrapper.rs index 14b2ddd5f..51e5fc3e6 100644 --- a/backend/src/estark/polygon_wrapper.rs +++ b/backend/src/estark/polygon_wrapper.rs @@ -1,4 +1,4 @@ -use std::{fs, path::PathBuf}; +use std::{fs, path::PathBuf, sync::Arc}; use powdr_ast::analyzed::Analyzed; use powdr_executor::witgen::WitgenCallback; @@ -13,8 +13,8 @@ pub struct Factory; impl BackendFactory for Factory { fn create<'a>( &self, - analyzed: &'a Analyzed, - fixed: &'a [(String, Vec)], + analyzed: Arc>, + fixed: Arc)>>, output_dir: Option, setup: Option<&mut dyn std::io::Read>, verification_key: Option<&mut dyn std::io::Read>, @@ -22,8 +22,8 @@ impl BackendFactory for Factory { options: BackendOptions, ) -> Result + 'a>, Error> { Ok(Box::new(PolygonBackend(EStarkFilesCommon::create( - analyzed, - fixed, + &analyzed, + &fixed, output_dir, setup, verification_key, diff --git a/backend/src/plonky3/mod.rs b/backend/src/plonky3/mod.rs index 0bd4c35b5..3c24048b7 100644 --- a/backend/src/plonky3/mod.rs +++ b/backend/src/plonky3/mod.rs @@ -1,4 +1,4 @@ -use std::{io, path::PathBuf}; +use std::{io, path::PathBuf, sync::Arc}; use powdr_ast::analyzed::Analyzed; use powdr_executor::witgen::WitgenCallback; @@ -12,8 +12,8 @@ pub(crate) struct Factory; impl BackendFactory for Factory { fn create<'a>( &self, - pil: &'a Analyzed, - _fixed: &'a [(String, Vec)], + pil: Arc>, + _fixed: Arc)>>, _output_dir: Option, setup: Option<&mut dyn io::Read>, verification_key: Option<&mut dyn io::Read>, @@ -33,7 +33,7 @@ impl BackendFactory for Factory { } } -impl<'a, T: FieldElement> Backend<'a, T> for Plonky3Prover<'a, T> { +impl<'a, T: FieldElement> Backend<'a, T> for Plonky3Prover { fn verify(&self, proof: &[u8], instances: &[Vec]) -> Result<(), Error> { Ok(self.verify(proof, instances)?) } diff --git a/plonky3/src/prover/mod.rs b/plonky3/src/prover/mod.rs index 1c8bece98..ea0c7ad3f 100644 --- a/plonky3/src/prover/mod.rs +++ b/plonky3/src/prover/mod.rs @@ -2,6 +2,8 @@ mod params; +use std::sync::Arc; + use powdr_ast::analyzed::Analyzed; use powdr_executor::witgen::WitgenCallback; @@ -14,18 +16,18 @@ use crate::circuit_builder::{cast_to_goldilocks, PowdrCircuit}; use self::params::{get_challenger, get_config}; #[derive(Clone)] -pub struct Plonky3Prover<'a, T> { +pub struct Plonky3Prover { /// The analyzed PIL - analyzed: &'a Analyzed, + analyzed: Arc>, } -impl<'a, T> Plonky3Prover<'a, T> { - pub fn new(analyzed: &'a Analyzed) -> Self { +impl Plonky3Prover { + pub fn new(analyzed: Arc>) -> Self { Self { analyzed } } } -impl<'a, T: FieldElement> Plonky3Prover<'a, T> { +impl Plonky3Prover { pub fn prove( &self, witness: &[(String, Vec)], @@ -33,7 +35,7 @@ impl<'a, T: FieldElement> Plonky3Prover<'a, T> { ) -> Result, String> { assert_eq!(T::known_field(), Some(KnownField::GoldilocksField)); - let circuit = PowdrCircuit::new(self.analyzed) + let circuit = PowdrCircuit::new(&self.analyzed) .with_witgen_callback(witgen_callback) .with_witness(witness); @@ -68,7 +70,7 @@ impl<'a, T: FieldElement> Plonky3Prover<'a, T> { verify( &config, - &PowdrCircuit::new(self.analyzed), + &PowdrCircuit::new(&self.analyzed), &mut challenger, &proof, &publics, @@ -96,7 +98,7 @@ mod tests { let witness_callback = pipeline.witgen_callback().unwrap(); let witness = pipeline.compute_witness().unwrap(); - let prover = Plonky3Prover::new(&pil); + let prover = Plonky3Prover::new(pil); let proof = prover.prove(&witness, witness_callback); assert!(proof.is_ok());