Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 32 additions & 1 deletion extensions/native/circuit/cuda/include/native/poseidon2.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,43 @@ template <typename T> struct SimplePoseidonSpecificCols {
MemoryWriteAuxCols<T, CHUNK> write_data_2;
};

template <typename T> struct MultiObserveCols {
T pc;
T final_timestamp_increment;
T state_ptr;
T input_ptr;
T init_pos;
T len;
T input_register_1;
T input_register_2;
T input_register_3;
T output_register;
T is_first;
T is_last;
T curr_len;
T start_idx;
T end_idx;
T aux_after_start[CHUNK];
T aux_before_end[CHUNK];
T aux_read_enabled[CHUNK];
MemoryReadAuxCols<T> read_data[CHUNK];
MemoryWriteAuxCols<T, 1> write_data[CHUNK];
T data[CHUNK];
T should_permute;
MemoryWriteAuxCols<T, CHUNK * 2> write_sponge_state;
MemoryWriteAuxCols<T, 1> write_final_idx;
T final_idx;
};

template <typename T> constexpr T constexpr_max(T a, T b) { return a > b ? a : b; }

constexpr size_t COL_SPECIFIC_WIDTH = constexpr_max(
sizeof(TopLevelSpecificCols<uint8_t>),
constexpr_max(
sizeof(InsideRowSpecificCols<uint8_t>),
sizeof(SimplePoseidonSpecificCols<uint8_t>)
constexpr_max(
sizeof(SimplePoseidonSpecificCols<uint8_t>),
sizeof(MultiObserveCols<uint8_t>)
)
)
);
65 changes: 64 additions & 1 deletion extensions/native/circuit/cuda/src/poseidon2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ template <typename T, size_t SBOX_REGISTERS> struct NativePoseidon2Cols {
T incorporate_sibling;
T inside_row;
T simple;
T multi_observe_row;

T end_inside_row;
T end_top_level;
Expand All @@ -38,7 +39,7 @@ template <typename T, size_t SBOX_REGISTERS> struct NativePoseidon2Cols {
};

__device__ void mem_fill_base(
MemoryAuxColsFactory mem_helper,
MemoryAuxColsFactory &mem_helper,
uint32_t timestamp,
RowSlice base_aux
) {
Expand All @@ -58,6 +59,8 @@ template <size_t SBOX_REGISTERS> struct Poseidon2Wrapper {
) {
if (row[COL_INDEX(Cols, simple)] == Fp::one()) {
fill_simple_chunk(row, range_checker, timestamp_max_bits);
} else if (row[COL_INDEX(Cols, multi_observe_row)] == Fp::one()) {
fill_multi_observe_chunk(row, range_checker, timestamp_max_bits);
} else {
fill_verify_batch_chunk(row, range_checker, timestamp_max_bits);
}
Expand Down Expand Up @@ -335,6 +338,66 @@ template <size_t SBOX_REGISTERS> struct Poseidon2Wrapper {
}
}
}

__device__ static void fill_multi_observe_chunk(
RowSlice row,
VariableRangeChecker range_checker,
uint32_t timestamp_max_bits
) {
MemoryAuxColsFactory mem_helper(range_checker, timestamp_max_bits);
Poseidon2Row head_row(row);
uint32_t num_rows = head_row.export_col()[0].asUInt32();

for (uint32_t idx = 0; idx < num_rows; ++idx) {
RowSlice curr_row = row.shift_row(idx);
fill_inner(curr_row);
fill_multi_observe_specific(curr_row, mem_helper);
}
}

__device__ static void fill_multi_observe_specific(
RowSlice row,
MemoryAuxColsFactory &mem_helper
) {
RowSlice specific = row.slice_from(COL_INDEX(Cols, specific));
if (specific[COL_INDEX(MultiObserveCols, is_first)] == Fp::one()) {
uint32_t very_start_timestamp =
row[COL_INDEX(Cols, very_first_timestamp)].asUInt32();
for (uint32_t i = 0; i < 4; ++i) {
mem_fill_base(
mem_helper,
very_start_timestamp + i,
specific.slice_from(COL_INDEX(MultiObserveCols, read_data[i].base))
);
}
} else {
uint32_t start_timestamp = row[COL_INDEX(Cols, start_timestamp)].asUInt32();
uint32_t chunk_start =
specific[COL_INDEX(MultiObserveCols, start_idx)].asUInt32();
uint32_t chunk_end =
specific[COL_INDEX(MultiObserveCols, end_idx)].asUInt32();
for (uint32_t j = chunk_start; j < chunk_end; ++j) {
mem_fill_base(
mem_helper,
start_timestamp,
specific.slice_from(COL_INDEX(MultiObserveCols, read_data[j].base))
);
mem_fill_base(
mem_helper,
start_timestamp + 1,
specific.slice_from(COL_INDEX(MultiObserveCols, write_data[j].base))
);
start_timestamp += 2;
}
if (chunk_end >= CHUNK) {
mem_fill_base(
mem_helper,
start_timestamp,
specific.slice_from(COL_INDEX(MultiObserveCols, write_sponge_state.base))
);
}
}
}
};

template <size_t SBOX_REGISTERS>
Expand Down
3 changes: 3 additions & 0 deletions extensions/native/circuit/src/poseidon2/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ impl<const SBOX_REGISTERS: usize> Chip<DenseRecordArena, GpuBackend>
chunk_start.push(row_idx as u32);
if cols.simple.is_one() {
row_idx += 1;
} else if cols.multi_observe_row.is_one() {
let num_rows = cols.inner.export.as_canonical_u32() as usize;
row_idx += num_rows;
} else {
let num_non_inside_row = cols.inner.export.as_canonical_u32() as usize;
let non_inside_start = start + (num_non_inside_row - 1) * width;
Expand Down
3 changes: 2 additions & 1 deletion extensions/native/recursion/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ repository.workspace = true

[dependencies]
openvm-stark-backend = { workspace = true }
openvm-cuda-backend = { workspace = true, optional = true }
openvm-native-circuit = { workspace = true, features = ["test-utils"] }
openvm-native-compiler = { workspace = true }
openvm-native-compiler-derive = { workspace = true }
Expand Down Expand Up @@ -58,4 +59,4 @@ parallel = ["openvm-stark-backend/parallel"]
mimalloc = ["openvm-stark-backend/mimalloc"]
jemalloc = ["openvm-stark-backend/jemalloc"]
nightly-features = ["openvm-circuit/nightly-features"]
cuda = ["openvm-circuit/cuda", "openvm-native-circuit/cuda"]
cuda = ["openvm-circuit/cuda", "openvm-native-circuit/cuda", "dep:openvm-cuda-backend"]
16 changes: 16 additions & 0 deletions extensions/native/recursion/tests/recursion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use openvm_circuit::{
},
utils::{air_test_impl, TestStarkEngine},
};
#[cfg(feature = "cuda")]
use openvm_cuda_backend::engine::GpuBabyBearPoseidon2Engine;
use openvm_native_circuit::{
execute_program_with_config, test_native_config, NativeBuilder, NativeConfig,
};
Expand Down Expand Up @@ -211,8 +213,22 @@ fn test_multi_observe() {
config.system.memory_config.max_access_adapter_n = 16;

let vb = NativeBuilder::default();
#[cfg(not(feature = "cuda"))]
air_test_impl::<BabyBearPoseidon2Engine, _>(fri_params, vb, config, program, vec![], 1, true)
.unwrap();
#[cfg(feature = "cuda")]
{
air_test_impl::<GpuBabyBearPoseidon2Engine, _>(
fri_params,
vb,
config,
program,
vec![],
1,
true,
)
.unwrap();
}
}

fn build_test_program<C: Config>(builder: &mut Builder<C>) {
Expand Down
Loading