Skip to content

Commit

Permalink
Pick one algo to select random K3 indices
Browse files Browse the repository at this point in the history
  • Loading branch information
poszu committed Mar 17, 2023
1 parent 902c12c commit e7dedcb
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 103 deletions.
121 changes: 29 additions & 92 deletions src/random_values_gen.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#[derive(Debug, Clone)]
struct Blake3Rng(blake3::OutputReader);

impl Blake3Rng {
Expand All @@ -17,77 +18,39 @@ impl Blake3Rng {
}

/// Picks random items from the provided Vec.
pub struct RandomValuesIterator<T> {
values: Vec<T>,
rng: Blake3Rng,
j: usize,
}

impl<T> RandomValuesIterator<T> {
pub(crate) fn new(values: Vec<T>, seed: &[&[u8]]) -> Self {
Self {
j: 0,
values,
rng: Blake3Rng::from_seed(seed),
}
}
}

impl<T: Copy> Iterator for RandomValuesIterator<T> {
type Item = T;

fn next(&mut self) -> Option<Self::Item> {
let data_len = self.values.len();
if self.j >= data_len {
return None;
}
let max_allowed = u16::MAX - (u16::MAX % (data_len - self.j) as u16);
loop {
let rand_num = self.rng.next_u16();
if rand_num < max_allowed {
let index = rand_num as usize % (data_len - self.j);
let result = self.values[index];
self.values.swap(index, data_len - self.j - 1);
self.j += 1;
return Some(result);
}
}
}
}

pub(crate) struct FisherYatesShuffle<T> {
pub(crate) struct RandomValuesIterator<T> {
// data shuffled in-place
data: Vec<T>,
rng: Blake3Rng,
index: usize,
idx: usize,
}

impl<T> FisherYatesShuffle<T> {
impl<T> RandomValuesIterator<T> {
pub(crate) fn new(data: Vec<T>, seed: &[&[u8]]) -> Self {
Self {
index: 0,
idx: 0,
data,
rng: Blake3Rng::from_seed(seed),
}
}
}

impl<T: Copy> Iterator for FisherYatesShuffle<T> {
impl<T: Copy> Iterator for RandomValuesIterator<T> {
type Item = T;

fn next(&mut self) -> Option<Self::Item> {
let remaining = self.data.len() - self.index;
let remaining = self.data.len() - self.idx;
if remaining == 0 {
return None;
}
let max_allowed = u16::MAX - u16::MAX % remaining as u16;
loop {
let rand_num = self.rng.next_u16();
let sample_max = u16::MAX - u16::MAX % remaining as u16;

if rand_num < sample_max {
let replacement_position = (rand_num as usize % remaining) + self.index;
self.data.swap(self.index, replacement_position);
let value = self.data[self.index];
self.index += 1;
if rand_num < max_allowed {
self.data
.swap(self.idx, (rand_num as usize % remaining) + self.idx);
let value = self.data[self.idx];
self.idx += 1;
return Some(value);
}
}
Expand All @@ -96,15 +59,12 @@ impl<T: Copy> Iterator for FisherYatesShuffle<T> {

#[cfg(test)]
mod tests {
use super::RandomValuesIterator;
use itertools::Itertools;
use rayon::prelude::{IntoParallelIterator, ParallelIterator};
use std::collections::HashSet;
use std::sync::atomic::{AtomicUsize, Ordering};

use crate::random_values_gen::FisherYatesShuffle;

use super::RandomValuesIterator;

/// Check if returns unique items from its data set.
#[test]
fn gives_each_value_once() {
Expand All @@ -118,13 +78,18 @@ mod tests {
}

#[test]
fn fisher_yates_shuffling_iter() {
let k2 = 1000;
let mut occurences = HashSet::new();
for item in FisherYatesShuffle::new((0..k2).collect(), &[]) {
assert!(occurences.insert(item));
}
assert_eq!(k2, occurences.len());
fn test_vec() {
let expected = [
39, 13, 95, 77, 36, 41, 74, 17, 59, 87, 91, 63, 40, 20, 94, 78, 48, 60, 18, 32, 67, 43,
23, 69, 71, 1, 51, 79, 19, 53, 86, 80, 14, 84, 97, 92, 83, 26, 2, 81, 42, 55, 50, 88,
75, 82, 44, 34, 58, 72, 35, 25, 10, 68, 12, 11, 70, 27, 98, 57, 96, 16, 45, 73, 0, 15,
62, 46, 30, 89, 33, 54, 9, 29, 7, 90, 38, 5, 49, 61, 93, 99, 22, 6, 64, 24, 76, 85, 37,
65, 31, 4, 52, 3, 56, 21, 8, 28, 66, 47,
];
let input = (0..expected.len()).collect();

let iter = RandomValuesIterator::new(input, &[]);
assert_eq!(&expected, iter.collect_vec().as_slice());
}

#[test]
Expand All @@ -140,43 +105,15 @@ mod tests {
(0u64..iterations).into_par_iter().for_each(|seed| {
for value in RandomValuesIterator::new(data_set.clone(), &[&seed.to_le_bytes()]).take(n)
{
occurences[value].fetch_add(1, Ordering::Relaxed);
}
});

// Verify distribution
let expected_count = (iterations * n as u64 / data_set.len() as u64) as f64;
let max_deviation = 0.002;

for (value, count) in occurences.into_iter().enumerate() {
let count = count.load(Ordering::Relaxed);
let deviation = (count as f64 - expected_count) / expected_count;
assert!(deviation.abs() < max_deviation, "{value} occured {count} times (expected {expected_count}). deviation {deviation} > {max_deviation}");
}
}

#[test]
fn distribution_is_uniform_fisher_iter() {
let data_set = (0..200).collect_vec();
let occurences = (0..data_set.len())
.map(|_| AtomicUsize::new(0))
.collect_vec();

// Take random n values many times and count each occurence
let n = 50;
let iterations = 10_000_000;
(0u64..iterations).into_par_iter().for_each(|seed| {
for value in FisherYatesShuffle::new(data_set.clone(), &[&seed.to_le_bytes()]).take(n) {
occurences[value].fetch_add(1, Ordering::Relaxed);
occurences[value].fetch_add(1, Ordering::Release);
}
});

// Verify distribution
let expected_count = (iterations * n as u64 / data_set.len() as u64) as f64;
let max_deviation = 0.002;

for (value, count) in occurences.into_iter().enumerate() {
let count = count.load(Ordering::Relaxed);
let count = count.load(Ordering::Acquire);
let deviation = (count as f64 - expected_count) / expected_count;
assert!(deviation.abs() < max_deviation, "{value} occured {count} times (expected {expected_count}). deviation {deviation} > {max_deviation}");
}
Expand Down
21 changes: 10 additions & 11 deletions src/verification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,19 @@
//! ## Selecting subset of K3 proven indices
//!
//! ```text
//! seed = concat(ch, nonce, all_indices, k2pow, k3pow)
//! seed = concat(ch, nonce, indices, k2pow, k3pow)
//! random_bytes = blake3(seed) // infinite blake output
//! for (j:=0; j<K3; j++)
//! max_allowed = u16::MAX - (u16::MAX % (K2 - j))
//! for (index=0; index<K3; index++)
//! remaining = K2 - index
//! max_allowed = u16::MAX - (u16::MAX % remaining)
//! do {
//! rand_num = random_bytes.read_u16_le()
//! } while rand_num > max_allowed;
//! } while rand_num >= max_allowed;
//!
//! index = rand_num % (K2 - j)
//! if validate_label(all_indices[index]) is INVALID
//! return INVALID
//! all_indices[index] = all_indices[k2-j-1]
//! return true
//! to_swap = (rand_num % remaining) + index
//! indices.swap(index, to_swap)
//! ```
//! indices[0..K3] now contains randomly picked values
//!
//! ## Verifying K3 indexes
//!
Expand All @@ -49,7 +48,7 @@ use crate::{
metadata::ProofMetadata,
pow::{hash_k2_pow, hash_k3_pow},
prove::Proof,
random_values_gen::FisherYatesShuffle,
random_values_gen::RandomValuesIterator,
};

#[inline]
Expand Down Expand Up @@ -156,7 +155,7 @@ pub fn verify(
&proof.k3_pow.to_le_bytes(),
];

let k3_indices = FisherYatesShuffle::new(indices_unpacked, seed).take(params.k3 as usize);
let k3_indices = RandomValuesIterator::new(indices_unpacked, seed).take(params.k3 as usize);

let pool = rayon::ThreadPoolBuilder::new()
.num_threads(threads)
Expand Down

0 comments on commit e7dedcb

Please sign in to comment.