diff --git a/extensions/native/circuit/src/extension/mod.rs b/extensions/native/circuit/src/extension/mod.rs index 98b2fe774d..a86cdb1bd2 100644 --- a/extensions/native/circuit/src/extension/mod.rs +++ b/extensions/native/circuit/src/extension/mod.rs @@ -165,6 +165,7 @@ impl VmExecutionExtension for Native { VerifyBatchOpcode::VERIFY_BATCH.global_opcode(), Poseidon2Opcode::PERM_POS2.global_opcode(), Poseidon2Opcode::COMP_POS2.global_opcode(), + Poseidon2Opcode::MULTI_OBSERVE.global_opcode(), ], )?; diff --git a/extensions/native/circuit/src/poseidon2/air.rs b/extensions/native/circuit/src/poseidon2/air.rs index 373995bce9..adf01695e4 100644 --- a/extensions/native/circuit/src/poseidon2/air.rs +++ b/extensions/native/circuit/src/poseidon2/air.rs @@ -8,7 +8,7 @@ use openvm_circuit_primitives::utils::not; use openvm_instructions::LocalOpcode; use openvm_native_compiler::{ conversion::AS, - Poseidon2Opcode::{COMP_POS2, PERM_POS2, MULTI_OBSERVE}, + Poseidon2Opcode::{COMP_POS2, MULTI_OBSERVE, PERM_POS2}, VerifyBatchOpcode::VERIFY_BATCH, }; use openvm_poseidon2_air::{ @@ -26,8 +26,8 @@ use openvm_stark_backend::{ use crate::poseidon2::{ chip::{NUM_INITIAL_READS, NUM_SIMPLE_ACCESSES}, columns::{ - InsideRowSpecificCols, NativePoseidon2Cols, SimplePoseidonSpecificCols, - TopLevelSpecificCols, MultiObserveCols, + InsideRowSpecificCols, MultiObserveCols, NativePoseidon2Cols, SimplePoseidonSpecificCols, + TopLevelSpecificCols, }, }; @@ -118,7 +118,8 @@ impl Air builder.assert_bool(inside_row); builder.assert_bool(simple); builder.assert_bool(multi_observe_row); - let enabled = incorporate_row + incorporate_sibling + inside_row + simple + multi_observe_row; + let enabled = + incorporate_row + incorporate_sibling + inside_row + simple + multi_observe_row; builder.assert_bool(enabled.clone()); builder.assert_bool(end_inside_row); builder.when(end_inside_row).assert_one(inside_row); @@ -730,18 +731,12 @@ impl Air input_register_1, input_register_2, input_register_3, - output_register + output_register, } = multi_observe_specific; - builder - .when(multi_observe_row) - .assert_bool(is_first); - builder - .when(multi_observe_row) - .assert_bool(is_last); - builder - .when(multi_observe_row) - .assert_bool(should_permute); + builder.when(multi_observe_row).assert_bool(is_first); + builder.when(multi_observe_row).assert_bool(is_last); + builder.when(multi_observe_row).assert_bool(should_permute); self.execution_bridge .execute_and_increment_pc( @@ -799,19 +794,19 @@ impl Air let i_var = AB::F::from_canonical_usize(i); self.memory_bridge .read( - MemoryAddress::new(self.address_space, input_ptr + curr_len + i_var - start_idx), + MemoryAddress::new( + self.address_space, + input_ptr + curr_len + i_var - start_idx, + ), [data[i]], start_timestamp + i_var * AB::F::TWO - start_idx * AB::F::TWO, - &read_data[i] + &read_data[i], ) .eval(builder, aux_after_start[i] * aux_before_end[i]); - + self.memory_bridge .write( - MemoryAddress::new( - self.address_space, - state_ptr + i_var, - ), + MemoryAddress::new(self.address_space, state_ptr + i_var), [data[i]], start_timestamp + i_var * AB::F::TWO - start_idx * AB::F::TWO + AB::F::ONE, &write_data[i], @@ -835,15 +830,15 @@ impl Air .when(multi_observe_row) .when(not(is_first)) .assert_eq( - aux_after_start[0] - + aux_after_start[1] - + aux_after_start[2] - + aux_after_start[3] - + aux_after_start[4] - + aux_after_start[5] - + aux_after_start[6] - + aux_after_start[7], - AB::Expr::from_canonical_usize(CHUNK) - start_idx.into() + aux_after_start[0] + + aux_after_start[1] + + aux_after_start[2] + + aux_after_start[3] + + aux_after_start[4] + + aux_after_start[5] + + aux_after_start[6] + + aux_after_start[7], + AB::Expr::from_canonical_usize(CHUNK) - start_idx.into(), ); builder @@ -851,43 +846,40 @@ impl Air .when(not(is_first)) .assert_eq( aux_before_end[0] - + aux_before_end[1] - + aux_before_end[2] - + aux_before_end[3] - + aux_before_end[4] - + aux_before_end[5] - + aux_before_end[6] - + aux_before_end[7], - end_idx + + aux_before_end[1] + + aux_before_end[2] + + aux_before_end[3] + + aux_before_end[4] + + aux_before_end[5] + + aux_before_end[6] + + aux_before_end[7], + end_idx, ); - let full_sponge_input = from_fn::<_, {CHUNK * 2}, _>(|i| local.inner.inputs[i]); - let full_sponge_output = from_fn::<_, {CHUNK * 2}, _>(|i| local.inner.ending_full_rounds[BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS - 1].post[i]); + let full_sponge_input = from_fn::<_, { CHUNK * 2 }, _>(|i| local.inner.inputs[i]); + let full_sponge_output = from_fn::<_, { CHUNK * 2 }, _>(|i| { + local.inner.ending_full_rounds[BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS - 1].post[i] + }); self.memory_bridge .read( - MemoryAddress::new( - self.address_space, - state_ptr, - ), + MemoryAddress::new(self.address_space, state_ptr), full_sponge_input, - start_timestamp + end_idx * AB::F::TWO - start_idx * AB::F::TWO, + start_timestamp + (end_idx - start_idx) * AB::F::TWO, &read_sponge_state, ) .eval(builder, multi_observe_row * should_permute); - + self.memory_bridge .write( - MemoryAddress::new( - self.address_space, - state_ptr - ), + MemoryAddress::new(self.address_space, state_ptr), full_sponge_output, - start_timestamp + end_idx * AB::F::TWO - start_idx * AB::F::TWO + AB::F::ONE, + start_timestamp + (end_idx - start_idx) * AB::F::TWO + AB::F::ONE, &write_sponge_state, ) .eval(builder, multi_observe_row * should_permute); + /* self.memory_bridge .write( MemoryAddress::new( @@ -899,13 +891,14 @@ impl Air &write_final_idx ) .eval(builder, multi_observe_row * is_last); + */ // Field transitions builder .when(next.multi_observe_row) .when(not(next_multi_observe_specific.is_first)) .assert_eq(next_multi_observe_specific.curr_len, multi_observe_specific.curr_len + end_idx - start_idx); - + // Boundary conditions builder .when(multi_observe_row) @@ -927,7 +920,7 @@ impl Air .when(not(is_last)) .assert_one(next.multi_observe_row); - // Field consistency + // Fields remain same across same instance builder .when(next.multi_observe_row) .when(not(next_multi_observe_specific.is_first)) @@ -951,17 +944,26 @@ impl Air builder .when(next.multi_observe_row) .when(not(next_multi_observe_specific.is_first)) - .assert_eq(input_register_1, next_multi_observe_specific.input_register_1); + .assert_eq( + input_register_1, + next_multi_observe_specific.input_register_1, + ); builder .when(next.multi_observe_row) .when(not(next_multi_observe_specific.is_first)) - .assert_eq(input_register_2, next_multi_observe_specific.input_register_2); + .assert_eq( + input_register_2, + next_multi_observe_specific.input_register_2, + ); builder .when(next.multi_observe_row) .when(not(next_multi_observe_specific.is_first)) - .assert_eq(input_register_3, next_multi_observe_specific.input_register_3); + .assert_eq( + input_register_3, + next_multi_observe_specific.input_register_3, + ); builder .when(next.multi_observe_row) @@ -974,10 +976,12 @@ impl Air .when(not(next_multi_observe_specific.is_first)) .assert_eq(very_first_timestamp, next.very_first_timestamp); + /* builder .when(next.multi_observe_row) .when(not(next_multi_observe_specific.is_first)) .assert_eq(next.start_timestamp, start_timestamp + is_first * AB::F::from_canonical_usize(4) + (end_idx - start_idx) * AB::F::TWO + should_permute * AB::F::TWO); + */ } } diff --git a/extensions/native/circuit/src/poseidon2/chip.rs b/extensions/native/circuit/src/poseidon2/chip.rs index 01a77d0c65..2ccfbe93a3 100644 --- a/extensions/native/circuit/src/poseidon2/chip.rs +++ b/extensions/native/circuit/src/poseidon2/chip.rs @@ -12,7 +12,7 @@ use openvm_circuit::{ use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; use openvm_native_compiler::{ conversion::AS, - Poseidon2Opcode::{COMP_POS2, PERM_POS2, MULTI_OBSERVE}, + Poseidon2Opcode::{COMP_POS2, MULTI_OBSERVE, PERM_POS2}, VerifyBatchOpcode::VERIFY_BATCH, }; use openvm_poseidon2_air::{Poseidon2Config, Poseidon2SubChip, Poseidon2SubCols}; @@ -25,7 +25,7 @@ use openvm_stark_backend::{ use crate::poseidon2::{ columns::{ - InsideRowSpecificCols, NativePoseidon2Cols, SimplePoseidonSpecificCols, + InsideRowSpecificCols, MultiObserveCols, NativePoseidon2Cols, SimplePoseidonSpecificCols, TopLevelSpecificCols, }, CHUNK, @@ -644,6 +644,181 @@ where if !self.optimistic { assert_eq!(commit, root); } + } else if instruction.opcode == MULTI_OBSERVE.global_opcode() { + let &Instruction { + a: state_ptr_register, + b: init_pos_register, + c: input_ptr_register, + d: register_address_space, + e: data_address_space, + f: len_register, + .. + } = instruction; + + assert_eq!( + register_address_space, + F::from_canonical_u32(AS::Native as u32) + ); + assert_eq!(data_address_space, F::from_canonical_u32(AS::Native as u32)); + + let [init_pos]: [F; 1] = + memory_read_native(state.memory.data(), init_pos_register.as_canonical_u32()); + let [input_len]: [F; 1] = + memory_read_native(state.memory.data(), len_register.as_canonical_u32()); + + let mut len = input_len.as_canonical_u32() as usize; + let mut pos = init_pos.as_canonical_u32() as usize; + let mut chunks: Vec<(usize, usize)> = vec![]; + + const NUM_HEAD_ACCESSES: usize = 4; + let mut final_timestamp_inc = NUM_HEAD_ACCESSES; + while len > 0 { + if len >= (CHUNK - pos) { + chunks.push((pos.clone(), CHUNK.clone())); + len -= CHUNK - pos; + final_timestamp_inc += 2 * (CHUNK - pos + 1); + pos = 0; + } else { + chunks.push((pos.clone(), pos + len)); + final_timestamp_inc += 2 * len; + len = 0; + pos = pos + len; + } + } + + let allocated_rows = arena + .alloc(MultiRowLayout::new(NativePoseidon2Metadata { + num_rows: 1 + chunks.len(), + })) + .0; + let head_cols = &mut allocated_rows[0]; + let head_multi_observe_cols: &mut MultiObserveCols = + head_cols.specific[..MultiObserveCols::::width()].borrow_mut(); + + let [state_ptr]: [F; 1] = tracing_read_native_helper( + state.memory, + state_ptr_register.as_canonical_u32(), + head_multi_observe_cols.read_data[0].as_mut(), + ); + let [init_pos]: [F; 1] = tracing_read_native_helper( + state.memory, + init_pos_register.as_canonical_u32(), + head_multi_observe_cols.read_data[1].as_mut(), + ); + let [input_ptr]: [F; 1] = tracing_read_native_helper( + state.memory, + input_ptr_register.as_canonical_u32(), + head_multi_observe_cols.read_data[2].as_mut(), + ); + let [input_len]: [F; 1] = tracing_read_native_helper( + state.memory, + len_register.as_canonical_u32(), + head_multi_observe_cols.read_data[3].as_mut(), + ); + + let input_ptr_u32 = input_ptr.as_canonical_u32(); + let state_ptr_u32 = state_ptr.as_canonical_u32(); + + let init_timestamp = F::from_canonical_u32(init_timestamp_u32); + + for (i, cols) in allocated_rows.iter_mut().enumerate() { + let multi_observe_cols: &mut MultiObserveCols = + cols.specific[..MultiObserveCols::::width()].borrow_mut(); + multi_observe_cols.input_register_1 = init_pos_register; + multi_observe_cols.input_register_2 = input_ptr_register; + multi_observe_cols.input_register_3 = len_register; + multi_observe_cols.output_register = state_ptr_register; + multi_observe_cols.init_pos = init_pos; + multi_observe_cols.input_ptr = input_ptr; + multi_observe_cols.state_ptr = state_ptr; + multi_observe_cols.len = input_len; + + cols.multi_observe_row = F::ONE; + cols.very_first_timestamp = init_timestamp; + + if i == 0 { + // head row + cols.inner.export = F::from_canonical_u32(1 + chunks.len() as u32); + multi_observe_cols.pc = F::from_canonical_u32(*state.pc); + multi_observe_cols.final_timestamp_increment = + F::from_canonical_usize(final_timestamp_inc); + multi_observe_cols.is_first = F::ONE; + multi_observe_cols.is_last = F::ZERO; + multi_observe_cols.curr_len = F::ZERO; + multi_observe_cols.should_permute = F::ZERO; + } + } + + let mut input_idx: usize = 0; + let mut cur_timestamp = init_timestamp_u32 + NUM_HEAD_ACCESSES as u32; + let num_chunks = chunks.len(); + for (i, ((chunk_start, chunk_end), cols)) in chunks + .into_iter() + .zip(allocated_rows.iter_mut().skip(1)) + .enumerate() + { + let multi_observe_cols: &mut MultiObserveCols = + cols.specific[..MultiObserveCols::::width()].borrow_mut(); + + cols.start_timestamp = F::from_canonical_u32(cur_timestamp); + + multi_observe_cols.start_idx = F::from_canonical_usize(chunk_start); + multi_observe_cols.end_idx = F::from_canonical_usize(chunk_end); + + multi_observe_cols.is_first = F::ZERO; + multi_observe_cols.is_last = if i == num_chunks - 1 { + F::ONE + } else { + F::ZERO + }; + multi_observe_cols.curr_len = F::from_canonical_usize(input_idx); + + for j in chunk_start..CHUNK { + multi_observe_cols.aux_after_start[j] = F::ONE; + } + for j in 0..chunk_end { + multi_observe_cols.aux_before_end[j] = F::ONE; + } + for j in chunk_start..chunk_end { + let n_f: [F; 1] = tracing_read_native_helper( + state.memory, + input_ptr_u32 + input_idx as u32, + multi_observe_cols.read_data[j].as_mut(), + ); + tracing_write_native_inplace( + state.memory, + state_ptr_u32 + j as u32, + n_f, + &mut multi_observe_cols.write_data[j], + ); + multi_observe_cols.data[j] = n_f[0]; + input_idx += 1; + cur_timestamp += 2; + } + + if chunk_end >= CHUNK { + multi_observe_cols.should_permute = F::ONE; + let permutation_input: [F; 16] = tracing_read_native_helper( + state.memory, + state_ptr_u32, + multi_observe_cols.read_sponge_state.as_mut(), + ); + cols.inner.inputs.clone_from_slice(&permutation_input); + let output = self.subchip.permute(permutation_input); + tracing_write_native_inplace( + state.memory, + state_ptr_u32, + std::array::from_fn(|i| output[i]), + &mut multi_observe_cols.write_sponge_state, + ); + cur_timestamp += 2; + } else { + multi_observe_cols.should_permute = F::ZERO; + let sponge_state: [F; 16] = + memory_read_native(state.memory.data(), state_ptr_u32); + cols.inner.inputs.clone_from_slice(&sponge_state); + } + } } else { unreachable!() } @@ -661,7 +836,7 @@ where String::from("COMP_POS2") } else if opcode == MULTI_OBSERVE.global_opcode().as_usize() { String::from("MULTI_OBSERVE") - }else { + } else { unreachable!("unsupported opcode: {}", opcode) } } @@ -688,6 +863,10 @@ impl TraceFiller let (curr, rest) = if cols.simple.is_one() { row_idx += 1; row_slice.split_at_mut(width) + } else if cols.multi_observe_row.is_one() { + let total_num_row = cols.inner.export.as_canonical_u32() as usize; + row_idx += total_num_row; + row_slice.split_at_mut(total_num_row * width) } else { let num_non_inside_row = cols.inner.export.as_canonical_u32() as usize; let start = (num_non_inside_row - 1) * width; @@ -704,6 +883,8 @@ impl TraceFiller let cols: &NativePoseidon2Cols = chunk_slice[..width].borrow(); if cols.simple.is_one() { self.fill_simple_chunk(mem_helper, chunk_slice); + } else if cols.multi_observe_row.is_one() { + self.fill_multi_observe_chunk(mem_helper, chunk_slice); } else { self.fill_verify_batch_chunk(mem_helper, chunk_slice); } @@ -961,6 +1142,94 @@ impl NativePoseidon2Filler, + chunk_slice: &mut [F], + ) { + let inner_width = self.subchip.air.width(); + let width = NativePoseidon2Cols::::width(); + let head_cols: &mut NativePoseidon2Cols = + chunk_slice[..width].borrow_mut(); + let num_rows = head_cols.inner.export.as_canonical_u32() as usize; + + let head_multi_observe_cols: &mut MultiObserveCols = + head_cols.specific[..MultiObserveCols::::width()].borrow_mut(); + let start_timestamp_u32 = head_cols.very_first_timestamp.as_canonical_u32(); + + // state_ptr, init_pos, input_ptr, len + mem_fill_helper( + mem_helper, + start_timestamp_u32, + head_multi_observe_cols.read_data[0].as_mut(), + ); + mem_fill_helper( + mem_helper, + start_timestamp_u32 + 1, + head_multi_observe_cols.read_data[1].as_mut(), + ); + mem_fill_helper( + mem_helper, + start_timestamp_u32 + 2, + head_multi_observe_cols.read_data[2].as_mut(), + ); + mem_fill_helper( + mem_helper, + start_timestamp_u32 + 3, + head_multi_observe_cols.read_data[3].as_mut(), + ); + + // generate permutation traces for each row + for row_idx in 0..num_rows { + let cols: &NativePoseidon2Cols = chunk_slice + [row_idx * width..(row_idx + 1) * width] + .as_ref() + .borrow(); + let inner_cols = &self.subchip.generate_trace(vec![cols.inner.inputs]).values; + chunk_slice[row_idx * width..(row_idx + 1) * width][..inner_width] + .copy_from_slice(inner_cols); + } + + for row_idx in 1..num_rows { + let cols: &mut NativePoseidon2Cols = + chunk_slice[row_idx * width..(row_idx + 1) * width].borrow_mut(); + let multi_observe_cols: &mut MultiObserveCols = + cols.specific[..MultiObserveCols::::width()].borrow_mut(); + + let mut start_timestamp_u32 = cols.start_timestamp.as_canonical_u32(); + let chunk_start = multi_observe_cols.start_idx.as_canonical_u32(); + let chunk_end = multi_observe_cols.end_idx.as_canonical_u32(); + + for j in chunk_start..chunk_end { + mem_fill_helper( + mem_helper, + start_timestamp_u32, + multi_observe_cols.read_data[j as usize].as_mut(), + ); + mem_fill_helper( + mem_helper, + start_timestamp_u32 + 1, + multi_observe_cols.write_data[j as usize].as_mut(), + ); + + start_timestamp_u32 += 2; + } + + if chunk_end >= CHUNK as u32 { + mem_fill_helper( + mem_helper, + start_timestamp_u32, + multi_observe_cols.read_sponge_state.as_mut(), + ); + mem_fill_helper( + mem_helper, + start_timestamp_u32 + 1, + multi_observe_cols.write_sponge_state.as_mut(), + ); + } + } + } + #[inline(always)] fn poseidon2_output_from_trace(inner: &Poseidon2SubCols) -> &[F; 2 * CHUNK] { &inner.ending_full_rounds.last().unwrap().post diff --git a/extensions/native/circuit/src/poseidon2/columns.rs b/extensions/native/circuit/src/poseidon2/columns.rs index fe0fce881a..934378dfbe 100644 --- a/extensions/native/circuit/src/poseidon2/columns.rs +++ b/extensions/native/circuit/src/poseidon2/columns.rs @@ -212,10 +212,15 @@ pub struct MultiObserveCols { pub final_timestamp_increment: T, // Initial reads from registers + // They are same across same instance of multi_observe pub state_ptr: T, pub input_ptr: T, pub init_pos: T, pub len: T, + pub input_register_1: T, + pub input_register_2: T, + pub input_register_3: T, + pub output_register: T, pub is_first: T, pub is_last: T, @@ -238,9 +243,4 @@ pub struct MultiObserveCols { // Final write back and registers pub write_final_idx: MemoryWriteAuxCols, pub final_idx: T, - - pub input_register_1: T, - pub input_register_2: T, - pub input_register_3: T, - pub output_register: T, -} \ No newline at end of file +} diff --git a/extensions/native/circuit/src/poseidon2/execution.rs b/extensions/native/circuit/src/poseidon2/execution.rs index 20889e4186..41f4827a67 100644 --- a/extensions/native/circuit/src/poseidon2/execution.rs +++ b/extensions/native/circuit/src/poseidon2/execution.rs @@ -5,10 +5,12 @@ use std::{ use openvm_circuit::{arch::*, system::memory::online::GuestMemory}; use openvm_circuit_primitives::AlignedBytesBorrow; -use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; +use openvm_instructions::{ + instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode, NATIVE_AS, +}; use openvm_native_compiler::{ conversion::AS, - Poseidon2Opcode::{COMP_POS2, PERM_POS2}, + Poseidon2Opcode::{COMP_POS2, MULTI_OBSERVE, PERM_POS2}, VerifyBatchOpcode::VERIFY_BATCH, }; use openvm_poseidon2_air::Poseidon2SubChip; @@ -29,6 +31,16 @@ struct Pos2PreCompute<'a, F: Field, const SBOX_REGISTERS: usize> { input_register_2: u32, } +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct MultiObservePreCompute<'a, F: Field, const SBOX_REGISTERS: usize> { + subchip: &'a Poseidon2SubChip, + pub init_pos_register: u32, + pub input_ptr_register: u32, + pub len_register: u32, + pub state_ptr_register: u32, +} + #[derive(AlignedBytesBorrow, Clone)] #[repr(C)] struct VerifyBatchPreCompute<'a, F: Field, const SBOX_REGISTERS: usize> { @@ -87,6 +99,51 @@ impl<'a, F: PrimeField32, const SBOX_REGISTERS: usize> NativePoseidon2Executor, + multi_observe_data: &mut MultiObservePreCompute<'a, F, SBOX_REGISTERS>, + ) -> Result<(), StaticProgramError> { + let &Instruction { + opcode, + a, + b, + c, + d, + e, + f, + .. + } = inst; + + if opcode != MULTI_OBSERVE.global_opcode() { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + + let a = a.as_canonical_u32(); + let b = b.as_canonical_u32(); + let c = c.as_canonical_u32(); + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); + let f = f.as_canonical_u32(); + + if d != AS::Native as u32 { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + if e != AS::Native as u32 { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + + multi_observe_data.subchip = &self.subchip; + multi_observe_data.state_ptr_register = a; + multi_observe_data.init_pos_register = b; + multi_observe_data.input_ptr_register = c; + multi_observe_data.len_register = f; + + Ok(()) + } + #[inline(always)] fn pre_compute_verify_batch_impl( &'a self, @@ -142,6 +199,7 @@ impl<'a, F: PrimeField32, const SBOX_REGISTERS: usize> NativePoseidon2Executor) } + } else if $opcode == MULTI_OBSERVE.global_opcode() { + let multi_observe_data: &mut MultiObservePreCompute = + $data.borrow_mut(); + $executor.pre_compute_multi_observe_impl($pc, $inst, multi_observe_data)?; + Ok($execute_multi_observe_impl::<_, _, SBOX_REGISTERS>) } else { let verify_batch_data: &mut VerifyBatchPreCompute = $data.borrow_mut(); @@ -166,13 +229,18 @@ macro_rules! dispatch1 { }; } +fn max3(a: usize, b: usize, c: usize) -> usize { + std::cmp::max(a, std::cmp::max(b, c)) +} + impl Executor for NativePoseidon2Executor { #[inline(always)] fn pre_compute_size(&self) -> usize { - std::cmp::max( + max3( size_of::>(), + size_of::>(), size_of::>(), ) } @@ -187,6 +255,7 @@ impl Executor ) -> Result, StaticProgramError> { dispatch1!( execute_pos2_e1_impl, + execute_multi_observe_e1_impl, execute_verify_batch_e1_impl, self, inst.opcode, @@ -205,6 +274,7 @@ impl Executor ) -> Result, StaticProgramError> { dispatch1!( execute_pos2_e1_handler, + execute_multi_observe_e1_handler, execute_verify_batch_e1_handler, self, inst.opcode, @@ -218,6 +288,7 @@ impl Executor macro_rules! dispatch2 { ( $execute_pos2_impl:ident, + $execute_multi_observe_impl:ident, $execute_verify_batch_impl:ident, $executor:ident, $opcode:expr, @@ -237,6 +308,13 @@ macro_rules! dispatch2 { } else { Ok($execute_pos2_impl::<_, _, SBOX_REGISTERS, false>) } + } else if $opcode == MULTI_OBSERVE.global_opcode() { + let pre_compute: &mut E2PreCompute> = + $data.borrow_mut(); + pre_compute.chip_idx = $chip_idx as u32; + + $executor.pre_compute_multi_observe_impl($pc, $inst, &mut pre_compute.data)?; + Ok($execute_multi_observe_impl::<_, _, SBOX_REGISTERS>) } else { let pre_compute: &mut E2PreCompute> = $data.borrow_mut(); @@ -270,6 +348,7 @@ impl MeteredExecutor ) -> Result, StaticProgramError> { dispatch2!( execute_pos2_e2_impl, + execute_multi_observe_e2_impl, execute_verify_batch_e2_impl, self, inst.opcode, @@ -290,6 +369,7 @@ impl MeteredExecutor ) -> Result, StaticProgramError> { dispatch2!( execute_pos2_e2_handler, + execute_multi_observe_e2_handler, execute_verify_batch_e2_handler, self, inst.opcode, @@ -345,6 +425,49 @@ unsafe fn execute_pos2_e2_impl< .on_height_change(pre_compute.chip_idx as usize, height); } +#[create_handler] +#[inline(always)] +unsafe fn execute_multi_observe_e1_impl< + F: PrimeField32, + CTX: ExecutionCtxTrait, + const SBOX_REGISTERS: usize, +>( + pre_compute: &[u8], + instret: &mut u64, + pc: &mut u32, + _arg: u64, + exec_state: &mut VmExecState, +) { + let pre_compute: &MultiObservePreCompute = pre_compute.borrow(); + execute_multi_observe_e12_impl::<_, _, SBOX_REGISTERS>(pre_compute, instret, pc, exec_state); +} + +#[create_handler] +#[inline(always)] +unsafe fn execute_multi_observe_e2_impl< + F: PrimeField32, + CTX: MeteredExecutionCtxTrait, + const SBOX_REGISTERS: usize, +>( + pre_compute: &[u8], + instret: &mut u64, + pc: &mut u32, + _arg: u64, + exec_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute> = + pre_compute.borrow(); + let height = execute_multi_observe_e12_impl::<_, _, SBOX_REGISTERS>( + &pre_compute.data, + instret, + pc, + exec_state, + ); + exec_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, height); +} + #[create_handler] #[inline(always)] unsafe fn execute_verify_batch_e1_impl< @@ -452,6 +575,68 @@ unsafe fn execute_pos2_e12_impl< 1 } +#[inline(always)] +unsafe fn execute_multi_observe_e12_impl< + F: PrimeField32, + CTX: ExecutionCtxTrait, + const SBOX_REGISTERS: usize, +>( + pre_compute: &MultiObservePreCompute, + instret: &mut u64, + pc: &mut u32, + exec_state: &mut VmExecState, +) -> u32 { + let subchip = pre_compute.subchip; + + let [sponge_ptr]: [F; 1] = + exec_state.vm_read(AS::Native as u32, pre_compute.state_ptr_register); + let [init_pos]: [F; 1] = exec_state.vm_read(AS::Native as u32, pre_compute.init_pos_register); + let [input_ptr]: [F; 1] = exec_state.vm_read(AS::Native as u32, pre_compute.input_ptr_register); + let [len]: [F; 1] = exec_state.vm_read(AS::Native as u32, pre_compute.len_register); + + let mut len = len.as_canonical_u32() as usize; + let mut pos = init_pos.as_canonical_u32() as usize; + let input_ptr_u32 = input_ptr.as_canonical_u32(); + let sponge_ptr_u32 = sponge_ptr.as_canonical_u32(); + let mut height = 0; + + // split input into chunks s.t. each chunk fills the RATE portion of sponge state + let mut observation_chunks: Vec<(usize, usize)> = vec![]; + while len > 0 { + if len >= (CHUNK - pos) { + observation_chunks.push((pos, CHUNK)); + len -= CHUNK - pos; + pos = 0; + } else { + observation_chunks.push((pos, pos + len)); + len = 0; + pos = pos + len; + } + } + + height += 1; + let mut input_idx = 0; + + for (chunk_start, chunk_end) in observation_chunks { + for j in chunk_start..chunk_end { + let [n_f]: [F; 1] = exec_state.vm_read(NATIVE_AS as u32, input_ptr_u32 + input_idx); + exec_state.vm_write(NATIVE_AS as u32, sponge_ptr_u32 + (j as u32), &[n_f]); + input_idx += 1; + } + if chunk_end == CHUNK { + let mut p2_input: [F; CHUNK * 2] = exec_state.vm_read(NATIVE_AS as u32, sponge_ptr_u32); + subchip.permute_mut(&mut p2_input); + exec_state.vm_write(NATIVE_AS as u32, sponge_ptr_u32, &p2_input); + } + + height += 1; + } + *pc = pc.wrapping_add(DEFAULT_PC_STEP); + *instret += 1; + + height +} + #[inline(always)] unsafe fn execute_verify_batch_e12_impl< F: PrimeField32, diff --git a/extensions/native/recursion/tests/recursion.rs b/extensions/native/recursion/tests/recursion.rs index e3a5fd0f3b..a147d161f0 100644 --- a/extensions/native/recursion/tests/recursion.rs +++ b/extensions/native/recursion/tests/recursion.rs @@ -1,9 +1,10 @@ use itertools::Itertools; use openvm_circuit::{ arch::{ - PreflightExecutionOutput, PreflightExecutor, VmBuilder, VmCircuitConfig, VmExecutionConfig, instructions::program::Program + instructions::program::Program, PreflightExecutionOutput, PreflightExecutor, VmBuilder, + VmCircuitConfig, VmExecutionConfig, }, - utils::{TestStarkEngine, air_test_impl}, + utils::{air_test_impl, TestStarkEngine}, }; use openvm_native_circuit::{ execute_program_with_config, test_native_config, NativeBuilder, NativeConfig, @@ -209,12 +210,12 @@ fn test_multi_observe() { config.system.memory_config.max_access_adapter_n = 16; let vb = NativeBuilder::default(); - air_test_impl::(fri_params, vb, config, program, vec![], 1, true).unwrap(); - + air_test_impl::(fri_params, vb, config, program, vec![], 1, true) + .unwrap(); } fn build_test_program(builder: &mut Builder) { - let sample_lens: Vec = vec![10, 2, 0, 3, 20]; + let sample_lens: Vec = vec![10, 2, 1, 3, 20]; let mut rng = create_seeded_rng(); let challenger = DuplexChallengerVariable::new(builder);