From b11662a65eaeb7b6485d184aa835320e2d970724 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Sun, 30 Nov 2025 22:51:07 +0800 Subject: [PATCH] tracegen --- .../circuit/cuda/include/native/sumcheck.cuh | 85 ++++++++++++ .../circuit/cuda/include/native/utils.cuh | 13 ++ .../native/circuit/cuda/src/poseidon2.cu | 10 +- .../native/circuit/cuda/src/sumcheck.cu | 126 ++++++++++++++++++ extensions/native/circuit/src/cuda_abi.rs | 38 ++++++ .../native/circuit/src/extension/cuda.rs | 5 + .../native/circuit/src/sumcheck/chip.rs | 14 +- .../native/circuit/src/sumcheck/cuda.rs | 56 ++++++++ extensions/native/recursion/tests/sumcheck.rs | 13 +- 9 files changed, 342 insertions(+), 18 deletions(-) create mode 100644 extensions/native/circuit/cuda/include/native/sumcheck.cuh create mode 100644 extensions/native/circuit/cuda/include/native/utils.cuh create mode 100644 extensions/native/circuit/cuda/src/sumcheck.cu diff --git a/extensions/native/circuit/cuda/include/native/sumcheck.cuh b/extensions/native/circuit/cuda/include/native/sumcheck.cuh new file mode 100644 index 0000000000..052dc03fd5 --- /dev/null +++ b/extensions/native/circuit/cuda/include/native/sumcheck.cuh @@ -0,0 +1,85 @@ +#pragma once + +#include "primitives/constants.h" +#include "system/memory/offline_checker.cuh" + +using namespace native; + +template struct HeaderSpecificCols { + T pc; + T registers[5]; + MemoryReadAuxCols read_records[7]; + MemoryWriteAuxCols write_records; +}; + +template struct ProdSpecificCols { + T data_ptr; + T p[EXT_DEG * 2]; + MemoryReadAuxCols read_records[2]; + T p_evals[EXT_DEG]; + MemoryWriteAuxCols write_record; + T eval_rlc[EXT_DEG]; +}; + +template struct LogupSpecificCols { + T data_ptr; + T pq[EXT_DEG * 4]; + MemoryReadAuxCols read_records[2]; + T p_evals[EXT_DEG]; + T q_evals[EXT_DEG]; + MemoryWriteAuxCols write_records[2]; + T eval_rlc[EXT_DEG]; +}; + +template constexpr T constexpr_max(T a, T b) { + return a > b ? a : b; +} + +constexpr size_t COL_SPECIFIC_WIDTH = constexpr_max( + sizeof(HeaderSpecificCols), + constexpr_max(sizeof(ProdSpecificCols), sizeof(LogupSpecificCols)) +); + +template struct NativeSumcheckCols { + T header_row; + T prod_row; + T logup_row; + T is_end; + + T prod_continued; + T logup_continued; + + T prod_in_round_evaluation; + T prod_next_round_evaluation; + T logup_in_round_evaluation; + T logup_next_round_evaluation; + + T prod_acc; + T logup_acc; + + T first_timestamp; + T start_timestamp; + T last_timestamp; + + T register_ptrs[5]; + + T ctx[EXT_DEG * 2]; + + T prod_nested_len; + T logup_nested_len; + + T curr_prod_n; + T curr_logup_n; + + T alpha[EXT_DEG]; + T challenges[EXT_DEG * 4]; + + T max_round; + T within_round_limit; + T should_acc; + + T eval_acc[EXT_DEG]; + + T specific[COL_SPECIFIC_WIDTH]; +}; + diff --git a/extensions/native/circuit/cuda/include/native/utils.cuh b/extensions/native/circuit/cuda/include/native/utils.cuh new file mode 100644 index 0000000000..f217350959 --- /dev/null +++ b/extensions/native/circuit/cuda/include/native/utils.cuh @@ -0,0 +1,13 @@ +#pragma once + +#include "primitives/trace_access.h" +#include "system/memory/controller.cuh" + +__device__ __forceinline__ void mem_fill_base( + MemoryAuxColsFactory &mem_helper, + uint32_t timestamp, + RowSlice base_aux +) { + uint32_t prev = base_aux[COL_INDEX(MemoryBaseAuxCols, prev_timestamp)].asUInt32(); + mem_helper.fill(base_aux, prev, timestamp); +} diff --git a/extensions/native/circuit/cuda/src/poseidon2.cu b/extensions/native/circuit/cuda/src/poseidon2.cu index 59c65626ab..fdbe0d3ce5 100644 --- a/extensions/native/circuit/cuda/src/poseidon2.cu +++ b/extensions/native/circuit/cuda/src/poseidon2.cu @@ -2,6 +2,7 @@ #include "poseidon2-air/columns.cuh" #include "poseidon2-air/params.cuh" #include "poseidon2-air/tracegen.cuh" +#include "native/utils.cuh" #include "primitives/trace_access.h" #include "system/memory/controller.cuh" @@ -38,15 +39,6 @@ template struct NativePoseidon2Cols { T specific[COL_SPECIFIC_WIDTH]; }; -__device__ void mem_fill_base( - MemoryAuxColsFactory &mem_helper, - uint32_t timestamp, - RowSlice base_aux -) { - uint32_t prev = base_aux[COL_INDEX(MemoryBaseAuxCols, prev_timestamp)].asUInt32(); - mem_helper.fill(base_aux, prev, timestamp); -} - template struct Poseidon2Wrapper { template using Cols = NativePoseidon2Cols; using Poseidon2Row = diff --git a/extensions/native/circuit/cuda/src/sumcheck.cu b/extensions/native/circuit/cuda/src/sumcheck.cu new file mode 100644 index 0000000000..99c365135b --- /dev/null +++ b/extensions/native/circuit/cuda/src/sumcheck.cu @@ -0,0 +1,126 @@ +#include "launcher.cuh" +#include "native/sumcheck.cuh" +#include "native/utils.cuh" +#include "primitives/trace_access.h" +#include "system/memory/controller.cuh" + +using namespace native; + +__device__ void fill_sumcheck_specific(RowSlice row, MemoryAuxColsFactory &mem_helper) { + RowSlice specific = row.slice_from(COL_INDEX(NativeSumcheckCols, specific)); + uint32_t start_timestamp = row[COL_INDEX(NativeSumcheckCols, start_timestamp)].asUInt32(); + + if (row[COL_INDEX(NativeSumcheckCols, header_row)] == Fp::one()) { + for (uint32_t i = 0; i < 7; ++i) { + mem_fill_base( + mem_helper, + start_timestamp + i, + specific.slice_from(COL_INDEX(HeaderSpecificCols, read_records[i].base)) + ); + } + uint32_t last_timestamp = row[COL_INDEX(NativeSumcheckCols, last_timestamp)].asUInt32(); + mem_fill_base( + mem_helper, + last_timestamp - 1, + specific.slice_from(COL_INDEX(HeaderSpecificCols, write_records.base)) + ); + } else if (row[COL_INDEX(NativeSumcheckCols, prod_row)] == Fp::one()) { + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(ProdSpecificCols, read_records[0].base)) + ); + if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) { + mem_fill_base( + mem_helper, + start_timestamp + 1, + specific.slice_from(COL_INDEX(ProdSpecificCols, read_records[1].base)) + ); + mem_fill_base( + mem_helper, + start_timestamp + 2, + specific.slice_from(COL_INDEX(ProdSpecificCols, write_record.base)) + ); + } + } else if (row[COL_INDEX(NativeSumcheckCols, logup_row)] == Fp::one()) { + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(LogupSpecificCols, read_records[0].base)) + ); + if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) { + mem_fill_base( + mem_helper, + start_timestamp + 1, + specific.slice_from(COL_INDEX(LogupSpecificCols, read_records[1].base)) + ); + mem_fill_base( + mem_helper, + start_timestamp + 2, + specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[0].base)) + ); + mem_fill_base( + mem_helper, + start_timestamp + 3, + specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[1].base)) + ); + } + } +} + +__global__ void native_sumcheck_tracegen( + Fp *trace, + size_t height, + size_t width, + const Fp *records, + size_t rows_used, + uint32_t *range_checker_ptr, + uint32_t range_checker_num_bins, + uint32_t timestamp_max_bits +) { + uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= height) { + return; + } + + RowSlice row(trace + idx, height); + if (idx < rows_used) { + const Fp *record = records + idx * width; + for (uint32_t col = 0; col < width; ++col) { + row[col] = record[col]; + } + MemoryAuxColsFactory mem_helper( + VariableRangeChecker(range_checker_ptr, range_checker_num_bins), timestamp_max_bits + ); + fill_sumcheck_specific(row, mem_helper); + } else { + row.fill_zero(0, width); + COL_WRITE_VALUE(row, NativeSumcheckCols, is_end, Fp::one()); + } +} + +extern "C" int _native_sumcheck_tracegen( + Fp *d_trace, + size_t height, + size_t width, + const Fp *d_records, + size_t rows_used, + uint32_t *d_range_checker, + uint32_t range_checker_num_bins, + uint32_t timestamp_max_bits +) { + assert((height & (height - 1)) == 0); + assert(width == sizeof(NativeSumcheckCols)); + auto [grid, block] = kernel_launch_params(height); + native_sumcheck_tracegen<<>>( + d_trace, + height, + width, + d_records, + rows_used, + d_range_checker, + range_checker_num_bins, + timestamp_max_bits + ); + return CHECK_KERNEL(); +} diff --git a/extensions/native/circuit/src/cuda_abi.rs b/extensions/native/circuit/src/cuda_abi.rs index ad1a454d7b..5de9124f0d 100644 --- a/extensions/native/circuit/src/cuda_abi.rs +++ b/extensions/native/circuit/src/cuda_abi.rs @@ -235,6 +235,44 @@ pub mod poseidon2_cuda { } } +pub mod sumcheck_cuda { + use super::*; + + extern "C" { + pub fn _native_sumcheck_tracegen( + d_trace: *mut F, + height: usize, + width: usize, + d_records: *const F, + rows_used: usize, + d_range_checker: *mut u32, + range_checker_max_bins: u32, + timestamp_max_bits: u32, + ) -> i32; + } + + pub unsafe fn tracegen( + d_trace: &DeviceBuffer, + height: usize, + width: usize, + d_records: &DeviceBuffer, + rows_used: usize, + d_range_checker: &DeviceBuffer, + timestamp_max_bits: u32, + ) -> Result<(), CudaError> { + CudaError::from_result(_native_sumcheck_tracegen( + d_trace.as_mut_ptr(), + height, + width, + d_records.as_ptr(), + rows_used, + d_range_checker.as_mut_ptr() as *mut u32, + d_range_checker.len() as u32, + timestamp_max_bits, + )) + } +} + pub mod native_loadstore_cuda { use super::*; diff --git a/extensions/native/circuit/src/extension/cuda.rs b/extensions/native/circuit/src/extension/cuda.rs index 9a433fce11..765ce8d6cc 100644 --- a/extensions/native/circuit/src/extension/cuda.rs +++ b/extensions/native/circuit/src/extension/cuda.rs @@ -17,6 +17,7 @@ use crate::{ jal_rangecheck::{JalRangeCheckAir, JalRangeCheckGpu}, loadstore::{NativeLoadStoreAir, NativeLoadStoreChipGpu}, poseidon2::{air::NativePoseidon2Air, NativePoseidon2ChipGpu}, + sumcheck::{air::NativeSumcheckAir, NativeSumcheckChipGpu}, CastFExtension, GpuBackend, Native, }; @@ -75,6 +76,10 @@ impl VmProverExtension let poseidon2 = NativePoseidon2ChipGpu::<1>::new(range_checker.clone(), timestamp_max_bits); inventory.add_executor_chip(poseidon2); + inventory.next_air::()?; + let sumcheck = NativeSumcheckChipGpu::new(range_checker.clone(), timestamp_max_bits); + inventory.add_executor_chip(sumcheck); + Ok(()) } } diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index 17e99c442a..a33d286cd1 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -3,7 +3,7 @@ use std::borrow::BorrowMut; use openvm_circuit::{ arch::{ CustomBorrow, ExecutionError, MultiRowLayout, MultiRowMetadata, PreflightExecutor, - RecordArena, TraceFiller, VmChipWrapper, VmStateMut, + RecordArena, SizedRecord, TraceFiller, VmChipWrapper, VmStateMut, }, system::{ memory::{online::TracingMemory, MemoryAuxColsFactory}, @@ -76,7 +76,7 @@ impl<'a, F: PrimeField32> // Each instruction record consists solely of some number of contiguously // stored NativeSumcheckCols<...> structs, each of which corresponds to a // single trace row. Trace fillers don't actually need to know how many rows - // each instruction uses, and can thus treat each NativePoseidon2Cols<...> + // each instruction uses, and can thus treat each NativeSumcheckCols<...> // as a single record. NativeSumcheckRecordLayout { metadata: NativeSumcheckMetadata { num_rows: 1 }, @@ -84,6 +84,16 @@ impl<'a, F: PrimeField32> } } +impl SizedRecord for NativeSumcheckRecordMut<'_, F> { + fn size(layout: &NativeSumcheckRecordLayout) -> usize { + layout.metadata.num_rows * size_of::>() + } + + fn alignment(_layout: &NativeSumcheckRecordLayout) -> usize { + align_of::>() + } +} + #[derive(derive_new::new, Copy, Clone)] pub struct NativeSumcheckExecutor; diff --git a/extensions/native/circuit/src/sumcheck/cuda.rs b/extensions/native/circuit/src/sumcheck/cuda.rs index 8b13789179..60aba15b95 100644 --- a/extensions/native/circuit/src/sumcheck/cuda.rs +++ b/extensions/native/circuit/src/sumcheck/cuda.rs @@ -1 +1,57 @@ +use std::{mem::size_of, slice::from_raw_parts, sync::Arc}; +use derive_new::new; +use openvm_circuit::{arch::DenseRecordArena, utils::next_power_of_two_or_zero}; +use openvm_circuit_primitives::var_range::VariableRangeCheckerChipGPU; +use openvm_cuda_backend::{ + base::DeviceMatrix, chip::get_empty_air_proving_ctx, prover_backend::GpuBackend, types::F, +}; +use openvm_cuda_common::copy::MemCopyH2D; +use openvm_stark_backend::{prover::types::AirProvingContext, Chip}; + +use super::columns::NativeSumcheckCols; +use crate::cuda_abi::sumcheck_cuda; + +#[derive(new)] +pub struct NativeSumcheckChipGpu { + pub range_checker: Arc, + pub timestamp_max_bits: usize, +} + +impl Chip for NativeSumcheckChipGpu { + fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext { + let records = arena.allocated(); + if records.is_empty() { + return get_empty_air_proving_ctx::(); + } + + let width = NativeSumcheckCols::::width(); + let record_size = width * size_of::(); + assert_eq!(records.len() % record_size, 0); + + let height = records.len() / record_size; + let padded_height = next_power_of_two_or_zero(height); + let trace = DeviceMatrix::::with_capacity(padded_height, width); + + let record_slice = unsafe { + let ptr = records.as_ptr(); + from_raw_parts(ptr as *const F, records.len() / size_of::()) + }; + let d_records = record_slice.to_device().unwrap(); + + unsafe { + sumcheck_cuda::tracegen( + trace.buffer(), + padded_height, + width, + &d_records, + height, + &self.range_checker.count, + self.timestamp_max_bits as u32, + ) + .unwrap(); + } + + AirProvingContext::simple_no_pis(trace) + } +} diff --git a/extensions/native/recursion/tests/sumcheck.rs b/extensions/native/recursion/tests/sumcheck.rs index a4039028bc..494d82c03a 100644 --- a/extensions/native/recursion/tests/sumcheck.rs +++ b/extensions/native/recursion/tests/sumcheck.rs @@ -1,6 +1,6 @@ -use openvm_circuit::arch::instructions::program::Program; -#[cfg(not(feature = "cuda"))] -use openvm_circuit::utils::air_test_impl; +use openvm_circuit::{arch::instructions::program::Program, utils::air_test_impl}; +#[cfg(feature = "cuda")] +use openvm_cuda_backend::engine::GpuBabyBearPoseidon2Engine; use openvm_native_circuit::{NativeBuilder, NativeConfig, EXT_DEG}; use openvm_native_compiler::{ asm::{AsmBuilder, AsmCompiler}, @@ -11,11 +11,10 @@ use openvm_native_compiler::{ use openvm_stark_backend::p3_field::{ extension::BinomialExtensionField, FieldAlgebra, FieldExtensionAlgebra, }; +#[cfg(not(feature = "cuda"))] +use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Engine; use openvm_stark_sdk::{ - config::{ - baby_bear_poseidon2::BabyBearPoseidon2Engine, - fri_params::standard_fri_params_with_100_bits_conjectured_security, FriParameters, - }, + config::{fri_params::standard_fri_params_with_100_bits_conjectured_security, FriParameters}, p3_baby_bear::BabyBear, };