Skip to content

Commit

Permalink
chore(rln): refactor resource initialization (#260)
Browse files Browse the repository at this point in the history
* chore(rln): optimize into Lazy OnceCells

* fix

* fix: dont change duration

* fix: increase duration?

* chore: add backtrace

* fix: remove plotter to avoid f64 range failure

* fix: remove ci alteration

* fix: use arc over witness calc

* fix: remove more lifetimes

* fix: benchmark correct fn call, not the getter

* fix: bench config
  • Loading branch information
rymnc committed Jun 17, 2024
1 parent c6493bd commit d8f813b
Show file tree
Hide file tree
Showing 11 changed files with 86 additions and 135 deletions.
20 changes: 0 additions & 20 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions rln-cli/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ use std::fs::File;
use crate::config::{Config, InnerConfig};

#[derive(Default)]
pub(crate) struct State<'a> {
pub rln: Option<RLN<'a>>,
pub(crate) struct State {
pub rln: Option<RLN>,
}

impl<'a> State<'a> {
pub(crate) fn load_state() -> Result<State<'a>> {
impl State {
pub(crate) fn load_state() -> Result<State> {
let config = Config::load_config()?;
let rln = if let Some(InnerConfig { file, tree_height }) = config.inner {
let resources = File::open(&file)?;
Expand Down
6 changes: 3 additions & 3 deletions rln-wasm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub fn init_panic_hook() {
pub struct RLNWrapper {
// The purpose of this wrapper is to hold a RLN instance with the 'static lifetime
// because wasm_bindgen does not allow returning elements with lifetimes
instance: RLN<'static>,
instance: RLN,
}

// Macro to call methods with arbitrary amount of arguments,
Expand Down Expand Up @@ -150,8 +150,8 @@ impl<T> ProcessArg for Vec<T> {
}
}

impl<'a> ProcessArg for *const RLN<'a> {
type ReturnType = &'a RLN<'a>;
impl ProcessArg for *const RLN {
type ReturnType = &'static RLN;
fn process(self) -> Self::ReturnType {
unsafe { &*self }
}
Expand Down
2 changes: 0 additions & 2 deletions rln/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ utils = { package = "zerokit_utils", version = "=0.5.0", path = "../utils/", def
serde_json = "=1.0.96"
serde = { version = "=1.0.163", features = ["derive"] }

include_dir = "=0.7.3"

[dev-dependencies]
sled = "=0.34.7"
criterion = { version = "=0.4.0", features = ["html_reports"] }
Expand Down
10 changes: 4 additions & 6 deletions rln/benches/circuit_deser_benchmark.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
use criterion::{criterion_group, criterion_main, Criterion};
use rln::circuit::{vk_from_ark_serialized, RESOURCES_DIR, VK_FILENAME};
use std::path::Path;
use rln::circuit::{vk_from_ark_serialized, VK_BYTES};

// Here we benchmark how long the deserialization of the
// verifying_key takes, only testing the json => verifying_key conversion,
// and skipping conversion from bytes => string => serde_json::Value
pub fn vk_deserialize_benchmark(c: &mut Criterion) {
let vk = RESOURCES_DIR.get_file(Path::new(VK_FILENAME)).unwrap();
let vk = vk.contents();
let vk = VK_BYTES;

c.bench_function("circuit::to_verifying_key", |b| {
c.bench_function("vk::vk_from_ark_serialized", |b| {
b.iter(|| {
let _ = vk_from_ark_serialized(&vk);
let _ = vk_from_ark_serialized(vk);
})
});
}
Expand Down
7 changes: 5 additions & 2 deletions rln/benches/circuit_loading_benchmark.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use criterion::{criterion_group, criterion_main, Criterion};
use rln::circuit::{zkey_from_raw, ZKEY_BYTES};

// Depending on the key type (enabled by the `--features arkzkey` flag)
// the upload speed from the `rln_final.zkey` or `rln_final.arkzkey` file is calculated
pub fn key_load_benchmark(c: &mut Criterion) {
c.bench_function("zkey::upload_from_folder", |b| {
let zkey = ZKEY_BYTES.to_vec();

c.bench_function("zkey::zkey_from_raw", |b| {
b.iter(|| {
let _ = rln::circuit::zkey_from_folder();
let _ = zkey_from_raw(&zkey);
})
});
}
Expand Down
105 changes: 39 additions & 66 deletions rln/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,49 @@ use color_eyre::{Report, Result};
cfg_if! {
if #[cfg(not(target_arch = "wasm32"))] {
use ark_circom::{WitnessCalculator};
use once_cell::sync::OnceCell;
use once_cell::sync::{Lazy};
use std::sync::Mutex;
use wasmer::{Module, Store};
use include_dir::{include_dir, Dir};
use std::path::Path;
use std::sync::Arc;
}
}

cfg_if! {
if #[cfg(feature = "arkzkey")] {
use ark_zkey::read_arkzkey_from_bytes;
const ARKZKEY_FILENAME: &str = "tree_height_20/rln_final.arkzkey";

const ARKZKEY_BYTES: &[u8] = include_bytes!("tree_height_20/rln_final.arkzkey");
} else {
use std::io::Cursor;
use ark_circom::read_zkey;
}
}

const ZKEY_FILENAME: &str = "tree_height_20/rln_final.zkey";
pub const VK_FILENAME: &str = "tree_height_20/verification_key.arkvkey";
const WASM_FILENAME: &str = "tree_height_20/rln.wasm";
pub const ZKEY_BYTES: &[u8] = include_bytes!("../resources/tree_height_20/rln_final.zkey");
pub const VK_BYTES: &[u8] = include_bytes!("../resources/tree_height_20/verification_key.arkvkey");
const WASM_BYTES: &[u8] = include_bytes!("../resources/tree_height_20/rln.wasm");

pub const TEST_TREE_HEIGHT: usize = 20;
#[cfg(not(target_arch = "wasm32"))]
static ZKEY: Lazy<(ProvingKey<Curve>, ConstraintMatrices<Fr>)> = Lazy::new(|| {
cfg_if! {
if #[cfg(feature = "arkzkey")] {
read_arkzkey_from_bytes(ARKZKEY_BYTES).expect("Failed to read arkzkey")
} else {
let mut reader = Cursor::new(ZKEY_BYTES);
read_zkey(&mut reader).expect("Failed to read zkey")
}
}
});

#[cfg(not(target_arch = "wasm32"))]
pub static RESOURCES_DIR: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/resources");
static VK: Lazy<VerifyingKey<Curve>> =
Lazy::new(|| vk_from_ark_serialized(VK_BYTES).expect("Failed to read vk"));

#[cfg(not(target_arch = "wasm32"))]
static WITNESS_CALCULATOR: Lazy<Arc<Mutex<WitnessCalculator>>> = Lazy::new(|| {
circom_from_raw(WASM_BYTES.to_vec()).expect("Failed to create witness calculator")
});

pub const TEST_TREE_HEIGHT: usize = 20;

// The following types define the pairing friendly elliptic curve, the underlying finite fields and groups default to this module
// Note that proofs are serialized assuming Fr to be 4x8 = 32 bytes in size. Hence, changing to a curve with different encoding will make proof verification to fail
Expand Down Expand Up @@ -72,26 +88,8 @@ pub fn zkey_from_raw(zkey_data: &Vec<u8>) -> Result<(ProvingKey<Curve>, Constrai

// Loads the proving key
#[cfg(not(target_arch = "wasm32"))]
pub fn zkey_from_folder() -> Result<(ProvingKey<Curve>, ConstraintMatrices<Fr>)> {
#[cfg(feature = "arkzkey")]
let zkey = RESOURCES_DIR.get_file(Path::new(ARKZKEY_FILENAME));
#[cfg(not(feature = "arkzkey"))]
let zkey = RESOURCES_DIR.get_file(Path::new(ZKEY_FILENAME));

if let Some(zkey) = zkey {
let proving_key_and_matrices = match () {
#[cfg(feature = "arkzkey")]
() => read_arkzkey_from_bytes(zkey.contents())?,
#[cfg(not(feature = "arkzkey"))]
() => {
let mut c = Cursor::new(zkey.contents());
read_zkey(&mut c)?
}
};
Ok(proving_key_and_matrices)
} else {
Err(Report::msg("No proving key found!"))
}
pub fn zkey_from_folder() -> &'static (ProvingKey<Curve>, ConstraintMatrices<Fr>) {
&ZKEY
}

// Loads the verification key from a bytes vector
Expand All @@ -112,49 +110,24 @@ pub fn vk_from_raw(vk_data: &[u8], zkey_data: &Vec<u8>) -> Result<VerifyingKey<C

// Loads the verification key
#[cfg(not(target_arch = "wasm32"))]
pub fn vk_from_folder() -> Result<VerifyingKey<Curve>> {
let vk = RESOURCES_DIR.get_file(Path::new(VK_FILENAME));
let zkey = RESOURCES_DIR.get_file(Path::new(ZKEY_FILENAME));

let verifying_key: VerifyingKey<Curve>;
if let Some(vk) = vk {
verifying_key = vk_from_ark_serialized(vk.contents())?;
Ok(verifying_key)
} else if let Some(_zkey) = zkey {
let (proving_key, _matrices) = zkey_from_folder()?;
verifying_key = proving_key.vk;
Ok(verifying_key)
} else {
Err(Report::msg("No proving/verification key found!"))
}
pub fn vk_from_folder() -> &'static VerifyingKey<Curve> {
&VK
}

#[cfg(not(target_arch = "wasm32"))]
static WITNESS_CALCULATOR: OnceCell<Mutex<WitnessCalculator>> = OnceCell::new();

// Initializes the witness calculator using a bytes vector
#[cfg(not(target_arch = "wasm32"))]
pub fn circom_from_raw(wasm_buffer: Vec<u8>) -> Result<&'static Mutex<WitnessCalculator>> {
WITNESS_CALCULATOR.get_or_try_init(|| {
let store = Store::default();
let module = Module::new(&store, wasm_buffer)?;
let result = WitnessCalculator::from_module(module)?;
Ok::<Mutex<WitnessCalculator>, Report>(Mutex::new(result))
})
pub fn circom_from_raw(wasm_buffer: Vec<u8>) -> Result<Arc<Mutex<WitnessCalculator>>> {
let store = Store::default();
let module = Module::new(&store, wasm_buffer)?;
let result = WitnessCalculator::from_module(module)?;
let wrapped = Mutex::new(result);
Ok(Arc::new(wrapped))
}

// Initializes the witness calculator
#[cfg(not(target_arch = "wasm32"))]
pub fn circom_from_folder() -> Result<&'static Mutex<WitnessCalculator>> {
// We read the wasm file
let wasm = RESOURCES_DIR.get_file(Path::new(WASM_FILENAME));

if let Some(wasm) = wasm {
let wasm_buffer = wasm.contents();
circom_from_raw(wasm_buffer.to_vec())
} else {
Err(Report::msg("No wasm file found!"))
}
pub fn circom_from_folder() -> &'static Arc<Mutex<WitnessCalculator>> {
&WITNESS_CALCULATOR
}

// Computes the verification key from a bytes vector containing pre-processed ark-serialized verification key
Expand All @@ -167,7 +140,7 @@ pub fn vk_from_ark_serialized(data: &[u8]) -> Result<VerifyingKey<Curve>> {
// Checks verification key to be correct with respect to proving key
#[cfg(not(target_arch = "wasm32"))]
pub fn check_vk_from_zkey(verifying_key: VerifyingKey<Curve>) -> Result<()> {
let (proving_key, _matrices) = zkey_from_folder()?;
let (proving_key, _matrices) = zkey_from_folder();
if proving_key.vk == verifying_key {
Ok(())
} else {
Expand Down
8 changes: 4 additions & 4 deletions rln/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,15 @@ impl ProcessArg for *const Buffer {
}
}

impl<'a> ProcessArg for *const RLN<'a> {
type ReturnType = &'a RLN<'a>;
impl ProcessArg for *const RLN {
type ReturnType = &'static RLN;
fn process(self) -> Self::ReturnType {
unsafe { &*self }
}
}

impl<'a> ProcessArg for *mut RLN<'a> {
type ReturnType = &'a mut RLN<'a>;
impl ProcessArg for *mut RLN {
type ReturnType = &'static mut RLN;
fn process(self) -> Self::ReturnType {
unsafe { &mut *self }
}
Expand Down
Loading

0 comments on commit d8f813b

Please sign in to comment.