Skip to content

Commit

Permalink
BackendFactory: Use Arc instead of references
Browse files Browse the repository at this point in the history
  • Loading branch information
georgwiese committed Jun 28, 2024
1 parent e7c87a8 commit f6c7c81
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 56 deletions.
10 changes: 5 additions & 5 deletions backend/src/composite/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{collections::BTreeMap, io, marker::PhantomData, path::PathBuf};
use std::{collections::BTreeMap, io, marker::PhantomData, path::PathBuf, sync::Arc};

use powdr_ast::analyzed::Analyzed;
use powdr_executor::witgen::WitgenCallback;
Expand All @@ -23,8 +23,8 @@ impl<F: FieldElement, B: BackendFactory<F>> CompositeBackendFactory<F, B> {
impl<F: FieldElement, B: BackendFactory<F>> BackendFactory<F> for CompositeBackendFactory<F, B> {
fn create<'a>(
&self,
pil: &'a Analyzed<F>,
fixed: &'a [(String, Vec<F>)],
pil: Arc<Analyzed<F>>,
fixed: Arc<Vec<(String, Vec<F>)>>,
output_dir: Option<PathBuf>,
setup: Option<&mut dyn std::io::Read>,
verification_key: Option<&mut dyn std::io::Read>,
Expand All @@ -45,8 +45,8 @@ impl<F: FieldElement, B: BackendFactory<F>> BackendFactory<F> for CompositeBacke
std::fs::create_dir_all(output_dir)?;
}
let backend = self.factory.create(
pil,
fixed,
pil.clone(),
fixed.clone(),
output_dir,
// TODO: Handle setup, verification_key, verification_app_key
None,
Expand Down
9 changes: 5 additions & 4 deletions backend/src/estark/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::{
fs::{hard_link, remove_file},
iter::{once, repeat},
path::{Path, PathBuf},
sync::Arc,
};

use crate::{Backend, BackendFactory, BackendOptions, Error, Proof};
Expand Down Expand Up @@ -214,17 +215,17 @@ pub struct DumpFactory;
impl<F: FieldElement> BackendFactory<F> for DumpFactory {
fn create<'a>(
&self,
analyzed: &'a Analyzed<F>,
fixed: &'a [(String, Vec<F>)],
analyzed: Arc<Analyzed<F>>,
fixed: Arc<Vec<(String, Vec<F>)>>,
output_dir: Option<PathBuf>,
setup: Option<&mut dyn std::io::Read>,
verification_key: Option<&mut dyn std::io::Read>,
verification_app_key: Option<&mut dyn std::io::Read>,
options: BackendOptions,
) -> Result<Box<dyn crate::Backend<'a, F> + 'a>, Error> {
Ok(Box::new(DumpBackend(EStarkFilesCommon::create(
analyzed,
fixed,
&analyzed,
&fixed,
output_dir,
setup,
verification_key,
Expand Down
19 changes: 10 additions & 9 deletions backend/src/estark/starky_wrapper.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::io;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Instant;
use std::{borrow::Cow, io};

use crate::{Backend, BackendFactory, BackendOptions, Error};
use powdr_ast::analyzed::Analyzed;
Expand All @@ -25,8 +26,8 @@ pub struct Factory;
impl<F: FieldElement> BackendFactory<F> for Factory {
fn create<'a>(
&self,
pil: &'a Analyzed<F>,
fixed: &'a [(String, Vec<F>)],
pil: Arc<Analyzed<F>>,
fixed: Arc<Vec<(String, Vec<F>)>>,
_output_dir: Option<PathBuf>,
setup: Option<&mut dyn std::io::Read>,
verification_key: Option<&mut dyn std::io::Read>,
Expand All @@ -50,9 +51,9 @@ impl<F: FieldElement> BackendFactory<F> for Factory {

let params = create_stark_struct(pil.degree(), proof_type.hash_type());

let (pil_json, patched_fixed) = first_step_fixup(pil, fixed);
let (pil_json, patched_fixed) = first_step_fixup(&pil, &fixed);

let fixed = patched_fixed.map_or_else(|| Cow::Borrowed(fixed), Cow::Owned);
let fixed = patched_fixed.map_or_else(|| fixed.clone(), Arc::new);

let const_pols = to_starky_pols_array(&fixed, &pil_json, PolKind::Constant);

Expand Down Expand Up @@ -86,8 +87,8 @@ fn create_stark_setup(
.unwrap()
}

pub struct EStark<'a, F: FieldElement> {
fixed: Cow<'a, [(String, Vec<F>)]>,
pub struct EStark<F: FieldElement> {
fixed: Arc<Vec<(String, Vec<F>)>>,
pil_json: PIL,
params: StarkStruct,
// eSTARK calls it setup, but it works similarly to a verification key and depends only on the
Expand All @@ -96,7 +97,7 @@ pub struct EStark<'a, F: FieldElement> {
proof_type: ProofType,
}

impl<'a, F: FieldElement> EStark<'a, F> {
impl<F: FieldElement> EStark<F> {
fn verify_stark_gl_with_publics(
&self,
proof: &StarkProof<MerkleTreeGL>,
Expand Down Expand Up @@ -183,7 +184,7 @@ impl<'a, F: FieldElement> EStark<'a, F> {
}
}

impl<'a, F: FieldElement> Backend<'a, F> for EStark<'a, F> {
impl<'a, F: FieldElement> Backend<'a, F> for EStark<F> {
fn verify(&self, proof: &[u8], instances: &[Vec<F>]) -> Result<(), Error> {
match self.proof_type {
ProofType::StarkGL => {
Expand Down
16 changes: 9 additions & 7 deletions backend/src/halo2/circuit_builder.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{cmp::max, collections::BTreeMap, iter};
use std::{cmp::max, collections::BTreeMap, iter, sync::Arc};

use halo2_curves::ff::PrimeField;
use halo2_proofs::{
Expand Down Expand Up @@ -52,7 +52,7 @@ impl PowdrCircuitConfig {
#[derive(Clone)]
/// Wraps an Analyzed<T>. This is used as the PowdrCircuit::Params type, which required
/// a type that implements Default.
pub(crate) struct AnalyzedWrapper<T: FieldElement>(Analyzed<T>);
pub(crate) struct AnalyzedWrapper<T: FieldElement>(Arc<Analyzed<T>>);

impl<T> Default for AnalyzedWrapper<T>
where
Expand All @@ -65,16 +65,16 @@ where
}
}

impl<T: FieldElement> From<Analyzed<T>> for AnalyzedWrapper<T> {
fn from(analyzed: Analyzed<T>) -> Self {
impl<T: FieldElement> From<Arc<Analyzed<T>>> for AnalyzedWrapper<T> {
fn from(analyzed: Arc<Analyzed<T>>) -> Self {
Self(analyzed)
}
}

#[derive(Clone)]
pub(crate) struct PowdrCircuit<'a, T> {
/// The analyzed PIL
analyzed: &'a Analyzed<T>,
analyzed: Arc<Analyzed<T>>,
/// The value of the fixed columns
fixed: &'a [(String, Vec<T>)],
/// The value of the witness columns, if set
Expand Down Expand Up @@ -102,17 +102,19 @@ fn get_publics<T: FieldElement>(analyzed: &Analyzed<T>) -> Vec<(String, usize)>
}

impl<'a, T: FieldElement> PowdrCircuit<'a, T> {
pub(crate) fn new(analyzed: &'a Analyzed<T>, fixed: &'a [(String, Vec<T>)]) -> Self {
pub(crate) fn new(analyzed: Arc<Analyzed<T>>, fixed: &'a [(String, Vec<T>)]) -> Self {
for (fixed_name, _) in fixed {
assert!(fixed_name != FIRST_STEP_NAME);
assert!(fixed_name != ENABLE_NAME);
}

let publics = get_publics(&analyzed);

Self {
analyzed,
fixed,
witness: None,
publics: get_publics(analyzed),
publics,
witgen_callback: None,
}
}
Expand Down
4 changes: 3 additions & 1 deletion backend/src/halo2/mock_prover.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use powdr_ast::analyzed::Analyzed;
use powdr_executor::witgen::WitgenCallback;

Expand All @@ -8,7 +10,7 @@ use powdr_number::{FieldElement, KnownField};

// Can't depend on compiler::pipeline::GeneratedWitness because of circular dependencies...
pub fn mock_prove<T: FieldElement>(
pil: &Analyzed<T>,
pil: Arc<Analyzed<T>>,
constants: &[(String, Vec<T>)],
witness: &[(String, Vec<T>)],
witgen_callback: WitgenCallback<T>,
Expand Down
21 changes: 11 additions & 10 deletions backend/src/halo2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use std::io;
use std::path::PathBuf;
use std::sync::Arc;

use crate::{Backend, BackendFactory, BackendOptions, Error, Proof};
use powdr_ast::analyzed::Analyzed;
Expand Down Expand Up @@ -74,8 +75,8 @@ where
impl<F: FieldElement> BackendFactory<F> for Halo2ProverFactory {
fn create<'a>(
&self,
pil: &'a Analyzed<F>,
fixed: &'a [(String, Vec<F>)],
pil: Arc<Analyzed<F>>,
fixed: Arc<Vec<(String, Vec<F>)>>,
_output_dir: Option<PathBuf>,
setup: Option<&mut dyn io::Read>,
verification_key: Option<&mut dyn io::Read>,
Expand Down Expand Up @@ -108,7 +109,7 @@ fn fe_slice_to_string<F: FieldElement>(fe: &[F]) -> Vec<String> {
fe.iter().map(|x| x.to_string()).collect()
}

impl<'a, T: FieldElement> Backend<'a, T> for Halo2Prover<'a, T> {
impl<'a, T: FieldElement> Backend<'a, T> for Halo2Prover<T> {
fn verify(&self, proof: &[u8], instances: &[Vec<T>]) -> Result<(), Error> {
let proof: Halo2Proof = serde_json::from_slice(proof).unwrap();
// TODO should do a verification refactoring making it a 1d vec
Expand Down Expand Up @@ -180,8 +181,8 @@ pub(crate) struct Halo2MockFactory;
impl<F: FieldElement> BackendFactory<F> for Halo2MockFactory {
fn create<'a>(
&self,
pil: &'a Analyzed<F>,
fixed: &'a [(String, Vec<F>)],
pil: Arc<Analyzed<F>>,
fixed: Arc<Vec<(String, Vec<F>)>>,
_output_dir: Option<PathBuf>,
setup: Option<&mut dyn io::Read>,
verification_key: Option<&mut dyn io::Read>,
Expand All @@ -202,12 +203,12 @@ impl<F: FieldElement> BackendFactory<F> for Halo2MockFactory {
}
}

pub struct Halo2Mock<'a, F: FieldElement> {
pil: &'a Analyzed<F>,
fixed: &'a [(String, Vec<F>)],
pub struct Halo2Mock<F: FieldElement> {
pil: Arc<Analyzed<F>>,
fixed: Arc<Vec<(String, Vec<F>)>>,
}

impl<'a, T: FieldElement> Backend<'a, T> for Halo2Mock<'a, T> {
impl<'a, T: FieldElement> Backend<'a, T> for Halo2Mock<T> {
fn prove(
&self,
witness: &[(String, Vec<T>)],
Expand All @@ -218,7 +219,7 @@ impl<'a, T: FieldElement> Backend<'a, T> for Halo2Mock<'a, T> {
return Err(Error::NoAggregationAvailable);
}

mock_prover::mock_prove(self.pil, self.fixed, witness, witgen_callback)
mock_prover::mock_prove(self.pil.clone(), &self.fixed, witness, witgen_callback)
.map_err(Error::BackendError)?;

Ok(vec![])
Expand Down
19 changes: 10 additions & 9 deletions backend/src/halo2/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ use itertools::Itertools;
use rand::rngs::OsRng;
use std::{
io::{self, Cursor},
sync::Arc,
time::Instant,
};

Expand All @@ -58,9 +59,9 @@ use std::{
/// This only works with Bn254, so it really shouldn't be generic over the field
/// element, but without RFC #1210, the only alternative I found is a very ugly
/// "unsafe" code, and unsafe code is harder to explain and maintain.
pub struct Halo2Prover<'a, F> {
analyzed: &'a Analyzed<F>,
fixed: &'a [(String, Vec<F>)],
pub struct Halo2Prover<F> {
analyzed: Arc<Analyzed<F>>,
fixed: Arc<Vec<(String, Vec<F>)>>,
params: ParamsKZG<Bn256>,
// Verification key of the proof type we're generating
vkey: Option<VerifyingKey<G1Affine>>,
Expand All @@ -82,10 +83,10 @@ pub fn generate_setup(size: DegreeType) -> ParamsKZG<Bn256> {
ParamsKZG::<Bn256>::new(std::cmp::max(4, degree_bits(size)))
}

impl<'a, F: FieldElement> Halo2Prover<'a, F> {
impl<F: FieldElement> Halo2Prover<F> {
pub fn new(
analyzed: &'a Analyzed<F>,
fixed: &'a [(String, Vec<F>)],
analyzed: Arc<Analyzed<F>>,
fixed: Arc<Vec<(String, Vec<F>)>>,
setup: Option<&mut dyn io::Read>,
proof_type: ProofType,
) -> Result<Self, io::Error> {
Expand Down Expand Up @@ -129,7 +130,7 @@ impl<'a, F: FieldElement> Halo2Prover<'a, F> {
) -> Result<(Vec<u8>, Vec<Vec<Fr>>), String> {
log::info!("Starting proof generation...");

let circuit = PowdrCircuit::new(self.analyzed, self.fixed)
let circuit = PowdrCircuit::new(self.analyzed.clone(), &self.fixed)
.with_witgen_callback(witgen_callback)
.with_witness(witness);
let publics = vec![circuit.instance_column()];
Expand Down Expand Up @@ -238,7 +239,7 @@ impl<'a, F: FieldElement> Halo2Prover<'a, F> {

log::info!("Generating circuit for app snark...");

let circuit_app = PowdrCircuit::new(self.analyzed, self.fixed)
let circuit_app = PowdrCircuit::new(self.analyzed.clone(), &self.fixed)
.with_witgen_callback(witgen_callback)
.with_witness(witness);

Expand Down Expand Up @@ -362,7 +363,7 @@ impl<'a, F: FieldElement> Halo2Prover<'a, F> {
}

fn generate_verification_key_single(&self) -> Result<VerifyingKey<G1Affine>, String> {
let circuit = PowdrCircuit::new(self.analyzed, self.fixed);
let circuit = PowdrCircuit::new(self.analyzed.clone(), &self.fixed);
keygen_vk(&self.params, &circuit).map_err(|e| e.to_string())
}

Expand Down
6 changes: 3 additions & 3 deletions backend/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ mod composite;
use powdr_ast::analyzed::Analyzed;
use powdr_executor::witgen::WitgenCallback;
use powdr_number::{DegreeType, FieldElement};
use std::{io, path::PathBuf};
use std::{io, path::PathBuf, sync::Arc};
use strum::{Display, EnumString, EnumVariantNames};

#[derive(Clone, EnumString, EnumVariantNames, Display, Copy)]
Expand Down Expand Up @@ -131,8 +131,8 @@ pub trait BackendFactory<F: FieldElement> {
#[allow(clippy::too_many_arguments)]
fn create<'a>(
&self,
pil: &'a Analyzed<F>,
fixed: &'a [(String, Vec<F>)],
pil: Arc<Analyzed<F>>,
fixed: Arc<Vec<(String, Vec<F>)>>,
output_dir: Option<PathBuf>,
setup: Option<&mut dyn io::Read>,
verification_key: Option<&mut dyn io::Read>,
Expand Down
16 changes: 8 additions & 8 deletions pipeline/src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -911,8 +911,8 @@ impl<T: FieldElement> Pipeline<T> {
/* Create the backend */
let backend = factory
.create(
pil.borrow(),
&fixed_cols[..],
pil.clone(),
fixed_cols.clone(),
self.output_dir.clone(),
setup.as_io_read(),
vkey.as_io_read(),
Expand Down Expand Up @@ -993,8 +993,8 @@ impl<T: FieldElement> Pipeline<T> {

let backend = factory
.create(
pil.borrow(),
&fixed_cols[..],
pil.clone(),
fixed_cols.clone(),
self.output_dir.clone(),
setup_file
.as_mut()
Expand Down Expand Up @@ -1047,8 +1047,8 @@ impl<T: FieldElement> Pipeline<T> {

let backend = factory
.create(
pil.borrow(),
&fixed_cols[..],
pil.clone(),
fixed_cols.clone(),
self.output_dir.clone(),
setup_file
.as_mut()
Expand Down Expand Up @@ -1095,8 +1095,8 @@ impl<T: FieldElement> Pipeline<T> {

let backend = factory
.create(
pil.borrow(),
&fixed_cols[..],
pil.clone(),
fixed_cols.clone(),
self.output_dir.clone(),
setup_file
.as_mut()
Expand Down

0 comments on commit f6c7c81

Please sign in to comment.