From d8f813bc2ec50e5738e8e0019b0fbd69400f389e Mon Sep 17 00:00:00 2001 From: Aaryamann Challani <43716372+rymnc@users.noreply.github.com> Date: Mon, 17 Jun 2024 13:43:09 +0530 Subject: [PATCH] chore(rln): refactor resource initialization (#260) * chore(rln): optimize into Lazy OnceCells * fix * fix: dont change duration * fix: increase duration? * chore: add backtrace * fix: remove plotter to avoid f64 range failure * fix: remove ci alteration * fix: use arc over witness calc * fix: remove more lifetimes * fix: benchmark correct fn call, not the getter * fix: bench config --- Cargo.lock | 20 ----- rln-cli/src/state.rs | 8 +- rln-wasm/src/lib.rs | 6 +- rln/Cargo.toml | 2 - rln/benches/circuit_deser_benchmark.rs | 10 +-- rln/benches/circuit_loading_benchmark.rs | 7 +- rln/src/circuit.rs | 105 +++++++++-------------- rln/src/ffi.rs | 8 +- rln/src/public.rs | 41 +++++---- rln/tests/ffi.rs | 2 +- rln/tests/protocol.rs | 12 +-- 11 files changed, 86 insertions(+), 135 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7953501d..8c8ffbc9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1460,25 +1460,6 @@ dependencies = [ "syn 1.0.109", ] -[[package]] -name = "include_dir" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18762faeff7122e89e0857b02f7ce6fcc0d101d5e9ad2ad7846cc01d61b7f19e" -dependencies = [ - "include_dir_macros", -] - -[[package]] -name = "include_dir_macros" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b139284b5cf57ecfa712bcc66950bb635b31aff41c188e8a4cfc758eca374a3f" -dependencies = [ - "proc-macro2", - "quote", -] - [[package]] name = "indenter" version = "0.3.3" @@ -2303,7 +2284,6 @@ dependencies = [ "cfg-if", "color-eyre", "criterion 0.4.0", - "include_dir", "num-bigint", "num-traits", "once_cell", diff --git a/rln-cli/src/state.rs b/rln-cli/src/state.rs index b0986d79..b3cb76b9 100644 --- a/rln-cli/src/state.rs +++ b/rln-cli/src/state.rs @@ -5,12 +5,12 @@ use std::fs::File; use crate::config::{Config, InnerConfig}; #[derive(Default)] -pub(crate) struct State<'a> { - pub rln: Option>, +pub(crate) struct State { + pub rln: Option, } -impl<'a> State<'a> { - pub(crate) fn load_state() -> Result> { +impl State { + pub(crate) fn load_state() -> Result { let config = Config::load_config()?; let rln = if let Some(InnerConfig { file, tree_height }) = config.inner { let resources = File::open(&file)?; diff --git a/rln-wasm/src/lib.rs b/rln-wasm/src/lib.rs index ee2ad3b0..fcb8d25c 100644 --- a/rln-wasm/src/lib.rs +++ b/rln-wasm/src/lib.rs @@ -19,7 +19,7 @@ pub fn init_panic_hook() { pub struct RLNWrapper { // The purpose of this wrapper is to hold a RLN instance with the 'static lifetime // because wasm_bindgen does not allow returning elements with lifetimes - instance: RLN<'static>, + instance: RLN, } // Macro to call methods with arbitrary amount of arguments, @@ -150,8 +150,8 @@ impl ProcessArg for Vec { } } -impl<'a> ProcessArg for *const RLN<'a> { - type ReturnType = &'a RLN<'a>; +impl ProcessArg for *const RLN { + type ReturnType = &'static RLN; fn process(self) -> Self::ReturnType { unsafe { &*self } } diff --git a/rln/Cargo.toml b/rln/Cargo.toml index ed7b905e..e6767258 100644 --- a/rln/Cargo.toml +++ b/rln/Cargo.toml @@ -58,8 +58,6 @@ utils = { package = "zerokit_utils", version = "=0.5.0", path = "../utils/", def serde_json = "=1.0.96" serde = { version = "=1.0.163", features = ["derive"] } -include_dir = "=0.7.3" - [dev-dependencies] sled = "=0.34.7" criterion = { version = "=0.4.0", features = ["html_reports"] } diff --git a/rln/benches/circuit_deser_benchmark.rs b/rln/benches/circuit_deser_benchmark.rs index 1311c9f3..c398b073 100644 --- a/rln/benches/circuit_deser_benchmark.rs +++ b/rln/benches/circuit_deser_benchmark.rs @@ -1,17 +1,15 @@ use criterion::{criterion_group, criterion_main, Criterion}; -use rln::circuit::{vk_from_ark_serialized, RESOURCES_DIR, VK_FILENAME}; -use std::path::Path; +use rln::circuit::{vk_from_ark_serialized, VK_BYTES}; // Here we benchmark how long the deserialization of the // verifying_key takes, only testing the json => verifying_key conversion, // and skipping conversion from bytes => string => serde_json::Value pub fn vk_deserialize_benchmark(c: &mut Criterion) { - let vk = RESOURCES_DIR.get_file(Path::new(VK_FILENAME)).unwrap(); - let vk = vk.contents(); + let vk = VK_BYTES; - c.bench_function("circuit::to_verifying_key", |b| { + c.bench_function("vk::vk_from_ark_serialized", |b| { b.iter(|| { - let _ = vk_from_ark_serialized(&vk); + let _ = vk_from_ark_serialized(vk); }) }); } diff --git a/rln/benches/circuit_loading_benchmark.rs b/rln/benches/circuit_loading_benchmark.rs index d671e86c..5ef2077b 100644 --- a/rln/benches/circuit_loading_benchmark.rs +++ b/rln/benches/circuit_loading_benchmark.rs @@ -1,11 +1,14 @@ use criterion::{criterion_group, criterion_main, Criterion}; +use rln::circuit::{zkey_from_raw, ZKEY_BYTES}; // Depending on the key type (enabled by the `--features arkzkey` flag) // the upload speed from the `rln_final.zkey` or `rln_final.arkzkey` file is calculated pub fn key_load_benchmark(c: &mut Criterion) { - c.bench_function("zkey::upload_from_folder", |b| { + let zkey = ZKEY_BYTES.to_vec(); + + c.bench_function("zkey::zkey_from_raw", |b| { b.iter(|| { - let _ = rln::circuit::zkey_from_folder(); + let _ = zkey_from_raw(&zkey); }) }); } diff --git a/rln/src/circuit.rs b/rln/src/circuit.rs index bcd2bb13..34f1d38e 100644 --- a/rln/src/circuit.rs +++ b/rln/src/circuit.rs @@ -13,33 +13,49 @@ use color_eyre::{Report, Result}; cfg_if! { if #[cfg(not(target_arch = "wasm32"))] { use ark_circom::{WitnessCalculator}; - use once_cell::sync::OnceCell; + use once_cell::sync::{Lazy}; use std::sync::Mutex; use wasmer::{Module, Store}; - use include_dir::{include_dir, Dir}; - use std::path::Path; + use std::sync::Arc; } } cfg_if! { if #[cfg(feature = "arkzkey")] { use ark_zkey::read_arkzkey_from_bytes; - const ARKZKEY_FILENAME: &str = "tree_height_20/rln_final.arkzkey"; - + const ARKZKEY_BYTES: &[u8] = include_bytes!("tree_height_20/rln_final.arkzkey"); } else { use std::io::Cursor; use ark_circom::read_zkey; } } -const ZKEY_FILENAME: &str = "tree_height_20/rln_final.zkey"; -pub const VK_FILENAME: &str = "tree_height_20/verification_key.arkvkey"; -const WASM_FILENAME: &str = "tree_height_20/rln.wasm"; +pub const ZKEY_BYTES: &[u8] = include_bytes!("../resources/tree_height_20/rln_final.zkey"); +pub const VK_BYTES: &[u8] = include_bytes!("../resources/tree_height_20/verification_key.arkvkey"); +const WASM_BYTES: &[u8] = include_bytes!("../resources/tree_height_20/rln.wasm"); -pub const TEST_TREE_HEIGHT: usize = 20; +#[cfg(not(target_arch = "wasm32"))] +static ZKEY: Lazy<(ProvingKey, ConstraintMatrices)> = Lazy::new(|| { + cfg_if! { + if #[cfg(feature = "arkzkey")] { + read_arkzkey_from_bytes(ARKZKEY_BYTES).expect("Failed to read arkzkey") + } else { + let mut reader = Cursor::new(ZKEY_BYTES); + read_zkey(&mut reader).expect("Failed to read zkey") + } + } +}); #[cfg(not(target_arch = "wasm32"))] -pub static RESOURCES_DIR: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/resources"); +static VK: Lazy> = + Lazy::new(|| vk_from_ark_serialized(VK_BYTES).expect("Failed to read vk")); + +#[cfg(not(target_arch = "wasm32"))] +static WITNESS_CALCULATOR: Lazy>> = Lazy::new(|| { + circom_from_raw(WASM_BYTES.to_vec()).expect("Failed to create witness calculator") +}); + +pub const TEST_TREE_HEIGHT: usize = 20; // The following types define the pairing friendly elliptic curve, the underlying finite fields and groups default to this module // Note that proofs are serialized assuming Fr to be 4x8 = 32 bytes in size. Hence, changing to a curve with different encoding will make proof verification to fail @@ -72,26 +88,8 @@ pub fn zkey_from_raw(zkey_data: &Vec) -> Result<(ProvingKey, Constrai // Loads the proving key #[cfg(not(target_arch = "wasm32"))] -pub fn zkey_from_folder() -> Result<(ProvingKey, ConstraintMatrices)> { - #[cfg(feature = "arkzkey")] - let zkey = RESOURCES_DIR.get_file(Path::new(ARKZKEY_FILENAME)); - #[cfg(not(feature = "arkzkey"))] - let zkey = RESOURCES_DIR.get_file(Path::new(ZKEY_FILENAME)); - - if let Some(zkey) = zkey { - let proving_key_and_matrices = match () { - #[cfg(feature = "arkzkey")] - () => read_arkzkey_from_bytes(zkey.contents())?, - #[cfg(not(feature = "arkzkey"))] - () => { - let mut c = Cursor::new(zkey.contents()); - read_zkey(&mut c)? - } - }; - Ok(proving_key_and_matrices) - } else { - Err(Report::msg("No proving key found!")) - } +pub fn zkey_from_folder() -> &'static (ProvingKey, ConstraintMatrices) { + &ZKEY } // Loads the verification key from a bytes vector @@ -112,49 +110,24 @@ pub fn vk_from_raw(vk_data: &[u8], zkey_data: &Vec) -> Result Result> { - let vk = RESOURCES_DIR.get_file(Path::new(VK_FILENAME)); - let zkey = RESOURCES_DIR.get_file(Path::new(ZKEY_FILENAME)); - - let verifying_key: VerifyingKey; - if let Some(vk) = vk { - verifying_key = vk_from_ark_serialized(vk.contents())?; - Ok(verifying_key) - } else if let Some(_zkey) = zkey { - let (proving_key, _matrices) = zkey_from_folder()?; - verifying_key = proving_key.vk; - Ok(verifying_key) - } else { - Err(Report::msg("No proving/verification key found!")) - } +pub fn vk_from_folder() -> &'static VerifyingKey { + &VK } -#[cfg(not(target_arch = "wasm32"))] -static WITNESS_CALCULATOR: OnceCell> = OnceCell::new(); - // Initializes the witness calculator using a bytes vector #[cfg(not(target_arch = "wasm32"))] -pub fn circom_from_raw(wasm_buffer: Vec) -> Result<&'static Mutex> { - WITNESS_CALCULATOR.get_or_try_init(|| { - let store = Store::default(); - let module = Module::new(&store, wasm_buffer)?; - let result = WitnessCalculator::from_module(module)?; - Ok::, Report>(Mutex::new(result)) - }) +pub fn circom_from_raw(wasm_buffer: Vec) -> Result>> { + let store = Store::default(); + let module = Module::new(&store, wasm_buffer)?; + let result = WitnessCalculator::from_module(module)?; + let wrapped = Mutex::new(result); + Ok(Arc::new(wrapped)) } // Initializes the witness calculator #[cfg(not(target_arch = "wasm32"))] -pub fn circom_from_folder() -> Result<&'static Mutex> { - // We read the wasm file - let wasm = RESOURCES_DIR.get_file(Path::new(WASM_FILENAME)); - - if let Some(wasm) = wasm { - let wasm_buffer = wasm.contents(); - circom_from_raw(wasm_buffer.to_vec()) - } else { - Err(Report::msg("No wasm file found!")) - } +pub fn circom_from_folder() -> &'static Arc> { + &WITNESS_CALCULATOR } // Computes the verification key from a bytes vector containing pre-processed ark-serialized verification key @@ -167,7 +140,7 @@ pub fn vk_from_ark_serialized(data: &[u8]) -> Result> { // Checks verification key to be correct with respect to proving key #[cfg(not(target_arch = "wasm32"))] pub fn check_vk_from_zkey(verifying_key: VerifyingKey) -> Result<()> { - let (proving_key, _matrices) = zkey_from_folder()?; + let (proving_key, _matrices) = zkey_from_folder(); if proving_key.vk == verifying_key { Ok(()) } else { diff --git a/rln/src/ffi.rs b/rln/src/ffi.rs index 7c89d730..7d438e76 100644 --- a/rln/src/ffi.rs +++ b/rln/src/ffi.rs @@ -143,15 +143,15 @@ impl ProcessArg for *const Buffer { } } -impl<'a> ProcessArg for *const RLN<'a> { - type ReturnType = &'a RLN<'a>; +impl ProcessArg for *const RLN { + type ReturnType = &'static RLN; fn process(self) -> Self::ReturnType { unsafe { &*self } } } -impl<'a> ProcessArg for *mut RLN<'a> { - type ReturnType = &'a mut RLN<'a>; +impl ProcessArg for *mut RLN { + type ReturnType = &'static mut RLN; fn process(self) -> Self::ReturnType { unsafe { &mut *self } } diff --git a/rln/src/public.rs b/rln/src/public.rs index 5a64fbd1..6d2dfb92 100644 --- a/rln/src/public.rs +++ b/rln/src/public.rs @@ -22,6 +22,7 @@ cfg_if! { use ark_circom::WitnessCalculator; use serde_json::{json, Value}; use utils::{Hasher}; + use std::sync::Arc; use std::str::FromStr; } else { use std::marker::*; @@ -39,7 +40,7 @@ pub const RLN_IDENTIFIER: &[u8] = b"zerokit/rln/010203040506070809"; /// It implements the methods required to update the internal Merkle Tree, generate and verify RLN ZK proofs. /// /// I/O is mostly done using writers and readers implementing `std::io::Write` and `std::io::Read`, respectively. -pub struct RLN<'a> { +pub struct RLN { proving_key: (ProvingKey, ConstraintMatrices), pub(crate) verification_key: VerifyingKey, pub(crate) tree: PoseidonTree, @@ -48,12 +49,12 @@ pub struct RLN<'a> { // contains a lifetime, a PhantomData is necessary to avoid a compiler // error since the lifetime is not being used #[cfg(not(target_arch = "wasm32"))] - pub(crate) witness_calculator: &'a Mutex, + pub(crate) witness_calculator: Arc>, #[cfg(target_arch = "wasm32")] - _marker: PhantomData<&'a ()>, + _marker: PhantomData<()>, } -impl RLN<'_> { +impl RLN { /// Creates a new RLN object by loading circuit resources from a folder. /// /// Input parameters are @@ -70,7 +71,7 @@ impl RLN<'_> { /// let mut rln = RLN::new(tree_height, input); /// ``` #[cfg(not(target_arch = "wasm32"))] - pub fn new(tree_height: usize, mut input_data: R) -> Result> { + pub fn new(tree_height: usize, mut input_data: R) -> Result { // We read input let mut input: Vec = Vec::new(); input_data.read_to_end(&mut input)?; @@ -78,10 +79,10 @@ impl RLN<'_> { let rln_config: Value = serde_json::from_str(&String::from_utf8(input)?)?; let tree_config = rln_config["tree_config"].to_string(); - let witness_calculator = circom_from_folder()?; - let proving_key = zkey_from_folder()?; + let witness_calculator = circom_from_folder(); + let proving_key = zkey_from_folder(); - let verification_key = vk_from_folder()?; + let verification_key = vk_from_folder(); let tree_config: ::Config = if tree_config.is_empty() { ::Config::default() @@ -97,9 +98,9 @@ impl RLN<'_> { )?; Ok(RLN { - witness_calculator, - proving_key, - verification_key, + witness_calculator: witness_calculator.to_owned(), + proving_key: proving_key.to_owned(), + verification_key: verification_key.to_owned(), tree, #[cfg(target_arch = "wasm32")] _marker: PhantomData, @@ -150,7 +151,7 @@ impl RLN<'_> { zkey_vec: Vec, vk_vec: Vec, mut tree_config_input: R, - ) -> Result> { + ) -> Result { #[cfg(not(target_arch = "wasm32"))] let witness_calculator = circom_from_raw(circom_vec)?; @@ -179,15 +180,13 @@ impl RLN<'_> { proving_key, verification_key, tree, + #[cfg(target_arch = "wasm32")] + _marker: PhantomData, }) } #[cfg(target_arch = "wasm32")] - pub fn new_with_params( - tree_height: usize, - zkey_vec: Vec, - vk_vec: Vec, - ) -> Result> { + pub fn new_with_params(tree_height: usize, zkey_vec: Vec, vk_vec: Vec) -> Result { #[cfg(not(target_arch = "wasm32"))] let witness_calculator = circom_from_raw(circom_vec)?; @@ -682,7 +681,7 @@ impl RLN<'_> { } */ - let proof = generate_proof(self.witness_calculator, &self.proving_key, &rln_witness)?; + let proof = generate_proof(&self.witness_calculator, &self.proving_key, &rln_witness)?; // Note: we export a serialization of ark-groth16::Proof not semaphore::Proof proof.serialize_compressed(&mut output_data)?; @@ -805,7 +804,7 @@ impl RLN<'_> { let (rln_witness, _) = proof_inputs_to_rln_witness(&mut self.tree, &witness_byte)?; let proof_values = proof_values_from_witness(&rln_witness)?; - let proof = generate_proof(self.witness_calculator, &self.proving_key, &rln_witness)?; + let proof = generate_proof(&self.witness_calculator, &self.proving_key, &rln_witness)?; // Note: we export a serialization of ark-groth16::Proof not semaphore::Proof // This proof is compressed, i.e. 128 bytes long @@ -853,7 +852,7 @@ impl RLN<'_> { let (rln_witness, _) = deserialize_witness(&witness_byte)?; let proof_values = proof_values_from_witness(&rln_witness)?; - let proof = generate_proof(self.witness_calculator, &self.proving_key, &rln_witness)?; + let proof = generate_proof(&self.witness_calculator, &self.proving_key, &rln_witness)?; // Note: we export a serialization of ark-groth16::Proof not semaphore::Proof // This proof is compressed, i.e. 128 bytes long @@ -1284,7 +1283,7 @@ impl RLN<'_> { } #[cfg(not(target_arch = "wasm32"))] -impl Default for RLN<'_> { +impl Default for RLN { fn default() -> Self { let tree_height = TEST_TREE_HEIGHT; let buffer = Cursor::new(json!({}).to_string()); diff --git a/rln/tests/ffi.rs b/rln/tests/ffi.rs index 1bdb520b..778d7534 100644 --- a/rln/tests/ffi.rs +++ b/rln/tests/ffi.rs @@ -16,7 +16,7 @@ mod test { const NO_OF_LEAVES: usize = 256; - fn create_rln_instance() -> &'static mut RLN<'static> { + fn create_rln_instance() -> &'static mut RLN { let mut rln_pointer = MaybeUninit::<*mut RLN>::uninit(); let input_config = json!({}).to_string(); let input_buffer = &Buffer::from(input_config.as_bytes()); diff --git a/rln/tests/protocol.rs b/rln/tests/protocol.rs index 59698e28..9a5144ad 100644 --- a/rln/tests/protocol.rs +++ b/rln/tests/protocol.rs @@ -127,9 +127,9 @@ mod test { // We test a RLN proof generation and verification fn test_witness_from_json() { // We generate all relevant keys - let proving_key = zkey_from_folder().unwrap(); - let verification_key = vk_from_folder().unwrap(); - let builder = circom_from_folder().unwrap(); + let proving_key = zkey_from_folder(); + let verification_key = vk_from_folder(); + let builder = circom_from_folder(); // We compute witness from the json input let rln_witness = get_test_witness(); @@ -156,9 +156,9 @@ mod test { assert_eq!(rln_witness_deser, rln_witness); // We generate all relevant keys - let proving_key = zkey_from_folder().unwrap(); - let verification_key = vk_from_folder().unwrap(); - let builder = circom_from_folder().unwrap(); + let proving_key = zkey_from_folder(); + let verification_key = vk_from_folder(); + let builder = circom_from_folder(); // Let's generate a zkSNARK proof let proof = generate_proof(builder, &proving_key, &rln_witness_deser).unwrap();