From 011285567de052bb651898fce125851d1dd2dd76 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Wed, 26 Nov 2025 15:48:09 +0800 Subject: [PATCH 01/18] wip --- .gitignore | 3 + crates/sdk/src/prover/agg.rs | 6 +- .../circuit/src/field_extension/core.rs | 4 +- extensions/native/circuit/src/fri/mod.rs | 2 +- extensions/native/circuit/src/lib.rs | 1 + extensions/native/circuit/src/sumcheck/air.rs | 532 +++++++ .../native/circuit/src/sumcheck/chip.rs | 565 ++++++++ .../native/circuit/src/sumcheck/columns.rs | 142 ++ .../native/circuit/src/sumcheck/cuda.rs | 1 + .../native/circuit/src/sumcheck/execution.rs | 194 +++ extensions/native/circuit/src/sumcheck/mod.rs | 11 + .../native/compiler/src/asm/compiler.rs | 12 + .../native/compiler/src/asm/instruction.rs | 16 + .../native/compiler/src/conversion/mod.rs | 16 +- .../native/compiler/src/ir/instructions.rs | 25 + extensions/native/compiler/src/ir/mod.rs | 1 + extensions/native/compiler/src/ir/sumcheck.rs | 52 + extensions/native/compiler/src/lib.rs | 15 + extensions/native/recursion/tests/sumcheck.rs | 1261 +++++++++++++++++ 19 files changed, 2851 insertions(+), 8 deletions(-) create mode 100644 extensions/native/circuit/src/sumcheck/air.rs create mode 100644 extensions/native/circuit/src/sumcheck/chip.rs create mode 100644 extensions/native/circuit/src/sumcheck/columns.rs create mode 100644 extensions/native/circuit/src/sumcheck/cuda.rs create mode 100644 extensions/native/circuit/src/sumcheck/execution.rs create mode 100644 extensions/native/circuit/src/sumcheck/mod.rs create mode 100644 extensions/native/compiler/src/ir/sumcheck.rs create mode 100644 extensions/native/recursion/tests/sumcheck.rs diff --git a/.gitignore b/.gitignore index c6e6aa2049..c87b73f198 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,9 @@ Cargo.lock **/.env .DS_Store +# Log outputs +*.log + .cache/ rustc-* diff --git a/crates/sdk/src/prover/agg.rs b/crates/sdk/src/prover/agg.rs index 5e22562675..fd4491ddec 100644 --- a/crates/sdk/src/prover/agg.rs +++ b/crates/sdk/src/prover/agg.rs @@ -27,12 +27,12 @@ where E: StarkFriEngine, NativeBuilder: VmBuilder, { - leaf_prover: VmInstance, - leaf_controller: LeafProvingController, + pub leaf_prover: VmInstance, + pub leaf_controller: LeafProvingController, pub internal_prover: VmInstance, #[cfg(feature = "evm-prove")] - root_prover: RootVerifierLocalProver, + pub root_prover: RootVerifierLocalProver, pub num_children_internal: usize, pub max_internal_wrapper_layers: usize, } diff --git a/extensions/native/circuit/src/field_extension/core.rs b/extensions/native/circuit/src/field_extension/core.rs index 5afaf74af5..a7d535a14b 100644 --- a/extensions/native/circuit/src/field_extension/core.rs +++ b/extensions/native/circuit/src/field_extension/core.rs @@ -254,10 +254,10 @@ pub(crate) struct FieldExtension; impl FieldExtension { pub(crate) fn add(x: [V; EXT_DEG], y: [V; EXT_DEG]) -> [E; EXT_DEG] where - V: Copy, + V: Clone, V: Add, { - array::from_fn(|i| x[i] + y[i]) + array::from_fn(|i| x[i].clone() + y[i].clone()) } pub(crate) fn subtract(x: [V; EXT_DEG], y: [V; EXT_DEG]) -> [E; EXT_DEG] diff --git a/extensions/native/circuit/src/fri/mod.rs b/extensions/native/circuit/src/fri/mod.rs index 1e1ec65cb8..4a0d3847d5 100644 --- a/extensions/native/circuit/src/fri/mod.rs +++ b/extensions/native/circuit/src/fri/mod.rs @@ -542,7 +542,7 @@ fn assert_array_eq, I2: Into, const } } -fn elem_to_ext(elem: F) -> [F; EXT_DEG] { +pub fn elem_to_ext(elem: F) -> [F; EXT_DEG] { let mut ret = [F::ZERO; EXT_DEG]; ret[0] = elem; ret diff --git a/extensions/native/circuit/src/lib.rs b/extensions/native/circuit/src/lib.rs index b5db4f0010..aa2bb3bd31 100644 --- a/extensions/native/circuit/src/lib.rs +++ b/extensions/native/circuit/src/lib.rs @@ -42,6 +42,7 @@ mod fri; mod jal_rangecheck; mod loadstore; mod poseidon2; +mod sumcheck; mod extension; pub use extension::*; diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs new file mode 100644 index 0000000000..494204ceff --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -0,0 +1,532 @@ +use std::{array::from_fn, borrow::Borrow, sync::Arc}; + +use openvm_circuit::{ + arch::{ContinuationVmProof, ExecutionBridge, ExecutionState}, + system::memory::{offline_checker::MemoryBridge, MemoryAddress}, +}; +use openvm_circuit_primitives::utils::{assert_array_eq, not}; +use openvm_instructions::LocalOpcode; +use openvm_native_compiler::SumcheckOpcode::SUMCHECK_LAYER_EVAL; +use openvm_stark_backend::{ + air_builders::sub::SubAirBuilder, + interaction::{BusIndex, InteractionBuilder, PermutationCheckBus}, + p3_air::{Air, AirBuilder, BaseAir}, + p3_field::{Field, FieldAlgebra}, + p3_matrix::Matrix, + rap::{BaseAirWithPublicValues, PartitionedBaseAir}, +}; + +use crate::{ + sumcheck::columns::{ + HeaderSpecificCols, LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols, + }, + FieldExtension, EXT_DEG, +}; + +#[derive(Clone, Debug)] +pub struct NativeSumcheckAir { + pub execution_bridge: ExecutionBridge, + pub memory_bridge: MemoryBridge, + pub address_space: F, +} + +impl BaseAir for NativeSumcheckAir { + fn width(&self) -> usize { + NativeSumcheckCols::::width() + } +} + +impl BaseAirWithPublicValues for NativeSumcheckAir {} + +impl PartitionedBaseAir for NativeSumcheckAir {} + +impl Air for NativeSumcheckAir { + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let local = main.row_slice(0); + let local: &NativeSumcheckCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &NativeSumcheckCols = (*next).borrow(); + + let &NativeSumcheckCols { + // Row indicators + header_row, + prod_row, + logup_row, + + // Whether valid prod/logup row operations follow this row + header_continuation, + prod_continuation, + logup_continuation, + + // Round limit + prod_row_within_max_round, + logup_row_within_max_round, + + // What type of evaluation is performed + prod_in_round_evaluation, + prod_next_round_evaluation, + logup_in_round_evaluation, + logup_next_round_evaluation, + + // Indicates whether the round evaluations should be added to the accumulator + prod_acc, + logup_acc, + + // Timestamps + first_timestamp, + start_timestamp, + last_timestamp, + + // Results from reading registers + register_ptrs, + ctx, + prod_nested_len, + logup_nested_len, + + // Challenges + alpha, + challenges, + + curr_prod_n, + curr_logup_n, + + max_round, + within_round_limit, + should_acc, + eval_acc, + specific, + } = local; + + builder.assert_bool(header_row); + builder.assert_bool(prod_row); + builder.assert_bool(logup_row); + builder.assert_bool(header_continuation); + builder.assert_bool(prod_continuation); + builder.assert_bool(logup_continuation); + builder.assert_bool(prod_row_within_max_round); + builder.assert_bool(logup_row_within_max_round); + builder.assert_bool(prod_in_round_evaluation); + builder.assert_bool(logup_in_round_evaluation); + let enabled = header_row + prod_row + logup_row; + builder.assert_bool(enabled.clone()); + let in_round = ctx[7]; + let continuation = header_continuation + prod_continuation + logup_continuation; + builder.assert_bool(continuation.clone()); + + // Randomness transition + let alpha1: [_; EXT_DEG] = challenges[0..EXT_DEG].try_into().expect(""); + let c1: [_; EXT_DEG] = challenges[EXT_DEG..{ EXT_DEG * 2 }].try_into().expect(""); + let c2: [_; EXT_DEG] = challenges[{ EXT_DEG * 2 }..{ EXT_DEG * 3 }] + .try_into() + .expect(""); + let alpha2: [_; EXT_DEG] = challenges[{ EXT_DEG * 3 }..{ EXT_DEG * 4 }] + .try_into() + .expect(""); + let next_alpha1: [_; EXT_DEG] = next.challenges[0..EXT_DEG].try_into().expect(""); + + // Carry along columns + assert_array_eq( + &mut builder.when(next.prod_row + next.logup_row), + register_ptrs, + next.register_ptrs, + ); + assert_array_eq( + &mut builder.when(next.prod_row + next.logup_row), + ctx, + next.ctx, + ); + assert_array_eq::<_, _, _, { EXT_DEG * 2 }>( + &mut builder.when(next.prod_row + next.logup_row), + challenges[EXT_DEG..(EXT_DEG * 3)].try_into().expect(""), + next.challenges[EXT_DEG..(EXT_DEG * 3)] + .try_into() + .expect(""), + ); + builder + .when(next.prod_row + next.logup_row) + .assert_eq(prod_nested_len, next.prod_nested_len); + builder + .when(next.prod_row + next.logup_row) + .assert_eq(logup_nested_len, next.logup_nested_len); + + // Row transition + builder + .when(next.prod_row) + .assert_eq(curr_prod_n + AB::F::ONE, next.curr_prod_n); + builder + .when(next.logup_row) + .assert_eq(curr_logup_n + AB::F::ONE, next.curr_logup_n); + builder + .when(header_row) + .when(next.logup_row) + .assert_zero(ctx[1]); + builder + .when(prod_row) + .when(next.logup_row) + .assert_eq(ctx[1], curr_prod_n); + builder + .when(prod_row) + .when(not(prod_continuation)) + .assert_eq(ctx[1], curr_prod_n); + builder + .when(logup_row) + .when(not(logup_continuation)) + .assert_eq(ctx[2], curr_logup_n); + + // Timestamp transition + builder + .when(header_row) + .when(next.prod_row + next.logup_row) + .assert_eq( + next.start_timestamp, + start_timestamp + AB::F::from_canonical_usize(7), + ); + builder + .when(prod_row) + .when(next.prod_row + next.logup_row) + .assert_eq( + next.start_timestamp, + start_timestamp + AB::F::ONE + within_round_limit * AB::F::TWO, + ); + builder + .when(logup_row) + .when(next.prod_row + next.logup_row) + .assert_eq( + next.start_timestamp, + start_timestamp + AB::F::ONE + within_round_limit * AB::F::from_canonical_usize(3), + ); + + // Termination condition + assert_array_eq( + &mut builder.when::(not(continuation)), + eval_acc, + [AB::F::ZERO; 4], + ); + + // Randomness transition + assert_array_eq( + &mut builder.when(header_continuation), + next.challenges[0..EXT_DEG].try_into().expect(""), + [AB::F::ONE, AB::F::ZERO, AB::F::ZERO, AB::F::ZERO], + ); + let alpha_denominator = FieldExtension::multiply(alpha1, alpha); + assert_array_eq::<_, _, _, EXT_DEG>( + &mut builder.when(logup_row), + alpha_denominator, + alpha2, + ); + let prod_next_alpha = FieldExtension::multiply(alpha1, alpha); + assert_array_eq::<_, _, _, EXT_DEG>( + &mut builder.when(prod_continuation), + prod_next_alpha, + next_alpha1, + ); + let logup_next_alpha = FieldExtension::multiply(alpha2, alpha); + assert_array_eq::<_, _, _, EXT_DEG>( + &mut builder.when(logup_continuation), + logup_next_alpha, + next_alpha1, + ); + + // Header + let header_row_specific: &HeaderSpecificCols = + specific[..HeaderSpecificCols::::width()].borrow(); + let registers = header_row_specific.registers; + + self.execution_bridge + .execute_and_increment_pc( + AB::Expr::from_canonical_usize(SUMCHECK_LAYER_EVAL.global_opcode().as_usize()), + [ + registers[4].into(), + registers[0].into(), + registers[1].into(), + self.address_space.into(), + self.address_space.into(), + registers[2].into(), + registers[3].into(), + ], + ExecutionState::new(header_row_specific.pc, first_timestamp), + last_timestamp - first_timestamp, + ) + .eval(builder, header_row); + + // Read registers + for i in 0..5usize { + self.memory_bridge + .read( + MemoryAddress::new(self.address_space, registers[i]), + [register_ptrs[i]], + first_timestamp + AB::F::from_canonical_usize(i), + &header_row_specific.read_records[i], + ) + .eval(builder, header_row); + } + + // React ctx + self.memory_bridge + .read( + MemoryAddress::new(self.address_space, register_ptrs[0]), + ctx, + first_timestamp + AB::F::from_canonical_usize(5), + &header_row_specific.read_records[5], + ) + .eval(builder, header_row); + + // Read challenges + self.memory_bridge + .read( + MemoryAddress::new(self.address_space, register_ptrs[1]), + challenges, + first_timestamp + AB::F::from_canonical_usize(6), + &header_row_specific.read_records[6], + ) + .eval(builder, header_row); + + // Write final result + self.memory_bridge + .write( + MemoryAddress::new(self.address_space, register_ptrs[4]), + eval_acc, + last_timestamp - AB::F::ONE, + &header_row_specific.write_records, + ) + .eval(builder, header_row); + + // Prod spec evaluation + let prod_row_specific: &ProdSpecificCols = + specific[..ProdSpecificCols::::width()].borrow(); + let next_prod_row_specific: &ProdSpecificCols = + next.specific[..ProdSpecificCols::::width()].borrow(); + + self.memory_bridge + .read( + MemoryAddress::new( + self.address_space, + register_ptrs[0] + + AB::F::from_canonical_usize(EXT_DEG * 2) + + (curr_prod_n - AB::F::ONE), + ), // curr_prod_n starts at 1. + [max_round], + start_timestamp, + &prod_row_specific.read_records[0], + ) + .eval(builder, prod_row); + + builder.when(prod_row_within_max_round).assert_eq( + prod_row_specific.data_ptr, + (prod_nested_len * (curr_prod_n - AB::F::ONE) + ctx[4] * ctx[0]) + * AB::F::from_canonical_usize(EXT_DEG), + ); + builder.assert_eq( + prod_row * prod_row_within_max_round * in_round, + prod_in_round_evaluation, + ); + builder.assert_eq( + prod_row * prod_row_within_max_round * not(in_round), + prod_next_round_evaluation, + ); + builder.assert_eq(prod_row * should_acc, prod_acc); + + self.memory_bridge + .read( + MemoryAddress::new( + self.address_space, + register_ptrs[2] + prod_row_specific.data_ptr, + ), + prod_row_specific.p, + start_timestamp + AB::F::ONE, + &prod_row_specific.read_records[1], + ) + .eval(builder, prod_row_within_max_round); + + let p1: [AB::Var; EXT_DEG] = prod_row_specific.p[0..EXT_DEG].try_into().expect(""); + let p2: [AB::Var; EXT_DEG] = prod_row_specific.p[EXT_DEG..(EXT_DEG * 2)] + .try_into() + .expect(""); + + self.memory_bridge + .write( + MemoryAddress::new( + self.address_space, + register_ptrs[4] + curr_prod_n * AB::F::from_canonical_usize(EXT_DEG), + ), + prod_row_specific.p_evals, + start_timestamp + AB::F::TWO, + &prod_row_specific.write_record, + ) + .eval(builder, prod_row_within_max_round); + + // Calculate evaluations + let next_round_p_evals = FieldExtension::add( + FieldExtension::multiply::(p1, c1), + FieldExtension::multiply::(p2, c2), + ); + let in_round_p_evals = FieldExtension::multiply::(p1, p2); + assert_array_eq::<_, _, _, EXT_DEG>( + &mut builder.when(prod_in_round_evaluation), + in_round_p_evals, + prod_row_specific.p_evals, + ); + assert_array_eq::<_, _, _, EXT_DEG>( + &mut builder.when(prod_next_round_evaluation), + next_round_p_evals, + prod_row_specific.p_evals, + ); + + // Accumulate evaluation + let acc_eval = + FieldExtension::multiply::(prod_row_specific.p_evals, alpha1); + assert_array_eq::<_, _, _, EXT_DEG>( + &mut builder.when(prod_acc), + prod_row_specific.acc_eval, + acc_eval, + ); + + let next_acc = FieldExtension::subtract(eval_acc, next_prod_row_specific.acc_eval); + assert_array_eq::<_, _, _, EXT_DEG>( + &mut builder.when(next.prod_acc), + next.eval_acc, + next_acc, + ); + + // Logup spec evaluation + let logup_row_specific: &LogupSpecificCols = + specific[..LogupSpecificCols::::width()].borrow(); + let next_logup_row_specfic: &LogupSpecificCols = + next.specific[..LogupSpecificCols::::width()].borrow(); + + self.memory_bridge + .read( + MemoryAddress::new( + self.address_space, + register_ptrs[0] + + AB::F::from_canonical_usize(EXT_DEG * 2) + + ctx[1] + + (curr_logup_n - AB::F::ONE), + ), // curr_logup_n starts at 1. + [max_round], + start_timestamp, + &logup_row_specific.read_records[0], + ) + .eval(builder, logup_row); + + builder.when(logup_row_within_max_round).assert_eq( + logup_row_specific.data_ptr, + (logup_nested_len * (curr_logup_n - AB::F::ONE) + ctx[6] * ctx[0]) + * AB::F::from_canonical_usize(EXT_DEG), + ); + builder.assert_eq( + logup_row * logup_row_within_max_round * in_round, + logup_in_round_evaluation, + ); + builder.assert_eq( + logup_row * logup_row_within_max_round * not(in_round), + logup_next_round_evaluation, + ); + builder.assert_eq(logup_row * should_acc, logup_acc); + + self.memory_bridge + .read( + MemoryAddress::new( + self.address_space, + register_ptrs[3] + logup_row_specific.data_ptr, + ), + logup_row_specific.pq, + start_timestamp + AB::F::ONE, + &logup_row_specific.read_records[1], + ) + .eval(builder, logup_row_within_max_round); + + let p1: [_; EXT_DEG] = logup_row_specific.pq[0..EXT_DEG].try_into().expect(""); + let p2: [_; EXT_DEG] = logup_row_specific.pq[EXT_DEG..(EXT_DEG * 2)] + .try_into() + .expect(""); + let q1: [_; EXT_DEG] = logup_row_specific.pq[(EXT_DEG * 2)..{ EXT_DEG * 3 }] + .try_into() + .expect(""); + let q2: [_; EXT_DEG] = logup_row_specific.pq[(EXT_DEG * 3)..(EXT_DEG * 4)] + .try_into() + .expect(""); + + self.memory_bridge + .write( + MemoryAddress::new( + self.address_space, + register_ptrs[4] + + (ctx[1] + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG), + ), + logup_row_specific.p_evals, + start_timestamp + AB::F::TWO, + &logup_row_specific.write_records[0], + ) + .eval(builder, logup_row_within_max_round); + + self.memory_bridge + .write( + MemoryAddress::new( + self.address_space, + register_ptrs[4] + + (ctx[1] + ctx[2] + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG), + ), + logup_row_specific.q_evals, + start_timestamp + AB::F::from_canonical_usize(3), + &logup_row_specific.write_records[1], + ) + .eval(builder, logup_row_within_max_round); + + // Calculate evaluations + let next_round_p_evals = FieldExtension::add( + FieldExtension::multiply::(p1, c1), + FieldExtension::multiply::(p2, c2), + ); + let in_round_p_evals = FieldExtension::add( + FieldExtension::multiply::(p1, q2), + FieldExtension::multiply::(p2, q1), + ); + assert_array_eq::<_, _, _, EXT_DEG>( + &mut builder.when(logup_in_round_evaluation), + in_round_p_evals, + logup_row_specific.p_evals, + ); + assert_array_eq::<_, _, _, EXT_DEG>( + &mut builder.when(logup_next_round_evaluation), + next_round_p_evals, + logup_row_specific.p_evals, + ); + + let next_round_q_evals = FieldExtension::add( + FieldExtension::multiply::(q1, c1), + FieldExtension::multiply::(q2, c2), + ); + let in_round_q_evals = FieldExtension::multiply::(q1, q2); + assert_array_eq::<_, _, _, EXT_DEG>( + &mut builder.when(logup_in_round_evaluation), + in_round_q_evals, + logup_row_specific.q_evals, + ); + assert_array_eq::<_, _, _, EXT_DEG>( + &mut builder.when(logup_next_round_evaluation), + next_round_q_evals, + logup_row_specific.q_evals, + ); + + // Accumulate evaluation + let acc_eval = FieldExtension::add( + FieldExtension::multiply::(logup_row_specific.p_evals, alpha1), + FieldExtension::multiply::(logup_row_specific.q_evals, alpha2), + ); + assert_array_eq::<_, _, _, EXT_DEG>( + &mut builder.when(logup_acc), + logup_row_specific.acc_eval, + acc_eval, + ); + + let next_acc = FieldExtension::subtract(eval_acc, next_logup_row_specfic.acc_eval); + assert_array_eq::<_, _, _, EXT_DEG>( + &mut builder.when(next.logup_acc), + next.eval_acc, + next_acc, + ); + } +} diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs new file mode 100644 index 0000000000..2ec3f1bd3c --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -0,0 +1,565 @@ +use std::sync::{Arc, Mutex}; + +use openvm_circuit::{ + arch::{ + ExecutionBridge, ExecutionError, ExecutionState, InstructionExecutor, PreflightExecutor, + RecordArena, Streams, SystemPort, TraceFiller, VmChipWrapper, VmStateMut, + }, + system::memory::{ + online::TracingMemory, MemoryAuxColsFactory, MemoryController, OfflineMemory, RecordId, + }, +}; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; +use openvm_native_compiler::{conversion::AS, SumcheckOpcode::SUMCHECK_LAYER_EVAL}; +use openvm_stark_backend::{ + p3_field::{Field, PrimeField, PrimeField32}, + p3_maybe_rayon::prelude::{ParallelIterator, ParallelSlice}, +}; +use serde::{Deserialize, Serialize}; + +use crate::{ + field_extension::{FieldExtension, EXT_DEG}, + fri::elem_to_ext, + sumcheck::{ + air::NativeSumcheckAir, + columns::{HeaderSpecificCols, LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols}, + }, + utils::const_max, +}; +const CONTEXT_ARR_BASE_LEN: usize = EXT_DEG * 2; + +#[repr(C)] +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[serde(bound = "F: Field")] +pub struct SumcheckEvalRecord { + pub from_state: ExecutionState, + pub instruction: Instruction, + pub row_type: usize, // 0 - header; 1 - prod; 2 - logup + pub curr_timestamp_increment: usize, + pub final_timestamp_increment: usize, + pub continuation: bool, + + pub register_ptrs: [F; 5], + pub registers: [F; 5], + pub ctx: [F; EXT_DEG * 2], + pub challenges: [F; EXT_DEG * 4], + pub read_data_records: [RecordId; 7], + pub write_data_records: [RecordId; 2], + + pub max_round: F, + pub within_round_limit: bool, + pub should_acc: bool, + pub prod_spec_n: usize, + pub logup_spec_n: usize, + pub alpha: [F; EXT_DEG], + pub alpha1: [F; EXT_DEG], + pub alpha2: [F; EXT_DEG], + pub data_ptr: F, + pub p1: [F; EXT_DEG], + pub p2: [F; EXT_DEG], + pub q1: [F; EXT_DEG], + pub q2: [F; EXT_DEG], + pub p_evals: [F; EXT_DEG], + pub q_evals: [F; EXT_DEG], + pub eval_acc: [F; EXT_DEG], + pub acc_eval: [F; EXT_DEG], +} + +fn calculate_3d_ext_idx( + inner_inner_len: F, + inner_len: F, + outer_idx: F, + inner_idx: F, + inner_inner_idx: F, +) -> F { + (inner_inner_len * inner_len * outer_idx + inner_inner_len * inner_idx + inner_inner_idx) + * F::from_canonical_usize(EXT_DEG) +} + +#[derive(derive_new::new, Copy, Clone)] +pub struct NativeSumcheckExecutor; + +#[derive(derive_new::new)] +pub struct NativeSumcheckFiller; + +pub type NativeSumcheckChip = VmChipWrapper; + +impl Default for NativeSumcheckExecutor { + fn default() -> Self { + Self::new() + } +} + +impl PreflightExecutor for NativeSumcheckExecutor +where + F: PrimeField32, + for<'buf> RA: RecordArena<'buf, FriReducedOpeningLayout, FriReducedOpeningRecordMut<'buf, F>>, +{ + fn execute( + &self, + state: VmStateMut, + instruction: &Instruction, + ) -> Result<(), ExecutionError> { + let &Instruction { + opcode: op, + a: output_register, + b: input_register_1, + c: input_register_2, + d: data_address_space, + e: register_address_space, + f: input_register_3, + g: input_register_4, + } = instruction; + + if op == SUMCHECK_LAYER_EVAL.global_opcode() { + let mut observation_records: Vec> = vec![]; + let mut curr_timestamp: usize = 0; + + let (read_ctx_pointer, ctx_pointer) = + memory.read_cell(register_address_space, input_register_1); + let (read_cs_pointer, cs_pointer) = + memory.read_cell(register_address_space, input_register_2); + let (read_prod_pointer, prod_ptr) = + memory.read_cell(register_address_space, input_register_3); + let (read_logup_pointer, logup_ptr) = + memory.read_cell(register_address_space, input_register_4); + let (read_result_pointer, r_ptr) = + memory.read_cell(register_address_space, output_register); + let register_ptrs: [F; 5] = [ctx_pointer, cs_pointer, prod_ptr, logup_ptr, r_ptr]; + + let (ctx_read, ctx): (RecordId, [F; EXT_DEG * 2]) = + memory.read::<{ EXT_DEG * 2 }>(data_address_space, ctx_pointer); + let [ + round, + num_prod_spec, + num_logup_spec, + prod_specs_inner_len, + prod_specs_inner_inner_len, + logup_specs_inner_len, + logup_specs_inner_inner_len, + is_op_for_cur_sumcheck_round, // This opcode supports two modes of operation: + // 1. calculate the expected evaluation of two types of sumchecks for the current round + // a. product sumcheck: v' = v[0] * v[1] + // b. logup sumcheck: p'= p[0] * q[1] + p[1] * q[0] and q'= q[0] * q[1]. + // 2. calculate the expected value of next layer: + // a. product sumcheck: v[r] = eq(0,r) * v[0] + eq(1,r) * v[1] + // b. logup sumcheck: p[r] = eq(0,r) * p[0] + eq(1,r) * p[1] and q[r] = eq(0,r) * q[0] + eq(1,r) * q[1] + ] = ctx; + + let (challenges_read, challenges): (RecordId, [F; EXT_DEG * 4]) = + memory.read::<{ EXT_DEG * 4 }>(data_address_space, cs_pointer); + let alpha: [F; 4] = challenges[0..EXT_DEG].try_into().expect(""); + + let mut header_row = SumcheckEvalRecord { + from_state, + instruction: instruction.clone(), + row_type: 0, + continuation: true, + curr_timestamp_increment: curr_timestamp, + register_ptrs, + alpha, + registers: [ + input_register_1, + input_register_2, + input_register_3, + input_register_4, + output_register, + ], + ctx, + challenges, + read_data_records: [ + read_ctx_pointer, + read_cs_pointer, + read_prod_pointer, + read_logup_pointer, + read_result_pointer, + ctx_read, + challenges_read, + ], + ..Default::default() + }; + + observation_records.push(header_row); + self.height += 1; + curr_timestamp += 7; + + let mut eval_acc = elem_to_ext(F::from_canonical_u32(0)); + let mut alpha_acc = elem_to_ext(F::from_canonical_u32(1)); + let c1: [F; 4] = challenges[EXT_DEG..(EXT_DEG * 2)].try_into().expect(""); + let c2: [F; 4] = challenges[(EXT_DEG * 2)..(EXT_DEG * 3)] + .try_into() + .expect(""); + + let mut i = F::ZERO; + let mut i_usize = 0usize; + while i < num_prod_spec { + let mut prod_row: SumcheckEvalRecord = SumcheckEvalRecord { + from_state, + instruction: instruction.clone(), + row_type: 1, + continuation: true, + curr_timestamp_increment: curr_timestamp, + register_ptrs, + ctx, + challenges, + alpha, + prod_spec_n: i_usize, + ..Default::default() + }; + prod_row.alpha1 = alpha_acc; + + let (read_max_round, max_round) = memory.read_cell( + data_address_space, + ctx_pointer + F::from_canonical_usize(CONTEXT_ARR_BASE_LEN) + i, + ); + prod_row.max_round = max_round; + prod_row.read_data_records[0] = read_max_round; + curr_timestamp += 1; + + if round < (max_round - F::from_canonical_usize(1)) { + prod_row.within_round_limit = true; + let start = calculate_3d_ext_idx( + prod_specs_inner_inner_len, + prod_specs_inner_len, + i, + round, + F::from_canonical_usize(0), + ); + prod_row.data_ptr = start; + + let (read_p, ps) = + memory.read::<{ EXT_DEG * 2 }>(data_address_space, prod_ptr + start); + let p1: [F; 4] = ps[0..EXT_DEG].try_into().expect(""); + let p2: [F; 4] = ps[EXT_DEG..(EXT_DEG * 2)].try_into().expect(""); + + prod_row.read_data_records[1] = read_p; + prod_row.p1 = p1; + prod_row.p2 = p2; + + let evals = if is_op_for_cur_sumcheck_round > F::ZERO { + FieldExtension::multiply(p1, p2) + } else { + FieldExtension::add( + FieldExtension::multiply(p1, c1), + FieldExtension::multiply(p2, c2), + ) + }; + prod_row.p_evals = evals; + + let (write_slice_eval_1, _) = memory.write::( + data_address_space, + r_ptr + (F::ONE + i) * F::from_canonical_usize(EXT_DEG), + evals, + ); + prod_row.write_data_records[0] = write_slice_eval_1; + + let is_op_for_next_sumcheck_round = F::ONE - is_op_for_cur_sumcheck_round; + let acc_eval = FieldExtension::multiply(alpha_acc, evals); + prod_row.acc_eval = acc_eval; + + if (round + is_op_for_next_sumcheck_round) + < (max_round - F::from_canonical_usize(1)) + { + eval_acc = FieldExtension::add(eval_acc, acc_eval); + prod_row.should_acc = true; + prod_row.eval_acc = eval_acc.clone(); + } + + curr_timestamp += 2; + } + + alpha_acc = FieldExtension::multiply(alpha_acc, alpha); + + i = i + F::ONE; + i_usize += 1; + observation_records.push(prod_row); + self.height += 1; + } + + let mut i = F::ZERO; + let mut i_usize = 0usize; + while i < num_logup_spec { + let mut logup_row: SumcheckEvalRecord = SumcheckEvalRecord { + from_state, + instruction: instruction.clone(), + row_type: 2, + continuation: true, + curr_timestamp_increment: curr_timestamp, + register_ptrs, + ctx, + challenges, + alpha, + logup_spec_n: i_usize, + ..Default::default() + }; + logup_row.alpha1 = alpha_acc; + + let (read_max_round, max_round) = memory.read_cell( + data_address_space, + ctx_pointer + F::from_canonical_usize(CONTEXT_ARR_BASE_LEN) + num_prod_spec + i, + ); + logup_row.max_round = max_round; + logup_row.read_data_records[0] = read_max_round; + curr_timestamp += 1; + + if round < (max_round - F::from_canonical_usize(1)) { + logup_row.within_round_limit = true; + let start = calculate_3d_ext_idx( + logup_specs_inner_inner_len, + logup_specs_inner_len, + i, + round, + F::from_canonical_usize(0), + ); + logup_row.data_ptr = start; + + let (read_pqs, pqs) = + memory.read::<{ EXT_DEG * 4 }>(data_address_space, logup_ptr + start); + let p1: [F; 4] = pqs[0..EXT_DEG].try_into().expect(""); + let p2: [F; 4] = pqs[EXT_DEG..(EXT_DEG * 2)].try_into().expect(""); + let q1: [F; 4] = pqs[(EXT_DEG * 2)..(EXT_DEG * 3)].try_into().expect(""); + let q2: [F; 4] = pqs[(EXT_DEG * 3)..(EXT_DEG * 4)].try_into().expect(""); + + logup_row.read_data_records[1] = read_pqs; + logup_row.p1 = p1; + logup_row.p2 = p2; + logup_row.q1 = q1; + logup_row.q2 = q2; + + let p_evals = if is_op_for_cur_sumcheck_round > F::ZERO { + FieldExtension::add( + FieldExtension::multiply(p1, q2), + FieldExtension::multiply(p2, q1), + ) + } else { + FieldExtension::add( + FieldExtension::multiply(p1, c1), + FieldExtension::multiply(p2, c2), + ) + }; + + let q_evals = if is_op_for_cur_sumcheck_round > F::ZERO { + FieldExtension::multiply(q1, q2) + } else { + FieldExtension::add( + FieldExtension::multiply(q1, c1), + FieldExtension::multiply(q2, c2), + ) + }; + + logup_row.p_evals = p_evals; + logup_row.q_evals = q_evals; + + let (write_slice_eval_1, _) = memory.write::( + data_address_space, + r_ptr + (F::ONE + num_prod_spec + i) * F::from_canonical_usize(EXT_DEG), + p_evals, + ); + let (write_slice_eval_2, _) = memory.write::( + data_address_space, + r_ptr + + (F::ONE + num_prod_spec + num_logup_spec + i) + * F::from_canonical_usize(EXT_DEG), + q_evals, + ); + + logup_row.write_data_records[0] = write_slice_eval_1; + logup_row.write_data_records[1] = write_slice_eval_2; + + let is_op_for_next_sumcheck_round = F::ONE - is_op_for_cur_sumcheck_round; + let alpha_denominator = FieldExtension::multiply(alpha_acc, alpha); + logup_row.alpha2 = alpha_denominator; + + if (round + is_op_for_next_sumcheck_round) + < (max_round - F::from_canonical_usize(1)) + { + let acc_eval = FieldExtension::add( + FieldExtension::multiply(alpha_acc, p_evals), + FieldExtension::multiply(alpha_denominator, q_evals), + ); + logup_row.acc_eval = acc_eval; + eval_acc = FieldExtension::add(eval_acc, acc_eval); + logup_row.should_acc = true; + logup_row.eval_acc = eval_acc.clone(); + } + + curr_timestamp += 3; + } + + alpha_acc = + FieldExtension::multiply(FieldExtension::multiply(alpha_acc, alpha), alpha); + + i = i + F::ONE; + i_usize += 1; + observation_records.push(logup_row); + self.height += 1; + } + + let (write_r, _) = memory.write::(data_address_space, r_ptr, eval_acc); + curr_timestamp += 1; + observation_records[0].write_data_records[0] = write_r; + + for record in &mut observation_records { + record.final_timestamp_increment = curr_timestamp; + record.eval_acc = FieldExtension::subtract(eval_acc, record.eval_acc); + } + let last_idx = observation_records.len() - 1; + observation_records[last_idx].continuation = false; + + self.record_set.extend(observation_records); + } else { + unreachable!() + } + + Ok(ExecutionState { + pc: from_state.pc + DEFAULT_PC_STEP, + timestamp: memory.timestamp(), + }) + } + + // GKR layered IOP for product and logup relations + fn get_opcode_name(&self, opcode: usize) -> String { + assert_eq!(opcode, SUMCHECK_LAYER_EVAL.global_opcode().as_usize()); + String::from("SUMCHECK_LAYER_EVAL") + } +} + +impl TraceFiller for NativeSumcheckFiller { + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + todo!(); + let slice = &mut flat_trace[used_cells..used_cells + width]; + let cols: &mut NativeSumcheckCols = slice.borrow_mut(); + cols.first_timestamp = F::from_canonical_u32(record.from_state.timestamp); + cols.start_timestamp = F::from_canonical_usize( + record.from_state.timestamp as usize + record.curr_timestamp_increment, + ); + cols.last_timestamp = F::from_canonical_usize( + record.from_state.timestamp as usize + record.final_timestamp_increment, + ); + cols.register_ptrs = record.register_ptrs; + cols.ctx = record.ctx; + cols.prod_nested_len = record.ctx[4] * record.ctx[3]; + cols.logup_nested_len = record.ctx[6] * record.ctx[5]; + cols.challenges = record.challenges; + cols.alpha = record.alpha; + cols.max_round = record.max_round; + cols.within_round_limit = if record.within_round_limit { + F::ONE + } else { + F::ZERO + }; + cols.should_acc = if record.should_acc { F::ONE } else { F::ZERO }; + cols.eval_acc = record.eval_acc; + + if record.row_type == 0 { + cols.header_row = F::ONE; + cols.header_continuation = if record.continuation { F::ONE } else { F::ZERO }; + let header: &mut HeaderSpecificCols = + cols.specific[..HeaderSpecificCols::::width()].borrow_mut(); + + header.pc = F::from_canonical_u32(record.from_state.pc); + header.registers = record.registers; + + for i in 0..7usize { + let mem_record = memory.record_by_id(record.read_data_records[i]); + aux_cols_factory.generate_read_aux(mem_record, &mut header.read_records[i]); + } + + // write the final result + let mem_record = memory.record_by_id(record.write_data_records[0]); + aux_cols_factory.generate_write_aux(mem_record, &mut header.write_records); + } else if record.row_type == 1 { + cols.prod_row = F::ONE; + cols.prod_continuation = if record.continuation { F::ONE } else { F::ZERO }; + cols.prod_row_within_max_round = if record.within_round_limit { + F::ONE + } else { + F::ZERO + }; + cols.prod_in_round_evaluation = if record.within_round_limit { + record.ctx[7] + } else { + F::ZERO + }; + cols.prod_next_round_evaluation = if record.within_round_limit { + F::ONE - record.ctx[7] + } else { + F::ZERO + }; + cols.prod_acc = if record.should_acc { F::ONE } else { F::ZERO }; + let prod: &mut ProdSpecificCols = + cols.specific[..ProdSpecificCols::::width()].borrow_mut(); + + cols.curr_prod_n = F::from_canonical_usize(record.prod_spec_n + 1); + cols.challenges[0..EXT_DEG].copy_from_slice(&record.alpha1); + prod.p[0..EXT_DEG].copy_from_slice(&record.p1); + prod.p[EXT_DEG..(EXT_DEG * 2)].copy_from_slice(&record.p2); + prod.data_ptr = record.data_ptr; + prod.acc_eval = record.acc_eval; + + // Read max_round + let mem_record = memory.record_by_id(record.read_data_records[0]); + aux_cols_factory.generate_read_aux(mem_record, &mut prod.read_records[0]); + + if record.within_round_limit { + // Read p1, p2 + let mem_record = memory.record_by_id(record.read_data_records[1]); + aux_cols_factory.generate_read_aux(mem_record, &mut prod.read_records[1]); + + // Write p eval + prod.p_evals = record.p_evals; + let mem_record = memory.record_by_id(record.write_data_records[0]); + aux_cols_factory.generate_write_aux(mem_record, &mut prod.write_record); + } + } else if record.row_type == 2 { + cols.logup_row = F::ONE; + cols.logup_continuation = if record.continuation { F::ONE } else { F::ZERO }; + cols.logup_row_within_max_round = if record.within_round_limit { + F::ONE + } else { + F::ZERO + }; + cols.logup_in_round_evaluation = if record.within_round_limit { + record.ctx[7] + } else { + F::ZERO + }; + cols.logup_next_round_evaluation = if record.within_round_limit { + F::ONE - record.ctx[7] + } else { + F::ZERO + }; + cols.logup_acc = if record.should_acc { F::ONE } else { F::ZERO }; + let logup: &mut LogupSpecificCols = + cols.specific[..LogupSpecificCols::::width()].borrow_mut(); + + cols.curr_logup_n = F::from_canonical_usize(record.logup_spec_n + 1); + cols.challenges[0..EXT_DEG].copy_from_slice(&record.alpha1); + cols.challenges[(EXT_DEG * 3)..(EXT_DEG * 4)].copy_from_slice(&record.alpha2); + logup.pq[0..EXT_DEG].copy_from_slice(&record.p1); + logup.pq[EXT_DEG..(EXT_DEG * 2)].copy_from_slice(&record.p2); + logup.pq[(EXT_DEG * 2)..(EXT_DEG * 3)].copy_from_slice(&record.q1); + logup.pq[(EXT_DEG * 3)..(EXT_DEG * 4)].copy_from_slice(&record.q2); + logup.data_ptr = record.data_ptr; + logup.acc_eval = record.acc_eval; + + // Read max_round + let mem_record = memory.record_by_id(record.read_data_records[0]); + aux_cols_factory.generate_read_aux(mem_record, &mut logup.read_records[0]); + + if record.within_round_limit { + // Read p1, p2, q1, q2 + let mem_record = memory.record_by_id(record.read_data_records[1]); + aux_cols_factory.generate_read_aux(mem_record, &mut logup.read_records[1]); + + // Write p and q eval + logup.p_evals = record.p_evals; + logup.q_evals = record.q_evals; + for i in 0..2usize { + let mem_record = memory.record_by_id(record.write_data_records[i]); + aux_cols_factory.generate_write_aux(mem_record, &mut logup.write_records[i]); + } + } + } + } +} diff --git a/extensions/native/circuit/src/sumcheck/columns.rs b/extensions/native/circuit/src/sumcheck/columns.rs new file mode 100644 index 0000000000..ca4a264277 --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/columns.rs @@ -0,0 +1,142 @@ +use openvm_circuit::system::memory::offline_checker::{MemoryReadAuxCols, MemoryWriteAuxCols}; +use openvm_circuit_primitives_derive::AlignedBorrow; + +use crate::{field_extension::EXT_DEG, utils::const_max}; + +const fn max3(a: usize, b: usize, c: usize) -> usize { + const_max(a, const_max(b, c)) +} + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct NativeSumcheckCols { + /// Indicates that this row is the header for a layer sum operation + pub header_row: T, + /// Indicates that this row is a step for prod_spec in the layer sum operation + pub prod_row: T, + /// Indicates that this row is a step for logup_spec in the layer sum operation + pub logup_row: T, + + /// Indicates that there are valid operations following this header row + pub header_continuation: T, + /// Indicates that there are valid operations following this product evaluation row + pub prod_continuation: T, + /// Indicates that there are valid operations following this logup row + pub logup_continuation: T, + + /// Indicates that the prod row is within maximum round + pub prod_row_within_max_round: T, + /// Indicates that the logup row is within maximum round + pub logup_row_within_max_round: T, + + /// Indicates what type of evaluation constraints should be applied + pub prod_in_round_evaluation: T, + pub prod_next_round_evaluation: T, + pub logup_in_round_evaluation: T, + pub logup_next_round_evaluation: T, + + /// Indicates if evaluations are accumulated + pub prod_acc: T, + pub logup_acc: T, + + /// Timestamps + pub first_timestamp: T, + pub start_timestamp: T, + pub last_timestamp: T, + + // Register values + pub register_ptrs: [T; 5], + + // Context variables + // [ + // round, + // num_prod_spec, + // num_logup_spec, + // prod_spec_inner_len, + // prod_spec_inner_inner_len, + // logup_spec_inner_len, + // logup_spec_inner_inner_len, + // in_layer, + // ] + pub ctx: [T; EXT_DEG * 2], + + pub prod_nested_len: T, + pub logup_nested_len: T, + + pub curr_prod_n: T, + pub curr_logup_n: T, + + // alpha1, c1, c2, alpha2 (for logup rows) + pub alpha: [T; EXT_DEG], + pub challenges: [T; EXT_DEG * 4], + + // Specific to each row + pub max_round: T, + // Is this round within max_round + pub within_round_limit: T, + // Should the evaluation be accumualted + pub should_acc: T, + + // The current final evaluation accumulator. Extension element. + pub eval_acc: [T; EXT_DEG], + + // /// 1. For header row, 5 registers, ctx, challenges + // /// 2. For the rest: max_variables, p1, p2, q1, q2 + // pub read_records: [MemoryReadAuxCols; 7], + // /// 1. For header row, write final result + // /// 2. For prod rows: write prod_evals + // /// 3. For logup rows: write q_evals, p_evals + // pub write_records: [MemoryWriteAuxCols; 2], + pub specific: [T; max3( + HeaderSpecificCols::::width(), + ProdSpecificCols::::width(), + LogupSpecificCols::::width(), + )], +} + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct HeaderSpecificCols { + pub pc: T, + pub registers: [T; 5], + /// 5 register reads + ctx read + challenges read + pub read_records: [MemoryReadAuxCols; 7], + /// Write the final evaluation + pub write_records: MemoryWriteAuxCols, +} + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct ProdSpecificCols { + /// Pointer + pub data_ptr: T, + /// 2 extension elements + pub p: [T; EXT_DEG * 2], + /// read max varibale and 2 p values + pub read_records: [MemoryReadAuxCols; 2], + /// Calculated p evals + pub p_evals: [T; EXT_DEG], + /// write p_evals + pub write_record: MemoryWriteAuxCols, + /// Evaluation for the accumulator + pub acc_eval: [T; EXT_DEG], +} + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct LogupSpecificCols { + /// Pointer + pub data_ptr: T, + /// 4 extension elements + pub pq: [T; EXT_DEG * 4], + /// read max variable and 4 values: p1, p2, q1, q2 + pub read_records: [MemoryReadAuxCols; 2], + /// Calculated p evals + pub p_evals: [T; EXT_DEG], + /// Calculated q evals + pub q_evals: [T; EXT_DEG], + /// write both p_evals and q_evals + pub write_records: [MemoryWriteAuxCols; 2], + /// Evaluation for the accumulator + pub acc_eval: [T; EXT_DEG], +} diff --git a/extensions/native/circuit/src/sumcheck/cuda.rs b/extensions/native/circuit/src/sumcheck/cuda.rs new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/cuda.rs @@ -0,0 +1 @@ + diff --git a/extensions/native/circuit/src/sumcheck/execution.rs b/extensions/native/circuit/src/sumcheck/execution.rs new file mode 100644 index 0000000000..ce09c8fc3d --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/execution.rs @@ -0,0 +1,194 @@ +use std::{ + borrow::{Borrow, BorrowMut}, + mem::size_of, +}; + +use openvm_circuit::{arch::*, system::memory::online::GuestMemory}; +use openvm_circuit_primitives::AlignedBytesBorrow; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP}; +use openvm_native_compiler::conversion::AS; +use openvm_stark_backend::p3_field::PrimeField32; + +use super::{elem_to_ext, FriReducedOpeningExecutor}; +use crate::field_extension::{FieldExtension, EXT_DEG}; + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct NativeSumcheckPreCompute { + a_ptr_ptr: u32, + b_ptr_ptr: u32, + length_ptr: u32, + alpha_ptr: u32, + result_ptr: u32, + hint_id_ptr: u32, + is_init_ptr: u32, +} + +impl NativeSumcheckPreCompute { + #[inline(always)] + fn pre_compute_impl( + &self, + _pc: u32, + inst: &Instruction, + data: &mut FriReducedOpeningPreCompute, + ) -> Result<(), StaticProgramError> { + let &Instruction { + a, + b, + c, + d, + e, + f, + g, + .. + } = inst; + + let a_ptr_ptr = a.as_canonical_u32(); + let b_ptr_ptr = b.as_canonical_u32(); + let length_ptr = c.as_canonical_u32(); + let alpha_ptr = d.as_canonical_u32(); + let result_ptr = e.as_canonical_u32(); + let hint_id_ptr = f.as_canonical_u32(); + let is_init_ptr = g.as_canonical_u32(); + + *data = FriReducedOpeningPreCompute { + a_ptr_ptr, + b_ptr_ptr, + length_ptr, + alpha_ptr, + result_ptr, + hint_id_ptr, + is_init_ptr, + }; + + Ok(()) + } +} + +impl Executor for NativeSumcheckExecutor +where + F: PrimeField32, +{ + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let pre_compute: &mut NativeSumcheckPreCompute = data.borrow_mut(); + + self.pre_compute_impl(pc, inst, pre_compute)?; + + let fn_ptr = execute_e1_handler; + Ok(fn_ptr) + } + + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() + } + + #[cfg(not(feature = "tco"))] + #[inline(always)] + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut NativeSumcheckPreCompute = data.borrow_mut(); + + self.pre_compute_impl(pc, inst, pre_compute)?; + + let fn_ptr = execute_e1_impl; + Ok(fn_ptr) + } +} + +impl MeteredExecutor for NativeSumcheckExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn metered_pre_compute_size(&self) -> usize { + size_of::>() + } + + #[cfg(not(feature = "tco"))] + #[inline(always)] + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + + let fn_ptr = execute_e2_impl; + Ok(fn_ptr) + } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + + let fn_ptr = execute_e2_handler; + Ok(fn_ptr) + } +} + +#[create_handler] +#[inline(always)] +unsafe fn execute_e1_impl( + pre_compute: &[u8], + instret: &mut u64, + pc: &mut u32, + _instret_end: u64, + exec_state: &mut VmExecState, +) { + let pre_compute: &NativeSumcheckPreCompute = pre_compute.borrow(); + execute_e12_impl(pre_compute, instret, pc, exec_state); +} + +#[create_handler] +#[inline(always)] +unsafe fn execute_e2_impl( + pre_compute: &[u8], + instret: &mut u64, + pc: &mut u32, + _arg: u64, + exec_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + let height = execute_e12_impl(&pre_compute.data, instret, pc, exec_state); + exec_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, height); +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &NativeSumcheckPreCompute, + instret: &mut u64, + pc: &mut u32, + exec_state: &mut VmExecState, +) -> u32 { + todo!() +} diff --git a/extensions/native/circuit/src/sumcheck/mod.rs b/extensions/native/circuit/src/sumcheck/mod.rs new file mode 100644 index 0000000000..8c35a0a7aa --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/mod.rs @@ -0,0 +1,11 @@ +pub mod air; +pub mod chip; +mod columns; + +#[cfg(feature = "cuda")] +mod cuda; +#[cfg(feature = "cuda")] +pub use cuda::*; + +// mod tests; +mod execution; diff --git a/extensions/native/compiler/src/asm/compiler.rs b/extensions/native/compiler/src/asm/compiler.rs index 42e32f3a7d..c80615ca7a 100644 --- a/extensions/native/compiler/src/asm/compiler.rs +++ b/extensions/native/compiler/src/asm/compiler.rs @@ -637,6 +637,18 @@ impl + TwoAdicField> AsmCo ); } } + DslIr::SumcheckLayerEval(input_ctx, challenges, prod_ptr, logup_ptr, r_ptr) => { + self.push( + AsmInstruction::SumcheckLayerEval( + input_ctx.fp(), + challenges.fp(), + prod_ptr.fp(), + logup_ptr.fp(), + r_ptr.fp(), + ), + debug_info, + ); + } _ => unimplemented!(), } } diff --git a/extensions/native/compiler/src/asm/instruction.rs b/extensions/native/compiler/src/asm/instruction.rs index ae0875b83a..3d498c00f4 100644 --- a/extensions/native/compiler/src/asm/instruction.rs +++ b/extensions/native/compiler/src/asm/instruction.rs @@ -171,6 +171,15 @@ pub enum AsmInstruction { CycleTrackerStart(), CycleTrackerEnd(), + + // Native opcode for calculating sumcheck layer evaluation + // SumcheckLayerEval(reg_a, reg_b, reg_c, ... , reg_f, reg_g) + // - reg_a: Output ptr for next layer's evaluations + // - reg_b: Context variables + // - reg_c: Challenge values (alpha, coeff) + // - reg_g: GKR product IOP evaluations + // - reg_f: GKR logup IOP evaluations + SumcheckLayerEval(i32, i32, i32, i32, i32), } impl> AsmInstruction { @@ -407,6 +416,13 @@ impl> AsmInstruction { AsmInstruction::RangeCheck(fp, lo_bits, hi_bits) => { write!(f, "range_check_fp ({})fp, ({}), ({})", fp, lo_bits, hi_bits) } + AsmInstruction::SumcheckLayerEval(ctx, cs, p_ptr, l_ptr, r_ptr) => { + write!( + f, + "sumcheck_layer_eval ({})fp, ({})fp, ({})fp, ({})fp, ({})fp", + ctx, cs, p_ptr, l_ptr, r_ptr + ) + } } } } diff --git a/extensions/native/compiler/src/conversion/mod.rs b/extensions/native/compiler/src/conversion/mod.rs index 82ee912703..e71608c190 100644 --- a/extensions/native/compiler/src/conversion/mod.rs +++ b/extensions/native/compiler/src/conversion/mod.rs @@ -12,7 +12,7 @@ use crate::{ asm::{AsmInstruction, AssemblyCode}, FieldArithmeticOpcode, FieldExtensionOpcode, FriOpcode, NativeBranchEqualOpcode, NativeJalOpcode, NativeLoadStore4Opcode, NativeLoadStoreOpcode, NativePhantom, - NativeRangeCheckOpcode, Poseidon2Opcode, VerifyBatchOpcode, + NativeRangeCheckOpcode, Poseidon2Opcode, SumcheckOpcode, VerifyBatchOpcode, }; #[derive(Clone, Copy, Debug, Serialize, Deserialize)] @@ -535,7 +535,19 @@ fn convert_instruction>( // Here it just requires a 0 AS::Immediate, )] - } + }, + AsmInstruction::SumcheckLayerEval(ctx, cs, p_ptr, l_ptr, r_ptr) => vec![ + Instruction { + opcode: options.opcode_with_offset(SumcheckOpcode::SUMCHECK_LAYER_EVAL), + a: i32_f(r_ptr), + b: i32_f(ctx), + c: i32_f(cs), + d: AS::Native.to_field(), + e: AS::Native.to_field(), + f: i32_f(p_ptr), + g: i32_f(l_ptr), + } + ], }; let debug_infos = vec![debug_info; instructions.len()]; diff --git a/extensions/native/compiler/src/ir/instructions.rs b/extensions/native/compiler/src/ir/instructions.rs index 3b30a45ad6..78347283d5 100644 --- a/extensions/native/compiler/src/ir/instructions.rs +++ b/extensions/native/compiler/src/ir/instructions.rs @@ -320,6 +320,31 @@ pub enum DslIr { CycleTrackerStart(String), /// End the cycle tracker used by a block of code annotated by the string input. CycleTrackerEnd(String), + + /// Native operation for calculating a sumcheck layer's evaluation + /// This op supports two modes: + /// 1. for computing expected evaluation for current layer, output = [ \sum_i alpha^i * + /// prod[i][0] * prod[i][1] + \sum_j alpha^(2j) * (logup_q[i][0] * logup_q[i][1] + alpha* + /// logup_p[i][0] * logup_q[i][1] + alpha * logup_p[i][1] * logup_q[i][0] ]; + /// + /// 2. for computing expected evaluation of next layer, output[1+i] = eq(0,r)*p[i][0] + eq(1,r) + /// * p[i][1]. + SumcheckLayerEval( + Ptr, // Context variables: + // 0. round, + // 1. number of product + // 2. number of logup + // 3. (3D array description) prod_specs_eval inner length + // 4. (3D array description) prod_specs_eval inner_inner length + // 5. (3D array description) logup_spec_eval inner length + // 6. (3D array description) logup_spec_eval inner length + // 7. Operational mode indicator + // 8+. usize-type variables indicating maximum rounds + Ptr, // Challenges: alpha, coeffs + Ptr, // prod_specs_eval + Ptr, // logup_specs_eval + Ptr, // output + ), } impl Default for DslIr { diff --git a/extensions/native/compiler/src/ir/mod.rs b/extensions/native/compiler/src/ir/mod.rs index 47e901cd3a..f708318c34 100644 --- a/extensions/native/compiler/src/ir/mod.rs +++ b/extensions/native/compiler/src/ir/mod.rs @@ -18,6 +18,7 @@ mod instructions; mod poseidon; mod ptr; mod select; +mod sumcheck; mod symbolic; mod types; mod utils; diff --git a/extensions/native/compiler/src/ir/sumcheck.rs b/extensions/native/compiler/src/ir/sumcheck.rs new file mode 100644 index 0000000000..ff5823600f --- /dev/null +++ b/extensions/native/compiler/src/ir/sumcheck.rs @@ -0,0 +1,52 @@ +use openvm_native_compiler_derive::iter_zip; +use openvm_stark_backend::p3_field::FieldAlgebra; + +use super::{Array, ArrayLike, Builder, Config, DslIr, Ext, Felt, MemIndex, Ptr, Usize, Var}; +use crate::ir::Variable; + +impl Builder { + /// Extends native VM ability to calculate the evaluation for a sumcheck layer + /// This opcode supports two modes (indicated by a context variable): + /// 1. calculate the expected evaluation of two types of sumchecks (prod, logup) + /// 2. calculate the expected value of next layer p[r] = eq(0,r)*p[0] + eq(1,r)*p[1] + /// + /// Context variables + /// + /// 0: round, + /// 1: number of product + /// 2. number of logup + /// 3. (3D array description) prod_specs_eval inner length + /// 4. (3D array description) prod_specs_eval inner_inner length + /// 5. (3D array description) logup_spec_eval inner length + /// 6. (3D array description) logup_spec_eval inner length + /// 7. Operational mode indicator + /// 8+ Additional usize-type variables indicating maximum rounds + /// + /// Output + /// + /// 1. for computing expected evaluation, output = [ \sum_i alpha^i * prod[i][0] * prod[i][1] + + /// \sum_j alpha^(2j) * (logup_q[i][0] * logup_q[i][1] + alpha* logup_p[i][0] * logup_q[i][1] + /// + alpha * logup_p[i][1] * logup_q[i][0] ]; + /// + /// 2. for computing expected eval of next layer, output[1+i] = eq(0,r)*p[i][0] + eq(1,r) * + /// p[i][1]. + pub fn sumcheck_layer_eval( + &mut self, + input_ctx: &Array>, // Context variables + challenges: &Array>, // Challenges + prod_specs_eval: &Array>, /* GKR product IOP evaluations. Flattened + * from 3D array. */ + logup_specs_eval: &Array>, /* GKR logup IOP evaluations. Flattened + * from 3D array. */ + r_evals: &Array>, /* Next layer's evaluations (pointer used for + * storing opcode output) */ + ) { + self.operations.push(DslIr::SumcheckLayerEval( + input_ctx.ptr(), + challenges.ptr(), + prod_specs_eval.ptr(), + logup_specs_eval.ptr(), + r_evals.ptr(), + )); + } +} diff --git a/extensions/native/compiler/src/lib.rs b/extensions/native/compiler/src/lib.rs index 66c786fbd9..efb45b0159 100644 --- a/extensions/native/compiler/src/lib.rs +++ b/extensions/native/compiler/src/lib.rs @@ -212,3 +212,18 @@ pub enum VerifyBatchOpcode { /// per column polynomial, per opening point VERIFY_BATCH, } + +/// Opcodes for sumcheck. +#[derive( + Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode, +)] +#[opcode_offset = 0x180] +#[repr(usize)] +#[allow(non_camel_case_types)] +pub enum SumcheckOpcode { + /// Compute the expected evaluation for each layer in the tower structure that GKR product IOP + /// and logup IOP uses Supports two modes of operation: + /// 1. Calculate current layer's expected evaluation + /// 2. Calculate next layer's evaluation + SUMCHECK_LAYER_EVAL, +} diff --git a/extensions/native/recursion/tests/sumcheck.rs b/extensions/native/recursion/tests/sumcheck.rs new file mode 100644 index 0000000000..607c2a783b --- /dev/null +++ b/extensions/native/recursion/tests/sumcheck.rs @@ -0,0 +1,1261 @@ +use itertools::Itertools; +use openvm_circuit::arch::{ + instructions::program::Program, verify_single, SystemConfig, VirtualMachine, VmConfig, + VmExecutor, +}; +use openvm_native_circuit::{Native, NativeConfig, EXT_DEG}; +use openvm_native_compiler::{ + asm::{AsmBuilder, AsmCompiler}, + conversion::{convert_program, CompilerOptions}, + ir::{Ext, Felt, Usize}, + prelude::*, +}; +use openvm_native_recursion::{ + challenger::{duplex::DuplexChallengerVariable, CanObserveVariable}, + testing_utils::inner::run_recursive_test, +}; +use openvm_stark_backend::{ + config::{Domain, StarkGenericConfig}, + p3_commit::PolynomialSpace, + p3_field::{ + extension::BinomialExtensionField, FieldAlgebra, FieldExtensionAlgebra, PackedValue, + }, +}; +use openvm_stark_sdk::{ + config::{ + baby_bear_poseidon2::BabyBearPoseidon2Engine, + fri_params::standard_fri_params_with_100_bits_conjectured_security, FriParameters, + }, + engine::StarkFriEngine, + p3_baby_bear::BabyBear, + utils::{create_seeded_rng, ProofInputForTest}, +}; +use rand::Rng; +pub type F = BabyBear; +pub type E = BinomialExtensionField; + +#[test] +fn test_sumcheck_layer_eval() { + let mut builder = AsmBuilder::>::default(); + + build_test_program(&mut builder); + + // Fill in test program logic + builder.halt(); + + let compilation_options = CompilerOptions::default().with_cycle_tracker(); + let mut compiler = AsmCompiler::new(compilation_options.word_size); + compiler.build(builder.operations); + let asm_code = compiler.code(); + + // let program = Program::from_instructions(&instructions); + let program: Program<_> = convert_program(asm_code, compilation_options); + let sumcheck_max_constraint_degree = 3; + let fri_params = if matches!(std::env::var("OPENVM_FAST_TEST"), Ok(x) if &x == "1") { + FriParameters { + // max constraint degree = 2^log_blowup + 1 + log_blowup: 1, + log_final_poly_len: 0, + num_queries: 2, + proof_of_work_bits: 0, + } + } else { + standard_fri_params_with_100_bits_conjectured_security(1) + }; + + let engine = BabyBearPoseidon2Engine::new(fri_params); + let mut config = NativeConfig::aggregation(0, sumcheck_max_constraint_degree); + config.system.memory_config.max_access_adapter_n = 16; + + let vm = VirtualMachine::new(engine, config); + + let pk = vm.keygen(); + let result = vm.execute_and_generate(program, vec![]).unwrap(); + let proofs = vm.prove(&pk, result); + + for proof in proofs { + verify_single(&vm.engine, &pk.get_vk(), &proof).expect("Verification failed"); + } +} + +fn build_test_program(builder: &mut Builder) { + let ctx_u32s = [3u32, 6, 5, 8, 2, 8, 4, 0, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9]; + let ctx: Array> = builder.dyn_array(ctx_u32s.len()); + for (idx, n) in ctx_u32s.into_iter().enumerate() { + builder.set(&ctx, idx, Usize::from(n as usize)); + } + + let challenges_u32s = [ + 548478283u32, + 456436544, + 1716290291, + 791326976, + 1829717553, + 1422025771, + 1917123958, + 727015942, + 183548369, + 591240150, + 96141963, + 1286249979, + ]; + let challenges: Array> = builder.dyn_array(challenges_u32s.len() / EXT_DEG); + for (idx, n) in challenges_u32s.chunks(EXT_DEG).enumerate() { + let e: Ext = builder.constant(C::EF::from_base_slice(&[ + C::F::from_canonical_u32(n[0]), + C::F::from_canonical_u32(n[1]), + C::F::from_canonical_u32(n[2]), + C::F::from_canonical_u32(n[3]), + ])); + + builder.set(&challenges, idx, e); + } + + let prod_spec_eval_u32s = [ + 1538906710u32, + 637535518, + 1753132406, + 1395236651, + 278806441, + 1722910382, + 1475548665, + 1117874675, + 1578586709, + 1826764884, + 384068476, + 1852240363, + 707958906, + 1960944944, + 183554399, + 1259273357, + 227285124, + 243066436, + 1718037317, + 369721963, + 1752968006, + 1061013677, + 775617499, + 1464907431, + 544300429, + 871461966, + 135151545, + 1343592602, + 1622220528, + 643966158, + 3932580, + 434948358, + 540553922, + 1446502052, + 153298741, + 1191216273, + 265936762, + 1463035257, + 1237633339, + 1797346310, + 1355791584, + 389527741, + 1741650463, + 1728913415, + 1825739540, + 1790924136, + 460776743, + 29536554, + 6842036, + 252495270, + 1968285155, + 299467416, + 49085744, + 1499815729, + 1098802236, + 644489275, + 1827273105, + 1888401527, + 390077051, + 565528894, + 1366177188, + 67441791, + 958486301, + 402056716, + 590379691, + 462035406, + 633459131, + 843304872, + 584100013, + 1932496508, + 250656031, + 146983915, + 1835173157, + 939973454, + 1844873638, + 1916054832, + 1601784696, + 167251717, + 409107688, + 1062925788, + 1291319514, + 1790529531, + 495655592, + 1093359708, + 790197205, + 674458164, + 195988318, + 399764452, + 106865258, + 967050329, + 350035523, + 1109292118, + 1815460301, + 281986036, + 900636603, + 1121197008, + 1228976590, + 1879998708, + 1924332706, + 434695844, + 1159360621, + 471397106, + 473371067, + 1009065094, + 1320176846, + 168020789, + 1265321929, + 1901808675, + 223657700, + 1480150183, + 1779968584, + 144416591, + 304407746, + 1864498679, + 1482460119, + 1554376965, + 1479261548, + 1657723043, + 1039345063, + 1053923521, + 442080513, + 1964082352, + 691664908, + 1941008321, + 1007729002, + 860529393, + 849697342, + 754485488, + 584295923, + 1072251466, + 1105105254, + 996079746, + 1305909868, + 1348028973, + 122275988, + 464050036, + 692807777, + 1098809324, + 397235220, + 596459886, + 1663209783, + 720230826, + 1422510715, + 1760654694, + 544197700, + 1417744567, + 1938716517, + 1571826328, + 1591430185, + 1173137446, + 175285007, + 1541718596, + 1715958587, + 1429966110, + 583013357, + 1667787861, + 109891172, + 668253167, + 161783842, + 296183397, + 1681897325, + 1054396117, + 264741948, + 464026995, + 1907686022, + 1532786783, + 394869458, + 1766734740, + 136047179, + 536856195, + 376188855, + 700633625, + 515518419, + 531043483, + 60673499, + 556496527, + 1743028981, + 873954569, + 1371062291, + 632169731, + 1353239206, + 526507035, + 1894490088, + 589441599, + 1610487168, + 1074160583, + 366366374, + 247602990, + 1535354896, + 894493713, + 1555870413, + 1389854934, + 1897251683, + 1525812801, + 675621735, + 697919636, + 1690274072, + 1466810921, + 1221110784, + 1741995587, + 1877169764, + 390876982, + 1794129810, + 297662156, + 144295349, + 417037264, + 1290835727, + 1654971513, + 1674131303, + 1625667423, + 1471248832, + 1676797844, + 1172916558, + 1707775403, + 423725211, + 1643279661, + 1695774264, + 378140395, + 1517569394, + 1666625392, + 1803981250, + 439036260, + 247966130, + 709534816, + 361144100, + 1546096548, + 1240886454, + 1898161518, + 843262057, + 1709259464, + 1301015977, + 1997626928, + 677153173, + 1606710353, + 1216038070, + 435565562, + 98686333, + 1773787396, + 267051994, + 99395396, + 545509105, + 782289675, + 1289865975, + 1707775075, + 1158993015, + 1506576588, + 993215179, + 1523099397, + 923914455, + 1895162386, + 284489994, + 1444139016, + 1943825680, + 466202724, + 1632522710, + 1384015062, + 723147188, + 1284031324, + 1430481515, + 341213007, + 171192499, + 1061688239, + 808927167, + 83182639, + 759209907, + 1728321272, + 976049976, + 1652071995, + 1002877840, + 69880246, + 1095135165, + 677588420, + 1384715290, + 829619452, + 170122781, + 1958173727, + 13389238, + 789379698, + 1883383039, + 1279195174, + 1618672336, + 1192839317, + 1348311124, + 758896285, + 1939775389, + 684108413, + 1838340479, + 1332232130, + 1070486028, + 549228790, + 868851698, + 1678207843, + 1754321489, + 637000403, + 647901906, + 45343322, + 1768524074, + 1167955205, + 1816497210, + 1609414096, + 1985231742, + 1540534482, + 232730819, + 232221968, + 1509637836, + 1480860627, + 884647789, + 1096458024, + 163721583, + 1248032262, + 436419506, + 1737102298, + 651105860, + 452298073, + 1064372507, + 1792838683, + 619243471, + 860127631, + 721724708, + 950768433, + 279913448, + 339693210, + 47730422, + 1952683911, + 1316500770, + 675944216, + 386902809, + 619333956, + 1194800389, + 43989936, + 1944372656, + 666045666, + 1155873844, + 522696968, + 58874730, + 1497238023, + 421619994, + 1980672127, + 1657191856, + 1913792631, + 1784663131, + 1118400672, + 1828104993, + 1637808383, + 414755472, + 775410449, + 747132157, + 136820101, + 1082674285, + 93190395, + 357955402, + 335652723, + 1192102705, + 480365232, + 1354935730, + 1391829361, + 966662991, + 1601510445, + 569528575, + 545490940, + 1753711688, + 807025222, + 580374183, + 587718008, + 977546290, + 1055719519, + 1157107032, + 562799608, + 859466927, + 840450024, + 815325134, + 936576801, + 1010587056, + 246624382, + 1808049797, + 1098183398, + 1005077390, + 772432546, + 1976629565, + 1003772218, + 1655315418, + 1767931114, + 982008720, + 785023351, + ]; + + let prod_spec_evals: Array> = + builder.dyn_array(prod_spec_eval_u32s.len() / EXT_DEG); + for (idx, n) in prod_spec_eval_u32s.chunks(EXT_DEG).enumerate() { + let e: Ext = builder.constant(C::EF::from_base_slice(&[ + C::F::from_canonical_u32(n[0]), + C::F::from_canonical_u32(n[1]), + C::F::from_canonical_u32(n[2]), + C::F::from_canonical_u32(n[3]), + ])); + + builder.set(&prod_spec_evals, idx, e); + } + + let logup_spec_eval_u32s = [ + 1522353967u32, + 457603397, + 421847521, + 1352563318, + 1746817766, + 737872688, + 1087008622, + 1850835028, + 456475558, + 892966330, + 638163666, + 148568548, + 678863061, + 1334386850, + 1896333039, + 154585769, + 433618446, + 1186936470, + 970218722, + 1213827097, + 1798557019, + 861757965, + 119285527, + 395360622, + 226164366, + 1330279872, + 66561048, + 785421608, + 1950755756, + 1559889596, + 348449876, + 1090789452, + 257578851, + 273164442, + 1644906, + 295600924, + 1187949602, + 1168249609, + 469763604, + 60929061, + 291163036, + 403842501, + 1421902433, + 1700188477, + 1046093370, + 921059131, + 1638991894, + 464012042, + 96905857, + 1370999592, + 271896041, + 13595534, + 1489760970, + 1650552701, + 133367846, + 25680377, + 377631580, + 652729291, + 645763356, + 426747355, + 482475486, + 1877299223, + 103226636, + 1333832358, + 1399609097, + 458536972, + 976248802, + 1109365280, + 515164588, + 1579426417, + 1601829549, + 607169702, + 852817956, + 1980537127, + 134138338, + 913344050, + 737880920, + 476360275, + 61624034, + 1610624252, + 264461991, + 546933535, + 937769429, + 293346965, + 1522058041, + 1012551797, + 994330314, + 23333322, + 1969510890, + 974351570, + 2012030621, + 120742000, + 450250620, + 180547360, + 642746933, + 1815029950, + 629489142, + 1176992624, + 723354779, + 572648755, + 1218615348, + 648847054, + 351903235, + 723149764, + 248065753, + 243829448, + 1283393001, + 1912627886, + 581641342, + 702465306, + 205969758, + 1061911274, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1703043252, + 1467887451, + 1714319214, + 907866644, + 1542426838, + 742609036, + 1814393459, + 448706641, + 1960340767, + 46490834, + 186512520, + 363973095, + 846448854, + 463742343, + 2012517527, + 40473617, + 9472552, + 263483342, + 105738598, + 586389136, + 254290990, + 625150844, + 960233097, + 1488303724, + 1700231692, + 1471714612, + 1540211186, + 1590246915, + 945341972, + 1343225515, + 179976237, + 34857822, + 276912528, + 984309272, + 1277293398, + 1520924162, + 1823117694, + 604836357, + 1460812009, + 600052559, + 970469338, + 1771022707, + 181855831, + 1445947220, + 467514809, + 1514677498, + 947030389, + 170390653, + 415409007, + 1601463730, + 204153427, + 904614278, + 1855419512, + 2009471607, + 1352607379, + 576586082, + 1343812879, + 1176377580, + 1166188815, + 1592289048, + 761793881, + 1529621462, + 193034837, + 344011596, + 1669461833, + 1356800025, + 314186361, + 586497329, + 1832810846, + 1288092861, + 1619454491, + 732529408, + 737934269, + 909504928, + 769680420, + 1437893101, + 1727002258, + 1618231110, + 535125583, + 153412473, + 1917760929, + 588586507, + 564531165, + 1790797737, + 1666283994, + 1366948884, + 117673690, + 476470378, + 2012274032, + 1951406668, + 1739767532, + 1273142151, + 1591812317, + 1900205312, + 1912608761, + 1734766024, + 1265002082, + 1450462894, + 749810837, + 1329222552, + 745081805, + 1231519431, + 1420957967, + 883846107, + 1995463911, + 407795592, + 161655852, + 125886157, + 995318920, + 484905024, + 284135318, + 551493419, + 406742309, + 1089024446, + 637339867, + 1858138403, + 1230680117, + 187078889, + 1929517480, + 1125646261, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1610035932, + 462442436, + 831412555, + 44798862, + 1748147276, + 1911945531, + 1329343740, + 971894393, + 362147969, + 1583335926, + 1528700112, + 426908674, + 847905883, + 447889090, + 1050883911, + 1883537469, + 1487501632, + 964178870, + 1818828551, + 1980840799, + 340372118, + 1697179193, + 215113037, + 1893217470, + 1138628493, + 1788052486, + 443362955, + 1349213730, + 589553425, + 562526667, + 1006040406, + 1194546769, + 1831034644, + 612004157, + 730213913, + 1068905440, + 371983982, + 502900790, + 802785198, + 822377635, + 1477528437, + 501356237, + 684668525, + 1306043781, + 621032592, + 1971342708, + 1411586583, + 733418745, + 186045462, + 1559301855, + 323758310, + 453170140, + 498381240, + 976247416, + 631213663, + 898017829, + 501459603, + 609703046, + 1379288251, + 177682695, + 912381595, + 121915494, + 1137416430, + 504054388, + 1138277238, + 1603388253, + 1838013301, + 1700271853, + 20488607, + 58775264, + 217974275, + 979141729, + 53136584, + 1331566240, + 1460303356, + 525812787, + 718385521, + 1477919263, + 1663622276, + 1089788203, + 1204483837, + 54225863, + 290660186, + 1441441958, + 134168813, + 349638823, + 1867912015, + 1579183319, + 55528656, + 1602973359, + 194297109, + 949763297, + 101931919, + 242300116, + 1610052257, + 1351823848, + 174522860, + 776955925, + 1706962365, + 808187490, + 1487253852, + 431806906, + 213982593, + 1170647308, + 1776840400, + 295916317, + 378708073, + 381270341, + 457494568, + 705823997, + 1407301442, + 1693003013, + 700310785, + 1349874247, + 1284363817, + 1566253815, + 1014298154, + 215294365, + 1070968678, + 871641358, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1302679751, + 1121894357, + 368587356, + 1564724097, + 733815591, + 2012670011, + 1146780092, + 1439780227, + 1801628424, + 838692317, + 932318853, + 213634365, + 155292454, + 1644317110, + 1599846194, + 978829059, + 1282095862, + 1780431647, + 527412087, + 1024583705, + 804423802, + 951808322, + 689345230, + 180304167, + 1784562773, + 1514653374, + 2009396440, + 1143778943, + 235299446, + 1553017484, + 475425117, + 758292254, + 716575432, + 517083432, + 1728864125, + 418010549, + 43202592, + 507659742, + 433077118, + 1268144019, + 1462778342, + 1928073362, + 1330130180, + 1749624351, + 827401013, + 1236194147, + 1875519726, + 1437946791, + 607293265, + 309229599, + 1009445595, + 1725229718, + 1436309341, + 1952606463, + 943149111, + 291680468, + 1989684076, + 1944713370, + 1285294139, + 399758737, + 1572979232, + 213817406, + 214840530, + 184898060, + 1483844295, + 1536616777, + 494816009, + 217625163, + 529448032, + 786640964, + 1766471731, + 1424140424, + 1721961711, + 740275169, + 169908711, + 913969302, + 1359358267, + 1328322971, + 593228769, + 771095186, + 801680440, + 450930656, + 1796349530, + 1824428677, + 1111258504, + 1741666629, + 1098430204, + 1792001884, + 1679003061, + 590088446, + 647614538, + 1324461639, + 818996796, + 229187928, + 74288115, + 1158900266, + 1512606270, + 1381672753, + 785927403, + 493453164, + 425259497, + 1367873539, + 931023744, + 221202218, + 669580668, + 424996238, + 1840425275, + 1873362670, + 967642716, + 263556335, + 578560519, + 1558449223, + 607579284, + 1724012378, + 333582342, + 1195784167, + 1419727276, + 199294290, + 138807165, + 1061030752, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 776332180, + 1333076185, + 1855163818, + 1897408938, + 799274251, + 950452503, + 691904988, + 1205387466, + 659107883, + 434394982, + 129587940, + 639018629, + 659238594, + 1957584892, + 864291238, + 589178070, + 1267157231, + 48925338, + 200093884, + 1953762869, + 1227617341, + 1471420621, + 193077633, + 1007876111, + 228491220, + 1377349503, + 1889411060, + 1807513892, + 1593042934, + 1240864695, + 1472870721, + 583021932, + 598239104, + 1862008818, + 1811242869, + 780768026, + 520870395, + 292016292, + 322246659, + 868240490, + 1715620331, + 1183509209, + 2010262726, + 1003957251, + 264895455, + 307755941, + 201990485, + 1662471178, + 1643997923, + 1573129362, + 277821143, + 388834470, + 943361405, + 1449402196, + 614413575, + 1504113993, + 1860552739, + 1755127315, + 1734129760, + 1232115188, + 803035456, + 360488092, + 271342171, + 1269544258, + 290642673, + 660703582, + 986842267, + 870891877, + 454573044, + 1999346236, + 701614601, + 820253867, + 883282765, + 137247873, + 1727164949, + 1320585493, + 1738664600, + 1900116905, + 472215154, + 1114994489, + 104218174, + 1694603079, + 771486383, + 935361143, + 92277671, + 881040480, + 925124484, + 1464396527, + 100625197, + 65290355, + 1001454341, + 134627585, + 58629702, + 1541542242, + 568583607, + 1706262052, + 530687550, + 1303187245, + 1010302462, + 264001857, + 789816678, + 561378226, + 827432508, + 801307507, + 1613508315, + 1650822853, + 1603502703, + 439320335, + 15283580, + 1244486577, + 254345266, + 1745653280, + 1648250354, + 1528271018, + 528366563, + 1078707735, + 1430767759, + 1890467731, + 2001894083, + 799949326, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1341839494, + 1092219735, + 755644898, + 966729319, + 1914277278, + 1545367697, + 1765189119, + 1693413008, + ]; + + let logup_spec_evals: Array> = + builder.dyn_array(logup_spec_eval_u32s.len() / EXT_DEG); + for (idx, n) in logup_spec_eval_u32s.chunks(EXT_DEG).enumerate() { + let e: Ext = builder.constant(C::EF::from_base_slice(&[ + C::F::from_canonical_u32(n[0]), + C::F::from_canonical_u32(n[1]), + C::F::from_canonical_u32(n[2]), + C::F::from_canonical_u32(n[3]), + ])); + + builder.set(&logup_spec_evals, idx, e); + } + + let r_evals_u32s = [ + 941378355u32, + 1078920879, + 696738840, + 496039492, + 1555445457, + 184545404, + 905938226, + 1847966044, + 1024875886, + 1782716223, + 1625644635, + 266865456, + 465953066, + 1663531470, + 757423849, + 1957075986, + 1919693393, + 839104130, + 127480221, + 1527842912, + 918650796, + 921462354, + 575456073, + 696646705, + 1585912361, + 258186488, + 353168830, + 1111094691, + 1401166558, + 1905942163, + 1923083163, + 393037255, + 1042127700, + 1126793296, + 895794165, + 1124924482, + 1324266058, + 722406365, + 1963838171, + 968504459, + 1934378800, + 714588691, + 6465911, + 1168379648, + 903786009, + 1326035939, + 518289228, + 418998914, + 1513133474, + 1578096058, + 617547414, + 1658315126, + 68556894, + 1697802593, + 1346510664, + 1709381671, + 345062962, + 1254089535, + 1002281845, + 1882822096, + 700581748, + 1431345304, + 489112954, + 98435728, + 1799886007, + 479788390, + 223111065, + 631662309, + ]; + + let next_layer_evals: Array> = + builder.dyn_array(r_evals_u32s.len() / EXT_DEG); + for (idx, n) in r_evals_u32s.chunks(EXT_DEG).enumerate() { + let e: Ext = builder.constant(C::EF::from_base_slice(&[ + C::F::from_canonical_u32(n[0]), + C::F::from_canonical_u32(n[1]), + C::F::from_canonical_u32(n[2]), + C::F::from_canonical_u32(n[3]), + ])); + + builder.set(&next_layer_evals, idx, e); + } + + builder.sumcheck_layer_eval( + &ctx, + &challenges, + &prod_spec_evals, + &logup_spec_evals, + &next_layer_evals, + ); +} From 33679c099348fab1277e63d95bf421066855b69f Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Thu, 27 Nov 2025 14:46:30 +0800 Subject: [PATCH 02/18] wip2 --- .../native/circuit/src/extension/mod.rs | 20 +- extensions/native/circuit/src/lib.rs | 1 + extensions/native/circuit/src/sumcheck/air.rs | 83 +++---- .../native/circuit/src/sumcheck/chip.rs | 116 +++++----- .../native/circuit/src/sumcheck/execution.rs | 204 +++++++++++++++--- extensions/native/compiler/src/ir/sumcheck.rs | 2 +- extensions/native/recursion/tests/sumcheck.rs | 32 ++- 7 files changed, 319 insertions(+), 139 deletions(-) diff --git a/extensions/native/circuit/src/extension/mod.rs b/extensions/native/circuit/src/extension/mod.rs index a86cdb1bd2..f175a9c4de 100644 --- a/extensions/native/circuit/src/extension/mod.rs +++ b/extensions/native/circuit/src/extension/mod.rs @@ -17,7 +17,8 @@ use openvm_instructions::{program::DEFAULT_PC_STEP, LocalOpcode, PhantomDiscrimi use openvm_native_compiler::{ CastfOpcode, FieldArithmeticOpcode, FieldExtensionOpcode, FriOpcode, NativeBranchEqualOpcode, NativeJalOpcode, NativeLoadStore4Opcode, NativeLoadStoreOpcode, NativePhantom, - NativeRangeCheckOpcode, Poseidon2Opcode, VerifyBatchOpcode, BLOCK_LOAD_STORE_SIZE, + NativeRangeCheckOpcode, Poseidon2Opcode, SumcheckOpcode, VerifyBatchOpcode, + BLOCK_LOAD_STORE_SIZE, }; use openvm_poseidon2_air::Poseidon2Config; use openvm_rv32im_circuit::BranchEqualCoreAir; @@ -61,6 +62,10 @@ use crate::{ chip::{NativePoseidon2Executor, NativePoseidon2Filler}, NativePoseidon2Chip, }, + sumcheck::{ + air::NativeSumcheckAir, + chip::{NativeSumcheckChip, NativeSumcheckExecutor, NativeSumcheckFiller}, + }, }; cfg_if::cfg_if! { @@ -94,6 +99,7 @@ pub enum NativeExecutor { FieldExtension(FieldExtensionExecutor), FriReducedOpening(FriReducedOpeningExecutor), VerifyBatch(NativePoseidon2Executor), + TowerVerify(NativeSumcheckExecutor), } impl VmExecutionExtension for Native { @@ -169,6 +175,12 @@ impl VmExecutionExtension for Native { ], )?; + let tower_verify = NativeSumcheckExecutor::new(); + inventory.add_executor( + tower_verify, + [SumcheckOpcode::SUMCHECK_LAYER_EVAL.global_opcode()], + ); + inventory.add_phantom_sub_executor( NativeHintInputSubEx, PhantomDiscriminant(NativePhantom::HintInput as u16), @@ -262,6 +274,9 @@ where ); inventory.add_air(verify_batch); + let tower_evaluate = NativeSumcheckAir::new(exec_bridge, memory_bridge); + inventory.add_air(tower_evaluate); + Ok(()) } } @@ -342,6 +357,9 @@ where ); inventory.add_executor_chip(poseidon2); + let tower_verify = NativeSumcheckChip::new(NativeSumcheckFiller::new(), mem_helper.clone()); + inventory.add_executor_chip(tower_verify); + Ok(()) } } diff --git a/extensions/native/circuit/src/lib.rs b/extensions/native/circuit/src/lib.rs index aa2bb3bd31..ce257c9c22 100644 --- a/extensions/native/circuit/src/lib.rs +++ b/extensions/native/circuit/src/lib.rs @@ -46,6 +46,7 @@ mod sumcheck; mod extension; pub use extension::*; +pub use field_extension::EXT_DEG; mod utils; #[cfg(any(test, feature = "test-utils"))] diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 494204ceff..0275e009b9 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -5,7 +5,7 @@ use openvm_circuit::{ system::memory::{offline_checker::MemoryBridge, MemoryAddress}, }; use openvm_circuit_primitives::utils::{assert_array_eq, not}; -use openvm_instructions::LocalOpcode; +use openvm_instructions::{LocalOpcode, NATIVE_AS}; use openvm_native_compiler::SumcheckOpcode::SUMCHECK_LAYER_EVAL; use openvm_stark_backend::{ air_builders::sub::SubAirBuilder, @@ -17,36 +17,45 @@ use openvm_stark_backend::{ }; use crate::{ + field_extension::{FieldExtension, EXT_DEG}, sumcheck::columns::{ HeaderSpecificCols, LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols, }, - FieldExtension, EXT_DEG, }; #[derive(Clone, Debug)] -pub struct NativeSumcheckAir { +pub struct NativeSumcheckAir { pub execution_bridge: ExecutionBridge, pub memory_bridge: MemoryBridge, - pub address_space: F, } -impl BaseAir for NativeSumcheckAir { +impl NativeSumcheckAir { + pub fn new(execution_bridge: ExecutionBridge, memory_bridge: MemoryBridge) -> Self { + Self { + execution_bridge, + memory_bridge, + } + } +} + +impl BaseAir for NativeSumcheckAir { fn width(&self) -> usize { NativeSumcheckCols::::width() } } -impl BaseAirWithPublicValues for NativeSumcheckAir {} +impl BaseAirWithPublicValues for NativeSumcheckAir {} -impl PartitionedBaseAir for NativeSumcheckAir {} +impl PartitionedBaseAir for NativeSumcheckAir {} -impl Air for NativeSumcheckAir { +impl Air for NativeSumcheckAir { fn eval(&self, builder: &mut AB) { let main = builder.main(); let local = main.row_slice(0); let local: &NativeSumcheckCols = (*local).borrow(); let next = main.row_slice(1); let next: &NativeSumcheckCols = (*next).borrow(); + let native_as = AB::F::from_canonical_u32(NATIVE_AS); let &NativeSumcheckCols { // Row indicators @@ -211,19 +220,19 @@ impl Air for NativeSumcheckAir { [AB::F::ONE, AB::F::ZERO, AB::F::ZERO, AB::F::ZERO], ); let alpha_denominator = FieldExtension::multiply(alpha1, alpha); - assert_array_eq::<_, _, _, EXT_DEG>( + assert_array_eq::<_, _, _, { EXT_DEG }>( &mut builder.when(logup_row), alpha_denominator, alpha2, ); let prod_next_alpha = FieldExtension::multiply(alpha1, alpha); - assert_array_eq::<_, _, _, EXT_DEG>( + assert_array_eq::<_, _, _, { EXT_DEG }>( &mut builder.when(prod_continuation), prod_next_alpha, next_alpha1, ); let logup_next_alpha = FieldExtension::multiply(alpha2, alpha); - assert_array_eq::<_, _, _, EXT_DEG>( + assert_array_eq::<_, _, _, { EXT_DEG }>( &mut builder.when(logup_continuation), logup_next_alpha, next_alpha1, @@ -241,8 +250,8 @@ impl Air for NativeSumcheckAir { registers[4].into(), registers[0].into(), registers[1].into(), - self.address_space.into(), - self.address_space.into(), + native_as.into(), + native_as.into(), registers[2].into(), registers[3].into(), ], @@ -255,7 +264,7 @@ impl Air for NativeSumcheckAir { for i in 0..5usize { self.memory_bridge .read( - MemoryAddress::new(self.address_space, registers[i]), + MemoryAddress::new(native_as, registers[i]), [register_ptrs[i]], first_timestamp + AB::F::from_canonical_usize(i), &header_row_specific.read_records[i], @@ -266,7 +275,7 @@ impl Air for NativeSumcheckAir { // React ctx self.memory_bridge .read( - MemoryAddress::new(self.address_space, register_ptrs[0]), + MemoryAddress::new(native_as, register_ptrs[0]), ctx, first_timestamp + AB::F::from_canonical_usize(5), &header_row_specific.read_records[5], @@ -276,7 +285,7 @@ impl Air for NativeSumcheckAir { // Read challenges self.memory_bridge .read( - MemoryAddress::new(self.address_space, register_ptrs[1]), + MemoryAddress::new(native_as, register_ptrs[1]), challenges, first_timestamp + AB::F::from_canonical_usize(6), &header_row_specific.read_records[6], @@ -286,7 +295,7 @@ impl Air for NativeSumcheckAir { // Write final result self.memory_bridge .write( - MemoryAddress::new(self.address_space, register_ptrs[4]), + MemoryAddress::new(native_as, register_ptrs[4]), eval_acc, last_timestamp - AB::F::ONE, &header_row_specific.write_records, @@ -302,7 +311,7 @@ impl Air for NativeSumcheckAir { self.memory_bridge .read( MemoryAddress::new( - self.address_space, + native_as, register_ptrs[0] + AB::F::from_canonical_usize(EXT_DEG * 2) + (curr_prod_n - AB::F::ONE), @@ -330,10 +339,7 @@ impl Air for NativeSumcheckAir { self.memory_bridge .read( - MemoryAddress::new( - self.address_space, - register_ptrs[2] + prod_row_specific.data_ptr, - ), + MemoryAddress::new(native_as, register_ptrs[2] + prod_row_specific.data_ptr), prod_row_specific.p, start_timestamp + AB::F::ONE, &prod_row_specific.read_records[1], @@ -348,7 +354,7 @@ impl Air for NativeSumcheckAir { self.memory_bridge .write( MemoryAddress::new( - self.address_space, + native_as, register_ptrs[4] + curr_prod_n * AB::F::from_canonical_usize(EXT_DEG), ), prod_row_specific.p_evals, @@ -363,12 +369,12 @@ impl Air for NativeSumcheckAir { FieldExtension::multiply::(p2, c2), ); let in_round_p_evals = FieldExtension::multiply::(p1, p2); - assert_array_eq::<_, _, _, EXT_DEG>( + assert_array_eq::<_, _, _, { EXT_DEG }>( &mut builder.when(prod_in_round_evaluation), in_round_p_evals, prod_row_specific.p_evals, ); - assert_array_eq::<_, _, _, EXT_DEG>( + assert_array_eq::<_, _, _, { EXT_DEG }>( &mut builder.when(prod_next_round_evaluation), next_round_p_evals, prod_row_specific.p_evals, @@ -377,14 +383,14 @@ impl Air for NativeSumcheckAir { // Accumulate evaluation let acc_eval = FieldExtension::multiply::(prod_row_specific.p_evals, alpha1); - assert_array_eq::<_, _, _, EXT_DEG>( + assert_array_eq::<_, _, _, { EXT_DEG }>( &mut builder.when(prod_acc), prod_row_specific.acc_eval, acc_eval, ); let next_acc = FieldExtension::subtract(eval_acc, next_prod_row_specific.acc_eval); - assert_array_eq::<_, _, _, EXT_DEG>( + assert_array_eq::<_, _, _, { EXT_DEG }>( &mut builder.when(next.prod_acc), next.eval_acc, next_acc, @@ -399,7 +405,7 @@ impl Air for NativeSumcheckAir { self.memory_bridge .read( MemoryAddress::new( - self.address_space, + native_as, register_ptrs[0] + AB::F::from_canonical_usize(EXT_DEG * 2) + ctx[1] @@ -428,10 +434,7 @@ impl Air for NativeSumcheckAir { self.memory_bridge .read( - MemoryAddress::new( - self.address_space, - register_ptrs[3] + logup_row_specific.data_ptr, - ), + MemoryAddress::new(native_as, register_ptrs[3] + logup_row_specific.data_ptr), logup_row_specific.pq, start_timestamp + AB::F::ONE, &logup_row_specific.read_records[1], @@ -452,7 +455,7 @@ impl Air for NativeSumcheckAir { self.memory_bridge .write( MemoryAddress::new( - self.address_space, + native_as, register_ptrs[4] + (ctx[1] + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG), ), @@ -465,7 +468,7 @@ impl Air for NativeSumcheckAir { self.memory_bridge .write( MemoryAddress::new( - self.address_space, + native_as, register_ptrs[4] + (ctx[1] + ctx[2] + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG), ), @@ -484,12 +487,12 @@ impl Air for NativeSumcheckAir { FieldExtension::multiply::(p1, q2), FieldExtension::multiply::(p2, q1), ); - assert_array_eq::<_, _, _, EXT_DEG>( + assert_array_eq::<_, _, _, { EXT_DEG }>( &mut builder.when(logup_in_round_evaluation), in_round_p_evals, logup_row_specific.p_evals, ); - assert_array_eq::<_, _, _, EXT_DEG>( + assert_array_eq::<_, _, _, { EXT_DEG }>( &mut builder.when(logup_next_round_evaluation), next_round_p_evals, logup_row_specific.p_evals, @@ -500,12 +503,12 @@ impl Air for NativeSumcheckAir { FieldExtension::multiply::(q2, c2), ); let in_round_q_evals = FieldExtension::multiply::(q1, q2); - assert_array_eq::<_, _, _, EXT_DEG>( + assert_array_eq::<_, _, _, { EXT_DEG }>( &mut builder.when(logup_in_round_evaluation), in_round_q_evals, logup_row_specific.q_evals, ); - assert_array_eq::<_, _, _, EXT_DEG>( + assert_array_eq::<_, _, _, { EXT_DEG }>( &mut builder.when(logup_next_round_evaluation), next_round_q_evals, logup_row_specific.q_evals, @@ -516,14 +519,14 @@ impl Air for NativeSumcheckAir { FieldExtension::multiply::(logup_row_specific.p_evals, alpha1), FieldExtension::multiply::(logup_row_specific.q_evals, alpha2), ); - assert_array_eq::<_, _, _, EXT_DEG>( + assert_array_eq::<_, _, _, { EXT_DEG }>( &mut builder.when(logup_acc), logup_row_specific.acc_eval, acc_eval, ); let next_acc = FieldExtension::subtract(eval_acc, next_logup_row_specfic.acc_eval); - assert_array_eq::<_, _, _, EXT_DEG>( + assert_array_eq::<_, _, _, { EXT_DEG }>( &mut builder.when(next.logup_acc), next.eval_acc, next_acc, diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index 2ec3f1bd3c..09e362f8ac 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -2,12 +2,11 @@ use std::sync::{Arc, Mutex}; use openvm_circuit::{ arch::{ - ExecutionBridge, ExecutionError, ExecutionState, InstructionExecutor, PreflightExecutor, - RecordArena, Streams, SystemPort, TraceFiller, VmChipWrapper, VmStateMut, - }, - system::memory::{ - online::TracingMemory, MemoryAuxColsFactory, MemoryController, OfflineMemory, RecordId, + CustomBorrow, ExecutionBridge, ExecutionError, ExecutionState, MultiRowLayout, + MultiRowMetadata, PreflightExecutor, RecordArena, Streams, TraceFiller, VmChipWrapper, + VmStateMut, }, + system::memory::{online::TracingMemory, MemoryAuxColsFactory, MemoryController}, }; use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; use openvm_native_compiler::{conversion::AS, SumcheckOpcode::SUMCHECK_LAYER_EVAL}; @@ -28,52 +27,58 @@ use crate::{ }; const CONTEXT_ARR_BASE_LEN: usize = EXT_DEG * 2; -#[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -#[serde(bound = "F: Field")] -pub struct SumcheckEvalRecord { - pub from_state: ExecutionState, - pub instruction: Instruction, - pub row_type: usize, // 0 - header; 1 - prod; 2 - logup - pub curr_timestamp_increment: usize, - pub final_timestamp_increment: usize, - pub continuation: bool, - - pub register_ptrs: [F; 5], - pub registers: [F; 5], - pub ctx: [F; EXT_DEG * 2], - pub challenges: [F; EXT_DEG * 4], - pub read_data_records: [RecordId; 7], - pub write_data_records: [RecordId; 2], - - pub max_round: F, - pub within_round_limit: bool, - pub should_acc: bool, - pub prod_spec_n: usize, - pub logup_spec_n: usize, - pub alpha: [F; EXT_DEG], - pub alpha1: [F; EXT_DEG], - pub alpha2: [F; EXT_DEG], - pub data_ptr: F, - pub p1: [F; EXT_DEG], - pub p2: [F; EXT_DEG], - pub q1: [F; EXT_DEG], - pub q2: [F; EXT_DEG], - pub p_evals: [F; EXT_DEG], - pub q_evals: [F; EXT_DEG], - pub eval_acc: [F; EXT_DEG], - pub acc_eval: [F; EXT_DEG], +pub(crate) fn calculate_3d_ext_idx( + inner_inner_len: u32, + inner_len: u32, + outer_idx: u32, + inner_idx: u32, + inner_inner_idx: u32, +) -> u32 { + (inner_inner_len * inner_len * outer_idx + inner_inner_len * inner_idx + inner_inner_idx) + * EXT_DEG as u32 } -fn calculate_3d_ext_idx( - inner_inner_len: F, - inner_len: F, - outer_idx: F, - inner_idx: F, - inner_inner_idx: F, -) -> F { - (inner_inner_len * inner_len * outer_idx + inner_inner_len * inner_idx + inner_inner_idx) - * F::from_canonical_usize(EXT_DEG) +#[derive(Debug, Clone, Default)] +pub struct NativeSumcheckMetadata { + num_rows: usize, +} + +impl MultiRowMetadata for NativeSumcheckMetadata { + #[inline(always)] + fn get_num_rows(&self) -> usize { + self.num_rows + } +} + +type NativeSumcheckRecordLayout = MultiRowLayout; + +pub struct NativeSumcheckRecordMut<'a, F>(&'a mut [NativeSumcheckCols]); + +impl<'a, F: PrimeField32> + CustomBorrow<'a, NativeSumcheckRecordMut<'a, F>, NativeSumcheckRecordLayout> for [u8] +{ + fn custom_borrow( + &'a mut self, + layout: NativeSumcheckRecordLayout, + ) -> NativeSumcheckRecordMut<'a, F> { + // SAFETY: + // - align_to_mut() ensures proper alignment for NativeSumcheckCols + // - Layout guarantees sufficient length for num_rows records + // - Slice bounds validated by taking only num_rows elements + let arr = unsafe { self.align_to_mut::>().1 }; + NativeSumcheckRecordMut(&mut arr[..layout.metadata.num_rows]) + } + + unsafe fn extract_layout(&self) -> NativeSumcheckRecordLayout { + // Each instruction record consists solely of some number of contiguously + // stored NativeSumcheckCols<...> structs, each of which corresponds to a + // single trace row. Trace fillers don't actually need to know how many rows + // each instruction uses, and can thus treat each NativePoseidon2Cols<...> + // as a single record. + NativeSumcheckRecordLayout { + metadata: NativeSumcheckMetadata { num_rows: 1 }, + } + } } #[derive(derive_new::new, Copy, Clone)] @@ -93,7 +98,7 @@ impl Default for NativeSumcheckExecutor { impl PreflightExecutor for NativeSumcheckExecutor where F: PrimeField32, - for<'buf> RA: RecordArena<'buf, FriReducedOpeningLayout, FriReducedOpeningRecordMut<'buf, F>>, + for<'buf> RA: RecordArena<'buf, NativeSumcheckRecordLayout, NativeSumcheckRecordMut<'buf, F>>, { fn execute( &self, @@ -112,6 +117,7 @@ where } = instruction; if op == SUMCHECK_LAYER_EVAL.global_opcode() { + /* let mut observation_records: Vec> = vec![]; let mut curr_timestamp: usize = 0; @@ -140,7 +146,7 @@ where is_op_for_cur_sumcheck_round, // This opcode supports two modes of operation: // 1. calculate the expected evaluation of two types of sumchecks for the current round // a. product sumcheck: v' = v[0] * v[1] - // b. logup sumcheck: p'= p[0] * q[1] + p[1] * q[0] and q'= q[0] * q[1]. + // b. logup sumcheck: p'= p[0] * q[1] + p[1] * q[0] and q'= q[0] * q[1]. // 2. calculate the expected value of next layer: // a. product sumcheck: v[r] = eq(0,r) * v[0] + eq(1,r) * v[1] // b. logup sumcheck: p[r] = eq(0,r) * p[0] + eq(1,r) * p[1] and q[r] = eq(0,r) * q[0] + eq(1,r) * q[1] @@ -407,14 +413,12 @@ where observation_records[last_idx].continuation = false; self.record_set.extend(observation_records); + */ } else { unreachable!() } - Ok(ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }) + Ok(()) } // GKR layered IOP for product and logup relations @@ -427,6 +431,7 @@ where impl TraceFiller for NativeSumcheckFiller { fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { todo!(); + /* let slice = &mut flat_trace[used_cells..used_cells + width]; let cols: &mut NativeSumcheckCols = slice.borrow_mut(); cols.first_timestamp = F::from_canonical_u32(record.from_state.timestamp); @@ -561,5 +566,6 @@ impl TraceFiller for NativeSumcheckFiller { } } } + */ } } diff --git a/extensions/native/circuit/src/sumcheck/execution.rs b/extensions/native/circuit/src/sumcheck/execution.rs index ce09c8fc3d..e8f781a4ae 100644 --- a/extensions/native/circuit/src/sumcheck/execution.rs +++ b/extensions/native/circuit/src/sumcheck/execution.rs @@ -5,32 +5,32 @@ use std::{ use openvm_circuit::{arch::*, system::memory::online::GuestMemory}; use openvm_circuit_primitives::AlignedBytesBorrow; -use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP}; -use openvm_native_compiler::conversion::AS; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, NATIVE_AS}; use openvm_stark_backend::p3_field::PrimeField32; -use super::{elem_to_ext, FriReducedOpeningExecutor}; -use crate::field_extension::{FieldExtension, EXT_DEG}; +use crate::{ + field_extension::{FieldExtension, EXT_DEG}, + fri::elem_to_ext, + sumcheck::chip::{calculate_3d_ext_idx, NativeSumcheckExecutor}, +}; #[derive(AlignedBytesBorrow, Clone)] #[repr(C)] struct NativeSumcheckPreCompute { - a_ptr_ptr: u32, - b_ptr_ptr: u32, - length_ptr: u32, - alpha_ptr: u32, - result_ptr: u32, - hint_id_ptr: u32, - is_init_ptr: u32, + r_evals_reg: u32, + ctx_reg: u32, + challenges_reg: u32, + prod_evals_reg: u32, + logup_evals_reg: u32, } -impl NativeSumcheckPreCompute { +impl NativeSumcheckExecutor { #[inline(always)] fn pre_compute_impl( &self, - _pc: u32, + pc: u32, inst: &Instruction, - data: &mut FriReducedOpeningPreCompute, + data: &mut NativeSumcheckPreCompute, ) -> Result<(), StaticProgramError> { let &Instruction { a, @@ -43,22 +43,25 @@ impl NativeSumcheckPreCompute { .. } = inst; - let a_ptr_ptr = a.as_canonical_u32(); - let b_ptr_ptr = b.as_canonical_u32(); - let length_ptr = c.as_canonical_u32(); - let alpha_ptr = d.as_canonical_u32(); - let result_ptr = e.as_canonical_u32(); - let hint_id_ptr = f.as_canonical_u32(); - let is_init_ptr = g.as_canonical_u32(); - - *data = FriReducedOpeningPreCompute { - a_ptr_ptr, - b_ptr_ptr, - length_ptr, - alpha_ptr, - result_ptr, - hint_id_ptr, - is_init_ptr, + let r_evals_reg = a.as_canonical_u32(); + let ctx_reg = b.as_canonical_u32(); + let challenges_reg = c.as_canonical_u32(); + let prod_evals_reg = f.as_canonical_u32(); + let logup_evals_reg = g.as_canonical_u32(); + + if d.as_canonical_u32() != NATIVE_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + if e.as_canonical_u32() != NATIVE_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + + *data = NativeSumcheckPreCompute { + r_evals_reg, + ctx_reg, + challenges_reg, + prod_evals_reg, + logup_evals_reg, }; Ok(()) @@ -190,5 +193,144 @@ unsafe fn execute_e12_impl( pc: &mut u32, exec_state: &mut VmExecState, ) -> u32 { - todo!() + let [r_evals_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.r_evals_reg); + let [ctx_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.ctx_reg); + let [challenges_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.challenges_reg); + let [prod_evals_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.prod_evals_reg); + let [logup_evals_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.logup_evals_reg); + + let r_evals_ptr_u32 = r_evals_ptr.as_canonical_u32(); + let ctx_ptr_u32 = ctx_ptr.as_canonical_u32(); + let logup_evals_ptr = logup_evals_ptr.as_canonical_u32(); + let prod_evals_ptr = prod_evals_ptr.as_canonical_u32(); + + let ctx: [u32; 8] = exec_state + .vm_read(NATIVE_AS, ctx_ptr_u32) + .map(|x: F| x.as_canonical_u32()); + let [round, num_prod_spec, num_logup_spec, prod_specs_inner_len, prod_specs_inner_inner_len, logup_specs_inner_len, logup_specs_inner_inner_len, mode] = + ctx; + let challenges: [F; EXT_DEG * 3] = + exec_state.vm_read(NATIVE_AS, challenges_ptr.as_canonical_u32()); + let alpha: [F; EXT_DEG] = challenges[0..EXT_DEG].try_into().unwrap(); + let c1: [F; EXT_DEG] = challenges[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); + let c2: [F; EXT_DEG] = challenges[EXT_DEG * 2..EXT_DEG * 3].try_into().unwrap(); + + let mut height = 0; + let mut alpha_acc = elem_to_ext(F::ONE); + let mut eval_acc = elem_to_ext(F::ZERO); + + for i in 0..num_prod_spec { + let [max_round]: [u32; 1] = exec_state + .vm_read(NATIVE_AS, ctx_ptr_u32 + 8) + .map(|x: F| x.as_canonical_u32()); + + let start = calculate_3d_ext_idx( + prod_specs_inner_inner_len, + prod_specs_inner_len, + i, + round, + 0, + ); + + if round < max_round - 1 { + let ps: [F; EXT_DEG * 2] = exec_state.vm_read(NATIVE_AS, prod_evals_ptr + start); + let p1: [F; EXT_DEG] = ps[0..EXT_DEG].try_into().unwrap(); + let p2: [F; EXT_DEG] = ps[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); + + let eval = match mode { + 1 => FieldExtension::multiply(p1, p2), + 0 => FieldExtension::add( + FieldExtension::multiply(p1, c1), + FieldExtension::multiply(p2, c2), + ), + _ => unreachable!("mode can only be 0 or 1"), + }; + + exec_state.vm_write(NATIVE_AS, r_evals_ptr_u32 + 1 + i, &eval); + + if round + mode < max_round - 1 { + // update eval_acc + eval_acc = FieldExtension::add(eval_acc, FieldExtension::multiply(alpha_acc, eval)); + } + } + + // update alpha_acc + alpha_acc = FieldExtension::multiply(alpha_acc, alpha); + height += 1; + } + + for i in 0..num_logup_spec { + // read max_round + let [max_round]: [u32; 1] = exec_state + .vm_read(NATIVE_AS, ctx_ptr_u32 + 8 + num_prod_spec + i) + .map(|x: F| x.as_canonical_u32()); + let start = calculate_3d_ext_idx( + prod_specs_inner_inner_len, + prod_specs_inner_len, + i, + round, + 0, + ); + + if round < max_round - 1 { + // read logup_evals + let pqs: [F; EXT_DEG * 4] = exec_state.vm_read(NATIVE_AS, logup_evals_ptr + start); + let p1: [F; EXT_DEG] = pqs[0..EXT_DEG].try_into().unwrap(); + let p2: [F; EXT_DEG] = pqs[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); + let q1: [F; EXT_DEG] = pqs[EXT_DEG * 2..EXT_DEG * 3].try_into().unwrap(); + let q2: [F; EXT_DEG] = pqs[EXT_DEG * 3..EXT_DEG * 4].try_into().unwrap(); + + // compute p_eval and q_eval + let p_eval = match mode { + 1 => FieldExtension::add( + FieldExtension::multiply(p1, q2), + FieldExtension::multiply(p2, q1), + ), + 0 => FieldExtension::add( + FieldExtension::multiply(p1, c1), + FieldExtension::multiply(p2, c2), + ), + _ => unreachable!("mode can only be 0 or 1"), + }; + let q_eval = match mode { + 1 => FieldExtension::multiply(q1, q2), + 0 => FieldExtension::add( + FieldExtension::multiply(q1, c1), + FieldExtension::multiply(q2, c2), + ), + _ => unreachable!("mode can only be 0 or 1"), + }; + + // write eval to r_evals + exec_state.vm_write( + NATIVE_AS, + r_evals_ptr_u32 + (1 + num_prod_spec + i) * EXT_DEG as u32, + &p_eval, + ); + exec_state.vm_write( + NATIVE_AS, + r_evals_ptr_u32 + (1 + num_prod_spec + num_logup_spec + i) * EXT_DEG as u32, + &q_eval, + ); + + let alpha_denominator = FieldExtension::multiply(alpha_acc, alpha); + let alpha_numerator = alpha_acc; + + if round + mode < max_round - 1 { + // update eval_acc + eval_acc = FieldExtension::add( + FieldExtension::multiply(alpha_numerator, p_eval), + FieldExtension::multiply(alpha_denominator, q_eval), + ); + } + } + + // update alpha_acc + alpha_acc = FieldExtension::multiply(alpha_acc, FieldExtension::multiply(alpha, alpha)); + height += 1; + } + + exec_state.vm_write(NATIVE_AS, r_evals_ptr_u32, &eval_acc); + // return height delta + height } diff --git a/extensions/native/compiler/src/ir/sumcheck.rs b/extensions/native/compiler/src/ir/sumcheck.rs index ff5823600f..1649c3f11d 100644 --- a/extensions/native/compiler/src/ir/sumcheck.rs +++ b/extensions/native/compiler/src/ir/sumcheck.rs @@ -35,7 +35,7 @@ impl Builder { input_ctx: &Array>, // Context variables challenges: &Array>, // Challenges prod_specs_eval: &Array>, /* GKR product IOP evaluations. Flattened - * from 3D array. */ + * from 3D array. */ logup_specs_eval: &Array>, /* GKR logup IOP evaluations. Flattened * from 3D array. */ r_evals: &Array>, /* Next layer's evaluations (pointer used for diff --git a/extensions/native/recursion/tests/sumcheck.rs b/extensions/native/recursion/tests/sumcheck.rs index 607c2a783b..cfd3527adf 100644 --- a/extensions/native/recursion/tests/sumcheck.rs +++ b/extensions/native/recursion/tests/sumcheck.rs @@ -3,7 +3,9 @@ use openvm_circuit::arch::{ instructions::program::Program, verify_single, SystemConfig, VirtualMachine, VmConfig, VmExecutor, }; -use openvm_native_circuit::{Native, NativeConfig, EXT_DEG}; +#[cfg(not(feature = "cuda"))] +use openvm_circuit::utils::air_test_impl; +use openvm_native_circuit::{NativeBuilder, NativeConfig, EXT_DEG}; use openvm_native_compiler::{ asm::{AsmBuilder, AsmCompiler}, conversion::{convert_program, CompilerOptions}, @@ -30,9 +32,9 @@ use openvm_stark_sdk::{ p3_baby_bear::BabyBear, utils::{create_seeded_rng, ProofInputForTest}, }; -use rand::Rng; + pub type F = BabyBear; -pub type E = BinomialExtensionField; +pub type E = BinomialExtensionField; #[test] fn test_sumcheck_layer_eval() { @@ -67,14 +69,22 @@ fn test_sumcheck_layer_eval() { let mut config = NativeConfig::aggregation(0, sumcheck_max_constraint_degree); config.system.memory_config.max_access_adapter_n = 16; - let vm = VirtualMachine::new(engine, config); - - let pk = vm.keygen(); - let result = vm.execute_and_generate(program, vec![]).unwrap(); - let proofs = vm.prove(&pk, result); - - for proof in proofs { - verify_single(&vm.engine, &pk.get_vk(), &proof).expect("Verification failed"); + let vb = NativeBuilder::default(); + #[cfg(not(feature = "cuda"))] + air_test_impl::(fri_params, vb, config, program, vec![], 1, true) + .unwrap(); + #[cfg(feature = "cuda")] + { + air_test_impl::( + fri_params, + vb, + config, + program, + vec![], + 1, + true, + ) + .unwrap(); } } From 240beb60092b2cf35a32073cc7bfb96ed5817bc7 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Thu, 27 Nov 2025 18:16:30 +0800 Subject: [PATCH 03/18] wip3 --- extensions/native/circuit/src/extension/mod.rs | 2 +- extensions/native/circuit/src/sumcheck/execution.rs | 9 ++++++--- extensions/native/recursion/tests/sumcheck.rs | 7 ++----- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/extensions/native/circuit/src/extension/mod.rs b/extensions/native/circuit/src/extension/mod.rs index f175a9c4de..9f3e2035ad 100644 --- a/extensions/native/circuit/src/extension/mod.rs +++ b/extensions/native/circuit/src/extension/mod.rs @@ -179,7 +179,7 @@ impl VmExecutionExtension for Native { inventory.add_executor( tower_verify, [SumcheckOpcode::SUMCHECK_LAYER_EVAL.global_opcode()], - ); + )?; inventory.add_phantom_sub_executor( NativeHintInputSubEx, diff --git a/extensions/native/circuit/src/sumcheck/execution.rs b/extensions/native/circuit/src/sumcheck/execution.rs index e8f781a4ae..a4b4f484f3 100644 --- a/extensions/native/circuit/src/sumcheck/execution.rs +++ b/extensions/native/circuit/src/sumcheck/execution.rs @@ -215,7 +215,7 @@ unsafe fn execute_e12_impl( let c1: [F; EXT_DEG] = challenges[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); let c2: [F; EXT_DEG] = challenges[EXT_DEG * 2..EXT_DEG * 3].try_into().unwrap(); - let mut height = 0; + let mut height = 1; let mut alpha_acc = elem_to_ext(F::ONE); let mut eval_acc = elem_to_ext(F::ZERO); @@ -265,8 +265,8 @@ unsafe fn execute_e12_impl( .vm_read(NATIVE_AS, ctx_ptr_u32 + 8 + num_prod_spec + i) .map(|x: F| x.as_canonical_u32()); let start = calculate_3d_ext_idx( - prod_specs_inner_inner_len, - prod_specs_inner_len, + logup_specs_inner_len, + logup_specs_inner_inner_len, i, round, 0, @@ -330,6 +330,9 @@ unsafe fn execute_e12_impl( height += 1; } + *pc += DEFAULT_PC_STEP; + *instret += 1; + exec_state.vm_write(NATIVE_AS, r_evals_ptr_u32, &eval_acc); // return height delta height diff --git a/extensions/native/recursion/tests/sumcheck.rs b/extensions/native/recursion/tests/sumcheck.rs index cfd3527adf..d01c63729c 100644 --- a/extensions/native/recursion/tests/sumcheck.rs +++ b/extensions/native/recursion/tests/sumcheck.rs @@ -30,7 +30,6 @@ use openvm_stark_sdk::{ }, engine::StarkFriEngine, p3_baby_bear::BabyBear, - utils::{create_seeded_rng, ProofInputForTest}, }; pub type F = BabyBear; @@ -42,15 +41,11 @@ fn test_sumcheck_layer_eval() { build_test_program(&mut builder); - // Fill in test program logic - builder.halt(); - let compilation_options = CompilerOptions::default().with_cycle_tracker(); let mut compiler = AsmCompiler::new(compilation_options.word_size); compiler.build(builder.operations); let asm_code = compiler.code(); - // let program = Program::from_instructions(&instructions); let program: Program<_> = convert_program(asm_code, compilation_options); let sumcheck_max_constraint_degree = 3; let fri_params = if matches!(std::env::var("OPENVM_FAST_TEST"), Ok(x) if &x == "1") { @@ -1268,4 +1263,6 @@ fn build_test_program(builder: &mut Builder) { &logup_spec_evals, &next_layer_evals, ); + + builder.halt(); } From fee9d2cc75efb6bbbb49b7b5b128d32c21e74fe5 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Thu, 27 Nov 2025 21:19:58 +0800 Subject: [PATCH 04/18] wip4 --- .../native/circuit/src/poseidon2/chip.rs | 35 +- .../native/circuit/src/sumcheck/chip.rs | 757 ++++++++---------- extensions/native/circuit/src/utils.rs | 27 + 3 files changed, 367 insertions(+), 452 deletions(-) diff --git a/extensions/native/circuit/src/poseidon2/chip.rs b/extensions/native/circuit/src/poseidon2/chip.rs index aecff9f10f..d3fe6e3e38 100644 --- a/extensions/native/circuit/src/poseidon2/chip.rs +++ b/extensions/native/circuit/src/poseidon2/chip.rs @@ -23,12 +23,16 @@ use openvm_stark_backend::{ p3_maybe_rayon::prelude::{IntoParallelIterator, ParallelSliceMut, *}, }; -use crate::poseidon2::{ - columns::{ - InsideRowSpecificCols, MultiObserveCols, NativePoseidon2Cols, SimplePoseidonSpecificCols, - TopLevelSpecificCols, +use crate::{ + mem_fill_helper, + poseidon2::{ + columns::{ + InsideRowSpecificCols, MultiObserveCols, NativePoseidon2Cols, + SimplePoseidonSpecificCols, TopLevelSpecificCols, + }, + CHUNK, }, - CHUNK, + tracing_read_native_helper, }; #[derive(Clone)] @@ -1240,24 +1244,3 @@ impl NativePoseidon2Filler( - memory: &mut TracingMemory, - ptr: u32, - base_aux: &mut MemoryBaseAuxCols, -) -> [F; BLOCK_SIZE] { - let mut prev_ts = 0; - let ret = tracing_read_native(memory, ptr, &mut prev_ts); - base_aux.set_prev(F::from_canonical_u32(prev_ts)); - ret -} - -/// Fill `MemoryBaseAuxCols`, assuming that the `prev_timestamp` is already set in `base_aux`. -fn mem_fill_helper( - mem_helper: &MemoryAuxColsFactory, - timestamp: u32, - base_aux: &mut MemoryBaseAuxCols, -) { - let prev_ts = base_aux.prev_timestamp.as_canonical_u32(); - mem_helper.fill(prev_ts, timestamp, base_aux); -} diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index 09e362f8ac..e6ffd25a4a 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -1,31 +1,41 @@ -use std::sync::{Arc, Mutex}; +use std::borrow::{Borrow, BorrowMut}; +use itertools::Itertools; use openvm_circuit::{ arch::{ - CustomBorrow, ExecutionBridge, ExecutionError, ExecutionState, MultiRowLayout, - MultiRowMetadata, PreflightExecutor, RecordArena, Streams, TraceFiller, VmChipWrapper, - VmStateMut, + CustomBorrow, ExecutionError, ExecutionState, MultiRowLayout, MultiRowMetadata, + PreflightExecutor, RecordArena, Streams, TraceFiller, VmChipWrapper, VmStateMut, + }, + system::{ + memory::{online::TracingMemory, MemoryAuxColsFactory, MemoryController}, + native_adapter::util::{ + memory_read_native, tracing_read_native, tracing_write_native_inplace, + }, }, - system::memory::{online::TracingMemory, MemoryAuxColsFactory, MemoryController}, }; use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; use openvm_native_compiler::{conversion::AS, SumcheckOpcode::SUMCHECK_LAYER_EVAL}; use openvm_stark_backend::{ p3_field::{Field, PrimeField, PrimeField32}, - p3_maybe_rayon::prelude::{ParallelIterator, ParallelSlice}, + p3_matrix::{dense::RowMajorMatrix, Matrix}, + p3_maybe_rayon::prelude::{IntoParallelIterator, ParallelIterator, ParallelSlice}, }; -use serde::{Deserialize, Serialize}; use crate::{ field_extension::{FieldExtension, EXT_DEG}, fri::elem_to_ext, + mem_fill_helper, sumcheck::{ air::NativeSumcheckAir, columns::{HeaderSpecificCols, LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols}, }, + tracing_read_native_helper, utils::const_max, }; + const CONTEXT_ARR_BASE_LEN: usize = EXT_DEG * 2; +const CURRENT_LAYER_MODE: u32 = 1; +const NEXT_LAYER_MODE: u32 = 0; pub(crate) fn calculate_3d_ext_idx( inner_inner_len: u32, @@ -107,317 +117,306 @@ where ) -> Result<(), ExecutionError> { let &Instruction { opcode: op, - a: output_register, - b: input_register_1, - c: input_register_2, + a: r_evals_reg, + b: ctx_reg, + c: challenges_reg, d: data_address_space, e: register_address_space, - f: input_register_3, - g: input_register_4, + f: prod_evals_reg, + g: logup_evals_reg, } = instruction; - if op == SUMCHECK_LAYER_EVAL.global_opcode() { - /* - let mut observation_records: Vec> = vec![]; - let mut curr_timestamp: usize = 0; - - let (read_ctx_pointer, ctx_pointer) = - memory.read_cell(register_address_space, input_register_1); - let (read_cs_pointer, cs_pointer) = - memory.read_cell(register_address_space, input_register_2); - let (read_prod_pointer, prod_ptr) = - memory.read_cell(register_address_space, input_register_3); - let (read_logup_pointer, logup_ptr) = - memory.read_cell(register_address_space, input_register_4); - let (read_result_pointer, r_ptr) = - memory.read_cell(register_address_space, output_register); - let register_ptrs: [F; 5] = [ctx_pointer, cs_pointer, prod_ptr, logup_ptr, r_ptr]; - - let (ctx_read, ctx): (RecordId, [F; EXT_DEG * 2]) = - memory.read::<{ EXT_DEG * 2 }>(data_address_space, ctx_pointer); - let [ - round, - num_prod_spec, - num_logup_spec, - prod_specs_inner_len, - prod_specs_inner_inner_len, - logup_specs_inner_len, - logup_specs_inner_inner_len, - is_op_for_cur_sumcheck_round, // This opcode supports two modes of operation: - // 1. calculate the expected evaluation of two types of sumchecks for the current round - // a. product sumcheck: v' = v[0] * v[1] - // b. logup sumcheck: p'= p[0] * q[1] + p[1] * q[0] and q'= q[0] * q[1]. - // 2. calculate the expected value of next layer: - // a. product sumcheck: v[r] = eq(0,r) * v[0] + eq(1,r) * v[1] - // b. logup sumcheck: p[r] = eq(0,r) * p[0] + eq(1,r) * p[1] and q[r] = eq(0,r) * q[0] + eq(1,r) * q[1] - ] = ctx; - - let (challenges_read, challenges): (RecordId, [F; EXT_DEG * 4]) = - memory.read::<{ EXT_DEG * 4 }>(data_address_space, cs_pointer); - let alpha: [F; 4] = challenges[0..EXT_DEG].try_into().expect(""); - - let mut header_row = SumcheckEvalRecord { - from_state, - instruction: instruction.clone(), - row_type: 0, - continuation: true, - curr_timestamp_increment: curr_timestamp, - register_ptrs, - alpha, - registers: [ - input_register_1, - input_register_2, - input_register_3, - input_register_4, - output_register, - ], - ctx, - challenges, - read_data_records: [ - read_ctx_pointer, - read_cs_pointer, - read_prod_pointer, - read_logup_pointer, - read_result_pointer, - ctx_read, - challenges_read, - ], - ..Default::default() - }; - - observation_records.push(header_row); - self.height += 1; - curr_timestamp += 7; - - let mut eval_acc = elem_to_ext(F::from_canonical_u32(0)); - let mut alpha_acc = elem_to_ext(F::from_canonical_u32(1)); - let c1: [F; 4] = challenges[EXT_DEG..(EXT_DEG * 2)].try_into().expect(""); - let c2: [F; 4] = challenges[(EXT_DEG * 2)..(EXT_DEG * 3)] - .try_into() - .expect(""); - - let mut i = F::ZERO; - let mut i_usize = 0usize; - while i < num_prod_spec { - let mut prod_row: SumcheckEvalRecord = SumcheckEvalRecord { - from_state, - instruction: instruction.clone(), - row_type: 1, - continuation: true, - curr_timestamp_increment: curr_timestamp, - register_ptrs, - ctx, - challenges, - alpha, - prod_spec_n: i_usize, - ..Default::default() - }; - prod_row.alpha1 = alpha_acc; + // This opcode supports two modes of operation: + // 1. calculate the expected evaluation of two types of sumchecks for the current round + // a. product sumcheck: v' = v[0] * v[1] + // b. logup sumcheck: p'= p[0] * q[1] + p[1] * q[0] and q'= q[0] * q[1]. + // 2. calculate the expected value of next layer: + // a. product sumcheck: v[r] = eq(0,r) * v[0] + eq(1,r) * v[1] + // b. logup sumcheck: p[r] = eq(0,r) * p[0] + eq(1,r) * p[1] + // and q[r] = eq(0,r) * q[0] + eq(1,r) * q[1] + assert_eq!(op, SUMCHECK_LAYER_EVAL.global_opcode()); + + let [ctx_ptr]: [F; 1] = memory_read_native(state.memory.data(), ctx_reg.as_canonical_u32()); + let ctx: [u32; 8] = memory_read_native(state.memory.data(), ctx_ptr.as_canonical_u32()) + .map(|x: F| x.as_canonical_u32()); + + let [round, num_prod_spec, num_logup_spec, prod_specs_inner_len, prod_specs_inner_inner_len, logup_specs_inner_len, logup_specs_inner_inner_len, mode] = + ctx; + // allocate n rows + let num_rows = (1 + num_prod_spec + num_logup_spec) as usize; + println!("num_rows = {}", num_rows); + let rows = state + .ctx + .alloc(MultiRowLayout::new(NativeSumcheckMetadata { num_rows })) + .0; + + let mut cur_timestamp = 0; + // head row + let head_row: &mut NativeSumcheckCols = &mut rows[0]; + let head_specific: &mut HeaderSpecificCols = + head_row.specific[..HeaderSpecificCols::::width()].borrow_mut(); + + head_row.header_row = F::ONE; + + head_specific.pc = F::from_canonical_u32(*state.pc); + + head_specific.registers[0] = ctx_reg; + head_specific.registers[1] = challenges_reg; + head_specific.registers[2] = prod_evals_reg; + head_specific.registers[3] = logup_evals_reg; + head_specific.registers[4] = r_evals_reg; + + // read pointers + let [ctx_ptr]: [F; 1] = tracing_read_native_helper( + state.memory, + ctx_reg.as_canonical_u32(), + head_specific.read_records[0].as_mut(), + ); + let [challenges_ptr]: [F; 1] = tracing_read_native_helper( + state.memory, + challenges_reg.as_canonical_u32(), + head_specific.read_records[1].as_mut(), + ); + let [prod_evals_ptr]: [F; 1] = tracing_read_native_helper( + state.memory, + prod_evals_reg.as_canonical_u32(), + head_specific.read_records[2].as_mut(), + ); + let [logup_evals_ptr]: [F; 1] = tracing_read_native_helper( + state.memory, + logup_evals_reg.as_canonical_u32(), + head_specific.read_records[3].as_mut(), + ); + let [r_evals_ptr]: [F; 1] = tracing_read_native_helper( + state.memory, + r_evals_reg.as_canonical_u32(), + head_specific.read_records[4].as_mut(), + ); - let (read_max_round, max_round) = memory.read_cell( - data_address_space, - ctx_pointer + F::from_canonical_usize(CONTEXT_ARR_BASE_LEN) + i, - ); - prod_row.max_round = max_round; - prod_row.read_data_records[0] = read_max_round; - curr_timestamp += 1; - - if round < (max_round - F::from_canonical_usize(1)) { - prod_row.within_round_limit = true; - let start = calculate_3d_ext_idx( - prod_specs_inner_inner_len, - prod_specs_inner_len, - i, - round, - F::from_canonical_usize(0), - ); - prod_row.data_ptr = start; - - let (read_p, ps) = - memory.read::<{ EXT_DEG * 2 }>(data_address_space, prod_ptr + start); - let p1: [F; 4] = ps[0..EXT_DEG].try_into().expect(""); - let p2: [F; 4] = ps[EXT_DEG..(EXT_DEG * 2)].try_into().expect(""); - - prod_row.read_data_records[1] = read_p; - prod_row.p1 = p1; - prod_row.p2 = p2; - - let evals = if is_op_for_cur_sumcheck_round > F::ZERO { - FieldExtension::multiply(p1, p2) - } else { - FieldExtension::add( - FieldExtension::multiply(p1, c1), - FieldExtension::multiply(p2, c2), - ) - }; - prod_row.p_evals = evals; - - let (write_slice_eval_1, _) = memory.write::( - data_address_space, - r_ptr + (F::ONE + i) * F::from_canonical_usize(EXT_DEG), - evals, - ); - prod_row.write_data_records[0] = write_slice_eval_1; + let ctx: [F; CONTEXT_ARR_BASE_LEN] = tracing_read_native_helper( + state.memory, + ctx_ptr.as_canonical_u32(), + head_specific.read_records[5].as_mut(), + ); - let is_op_for_next_sumcheck_round = F::ONE - is_op_for_cur_sumcheck_round; - let acc_eval = FieldExtension::multiply(alpha_acc, evals); - prod_row.acc_eval = acc_eval; + let challenges: [F; EXT_DEG * 4] = tracing_read_native_helper( + state.memory, + challenges_ptr.as_canonical_u32(), + head_specific.read_records[6].as_mut(), + ); + cur_timestamp += 7; // 5 register reads + ctx read + challenges read + + // challenges = [alpha, c1=r, c2=1-r] + let alpha: [F; 4] = challenges[0..EXT_DEG].try_into().unwrap(); + let c1: [F; 4] = challenges[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap(); + let c2: [F; 4] = challenges[(EXT_DEG * 2)..(EXT_DEG * 3)].try_into().unwrap(); + + let mut eval_acc = elem_to_ext(F::from_canonical_u32(0)); + let mut alpha_acc = elem_to_ext(F::from_canonical_u32(1)); + // TODO: write final eval + + // all rows share same register values, ctx, challenges + for row in rows.iter_mut() { + row.challenges = challenges; + row.alpha = alpha; + row.ctx = ctx; + row.register_ptrs[0] = ctx_ptr; + row.register_ptrs[1] = challenges_ptr; + row.register_ptrs[2] = prod_evals_ptr; + row.register_ptrs[3] = logup_evals_ptr; + row.register_ptrs[4] = r_evals_ptr; + } - if (round + is_op_for_next_sumcheck_round) - < (max_round - F::from_canonical_usize(1)) - { - eval_acc = FieldExtension::add(eval_acc, acc_eval); - prod_row.should_acc = true; - prod_row.eval_acc = eval_acc.clone(); - } + // product rows + for (i, prod_row) in rows + .iter_mut() + .skip(1) + .take(num_prod_spec as usize) + .enumerate() + { + let prod_specific: &mut ProdSpecificCols = + prod_row.specific[..ProdSpecificCols::::width()].borrow_mut(); + + prod_row.prod_row = F::ONE; + prod_row.curr_prod_n = F::from_canonical_usize(i); + + // read max_round + let [max_round]: [F; 1] = tracing_read_native_helper( + state.memory, + ctx_ptr.as_canonical_u32() + (CONTEXT_ARR_BASE_LEN + i) as u32, + prod_specific.read_records[0].as_mut(), + ); + cur_timestamp += 1; + + prod_row.alpha = alpha_acc; + prod_row.max_round = max_round; + + // round starts from 0 + if round < max_round.as_canonical_u32() - 1 { + prod_row.within_round_limit = F::ONE; + let start = calculate_3d_ext_idx( + prod_specs_inner_inner_len, + prod_specs_inner_len, + i as u32, + round, + 0, + ); + prod_specific.data_ptr = F::from_canonical_u32(start); - curr_timestamp += 2; - } + // read p1, p2 + let ps: [F; EXT_DEG * 2] = tracing_read_native_helper( + state.memory, + prod_evals_ptr.as_canonical_u32() + start, + prod_specific.read_records[1].as_mut(), + ); + let p1: [F; EXT_DEG] = ps[0..EXT_DEG].try_into().unwrap(); + let p2: [F; EXT_DEG] = ps[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap(); + + prod_specific.p = ps; + + // compute expected eval + let eval = match mode { + NEXT_LAYER_MODE => FieldExtension::add( + FieldExtension::multiply(p1, c1), + FieldExtension::multiply(p2, c2), + ), + CURRENT_LAYER_MODE => FieldExtension::multiply(p1, p2), + _ => unreachable!("mode should be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), + }; + prod_specific.p_evals = eval; + + // write p eval + tracing_write_native_inplace( + state.memory, + r_evals_ptr.as_canonical_u32() + (1 + i as u32) * (EXT_DEG as u32), + eval, + &mut prod_specific.write_record, + ); + cur_timestamp += 1; - alpha_acc = FieldExtension::multiply(alpha_acc, alpha); + let acc_eval = FieldExtension::multiply(alpha_acc, eval); + prod_row.eval_acc = acc_eval; - i = i + F::ONE; - i_usize += 1; - observation_records.push(prod_row); - self.height += 1; + if mode == NEXT_LAYER_MODE && round < max_round.as_canonical_u32() - 2 { + eval_acc = FieldExtension::add(eval_acc, acc_eval); + prod_row.should_acc = F::ONE; + } } - let mut i = F::ZERO; - let mut i_usize = 0usize; - while i < num_logup_spec { - let mut logup_row: SumcheckEvalRecord = SumcheckEvalRecord { - from_state, - instruction: instruction.clone(), - row_type: 2, - continuation: true, - curr_timestamp_increment: curr_timestamp, - register_ptrs, - ctx, - challenges, - alpha, - logup_spec_n: i_usize, - ..Default::default() - }; - logup_row.alpha1 = alpha_acc; + prod_row.alpha = FieldExtension::multiply(alpha_acc, alpha); + } - let (read_max_round, max_round) = memory.read_cell( - data_address_space, - ctx_pointer + F::from_canonical_usize(CONTEXT_ARR_BASE_LEN) + num_prod_spec + i, + // logup rows + for (i, logup_row) in rows.iter_mut().skip(1 + num_prod_spec as usize).enumerate() { + let logup_specific: &mut LogupSpecificCols = + logup_row.specific[..LogupSpecificCols::::width()].borrow_mut(); + + logup_row.logup_row = F::ONE; + + let [max_round]: [F; 1] = tracing_read_native_helper( + state.memory, + ctx_ptr.as_canonical_u32() + num_prod_spec + (CONTEXT_ARR_BASE_LEN + i) as u32, + logup_specific.read_records[0].as_mut(), + ); + logup_row.max_round = max_round; + cur_timestamp += 1; + + if round < max_round.as_canonical_u32() - 1 { + logup_row.within_round_limit = F::ONE; + let start = calculate_3d_ext_idx( + prod_specs_inner_inner_len, + prod_specs_inner_len, + i as u32, + round, + 0, ); - logup_row.max_round = max_round; - logup_row.read_data_records[0] = read_max_round; - curr_timestamp += 1; - - if round < (max_round - F::from_canonical_usize(1)) { - logup_row.within_round_limit = true; - let start = calculate_3d_ext_idx( - logup_specs_inner_inner_len, - logup_specs_inner_len, - i, - round, - F::from_canonical_usize(0), - ); - logup_row.data_ptr = start; - - let (read_pqs, pqs) = - memory.read::<{ EXT_DEG * 4 }>(data_address_space, logup_ptr + start); - let p1: [F; 4] = pqs[0..EXT_DEG].try_into().expect(""); - let p2: [F; 4] = pqs[EXT_DEG..(EXT_DEG * 2)].try_into().expect(""); - let q1: [F; 4] = pqs[(EXT_DEG * 2)..(EXT_DEG * 3)].try_into().expect(""); - let q2: [F; 4] = pqs[(EXT_DEG * 3)..(EXT_DEG * 4)].try_into().expect(""); - - logup_row.read_data_records[1] = read_pqs; - logup_row.p1 = p1; - logup_row.p2 = p2; - logup_row.q1 = q1; - logup_row.q2 = q2; - - let p_evals = if is_op_for_cur_sumcheck_round > F::ZERO { - FieldExtension::add( - FieldExtension::multiply(p1, q2), - FieldExtension::multiply(p2, q1), - ) - } else { - FieldExtension::add( - FieldExtension::multiply(p1, c1), - FieldExtension::multiply(p2, c2), - ) - }; - - let q_evals = if is_op_for_cur_sumcheck_round > F::ZERO { - FieldExtension::multiply(q1, q2) - } else { - FieldExtension::add( - FieldExtension::multiply(q1, c1), - FieldExtension::multiply(q2, c2), - ) - }; - - logup_row.p_evals = p_evals; - logup_row.q_evals = q_evals; - - let (write_slice_eval_1, _) = memory.write::( - data_address_space, - r_ptr + (F::ONE + num_prod_spec + i) * F::from_canonical_usize(EXT_DEG), - p_evals, - ); - let (write_slice_eval_2, _) = memory.write::( - data_address_space, - r_ptr - + (F::ONE + num_prod_spec + num_logup_spec + i) - * F::from_canonical_usize(EXT_DEG), - q_evals, - ); + logup_specific.data_ptr = F::from_canonical_u32(start); - logup_row.write_data_records[0] = write_slice_eval_1; - logup_row.write_data_records[1] = write_slice_eval_2; - - let is_op_for_next_sumcheck_round = F::ONE - is_op_for_cur_sumcheck_round; - let alpha_denominator = FieldExtension::multiply(alpha_acc, alpha); - logup_row.alpha2 = alpha_denominator; - - if (round + is_op_for_next_sumcheck_round) - < (max_round - F::from_canonical_usize(1)) - { - let acc_eval = FieldExtension::add( - FieldExtension::multiply(alpha_acc, p_evals), - FieldExtension::multiply(alpha_denominator, q_evals), - ); - logup_row.acc_eval = acc_eval; - eval_acc = FieldExtension::add(eval_acc, acc_eval); - logup_row.should_acc = true; - logup_row.eval_acc = eval_acc.clone(); - } - - curr_timestamp += 3; - } + // read p1, p2, q1, q2 + let pqs: [F; EXT_DEG * 4] = tracing_read_native_helper( + state.memory, + logup_evals_ptr.as_canonical_u32() + start, + logup_specific.read_records[1].as_mut(), + ); + let p1: [F; EXT_DEG] = pqs[0..EXT_DEG].try_into().unwrap(); + let p2: [F; EXT_DEG] = pqs[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap(); + let q1: [F; EXT_DEG] = pqs[(EXT_DEG * 2)..(EXT_DEG * 3)].try_into().unwrap(); + let q2: [F; EXT_DEG] = pqs[(EXT_DEG * 3)..(EXT_DEG * 4)].try_into().unwrap(); + + logup_specific.pq = pqs; + + // compute expected evals + let p_eval = match mode { + NEXT_LAYER_MODE => FieldExtension::add( + FieldExtension::multiply(p1, c1), + FieldExtension::multiply(p2, c2), + ), + CURRENT_LAYER_MODE => FieldExtension::add( + FieldExtension::multiply(p1, q2), + FieldExtension::multiply(p2, q1), + ), + _ => unreachable!("mode should be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), + }; + let q_eval = match mode { + NEXT_LAYER_MODE => FieldExtension::add( + FieldExtension::multiply(q1, c1), + FieldExtension::multiply(q2, c2), + ), + CURRENT_LAYER_MODE => FieldExtension::multiply(q1, q2), + _ => unreachable!("mode should be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), + }; - alpha_acc = - FieldExtension::multiply(FieldExtension::multiply(alpha_acc, alpha), alpha); + logup_specific.p_evals = p_eval; + logup_specific.q_evals = q_eval; - i = i + F::ONE; - i_usize += 1; - observation_records.push(logup_row); - self.height += 1; - } + // write p_eval + tracing_write_native_inplace( + state.memory, + r_evals_ptr.as_canonical_u32() + + (1 + num_prod_spec + i as u32) * (EXT_DEG as u32), + p_eval, + &mut logup_specific.write_records[0], + ); + // write q_eval + tracing_write_native_inplace( + state.memory, + r_evals_ptr.as_canonical_u32() + + (1 + num_prod_spec + num_logup_spec + i as u32) * (EXT_DEG as u32), + p_eval, + &mut logup_specific.write_records[1], + ); + cur_timestamp += 3; // 1 read, 2 writes - let (write_r, _) = memory.write::(data_address_space, r_ptr, eval_acc); - curr_timestamp += 1; - observation_records[0].write_data_records[0] = write_r; + let alpha_numerator = alpha_acc; + let alpha_denominator = FieldExtension::multiply(alpha_acc, alpha); - for record in &mut observation_records { - record.final_timestamp_increment = curr_timestamp; - record.eval_acc = FieldExtension::subtract(eval_acc, record.eval_acc); + if mode == NEXT_LAYER_MODE && round < max_round.as_canonical_u32() - 2 { + let eval = FieldExtension::add( + FieldExtension::multiply(alpha_numerator, p_eval), + FieldExtension::multiply(alpha_denominator, q_eval), + ); + logup_specific.acc_eval = eval; + eval_acc = FieldExtension::add(eval_acc, eval); + logup_row.should_acc = F::ONE; + logup_row.eval_acc = eval_acc; + } } - let last_idx = observation_records.len() - 1; - observation_records[last_idx].continuation = false; - self.record_set.extend(observation_records); - */ - } else { - unreachable!() + alpha_acc = FieldExtension::multiply(FieldExtension::multiply(alpha_acc, alpha), alpha); } + let head_row = &mut rows[0]; + let head_specific: &mut HeaderSpecificCols = + head_row.specific[..HeaderSpecificCols::::width()].borrow_mut(); + + tracing_write_native_inplace( + state.memory, + r_evals_ptr.as_canonical_u32(), + eval_acc, + &mut head_specific.write_records, + ); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); Ok(()) } @@ -429,143 +428,49 @@ where } impl TraceFiller for NativeSumcheckFiller { + /* + fn fill_trace( + &self, + mem_helper: &MemoryAuxColsFactory, + trace: &mut RowMajorMatrix, + rows_used: usize, + ) where + F: Send + Sync + Clone, + { + // Split the trace rows by instruction + let width = trace.width(); + let mut row_idx = 0; + let mut row_slice = trace.values.as_mut_slice(); + let mut chunk_start = Vec::new(); + while row_idx < rows_used { + let cols: &NativeSumcheckCols = row_slice[..width].borrow(); + let num_rows = cols.num_rows.as_canonical_u32() as usize; + row_idx += num_rows; + let (curr, rest) = row_slice.split_at_mut(num_rows * width); + chunk_start.push(curr); + row_slice = rest; + } + chunk_start.into_par_iter().for_each(|chunk_slice| { + self.fill_trace_row(mem_helper, chunk_slice); + }); + } + */ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { - todo!(); - /* - let slice = &mut flat_trace[used_cells..used_cells + width]; - let cols: &mut NativeSumcheckCols = slice.borrow_mut(); - cols.first_timestamp = F::from_canonical_u32(record.from_state.timestamp); - cols.start_timestamp = F::from_canonical_usize( - record.from_state.timestamp as usize + record.curr_timestamp_increment, - ); - cols.last_timestamp = F::from_canonical_usize( - record.from_state.timestamp as usize + record.final_timestamp_increment, - ); - cols.register_ptrs = record.register_ptrs; - cols.ctx = record.ctx; - cols.prod_nested_len = record.ctx[4] * record.ctx[3]; - cols.logup_nested_len = record.ctx[6] * record.ctx[5]; - cols.challenges = record.challenges; - cols.alpha = record.alpha; - cols.max_round = record.max_round; - cols.within_round_limit = if record.within_round_limit { - F::ONE - } else { - F::ZERO - }; - cols.should_acc = if record.should_acc { F::ONE } else { F::ZERO }; - cols.eval_acc = record.eval_acc; - - if record.row_type == 0 { - cols.header_row = F::ONE; - cols.header_continuation = if record.continuation { F::ONE } else { F::ZERO }; + let cols: &mut NativeSumcheckCols = row_slice.borrow_mut(); + let start_timestamp = cols.start_timestamp.as_canonical_u32(); + + if cols.header_row == F::ONE { let header: &mut HeaderSpecificCols = cols.specific[..HeaderSpecificCols::::width()].borrow_mut(); - header.pc = F::from_canonical_u32(record.from_state.pc); - header.registers = record.registers; - for i in 0..7usize { - let mem_record = memory.record_by_id(record.read_data_records[i]); - aux_cols_factory.generate_read_aux(mem_record, &mut header.read_records[i]); - } - - // write the final result - let mem_record = memory.record_by_id(record.write_data_records[0]); - aux_cols_factory.generate_write_aux(mem_record, &mut header.write_records); - } else if record.row_type == 1 { - cols.prod_row = F::ONE; - cols.prod_continuation = if record.continuation { F::ONE } else { F::ZERO }; - cols.prod_row_within_max_round = if record.within_round_limit { - F::ONE - } else { - F::ZERO - }; - cols.prod_in_round_evaluation = if record.within_round_limit { - record.ctx[7] - } else { - F::ZERO - }; - cols.prod_next_round_evaluation = if record.within_round_limit { - F::ONE - record.ctx[7] - } else { - F::ZERO - }; - cols.prod_acc = if record.should_acc { F::ONE } else { F::ZERO }; - let prod: &mut ProdSpecificCols = - cols.specific[..ProdSpecificCols::::width()].borrow_mut(); - - cols.curr_prod_n = F::from_canonical_usize(record.prod_spec_n + 1); - cols.challenges[0..EXT_DEG].copy_from_slice(&record.alpha1); - prod.p[0..EXT_DEG].copy_from_slice(&record.p1); - prod.p[EXT_DEG..(EXT_DEG * 2)].copy_from_slice(&record.p2); - prod.data_ptr = record.data_ptr; - prod.acc_eval = record.acc_eval; - - // Read max_round - let mem_record = memory.record_by_id(record.read_data_records[0]); - aux_cols_factory.generate_read_aux(mem_record, &mut prod.read_records[0]); - - if record.within_round_limit { - // Read p1, p2 - let mem_record = memory.record_by_id(record.read_data_records[1]); - aux_cols_factory.generate_read_aux(mem_record, &mut prod.read_records[1]); - - // Write p eval - prod.p_evals = record.p_evals; - let mem_record = memory.record_by_id(record.write_data_records[0]); - aux_cols_factory.generate_write_aux(mem_record, &mut prod.write_record); - } - } else if record.row_type == 2 { - cols.logup_row = F::ONE; - cols.logup_continuation = if record.continuation { F::ONE } else { F::ZERO }; - cols.logup_row_within_max_round = if record.within_round_limit { - F::ONE - } else { - F::ZERO - }; - cols.logup_in_round_evaluation = if record.within_round_limit { - record.ctx[7] - } else { - F::ZERO - }; - cols.logup_next_round_evaluation = if record.within_round_limit { - F::ONE - record.ctx[7] - } else { - F::ZERO - }; - cols.logup_acc = if record.should_acc { F::ONE } else { F::ZERO }; - let logup: &mut LogupSpecificCols = - cols.specific[..LogupSpecificCols::::width()].borrow_mut(); - - cols.curr_logup_n = F::from_canonical_usize(record.logup_spec_n + 1); - cols.challenges[0..EXT_DEG].copy_from_slice(&record.alpha1); - cols.challenges[(EXT_DEG * 3)..(EXT_DEG * 4)].copy_from_slice(&record.alpha2); - logup.pq[0..EXT_DEG].copy_from_slice(&record.p1); - logup.pq[EXT_DEG..(EXT_DEG * 2)].copy_from_slice(&record.p2); - logup.pq[(EXT_DEG * 2)..(EXT_DEG * 3)].copy_from_slice(&record.q1); - logup.pq[(EXT_DEG * 3)..(EXT_DEG * 4)].copy_from_slice(&record.q2); - logup.data_ptr = record.data_ptr; - logup.acc_eval = record.acc_eval; - - // Read max_round - let mem_record = memory.record_by_id(record.read_data_records[0]); - aux_cols_factory.generate_read_aux(mem_record, &mut logup.read_records[0]); - - if record.within_round_limit { - // Read p1, p2, q1, q2 - let mem_record = memory.record_by_id(record.read_data_records[1]); - aux_cols_factory.generate_read_aux(mem_record, &mut logup.read_records[1]); - - // Write p and q eval - logup.p_evals = record.p_evals; - logup.q_evals = record.q_evals; - for i in 0..2usize { - let mem_record = memory.record_by_id(record.write_data_records[i]); - aux_cols_factory.generate_write_aux(mem_record, &mut logup.write_records[i]); - } + mem_fill_helper(mem_helper, start_timestamp, header.read_records[i].as_mut()); } + } else if cols.prod_row == F::ONE { + todo!() + } else if cols.logup_row == F::ONE { + todo!() } - */ } } diff --git a/extensions/native/circuit/src/utils.rs b/extensions/native/circuit/src/utils.rs index c38e656e74..3d05656f16 100644 --- a/extensions/native/circuit/src/utils.rs +++ b/extensions/native/circuit/src/utils.rs @@ -1,9 +1,36 @@ +use openvm_circuit::system::{ + memory::{offline_checker::MemoryBaseAuxCols, online::TracingMemory, MemoryAuxColsFactory}, + native_adapter::util::tracing_read_native, +}; +use p3_field::PrimeField32; + pub(crate) const CASTF_MAX_BITS: usize = 30; pub(crate) const fn const_max(a: usize, b: usize) -> usize { [a, b][(a < b) as usize] } +/// Fill `MemoryBaseAuxCols`, assuming that the `prev_timestamp` is already set in `base_aux`. +pub(crate) fn mem_fill_helper( + mem_helper: &MemoryAuxColsFactory, + timestamp: u32, + base_aux: &mut MemoryBaseAuxCols, +) { + let prev_ts = base_aux.prev_timestamp.as_canonical_u32(); + mem_helper.fill(prev_ts, timestamp, base_aux); +} + +pub(crate) fn tracing_read_native_helper( + memory: &mut TracingMemory, + ptr: u32, + base_aux: &mut MemoryBaseAuxCols, +) -> [F; BLOCK_SIZE] { + let mut prev_ts = 0; + let ret = tracing_read_native(memory, ptr, &mut prev_ts); + base_aux.set_prev(F::from_canonical_u32(prev_ts)); + ret +} + /// Testing framework #[cfg(any(test, feature = "test-utils"))] pub mod test_utils { From 84ec5d2b2f31e8e123eae72a62627d2371285f5b Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Thu, 27 Nov 2025 21:33:33 +0800 Subject: [PATCH 05/18] wip5 --- .../native/circuit/src/sumcheck/chip.rs | 27 +++++++------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index e6ffd25a4a..be5473e02a 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -1,36 +1,27 @@ -use std::borrow::{Borrow, BorrowMut}; +use std::borrow::BorrowMut; -use itertools::Itertools; use openvm_circuit::{ arch::{ - CustomBorrow, ExecutionError, ExecutionState, MultiRowLayout, MultiRowMetadata, - PreflightExecutor, RecordArena, Streams, TraceFiller, VmChipWrapper, VmStateMut, + CustomBorrow, ExecutionError, MultiRowLayout, MultiRowMetadata, PreflightExecutor, + RecordArena, TraceFiller, VmChipWrapper, VmStateMut, }, system::{ - memory::{online::TracingMemory, MemoryAuxColsFactory, MemoryController}, - native_adapter::util::{ - memory_read_native, tracing_read_native, tracing_write_native_inplace, - }, + memory::{online::TracingMemory, MemoryAuxColsFactory}, + native_adapter::util::{memory_read_native, tracing_write_native_inplace}, }, }; use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; -use openvm_native_compiler::{conversion::AS, SumcheckOpcode::SUMCHECK_LAYER_EVAL}; -use openvm_stark_backend::{ - p3_field::{Field, PrimeField, PrimeField32}, - p3_matrix::{dense::RowMajorMatrix, Matrix}, - p3_maybe_rayon::prelude::{IntoParallelIterator, ParallelIterator, ParallelSlice}, -}; +use openvm_native_compiler::SumcheckOpcode::SUMCHECK_LAYER_EVAL; +use openvm_stark_backend::p3_field::PrimeField32; use crate::{ field_extension::{FieldExtension, EXT_DEG}, fri::elem_to_ext, mem_fill_helper, - sumcheck::{ - air::NativeSumcheckAir, - columns::{HeaderSpecificCols, LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols}, + sumcheck::columns::{ + HeaderSpecificCols, LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols, }, tracing_read_native_helper, - utils::const_max, }; const CONTEXT_ARR_BASE_LEN: usize = EXT_DEG * 2; From 875996bb8311e560596c1644f143be6ddb649a7b Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Thu, 27 Nov 2025 22:31:53 +0800 Subject: [PATCH 06/18] replace variable position by name --- extensions/native/circuit/src/sumcheck/air.rs | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 0275e009b9..38047ebf01 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -107,6 +107,8 @@ impl Air for NativeSumcheckAir { specific, } = local; + let [round, num_prod_spec, num_logup_spec, _prod_spec_inner_len, prod_spec_inner_inner_len, _logup_spec_inner_len, logup_spec_inner_inner_len, in_round] = + ctx; builder.assert_bool(header_row); builder.assert_bool(prod_row); builder.assert_bool(logup_row); @@ -119,7 +121,6 @@ impl Air for NativeSumcheckAir { builder.assert_bool(logup_in_round_evaluation); let enabled = header_row + prod_row + logup_row; builder.assert_bool(enabled.clone()); - let in_round = ctx[7]; let continuation = header_continuation + prod_continuation + logup_continuation; builder.assert_bool(continuation.clone()); @@ -169,19 +170,19 @@ impl Air for NativeSumcheckAir { builder .when(header_row) .when(next.logup_row) - .assert_zero(ctx[1]); + .assert_zero(num_prod_spec); builder .when(prod_row) .when(next.logup_row) - .assert_eq(ctx[1], curr_prod_n); + .assert_eq(num_prod_spec, curr_prod_n); builder .when(prod_row) .when(not(prod_continuation)) - .assert_eq(ctx[1], curr_prod_n); + .assert_eq(num_prod_spec, curr_prod_n); builder .when(logup_row) .when(not(logup_continuation)) - .assert_eq(ctx[2], curr_logup_n); + .assert_eq(num_logup_spec, curr_logup_n); // Timestamp transition builder @@ -324,7 +325,7 @@ impl Air for NativeSumcheckAir { builder.when(prod_row_within_max_round).assert_eq( prod_row_specific.data_ptr, - (prod_nested_len * (curr_prod_n - AB::F::ONE) + ctx[4] * ctx[0]) + (prod_nested_len * (curr_prod_n - AB::F::ONE) + prod_spec_inner_inner_len * round) * AB::F::from_canonical_usize(EXT_DEG), ); builder.assert_eq( @@ -408,7 +409,7 @@ impl Air for NativeSumcheckAir { native_as, register_ptrs[0] + AB::F::from_canonical_usize(EXT_DEG * 2) - + ctx[1] + + num_prod_spec + (curr_logup_n - AB::F::ONE), ), // curr_logup_n starts at 1. [max_round], @@ -419,7 +420,7 @@ impl Air for NativeSumcheckAir { builder.when(logup_row_within_max_round).assert_eq( logup_row_specific.data_ptr, - (logup_nested_len * (curr_logup_n - AB::F::ONE) + ctx[6] * ctx[0]) + (logup_nested_len * (curr_logup_n - AB::F::ONE) + logup_spec_inner_inner_len * round) * AB::F::from_canonical_usize(EXT_DEG), ); builder.assert_eq( @@ -457,7 +458,7 @@ impl Air for NativeSumcheckAir { MemoryAddress::new( native_as, register_ptrs[4] - + (ctx[1] + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG), + + (num_prod_spec + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG), ), logup_row_specific.p_evals, start_timestamp + AB::F::TWO, @@ -470,7 +471,8 @@ impl Air for NativeSumcheckAir { MemoryAddress::new( native_as, register_ptrs[4] - + (ctx[1] + ctx[2] + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG), + + (num_prod_spec + num_logup_spec + curr_logup_n) + * AB::F::from_canonical_usize(EXT_DEG), ), logup_row_specific.q_evals, start_timestamp + AB::F::from_canonical_usize(3), From 93bdaffa73b3ca2b13a4613bcf3b0bb8fee34900 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Thu, 27 Nov 2025 23:54:45 +0800 Subject: [PATCH 07/18] wip6 --- .../native/circuit/src/sumcheck/chip.rs | 28 ------------------- 1 file changed, 28 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index be5473e02a..188964ee73 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -419,34 +419,6 @@ where } impl TraceFiller for NativeSumcheckFiller { - /* - fn fill_trace( - &self, - mem_helper: &MemoryAuxColsFactory, - trace: &mut RowMajorMatrix, - rows_used: usize, - ) where - F: Send + Sync + Clone, - { - // Split the trace rows by instruction - let width = trace.width(); - let mut row_idx = 0; - let mut row_slice = trace.values.as_mut_slice(); - let mut chunk_start = Vec::new(); - while row_idx < rows_used { - let cols: &NativeSumcheckCols = row_slice[..width].borrow(); - let num_rows = cols.num_rows.as_canonical_u32() as usize; - row_idx += num_rows; - let (curr, rest) = row_slice.split_at_mut(num_rows * width); - chunk_start.push(curr); - row_slice = rest; - } - chunk_start.into_par_iter().for_each(|chunk_slice| { - self.fill_trace_row(mem_helper, chunk_slice); - }); - } - */ - fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { let cols: &mut NativeSumcheckCols = row_slice.borrow_mut(); let start_timestamp = cols.start_timestamp.as_canonical_u32(); From 769cdf595607355aa5bf4fb11cf0a6b9dcee8281 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Fri, 28 Nov 2025 16:45:16 +0800 Subject: [PATCH 08/18] wip7 --- extensions/native/circuit/src/sumcheck/air.rs | 196 +++++++++++------- .../native/circuit/src/sumcheck/chip.rs | 122 +++++++++-- .../native/circuit/src/sumcheck/columns.rs | 22 +- 3 files changed, 225 insertions(+), 115 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 38047ebf01..90110347f6 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -1,14 +1,16 @@ -use std::{array::from_fn, borrow::Borrow, sync::Arc}; +use std::borrow::Borrow; use openvm_circuit::{ arch::{ContinuationVmProof, ExecutionBridge, ExecutionState}, system::memory::{offline_checker::MemoryBridge, MemoryAddress}, }; -use openvm_circuit_primitives::utils::{assert_array_eq, not}; +use openvm_circuit_primitives::{ + utils::{and, assert_array_eq, not, or}, + var_range::VariableRangeCheckerBus, +}; use openvm_instructions::{LocalOpcode, NATIVE_AS}; use openvm_native_compiler::SumcheckOpcode::SUMCHECK_LAYER_EVAL; use openvm_stark_backend::{ - air_builders::sub::SubAirBuilder, interaction::{BusIndex, InteractionBuilder, PermutationCheckBus}, p3_air::{Air, AirBuilder, BaseAir}, p3_field::{Field, FieldAlgebra}, @@ -18,8 +20,9 @@ use openvm_stark_backend::{ use crate::{ field_extension::{FieldExtension, EXT_DEG}, - sumcheck::columns::{ - HeaderSpecificCols, LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols, + sumcheck::{ + chip::CONTEXT_ARR_BASE_LEN, + columns::{HeaderSpecificCols, LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols}, }, }; @@ -62,17 +65,10 @@ impl Air for NativeSumcheckAir { header_row, prod_row, logup_row, - - // Whether valid prod/logup row operations follow this row - header_continuation, - prod_continuation, - logup_continuation, - - // Round limit - prod_row_within_max_round, - logup_row_within_max_round, + is_end, // What type of evaluation is performed + // mainly for reducing constraint degree prod_in_round_evaluation, prod_next_round_evaluation, logup_in_round_evaluation, @@ -112,28 +108,37 @@ impl Air for NativeSumcheckAir { builder.assert_bool(header_row); builder.assert_bool(prod_row); builder.assert_bool(logup_row); - builder.assert_bool(header_continuation); - builder.assert_bool(prod_continuation); - builder.assert_bool(logup_continuation); - builder.assert_bool(prod_row_within_max_round); - builder.assert_bool(logup_row_within_max_round); + builder.assert_bool(within_round_limit); builder.assert_bool(prod_in_round_evaluation); builder.assert_bool(logup_in_round_evaluation); + let enabled = header_row + prod_row + logup_row; + let next_enabled = next.header_row + next.prod_row + next.logup_row; builder.assert_bool(enabled.clone()); - let continuation = header_continuation + prod_continuation + logup_continuation; - builder.assert_bool(continuation.clone()); + + builder.assert_eq::( + or::( + or::( + and(prod_row, next.header_row), + and(logup_row, next.header_row), + ), + not::(next_enabled), + ), + is_end.into(), + ); + + // TODO: within_round_limit = true => round < max_round // Randomness transition - let alpha1: [_; EXT_DEG] = challenges[0..EXT_DEG].try_into().expect(""); - let c1: [_; EXT_DEG] = challenges[EXT_DEG..{ EXT_DEG * 2 }].try_into().expect(""); + let alpha1: [_; EXT_DEG] = challenges[0..EXT_DEG].try_into().unwrap(); + let c1: [_; EXT_DEG] = challenges[EXT_DEG..{ EXT_DEG * 2 }].try_into().unwrap(); let c2: [_; EXT_DEG] = challenges[{ EXT_DEG * 2 }..{ EXT_DEG * 3 }] .try_into() - .expect(""); + .unwrap(); let alpha2: [_; EXT_DEG] = challenges[{ EXT_DEG * 3 }..{ EXT_DEG * 4 }] .try_into() - .expect(""); - let next_alpha1: [_; EXT_DEG] = next.challenges[0..EXT_DEG].try_into().expect(""); + .unwrap(); + let next_alpha1: [_; EXT_DEG] = next.challenges[0..EXT_DEG].try_into().unwrap(); // Carry along columns assert_array_eq( @@ -146,6 +151,7 @@ impl Air for NativeSumcheckAir { ctx, next.ctx, ); + // c1, c2 remain the same assert_array_eq::<_, _, _, { EXT_DEG * 2 }>( &mut builder.when(next.prod_row + next.logup_row), challenges[EXT_DEG..(EXT_DEG * 3)].try_into().expect(""), @@ -153,6 +159,11 @@ impl Air for NativeSumcheckAir { .try_into() .expect(""), ); + assert_array_eq( + &mut builder.when(next.prod_row + next.logup_row), + alpha, + next.alpha, + ); builder .when(next.prod_row + next.logup_row) .assert_eq(prod_nested_len, next.prod_nested_len); @@ -160,29 +171,48 @@ impl Air for NativeSumcheckAir { .when(next.prod_row + next.logup_row) .assert_eq(logup_nested_len, next.logup_nested_len); - // Row transition + //////////////////////////////////////////////////////////////// + // Row transitions from current to next row + // The basic pattern is + // header_row -> prod_row -> ... -> prod_row + // -> logup_row -> ... -> logup_row + //////////////////////////////////////////////////////////////// + + // (curr_prod_n, curr_logup_n) start at 0 + builder.when(header_row).assert_zero(curr_prod_n); + builder + .when(header_row + prod_row) + .assert_zero(curr_logup_n); builder .when(next.prod_row) .assert_eq(curr_prod_n + AB::F::ONE, next.curr_prod_n); builder .when(next.logup_row) .assert_eq(curr_logup_n + AB::F::ONE, next.curr_logup_n); + // if header row is followed by another header row + // then num_prod_spec and num_logup_spec should be zero builder .when(header_row) - .when(next.logup_row) + .when(next.header_row) .assert_zero(num_prod_spec); builder - .when(prod_row) + .when(header_row) + .when(next.header_row) + .assert_zero(num_logup_spec); + // if header row is followed by a logup row, + // then num_prod_spec should be zero + builder + .when(header_row) .when(next.logup_row) - .assert_eq(num_prod_spec, curr_prod_n); + .assert_zero(num_prod_spec); builder .when(prod_row) - .when(not(prod_continuation)) - .assert_eq(num_prod_spec, curr_prod_n); + .when(next.logup_row) + .assert_eq(curr_prod_n, num_prod_spec); builder .when(logup_row) - .when(not(logup_continuation)) - .assert_eq(num_logup_spec, curr_logup_n); + .when(next.header_row) + .assert_eq(curr_logup_n, num_logup_spec); // Timestamp transition builder @@ -209,37 +239,41 @@ impl Air for NativeSumcheckAir { // Termination condition assert_array_eq( - &mut builder.when::(not(continuation)), + &mut builder.when::(is_end.into()), eval_acc, [AB::F::ZERO; 4], ); // Randomness transition assert_array_eq( - &mut builder.when(header_continuation), - next.challenges[0..EXT_DEG].try_into().expect(""), + &mut builder.when(and(header_row, or(next.prod_row, next.logup_row))), + next.challenges[0..EXT_DEG].try_into().unwrap(), [AB::F::ONE, AB::F::ZERO, AB::F::ZERO, AB::F::ZERO], ); + assert_array_eq::<_, _, _, { EXT_DEG }>(&mut builder.when(header_row), alpha, alpha1); + let prod_next_alpha = FieldExtension::multiply(alpha1, alpha); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(and(prod_row, next.prod_row)), + prod_next_alpha, + next_alpha1, + ); + // alpha1 = alpha_numerator, alpha2 = alpha_denominator for logup row let alpha_denominator = FieldExtension::multiply(alpha1, alpha); assert_array_eq::<_, _, _, { EXT_DEG }>( &mut builder.when(logup_row), alpha_denominator, alpha2, ); - let prod_next_alpha = FieldExtension::multiply(alpha1, alpha); - assert_array_eq::<_, _, _, { EXT_DEG }>( - &mut builder.when(prod_continuation), - prod_next_alpha, - next_alpha1, - ); let logup_next_alpha = FieldExtension::multiply(alpha2, alpha); assert_array_eq::<_, _, _, { EXT_DEG }>( - &mut builder.when(logup_continuation), + &mut builder.when(and(logup_row, next.logup_row)), logup_next_alpha, next_alpha1, ); + /////////////////////////////////////// // Header + /////////////////////////////////////// let header_row_specific: &HeaderSpecificCols = specific[..HeaderSpecificCols::::width()].borrow(); let registers = header_row_specific.registers; @@ -273,7 +307,7 @@ impl Air for NativeSumcheckAir { .eval(builder, header_row); } - // React ctx + // Read ctx self.memory_bridge .read( MemoryAddress::new(native_as, register_ptrs[0]), @@ -303,7 +337,9 @@ impl Air for NativeSumcheckAir { ) .eval(builder, header_row); + /////////////////////////////////////// // Prod spec evaluation + /////////////////////////////////////// let prod_row_specific: &ProdSpecificCols = specific[..ProdSpecificCols::::width()].borrow(); let next_prod_row_specific: &ProdSpecificCols = @@ -314,7 +350,7 @@ impl Air for NativeSumcheckAir { MemoryAddress::new( native_as, register_ptrs[0] - + AB::F::from_canonical_usize(EXT_DEG * 2) + + AB::F::from_canonical_usize(CONTEXT_ARR_BASE_LEN) + (curr_prod_n - AB::F::ONE), ), // curr_prod_n starts at 1. [max_round], @@ -323,17 +359,17 @@ impl Air for NativeSumcheckAir { ) .eval(builder, prod_row); - builder.when(prod_row_within_max_round).assert_eq( + builder.when(prod_row * within_round_limit).assert_eq( prod_row_specific.data_ptr, (prod_nested_len * (curr_prod_n - AB::F::ONE) + prod_spec_inner_inner_len * round) * AB::F::from_canonical_usize(EXT_DEG), ); builder.assert_eq( - prod_row * prod_row_within_max_round * in_round, + prod_row * within_round_limit * in_round, prod_in_round_evaluation, ); builder.assert_eq( - prod_row * prod_row_within_max_round * not(in_round), + prod_row * within_round_limit * not(in_round), prod_next_round_evaluation, ); builder.assert_eq(prod_row * should_acc, prod_acc); @@ -345,12 +381,12 @@ impl Air for NativeSumcheckAir { start_timestamp + AB::F::ONE, &prod_row_specific.read_records[1], ) - .eval(builder, prod_row_within_max_round); + .eval(builder, prod_row * within_round_limit); - let p1: [AB::Var; EXT_DEG] = prod_row_specific.p[0..EXT_DEG].try_into().expect(""); + let p1: [AB::Var; EXT_DEG] = prod_row_specific.p[0..EXT_DEG].try_into().unwrap(); let p2: [AB::Var; EXT_DEG] = prod_row_specific.p[EXT_DEG..(EXT_DEG * 2)] .try_into() - .expect(""); + .unwrap(); self.memory_bridge .write( @@ -362,7 +398,7 @@ impl Air for NativeSumcheckAir { start_timestamp + AB::F::TWO, &prod_row_specific.write_record, ) - .eval(builder, prod_row_within_max_round); + .eval(builder, prod_row * within_round_limit); // Calculate evaluations let next_round_p_evals = FieldExtension::add( @@ -381,23 +417,26 @@ impl Air for NativeSumcheckAir { prod_row_specific.p_evals, ); - // Accumulate evaluation - let acc_eval = + // TODO: add constraint on should_acc + + // Accumulate `eval_rlc` into global accumulator `eval_acc` + // when round < max_round - 2 + let eval_rlc = FieldExtension::multiply::(prod_row_specific.p_evals, alpha1); assert_array_eq::<_, _, _, { EXT_DEG }>( &mut builder.when(prod_acc), - prod_row_specific.acc_eval, - acc_eval, + prod_row_specific.eval_rlc, + eval_rlc, ); - - let next_acc = FieldExtension::subtract(eval_acc, next_prod_row_specific.acc_eval); assert_array_eq::<_, _, _, { EXT_DEG }>( &mut builder.when(next.prod_acc), - next.eval_acc, - next_acc, + FieldExtension::add(next.eval_acc, next_prod_row_specific.eval_rlc), + eval_acc, ); + /////////////////////////////////////// // Logup spec evaluation + /////////////////////////////////////// let logup_row_specific: &LogupSpecificCols = specific[..LogupSpecificCols::::width()].borrow(); let next_logup_row_specfic: &LogupSpecificCols = @@ -418,17 +457,17 @@ impl Air for NativeSumcheckAir { ) .eval(builder, logup_row); - builder.when(logup_row_within_max_round).assert_eq( + builder.when(logup_row * within_round_limit).assert_eq( logup_row_specific.data_ptr, (logup_nested_len * (curr_logup_n - AB::F::ONE) + logup_spec_inner_inner_len * round) * AB::F::from_canonical_usize(EXT_DEG), ); builder.assert_eq( - logup_row * logup_row_within_max_round * in_round, + logup_row * within_round_limit * in_round, logup_in_round_evaluation, ); builder.assert_eq( - logup_row * logup_row_within_max_round * not(in_round), + logup_row * within_round_limit * not(in_round), logup_next_round_evaluation, ); builder.assert_eq(logup_row * should_acc, logup_acc); @@ -440,19 +479,20 @@ impl Air for NativeSumcheckAir { start_timestamp + AB::F::ONE, &logup_row_specific.read_records[1], ) - .eval(builder, logup_row_within_max_round); + .eval(builder, logup_row * within_round_limit); - let p1: [_; EXT_DEG] = logup_row_specific.pq[0..EXT_DEG].try_into().expect(""); + let p1: [_; EXT_DEG] = logup_row_specific.pq[0..EXT_DEG].try_into().unwrap(); let p2: [_; EXT_DEG] = logup_row_specific.pq[EXT_DEG..(EXT_DEG * 2)] .try_into() - .expect(""); + .unwrap(); let q1: [_; EXT_DEG] = logup_row_specific.pq[(EXT_DEG * 2)..{ EXT_DEG * 3 }] .try_into() - .expect(""); + .unwrap(); let q2: [_; EXT_DEG] = logup_row_specific.pq[(EXT_DEG * 3)..(EXT_DEG * 4)] .try_into() - .expect(""); + .unwrap(); + // write p_evals self.memory_bridge .write( MemoryAddress::new( @@ -464,8 +504,9 @@ impl Air for NativeSumcheckAir { start_timestamp + AB::F::TWO, &logup_row_specific.write_records[0], ) - .eval(builder, logup_row_within_max_round); + .eval(builder, logup_row * within_round_limit); + // write q_evals self.memory_bridge .write( MemoryAddress::new( @@ -478,7 +519,7 @@ impl Air for NativeSumcheckAir { start_timestamp + AB::F::from_canonical_usize(3), &logup_row_specific.write_records[1], ) - .eval(builder, logup_row_within_max_round); + .eval(builder, logup_row * within_round_limit); // Calculate evaluations let next_round_p_evals = FieldExtension::add( @@ -517,21 +558,22 @@ impl Air for NativeSumcheckAir { ); // Accumulate evaluation - let acc_eval = FieldExtension::add( + let eval_rlc = FieldExtension::add( FieldExtension::multiply::(logup_row_specific.p_evals, alpha1), FieldExtension::multiply::(logup_row_specific.q_evals, alpha2), ); assert_array_eq::<_, _, _, { EXT_DEG }>( &mut builder.when(logup_acc), - logup_row_specific.acc_eval, - acc_eval, + logup_row_specific.eval_rlc, + eval_rlc, ); - let next_acc = FieldExtension::subtract(eval_acc, next_logup_row_specfic.acc_eval); + // Accumulate into global accumulator `eval_acc` + // when round < max_round - 2 assert_array_eq::<_, _, _, { EXT_DEG }>( &mut builder.when(next.logup_acc), - next.eval_acc, - next_acc, + FieldExtension::add(next.eval_acc, next_logup_row_specfic.eval_rlc), + eval_acc, ); } } diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index 188964ee73..680dfcf31b 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -10,7 +10,9 @@ use openvm_circuit::{ native_adapter::util::{memory_read_native, tracing_write_native_inplace}, }, }; -use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; +use openvm_instructions::{ + instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode, NATIVE_AS, +}; use openvm_native_compiler::SumcheckOpcode::SUMCHECK_LAYER_EVAL; use openvm_stark_backend::p3_field::PrimeField32; @@ -24,7 +26,7 @@ use crate::{ tracing_read_native_helper, }; -const CONTEXT_ARR_BASE_LEN: usize = EXT_DEG * 2; +pub(crate) const CONTEXT_ARR_BASE_LEN: usize = EXT_DEG * 2; const CURRENT_LAYER_MODE: u32 = 1; const NEXT_LAYER_MODE: u32 = 0; @@ -126,6 +128,8 @@ where // b. logup sumcheck: p[r] = eq(0,r) * p[0] + eq(1,r) * p[1] // and q[r] = eq(0,r) * q[0] + eq(1,r) * q[1] assert_eq!(op, SUMCHECK_LAYER_EVAL.global_opcode()); + assert_eq!(data_address_space.as_canonical_u32(), NATIVE_AS); + assert_eq!(register_address_space.as_canonical_u32(), NATIVE_AS); let [ctx_ptr]: [F; 1] = memory_read_native(state.memory.data(), ctx_reg.as_canonical_u32()); let ctx: [u32; 8] = memory_read_native(state.memory.data(), ctx_ptr.as_canonical_u32()) @@ -135,7 +139,6 @@ where ctx; // allocate n rows let num_rows = (1 + num_prod_spec + num_logup_spec) as usize; - println!("num_rows = {}", num_rows); let rows = state .ctx .alloc(MultiRowLayout::new(NativeSumcheckMetadata { num_rows })) @@ -204,11 +207,10 @@ where let mut eval_acc = elem_to_ext(F::from_canonical_u32(0)); let mut alpha_acc = elem_to_ext(F::from_canonical_u32(1)); - // TODO: write final eval // all rows share same register values, ctx, challenges for row in rows.iter_mut() { - row.challenges = challenges; + row.challenges[EXT_DEG..3 * EXT_DEG].copy_from_slice(&challenges[EXT_DEG..3 * EXT_DEG]); row.alpha = alpha; row.ctx = ctx; row.register_ptrs[0] = ctx_ptr; @@ -230,6 +232,7 @@ where prod_row.prod_row = F::ONE; prod_row.curr_prod_n = F::from_canonical_usize(i); + prod_row.start_timestamp = F::from_canonical_usize(cur_timestamp); // read max_round let [max_round]: [F; 1] = tracing_read_native_helper( @@ -239,7 +242,7 @@ where ); cur_timestamp += 1; - prod_row.alpha = alpha_acc; + prod_row.challenges[0..EXT_DEG].copy_from_slice(&alpha_acc); prod_row.max_round = max_round; // round starts from 0 @@ -276,6 +279,16 @@ where }; prod_specific.p_evals = eval; + match mode { + NEXT_LAYER_MODE => { + prod_row.prod_next_round_evaluation = F::ONE; + } + CURRENT_LAYER_MODE => { + prod_row.prod_in_round_evaluation = F::ONE; + } + _ => unreachable!("mode should be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), + } + // write p eval tracing_write_native_inplace( state.memory, @@ -285,16 +298,18 @@ where ); cur_timestamp += 1; - let acc_eval = FieldExtension::multiply(alpha_acc, eval); - prod_row.eval_acc = acc_eval; + let eval_rlc = FieldExtension::multiply(alpha_acc, eval); + prod_specific.eval_rlc = eval_rlc; if mode == NEXT_LAYER_MODE && round < max_round.as_canonical_u32() - 2 { - eval_acc = FieldExtension::add(eval_acc, acc_eval); + eval_acc = FieldExtension::add(eval_acc, eval_rlc); prod_row.should_acc = F::ONE; + prod_row.eval_acc = eval_acc; } } - prod_row.alpha = FieldExtension::multiply(alpha_acc, alpha); + alpha_acc = FieldExtension::multiply(alpha_acc, alpha); + prod_row.challenges[0..EXT_DEG].copy_from_slice(&alpha_acc); } // logup rows @@ -303,6 +318,8 @@ where logup_row.specific[..LogupSpecificCols::::width()].borrow_mut(); logup_row.logup_row = F::ONE; + logup_row.curr_logup_n = F::from_canonical_usize(i); + logup_row.start_timestamp = F::from_canonical_usize(cur_timestamp); let [max_round]: [F; 1] = tracing_read_native_helper( state.memory, @@ -312,6 +329,11 @@ where logup_row.max_round = max_round; cur_timestamp += 1; + let alpha_numerator = alpha_acc; + let alpha_denominator = FieldExtension::multiply(alpha_acc, alpha); + logup_row.challenges[0..EXT_DEG].copy_from_slice(&alpha_acc); + logup_row.challenges[2 * EXT_DEG..(3 * EXT_DEG)].copy_from_slice(&alpha_denominator); + if round < max_round.as_canonical_u32() - 1 { logup_row.within_round_limit = F::ONE; let start = calculate_3d_ext_idx( @@ -357,6 +379,16 @@ where _ => unreachable!("mode should be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), }; + match mode { + NEXT_LAYER_MODE => { + logup_row.logup_next_round_evaluation = F::ONE; + } + CURRENT_LAYER_MODE => { + logup_row.logup_in_round_evaluation = F::ONE; + } + _ => unreachable!("mode should be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), + } + logup_specific.p_evals = p_eval; logup_specific.q_evals = q_eval; @@ -378,22 +410,20 @@ where ); cur_timestamp += 3; // 1 read, 2 writes - let alpha_numerator = alpha_acc; - let alpha_denominator = FieldExtension::multiply(alpha_acc, alpha); - + let eval = FieldExtension::add( + FieldExtension::multiply(alpha_numerator, p_eval), + FieldExtension::multiply(alpha_denominator, q_eval), + ); + logup_specific.eval_rlc = eval; if mode == NEXT_LAYER_MODE && round < max_round.as_canonical_u32() - 2 { - let eval = FieldExtension::add( - FieldExtension::multiply(alpha_numerator, p_eval), - FieldExtension::multiply(alpha_denominator, q_eval), - ); - logup_specific.acc_eval = eval; eval_acc = FieldExtension::add(eval_acc, eval); logup_row.should_acc = F::ONE; + logup_row.logup_acc = F::ONE; logup_row.eval_acc = eval_acc; } } - alpha_acc = FieldExtension::multiply(FieldExtension::multiply(alpha_acc, alpha), alpha); + alpha_acc = FieldExtension::multiply(alpha_denominator, alpha); } let head_row = &mut rows[0]; @@ -428,12 +458,60 @@ impl TraceFiller for NativeSumcheckFiller { cols.specific[..HeaderSpecificCols::::width()].borrow_mut(); for i in 0..7usize { - mem_fill_helper(mem_helper, start_timestamp, header.read_records[i].as_mut()); + mem_fill_helper( + mem_helper, + start_timestamp + i, + header.read_records[i].as_mut(), + ); } + mem_fill_helper( + mem_helper, + start_timestamp + 7, + header.write_records.as_mut(), + ); } else if cols.prod_row == F::ONE { - todo!() + let prod_row_specific: &mut ProdSpecificCols = + cols.specific[..ProdSpecificCols::::width()].borrow_mut(); + + mem_fill_helper( + mem_helper, + start_timestamp, + prod_row_specific.read_records[0].as_mut(), + ); + mem_fill_helper( + mem_helper, + start_timestamp + 1, + prod_row_specific.read_records[1].as_mut(), + ); + mem_fill_helper( + mem_helper, + start_timestamp + 2, + prod_row_specific.write_record.as_mut(), + ); } else if cols.logup_row == F::ONE { - todo!() + let logup_row_specific: &mut LogupSpecificCols = + cols.specific[..LogupSpecificCols::::width()].borrow_mut(); + + mem_fill_helper( + mem_helper, + start_timestamp, + logup_row_specific.read_records[0].as_mut(), + ); + mem_fill_helper( + mem_helper, + start_timestamp + 1, + logup_row_specific.read_records[1].as_mut(), + ); + mem_fill_helper( + mem_helper, + start_timestamp + 2, + logup_row_specific.write_records[0].as_mut(), + ); + mem_fill_helper( + mem_helper, + start_timestamp + 3, + logup_row_specific.write_records[1].as_mut(), + ); } } } diff --git a/extensions/native/circuit/src/sumcheck/columns.rs b/extensions/native/circuit/src/sumcheck/columns.rs index ca4a264277..0cee93c381 100644 --- a/extensions/native/circuit/src/sumcheck/columns.rs +++ b/extensions/native/circuit/src/sumcheck/columns.rs @@ -16,18 +16,8 @@ pub struct NativeSumcheckCols { pub prod_row: T, /// Indicates that this row is a step for logup_spec in the layer sum operation pub logup_row: T, - - /// Indicates that there are valid operations following this header row - pub header_continuation: T, - /// Indicates that there are valid operations following this product evaluation row - pub prod_continuation: T, - /// Indicates that there are valid operations following this logup row - pub logup_continuation: T, - - /// Indicates that the prod row is within maximum round - pub prod_row_within_max_round: T, - /// Indicates that the logup row is within maximum round - pub logup_row_within_max_round: T, + /// Indicates that this row is the end of the entire layer sum operation + pub is_end: T, /// Indicates what type of evaluation constraints should be applied pub prod_in_round_evaluation: T, @@ -66,8 +56,8 @@ pub struct NativeSumcheckCols { pub curr_prod_n: T, pub curr_logup_n: T, - // alpha1, c1, c2, alpha2 (for logup rows) pub alpha: [T; EXT_DEG], + // alpha1, c1, c2, alpha2 (for logup rows) pub challenges: [T; EXT_DEG * 4], // Specific to each row @@ -118,8 +108,8 @@ pub struct ProdSpecificCols { pub p_evals: [T; EXT_DEG], /// write p_evals pub write_record: MemoryWriteAuxCols, - /// Evaluation for the accumulator - pub acc_eval: [T; EXT_DEG], + /// p_evals * alpha^i + pub eval_rlc: [T; EXT_DEG], } #[repr(C)] @@ -138,5 +128,5 @@ pub struct LogupSpecificCols { /// write both p_evals and q_evals pub write_records: [MemoryWriteAuxCols; 2], /// Evaluation for the accumulator - pub acc_eval: [T; EXT_DEG], + pub eval_rlc: [T; EXT_DEG], } From a8ee4b8d67cf56ef6a230b9776290003fb16b2c5 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Fri, 28 Nov 2025 19:48:54 +0800 Subject: [PATCH 09/18] clippy --- extensions/native/recursion/tests/sumcheck.rs | 20 ++++--------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/extensions/native/recursion/tests/sumcheck.rs b/extensions/native/recursion/tests/sumcheck.rs index d01c63729c..9a7dca4306 100644 --- a/extensions/native/recursion/tests/sumcheck.rs +++ b/extensions/native/recursion/tests/sumcheck.rs @@ -1,27 +1,15 @@ -use itertools::Itertools; -use openvm_circuit::arch::{ - instructions::program::Program, verify_single, SystemConfig, VirtualMachine, VmConfig, - VmExecutor, -}; +use openvm_circuit::arch::instructions::program::Program; #[cfg(not(feature = "cuda"))] use openvm_circuit::utils::air_test_impl; use openvm_native_circuit::{NativeBuilder, NativeConfig, EXT_DEG}; use openvm_native_compiler::{ asm::{AsmBuilder, AsmCompiler}, conversion::{convert_program, CompilerOptions}, - ir::{Ext, Felt, Usize}, + ir::{Ext, Usize}, prelude::*, }; -use openvm_native_recursion::{ - challenger::{duplex::DuplexChallengerVariable, CanObserveVariable}, - testing_utils::inner::run_recursive_test, -}; -use openvm_stark_backend::{ - config::{Domain, StarkGenericConfig}, - p3_commit::PolynomialSpace, - p3_field::{ - extension::BinomialExtensionField, FieldAlgebra, FieldExtensionAlgebra, PackedValue, - }, +use openvm_stark_backend::p3_field::{ + extension::BinomialExtensionField, FieldAlgebra, FieldExtensionAlgebra, PackedValue, }; use openvm_stark_sdk::{ config::{ From cbbe3c105173eec044cdae96e08c28de48b93638 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Fri, 28 Nov 2025 21:18:18 +0800 Subject: [PATCH 10/18] wip8 --- extensions/native/recursion/tests/sumcheck.rs | 1388 ++++------------- 1 file changed, 281 insertions(+), 1107 deletions(-) diff --git a/extensions/native/recursion/tests/sumcheck.rs b/extensions/native/recursion/tests/sumcheck.rs index 9a7dca4306..d6981c4fdb 100644 --- a/extensions/native/recursion/tests/sumcheck.rs +++ b/extensions/native/recursion/tests/sumcheck.rs @@ -9,14 +9,13 @@ use openvm_native_compiler::{ prelude::*, }; use openvm_stark_backend::p3_field::{ - extension::BinomialExtensionField, FieldAlgebra, FieldExtensionAlgebra, PackedValue, + extension::BinomialExtensionField, FieldAlgebra, FieldExtensionAlgebra, }; use openvm_stark_sdk::{ config::{ baby_bear_poseidon2::BabyBearPoseidon2Engine, fri_params::standard_fri_params_with_100_bits_conjectured_security, FriParameters, }, - engine::StarkFriEngine, p3_baby_bear::BabyBear, }; @@ -48,7 +47,6 @@ fn test_sumcheck_layer_eval() { standard_fri_params_with_100_bits_conjectured_security(1) }; - let engine = BabyBearPoseidon2Engine::new(fri_params); let mut config = NativeConfig::aggregation(0, sumcheck_max_constraint_degree); config.system.memory_config.max_access_adapter_n = 16; @@ -78,19 +76,11 @@ fn build_test_program(builder: &mut Builder) { builder.set(&ctx, idx, Usize::from(n as usize)); } + #[rustfmt::skip] let challenges_u32s = [ - 548478283u32, - 456436544, - 1716290291, - 791326976, - 1829717553, - 1422025771, - 1917123958, - 727015942, - 183548369, - 591240150, - 96141963, - 1286249979, + 548478283u32, 456436544, 1716290291, 791326976, + 1829717553, 1422025771, 1917123958, 727015942, + 183548369, 591240150, 96141963, 1286249979, ]; let challenges: Array> = builder.dyn_array(challenges_u32s.len() / EXT_DEG); for (idx, n) in challenges_u32s.chunks(EXT_DEG).enumerate() { @@ -104,391 +94,104 @@ fn build_test_program(builder: &mut Builder) { builder.set(&challenges, idx, e); } + #[rustfmt::skip] let prod_spec_eval_u32s = [ - 1538906710u32, - 637535518, - 1753132406, - 1395236651, - 278806441, - 1722910382, - 1475548665, - 1117874675, - 1578586709, - 1826764884, - 384068476, - 1852240363, - 707958906, - 1960944944, - 183554399, - 1259273357, - 227285124, - 243066436, - 1718037317, - 369721963, - 1752968006, - 1061013677, - 775617499, - 1464907431, - 544300429, - 871461966, - 135151545, - 1343592602, - 1622220528, - 643966158, - 3932580, - 434948358, - 540553922, - 1446502052, - 153298741, - 1191216273, - 265936762, - 1463035257, - 1237633339, - 1797346310, - 1355791584, - 389527741, - 1741650463, - 1728913415, - 1825739540, - 1790924136, - 460776743, - 29536554, - 6842036, - 252495270, - 1968285155, - 299467416, - 49085744, - 1499815729, - 1098802236, - 644489275, - 1827273105, - 1888401527, - 390077051, - 565528894, - 1366177188, - 67441791, - 958486301, - 402056716, - 590379691, - 462035406, - 633459131, - 843304872, - 584100013, - 1932496508, - 250656031, - 146983915, - 1835173157, - 939973454, - 1844873638, - 1916054832, - 1601784696, - 167251717, - 409107688, - 1062925788, - 1291319514, - 1790529531, - 495655592, - 1093359708, - 790197205, - 674458164, - 195988318, - 399764452, - 106865258, - 967050329, - 350035523, - 1109292118, - 1815460301, - 281986036, - 900636603, - 1121197008, - 1228976590, - 1879998708, - 1924332706, - 434695844, - 1159360621, - 471397106, - 473371067, - 1009065094, - 1320176846, - 168020789, - 1265321929, - 1901808675, - 223657700, - 1480150183, - 1779968584, - 144416591, - 304407746, - 1864498679, - 1482460119, - 1554376965, - 1479261548, - 1657723043, - 1039345063, - 1053923521, - 442080513, - 1964082352, - 691664908, - 1941008321, - 1007729002, - 860529393, - 849697342, - 754485488, - 584295923, - 1072251466, - 1105105254, - 996079746, - 1305909868, - 1348028973, - 122275988, - 464050036, - 692807777, - 1098809324, - 397235220, - 596459886, - 1663209783, - 720230826, - 1422510715, - 1760654694, - 544197700, - 1417744567, - 1938716517, - 1571826328, - 1591430185, - 1173137446, - 175285007, - 1541718596, - 1715958587, - 1429966110, - 583013357, - 1667787861, - 109891172, - 668253167, - 161783842, - 296183397, - 1681897325, - 1054396117, - 264741948, - 464026995, - 1907686022, - 1532786783, - 394869458, - 1766734740, - 136047179, - 536856195, - 376188855, - 700633625, - 515518419, - 531043483, - 60673499, - 556496527, - 1743028981, - 873954569, - 1371062291, - 632169731, - 1353239206, - 526507035, - 1894490088, - 589441599, - 1610487168, - 1074160583, - 366366374, - 247602990, - 1535354896, - 894493713, - 1555870413, - 1389854934, - 1897251683, - 1525812801, - 675621735, - 697919636, - 1690274072, - 1466810921, - 1221110784, - 1741995587, - 1877169764, - 390876982, - 1794129810, - 297662156, - 144295349, - 417037264, - 1290835727, - 1654971513, - 1674131303, - 1625667423, - 1471248832, - 1676797844, - 1172916558, - 1707775403, - 423725211, - 1643279661, - 1695774264, - 378140395, - 1517569394, - 1666625392, - 1803981250, - 439036260, - 247966130, - 709534816, - 361144100, - 1546096548, - 1240886454, - 1898161518, - 843262057, - 1709259464, - 1301015977, - 1997626928, - 677153173, - 1606710353, - 1216038070, - 435565562, - 98686333, - 1773787396, - 267051994, - 99395396, - 545509105, - 782289675, - 1289865975, - 1707775075, - 1158993015, - 1506576588, - 993215179, - 1523099397, - 923914455, - 1895162386, - 284489994, - 1444139016, - 1943825680, - 466202724, - 1632522710, - 1384015062, - 723147188, - 1284031324, - 1430481515, - 341213007, - 171192499, - 1061688239, - 808927167, - 83182639, - 759209907, - 1728321272, - 976049976, - 1652071995, - 1002877840, - 69880246, - 1095135165, - 677588420, - 1384715290, - 829619452, - 170122781, - 1958173727, - 13389238, - 789379698, - 1883383039, - 1279195174, - 1618672336, - 1192839317, - 1348311124, - 758896285, - 1939775389, - 684108413, - 1838340479, - 1332232130, - 1070486028, - 549228790, - 868851698, - 1678207843, - 1754321489, - 637000403, - 647901906, - 45343322, - 1768524074, - 1167955205, - 1816497210, - 1609414096, - 1985231742, - 1540534482, - 232730819, - 232221968, - 1509637836, - 1480860627, - 884647789, - 1096458024, - 163721583, - 1248032262, - 436419506, - 1737102298, - 651105860, - 452298073, - 1064372507, - 1792838683, - 619243471, - 860127631, - 721724708, - 950768433, - 279913448, - 339693210, - 47730422, - 1952683911, - 1316500770, - 675944216, - 386902809, - 619333956, - 1194800389, - 43989936, - 1944372656, - 666045666, - 1155873844, - 522696968, - 58874730, - 1497238023, - 421619994, - 1980672127, - 1657191856, - 1913792631, - 1784663131, - 1118400672, - 1828104993, - 1637808383, - 414755472, - 775410449, - 747132157, - 136820101, - 1082674285, - 93190395, - 357955402, - 335652723, - 1192102705, - 480365232, - 1354935730, - 1391829361, - 966662991, - 1601510445, - 569528575, - 545490940, - 1753711688, - 807025222, - 580374183, - 587718008, - 977546290, - 1055719519, - 1157107032, - 562799608, - 859466927, - 840450024, - 815325134, - 936576801, - 1010587056, - 246624382, - 1808049797, - 1098183398, - 1005077390, - 772432546, - 1976629565, - 1003772218, - 1655315418, - 1767931114, - 982008720, - 785023351, + 1538906710u32, 637535518, 1753132406, 1395236651, + 278806441, 1722910382, 1475548665, 1117874675, + 1578586709, 1826764884, 384068476, 1852240363, + 707958906, 1960944944, 183554399, 1259273357, + 227285124, 243066436, 1718037317, 369721963, + 1752968006, 1061013677, 775617499, 1464907431, + 544300429, 871461966, 135151545, 1343592602, + 1622220528, 643966158, 3932580, 434948358, + 540553922, 1446502052, 153298741, 1191216273, + 265936762, 1463035257, 1237633339, 1797346310, + 1355791584, 389527741, 1741650463, 1728913415, + 1825739540, 1790924136, 460776743, 29536554, + 6842036, 252495270, 1968285155, 299467416, + 49085744, 1499815729, 1098802236, 644489275, + 1827273105, 1888401527, 390077051, 565528894, + 1366177188, 67441791, 958486301, 402056716, + 590379691, 462035406, 633459131, 843304872, + 584100013, 1932496508, 250656031, 146983915, + 1835173157, 939973454, 1844873638, 1916054832, + 1601784696, 167251717, 409107688, 1062925788, + 1291319514, 1790529531, 495655592, 1093359708, + 790197205, 674458164, 195988318, 399764452, + 106865258, 967050329, 350035523, 1109292118, + 1815460301, 281986036, 900636603, 1121197008, + 1228976590, 1879998708, 1924332706, 434695844, + 1159360621, 471397106, 473371067, 1009065094, + 1320176846, 168020789, 1265321929, 1901808675, + 223657700, 1480150183, 1779968584, 144416591, + 304407746, 1864498679, 1482460119, 1554376965, + 1479261548, 1657723043, 1039345063, 1053923521, + 442080513, 1964082352, 691664908, 1941008321, + 1007729002, 860529393, 849697342, 754485488, + 584295923, 1072251466, 1105105254, 996079746, + 1305909868, 1348028973, 122275988, 464050036, + 692807777, 1098809324, 397235220, 596459886, + 1663209783, 720230826, 1422510715, 1760654694, + 544197700, 1417744567, 1938716517, 1571826328, + 1591430185, 1173137446, 175285007, 1541718596, + 1715958587, 1429966110, 583013357, 1667787861, + 109891172, 668253167, 161783842, 296183397, + 1681897325, 1054396117, 264741948, 464026995, + 1907686022, 1532786783, 394869458, 1766734740, + 136047179, 536856195, 376188855, 700633625, + 515518419, 531043483, 60673499, 556496527, + 1743028981, 873954569, 1371062291, 632169731, + 1353239206, 526507035, 1894490088, 589441599, + 1610487168, 1074160583, 366366374, 247602990, + 1535354896, 894493713, 1555870413, 1389854934, + 1897251683, 1525812801, 675621735, 697919636, + 1690274072, 1466810921, 1221110784, 1741995587, + 1877169764, 390876982, 1794129810, 297662156, + 144295349, 417037264, 1290835727, 1654971513, + 1674131303, 1625667423, 1471248832, 1676797844, + 1172916558, 1707775403, 423725211, 1643279661, + 1695774264, 378140395, 1517569394, 1666625392, + 1803981250, 439036260, 247966130, 709534816, + 361144100, 1546096548, 1240886454, 1898161518, + 843262057, 1709259464, 1301015977, 1997626928, + 677153173, 1606710353, 1216038070, 435565562, + 98686333, 1773787396, 267051994, 99395396, + 545509105, 782289675, 1289865975, 1707775075, + 1158993015, 1506576588, 993215179, 1523099397, + 923914455, 1895162386, 284489994, 1444139016, + 1943825680, 466202724, 1632522710, 1384015062, + 723147188, 1284031324, 1430481515, 341213007, + 171192499, 1061688239, 808927167, 83182639, + 759209907, 1728321272, 976049976, 1652071995, + 1002877840, 69880246, 1095135165, 677588420, + 1384715290, 829619452, 170122781, 1958173727, + 13389238, 789379698, 1883383039, 1279195174, + 1618672336, 1192839317, 1348311124, 758896285, + 1939775389, 684108413, 1838340479, 1332232130, + 1070486028, 549228790, 868851698, 1678207843, + 1754321489, 637000403, 647901906, 45343322, + 1768524074, 1167955205, 1816497210, 1609414096, + 1985231742, 1540534482, 232730819, 232221968, + 1509637836, 1480860627, 884647789, 1096458024, + 163721583, 1248032262, 436419506, 1737102298, + 651105860, 452298073, 1064372507, 1792838683, + 619243471, 860127631, 721724708, 950768433, + 279913448, 339693210, 47730422, 1952683911, + 1316500770, 675944216, 386902809, 619333956, + 1194800389, 43989936, 1944372656, 666045666, + 1155873844, 522696968, 58874730, 1497238023, + 421619994, 1980672127, 1657191856, 1913792631, + 1784663131, 1118400672, 1828104993, 1637808383, + 414755472, 775410449, 747132157, 136820101, + 1082674285, 93190395, 357955402, 335652723, + 1192102705, 480365232, 1354935730, 1391829361, + 966662991, 1601510445, 569528575, 545490940, + 1753711688, 807025222, 580374183, 587718008, + 977546290, 1055719519, 1157107032, 562799608, + 859466927, 840450024, 815325134, 936576801, + 1010587056, 246624382, 1808049797, 1098183398, + 1005077390, 772432546, 1976629565, 1003772218, + 1655315418, 1767931114, 982008720, 785023351, ]; let prod_spec_evals: Array> = @@ -504,647 +207,168 @@ fn build_test_program(builder: &mut Builder) { builder.set(&prod_spec_evals, idx, e); } + #[rustfmt::skip] let logup_spec_eval_u32s = [ - 1522353967u32, - 457603397, - 421847521, - 1352563318, - 1746817766, - 737872688, - 1087008622, - 1850835028, - 456475558, - 892966330, - 638163666, - 148568548, - 678863061, - 1334386850, - 1896333039, - 154585769, - 433618446, - 1186936470, - 970218722, - 1213827097, - 1798557019, - 861757965, - 119285527, - 395360622, - 226164366, - 1330279872, - 66561048, - 785421608, - 1950755756, - 1559889596, - 348449876, - 1090789452, - 257578851, - 273164442, - 1644906, - 295600924, - 1187949602, - 1168249609, - 469763604, - 60929061, - 291163036, - 403842501, - 1421902433, - 1700188477, - 1046093370, - 921059131, - 1638991894, - 464012042, - 96905857, - 1370999592, - 271896041, - 13595534, - 1489760970, - 1650552701, - 133367846, - 25680377, - 377631580, - 652729291, - 645763356, - 426747355, - 482475486, - 1877299223, - 103226636, - 1333832358, - 1399609097, - 458536972, - 976248802, - 1109365280, - 515164588, - 1579426417, - 1601829549, - 607169702, - 852817956, - 1980537127, - 134138338, - 913344050, - 737880920, - 476360275, - 61624034, - 1610624252, - 264461991, - 546933535, - 937769429, - 293346965, - 1522058041, - 1012551797, - 994330314, - 23333322, - 1969510890, - 974351570, - 2012030621, - 120742000, - 450250620, - 180547360, - 642746933, - 1815029950, - 629489142, - 1176992624, - 723354779, - 572648755, - 1218615348, - 648847054, - 351903235, - 723149764, - 248065753, - 243829448, - 1283393001, - 1912627886, - 581641342, - 702465306, - 205969758, - 1061911274, - 1, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 1703043252, - 1467887451, - 1714319214, - 907866644, - 1542426838, - 742609036, - 1814393459, - 448706641, - 1960340767, - 46490834, - 186512520, - 363973095, - 846448854, - 463742343, - 2012517527, - 40473617, - 9472552, - 263483342, - 105738598, - 586389136, - 254290990, - 625150844, - 960233097, - 1488303724, - 1700231692, - 1471714612, - 1540211186, - 1590246915, - 945341972, - 1343225515, - 179976237, - 34857822, - 276912528, - 984309272, - 1277293398, - 1520924162, - 1823117694, - 604836357, - 1460812009, - 600052559, - 970469338, - 1771022707, - 181855831, - 1445947220, - 467514809, - 1514677498, - 947030389, - 170390653, - 415409007, - 1601463730, - 204153427, - 904614278, - 1855419512, - 2009471607, - 1352607379, - 576586082, - 1343812879, - 1176377580, - 1166188815, - 1592289048, - 761793881, - 1529621462, - 193034837, - 344011596, - 1669461833, - 1356800025, - 314186361, - 586497329, - 1832810846, - 1288092861, - 1619454491, - 732529408, - 737934269, - 909504928, - 769680420, - 1437893101, - 1727002258, - 1618231110, - 535125583, - 153412473, - 1917760929, - 588586507, - 564531165, - 1790797737, - 1666283994, - 1366948884, - 117673690, - 476470378, - 2012274032, - 1951406668, - 1739767532, - 1273142151, - 1591812317, - 1900205312, - 1912608761, - 1734766024, - 1265002082, - 1450462894, - 749810837, - 1329222552, - 745081805, - 1231519431, - 1420957967, - 883846107, - 1995463911, - 407795592, - 161655852, - 125886157, - 995318920, - 484905024, - 284135318, - 551493419, - 406742309, - 1089024446, - 637339867, - 1858138403, - 1230680117, - 187078889, - 1929517480, - 1125646261, - 1, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 1610035932, - 462442436, - 831412555, - 44798862, - 1748147276, - 1911945531, - 1329343740, - 971894393, - 362147969, - 1583335926, - 1528700112, - 426908674, - 847905883, - 447889090, - 1050883911, - 1883537469, - 1487501632, - 964178870, - 1818828551, - 1980840799, - 340372118, - 1697179193, - 215113037, - 1893217470, - 1138628493, - 1788052486, - 443362955, - 1349213730, - 589553425, - 562526667, - 1006040406, - 1194546769, - 1831034644, - 612004157, - 730213913, - 1068905440, - 371983982, - 502900790, - 802785198, - 822377635, - 1477528437, - 501356237, - 684668525, - 1306043781, - 621032592, - 1971342708, - 1411586583, - 733418745, - 186045462, - 1559301855, - 323758310, - 453170140, - 498381240, - 976247416, - 631213663, - 898017829, - 501459603, - 609703046, - 1379288251, - 177682695, - 912381595, - 121915494, - 1137416430, - 504054388, - 1138277238, - 1603388253, - 1838013301, - 1700271853, - 20488607, - 58775264, - 217974275, - 979141729, - 53136584, - 1331566240, - 1460303356, - 525812787, - 718385521, - 1477919263, - 1663622276, - 1089788203, - 1204483837, - 54225863, - 290660186, - 1441441958, - 134168813, - 349638823, - 1867912015, - 1579183319, - 55528656, - 1602973359, - 194297109, - 949763297, - 101931919, - 242300116, - 1610052257, - 1351823848, - 174522860, - 776955925, - 1706962365, - 808187490, - 1487253852, - 431806906, - 213982593, - 1170647308, - 1776840400, - 295916317, - 378708073, - 381270341, - 457494568, - 705823997, - 1407301442, - 1693003013, - 700310785, - 1349874247, - 1284363817, - 1566253815, - 1014298154, - 215294365, - 1070968678, - 871641358, - 1, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 1302679751, - 1121894357, - 368587356, - 1564724097, - 733815591, - 2012670011, - 1146780092, - 1439780227, - 1801628424, - 838692317, - 932318853, - 213634365, - 155292454, - 1644317110, - 1599846194, - 978829059, - 1282095862, - 1780431647, - 527412087, - 1024583705, - 804423802, - 951808322, - 689345230, - 180304167, - 1784562773, - 1514653374, - 2009396440, - 1143778943, - 235299446, - 1553017484, - 475425117, - 758292254, - 716575432, - 517083432, - 1728864125, - 418010549, - 43202592, - 507659742, - 433077118, - 1268144019, - 1462778342, - 1928073362, - 1330130180, - 1749624351, - 827401013, - 1236194147, - 1875519726, - 1437946791, - 607293265, - 309229599, - 1009445595, - 1725229718, - 1436309341, - 1952606463, - 943149111, - 291680468, - 1989684076, - 1944713370, - 1285294139, - 399758737, - 1572979232, - 213817406, - 214840530, - 184898060, - 1483844295, - 1536616777, - 494816009, - 217625163, - 529448032, - 786640964, - 1766471731, - 1424140424, - 1721961711, - 740275169, - 169908711, - 913969302, - 1359358267, - 1328322971, - 593228769, - 771095186, - 801680440, - 450930656, - 1796349530, - 1824428677, - 1111258504, - 1741666629, - 1098430204, - 1792001884, - 1679003061, - 590088446, - 647614538, - 1324461639, - 818996796, - 229187928, - 74288115, - 1158900266, - 1512606270, - 1381672753, - 785927403, - 493453164, - 425259497, - 1367873539, - 931023744, - 221202218, - 669580668, - 424996238, - 1840425275, - 1873362670, - 967642716, - 263556335, - 578560519, - 1558449223, - 607579284, - 1724012378, - 333582342, - 1195784167, - 1419727276, - 199294290, - 138807165, - 1061030752, - 1, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 776332180, - 1333076185, - 1855163818, - 1897408938, - 799274251, - 950452503, - 691904988, - 1205387466, - 659107883, - 434394982, - 129587940, - 639018629, - 659238594, - 1957584892, - 864291238, - 589178070, - 1267157231, - 48925338, - 200093884, - 1953762869, - 1227617341, - 1471420621, - 193077633, - 1007876111, - 228491220, - 1377349503, - 1889411060, - 1807513892, - 1593042934, - 1240864695, - 1472870721, - 583021932, - 598239104, - 1862008818, - 1811242869, - 780768026, - 520870395, - 292016292, - 322246659, - 868240490, - 1715620331, - 1183509209, - 2010262726, - 1003957251, - 264895455, - 307755941, - 201990485, - 1662471178, - 1643997923, - 1573129362, - 277821143, - 388834470, - 943361405, - 1449402196, - 614413575, - 1504113993, - 1860552739, - 1755127315, - 1734129760, - 1232115188, - 803035456, - 360488092, - 271342171, - 1269544258, - 290642673, - 660703582, - 986842267, - 870891877, - 454573044, - 1999346236, - 701614601, - 820253867, - 883282765, - 137247873, - 1727164949, - 1320585493, - 1738664600, - 1900116905, - 472215154, - 1114994489, - 104218174, - 1694603079, - 771486383, - 935361143, - 92277671, - 881040480, - 925124484, - 1464396527, - 100625197, - 65290355, - 1001454341, - 134627585, - 58629702, - 1541542242, - 568583607, - 1706262052, - 530687550, - 1303187245, - 1010302462, - 264001857, - 789816678, - 561378226, - 827432508, - 801307507, - 1613508315, - 1650822853, - 1603502703, - 439320335, - 15283580, - 1244486577, - 254345266, - 1745653280, - 1648250354, - 1528271018, - 528366563, - 1078707735, - 1430767759, - 1890467731, - 2001894083, - 799949326, - 1, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 1341839494, - 1092219735, - 755644898, - 966729319, - 1914277278, - 1545367697, - 1765189119, - 1693413008, + 1522353967u32, 457603397, 421847521, 1352563318, + 1746817766, 737872688, 1087008622, 1850835028, + 456475558, 892966330, 638163666, 148568548, + 678863061, 1334386850, 1896333039, 154585769, + 433618446, 1186936470, 970218722, 1213827097, + 1798557019, 861757965, 119285527, 395360622, + 226164366, 1330279872, 66561048, 785421608, + 1950755756, 1559889596, 348449876, 1090789452, + 257578851, 273164442, 1644906, 295600924, + 1187949602, 1168249609, 469763604, 60929061, + 291163036, 403842501, 1421902433, 1700188477, + 1046093370, 921059131, 1638991894, 464012042, + 96905857, 1370999592, 271896041, 13595534, + 1489760970, 1650552701, 133367846, 25680377, + 377631580, 652729291, 645763356, 426747355, + 482475486, 1877299223, 103226636, 1333832358, + 1399609097, 458536972, 976248802, 1109365280, + 515164588, 1579426417, 1601829549, 607169702, + 852817956, 1980537127, 134138338, 913344050, + 737880920, 476360275, 61624034, 1610624252, + 264461991, 546933535, 937769429, 293346965, + 1522058041, 1012551797, 994330314, 23333322, + 1969510890, 974351570, 2012030621, 120742000, + 450250620, 180547360, 642746933, 1815029950, + 629489142, 1176992624, 723354779, 572648755, + 1218615348, 648847054, 351903235, 723149764, + 248065753, 243829448, 1283393001, 1912627886, + 581641342, 702465306, 205969758, 1061911274, + 1, 0, 0, 0, + 1, 0, 0, 0, + 1703043252, 1467887451, 1714319214, 907866644, + 1542426838, 742609036, 1814393459, 448706641, + 1960340767, 46490834, 186512520, 363973095, + 846448854, 463742343, 2012517527, 40473617, + 9472552, 263483342, 105738598, 586389136, + 254290990, 625150844, 960233097, 1488303724, + 1700231692, 1471714612, 1540211186, 1590246915, + 945341972, 1343225515, 179976237, 34857822, + 276912528, 984309272, 1277293398, 1520924162, + 1823117694, 604836357, 1460812009, 600052559, + 970469338, 1771022707, 181855831, 1445947220, + 467514809, 1514677498, 947030389, 170390653, + 415409007, 1601463730, 204153427, 904614278, + 1855419512, 2009471607, 1352607379, 576586082, + 1343812879, 1176377580, 1166188815, 1592289048, + 761793881, 1529621462, 193034837, 344011596, + 1669461833, 1356800025, 314186361, 586497329, + 1832810846, 1288092861, 1619454491, 732529408, + 737934269, 909504928, 769680420, 1437893101, + 1727002258, 1618231110, 535125583, 153412473, + 1917760929, 588586507, 564531165, 1790797737, + 1666283994, 1366948884, 117673690, 476470378, + 2012274032, 1951406668, 1739767532, 1273142151, + 1591812317, 1900205312, 1912608761, 1734766024, + 1265002082, 1450462894, 749810837, 1329222552, + 745081805, 1231519431, 1420957967, 883846107, + 1995463911, 407795592, 161655852, 125886157, + 995318920, 484905024, 284135318, 551493419, + 406742309, 1089024446, 637339867, 1858138403, + 1230680117, 187078889, 1929517480, 1125646261, + 1, 0, 0, 0, + 1, 0, 0, 0, + 1610035932, 462442436, 831412555, 44798862, + 1748147276, 1911945531, 1329343740, 971894393, + 362147969, 1583335926, 1528700112, 426908674, + 847905883, 447889090, 1050883911, 1883537469, + 1487501632, 964178870, 1818828551, 1980840799, + 340372118, 1697179193, 215113037, 1893217470, + 1138628493, 1788052486, 443362955, 1349213730, + 589553425, 562526667, 1006040406, 1194546769, + 1831034644, 612004157, 730213913, 1068905440, + 371983982, 502900790, 802785198, 822377635, + 1477528437, 501356237, 684668525, 1306043781, + 621032592, 1971342708, 1411586583, 733418745, + 186045462, 1559301855, 323758310, 453170140, + 498381240, 976247416, 631213663, 898017829, + 501459603, 609703046, 1379288251, 177682695, + 912381595, 121915494, 1137416430, 504054388, + 1138277238, 1603388253, 1838013301, 1700271853, + 20488607, 58775264, 217974275, 979141729, + 53136584, 1331566240, 1460303356, 525812787, + 718385521, 1477919263, 1663622276, 1089788203, + 1204483837, 54225863, 290660186, 1441441958, + 134168813, 349638823, 1867912015, 1579183319, + 55528656, 1602973359, 194297109, 949763297, + 101931919, 242300116, 1610052257, 1351823848, + 174522860, 776955925, 1706962365, 808187490, + 1487253852, 431806906, 213982593, 1170647308, + 1776840400, 295916317, 378708073, 381270341, + 457494568, 705823997, 1407301442, 1693003013, + 700310785, 1349874247, 1284363817, 1566253815, + 1014298154, 215294365, 1070968678, 871641358, + 1, 0, 0, 0, + 1, 0, 0, 0, + 1302679751, 1121894357, 368587356, 1564724097, + 733815591, 2012670011, 1146780092, 1439780227, + 1801628424, 838692317, 932318853, 213634365, + 155292454, 1644317110, 1599846194, 978829059, + 1282095862, 1780431647, 527412087, 1024583705, + 804423802, 951808322, 689345230, 180304167, + 1784562773, 1514653374, 2009396440, 1143778943, + 235299446, 1553017484, 475425117, 758292254, + 716575432, 517083432, 1728864125, 418010549, + 43202592, 507659742, 433077118, 1268144019, + 1462778342, 1928073362, 1330130180, 1749624351, + 827401013, 1236194147, 1875519726, 1437946791, + 607293265, 309229599, 1009445595, 1725229718, + 1436309341, 1952606463, 943149111, 291680468, + 1989684076, 1944713370, 1285294139, 399758737, + 1572979232, 213817406, 214840530, 184898060, + 1483844295, 1536616777, 494816009, 217625163, + 529448032, 786640964, 1766471731, 1424140424, + 1721961711, 740275169, 169908711, 913969302, + 1359358267, 1328322971, 593228769, 771095186, + 801680440, 450930656, 1796349530, 1824428677, + 1111258504, 1741666629, 1098430204, 1792001884, + 1679003061, 590088446, 647614538, 1324461639, + 818996796, 229187928, 74288115, 1158900266, + 1512606270, 1381672753, 785927403, 493453164, + 425259497, 1367873539, 931023744, 221202218, + 669580668, 424996238, 1840425275, 1873362670, + 967642716, 263556335, 578560519, 1558449223, + 607579284, 1724012378, 333582342, 1195784167, + 1419727276, 199294290, 138807165, 1061030752, + 1, 0, 0, 0, + 1, 0, 0, 0, + 776332180, 1333076185, 1855163818, 1897408938, + 799274251, 950452503, 691904988, 1205387466, + 659107883, 434394982, 129587940, 639018629, + 659238594, 1957584892, 864291238, 589178070, + 1267157231, 48925338, 200093884, 1953762869, + 1227617341, 1471420621, 193077633, 1007876111, + 228491220, 1377349503, 1889411060, 1807513892, + 1593042934, 1240864695, 1472870721, 583021932, + 598239104, 1862008818, 1811242869, 780768026, + 520870395, 292016292, 322246659, 868240490, + 1715620331, 1183509209, 2010262726, 1003957251, + 264895455, 307755941, 201990485, 1662471178, + 1643997923, 1573129362, 277821143, 388834470, + 943361405, 1449402196, 614413575, 1504113993, + 1860552739, 1755127315, 1734129760, 1232115188, + 803035456, 360488092, 271342171, 1269544258, + 290642673, 660703582, 986842267, 870891877, + 454573044, 1999346236, 701614601, 820253867, + 883282765, 137247873, 1727164949, 1320585493, + 1738664600, 1900116905, 472215154, 1114994489, + 104218174, 1694603079, 771486383, 935361143, + 92277671, 881040480, 925124484, 1464396527, + 100625197, 65290355, 1001454341, 134627585, + 58629702, 1541542242, 568583607, 1706262052, + 530687550, 1303187245, 1010302462, 264001857, + 789816678, 561378226, 827432508, 801307507, + 1613508315, 1650822853, 1603502703, 439320335, + 15283580, 1244486577, 254345266, 1745653280, + 1648250354, 1528271018, 528366563, 1078707735, + 1430767759, 1890467731, 2001894083, 799949326, + 1, 0, 0, 0, + 1, 0, 0, 0, + 1341839494, 1092219735, 755644898, 966729319, + 1914277278, 1545367697, 1765189119, 1693413008, ]; let logup_spec_evals: Array> = @@ -1160,75 +384,25 @@ fn build_test_program(builder: &mut Builder) { builder.set(&logup_spec_evals, idx, e); } + #[rustfmt::skip] let r_evals_u32s = [ - 941378355u32, - 1078920879, - 696738840, - 496039492, - 1555445457, - 184545404, - 905938226, - 1847966044, - 1024875886, - 1782716223, - 1625644635, - 266865456, - 465953066, - 1663531470, - 757423849, - 1957075986, - 1919693393, - 839104130, - 127480221, - 1527842912, - 918650796, - 921462354, - 575456073, - 696646705, - 1585912361, - 258186488, - 353168830, - 1111094691, - 1401166558, - 1905942163, - 1923083163, - 393037255, - 1042127700, - 1126793296, - 895794165, - 1124924482, - 1324266058, - 722406365, - 1963838171, - 968504459, - 1934378800, - 714588691, - 6465911, - 1168379648, - 903786009, - 1326035939, - 518289228, - 418998914, - 1513133474, - 1578096058, - 617547414, - 1658315126, - 68556894, - 1697802593, - 1346510664, - 1709381671, - 345062962, - 1254089535, - 1002281845, - 1882822096, - 700581748, - 1431345304, - 489112954, - 98435728, - 1799886007, - 479788390, - 223111065, - 631662309, + 941378355u32, 1078920879, 696738840, 496039492, + 1555445457, 184545404, 905938226, 1847966044, + 1024875886, 1782716223, 1625644635, 266865456, + 465953066, 1663531470, 757423849, 1957075986, + 1919693393, 839104130, 127480221, 1527842912, + 918650796, 921462354, 575456073, 696646705, + 1585912361, 258186488, 353168830, 1111094691, + 1401166558, 1905942163, 1923083163, 393037255, + 1042127700, 1126793296, 895794165, 1124924482, + 1324266058, 722406365, 1963838171, 968504459, + 1934378800, 714588691, 6465911, 1168379648, + 903786009, 1326035939, 518289228, 418998914, + 1513133474, 1578096058, 617547414, 1658315126, + 68556894, 1697802593, 1346510664, 1709381671, + 345062962, 1254089535, 1002281845, 1882822096, + 700581748, 1431345304, 489112954, 98435728, + 1799886007, 479788390, 223111065, 631662309, ]; let next_layer_evals: Array> = From 939cd1e44776f38b6287e7c36a33460efaa09756 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Fri, 28 Nov 2025 21:34:39 +0800 Subject: [PATCH 11/18] wip9 --- .../native/circuit/src/sumcheck/chip.rs | 105 +++++++++++------- .../native/circuit/src/sumcheck/execution.rs | 54 +++++---- 2 files changed, 92 insertions(+), 67 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index 680dfcf31b..cbe38fded4 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -27,8 +27,8 @@ use crate::{ }; pub(crate) const CONTEXT_ARR_BASE_LEN: usize = EXT_DEG * 2; -const CURRENT_LAYER_MODE: u32 = 1; -const NEXT_LAYER_MODE: u32 = 0; +pub(crate) const CURRENT_LAYER_MODE: u32 = 1; +pub(crate) const NEXT_LAYER_MODE: u32 = 0; pub(crate) fn calculate_3d_ext_idx( inner_inner_len: u32, @@ -144,13 +144,15 @@ where .alloc(MultiRowLayout::new(NativeSumcheckMetadata { num_rows })) .0; - let mut cur_timestamp = 0; + let mut cur_timestamp = state.memory.timestamp(); // head row let head_row: &mut NativeSumcheckCols = &mut rows[0]; let head_specific: &mut HeaderSpecificCols = head_row.specific[..HeaderSpecificCols::::width()].borrow_mut(); head_row.header_row = F::ONE; + head_row.first_timestamp = F::from_canonical_u32(cur_timestamp); + head_row.start_timestamp = F::from_canonical_u32(cur_timestamp); head_specific.pc = F::from_canonical_u32(*state.pc); @@ -210,6 +212,7 @@ where // all rows share same register values, ctx, challenges for row in rows.iter_mut() { + // c1, c2 are same during the entire execution row.challenges[EXT_DEG..3 * EXT_DEG].copy_from_slice(&challenges[EXT_DEG..3 * EXT_DEG]); row.alpha = alpha; row.ctx = ctx; @@ -231,8 +234,8 @@ where prod_row.specific[..ProdSpecificCols::::width()].borrow_mut(); prod_row.prod_row = F::ONE; - prod_row.curr_prod_n = F::from_canonical_usize(i); - prod_row.start_timestamp = F::from_canonical_usize(cur_timestamp); + prod_row.curr_prod_n = F::from_canonical_usize(i + 1); // curr_prod_n starts from 1 + prod_row.start_timestamp = F::from_canonical_u32(cur_timestamp); // read max_round let [max_round]: [F; 1] = tracing_read_native_helper( @@ -245,8 +248,9 @@ where prod_row.challenges[0..EXT_DEG].copy_from_slice(&alpha_acc); prod_row.max_round = max_round; + let max_round = max_round.as_canonical_u32(); // round starts from 0 - if round < max_round.as_canonical_u32() - 1 { + if round < max_round - 1 { prod_row.within_round_limit = F::ONE; let start = calculate_3d_ext_idx( prod_specs_inner_inner_len, @@ -296,12 +300,12 @@ where eval, &mut prod_specific.write_record, ); - cur_timestamp += 1; + cur_timestamp += 2; let eval_rlc = FieldExtension::multiply(alpha_acc, eval); prod_specific.eval_rlc = eval_rlc; - if mode == NEXT_LAYER_MODE && round < max_round.as_canonical_u32() - 2 { + if mode == NEXT_LAYER_MODE && round + 1 < max_round - 1 { eval_acc = FieldExtension::add(eval_acc, eval_rlc); prod_row.should_acc = F::ONE; prod_row.eval_acc = eval_acc; @@ -309,7 +313,6 @@ where } alpha_acc = FieldExtension::multiply(alpha_acc, alpha); - prod_row.challenges[0..EXT_DEG].copy_from_slice(&alpha_acc); } // logup rows @@ -318,8 +321,8 @@ where logup_row.specific[..LogupSpecificCols::::width()].borrow_mut(); logup_row.logup_row = F::ONE; - logup_row.curr_logup_n = F::from_canonical_usize(i); - logup_row.start_timestamp = F::from_canonical_usize(cur_timestamp); + logup_row.curr_logup_n = F::from_canonical_usize(i + 1); // curr_logup_n starts from 1 + logup_row.start_timestamp = F::from_canonical_u32(cur_timestamp); let [max_round]: [F; 1] = tracing_read_native_helper( state.memory, @@ -334,7 +337,8 @@ where logup_row.challenges[0..EXT_DEG].copy_from_slice(&alpha_acc); logup_row.challenges[2 * EXT_DEG..(3 * EXT_DEG)].copy_from_slice(&alpha_denominator); - if round < max_round.as_canonical_u32() - 1 { + let max_round = max_round.as_canonical_u32(); + if round < max_round - 1 { logup_row.within_round_limit = F::ONE; let start = calculate_3d_ext_idx( prod_specs_inner_inner_len, @@ -410,13 +414,13 @@ where ); cur_timestamp += 3; // 1 read, 2 writes - let eval = FieldExtension::add( + let eval_rlc = FieldExtension::add( FieldExtension::multiply(alpha_numerator, p_eval), FieldExtension::multiply(alpha_denominator, q_eval), ); - logup_specific.eval_rlc = eval; - if mode == NEXT_LAYER_MODE && round < max_round.as_canonical_u32() - 2 { - eval_acc = FieldExtension::add(eval_acc, eval); + logup_specific.eval_rlc = eval_rlc; + if mode == NEXT_LAYER_MODE && round + 1 < max_round - 1 { + eval_acc = FieldExtension::add(eval_acc, eval_rlc); logup_row.should_acc = F::ONE; logup_row.logup_acc = F::ONE; logup_row.eval_acc = eval_acc; @@ -427,6 +431,8 @@ where } let head_row = &mut rows[0]; + head_row.last_timestamp = F::from_canonical_u32(cur_timestamp + 1); + let head_specific: &mut HeaderSpecificCols = head_row.specific[..HeaderSpecificCols::::width()].borrow_mut(); @@ -452,7 +458,9 @@ impl TraceFiller for NativeSumcheckFiller { fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { let cols: &mut NativeSumcheckCols = row_slice.borrow_mut(); let start_timestamp = cols.start_timestamp.as_canonical_u32(); + let last_timestamp = cols.last_timestamp.as_canonical_u32(); + println!("start_timestamp: {}, cols.header_row: {:?}, prod_row: {:?}, logup_row: {:?}", start_timestamp, cols.header_row, cols.prod_row, cols.logup_row); if cols.header_row == F::ONE { let header: &mut HeaderSpecificCols = cols.specific[..HeaderSpecificCols::::width()].borrow_mut(); @@ -460,58 +468,69 @@ impl TraceFiller for NativeSumcheckFiller { for i in 0..7usize { mem_fill_helper( mem_helper, - start_timestamp + i, + start_timestamp + i as u32, header.read_records[i].as_mut(), ); } mem_fill_helper( mem_helper, - start_timestamp + 7, + last_timestamp - 1, header.write_records.as_mut(), ); } else if cols.prod_row == F::ONE { let prod_row_specific: &mut ProdSpecificCols = cols.specific[..ProdSpecificCols::::width()].borrow_mut(); + // read max_round mem_fill_helper( mem_helper, start_timestamp, prod_row_specific.read_records[0].as_mut(), ); - mem_fill_helper( - mem_helper, - start_timestamp + 1, - prod_row_specific.read_records[1].as_mut(), - ); - mem_fill_helper( - mem_helper, - start_timestamp + 2, - prod_row_specific.write_record.as_mut(), - ); + if cols.within_round_limit == F::ONE { + // read p1, p2 + mem_fill_helper( + mem_helper, + start_timestamp + 1, + prod_row_specific.read_records[1].as_mut(), + ); + // write p_eval + mem_fill_helper( + mem_helper, + start_timestamp + 2, + prod_row_specific.write_record.as_mut(), + ); + } } else if cols.logup_row == F::ONE { let logup_row_specific: &mut LogupSpecificCols = cols.specific[..LogupSpecificCols::::width()].borrow_mut(); + // read max_round mem_fill_helper( mem_helper, start_timestamp, logup_row_specific.read_records[0].as_mut(), ); - mem_fill_helper( - mem_helper, - start_timestamp + 1, - logup_row_specific.read_records[1].as_mut(), - ); - mem_fill_helper( - mem_helper, - start_timestamp + 2, - logup_row_specific.write_records[0].as_mut(), - ); - mem_fill_helper( - mem_helper, - start_timestamp + 3, - logup_row_specific.write_records[1].as_mut(), - ); + if cols.within_round_limit == F::ONE { + // read p1, p2, q1, q2 + mem_fill_helper( + mem_helper, + start_timestamp + 1, + logup_row_specific.read_records[1].as_mut(), + ); + // write p_eval + mem_fill_helper( + mem_helper, + start_timestamp + 2, + logup_row_specific.write_records[0].as_mut(), + ); + // write q_eval + mem_fill_helper( + mem_helper, + start_timestamp + 3, + logup_row_specific.write_records[1].as_mut(), + ); + } } } } diff --git a/extensions/native/circuit/src/sumcheck/execution.rs b/extensions/native/circuit/src/sumcheck/execution.rs index a4b4f484f3..7202e57b00 100644 --- a/extensions/native/circuit/src/sumcheck/execution.rs +++ b/extensions/native/circuit/src/sumcheck/execution.rs @@ -11,7 +11,10 @@ use openvm_stark_backend::p3_field::PrimeField32; use crate::{ field_extension::{FieldExtension, EXT_DEG}, fri::elem_to_ext, - sumcheck::chip::{calculate_3d_ext_idx, NativeSumcheckExecutor}, + sumcheck::chip::{ + calculate_3d_ext_idx, NativeSumcheckExecutor, CONTEXT_ARR_BASE_LEN, CURRENT_LAYER_MODE, + NEXT_LAYER_MODE, + }, }; #[derive(AlignedBytesBorrow, Clone)] @@ -209,7 +212,7 @@ unsafe fn execute_e12_impl( .map(|x: F| x.as_canonical_u32()); let [round, num_prod_spec, num_logup_spec, prod_specs_inner_len, prod_specs_inner_inner_len, logup_specs_inner_len, logup_specs_inner_inner_len, mode] = ctx; - let challenges: [F; EXT_DEG * 3] = + let challenges: [F; EXT_DEG * 4] = exec_state.vm_read(NATIVE_AS, challenges_ptr.as_canonical_u32()); let alpha: [F; EXT_DEG] = challenges[0..EXT_DEG].try_into().unwrap(); let c1: [F; EXT_DEG] = challenges[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); @@ -219,9 +222,10 @@ unsafe fn execute_e12_impl( let mut alpha_acc = elem_to_ext(F::ONE); let mut eval_acc = elem_to_ext(F::ZERO); + let prod_offset = ctx_ptr_u32 + CONTEXT_ARR_BASE_LEN as u32; for i in 0..num_prod_spec { let [max_round]: [u32; 1] = exec_state - .vm_read(NATIVE_AS, ctx_ptr_u32 + 8) + .vm_read(NATIVE_AS, prod_offset + i) .map(|x: F| x.as_canonical_u32()); let start = calculate_3d_ext_idx( @@ -238,17 +242,17 @@ unsafe fn execute_e12_impl( let p2: [F; EXT_DEG] = ps[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); let eval = match mode { - 1 => FieldExtension::multiply(p1, p2), - 0 => FieldExtension::add( + CURRENT_LAYER_MODE => FieldExtension::multiply(p1, p2), + NEXT_LAYER_MODE => FieldExtension::add( FieldExtension::multiply(p1, c1), FieldExtension::multiply(p2, c2), ), - _ => unreachable!("mode can only be 0 or 1"), + _ => unreachable!("mode can only be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), }; - exec_state.vm_write(NATIVE_AS, r_evals_ptr_u32 + 1 + i, &eval); + exec_state.vm_write(NATIVE_AS, r_evals_ptr_u32 + (1 + i) * EXT_DEG as u32, &eval); - if round + mode < max_round - 1 { + if mode == NEXT_LAYER_MODE && round + 1 < max_round - 1 { // update eval_acc eval_acc = FieldExtension::add(eval_acc, FieldExtension::multiply(alpha_acc, eval)); } @@ -259,10 +263,11 @@ unsafe fn execute_e12_impl( height += 1; } + let logup_offset = ctx_ptr_u32 + CONTEXT_ARR_BASE_LEN as u32 + num_prod_spec; for i in 0..num_logup_spec { // read max_round let [max_round]: [u32; 1] = exec_state - .vm_read(NATIVE_AS, ctx_ptr_u32 + 8 + num_prod_spec + i) + .vm_read(NATIVE_AS, logup_offset + i) .map(|x: F| x.as_canonical_u32()); let start = calculate_3d_ext_idx( logup_specs_inner_len, @@ -272,6 +277,9 @@ unsafe fn execute_e12_impl( 0, ); + let alpha_denominator = FieldExtension::multiply(alpha_acc, alpha); + let alpha_numerator = alpha_acc; + if round < max_round - 1 { // read logup_evals let pqs: [F; EXT_DEG * 4] = exec_state.vm_read(NATIVE_AS, logup_evals_ptr + start); @@ -282,23 +290,23 @@ unsafe fn execute_e12_impl( // compute p_eval and q_eval let p_eval = match mode { - 1 => FieldExtension::add( + CURRENT_LAYER_MODE => FieldExtension::add( FieldExtension::multiply(p1, q2), FieldExtension::multiply(p2, q1), ), - 0 => FieldExtension::add( + NEXT_LAYER_MODE => FieldExtension::add( FieldExtension::multiply(p1, c1), FieldExtension::multiply(p2, c2), ), - _ => unreachable!("mode can only be 0 or 1"), + _ => unreachable!("mode can only be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), }; let q_eval = match mode { - 1 => FieldExtension::multiply(q1, q2), - 0 => FieldExtension::add( + CURRENT_LAYER_MODE => FieldExtension::multiply(q1, q2), + NEXT_LAYER_MODE => FieldExtension::add( FieldExtension::multiply(q1, c1), FieldExtension::multiply(q2, c2), ), - _ => unreachable!("mode can only be 0 or 1"), + _ => unreachable!("mode can only be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), }; // write eval to r_evals @@ -313,20 +321,18 @@ unsafe fn execute_e12_impl( &q_eval, ); - let alpha_denominator = FieldExtension::multiply(alpha_acc, alpha); - let alpha_numerator = alpha_acc; - - if round + mode < max_round - 1 { + let eval_rlc = FieldExtension::add( + FieldExtension::multiply(alpha_numerator, p_eval), + FieldExtension::multiply(alpha_denominator, q_eval), + ); + if mode == NEXT_LAYER_MODE && round + 1 < max_round - 1 { // update eval_acc - eval_acc = FieldExtension::add( - FieldExtension::multiply(alpha_numerator, p_eval), - FieldExtension::multiply(alpha_denominator, q_eval), - ); + eval_acc = FieldExtension::add(eval_acc, eval_rlc); } } // update alpha_acc - alpha_acc = FieldExtension::multiply(alpha_acc, FieldExtension::multiply(alpha, alpha)); + alpha_acc = FieldExtension::multiply(alpha_denominator, alpha); height += 1; } From 6cbfc49349e2935cd574e1f14b14bd347130594a Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Sat, 29 Nov 2025 17:19:40 +0800 Subject: [PATCH 12/18] wip10 --- extensions/native/circuit/src/sumcheck/chip.rs | 7 ++++--- extensions/native/recursion/tests/sumcheck.rs | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index cbe38fded4..a5a89042e2 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -201,6 +201,7 @@ where head_specific.read_records[6].as_mut(), ); cur_timestamp += 7; // 5 register reads + ctx read + challenges read + head_row.challenges.copy_from_slice(&challenges); // challenges = [alpha, c1=r, c2=1-r] let alpha: [F; 4] = challenges[0..EXT_DEG].try_into().unwrap(); @@ -335,7 +336,7 @@ where let alpha_numerator = alpha_acc; let alpha_denominator = FieldExtension::multiply(alpha_acc, alpha); logup_row.challenges[0..EXT_DEG].copy_from_slice(&alpha_acc); - logup_row.challenges[2 * EXT_DEG..(3 * EXT_DEG)].copy_from_slice(&alpha_denominator); + logup_row.challenges[3 * EXT_DEG..(4 * EXT_DEG)].copy_from_slice(&alpha_denominator); let max_round = max_round.as_canonical_u32(); if round < max_round - 1 { @@ -409,7 +410,7 @@ where state.memory, r_evals_ptr.as_canonical_u32() + (1 + num_prod_spec + num_logup_spec + i as u32) * (EXT_DEG as u32), - p_eval, + q_eval, &mut logup_specific.write_records[1], ); cur_timestamp += 3; // 1 read, 2 writes @@ -432,6 +433,7 @@ where let head_row = &mut rows[0]; head_row.last_timestamp = F::from_canonical_u32(cur_timestamp + 1); + head_row.eval_acc = eval_acc; let head_specific: &mut HeaderSpecificCols = head_row.specific[..HeaderSpecificCols::::width()].borrow_mut(); @@ -460,7 +462,6 @@ impl TraceFiller for NativeSumcheckFiller { let start_timestamp = cols.start_timestamp.as_canonical_u32(); let last_timestamp = cols.last_timestamp.as_canonical_u32(); - println!("start_timestamp: {}, cols.header_row: {:?}, prod_row: {:?}, logup_row: {:?}", start_timestamp, cols.header_row, cols.prod_row, cols.logup_row); if cols.header_row == F::ONE { let header: &mut HeaderSpecificCols = cols.specific[..HeaderSpecificCols::::width()].borrow_mut(); diff --git a/extensions/native/recursion/tests/sumcheck.rs b/extensions/native/recursion/tests/sumcheck.rs index d6981c4fdb..a4039028bc 100644 --- a/extensions/native/recursion/tests/sumcheck.rs +++ b/extensions/native/recursion/tests/sumcheck.rs @@ -81,6 +81,7 @@ fn build_test_program(builder: &mut Builder) { 548478283u32, 456436544, 1716290291, 791326976, 1829717553, 1422025771, 1917123958, 727015942, 183548369, 591240150, 96141963, 1286249979, + 0, 0, 0, 0, ]; let challenges: Array> = builder.dyn_array(challenges_u32s.len() / EXT_DEG); for (idx, n) in challenges_u32s.chunks(EXT_DEG).enumerate() { From d65cfde1593d52f6b1ff9ee973855c40a672e8b9 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Sat, 29 Nov 2025 21:49:07 +0800 Subject: [PATCH 13/18] wip11 --- extensions/native/circuit/src/sumcheck/chip.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index a5a89042e2..dc528322a0 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -431,6 +431,10 @@ where alpha_acc = FieldExtension::multiply(alpha_denominator, alpha); } + if let Some(last_row) = rows.last_mut() { + last_row.is_end = F::ONE; + } + let head_row = &mut rows[0]; head_row.last_timestamp = F::from_canonical_u32(cur_timestamp + 1); head_row.eval_acc = eval_acc; @@ -534,4 +538,10 @@ impl TraceFiller for NativeSumcheckFiller { } } } + + fn fill_dummy_trace_row(&self, row_slice: &mut [F]) { + let cols: &mut NativeSumcheckCols = row_slice.borrow_mut(); + + cols.is_end = F::ONE; + } } From b6f9c546647f7be9f458b91077a20a2f40e942d2 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Sat, 29 Nov 2025 22:14:46 +0800 Subject: [PATCH 14/18] wip12 --- extensions/native/circuit/src/sumcheck/air.rs | 93 ++++++++++--------- .../native/circuit/src/sumcheck/chip.rs | 5 + 2 files changed, 54 insertions(+), 44 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 90110347f6..4161026343 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -116,14 +116,12 @@ impl Air for NativeSumcheckAir { let next_enabled = next.header_row + next.prod_row + next.logup_row; builder.assert_bool(enabled.clone()); - builder.assert_eq::( - or::( - or::( - and(prod_row, next.header_row), - and(logup_row, next.header_row), - ), - not::(next_enabled), - ), + // TODO: handle last row properly + + builder.when_transition().assert_eq::( + prod_row * next.header_row + + logup_row * next.header_row + + not::(next_enabled), is_end.into(), ); @@ -238,25 +236,27 @@ impl Air for NativeSumcheckAir { ); // Termination condition - assert_array_eq( - &mut builder.when::(is_end.into()), - eval_acc, - [AB::F::ZERO; 4], - ); + // TODO: enable this + // assert_array_eq( + // &mut builder.when::(is_end.into()), + // eval_acc, + // [AB::F::ZERO; 4], + // ); // Randomness transition assert_array_eq( - &mut builder.when(and(header_row, or(next.prod_row, next.logup_row))), + &mut builder.when(and(header_row, next.prod_row + next.logup_row)), next.challenges[0..EXT_DEG].try_into().unwrap(), [AB::F::ONE, AB::F::ZERO, AB::F::ZERO, AB::F::ZERO], ); assert_array_eq::<_, _, _, { EXT_DEG }>(&mut builder.when(header_row), alpha, alpha1); let prod_next_alpha = FieldExtension::multiply(alpha1, alpha); - assert_array_eq::<_, _, _, { EXT_DEG }>( - &mut builder.when(and(prod_row, next.prod_row)), - prod_next_alpha, - next_alpha1, - ); + // TODO: reduce the degree + // assert_array_eq::<_, _, _, { EXT_DEG }>( + // &mut builder.when(and(prod_row, next.prod_row)), + // prod_next_alpha, + // next_alpha1, + // ); // alpha1 = alpha_numerator, alpha2 = alpha_denominator for logup row let alpha_denominator = FieldExtension::multiply(alpha1, alpha); assert_array_eq::<_, _, _, { EXT_DEG }>( @@ -265,11 +265,12 @@ impl Air for NativeSumcheckAir { alpha2, ); let logup_next_alpha = FieldExtension::multiply(alpha2, alpha); - assert_array_eq::<_, _, _, { EXT_DEG }>( - &mut builder.when(and(logup_row, next.logup_row)), - logup_next_alpha, - next_alpha1, - ); + // TODO: reduce the degree + // assert_array_eq::<_, _, _, { EXT_DEG }>( + // &mut builder.when(and(logup_row, next.logup_row)), + // logup_next_alpha, + // next_alpha1, + // ); /////////////////////////////////////// // Header @@ -359,11 +360,12 @@ impl Air for NativeSumcheckAir { ) .eval(builder, prod_row); - builder.when(prod_row * within_round_limit).assert_eq( - prod_row_specific.data_ptr, - (prod_nested_len * (curr_prod_n - AB::F::ONE) + prod_spec_inner_inner_len * round) - * AB::F::from_canonical_usize(EXT_DEG), - ); + // TODO: reduce the degree + // builder.when(prod_row * within_round_limit).assert_eq( + // prod_row_specific.data_ptr, + // (prod_nested_len * (curr_prod_n - AB::F::ONE) + prod_spec_inner_inner_len * round) + // * AB::F::from_canonical_usize(EXT_DEG), + // ); builder.assert_eq( prod_row * within_round_limit * in_round, prod_in_round_evaluation, @@ -428,11 +430,12 @@ impl Air for NativeSumcheckAir { prod_row_specific.eval_rlc, eval_rlc, ); - assert_array_eq::<_, _, _, { EXT_DEG }>( - &mut builder.when(next.prod_acc), - FieldExtension::add(next.eval_acc, next_prod_row_specific.eval_rlc), - eval_acc, - ); + // TODO: enable this + // assert_array_eq::<_, _, _, { EXT_DEG }>( + // &mut builder.when(next.prod_acc), + // FieldExtension::add(next.eval_acc, next_prod_row_specific.eval_rlc), + // eval_acc, + // ); /////////////////////////////////////// // Logup spec evaluation @@ -457,11 +460,12 @@ impl Air for NativeSumcheckAir { ) .eval(builder, logup_row); - builder.when(logup_row * within_round_limit).assert_eq( - logup_row_specific.data_ptr, - (logup_nested_len * (curr_logup_n - AB::F::ONE) + logup_spec_inner_inner_len * round) - * AB::F::from_canonical_usize(EXT_DEG), - ); + // TODO: reduce the degree + // builder.when(logup_row * within_round_limit).assert_eq( + // logup_row_specific.data_ptr, + // (logup_nested_len * (curr_logup_n - AB::F::ONE) + logup_spec_inner_inner_len * round) + // * AB::F::from_canonical_usize(EXT_DEG), + // ); builder.assert_eq( logup_row * within_round_limit * in_round, logup_in_round_evaluation, @@ -570,10 +574,11 @@ impl Air for NativeSumcheckAir { // Accumulate into global accumulator `eval_acc` // when round < max_round - 2 - assert_array_eq::<_, _, _, { EXT_DEG }>( - &mut builder.when(next.logup_acc), - FieldExtension::add(next.eval_acc, next_logup_row_specfic.eval_rlc), - eval_acc, - ); + // TODO: enable this + // assert_array_eq::<_, _, _, { EXT_DEG }>( + // &mut builder.when(next.logup_acc), + // FieldExtension::add(next.eval_acc, next_logup_row_specfic.eval_rlc), + // eval_acc, + // ); } } diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index dc528322a0..bc2524af75 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -217,6 +217,10 @@ where row.challenges[EXT_DEG..3 * EXT_DEG].copy_from_slice(&challenges[EXT_DEG..3 * EXT_DEG]); row.alpha = alpha; row.ctx = ctx; + row.prod_nested_len = + F::from_canonical_u32(prod_specs_inner_len * prod_specs_inner_inner_len); + row.logup_nested_len = + F::from_canonical_u32(logup_specs_inner_len * logup_specs_inner_inner_len); row.register_ptrs[0] = ctx_ptr; row.register_ptrs[1] = challenges_ptr; row.register_ptrs[2] = prod_evals_ptr; @@ -309,6 +313,7 @@ where if mode == NEXT_LAYER_MODE && round + 1 < max_round - 1 { eval_acc = FieldExtension::add(eval_acc, eval_rlc); prod_row.should_acc = F::ONE; + prod_row.prod_acc = F::ONE; prod_row.eval_acc = eval_acc; } } From db585cc38504c8ee0f54f4a9767cde24190c0542 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Sat, 29 Nov 2025 22:50:40 +0800 Subject: [PATCH 15/18] wip13 --- extensions/native/circuit/src/sumcheck/air.rs | 30 +++++++++++-------- .../native/circuit/src/sumcheck/chip.rs | 10 +++++++ .../native/circuit/src/sumcheck/columns.rs | 3 ++ 3 files changed, 31 insertions(+), 12 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 4161026343..b653b92330 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -67,6 +67,8 @@ impl Air for NativeSumcheckAir { logup_row, is_end, + prod_continued, + logup_continued, // What type of evaluation is performed // mainly for reducing constraint degree prod_in_round_evaluation, @@ -116,6 +118,12 @@ impl Air for NativeSumcheckAir { let next_enabled = next.header_row + next.prod_row + next.logup_row; builder.assert_bool(enabled.clone()); + builder + .when_transition() + .assert_eq(prod_row * next.prod_row, prod_continued); + builder + .when_transition() + .assert_eq(logup_row * next.logup_row, logup_continued); // TODO: handle last row properly builder.when_transition().assert_eq::( @@ -251,12 +259,11 @@ impl Air for NativeSumcheckAir { ); assert_array_eq::<_, _, _, { EXT_DEG }>(&mut builder.when(header_row), alpha, alpha1); let prod_next_alpha = FieldExtension::multiply(alpha1, alpha); - // TODO: reduce the degree - // assert_array_eq::<_, _, _, { EXT_DEG }>( - // &mut builder.when(and(prod_row, next.prod_row)), - // prod_next_alpha, - // next_alpha1, - // ); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(prod_continued), + prod_next_alpha, + next_alpha1, + ); // alpha1 = alpha_numerator, alpha2 = alpha_denominator for logup row let alpha_denominator = FieldExtension::multiply(alpha1, alpha); assert_array_eq::<_, _, _, { EXT_DEG }>( @@ -265,12 +272,11 @@ impl Air for NativeSumcheckAir { alpha2, ); let logup_next_alpha = FieldExtension::multiply(alpha2, alpha); - // TODO: reduce the degree - // assert_array_eq::<_, _, _, { EXT_DEG }>( - // &mut builder.when(and(logup_row, next.logup_row)), - // logup_next_alpha, - // next_alpha1, - // ); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(logup_continued), + logup_next_alpha, + next_alpha1, + ); /////////////////////////////////////// // Header diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index bc2524af75..8f31f25263 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -239,6 +239,11 @@ where prod_row.specific[..ProdSpecificCols::::width()].borrow_mut(); prod_row.prod_row = F::ONE; + prod_row.prod_continued = if i < (num_prod_spec - 1) as usize { + F::ONE + } else { + F::ZERO + }; prod_row.curr_prod_n = F::from_canonical_usize(i + 1); // curr_prod_n starts from 1 prod_row.start_timestamp = F::from_canonical_u32(cur_timestamp); @@ -327,6 +332,11 @@ where logup_row.specific[..LogupSpecificCols::::width()].borrow_mut(); logup_row.logup_row = F::ONE; + logup_row.logup_continued = if i < (num_logup_spec - 1) as usize { + F::ONE + } else { + F::ZERO + }; logup_row.curr_logup_n = F::from_canonical_usize(i + 1); // curr_logup_n starts from 1 logup_row.start_timestamp = F::from_canonical_u32(cur_timestamp); diff --git a/extensions/native/circuit/src/sumcheck/columns.rs b/extensions/native/circuit/src/sumcheck/columns.rs index 0cee93c381..b3e6bf4f25 100644 --- a/extensions/native/circuit/src/sumcheck/columns.rs +++ b/extensions/native/circuit/src/sumcheck/columns.rs @@ -19,6 +19,9 @@ pub struct NativeSumcheckCols { /// Indicates that this row is the end of the entire layer sum operation pub is_end: T, + pub prod_continued: T, + pub logup_continued: T, + /// Indicates what type of evaluation constraints should be applied pub prod_in_round_evaluation: T, pub prod_next_round_evaluation: T, From 67896e78452569123c88a5097f70fbe833304920 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Sun, 30 Nov 2025 15:38:54 +0800 Subject: [PATCH 16/18] wip14 --- extensions/native/circuit/src/sumcheck/air.rs | 31 ++++++++++++------- .../native/circuit/src/sumcheck/chip.rs | 4 +-- extensions/native/compiler/src/ir/sumcheck.rs | 6 +--- 3 files changed, 22 insertions(+), 19 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index b653b92330..64e6594bfa 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -366,12 +366,15 @@ impl Air for NativeSumcheckAir { ) .eval(builder, prod_row); - // TODO: reduce the degree - // builder.when(prod_row * within_round_limit).assert_eq( - // prod_row_specific.data_ptr, - // (prod_nested_len * (curr_prod_n - AB::F::ONE) + prod_spec_inner_inner_len * round) - // * AB::F::from_canonical_usize(EXT_DEG), - // ); + // prod_row * within_round_limit = + // prod_in_round_evaluation + prod_next_round_evaluation + builder + .when(prod_in_round_evaluation + prod_next_round_evaluation) + .assert_eq( + prod_row_specific.data_ptr, + (prod_nested_len * (curr_prod_n - AB::F::ONE) + prod_spec_inner_inner_len * round) + * AB::F::from_canonical_usize(EXT_DEG), + ); builder.assert_eq( prod_row * within_round_limit * in_round, prod_in_round_evaluation, @@ -466,12 +469,16 @@ impl Air for NativeSumcheckAir { ) .eval(builder, logup_row); - // TODO: reduce the degree - // builder.when(logup_row * within_round_limit).assert_eq( - // logup_row_specific.data_ptr, - // (logup_nested_len * (curr_logup_n - AB::F::ONE) + logup_spec_inner_inner_len * round) - // * AB::F::from_canonical_usize(EXT_DEG), - // ); + // logup_row * within_round_limit = + // logup_in_round_evaluation + logup_next_round_evaluation + builder + .when(logup_in_round_evaluation + logup_next_round_evaluation) + .assert_eq( + logup_row_specific.data_ptr, + (logup_nested_len * (curr_logup_n - AB::F::ONE) + + logup_spec_inner_inner_len * round) + * AB::F::from_canonical_usize(EXT_DEG), + ); builder.assert_eq( logup_row * within_round_limit * in_round, logup_in_round_evaluation, diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index 8f31f25263..8b130201e7 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -357,8 +357,8 @@ where if round < max_round - 1 { logup_row.within_round_limit = F::ONE; let start = calculate_3d_ext_idx( - prod_specs_inner_inner_len, - prod_specs_inner_len, + logup_specs_inner_inner_len, + logup_specs_inner_len, i as u32, round, 0, diff --git a/extensions/native/compiler/src/ir/sumcheck.rs b/extensions/native/compiler/src/ir/sumcheck.rs index 1649c3f11d..0237fd6740 100644 --- a/extensions/native/compiler/src/ir/sumcheck.rs +++ b/extensions/native/compiler/src/ir/sumcheck.rs @@ -1,8 +1,4 @@ -use openvm_native_compiler_derive::iter_zip; -use openvm_stark_backend::p3_field::FieldAlgebra; - -use super::{Array, ArrayLike, Builder, Config, DslIr, Ext, Felt, MemIndex, Ptr, Usize, Var}; -use crate::ir::Variable; +use super::{Array, Builder, Config, DslIr, Ext, Usize}; impl Builder { /// Extends native VM ability to calculate the evaluation for a sumcheck layer From c7e9a8b108f0fc59c57ddc348fdd7adc546dafae Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Sun, 30 Nov 2025 20:42:54 +0800 Subject: [PATCH 17/18] wip15 --- extensions/native/circuit/src/sumcheck/air.rs | 33 +++++++++---------- .../native/circuit/src/sumcheck/chip.rs | 19 ++++++++++- 2 files changed, 33 insertions(+), 19 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 64e6594bfa..2d87a316fb 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -244,12 +244,11 @@ impl Air for NativeSumcheckAir { ); // Termination condition - // TODO: enable this - // assert_array_eq( - // &mut builder.when::(is_end.into()), - // eval_acc, - // [AB::F::ZERO; 4], - // ); + assert_array_eq( + &mut builder.when::(is_end.into()), + eval_acc, + [AB::F::ZERO; 4], + ); // Randomness transition assert_array_eq( @@ -439,12 +438,11 @@ impl Air for NativeSumcheckAir { prod_row_specific.eval_rlc, eval_rlc, ); - // TODO: enable this - // assert_array_eq::<_, _, _, { EXT_DEG }>( - // &mut builder.when(next.prod_acc), - // FieldExtension::add(next.eval_acc, next_prod_row_specific.eval_rlc), - // eval_acc, - // ); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(next.prod_acc), + FieldExtension::add(next.eval_acc, next_prod_row_specific.eval_rlc), + eval_acc, + ); /////////////////////////////////////// // Logup spec evaluation @@ -587,11 +585,10 @@ impl Air for NativeSumcheckAir { // Accumulate into global accumulator `eval_acc` // when round < max_round - 2 - // TODO: enable this - // assert_array_eq::<_, _, _, { EXT_DEG }>( - // &mut builder.when(next.logup_acc), - // FieldExtension::add(next.eval_acc, next_logup_row_specfic.eval_rlc), - // eval_acc, - // ); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(next.logup_acc), + FieldExtension::add(next.eval_acc, next_logup_row_specfic.eval_rlc), + eval_acc, + ); } } diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index 8b130201e7..17e99c442a 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -452,7 +452,6 @@ where let head_row = &mut rows[0]; head_row.last_timestamp = F::from_canonical_u32(cur_timestamp + 1); - head_row.eval_acc = eval_acc; let head_specific: &mut HeaderSpecificCols = head_row.specific[..HeaderSpecificCols::::width()].borrow_mut(); @@ -464,6 +463,24 @@ where &mut head_specific.write_records, ); + for row in rows.iter_mut() { + if row.header_row == F::ONE { + row.eval_acc = eval_acc; + } else if row.prod_row == F::ONE { + let specific: &mut ProdSpecificCols = + row.specific[..ProdSpecificCols::::width()].borrow_mut(); + + eval_acc = FieldExtension::subtract(eval_acc, specific.eval_rlc); + row.eval_acc = eval_acc; + } else if row.logup_row == F::ONE { + let specific: &mut LogupSpecificCols = + row.specific[..LogupSpecificCols::::width()].borrow_mut(); + eval_acc = FieldExtension::subtract(eval_acc, specific.eval_rlc); + row.eval_acc = eval_acc; + } + } + assert_eq!(eval_acc, elem_to_ext(F::from_canonical_u32(0)),); + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); Ok(()) } From dd912a2e90d3a687a21d1cac10d760b59d85fe78 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Sun, 30 Nov 2025 20:51:15 +0800 Subject: [PATCH 18/18] clippy --- extensions/native/circuit/src/poseidon2/chip.rs | 6 ++---- extensions/native/circuit/src/sumcheck/air.rs | 9 +++------ 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/extensions/native/circuit/src/poseidon2/chip.rs b/extensions/native/circuit/src/poseidon2/chip.rs index d3fe6e3e38..770efc7307 100644 --- a/extensions/native/circuit/src/poseidon2/chip.rs +++ b/extensions/native/circuit/src/poseidon2/chip.rs @@ -3,10 +3,8 @@ use std::borrow::{Borrow, BorrowMut}; use openvm_circuit::{ arch::*, system::{ - memory::{offline_checker::MemoryBaseAuxCols, online::TracingMemory, MemoryAuxColsFactory}, - native_adapter::util::{ - memory_read_native, tracing_read_native, tracing_write_native_inplace, - }, + memory::{online::TracingMemory, MemoryAuxColsFactory}, + native_adapter::util::{memory_read_native, tracing_write_native_inplace}, }, }; use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 2d87a316fb..1cae3847ca 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -1,17 +1,14 @@ use std::borrow::Borrow; use openvm_circuit::{ - arch::{ContinuationVmProof, ExecutionBridge, ExecutionState}, + arch::{ExecutionBridge, ExecutionState}, system::memory::{offline_checker::MemoryBridge, MemoryAddress}, }; -use openvm_circuit_primitives::{ - utils::{and, assert_array_eq, not, or}, - var_range::VariableRangeCheckerBus, -}; +use openvm_circuit_primitives::utils::{and, assert_array_eq, not}; use openvm_instructions::{LocalOpcode, NATIVE_AS}; use openvm_native_compiler::SumcheckOpcode::SUMCHECK_LAYER_EVAL; use openvm_stark_backend::{ - interaction::{BusIndex, InteractionBuilder, PermutationCheckBus}, + interaction::InteractionBuilder, p3_air::{Air, AirBuilder, BaseAir}, p3_field::{Field, FieldAlgebra}, p3_matrix::Matrix,