Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ template <typename T> struct MultiObserveCols {
T should_permute;
MemoryWriteAuxCols<T, CHUNK * 2> write_sponge_state;
MemoryWriteAuxCols<T, 1> write_final_idx;
T final_idx;
};

template <typename T> constexpr T constexpr_max(T a, T b) { return a > b ? a : b; }
Expand Down
8 changes: 8 additions & 0 deletions extensions/native/circuit/cuda/src/poseidon2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,14 @@ template <size_t SBOX_REGISTERS> struct Poseidon2Wrapper {
start_timestamp,
specific.slice_from(COL_INDEX(MultiObserveCols, write_sponge_state.base))
);
start_timestamp += 1;
}
if (specific[COL_INDEX(MultiObserveCols, is_last)] == Fp::one()) {
mem_fill_base(
mem_helper,
start_timestamp,
specific.slice_from(COL_INDEX(MultiObserveCols, write_final_idx.base))
);
}
}
}
Expand Down
30 changes: 21 additions & 9 deletions extensions/native/circuit/src/poseidon2/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,6 @@ impl<AB: InteractionBuilder, const SBOX_REGISTERS: usize> Air<AB>
should_permute,
write_sponge_state,
write_final_idx,
final_idx,
input_register_1,
input_register_2,
input_register_3,
Expand Down Expand Up @@ -830,6 +829,16 @@ impl<AB: InteractionBuilder, const SBOX_REGISTERS: usize> Air<AB>
}

for i in 0..CHUNK {
builder
.when(multi_observe_row)
.assert_bool(aux_after_start[i]);
builder
.when(multi_observe_row)
.assert_bool(aux_before_end[i]);
builder
.when(multi_observe_row)
.when(is_first)
.assert_zero(aux_read_enabled[i]);
builder
.when(multi_observe_row)
.assert_eq(aux_after_start[i] * aux_before_end[i], aux_read_enabled[i]);
Expand Down Expand Up @@ -889,19 +898,22 @@ impl<AB: InteractionBuilder, const SBOX_REGISTERS: usize> Air<AB>
.assert_eq(*a, *b);
});

/*
builder
.when(multi_observe_row)
.when(aux_read_enabled[CHUNK - 1])
.assert_one(should_permute);

// final_idx = aux_read_enabled[CHUNK-1] * 0 + (1 - aux_read_enabled[CHUNK-1]) * end_idx
let final_idx = aux_read_enabled[CHUNK - 1] * AB::Expr::ZERO
+ (AB::Expr::ONE - aux_read_enabled[CHUNK - 1]) * end_idx;
self.memory_bridge
.write(
MemoryAddress::new(
self.address_space,
input_register_1,
),
MemoryAddress::new(self.address_space, input_register_1),
[final_idx],
start_timestamp + is_first * AB::F::from_canonical_usize(4) + (end_idx - start_idx) * AB::F::TWO + should_permute * AB::F::TWO,
&write_final_idx
start_timestamp + (end_idx - start_idx) * AB::F::TWO + should_permute,
&write_final_idx,
)
.eval(builder, multi_observe_row * is_last);
*/

// Field transitions
builder
Expand Down
18 changes: 18 additions & 0 deletions extensions/native/circuit/src/poseidon2/chip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,7 @@ where
pos += len;
}
}
final_timestamp_inc += 1; // write back to init_pos_register

let allocated_rows = arena
.alloc(MultiRowLayout::new(NativePoseidon2Metadata {
Expand Down Expand Up @@ -810,6 +811,15 @@ where
multi_observe_cols.should_permute = F::ZERO;
cols.inner.inputs.clone_from_slice(&permutation_input);
}
if i == num_chunks - 1 {
let final_idx = F::from_canonical_usize(chunk_end % CHUNK);
tracing_write_native_inplace(
state.memory,
init_pos_register.as_canonical_u32(),
[final_idx],
&mut multi_observe_cols.write_final_idx,
);
}
}
} else {
unreachable!()
Expand Down Expand Up @@ -1213,6 +1223,14 @@ impl<F: PrimeField32, const SBOX_REGISTERS: usize> NativePoseidon2Filler<F, SBOX
start_timestamp_u32,
multi_observe_cols.write_sponge_state.as_mut(),
);
start_timestamp_u32 += 1;
}
if row_idx == num_rows - 1 {
mem_fill_helper(
mem_helper,
start_timestamp_u32,
multi_observe_cols.write_final_idx.as_mut(),
);
}
}
}
Expand Down
1 change: 0 additions & 1 deletion extensions/native/circuit/src/poseidon2/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,5 +242,4 @@ pub struct MultiObserveCols<T> {

// Final write back and registers
pub write_final_idx: MemoryWriteAuxCols<T, 1>,
pub final_idx: T,
}
11 changes: 10 additions & 1 deletion extensions/native/circuit/src/poseidon2/execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,9 @@ impl<F: PrimeField32, const SBOX_REGISTERS: usize> MeteredExecutor<F>
{
#[inline(always)]
fn metered_pre_compute_size(&self) -> usize {
std::cmp::max(
max3(
size_of::<E2PreCompute<Pos2PreCompute<F, SBOX_REGISTERS>>>(),
size_of::<E2PreCompute<MultiObservePreCompute<F, SBOX_REGISTERS>>>(),
size_of::<E2PreCompute<VerifyBatchPreCompute<F, SBOX_REGISTERS>>>(),
)
}
Expand Down Expand Up @@ -613,6 +614,7 @@ unsafe fn execute_multi_observe_e12_impl<
pos += len;
}
}
let final_idx = observation_chunks.last().map(|(_, end)| *end % CHUNK);

height += 1;
let mut input_idx = 0;
Expand All @@ -631,6 +633,13 @@ unsafe fn execute_multi_observe_e12_impl<

height += 1;
}
if let Some(final_idx) = final_idx {
exec_state.vm_write::<F, 1>(
NATIVE_AS,
pre_compute.init_pos_register,
&[F::from_canonical_usize(final_idx)],
);
}
*pc = pc.wrapping_add(DEFAULT_PC_STEP);
*instret += 1;

Expand Down
1 change: 1 addition & 0 deletions extensions/native/compiler/src/ir/poseidon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ impl<C: Config> Builder<C> {
len.clone(),
));

// automatically updated by Poseidon2MultiObserve operation
Usize::Var(init_pos)
}
},
Expand Down
18 changes: 18 additions & 0 deletions extensions/native/recursion/src/challenger/duplex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,24 @@ impl<C: Config> DuplexChallengerVariable<C> {
}
}

// Observes multiple elements from an array.
// This is equivalent to calling `observe` multiple times, but more efficient.
pub fn observe_slice_opt(&self, builder: &mut Builder<C>, arr: &Array<C, Felt<C::F>>) {
builder.if_ne(arr.len(), Usize::from(0)).then(|builder| {
let next_pos = builder.poseidon2_multi_observe(&self.sponge_state, self.input_ptr, arr);

builder.assign(&self.input_ptr, self.io_empty_ptr + next_pos.clone());
builder.if_ne(next_pos, Usize::from(0)).then_or_else(
|builder| {
builder.assign(&self.output_ptr, self.io_empty_ptr);
},
|builder| {
builder.assign(&self.output_ptr, self.io_full_ptr);
},
);
});
}

fn sample(&self, builder: &mut Builder<C>) -> Felt<C::F> {
builder
.if_ne(self.input_ptr.address, self.io_empty_ptr.address)
Expand Down
38 changes: 10 additions & 28 deletions extensions/native/recursion/tests/recursion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@ use openvm_native_compiler::{
asm::{AsmBuilder, AsmCompiler},
conversion::{convert_program, CompilerOptions},
ir::{Array, Builder, Config, Felt},
prelude::Usize,
};
use openvm_native_recursion::{
challenger::{duplex::DuplexChallengerVariable, CanObserveVariable},
challenger::{duplex::DuplexChallengerVariable, CanObserveVariable, CanSampleVariable},
testing_utils::inner::run_recursive_test,
};
use openvm_stark_backend::{
Expand Down Expand Up @@ -192,7 +191,6 @@ fn test_multi_observe() {
compiler.build(builder.operations);
let asm_code = compiler.code();

// let program = Program::from_instructions(&instructions);
let program: Program<_> = convert_program(asm_code, compilation_options);

let poseidon2_max_constraint_degree = 3;
Expand Down Expand Up @@ -232,17 +230,12 @@ fn test_multi_observe() {
}

fn build_test_program<C: Config>(builder: &mut Builder<C>) {
let sample_lens: Vec<usize> = vec![10, 2, 1, 3, 20];
let sample_lens: Vec<usize> = vec![10, 2, 1, 0, 3, 20, 200, 400];

let mut rng = create_seeded_rng();
let mut challenger = DuplexChallengerVariable::new(builder);

// Observe a setup label
let label_f: Vec<u64> = vec![128, 3098, 192, 394, 1662, 928, 374, 281, 598, 182, 475, 729];
for n in label_f {
let f: Felt<C::F> = builder.constant(C::F::from_canonical_u64(n));
challenger.observe(builder, f);
}
let mut c1 = DuplexChallengerVariable::new(builder);
let mut c2 = DuplexChallengerVariable::new(builder);

for l in sample_lens {
let sample_input: Array<C, Felt<C::F>> = builder.dyn_array(l);
Expand All @@ -251,24 +244,13 @@ fn build_test_program<C: Config>(builder: &mut Builder<C>) {
builder.set(&sample_input, idx_vec[0], C::F::from_canonical_u32(f_u32));
});

let next_input_ptr = builder.poseidon2_multi_observe(
&challenger.sponge_state,
challenger.input_ptr,
&sample_input,
);
c1.observe_slice_opt(builder, &sample_input);
c2.observe_slice(builder, sample_input);

let e1 = c1.sample(builder);
let e2 = c2.sample(builder);

builder.assign(
&challenger.input_ptr,
challenger.io_empty_ptr + next_input_ptr.clone(),
);
builder.if_ne(next_input_ptr, Usize::from(0)).then_or_else(
|builder| {
builder.assign(&challenger.output_ptr, challenger.io_empty_ptr);
},
|builder| {
builder.assign(&challenger.output_ptr, challenger.io_full_ptr);
},
);
builder.assert_felt_eq(e1, e2);
}
builder.halt();
}
Loading