Skip to content

Commit

Permalink
combined batch operations (insert + delete) (#160)
Browse files Browse the repository at this point in the history
* fix(rln): clippy error

* feat: batch ops in ZerokitMerkleTree

* chore: bump pmtree

* fix: upstream root calc
  • Loading branch information
rymnc committed May 15, 2023
1 parent 584c2cf commit 8f2c9e3
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

46 changes: 46 additions & 0 deletions rln/src/pm_tree_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ use crate::hashers::{poseidon_hash, PoseidonHash};
use crate::utils::{bytes_le_to_fr, fr_to_bytes_le};
use color_eyre::{Report, Result};
use serde_json::Value;
use std::collections::HashSet;
use std::fmt::Debug;
use std::path::PathBuf;
use std::str::FromStr;
use utils::pmtree::Hasher;
use utils::*;

pub struct PmTree {
Expand Down Expand Up @@ -127,6 +129,10 @@ impl ZerokitMerkleTree for PmTree {
.map_err(|e| Report::msg(e.to_string()))
}

fn get(&self, index: usize) -> Result<FrOf<Self::Hasher>> {
self.tree.get(index).map_err(|e| Report::msg(e.to_string()))
}

fn set_range<I: IntoIterator<Item = FrOf<Self::Hasher>>>(
&mut self,
start: usize,
Expand All @@ -137,6 +143,42 @@ impl ZerokitMerkleTree for PmTree {
.map_err(|e| Report::msg(e.to_string()))
}

fn override_range<I: IntoIterator<Item = FrOf<Self::Hasher>>, J: IntoIterator<Item = usize>>(
&mut self,
start: usize,
leaves: I,
indices: J,
) -> Result<()> {
let leaves = leaves.into_iter().collect::<Vec<_>>();
let indices = indices.into_iter().collect::<HashSet<_>>();
let end = start + leaves.len();

if leaves.len() + start - indices.len() > self.capacity() {
return Err(Report::msg("index out of bounds"));
}

// extend the range to include indices to be removed
let min_index = indices.iter().min().unwrap_or(&start);
let max_index = indices.iter().max().unwrap_or(&end);

let mut new_leaves = Vec::new();

// insert leaves into new_leaves
for i in *min_index..*max_index {
if indices.contains(&i) {
// insert 0
new_leaves.push(Self::Hasher::default_leaf());
} else {
// insert leaf
new_leaves.push(leaves[i - start]);
}
}

self.tree
.set_range(start, new_leaves)
.map_err(|e| Report::msg(e.to_string()))
}

fn update_next(&mut self, leaf: FrOf<Self::Hasher>) -> Result<()> {
self.tree
.update_next(leaf)
Expand All @@ -161,6 +203,10 @@ impl ZerokitMerkleTree for PmTree {
Err(Report::msg("verify failed"))
}
}

fn compute_root(&mut self) -> Result<FrOf<Self::Hasher>> {
Ok(self.tree.root())
}
}

impl ZerokitMerkleProof for PmTreeProof {
Expand Down
2 changes: 1 addition & 1 deletion utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ license = "MIT OR Apache-2.0"
ark-ff = { version = "=0.4.1", default-features = false, features = ["asm"] }
num-bigint = { version = "=0.4.3", default-features = false, features = ["rand"] }
color-eyre = "=0.6.2"
pmtree = { git = "https://github.com/Rate-Limiting-Nullifier/pmtree", rev = "b3a02216cece3e9c24e1754ea381bf784fd1df48", optional = true}
pmtree = { git = "https://github.com/Rate-Limiting-Nullifier/pmtree", rev = "bf27d4273b34a7f4ec3c77cdf67fb17a5602504a", optional = true}
sled = "=0.34.7"
serde = "1.0.44"

Expand Down
45 changes: 45 additions & 0 deletions utils/src/merkle_tree/full_merkle_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,14 @@ where
Ok(())
}

// Get a leaf from the specified tree index
fn get(&self, leaf: usize) -> Result<FrOf<Self::Hasher>> {
if leaf >= self.capacity() {
return Err(Report::msg("leaf index out of bounds"));
}
Ok(self.nodes[self.capacity() + leaf - 1])
}

// Sets tree nodes, starting from start index
// Function proper of FullMerkleTree implementation
fn set_range<I: IntoIterator<Item = FrOf<Self::Hasher>>>(
Expand All @@ -151,6 +159,39 @@ where
Ok(())
}

fn override_range<I, J>(&mut self, start: usize, leaves: I, to_remove_indices: J) -> Result<()>
where
I: IntoIterator<Item = FrOf<Self::Hasher>>,
J: IntoIterator<Item = usize>,
{
let index = self.capacity() + start - 1;
let mut count = 0;
let leaves = leaves.into_iter().collect::<Vec<_>>();
let to_remove_indices = to_remove_indices.into_iter().collect::<Vec<_>>();
// first count number of hashes, and check that they fit in the tree
// then insert into the tree
if leaves.len() + start - to_remove_indices.len() > self.capacity() {
return Err(Report::msg("provided hashes do not fit in the tree"));
}

// remove leaves
for i in &to_remove_indices {
self.delete(*i)?;
}

// insert new leaves
for hash in leaves {
self.nodes[index + count] = hash;
count += 1;
}

if count != 0 {
self.update_nodes(index, index + (count - 1))?;
self.next_index = max(self.next_index, start + count - to_remove_indices.len());
}
Ok(())
}

// Sets a leaf at the next available index
fn update_next(&mut self, leaf: FrOf<Self::Hasher>) -> Result<()> {
self.set(self.next_index, leaf)?;
Expand Down Expand Up @@ -189,6 +230,10 @@ where
fn verify(&self, hash: &FrOf<Self::Hasher>, proof: &FullMerkleProof<H>) -> Result<bool> {
Ok(proof.compute_root_from(hash) == self.root())
}

fn compute_root(&mut self) -> Result<FrOf<Self::Hasher>> {
Ok(self.root())
}
}

impl<H: Hasher> FullMerkleTree<H>
Expand Down
6 changes: 6 additions & 0 deletions utils/src/merkle_tree/merkle_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,16 @@ pub trait ZerokitMerkleTree {
fn capacity(&self) -> usize;
fn leaves_set(&mut self) -> usize;
fn root(&self) -> FrOf<Self::Hasher>;
fn compute_root(&mut self) -> Result<FrOf<Self::Hasher>>;
fn set(&mut self, index: usize, leaf: FrOf<Self::Hasher>) -> Result<()>;
fn set_range<I>(&mut self, start: usize, leaves: I) -> Result<()>
where
I: IntoIterator<Item = FrOf<Self::Hasher>>;
fn get(&self, index: usize) -> Result<FrOf<Self::Hasher>>;
fn override_range<I, J>(&mut self, start: usize, leaves: I, to_remove_indices: J) -> Result<()>
where
I: IntoIterator<Item = FrOf<Self::Hasher>>,
J: IntoIterator<Item = usize>;
fn update_next(&mut self, leaf: FrOf<Self::Hasher>) -> Result<()>;
fn delete(&mut self, index: usize) -> Result<()>;
fn proof(&self, index: usize) -> Result<Self::Proof>;
Expand Down
44 changes: 44 additions & 0 deletions utils/src/merkle_tree/optimal_merkle_tree.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::merkle_tree::{Hasher, ZerokitMerkleProof, ZerokitMerkleTree};
use crate::FrOf;
use color_eyre::{Report, Result};
use std::collections::HashMap;
use std::str::FromStr;
Expand Down Expand Up @@ -110,6 +111,14 @@ where
Ok(())
}

// Get a leaf from the specified tree index
fn get(&self, index: usize) -> Result<H::Fr> {
if index >= self.capacity() {
return Err(Report::msg("index exceeds set size"));
}
Ok(self.get_node(self.depth, index))
}

// Sets multiple leaves from the specified tree index
fn set_range<I: IntoIterator<Item = H::Fr>>(&mut self, start: usize, leaves: I) -> Result<()> {
let leaves = leaves.into_iter().collect::<Vec<_>>();
Expand All @@ -125,6 +134,36 @@ where
Ok(())
}

fn override_range<I, J>(&mut self, start: usize, leaves: I, to_remove_indices: J) -> Result<()>
where
I: IntoIterator<Item = FrOf<Self::Hasher>>,
J: IntoIterator<Item = usize>,
{
let leaves = leaves.into_iter().collect::<Vec<_>>();
let to_remove_indices = to_remove_indices.into_iter().collect::<Vec<_>>();
// check if the range is valid
if leaves.len() + start - to_remove_indices.len() > self.capacity() {
return Err(Report::msg("provided range exceeds set size"));
}

// remove leaves
for i in &to_remove_indices {
self.delete(*i)?;
}

// add leaves
for (i, leaf) in leaves.iter().enumerate() {
self.nodes.insert((self.depth, start + i), *leaf);
self.recalculate_from(start + i)?;
}

self.next_index = max(
self.next_index,
start + leaves.len() - to_remove_indices.len(),
);
Ok(())
}

// Sets a leaf at the next available index
fn update_next(&mut self, leaf: H::Fr) -> Result<()> {
self.set(self.next_index, leaf)?;
Expand Down Expand Up @@ -172,6 +211,11 @@ where
let expected_root = witness.compute_root_from(leaf);
Ok(expected_root.eq(&self.root()))
}

fn compute_root(&mut self) -> Result<FrOf<Self::Hasher>> {
self.recalculate_from(0)?;
Ok(self.root())
}
}

impl<H: Hasher> OptimalMerkleTree<H>
Expand Down
5 changes: 2 additions & 3 deletions utils/src/pm_tree/sled_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ impl Database for SledDB {
Err(e) => {
return Err(PmtreeErrorKind::DatabaseError(
DatabaseErrorKind::CustomError(format!(
"Cannot create database: {} {:#?}",
e, config
"Cannot create database: {e} {config:#?}",
)),
))
}
Expand All @@ -29,7 +28,7 @@ impl Database for SledDB {
Ok(db) => db,
Err(e) => {
return Err(PmtreeErrorKind::DatabaseError(
DatabaseErrorKind::CustomError(format!("Cannot load database: {}", e)),
DatabaseErrorKind::CustomError(format!("Cannot load database: {e}")),
))
}
};
Expand Down
37 changes: 37 additions & 0 deletions utils/tests/merkle_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,41 @@ mod test {
.unwrap());
}
}

#[test]
fn test_override_range() {
let initial_leaves = [
hex!("0000000000000000000000000000000000000000000000000000000000000001"),
hex!("0000000000000000000000000000000000000000000000000000000000000002"),
hex!("0000000000000000000000000000000000000000000000000000000000000003"),
hex!("0000000000000000000000000000000000000000000000000000000000000004"),
];

let mut tree =
OptimalMerkleTree::<Keccak256>::new(2, [0; 32], OptimalMerkleConfig::default())
.unwrap();

// We set the leaves
tree.set_range(0, initial_leaves.iter().cloned()).unwrap();

let new_leaves = [
hex!("0000000000000000000000000000000000000000000000000000000000000005"),
hex!("0000000000000000000000000000000000000000000000000000000000000006"),
];

let to_delete_indices: [usize; 2] = [0, 1];

// We override the leaves
tree.override_range(
0, // start from the end of the initial leaves
new_leaves.iter().cloned(),
to_delete_indices.iter().cloned(),
)
.unwrap();

// ensure that the leaves are set correctly
for i in 0..new_leaves.len() {
assert_eq!(tree.get_leaf(i), new_leaves[i]);
}
}
}

0 comments on commit 8f2c9e3

Please sign in to comment.