diff --git a/Cargo.lock b/Cargo.lock index 9cb5dfbba2..ad540352cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5840,6 +5840,7 @@ dependencies = [ "metrics", "once_cell", "openvm-circuit", + "openvm-cuda-backend", "openvm-native-circuit", "openvm-native-compiler", "openvm-native-compiler-derive", diff --git a/extensions/native/circuit/cuda/include/native/poseidon2.cuh b/extensions/native/circuit/cuda/include/native/poseidon2.cuh index 737406839f..40f0e4ad43 100644 --- a/extensions/native/circuit/cuda/include/native/poseidon2.cuh +++ b/extensions/native/circuit/cuda/include/native/poseidon2.cuh @@ -62,12 +62,43 @@ template struct SimplePoseidonSpecificCols { MemoryWriteAuxCols write_data_2; }; +template struct MultiObserveCols { + T pc; + T final_timestamp_increment; + T state_ptr; + T input_ptr; + T init_pos; + T len; + T input_register_1; + T input_register_2; + T input_register_3; + T output_register; + T is_first; + T is_last; + T curr_len; + T start_idx; + T end_idx; + T aux_after_start[CHUNK]; + T aux_before_end[CHUNK]; + T aux_read_enabled[CHUNK]; + MemoryReadAuxCols read_data[CHUNK]; + MemoryWriteAuxCols write_data[CHUNK]; + T data[CHUNK]; + T should_permute; + MemoryWriteAuxCols write_sponge_state; + MemoryWriteAuxCols write_final_idx; + T final_idx; +}; + template constexpr T constexpr_max(T a, T b) { return a > b ? a : b; } constexpr size_t COL_SPECIFIC_WIDTH = constexpr_max( sizeof(TopLevelSpecificCols), constexpr_max( sizeof(InsideRowSpecificCols), - sizeof(SimplePoseidonSpecificCols) + constexpr_max( + sizeof(SimplePoseidonSpecificCols), + sizeof(MultiObserveCols) + ) ) ); diff --git a/extensions/native/circuit/cuda/src/poseidon2.cu b/extensions/native/circuit/cuda/src/poseidon2.cu index 9779b601e4..ece788b0a8 100644 --- a/extensions/native/circuit/cuda/src/poseidon2.cu +++ b/extensions/native/circuit/cuda/src/poseidon2.cu @@ -22,6 +22,7 @@ template struct NativePoseidon2Cols { T incorporate_sibling; T inside_row; T simple; + T multi_observe_row; T end_inside_row; T end_top_level; @@ -38,7 +39,7 @@ template struct NativePoseidon2Cols { }; __device__ void mem_fill_base( - MemoryAuxColsFactory mem_helper, + MemoryAuxColsFactory &mem_helper, uint32_t timestamp, RowSlice base_aux ) { @@ -58,6 +59,8 @@ template struct Poseidon2Wrapper { ) { if (row[COL_INDEX(Cols, simple)] == Fp::one()) { fill_simple_chunk(row, range_checker, timestamp_max_bits); + } else if (row[COL_INDEX(Cols, multi_observe_row)] == Fp::one()) { + fill_multi_observe_chunk(row, range_checker, timestamp_max_bits); } else { fill_verify_batch_chunk(row, range_checker, timestamp_max_bits); } @@ -335,6 +338,66 @@ template struct Poseidon2Wrapper { } } } + + __device__ static void fill_multi_observe_chunk( + RowSlice row, + VariableRangeChecker range_checker, + uint32_t timestamp_max_bits + ) { + MemoryAuxColsFactory mem_helper(range_checker, timestamp_max_bits); + Poseidon2Row head_row(row); + uint32_t num_rows = head_row.export_col()[0].asUInt32(); + + for (uint32_t idx = 0; idx < num_rows; ++idx) { + RowSlice curr_row = row.shift_row(idx); + fill_inner(curr_row); + fill_multi_observe_specific(curr_row, mem_helper); + } + } + + __device__ static void fill_multi_observe_specific( + RowSlice row, + MemoryAuxColsFactory &mem_helper + ) { + RowSlice specific = row.slice_from(COL_INDEX(Cols, specific)); + if (specific[COL_INDEX(MultiObserveCols, is_first)] == Fp::one()) { + uint32_t very_start_timestamp = + row[COL_INDEX(Cols, very_first_timestamp)].asUInt32(); + for (uint32_t i = 0; i < 4; ++i) { + mem_fill_base( + mem_helper, + very_start_timestamp + i, + specific.slice_from(COL_INDEX(MultiObserveCols, read_data[i].base)) + ); + } + } else { + uint32_t start_timestamp = row[COL_INDEX(Cols, start_timestamp)].asUInt32(); + uint32_t chunk_start = + specific[COL_INDEX(MultiObserveCols, start_idx)].asUInt32(); + uint32_t chunk_end = + specific[COL_INDEX(MultiObserveCols, end_idx)].asUInt32(); + for (uint32_t j = chunk_start; j < chunk_end; ++j) { + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(MultiObserveCols, read_data[j].base)) + ); + mem_fill_base( + mem_helper, + start_timestamp + 1, + specific.slice_from(COL_INDEX(MultiObserveCols, write_data[j].base)) + ); + start_timestamp += 2; + } + if (chunk_end >= CHUNK) { + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(MultiObserveCols, write_sponge_state.base)) + ); + } + } + } }; template diff --git a/extensions/native/circuit/src/poseidon2/cuda.rs b/extensions/native/circuit/src/poseidon2/cuda.rs index 107589ef49..0425cfda18 100644 --- a/extensions/native/circuit/src/poseidon2/cuda.rs +++ b/extensions/native/circuit/src/poseidon2/cuda.rs @@ -53,6 +53,9 @@ impl Chip chunk_start.push(row_idx as u32); if cols.simple.is_one() { row_idx += 1; + } else if cols.multi_observe_row.is_one() { + let num_rows = cols.inner.export.as_canonical_u32() as usize; + row_idx += num_rows; } else { let num_non_inside_row = cols.inner.export.as_canonical_u32() as usize; let non_inside_start = start + (num_non_inside_row - 1) * width; diff --git a/extensions/native/recursion/Cargo.toml b/extensions/native/recursion/Cargo.toml index a8efd69ab3..f47f263693 100644 --- a/extensions/native/recursion/Cargo.toml +++ b/extensions/native/recursion/Cargo.toml @@ -8,6 +8,7 @@ repository.workspace = true [dependencies] openvm-stark-backend = { workspace = true } +openvm-cuda-backend = { workspace = true, optional = true } openvm-native-circuit = { workspace = true, features = ["test-utils"] } openvm-native-compiler = { workspace = true } openvm-native-compiler-derive = { workspace = true } @@ -58,4 +59,4 @@ parallel = ["openvm-stark-backend/parallel"] mimalloc = ["openvm-stark-backend/mimalloc"] jemalloc = ["openvm-stark-backend/jemalloc"] nightly-features = ["openvm-circuit/nightly-features"] -cuda = ["openvm-circuit/cuda", "openvm-native-circuit/cuda"] +cuda = ["openvm-circuit/cuda", "openvm-native-circuit/cuda", "dep:openvm-cuda-backend"] diff --git a/extensions/native/recursion/tests/recursion.rs b/extensions/native/recursion/tests/recursion.rs index 4e3bba92e7..3b78b6734b 100644 --- a/extensions/native/recursion/tests/recursion.rs +++ b/extensions/native/recursion/tests/recursion.rs @@ -6,6 +6,8 @@ use openvm_circuit::{ }, utils::{air_test_impl, TestStarkEngine}, }; +#[cfg(feature = "cuda")] +use openvm_cuda_backend::engine::GpuBabyBearPoseidon2Engine; use openvm_native_circuit::{ execute_program_with_config, test_native_config, NativeBuilder, NativeConfig, }; @@ -211,8 +213,22 @@ fn test_multi_observe() { config.system.memory_config.max_access_adapter_n = 16; let vb = NativeBuilder::default(); + #[cfg(not(feature = "cuda"))] air_test_impl::(fri_params, vb, config, program, vec![], 1, true) .unwrap(); + #[cfg(feature = "cuda")] + { + air_test_impl::( + fri_params, + vb, + config, + program, + vec![], + 1, + true, + ) + .unwrap(); + } } fn build_test_program(builder: &mut Builder) {