diff --git a/extensions/native/circuit/cuda/include/native/poseidon2.cuh b/extensions/native/circuit/cuda/include/native/poseidon2.cuh index 40f0e4ad43..206c0e16c0 100644 --- a/extensions/native/circuit/cuda/include/native/poseidon2.cuh +++ b/extensions/native/circuit/cuda/include/native/poseidon2.cuh @@ -87,7 +87,6 @@ template struct MultiObserveCols { T should_permute; MemoryWriteAuxCols write_sponge_state; MemoryWriteAuxCols write_final_idx; - T final_idx; }; template constexpr T constexpr_max(T a, T b) { return a > b ? a : b; } diff --git a/extensions/native/circuit/cuda/src/poseidon2.cu b/extensions/native/circuit/cuda/src/poseidon2.cu index ece788b0a8..59c65626ab 100644 --- a/extensions/native/circuit/cuda/src/poseidon2.cu +++ b/extensions/native/circuit/cuda/src/poseidon2.cu @@ -395,6 +395,14 @@ template struct Poseidon2Wrapper { start_timestamp, specific.slice_from(COL_INDEX(MultiObserveCols, write_sponge_state.base)) ); + start_timestamp += 1; + } + if (specific[COL_INDEX(MultiObserveCols, is_last)] == Fp::one()) { + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(MultiObserveCols, write_final_idx.base)) + ); } } } diff --git a/extensions/native/circuit/src/poseidon2/air.rs b/extensions/native/circuit/src/poseidon2/air.rs index d68ddf15f3..baf18b06a3 100644 --- a/extensions/native/circuit/src/poseidon2/air.rs +++ b/extensions/native/circuit/src/poseidon2/air.rs @@ -728,7 +728,6 @@ impl Air should_permute, write_sponge_state, write_final_idx, - final_idx, input_register_1, input_register_2, input_register_3, @@ -830,6 +829,16 @@ impl Air } for i in 0..CHUNK { + builder + .when(multi_observe_row) + .assert_bool(aux_after_start[i]); + builder + .when(multi_observe_row) + .assert_bool(aux_before_end[i]); + builder + .when(multi_observe_row) + .when(is_first) + .assert_zero(aux_read_enabled[i]); builder .when(multi_observe_row) .assert_eq(aux_after_start[i] * aux_before_end[i], aux_read_enabled[i]); @@ -889,19 +898,22 @@ impl Air .assert_eq(*a, *b); }); - /* + builder + .when(multi_observe_row) + .when(aux_read_enabled[CHUNK - 1]) + .assert_one(should_permute); + + // final_idx = aux_read_enabled[CHUNK-1] * 0 + (1 - aux_read_enabled[CHUNK-1]) * end_idx + let final_idx = aux_read_enabled[CHUNK - 1] * AB::Expr::ZERO + + (AB::Expr::ONE - aux_read_enabled[CHUNK - 1]) * end_idx; self.memory_bridge .write( - MemoryAddress::new( - self.address_space, - input_register_1, - ), + MemoryAddress::new(self.address_space, input_register_1), [final_idx], - start_timestamp + is_first * AB::F::from_canonical_usize(4) + (end_idx - start_idx) * AB::F::TWO + should_permute * AB::F::TWO, - &write_final_idx + start_timestamp + (end_idx - start_idx) * AB::F::TWO + should_permute, + &write_final_idx, ) .eval(builder, multi_observe_row * is_last); - */ // Field transitions builder diff --git a/extensions/native/circuit/src/poseidon2/chip.rs b/extensions/native/circuit/src/poseidon2/chip.rs index 3a6ccf094a..aecff9f10f 100644 --- a/extensions/native/circuit/src/poseidon2/chip.rs +++ b/extensions/native/circuit/src/poseidon2/chip.rs @@ -685,6 +685,7 @@ where pos += len; } } + final_timestamp_inc += 1; // write back to init_pos_register let allocated_rows = arena .alloc(MultiRowLayout::new(NativePoseidon2Metadata { @@ -810,6 +811,15 @@ where multi_observe_cols.should_permute = F::ZERO; cols.inner.inputs.clone_from_slice(&permutation_input); } + if i == num_chunks - 1 { + let final_idx = F::from_canonical_usize(chunk_end % CHUNK); + tracing_write_native_inplace( + state.memory, + init_pos_register.as_canonical_u32(), + [final_idx], + &mut multi_observe_cols.write_final_idx, + ); + } } } else { unreachable!() @@ -1213,6 +1223,14 @@ impl NativePoseidon2Filler { // Final write back and registers pub write_final_idx: MemoryWriteAuxCols, - pub final_idx: T, } diff --git a/extensions/native/circuit/src/poseidon2/execution.rs b/extensions/native/circuit/src/poseidon2/execution.rs index 7e8d9a8aea..a0c1fc72a2 100644 --- a/extensions/native/circuit/src/poseidon2/execution.rs +++ b/extensions/native/circuit/src/poseidon2/execution.rs @@ -331,8 +331,9 @@ impl MeteredExecutor { #[inline(always)] fn metered_pre_compute_size(&self) -> usize { - std::cmp::max( + max3( size_of::>>(), + size_of::>>(), size_of::>>(), ) } @@ -613,6 +614,7 @@ unsafe fn execute_multi_observe_e12_impl< pos += len; } } + let final_idx = observation_chunks.last().map(|(_, end)| *end % CHUNK); height += 1; let mut input_idx = 0; @@ -631,6 +633,13 @@ unsafe fn execute_multi_observe_e12_impl< height += 1; } + if let Some(final_idx) = final_idx { + exec_state.vm_write::( + NATIVE_AS, + pre_compute.init_pos_register, + &[F::from_canonical_usize(final_idx)], + ); + } *pc = pc.wrapping_add(DEFAULT_PC_STEP); *instret += 1; diff --git a/extensions/native/compiler/src/ir/poseidon.rs b/extensions/native/compiler/src/ir/poseidon.rs index c82bbaec38..6d32f89409 100644 --- a/extensions/native/compiler/src/ir/poseidon.rs +++ b/extensions/native/compiler/src/ir/poseidon.rs @@ -42,6 +42,7 @@ impl Builder { len.clone(), )); + // automatically updated by Poseidon2MultiObserve operation Usize::Var(init_pos) } }, diff --git a/extensions/native/recursion/src/challenger/duplex.rs b/extensions/native/recursion/src/challenger/duplex.rs index 2d45d896be..440b14ec59 100644 --- a/extensions/native/recursion/src/challenger/duplex.rs +++ b/extensions/native/recursion/src/challenger/duplex.rs @@ -77,6 +77,24 @@ impl DuplexChallengerVariable { } } + // Observes multiple elements from an array. + // This is equivalent to calling `observe` multiple times, but more efficient. + pub fn observe_slice_opt(&self, builder: &mut Builder, arr: &Array>) { + builder.if_ne(arr.len(), Usize::from(0)).then(|builder| { + let next_pos = builder.poseidon2_multi_observe(&self.sponge_state, self.input_ptr, arr); + + builder.assign(&self.input_ptr, self.io_empty_ptr + next_pos.clone()); + builder.if_ne(next_pos, Usize::from(0)).then_or_else( + |builder| { + builder.assign(&self.output_ptr, self.io_empty_ptr); + }, + |builder| { + builder.assign(&self.output_ptr, self.io_full_ptr); + }, + ); + }); + } + fn sample(&self, builder: &mut Builder) -> Felt { builder .if_ne(self.input_ptr.address, self.io_empty_ptr.address) diff --git a/extensions/native/recursion/tests/recursion.rs b/extensions/native/recursion/tests/recursion.rs index 3b78b6734b..68e033845b 100644 --- a/extensions/native/recursion/tests/recursion.rs +++ b/extensions/native/recursion/tests/recursion.rs @@ -15,10 +15,9 @@ use openvm_native_compiler::{ asm::{AsmBuilder, AsmCompiler}, conversion::{convert_program, CompilerOptions}, ir::{Array, Builder, Config, Felt}, - prelude::Usize, }; use openvm_native_recursion::{ - challenger::{duplex::DuplexChallengerVariable, CanObserveVariable}, + challenger::{duplex::DuplexChallengerVariable, CanObserveVariable, CanSampleVariable}, testing_utils::inner::run_recursive_test, }; use openvm_stark_backend::{ @@ -192,7 +191,6 @@ fn test_multi_observe() { compiler.build(builder.operations); let asm_code = compiler.code(); - // let program = Program::from_instructions(&instructions); let program: Program<_> = convert_program(asm_code, compilation_options); let poseidon2_max_constraint_degree = 3; @@ -232,17 +230,12 @@ fn test_multi_observe() { } fn build_test_program(builder: &mut Builder) { - let sample_lens: Vec = vec![10, 2, 1, 3, 20]; + let sample_lens: Vec = vec![10, 2, 1, 0, 3, 20, 200, 400]; let mut rng = create_seeded_rng(); - let mut challenger = DuplexChallengerVariable::new(builder); - // Observe a setup label - let label_f: Vec = vec![128, 3098, 192, 394, 1662, 928, 374, 281, 598, 182, 475, 729]; - for n in label_f { - let f: Felt = builder.constant(C::F::from_canonical_u64(n)); - challenger.observe(builder, f); - } + let mut c1 = DuplexChallengerVariable::new(builder); + let mut c2 = DuplexChallengerVariable::new(builder); for l in sample_lens { let sample_input: Array> = builder.dyn_array(l); @@ -251,24 +244,13 @@ fn build_test_program(builder: &mut Builder) { builder.set(&sample_input, idx_vec[0], C::F::from_canonical_u32(f_u32)); }); - let next_input_ptr = builder.poseidon2_multi_observe( - &challenger.sponge_state, - challenger.input_ptr, - &sample_input, - ); + c1.observe_slice_opt(builder, &sample_input); + c2.observe_slice(builder, sample_input); + + let e1 = c1.sample(builder); + let e2 = c2.sample(builder); - builder.assign( - &challenger.input_ptr, - challenger.io_empty_ptr + next_input_ptr.clone(), - ); - builder.if_ne(next_input_ptr, Usize::from(0)).then_or_else( - |builder| { - builder.assign(&challenger.output_ptr, challenger.io_empty_ptr); - }, - |builder| { - builder.assign(&challenger.output_ptr, challenger.io_full_ptr); - }, - ); + builder.assert_felt_eq(e1, e2); } builder.halt(); }