Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes to support recursion #337

Merged
merged 2 commits into from
Jan 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19,776 changes: 9,895 additions & 9,881 deletions risc0/circuit/rv32im/cxx/poly_fp.cpp

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion risc0/circuit/rv32im/cxx/step_compute_accum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Fp step_compute_accum(void* ctx, HostBridge host, size_t steps, size_t cycle, Fp
size_t mask = steps - 1;
std::array<Fp, 32> host_args;
std::array<Fp, 32> host_outs;
// loc("cirgen/circuit/rv32im/ffpu.cpp":70:85)
// loc("cirgen/circuit/rv32im/ffpu.cpp":76:85)
Fp x0(2013265910);
// loc("./cirgen/components/onehot.h":35:32)
Fp x1(11);
Expand Down
8,358 changes: 4,239 additions & 4,119 deletions risc0/circuit/rv32im/cxx/step_exec.cpp

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion risc0/circuit/rv32im/cxx/step_verify_accum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Fp step_verify_accum(void* ctx, HostBridge host, size_t steps, size_t cycle, Fp*
size_t mask = steps - 1;
std::array<Fp, 32> host_args;
std::array<Fp, 32> host_outs;
// loc("cirgen/circuit/rv32im/ffpu.cpp":70:85)
// loc("cirgen/circuit/rv32im/ffpu.cpp":76:85)
Fp x0(2013265910);
// loc("./cirgen/components/onehot.h":35:32)
Fp x1(11);
Expand Down
17,687 changes: 8,847 additions & 8,840 deletions risc0/circuit/rv32im/kernels/eval_check.cu

Large diffs are not rendered by default.

17,687 changes: 8,847 additions & 8,840 deletions risc0/circuit/rv32im/kernels/eval_check.metal

Large diffs are not rendered by default.

14,623 changes: 7,315 additions & 7,308 deletions risc0/circuit/rv32im/src/poly_ext.rs

Large diffs are not rendered by default.

14 changes: 14 additions & 0 deletions risc0/zkvm/platform/src/syscall.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,18 @@ pub const DIGEST_BYTES: usize = WORD_SIZE * DIGEST_WORDS;
#[cfg(target_os = "zkvm")]
static mut READ_PTR: UnsafeCell<usize> = UnsafeCell::new(memory::INPUT.start());

/// Compute `ceil(a / b)` via truncated integer division.
#[allow(dead_code)]
const fn div_ceil(a: u32, b: u32) -> u32 {
(a + b - 1) / b
}

/// Round `a` up to the nearest multipe of `b`.
#[allow(dead_code)]
const fn round_up(a: u32, b: u32) -> u32 {
div_ceil(a, b) * b
}

#[inline(always)]
pub unsafe fn sys_panic(msg_ptr: *const u8, msg_len: usize) -> ! {
#[cfg(target_os = "zkvm")]
Expand Down Expand Up @@ -116,6 +128,7 @@ pub unsafe fn sys_io(channel: u32, buf_ptr: *const u8, buf_len: usize) -> &'stat
#[cfg(target_os = "zkvm")]
{
let read_ptr: &mut usize = &mut *READ_PTR.get();
*read_ptr = round_up(*read_ptr as u32, crate::PAGE_SIZE as u32) as usize;
let out_ptr = *read_ptr as *const u8;
let out_nbytes: usize;
asm!(
Expand Down Expand Up @@ -203,6 +216,7 @@ pub unsafe fn sys_compute_poly(
#[cfg(target_os = "zkvm")]
{
let read_ptr: &mut usize = &mut *READ_PTR.get();
*read_ptr = round_up(*read_ptr as u32, crate::PAGE_SIZE as u32) as usize;
let out_ptr = *read_ptr as *const u32;
let out_nwords: usize;
asm!(
Expand Down
150 changes: 90 additions & 60 deletions risc0/zkvm/src/prove/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,21 @@ use crate::{MemoryImage, CIRCUIT};

const IMM_BITS: usize = 12;

#[allow(dead_code)]
#[derive(Debug)]
enum MemoryOp {
PageIn,
Read,
Write,
PageOut,
}

impl MemoryOp {
fn as_u32(self) -> u32 {
self as u32
}
}

pub trait HostHandler {
fn is_trace_enabled(&self) -> bool;
fn on_commit(&mut self, buf: &[u32]) -> Result<()>;
Expand Down Expand Up @@ -245,29 +260,31 @@ impl OpCode {
}
}

struct PageFaults {
struct PageFaults<'a> {
reads: BTreeSet<u32>,
info: &'a PageTableInfo,
}

impl PageFaults {
pub fn new() -> Self {
impl<'a> PageFaults<'a> {
pub fn new(info: &'a PageTableInfo) -> Self {
Self {
reads: BTreeSet::new(),
info,
}
}

pub fn include(&mut self, addr: u32, info: &PageTableInfo) {
if addr < info.mem_start {
let page_idx = info.get_page_index_nondet(addr);
pub fn include(&mut self, addr: u32) {
if addr < self.info.mem_start {
let page_idx = self.info.get_page_index_nondet(addr);
self.reads.insert(page_idx);
} else {
let mut addr = addr;
loop {
let raw_page_idx = info.get_page_index_nondet(addr);
let page_idx = info.get_page_index(addr);
let entry_addr = info.get_page_entry_addr(page_idx);
let raw_page_idx = self.info.get_page_index_nondet(addr);
let page_idx = self.info.get_page_index(addr);
let entry_addr = self.info.get_page_entry_addr(page_idx);
self.reads.insert(raw_page_idx);
if raw_page_idx == info.raw_root_idx {
if raw_page_idx == self.info.raw_root_idx {
break;
}
addr = entry_addr;
Expand Down Expand Up @@ -362,6 +379,10 @@ impl<'a, H: HostHandler> CircuitStepHandler<BabyBearElem> for MachineContext<'a,
) = self.syscall(cycle)?;
Ok(())
}
"isResident" => {
outs[0] = self.is_resident(args[0]);
Ok(())
}
_ => unimplemented!("Unsupported extern: {name}"),
}
}
Expand Down Expand Up @@ -414,9 +435,9 @@ impl<'a, H: HostHandler> MachineContext<'a, H> {

fn get_page_faults(&self, pc: u32, inst: u32, opcode: &OpCode) -> PageFaults {
let info = &self.memory.ram.info;
let mut faults = PageFaults::new();
faults.include(SYSTEM.start() as u32, info);
faults.include(pc, info);
let mut faults = PageFaults::new(info);
faults.include(SYSTEM.start() as u32);
faults.include(pc);

if opcode.major == MajorType::MemIo {
let rs1 = (inst >> 15) & 0x1f;
Expand All @@ -427,7 +448,7 @@ impl<'a, H: HostHandler> MachineContext<'a, H> {
let imm = sign_extend(imm as i32, IMM_BITS as u32);
let addr = base.checked_add_signed(imm).unwrap();
// debug!(" load: 0x{inst:08x}, M[x{rs1} + {imm}] -> 0x{addr:08x}");
faults.include(addr, info);
faults.include(addr);
} else {
// store: S-type
let imm_low = (inst >> 7) & 0x1f;
Expand All @@ -436,39 +457,42 @@ impl<'a, H: HostHandler> MachineContext<'a, H> {
let imm = sign_extend(imm as i32, IMM_BITS as u32);
let addr = base.checked_add_signed(imm).unwrap();
debug!(" store: 0x{inst:08x}, M[x{rs1} + {imm}] -> 0x{addr:08x}");
faults.include(addr, info);
faults.include(addr);
}
} else if opcode.major == MajorType::ECall {
let minor = self.memory.load_register(REG_T0);
if minor == ecall::FFPU {
let code_addr = self.memory.load_register(REG_A0);
let args_addr = self.memory.load_register(REG_A1);
let code_end = self.memory.load_register(REG_A2);
let const_addr = self.memory.load_u32(args_addr + (0 * WORD_SIZE) as u32);
let input_addr = self.memory.load_u32(args_addr + (1 * WORD_SIZE) as u32);
let output_addr = self.memory.load_u32(args_addr + (1 * WORD_SIZE) as u32);
faults.include(code_addr, info);
faults.include(args_addr, info);
faults.include(const_addr, info);
faults.include(input_addr, info);
faults.include(output_addr, info);
let output_addr = self.memory.load_u32(args_addr + (2 * WORD_SIZE) as u32);
for addr in (code_addr..code_end).step_by(WORD_SIZE) {
faults.include(addr);
}
faults.include(args_addr);
faults.include(const_addr);
faults.include(input_addr);
faults.include(output_addr);
} else if minor == ecall::SHA {
let state_out_addr = self.memory.load_register(REG_A0);
let state_in_addr = self.memory.load_register(REG_A1);
let block1_addr = self.memory.load_register(REG_A2);
let block2_addr = self.memory.load_register(REG_A3);
let count = self.memory.load_register(REG_A4);
for i in 0..DIGEST_WORDS {
faults.include(state_out_addr + (i * WORD_SIZE) as u32, info);
faults.include(state_out_addr + (i * WORD_SIZE) as u32);
}
for i in 0..DIGEST_WORDS {
faults.include(state_in_addr + (i * WORD_SIZE) as u32, info);
faults.include(state_in_addr + (i * WORD_SIZE) as u32);
}
for i in 0..count {
let addr1 = block1_addr + i * BLOCK_SIZE as u32;
let addr2 = block2_addr + i * BLOCK_SIZE as u32;
for j in 0..DIGEST_WORDS {
faults.include(addr1 + (j * WORD_SIZE) as u32, info);
faults.include(addr2 + (j * WORD_SIZE) as u32, info);
faults.include(addr1 + (j * WORD_SIZE) as u32);
faults.include(addr2 + (j * WORD_SIZE) as u32);
}
}
}
Expand Down Expand Up @@ -627,39 +651,32 @@ impl<'a, H: HostHandler> MachineContext<'a, H> {
fn ram_read(
&mut self,
addr: BabyBearElem,
page_in: BabyBearElem,
op: BabyBearElem,
) -> (BabyBearElem, BabyBearElem, BabyBearElem, BabyBearElem) {
let addr: u32 = addr.into();
let page_in: u32 = page_in.into();
let op: u32 = op.into();
if op == MemoryOp::PageIn.as_u32() {
if self.memory.resident.replace(addr).is_some() {
panic!("Memory read already marked for page in: 0x{addr:08x}");
}
} else {
if !self.memory.resident.contains(&addr) {
let addr = addr * WORD_SIZE as u32;
if addr >= self.memory.ram.info.mem_start {
let page_idx = self.memory.ram.info.get_page_index(addr);
let entry_addr = self.memory.ram.info.get_page_entry_addr(page_idx);
debug!(" ram_read: 0x{addr:08x}, op: {op:?}, entry_addr: 0x{entry_addr:08x}");
}
panic!("Memory read before page in: 0x{addr:08x}");
}
}
if addr as usize * WORD_SIZE >= FFPU.start() {
let ffpu_addr = addr as usize - FFPU.start() / WORD_SIZE;
if ffpu_addr >= self.memory.ffpu_ram.len() {
return (
BabyBearElem::ZERO,
BabyBearElem::ZERO,
BabyBearElem::ZERO,
BabyBearElem::ZERO,
);
return split_word8(0);
}
self.memory.ffpu_ram[ffpu_addr]
} else {
if page_in == 1 {
self.memory.resident.insert(addr);
} else {
if !self.memory.resident.contains(&addr) {
let addr = addr * WORD_SIZE as u32;
if addr >= self.memory.ram.info.mem_start {
let page_idx = self.memory.ram.info.get_page_index(addr);
let entry_addr = self.memory.ram.info.get_page_entry_addr(page_idx);
debug!(" ram_read: 0x{addr:08x}, page_in: {page_in}, entry_addr: 0x{entry_addr:08x}");
}
}
assert!(
self.memory.resident.contains(&addr),
"Memory read before page in"
);
}

let addr = addr * WORD_SIZE as u32;
let word = self.memory.load_u32(addr);
// debug!("ram_read: 0x{addr:08X} -> 0x{word:08X}");
Expand All @@ -671,10 +688,21 @@ impl<'a, H: HostHandler> MachineContext<'a, H> {
&mut self,
addr: BabyBearElem,
data: (BabyBearElem, BabyBearElem, BabyBearElem, BabyBearElem),
page_in: BabyBearElem,
op: BabyBearElem,
) -> Result<()> {
let addr: u32 = addr.into();
let page_in: u32 = page_in.into();
let op: u32 = op.into();
if op == MemoryOp::PageIn.as_u32() {
if self.memory.resident.replace(addr).is_some() {
panic!("Memory write already marked for page in: 0x{addr:08x}");
}
} else {
assert!(
self.memory.resident.contains(&addr),
"Memory write before page in: 0x{addr:08x}"
);
}

if addr as usize * WORD_SIZE >= FFPU.start() {
let ffpu_addr = addr as usize - FFPU.start() / WORD_SIZE;
if self.memory.ffpu_ram.len() <= ffpu_addr {
Expand All @@ -683,14 +711,6 @@ impl<'a, H: HostHandler> MachineContext<'a, H> {
}
self.memory.ffpu_ram[ffpu_addr] = data
} else {
if page_in == 1 {
self.memory.resident.insert(addr);
} else {
assert!(
self.memory.resident.contains(&addr),
"Memory write before page in"
);
}
let data = merge_word8(data);
let addr = addr * WORD_SIZE as u32;
// debug!("ram_write> 0x{:08X} <= 0x{:08X}", addr, data);
Expand All @@ -710,9 +730,19 @@ impl<'a, H: HostHandler> MachineContext<'a, H> {
}
}
}

Ok(())
}

fn is_resident(&self, addr: BabyBearElem) -> BabyBearElem {
let addr: u32 = addr.into();
if self.memory.resident.contains(&addr) {
1u32.into()
} else {
0u32.into()
}
}

fn plonk_read(&mut self, name: &str, outs: &mut [BabyBearElem]) {
match name {
"ram" => self.memory.ram_plonk.read(outs.try_into().unwrap()),
Expand Down