Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve zkp performance #265

Merged
merged 1 commit into from
Aug 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
24 changes: 18 additions & 6 deletions risc0/zkp/rust/src/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,10 +335,11 @@ impl CircuitStep {
}
CircuitStep::If(cond, block, _loc) => {
if stack[*cond] != Fp::new(0) {
let mut stack = stack.clone();
let stacklen = stack.len();
for op in block.iter() {
op.step(&mut stack, ctx, custom, args)?;
op.step(stack, ctx, custom, args)?;
}
stack.truncate(stacklen);
}
}
CircuitStep::Add(x1, x2, _loc) => {
Expand Down Expand Up @@ -376,10 +377,11 @@ impl CircuitStep {
stack.extend(custom.call(name, extra, &args)?);
}
CircuitStep::Nondet(block, _loc) => {
let mut stack = stack.clone();
let stacklen = stack.len();
for op in block.iter() {
op.step(&mut stack, ctx, custom, args)?;
op.step(stack, ctx, custom, args)?;
}
stack.truncate(stacklen);
}
})
}
Expand Down Expand Up @@ -559,11 +561,21 @@ impl PolyExtStep {

impl PolyExtStepDef {
pub fn step(&self, ctx: &PolyExtContext, u: &[Fp4], args: &[&[Fp]]) -> MixState {
let mut fp_vars = Vec::new();
let mut mix_vars = Vec::new();
let mut fp_vars = Vec::with_capacity(self.block.len() - (self.ret + 1));
let mut mix_vars = Vec::with_capacity(self.ret + 1);
for op in self.block.iter() {
op.step(&mut fp_vars, &mut mix_vars, ctx, u, args);
}
assert_eq!(
fp_vars.len(),
self.block.len() - (self.ret + 1),
"Miscalculated capacity for fp_vars"
);
assert_eq!(
mix_vars.len(),
self.ret + 1,
"Miscalculated capacity for mix_vars"
);
mix_vars[self.ret]
}
}
11 changes: 10 additions & 1 deletion risc0/zkp/rust/src/taps.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ pub struct TapSet<'a> {
pub combo_begin: &'a [u16],
pub group_begin: [usize; REGISTER_GROUPS.len() + 1],
pub combos_count: usize,
pub reg_count: usize,
}

impl<'a> TapSet<'a> {
Expand Down Expand Up @@ -149,6 +150,10 @@ impl<'a> TapSet<'a> {
self.combos_count
}

pub fn reg_count(&self) -> usize {
self.reg_count
}

pub fn combos(&self) -> ComboIter {
ComboIter {
data: ComboData {
Expand Down Expand Up @@ -177,6 +182,7 @@ pub struct TapSetOwned {
combo_begin: Vec<u16>,
group_begin: [usize; REGISTER_GROUPS.len() + 1],
combos_count: usize,
reg_count: usize,
}

impl TapSetOwned {
Expand All @@ -198,7 +204,7 @@ impl TapSetOwned {
let mut combo_begin = Vec::new();
let mut combo_taps = Vec::new();
let mut taps = Vec::new();

let mut tot_reg_count = 0;
// Pre-insert the 'only self' combo
let myself = BTreeSet::from([0_usize]);
let mut combos = vec![&myself];
Expand All @@ -211,6 +217,7 @@ impl TapSetOwned {
group_begin[group_id] = taps.len();
let regs = all.get(group).unwrap();
let reg_count = regs.keys().last().unwrap() + 1;
tot_reg_count += reg_count;
for reg in 0..reg_count {
// Make sure all registers have at least one tap
assert!(regs.contains_key(&reg));
Expand Down Expand Up @@ -247,6 +254,7 @@ impl TapSetOwned {
combo_begin,
group_begin,
combos_count: combos.len(),
reg_count: tot_reg_count,
}
}
}
Expand All @@ -259,6 +267,7 @@ impl<'a> From<&'a TapSetOwned> for TapSet<'a> {
combo_begin: owned.combo_begin.as_slice(),
group_begin: owned.group_begin,
combos_count: owned.combos_count,
reg_count: owned.reg_count,
}
}
}
Expand Down
36 changes: 24 additions & 12 deletions risc0/zkp/rust/src/verify/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use alloc::vec;
use alloc::vec::Vec;

use rand::RngCore;

Expand All @@ -27,7 +27,7 @@ use crate::{
},
field::{baby_bear::BabyBear, Elem},
verify::{merkle::MerkleTreeVerifier, read_iop::ReadIOP, VerificationError},
FRI_FOLD, FRI_MIN_DEGREE, INV_RATE, QUERIES,
FRI_FOLD, FRI_FOLD_PO2, FRI_MIN_DEGREE, INV_RATE, QUERIES,
};

/// VerifyRoundInfo contains the data against which the queries for a particular
Expand Down Expand Up @@ -75,15 +75,16 @@ impl<'a, S: Sha> VerifyRoundInfo<'a, S> {
let group = *pos % self.domain;
// Get the column data
let data = self.merkle.verify::<BabyBear>(iop, group)?;
let mut data4 = vec![];
for i in 0..FRI_FOLD {
data4.push(Fp4::new(
data[0 * FRI_FOLD + i],
data[1 * FRI_FOLD + i],
data[2 * FRI_FOLD + i],
data[3 * FRI_FOLD + i],
));
}
let mut data4: Vec<_> = (0..FRI_FOLD)
.map(|i| {
Fp4::new(
data[0 * FRI_FOLD + i],
data[1 * FRI_FOLD + i],
data[2 * FRI_FOLD + i],
data[3 * FRI_FOLD + i],
)
})
.collect();
// Check the existing goal
if data4[quot] != *goal {
return Err(VerificationError::InvalidProof);
Expand All @@ -106,12 +107,23 @@ where
let orig_domain = INV_RATE * degree;
let mut domain = orig_domain;
// Prep the folding verfiers
let mut rounds = vec![];
let rounds_capacity =
(log2_ceil((degree + FRI_FOLD - 1) / FRI_FOLD) + FRI_FOLD_PO2 - 1) / FRI_FOLD_PO2;
let mut rounds = Vec::with_capacity(rounds_capacity);
while degree > FRI_MIN_DEGREE {
rounds.push(VerifyRoundInfo::new(iop, domain));
domain /= FRI_FOLD;
degree /= FRI_FOLD;
}
// We want to minimize reallocation in verify, so make sure we
// didn't have to reallocate.
assert!(
rounds.len() < rounds_capacity,
"Did not allocate enough rounds; needed {} for degree {} but only allocated {}",
rounds.len(),
degree,
rounds_capacity
);
// Grab the final coeffs + commit
let final_coeffs = iop.read_pod_slice(EXT_SIZE * degree);
let final_digest = iop.get_sha().hash_raw_pod_slice(final_coeffs);
Expand Down
55 changes: 38 additions & 17 deletions risc0/zkp/rust/src/verify/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ mod fri;
pub(crate) mod merkle;
pub mod read_iop;

use alloc::vec;
use alloc::{vec, vec::Vec};
use core::fmt;
use core::iter::zip;
// use log::debug;

use crate::{
Expand Down Expand Up @@ -137,7 +138,7 @@ where

// Now, convert to evaluated values
let mut cur_pos = 0;
let mut eval_u = vec![];
let mut eval_u = Vec::with_capacity(num_taps);
for reg in taps.regs() {
for i in 0..reg.size() {
let x = back_one.pow(reg.back(i)) * z;
Expand All @@ -146,6 +147,7 @@ where
}
cur_pos += reg.size();
}
assert_eq!(eval_u.len(), num_taps, "Miscalculated capacity for eval_us");

// Compute the core polynomial
let result = circuit.compute_polynomial(&eval_u, poly_mix);
Expand Down Expand Up @@ -174,27 +176,48 @@ where
// debug!("mix = {mix:?}");

// Make the mixed U polynomials
let mut combo_u = vec![];
for i in 0..combo_count {
combo_u.push(vec![Fp4::ZERO; taps.get_combo(i).size()]);
}
let mut combo_u: Vec<Vec<Fp4>> = Vec::with_capacity(combo_count + 1);
combo_u.extend(
(0..combo_count)
.into_iter()
.map(|i| vec![Fp4::ZERO; taps.get_combo(i).size()]),
);
let mut cur_mix = Fp4::ONE;
cur_pos = 0;
let mut tap_mix_pows = Vec::with_capacity(taps.reg_count());
for reg in taps.regs() {
for i in 0..reg.size() {
combo_u[reg.combo_id()][i] += cur_mix * coeff_u[cur_pos + i];
}
tap_mix_pows.push(cur_mix);
cur_mix *= mix;
cur_pos += reg.size();
}
assert_eq!(
tap_mix_pows.len(),
taps.reg_count(),
"Miscalculated capacity for tap_mix_pows"
);
// debug!("cur_mix: {cur_mix:?}, cur_pos: {cur_pos}");
// Handle check group
combo_u.push(vec![Fp4::ZERO]);
assert_eq!(
combo_u.len(),
combo_count + 1,
"Miscalculated capacity for combo_u"
);
let mut check_mix_pows = Vec::with_capacity(CHECK_SIZE);
for _ in 0..CHECK_SIZE {
combo_u[combo_count][0] += cur_mix * coeff_u[cur_pos];
cur_pos += 1;
check_mix_pows.push(cur_mix);
cur_mix *= mix;
}
assert_eq!(
check_mix_pows.len(),
CHECK_SIZE,
"Miscalculated capacity for check_mix_pows"
);
// debug!("cur_mix: {cur_mix:?}");

let gen = Fp::new(ROU_FWD[log2_ceil(domain)]);
Expand All @@ -204,20 +227,18 @@ where
size,
|iop: &mut ReadIOP<S>, idx: usize| -> Result<Fp4, VerificationError> {
let x = Fp4::from_fp(gen.pow(idx));
let mut rows = vec![];
rows.push(accum_merkle.verify::<BabyBear>(iop, idx)?);
rows.push(code_merkle.verify::<BabyBear>(iop, idx)?);
rows.push(data_merkle.verify::<BabyBear>(iop, idx)?);
let rows = [
accum_merkle.verify::<BabyBear>(iop, idx)?,
code_merkle.verify::<BabyBear>(iop, idx)?,
data_merkle.verify::<BabyBear>(iop, idx)?,
];
let check_row = check_merkle.verify::<BabyBear>(iop, idx)?;
let mut cur = Fp4::ONE;
let mut tot = vec![Fp4::ZERO; combo_count + 1];
for reg in taps.regs() {
tot[reg.combo_id()] += cur * rows[reg.group() as usize][reg.offset()];
cur *= mix;
for (reg, cur) in zip(taps.regs(), tap_mix_pows.iter()) {
tot[reg.combo_id()] += *cur * rows[reg.group() as usize][reg.offset()];
}
for i in 0..CHECK_SIZE {
tot[combo_count] += cur * check_row[i];
cur *= mix;
for (i, cur) in zip(0..CHECK_SIZE, check_mix_pows.iter()) {
tot[combo_count] += *cur * check_row[i];
}
let mut ret = Fp4::ZERO;
for i in 0..combo_count {
Expand Down
1 change: 1 addition & 0 deletions risc0/zkvm/sdk/rust/circuit/src/taps.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5949,4 +5949,5 @@ pub(crate) const TAPSET: &'static TapSet = &TapSet::<'static> {
combo_begin: &[0, 1, 3, 9, 16, 18],
group_begin: [0, 18, 34, 742],
combos_count: 5,
reg_count: 188,
};