Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include tapsets in generated rust code instead of making rust do a bu… #242

Merged
merged 1 commit into from
Aug 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion risc0/zkp/rust/src/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ pub trait PolyExt {
}

pub trait TapsProvider {
fn get_taps(&self) -> &TapSet;
fn get_taps(&self) -> &'static TapSet<'static>;
}

pub trait CircuitInfo {
Expand Down
2 changes: 1 addition & 1 deletion risc0/zkp/rust/src/prove/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ impl<'a, C: CircuitDef<S>, S: CustomStep> ProveAdapter<'a, C, S> {
}

impl<'a, C: CircuitDef<CS>, CS: CustomStep> Circuit for ProveAdapter<'a, C, CS> {
fn get_taps(&self) -> &TapSet {
fn get_taps(&self) -> &'static TapSet<'static> {
self.exec.circuit.get_taps()
}

Expand Down
5 changes: 3 additions & 2 deletions risc0/zkp/rust/src/prove/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ pub struct Executor<C: CircuitDef<S>, S: CustomStep> {
impl<C: CircuitDef<S>, S: CustomStep> Executor<C, S> {
pub fn new(circuit: C, custom: S, min_po2: usize, max_po2: usize) -> Self {
let po2 = max(min_po2, MIN_PO2);
let code_size = circuit.get_taps().group_size(RegisterGroup::Code);
let data_size = circuit.get_taps().group_size(RegisterGroup::Data);
let taps = circuit.get_taps();
let code_size = taps.group_size(RegisterGroup::Code);
let data_size = taps.group_size(RegisterGroup::Data);
let steps = 1 << po2;
let output_size = circuit.output_size();
debug!("po2: {po2}, steps: {steps}, code_size: {code_size}");
Expand Down
4 changes: 2 additions & 2 deletions risc0/zkp/rust/src/prove/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use crate::{
};

pub trait Circuit {
fn get_taps(&self) -> &TapSet;
fn get_taps(&self) -> &'static TapSet<'static>;

/// Perform initial 'execution' setting code + data.
/// Additionally, write any 'results' as needed.
Expand Down Expand Up @@ -74,7 +74,7 @@ pub fn prove_without_seal<H: Hal, S: Sha, C: Circuit>(_hal: &H, sha: &S, circuit
}

pub fn prove<H: Hal, S: Sha, C: Circuit>(hal: &H, sha: &S, circuit: &mut C) -> Vec<u32> {
let taps = circuit.get_taps().clone();
let taps = circuit.get_taps();
let code_size = taps.group_size(RegisterGroup::Code);
let data_size = taps.group_size(RegisterGroup::Data);
let accum_size = taps.group_size(RegisterGroup::Accum);
Expand Down
183 changes: 99 additions & 84 deletions risc0/zkp/rust/src/taps.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

use alloc::{
collections::{BTreeMap, BTreeSet},
rc::Rc,
vec,
vec::Vec,
};
Expand Down Expand Up @@ -90,20 +89,97 @@ impl PartialOrd for TapData {
}
}

struct TapSetData {
#[derive(Debug)]
pub struct TapSet<'a> {
pub taps: &'a [TapData],
pub combo_taps: &'a [u16],
pub combo_begin: &'a [u16],
pub group_begin: [usize; REGISTER_GROUPS.len() + 1],
pub combos_count: usize,
}

impl<'a> TapSet<'a> {
pub fn tap_size(&self) -> usize {
self.group_begin[REGISTER_GROUPS.len()]
}

pub fn taps(&self) -> TapIter {
TapIter {
data: &self.taps,
cursor: 0,
end: self.group_begin[REGISTER_GROUPS.len()],
}
}

pub fn regs(&self) -> RegisterIter {
RegisterIter {
data: &self.taps,
cursor: 0,
end: self.group_begin[REGISTER_GROUPS.len()],
}
}

pub fn group_taps(&self, group: RegisterGroup) -> TapIter {
let group_id = group as usize;
TapIter {
data: &self.taps,
cursor: self.group_begin[group_id],
end: self.group_begin[group_id + 1],
}
}

pub fn group_regs(&self, group: RegisterGroup) -> RegisterIter {
let group_id = group as usize;
RegisterIter {
data: &self.taps,
cursor: self.group_begin[group_id],
end: self.group_begin[group_id + 1],
}
}

pub fn group_size(&self, group: RegisterGroup) -> usize {
let group_id = group as usize;
let idx = self.group_begin[group_id + 1] - 1;
let last = self.taps[idx].offset as usize;
last + 1
}

// size_t combosSize() const { return data_->combos.count; }
pub fn combos_size(&self) -> usize {
self.combos_count
}

pub fn combos(&self) -> ComboIter {
ComboIter {
data: ComboData {
taps: &self.combo_taps,
offsets: &self.combo_begin,
},
id: 0,
end: self.combos_count,
}
}

pub fn get_combo(&self, id: usize) -> ComboRef {
ComboRef {
data: ComboData {
taps: &self.combo_taps,
offsets: &self.combo_begin,
},
id,
}
}
}

pub struct TapSetOwned {
taps: Vec<TapData>,
combo_taps: Vec<u16>,
combo_begin: Vec<u16>,
group_begin: [usize; REGISTER_GROUPS.len() + 1],
combos_count: usize,
}

#[derive(Clone)]
pub struct TapSet {
data: Rc<TapSetData>,
}

impl TapSet {
impl TapSetOwned {
pub fn new(raw: &[Tap]) -> Self {
type Reg = BTreeSet<usize>;
type Group = BTreeMap<usize, Reg>;
Expand Down Expand Up @@ -165,85 +241,24 @@ impl TapSet {
}
combo_begin.push(combo_taps.len().try_into().unwrap());
assert!(combo_taps.len() < 64 * 1024);
TapSet {
data: Rc::new(TapSetData {
taps,
combo_taps,
combo_begin,
group_begin,
combos_count: combos.len(),
}),
}
}

pub fn tap_size(&self) -> usize {
self.data.group_begin[REGISTER_GROUPS.len()]
}

pub fn taps(&self) -> TapIter {
TapIter {
data: &self.data.taps,
cursor: 0,
end: self.data.group_begin[REGISTER_GROUPS.len()],
}
}

pub fn regs(&self) -> RegisterIter {
RegisterIter {
data: &self.data.taps,
cursor: 0,
end: self.data.group_begin[REGISTER_GROUPS.len()],
}
}

pub fn group_taps(&self, group: RegisterGroup) -> TapIter {
let group_id = group as usize;
TapIter {
data: &self.data.taps,
cursor: self.data.group_begin[group_id],
end: self.data.group_begin[group_id + 1],
}
}

pub fn group_regs(&self, group: RegisterGroup) -> RegisterIter {
let group_id = group as usize;
RegisterIter {
data: &self.data.taps,
cursor: self.data.group_begin[group_id],
end: self.data.group_begin[group_id + 1],
TapSetOwned {
taps,
combo_taps,
combo_begin,
group_begin,
combos_count: combos.len(),
}
}
}

pub fn group_size(&self, group: RegisterGroup) -> usize {
let group_id = group as usize;
let idx = self.data.group_begin[group_id + 1] - 1;
let last = self.data.taps[idx].offset as usize;
last + 1
}

// size_t combosSize() const { return data_->combos.count; }
pub fn combos_size(&self) -> usize {
self.data.combos_count
}

pub fn combos(&self) -> ComboIter {
ComboIter {
data: ComboData {
taps: &self.data.combo_taps,
offsets: &self.data.combo_begin,
},
id: 0,
end: self.data.combos_count,
}
}

pub fn get_combo(&self, id: usize) -> ComboRef {
ComboRef {
data: ComboData {
taps: &self.data.combo_taps,
offsets: &self.data.combo_begin,
},
id,
impl<'a> From<&'a TapSetOwned> for TapSet<'a> {
fn from(owned: &'a TapSetOwned) -> Self {
Self {
taps: owned.taps.as_slice(),
combo_taps: owned.combo_taps.as_slice(),
combo_begin: owned.combo_begin.as_slice(),
group_begin: owned.group_begin,
combos_count: owned.combos_count,
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion risc0/zkp/rust/src/verify/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ impl<'a, C: CircuitInfo + PolyExt + TapsProvider> VerifyAdapter<'a, C> {
}

impl<'a, C: CircuitInfo + PolyExt + TapsProvider> Circuit for VerifyAdapter<'a, C> {
fn taps(&self) -> &TapSet {
fn taps(&self) -> &'static TapSet<'static> {
self.circuit.get_taps()
}

Expand Down
4 changes: 2 additions & 2 deletions risc0/zkp/rust/src/verify/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ impl fmt::Display for VerificationError {
}

pub trait Circuit {
fn taps(&self) -> &TapSet;
fn taps(&self) -> &'static TapSet<'static>;
fn execute<S: Sha>(&mut self, iop: &mut ReadIOP<S>);
fn accumulate<S: Sha>(&mut self, iop: &mut ReadIOP<S>);
fn po2(&self) -> u32;
Expand All @@ -70,7 +70,7 @@ where
if seal.len() == 0 {
return Err(VerificationError::ReceiptFormatError);
}
let taps = circuit.taps().clone();
let taps = circuit.taps();

// Make IOP
let mut iop = ReadIOP::new(sha, seal);
Expand Down
32 changes: 24 additions & 8 deletions risc0/zkvm/sdk/rust/circuit/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,17 @@ use risc0_zkp::{
taps::TapSet,
};

pub struct CircuitImpl {
taps: TapSet,
}
pub struct CircuitImpl;

impl CircuitImpl {
pub fn new() -> Self {
CircuitImpl {
taps: TapSet::new(taps::TAPS),
}
CircuitImpl
}
}

impl TapsProvider for CircuitImpl {
fn get_taps(&self) -> &TapSet {
&self.taps
fn get_taps(&self) -> &'static TapSet<'static> {
taps::TAPSET
}
}

Expand All @@ -41,3 +37,23 @@ impl CircuitInfo for CircuitImpl {
20
}
}

#[cfg(test)]
mod test {
use super::taps;
use risc0_zkp::taps::{TapSet, TapSetOwned};

#[test]
fn generated_tapset_matches() {
let cirgen_generated = taps::TAPSET;
let rs_generated = TapSetOwned::new(taps::TAPS);

// TapData includes its own PartialEq implementation which
// skips some fields, so make sure the debug representation of
// these two structures are identical.
assert_eq!(
format!("{:?}", &TapSet::from(&rs_generated)),
format!("{:?}", cirgen_generated)
);
}
}