Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract EvalCheck trait from Hal #252

Merged
merged 2 commits into from
Aug 24, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
77 changes: 10 additions & 67 deletions risc0/zkp/rust/src/hal/cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,23 @@ use rayon::prelude::*;

use super::{Buffer, Hal};
use crate::{
adapter::{PolyFp, PolyFpContext},
core::{
fp::Fp,
fp4::{Fp4, EXT_SIZE},
log2_ceil,
ntt::{bit_rev_32, bit_reverse, evaluate_ntt, expand, interpolate_ntt},
rou::ROU_FWD,
sha::{Digest, Sha},
sha_cpu,
},
field::Elem,
FRI_FOLD, INV_RATE,
FRI_FOLD,
};

pub struct CpuHal<'a, C: PolyFp> {
circuit: &'a C,
}
pub struct CpuHal {}

impl<'a, C: PolyFp> CpuHal<'a, C> {
pub fn new(circuit: &'a C) -> Self {
CpuHal { circuit }
impl CpuHal {
pub fn new() -> Self {
CpuHal {}
}
}

Expand Down Expand Up @@ -92,15 +88,15 @@ impl<T: Default + Clone + Pod> CpuBuffer<T> {
}
}

fn as_slice<'a>(&'a self) -> Ref<'a, [T]> {
pub fn as_slice<'a>(&'a self) -> Ref<'a, [T]> {
let vec = self.buf.borrow();
Ref::map(vec, |vec| {
let slice = bytemuck::cast_slice(vec);
&slice[self.region.range()]
})
}

fn as_slice_mut<'a>(&'a self) -> RefMut<'a, [T]> {
pub fn as_slice_mut<'a>(&'a self) -> RefMut<'a, [T]> {
let vec = self.buf.borrow_mut();
RefMut::map(vec, |vec| {
let slice = bytemuck::cast_slice_mut(vec);
Expand Down Expand Up @@ -136,7 +132,7 @@ impl<T: Pod> Buffer<T> for CpuBuffer<T> {
}
}

impl<'a, E: PolyFp> Hal for CpuHal<'a, E> {
impl Hal for CpuHal {
type BufferFp = CpuBuffer<Fp>;
type BufferFp4 = CpuBuffer<Fp4>;
type BufferDigest = CpuBuffer<Digest>;
Expand Down Expand Up @@ -414,49 +410,6 @@ impl<'a, E: PolyFp> Hal for CpuHal<'a, E> {
*output = *sha.hash_pair(&input[0], &input[1]);
});
}

fn eval_check(
&self,
_circuit: &str,
check: &CpuBuffer<Fp>,
code: &CpuBuffer<Fp>,
data: &CpuBuffer<Fp>,
accum: &CpuBuffer<Fp>,
mix: &CpuBuffer<Fp>,
out: &CpuBuffer<Fp>,
poly_mix: Fp4,
po2: usize,
steps: usize,
) {
const EXP_PO2: usize = log2_ceil(INV_RATE);

let domain = steps * INV_RATE;
let code = code.as_slice();
let data = data.as_slice();
let accum = accum.as_slice();
let mix = mix.as_slice();
let out = out.as_slice();
let mut check = check.as_slice_mut();
// TODO: parallelize
for cycle in 0..domain {
let args: &[&[Fp]] = &[&code, &out, &data, &mix, &accum];
let cond = self.circuit.poly_fp(
&PolyFpContext {
size: domain,
cycle,
mix: poly_mix,
},
args,
);
let x = Fp::new(ROU_FWD[po2 + EXP_PO2]).pow(cycle);
// TODO: what is this magic number 3?
let y = (Fp::new(3) * x).pow(1 << po2);
let ret = cond.tot * (y - Fp::new(1)).inv();
for i in 0..EXT_SIZE {
check[i * domain + cycle] = ret.elems()[i];
}
}
}
}

#[cfg(test)]
Expand All @@ -465,28 +418,18 @@ mod test {

use super::*;

struct PolyFpMock {}

impl PolyFp for PolyFpMock {
fn poly_fp(&self, _ctx: &PolyFpContext, _args: &[&[Fp]]) -> crate::adapter::MixState {
unimplemented!()
}
}

#[test]
#[should_panic]
fn check_req() {
let mock = PolyFpMock {};
let hal = CpuHal::new(&mock);
let hal = CpuHal::new();
let a = hal.alloc_fp(10);
let b = hal.alloc_fp(20);
hal.eltwise_add_fp(&a, &b, &b);
}

#[test]
fn fp() {
let mock = PolyFpMock {};
let hal = CpuHal::new(&mock);
let hal = CpuHal::new();
const COUNT: usize = 1024 * 1024;
test_binary(
&hal,
Expand Down
15 changes: 8 additions & 7 deletions risc0/zkp/rust/src/hal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,18 @@ pub trait Hal {
fn sha_rows(&self, output: &Self::BufferDigest, matrix: &Self::BufferFp);

fn sha_fold(&self, io: &Self::BufferDigest, input_size: usize, output_size: usize);
}

pub trait EvalCheck<H: Hal> {
/// Compute check polynomial.
fn eval_check(
&self,
circuit: &str,
check: &Self::BufferFp,
code: &Self::BufferFp,
data: &Self::BufferFp,
accum: &Self::BufferFp,
mix: &Self::BufferFp,
out: &Self::BufferFp,
check: &H::BufferFp,
code: &H::BufferFp,
data: &H::BufferFp,
accum: &H::BufferFp,
mix: &H::BufferFp,
out: &H::BufferFp,
poly_mix: Fp4,
po2: usize,
steps: usize,
Expand Down
14 changes: 9 additions & 5 deletions risc0/zkp/rust/src/prove/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use crate::{
sha::Sha,
},
field::Elem,
hal::{Buffer, Hal},
hal::{Buffer, EvalCheck, Hal},
prove::{fri::fri_prove, poly_group::PolyGroup, write_iop::WriteIOP},
taps::{RegisterGroup, TapSet},
CHECK_SIZE, INV_RATE, MAX_CYCLES_PO2,
Expand Down Expand Up @@ -64,12 +64,17 @@ pub trait Circuit {
fn get_steps(&self) -> usize;
}

pub fn prove_without_seal<H: Hal, S: Sha, C: Circuit>(_hal: &H, sha: &S, circuit: &mut C) {
pub fn prove_without_seal<S: Sha, C: Circuit>(sha: &S, circuit: &mut C) {
let mut iop = WriteIOP::new(sha);
circuit.execute(&mut iop);
}

pub fn prove<H: Hal, S: Sha, C: Circuit>(hal: &H, sha: &S, circuit: &mut C) -> Vec<u32> {
pub fn prove<H: Hal, S: Sha, C: Circuit, E: EvalCheck<H>>(
hal: &H,
sha: &S,
circuit: &mut C,
eval: &E,
) -> Vec<u32> {
let taps = circuit.get_taps();
let code_size = taps.group_size(RegisterGroup::Code);
let data_size = taps.group_size(RegisterGroup::Data);
Expand Down Expand Up @@ -111,8 +116,7 @@ pub fn prove<H: Hal, S: Sha, C: Circuit>(hal: &H, sha: &S, circuit: &mut C) -> V
let check_poly = hal.alloc_fp(EXT_SIZE * domain);
let mix = hal.copy_fp_from(circuit.get_mix());
let out = hal.copy_fp_from(circuit.get_output());
hal.eval_check(
"rv32im",
eval.eval_check(
&check_poly,
&code_group.evaluated,
&data_group.evaluated,
Expand Down
5 changes: 2 additions & 3 deletions risc0/zkvm/sdk/rust/src/method_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,14 @@ impl MethodId {

#[cfg(feature = "prove")]
pub fn compute_with_limit(elf_contents: &[u8], limit: u32) -> Result<Self> {
use crate::{elf::Program, prove::CIRCUIT, CODE_SIZE};
use crate::{elf::Program, CODE_SIZE};
use risc0_zkp::{
hal::{cpu::CpuHal, Hal},
prove::poly_group::PolyGroup,
};
use risc0_zkvm_circuit::CircuitImpl;
use risc0_zkvm_platform::memory::MEM_SIZE;

let hal = CpuHal::<CircuitImpl>::new(&CIRCUIT);
let hal = CpuHal::new();
let program = Program::load_elf(elf_contents, MEM_SIZE as u32)?;

// Start with an empty table
Expand Down
83 changes: 83 additions & 0 deletions risc0/zkvm/sdk/rust/src/prove/cpu_eval.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright 2022 Risc0, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use risc0_zkp::{
adapter::{PolyFp, PolyFpContext},
core::{
fp::Fp,
fp4::{Fp4, EXT_SIZE},
log2_ceil,
rou::ROU_FWD,
},
field::Elem,
hal::{
cpu::{CpuBuffer, CpuHal},
EvalCheck,
},
INV_RATE,
};

pub struct CpuEvalCheck<'a, C: PolyFp> {
circuit: &'a C,
}

impl<'a, C: PolyFp> CpuEvalCheck<'a, C> {
pub fn new(circuit: &'a C) -> Self {
Self { circuit }
}
}

impl<'a, C: PolyFp> EvalCheck<CpuHal> for CpuEvalCheck<'a, C> {
fn eval_check(
&self,
check: &CpuBuffer<Fp>,
code: &CpuBuffer<Fp>,
data: &CpuBuffer<Fp>,
accum: &CpuBuffer<Fp>,
mix: &CpuBuffer<Fp>,
out: &CpuBuffer<Fp>,
poly_mix: Fp4,
po2: usize,
steps: usize,
) {
const EXP_PO2: usize = log2_ceil(INV_RATE);

let domain = steps * INV_RATE;
let code = code.as_slice();
let data = data.as_slice();
let accum = accum.as_slice();
let mix = mix.as_slice();
let out = out.as_slice();
let mut check = check.as_slice_mut();
// TODO: parallelize
for cycle in 0..domain {
let args: &[&[Fp]] = &[&code, &out, &data, &mix, &accum];
let cond = self.circuit.poly_fp(
&PolyFpContext {
size: domain,
cycle,
mix: poly_mix,
},
args,
);
let x = Fp::new(ROU_FWD[po2 + EXP_PO2]).pow(cycle);
// TODO: what is this magic number 3?
let y = (Fp::new(3) * x).pow(1 << po2);
let ret = cond.tot * (y - Fp::new(1)).inv();
for i in 0..EXT_SIZE {
check[i * domain + cycle] = ret.elems()[i];
}
}
}
}
2 changes: 0 additions & 2 deletions risc0/zkvm/sdk/rust/src/prove/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,6 @@ impl MemoryState {
}
}

impl MemoryState {}

fn split_word(value: u32) -> (Fp, Fp) {
(Fp::new(value & 0xffff), Fp::new(value >> 16))
}
Expand Down
22 changes: 17 additions & 5 deletions risc0/zkvm/sdk/rust/src/prove/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.

mod cpu_eval;
pub mod exec;

use std::io::Write;

use anyhow::Result;
use lazy_static::lazy_static;
use risc0_zkp::{
core::sha::default_implementation, hal::cpu::CpuHal, prove::adapter::ProveAdapter,
core::sha::default_implementation,
hal::{cpu::CpuHal, EvalCheck, Hal},
prove::adapter::ProveAdapter,
};
use risc0_zkvm_circuit::CircuitImpl;
use risc0_zkvm_platform::{
Expand All @@ -29,6 +32,8 @@ use risc0_zkvm_platform::{

use crate::{elf::Program, host::ProverOpts, method_id::MethodId, receipt::Receipt};

use self::cpu_eval::CpuEvalCheck;

lazy_static! {
pub static ref CIRCUIT: CircuitImpl = CircuitImpl::new();
}
Expand Down Expand Up @@ -67,20 +72,27 @@ impl<'a> Prover<'a> {
}

pub fn run(&mut self) -> Result<Receipt> {
let hal = CpuHal::new();
let circuit: &CircuitImpl = &CIRCUIT;
let eval = CpuEvalCheck::new(circuit);
self.run_with_hal(&hal, &eval)
}

pub fn run_with_hal<H: Hal, E: EvalCheck<H>>(&mut self, hal: &H, eval: &E) -> Result<Receipt> {
let skip_seal = self.inner.opts.skip_seal;

let mut executor = exec::RV32Executor::new(&CIRCUIT, &self.elf, &mut self.inner);
let circuit: &CircuitImpl = &CIRCUIT;
let mut executor = exec::RV32Executor::new(circuit, &self.elf, &mut self.inner);
executor.run()?;

let mut prover = ProveAdapter::new(&mut executor.executor);
let hal = CpuHal::<CircuitImpl>::new(&CIRCUIT);
let sha = default_implementation();

let seal = if skip_seal {
risc0_zkp::prove::prove_without_seal(&hal, sha, &mut prover);
risc0_zkp::prove::prove_without_seal(sha, &mut prover);
Vec::new()
} else {
risc0_zkp::prove::prove(&hal, sha, &mut prover)
risc0_zkp::prove::prove(hal, sha, &mut prover, eval)
};

// Attach the full version of the output journal & construct receipt object
Expand Down