diff --git a/.gitignore b/.gitignore index c6e6aa2049..c87b73f198 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,9 @@ Cargo.lock **/.env .DS_Store +# Log outputs +*.log + .cache/ rustc-* diff --git a/crates/sdk/src/prover/agg.rs b/crates/sdk/src/prover/agg.rs index 5e22562675..fd4491ddec 100644 --- a/crates/sdk/src/prover/agg.rs +++ b/crates/sdk/src/prover/agg.rs @@ -27,12 +27,12 @@ where E: StarkFriEngine, NativeBuilder: VmBuilder, { - leaf_prover: VmInstance, - leaf_controller: LeafProvingController, + pub leaf_prover: VmInstance, + pub leaf_controller: LeafProvingController, pub internal_prover: VmInstance, #[cfg(feature = "evm-prove")] - root_prover: RootVerifierLocalProver, + pub root_prover: RootVerifierLocalProver, pub num_children_internal: usize, pub max_internal_wrapper_layers: usize, } diff --git a/extensions/native/circuit/src/extension/mod.rs b/extensions/native/circuit/src/extension/mod.rs index a86cdb1bd2..9f3e2035ad 100644 --- a/extensions/native/circuit/src/extension/mod.rs +++ b/extensions/native/circuit/src/extension/mod.rs @@ -17,7 +17,8 @@ use openvm_instructions::{program::DEFAULT_PC_STEP, LocalOpcode, PhantomDiscrimi use openvm_native_compiler::{ CastfOpcode, FieldArithmeticOpcode, FieldExtensionOpcode, FriOpcode, NativeBranchEqualOpcode, NativeJalOpcode, NativeLoadStore4Opcode, NativeLoadStoreOpcode, NativePhantom, - NativeRangeCheckOpcode, Poseidon2Opcode, VerifyBatchOpcode, BLOCK_LOAD_STORE_SIZE, + NativeRangeCheckOpcode, Poseidon2Opcode, SumcheckOpcode, VerifyBatchOpcode, + BLOCK_LOAD_STORE_SIZE, }; use openvm_poseidon2_air::Poseidon2Config; use openvm_rv32im_circuit::BranchEqualCoreAir; @@ -61,6 +62,10 @@ use crate::{ chip::{NativePoseidon2Executor, NativePoseidon2Filler}, NativePoseidon2Chip, }, + sumcheck::{ + air::NativeSumcheckAir, + chip::{NativeSumcheckChip, NativeSumcheckExecutor, NativeSumcheckFiller}, + }, }; cfg_if::cfg_if! { @@ -94,6 +99,7 @@ pub enum NativeExecutor { FieldExtension(FieldExtensionExecutor), FriReducedOpening(FriReducedOpeningExecutor), VerifyBatch(NativePoseidon2Executor), + TowerVerify(NativeSumcheckExecutor), } impl VmExecutionExtension for Native { @@ -169,6 +175,12 @@ impl VmExecutionExtension for Native { ], )?; + let tower_verify = NativeSumcheckExecutor::new(); + inventory.add_executor( + tower_verify, + [SumcheckOpcode::SUMCHECK_LAYER_EVAL.global_opcode()], + )?; + inventory.add_phantom_sub_executor( NativeHintInputSubEx, PhantomDiscriminant(NativePhantom::HintInput as u16), @@ -262,6 +274,9 @@ where ); inventory.add_air(verify_batch); + let tower_evaluate = NativeSumcheckAir::new(exec_bridge, memory_bridge); + inventory.add_air(tower_evaluate); + Ok(()) } } @@ -342,6 +357,9 @@ where ); inventory.add_executor_chip(poseidon2); + let tower_verify = NativeSumcheckChip::new(NativeSumcheckFiller::new(), mem_helper.clone()); + inventory.add_executor_chip(tower_verify); + Ok(()) } } diff --git a/extensions/native/circuit/src/field_extension/core.rs b/extensions/native/circuit/src/field_extension/core.rs index 5afaf74af5..a7d535a14b 100644 --- a/extensions/native/circuit/src/field_extension/core.rs +++ b/extensions/native/circuit/src/field_extension/core.rs @@ -254,10 +254,10 @@ pub(crate) struct FieldExtension; impl FieldExtension { pub(crate) fn add(x: [V; EXT_DEG], y: [V; EXT_DEG]) -> [E; EXT_DEG] where - V: Copy, + V: Clone, V: Add, { - array::from_fn(|i| x[i] + y[i]) + array::from_fn(|i| x[i].clone() + y[i].clone()) } pub(crate) fn subtract(x: [V; EXT_DEG], y: [V; EXT_DEG]) -> [E; EXT_DEG] diff --git a/extensions/native/circuit/src/fri/mod.rs b/extensions/native/circuit/src/fri/mod.rs index 1e1ec65cb8..4a0d3847d5 100644 --- a/extensions/native/circuit/src/fri/mod.rs +++ b/extensions/native/circuit/src/fri/mod.rs @@ -542,7 +542,7 @@ fn assert_array_eq, I2: Into, const } } -fn elem_to_ext(elem: F) -> [F; EXT_DEG] { +pub fn elem_to_ext(elem: F) -> [F; EXT_DEG] { let mut ret = [F::ZERO; EXT_DEG]; ret[0] = elem; ret diff --git a/extensions/native/circuit/src/lib.rs b/extensions/native/circuit/src/lib.rs index b5db4f0010..ce257c9c22 100644 --- a/extensions/native/circuit/src/lib.rs +++ b/extensions/native/circuit/src/lib.rs @@ -42,9 +42,11 @@ mod fri; mod jal_rangecheck; mod loadstore; mod poseidon2; +mod sumcheck; mod extension; pub use extension::*; +pub use field_extension::EXT_DEG; mod utils; #[cfg(any(test, feature = "test-utils"))] diff --git a/extensions/native/circuit/src/poseidon2/chip.rs b/extensions/native/circuit/src/poseidon2/chip.rs index aecff9f10f..770efc7307 100644 --- a/extensions/native/circuit/src/poseidon2/chip.rs +++ b/extensions/native/circuit/src/poseidon2/chip.rs @@ -3,10 +3,8 @@ use std::borrow::{Borrow, BorrowMut}; use openvm_circuit::{ arch::*, system::{ - memory::{offline_checker::MemoryBaseAuxCols, online::TracingMemory, MemoryAuxColsFactory}, - native_adapter::util::{ - memory_read_native, tracing_read_native, tracing_write_native_inplace, - }, + memory::{online::TracingMemory, MemoryAuxColsFactory}, + native_adapter::util::{memory_read_native, tracing_write_native_inplace}, }, }; use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; @@ -23,12 +21,16 @@ use openvm_stark_backend::{ p3_maybe_rayon::prelude::{IntoParallelIterator, ParallelSliceMut, *}, }; -use crate::poseidon2::{ - columns::{ - InsideRowSpecificCols, MultiObserveCols, NativePoseidon2Cols, SimplePoseidonSpecificCols, - TopLevelSpecificCols, +use crate::{ + mem_fill_helper, + poseidon2::{ + columns::{ + InsideRowSpecificCols, MultiObserveCols, NativePoseidon2Cols, + SimplePoseidonSpecificCols, TopLevelSpecificCols, + }, + CHUNK, }, - CHUNK, + tracing_read_native_helper, }; #[derive(Clone)] @@ -1240,24 +1242,3 @@ impl NativePoseidon2Filler( - memory: &mut TracingMemory, - ptr: u32, - base_aux: &mut MemoryBaseAuxCols, -) -> [F; BLOCK_SIZE] { - let mut prev_ts = 0; - let ret = tracing_read_native(memory, ptr, &mut prev_ts); - base_aux.set_prev(F::from_canonical_u32(prev_ts)); - ret -} - -/// Fill `MemoryBaseAuxCols`, assuming that the `prev_timestamp` is already set in `base_aux`. -fn mem_fill_helper( - mem_helper: &MemoryAuxColsFactory, - timestamp: u32, - base_aux: &mut MemoryBaseAuxCols, -) { - let prev_ts = base_aux.prev_timestamp.as_canonical_u32(); - mem_helper.fill(prev_ts, timestamp, base_aux); -} diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs new file mode 100644 index 0000000000..1cae3847ca --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -0,0 +1,591 @@ +use std::borrow::Borrow; + +use openvm_circuit::{ + arch::{ExecutionBridge, ExecutionState}, + system::memory::{offline_checker::MemoryBridge, MemoryAddress}, +}; +use openvm_circuit_primitives::utils::{and, assert_array_eq, not}; +use openvm_instructions::{LocalOpcode, NATIVE_AS}; +use openvm_native_compiler::SumcheckOpcode::SUMCHECK_LAYER_EVAL; +use openvm_stark_backend::{ + interaction::InteractionBuilder, + p3_air::{Air, AirBuilder, BaseAir}, + p3_field::{Field, FieldAlgebra}, + p3_matrix::Matrix, + rap::{BaseAirWithPublicValues, PartitionedBaseAir}, +}; + +use crate::{ + field_extension::{FieldExtension, EXT_DEG}, + sumcheck::{ + chip::CONTEXT_ARR_BASE_LEN, + columns::{HeaderSpecificCols, LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols}, + }, +}; + +#[derive(Clone, Debug)] +pub struct NativeSumcheckAir { + pub execution_bridge: ExecutionBridge, + pub memory_bridge: MemoryBridge, +} + +impl NativeSumcheckAir { + pub fn new(execution_bridge: ExecutionBridge, memory_bridge: MemoryBridge) -> Self { + Self { + execution_bridge, + memory_bridge, + } + } +} + +impl BaseAir for NativeSumcheckAir { + fn width(&self) -> usize { + NativeSumcheckCols::::width() + } +} + +impl BaseAirWithPublicValues for NativeSumcheckAir {} + +impl PartitionedBaseAir for NativeSumcheckAir {} + +impl Air for NativeSumcheckAir { + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let local = main.row_slice(0); + let local: &NativeSumcheckCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &NativeSumcheckCols = (*next).borrow(); + let native_as = AB::F::from_canonical_u32(NATIVE_AS); + + let &NativeSumcheckCols { + // Row indicators + header_row, + prod_row, + logup_row, + is_end, + + prod_continued, + logup_continued, + // What type of evaluation is performed + // mainly for reducing constraint degree + prod_in_round_evaluation, + prod_next_round_evaluation, + logup_in_round_evaluation, + logup_next_round_evaluation, + + // Indicates whether the round evaluations should be added to the accumulator + prod_acc, + logup_acc, + + // Timestamps + first_timestamp, + start_timestamp, + last_timestamp, + + // Results from reading registers + register_ptrs, + ctx, + prod_nested_len, + logup_nested_len, + + // Challenges + alpha, + challenges, + + curr_prod_n, + curr_logup_n, + + max_round, + within_round_limit, + should_acc, + eval_acc, + specific, + } = local; + + let [round, num_prod_spec, num_logup_spec, _prod_spec_inner_len, prod_spec_inner_inner_len, _logup_spec_inner_len, logup_spec_inner_inner_len, in_round] = + ctx; + builder.assert_bool(header_row); + builder.assert_bool(prod_row); + builder.assert_bool(logup_row); + builder.assert_bool(within_round_limit); + builder.assert_bool(prod_in_round_evaluation); + builder.assert_bool(logup_in_round_evaluation); + + let enabled = header_row + prod_row + logup_row; + let next_enabled = next.header_row + next.prod_row + next.logup_row; + builder.assert_bool(enabled.clone()); + + builder + .when_transition() + .assert_eq(prod_row * next.prod_row, prod_continued); + builder + .when_transition() + .assert_eq(logup_row * next.logup_row, logup_continued); + // TODO: handle last row properly + + builder.when_transition().assert_eq::( + prod_row * next.header_row + + logup_row * next.header_row + + not::(next_enabled), + is_end.into(), + ); + + // TODO: within_round_limit = true => round < max_round + + // Randomness transition + let alpha1: [_; EXT_DEG] = challenges[0..EXT_DEG].try_into().unwrap(); + let c1: [_; EXT_DEG] = challenges[EXT_DEG..{ EXT_DEG * 2 }].try_into().unwrap(); + let c2: [_; EXT_DEG] = challenges[{ EXT_DEG * 2 }..{ EXT_DEG * 3 }] + .try_into() + .unwrap(); + let alpha2: [_; EXT_DEG] = challenges[{ EXT_DEG * 3 }..{ EXT_DEG * 4 }] + .try_into() + .unwrap(); + let next_alpha1: [_; EXT_DEG] = next.challenges[0..EXT_DEG].try_into().unwrap(); + + // Carry along columns + assert_array_eq( + &mut builder.when(next.prod_row + next.logup_row), + register_ptrs, + next.register_ptrs, + ); + assert_array_eq( + &mut builder.when(next.prod_row + next.logup_row), + ctx, + next.ctx, + ); + // c1, c2 remain the same + assert_array_eq::<_, _, _, { EXT_DEG * 2 }>( + &mut builder.when(next.prod_row + next.logup_row), + challenges[EXT_DEG..(EXT_DEG * 3)].try_into().expect(""), + next.challenges[EXT_DEG..(EXT_DEG * 3)] + .try_into() + .expect(""), + ); + assert_array_eq( + &mut builder.when(next.prod_row + next.logup_row), + alpha, + next.alpha, + ); + builder + .when(next.prod_row + next.logup_row) + .assert_eq(prod_nested_len, next.prod_nested_len); + builder + .when(next.prod_row + next.logup_row) + .assert_eq(logup_nested_len, next.logup_nested_len); + + //////////////////////////////////////////////////////////////// + // Row transitions from current to next row + // The basic pattern is + // header_row -> prod_row -> ... -> prod_row + // -> logup_row -> ... -> logup_row + //////////////////////////////////////////////////////////////// + + // (curr_prod_n, curr_logup_n) start at 0 + builder.when(header_row).assert_zero(curr_prod_n); + builder + .when(header_row + prod_row) + .assert_zero(curr_logup_n); + builder + .when(next.prod_row) + .assert_eq(curr_prod_n + AB::F::ONE, next.curr_prod_n); + builder + .when(next.logup_row) + .assert_eq(curr_logup_n + AB::F::ONE, next.curr_logup_n); + // if header row is followed by another header row + // then num_prod_spec and num_logup_spec should be zero + builder + .when(header_row) + .when(next.header_row) + .assert_zero(num_prod_spec); + builder + .when(header_row) + .when(next.header_row) + .assert_zero(num_logup_spec); + // if header row is followed by a logup row, + // then num_prod_spec should be zero + builder + .when(header_row) + .when(next.logup_row) + .assert_zero(num_prod_spec); + builder + .when(prod_row) + .when(next.logup_row) + .assert_eq(curr_prod_n, num_prod_spec); + builder + .when(logup_row) + .when(next.header_row) + .assert_eq(curr_logup_n, num_logup_spec); + + // Timestamp transition + builder + .when(header_row) + .when(next.prod_row + next.logup_row) + .assert_eq( + next.start_timestamp, + start_timestamp + AB::F::from_canonical_usize(7), + ); + builder + .when(prod_row) + .when(next.prod_row + next.logup_row) + .assert_eq( + next.start_timestamp, + start_timestamp + AB::F::ONE + within_round_limit * AB::F::TWO, + ); + builder + .when(logup_row) + .when(next.prod_row + next.logup_row) + .assert_eq( + next.start_timestamp, + start_timestamp + AB::F::ONE + within_round_limit * AB::F::from_canonical_usize(3), + ); + + // Termination condition + assert_array_eq( + &mut builder.when::(is_end.into()), + eval_acc, + [AB::F::ZERO; 4], + ); + + // Randomness transition + assert_array_eq( + &mut builder.when(and(header_row, next.prod_row + next.logup_row)), + next.challenges[0..EXT_DEG].try_into().unwrap(), + [AB::F::ONE, AB::F::ZERO, AB::F::ZERO, AB::F::ZERO], + ); + assert_array_eq::<_, _, _, { EXT_DEG }>(&mut builder.when(header_row), alpha, alpha1); + let prod_next_alpha = FieldExtension::multiply(alpha1, alpha); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(prod_continued), + prod_next_alpha, + next_alpha1, + ); + // alpha1 = alpha_numerator, alpha2 = alpha_denominator for logup row + let alpha_denominator = FieldExtension::multiply(alpha1, alpha); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(logup_row), + alpha_denominator, + alpha2, + ); + let logup_next_alpha = FieldExtension::multiply(alpha2, alpha); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(logup_continued), + logup_next_alpha, + next_alpha1, + ); + + /////////////////////////////////////// + // Header + /////////////////////////////////////// + let header_row_specific: &HeaderSpecificCols = + specific[..HeaderSpecificCols::::width()].borrow(); + let registers = header_row_specific.registers; + + self.execution_bridge + .execute_and_increment_pc( + AB::Expr::from_canonical_usize(SUMCHECK_LAYER_EVAL.global_opcode().as_usize()), + [ + registers[4].into(), + registers[0].into(), + registers[1].into(), + native_as.into(), + native_as.into(), + registers[2].into(), + registers[3].into(), + ], + ExecutionState::new(header_row_specific.pc, first_timestamp), + last_timestamp - first_timestamp, + ) + .eval(builder, header_row); + + // Read registers + for i in 0..5usize { + self.memory_bridge + .read( + MemoryAddress::new(native_as, registers[i]), + [register_ptrs[i]], + first_timestamp + AB::F::from_canonical_usize(i), + &header_row_specific.read_records[i], + ) + .eval(builder, header_row); + } + + // Read ctx + self.memory_bridge + .read( + MemoryAddress::new(native_as, register_ptrs[0]), + ctx, + first_timestamp + AB::F::from_canonical_usize(5), + &header_row_specific.read_records[5], + ) + .eval(builder, header_row); + + // Read challenges + self.memory_bridge + .read( + MemoryAddress::new(native_as, register_ptrs[1]), + challenges, + first_timestamp + AB::F::from_canonical_usize(6), + &header_row_specific.read_records[6], + ) + .eval(builder, header_row); + + // Write final result + self.memory_bridge + .write( + MemoryAddress::new(native_as, register_ptrs[4]), + eval_acc, + last_timestamp - AB::F::ONE, + &header_row_specific.write_records, + ) + .eval(builder, header_row); + + /////////////////////////////////////// + // Prod spec evaluation + /////////////////////////////////////// + let prod_row_specific: &ProdSpecificCols = + specific[..ProdSpecificCols::::width()].borrow(); + let next_prod_row_specific: &ProdSpecificCols = + next.specific[..ProdSpecificCols::::width()].borrow(); + + self.memory_bridge + .read( + MemoryAddress::new( + native_as, + register_ptrs[0] + + AB::F::from_canonical_usize(CONTEXT_ARR_BASE_LEN) + + (curr_prod_n - AB::F::ONE), + ), // curr_prod_n starts at 1. + [max_round], + start_timestamp, + &prod_row_specific.read_records[0], + ) + .eval(builder, prod_row); + + // prod_row * within_round_limit = + // prod_in_round_evaluation + prod_next_round_evaluation + builder + .when(prod_in_round_evaluation + prod_next_round_evaluation) + .assert_eq( + prod_row_specific.data_ptr, + (prod_nested_len * (curr_prod_n - AB::F::ONE) + prod_spec_inner_inner_len * round) + * AB::F::from_canonical_usize(EXT_DEG), + ); + builder.assert_eq( + prod_row * within_round_limit * in_round, + prod_in_round_evaluation, + ); + builder.assert_eq( + prod_row * within_round_limit * not(in_round), + prod_next_round_evaluation, + ); + builder.assert_eq(prod_row * should_acc, prod_acc); + + self.memory_bridge + .read( + MemoryAddress::new(native_as, register_ptrs[2] + prod_row_specific.data_ptr), + prod_row_specific.p, + start_timestamp + AB::F::ONE, + &prod_row_specific.read_records[1], + ) + .eval(builder, prod_row * within_round_limit); + + let p1: [AB::Var; EXT_DEG] = prod_row_specific.p[0..EXT_DEG].try_into().unwrap(); + let p2: [AB::Var; EXT_DEG] = prod_row_specific.p[EXT_DEG..(EXT_DEG * 2)] + .try_into() + .unwrap(); + + self.memory_bridge + .write( + MemoryAddress::new( + native_as, + register_ptrs[4] + curr_prod_n * AB::F::from_canonical_usize(EXT_DEG), + ), + prod_row_specific.p_evals, + start_timestamp + AB::F::TWO, + &prod_row_specific.write_record, + ) + .eval(builder, prod_row * within_round_limit); + + // Calculate evaluations + let next_round_p_evals = FieldExtension::add( + FieldExtension::multiply::(p1, c1), + FieldExtension::multiply::(p2, c2), + ); + let in_round_p_evals = FieldExtension::multiply::(p1, p2); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(prod_in_round_evaluation), + in_round_p_evals, + prod_row_specific.p_evals, + ); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(prod_next_round_evaluation), + next_round_p_evals, + prod_row_specific.p_evals, + ); + + // TODO: add constraint on should_acc + + // Accumulate `eval_rlc` into global accumulator `eval_acc` + // when round < max_round - 2 + let eval_rlc = + FieldExtension::multiply::(prod_row_specific.p_evals, alpha1); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(prod_acc), + prod_row_specific.eval_rlc, + eval_rlc, + ); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(next.prod_acc), + FieldExtension::add(next.eval_acc, next_prod_row_specific.eval_rlc), + eval_acc, + ); + + /////////////////////////////////////// + // Logup spec evaluation + /////////////////////////////////////// + let logup_row_specific: &LogupSpecificCols = + specific[..LogupSpecificCols::::width()].borrow(); + let next_logup_row_specfic: &LogupSpecificCols = + next.specific[..LogupSpecificCols::::width()].borrow(); + + self.memory_bridge + .read( + MemoryAddress::new( + native_as, + register_ptrs[0] + + AB::F::from_canonical_usize(EXT_DEG * 2) + + num_prod_spec + + (curr_logup_n - AB::F::ONE), + ), // curr_logup_n starts at 1. + [max_round], + start_timestamp, + &logup_row_specific.read_records[0], + ) + .eval(builder, logup_row); + + // logup_row * within_round_limit = + // logup_in_round_evaluation + logup_next_round_evaluation + builder + .when(logup_in_round_evaluation + logup_next_round_evaluation) + .assert_eq( + logup_row_specific.data_ptr, + (logup_nested_len * (curr_logup_n - AB::F::ONE) + + logup_spec_inner_inner_len * round) + * AB::F::from_canonical_usize(EXT_DEG), + ); + builder.assert_eq( + logup_row * within_round_limit * in_round, + logup_in_round_evaluation, + ); + builder.assert_eq( + logup_row * within_round_limit * not(in_round), + logup_next_round_evaluation, + ); + builder.assert_eq(logup_row * should_acc, logup_acc); + + self.memory_bridge + .read( + MemoryAddress::new(native_as, register_ptrs[3] + logup_row_specific.data_ptr), + logup_row_specific.pq, + start_timestamp + AB::F::ONE, + &logup_row_specific.read_records[1], + ) + .eval(builder, logup_row * within_round_limit); + + let p1: [_; EXT_DEG] = logup_row_specific.pq[0..EXT_DEG].try_into().unwrap(); + let p2: [_; EXT_DEG] = logup_row_specific.pq[EXT_DEG..(EXT_DEG * 2)] + .try_into() + .unwrap(); + let q1: [_; EXT_DEG] = logup_row_specific.pq[(EXT_DEG * 2)..{ EXT_DEG * 3 }] + .try_into() + .unwrap(); + let q2: [_; EXT_DEG] = logup_row_specific.pq[(EXT_DEG * 3)..(EXT_DEG * 4)] + .try_into() + .unwrap(); + + // write p_evals + self.memory_bridge + .write( + MemoryAddress::new( + native_as, + register_ptrs[4] + + (num_prod_spec + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG), + ), + logup_row_specific.p_evals, + start_timestamp + AB::F::TWO, + &logup_row_specific.write_records[0], + ) + .eval(builder, logup_row * within_round_limit); + + // write q_evals + self.memory_bridge + .write( + MemoryAddress::new( + native_as, + register_ptrs[4] + + (num_prod_spec + num_logup_spec + curr_logup_n) + * AB::F::from_canonical_usize(EXT_DEG), + ), + logup_row_specific.q_evals, + start_timestamp + AB::F::from_canonical_usize(3), + &logup_row_specific.write_records[1], + ) + .eval(builder, logup_row * within_round_limit); + + // Calculate evaluations + let next_round_p_evals = FieldExtension::add( + FieldExtension::multiply::(p1, c1), + FieldExtension::multiply::(p2, c2), + ); + let in_round_p_evals = FieldExtension::add( + FieldExtension::multiply::(p1, q2), + FieldExtension::multiply::(p2, q1), + ); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(logup_in_round_evaluation), + in_round_p_evals, + logup_row_specific.p_evals, + ); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(logup_next_round_evaluation), + next_round_p_evals, + logup_row_specific.p_evals, + ); + + let next_round_q_evals = FieldExtension::add( + FieldExtension::multiply::(q1, c1), + FieldExtension::multiply::(q2, c2), + ); + let in_round_q_evals = FieldExtension::multiply::(q1, q2); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(logup_in_round_evaluation), + in_round_q_evals, + logup_row_specific.q_evals, + ); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(logup_next_round_evaluation), + next_round_q_evals, + logup_row_specific.q_evals, + ); + + // Accumulate evaluation + let eval_rlc = FieldExtension::add( + FieldExtension::multiply::(logup_row_specific.p_evals, alpha1), + FieldExtension::multiply::(logup_row_specific.q_evals, alpha2), + ); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(logup_acc), + logup_row_specific.eval_rlc, + eval_rlc, + ); + + // Accumulate into global accumulator `eval_acc` + // when round < max_round - 2 + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(next.logup_acc), + FieldExtension::add(next.eval_acc, next_logup_row_specfic.eval_rlc), + eval_acc, + ); + } +} diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs new file mode 100644 index 0000000000..17e99c442a --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -0,0 +1,579 @@ +use std::borrow::BorrowMut; + +use openvm_circuit::{ + arch::{ + CustomBorrow, ExecutionError, MultiRowLayout, MultiRowMetadata, PreflightExecutor, + RecordArena, TraceFiller, VmChipWrapper, VmStateMut, + }, + system::{ + memory::{online::TracingMemory, MemoryAuxColsFactory}, + native_adapter::util::{memory_read_native, tracing_write_native_inplace}, + }, +}; +use openvm_instructions::{ + instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode, NATIVE_AS, +}; +use openvm_native_compiler::SumcheckOpcode::SUMCHECK_LAYER_EVAL; +use openvm_stark_backend::p3_field::PrimeField32; + +use crate::{ + field_extension::{FieldExtension, EXT_DEG}, + fri::elem_to_ext, + mem_fill_helper, + sumcheck::columns::{ + HeaderSpecificCols, LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols, + }, + tracing_read_native_helper, +}; + +pub(crate) const CONTEXT_ARR_BASE_LEN: usize = EXT_DEG * 2; +pub(crate) const CURRENT_LAYER_MODE: u32 = 1; +pub(crate) const NEXT_LAYER_MODE: u32 = 0; + +pub(crate) fn calculate_3d_ext_idx( + inner_inner_len: u32, + inner_len: u32, + outer_idx: u32, + inner_idx: u32, + inner_inner_idx: u32, +) -> u32 { + (inner_inner_len * inner_len * outer_idx + inner_inner_len * inner_idx + inner_inner_idx) + * EXT_DEG as u32 +} + +#[derive(Debug, Clone, Default)] +pub struct NativeSumcheckMetadata { + num_rows: usize, +} + +impl MultiRowMetadata for NativeSumcheckMetadata { + #[inline(always)] + fn get_num_rows(&self) -> usize { + self.num_rows + } +} + +type NativeSumcheckRecordLayout = MultiRowLayout; + +pub struct NativeSumcheckRecordMut<'a, F>(&'a mut [NativeSumcheckCols]); + +impl<'a, F: PrimeField32> + CustomBorrow<'a, NativeSumcheckRecordMut<'a, F>, NativeSumcheckRecordLayout> for [u8] +{ + fn custom_borrow( + &'a mut self, + layout: NativeSumcheckRecordLayout, + ) -> NativeSumcheckRecordMut<'a, F> { + // SAFETY: + // - align_to_mut() ensures proper alignment for NativeSumcheckCols + // - Layout guarantees sufficient length for num_rows records + // - Slice bounds validated by taking only num_rows elements + let arr = unsafe { self.align_to_mut::>().1 }; + NativeSumcheckRecordMut(&mut arr[..layout.metadata.num_rows]) + } + + unsafe fn extract_layout(&self) -> NativeSumcheckRecordLayout { + // Each instruction record consists solely of some number of contiguously + // stored NativeSumcheckCols<...> structs, each of which corresponds to a + // single trace row. Trace fillers don't actually need to know how many rows + // each instruction uses, and can thus treat each NativePoseidon2Cols<...> + // as a single record. + NativeSumcheckRecordLayout { + metadata: NativeSumcheckMetadata { num_rows: 1 }, + } + } +} + +#[derive(derive_new::new, Copy, Clone)] +pub struct NativeSumcheckExecutor; + +#[derive(derive_new::new)] +pub struct NativeSumcheckFiller; + +pub type NativeSumcheckChip = VmChipWrapper; + +impl Default for NativeSumcheckExecutor { + fn default() -> Self { + Self::new() + } +} + +impl PreflightExecutor for NativeSumcheckExecutor +where + F: PrimeField32, + for<'buf> RA: RecordArena<'buf, NativeSumcheckRecordLayout, NativeSumcheckRecordMut<'buf, F>>, +{ + fn execute( + &self, + state: VmStateMut, + instruction: &Instruction, + ) -> Result<(), ExecutionError> { + let &Instruction { + opcode: op, + a: r_evals_reg, + b: ctx_reg, + c: challenges_reg, + d: data_address_space, + e: register_address_space, + f: prod_evals_reg, + g: logup_evals_reg, + } = instruction; + + // This opcode supports two modes of operation: + // 1. calculate the expected evaluation of two types of sumchecks for the current round + // a. product sumcheck: v' = v[0] * v[1] + // b. logup sumcheck: p'= p[0] * q[1] + p[1] * q[0] and q'= q[0] * q[1]. + // 2. calculate the expected value of next layer: + // a. product sumcheck: v[r] = eq(0,r) * v[0] + eq(1,r) * v[1] + // b. logup sumcheck: p[r] = eq(0,r) * p[0] + eq(1,r) * p[1] + // and q[r] = eq(0,r) * q[0] + eq(1,r) * q[1] + assert_eq!(op, SUMCHECK_LAYER_EVAL.global_opcode()); + assert_eq!(data_address_space.as_canonical_u32(), NATIVE_AS); + assert_eq!(register_address_space.as_canonical_u32(), NATIVE_AS); + + let [ctx_ptr]: [F; 1] = memory_read_native(state.memory.data(), ctx_reg.as_canonical_u32()); + let ctx: [u32; 8] = memory_read_native(state.memory.data(), ctx_ptr.as_canonical_u32()) + .map(|x: F| x.as_canonical_u32()); + + let [round, num_prod_spec, num_logup_spec, prod_specs_inner_len, prod_specs_inner_inner_len, logup_specs_inner_len, logup_specs_inner_inner_len, mode] = + ctx; + // allocate n rows + let num_rows = (1 + num_prod_spec + num_logup_spec) as usize; + let rows = state + .ctx + .alloc(MultiRowLayout::new(NativeSumcheckMetadata { num_rows })) + .0; + + let mut cur_timestamp = state.memory.timestamp(); + // head row + let head_row: &mut NativeSumcheckCols = &mut rows[0]; + let head_specific: &mut HeaderSpecificCols = + head_row.specific[..HeaderSpecificCols::::width()].borrow_mut(); + + head_row.header_row = F::ONE; + head_row.first_timestamp = F::from_canonical_u32(cur_timestamp); + head_row.start_timestamp = F::from_canonical_u32(cur_timestamp); + + head_specific.pc = F::from_canonical_u32(*state.pc); + + head_specific.registers[0] = ctx_reg; + head_specific.registers[1] = challenges_reg; + head_specific.registers[2] = prod_evals_reg; + head_specific.registers[3] = logup_evals_reg; + head_specific.registers[4] = r_evals_reg; + + // read pointers + let [ctx_ptr]: [F; 1] = tracing_read_native_helper( + state.memory, + ctx_reg.as_canonical_u32(), + head_specific.read_records[0].as_mut(), + ); + let [challenges_ptr]: [F; 1] = tracing_read_native_helper( + state.memory, + challenges_reg.as_canonical_u32(), + head_specific.read_records[1].as_mut(), + ); + let [prod_evals_ptr]: [F; 1] = tracing_read_native_helper( + state.memory, + prod_evals_reg.as_canonical_u32(), + head_specific.read_records[2].as_mut(), + ); + let [logup_evals_ptr]: [F; 1] = tracing_read_native_helper( + state.memory, + logup_evals_reg.as_canonical_u32(), + head_specific.read_records[3].as_mut(), + ); + let [r_evals_ptr]: [F; 1] = tracing_read_native_helper( + state.memory, + r_evals_reg.as_canonical_u32(), + head_specific.read_records[4].as_mut(), + ); + + let ctx: [F; CONTEXT_ARR_BASE_LEN] = tracing_read_native_helper( + state.memory, + ctx_ptr.as_canonical_u32(), + head_specific.read_records[5].as_mut(), + ); + + let challenges: [F; EXT_DEG * 4] = tracing_read_native_helper( + state.memory, + challenges_ptr.as_canonical_u32(), + head_specific.read_records[6].as_mut(), + ); + cur_timestamp += 7; // 5 register reads + ctx read + challenges read + head_row.challenges.copy_from_slice(&challenges); + + // challenges = [alpha, c1=r, c2=1-r] + let alpha: [F; 4] = challenges[0..EXT_DEG].try_into().unwrap(); + let c1: [F; 4] = challenges[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap(); + let c2: [F; 4] = challenges[(EXT_DEG * 2)..(EXT_DEG * 3)].try_into().unwrap(); + + let mut eval_acc = elem_to_ext(F::from_canonical_u32(0)); + let mut alpha_acc = elem_to_ext(F::from_canonical_u32(1)); + + // all rows share same register values, ctx, challenges + for row in rows.iter_mut() { + // c1, c2 are same during the entire execution + row.challenges[EXT_DEG..3 * EXT_DEG].copy_from_slice(&challenges[EXT_DEG..3 * EXT_DEG]); + row.alpha = alpha; + row.ctx = ctx; + row.prod_nested_len = + F::from_canonical_u32(prod_specs_inner_len * prod_specs_inner_inner_len); + row.logup_nested_len = + F::from_canonical_u32(logup_specs_inner_len * logup_specs_inner_inner_len); + row.register_ptrs[0] = ctx_ptr; + row.register_ptrs[1] = challenges_ptr; + row.register_ptrs[2] = prod_evals_ptr; + row.register_ptrs[3] = logup_evals_ptr; + row.register_ptrs[4] = r_evals_ptr; + } + + // product rows + for (i, prod_row) in rows + .iter_mut() + .skip(1) + .take(num_prod_spec as usize) + .enumerate() + { + let prod_specific: &mut ProdSpecificCols = + prod_row.specific[..ProdSpecificCols::::width()].borrow_mut(); + + prod_row.prod_row = F::ONE; + prod_row.prod_continued = if i < (num_prod_spec - 1) as usize { + F::ONE + } else { + F::ZERO + }; + prod_row.curr_prod_n = F::from_canonical_usize(i + 1); // curr_prod_n starts from 1 + prod_row.start_timestamp = F::from_canonical_u32(cur_timestamp); + + // read max_round + let [max_round]: [F; 1] = tracing_read_native_helper( + state.memory, + ctx_ptr.as_canonical_u32() + (CONTEXT_ARR_BASE_LEN + i) as u32, + prod_specific.read_records[0].as_mut(), + ); + cur_timestamp += 1; + + prod_row.challenges[0..EXT_DEG].copy_from_slice(&alpha_acc); + prod_row.max_round = max_round; + + let max_round = max_round.as_canonical_u32(); + // round starts from 0 + if round < max_round - 1 { + prod_row.within_round_limit = F::ONE; + let start = calculate_3d_ext_idx( + prod_specs_inner_inner_len, + prod_specs_inner_len, + i as u32, + round, + 0, + ); + prod_specific.data_ptr = F::from_canonical_u32(start); + + // read p1, p2 + let ps: [F; EXT_DEG * 2] = tracing_read_native_helper( + state.memory, + prod_evals_ptr.as_canonical_u32() + start, + prod_specific.read_records[1].as_mut(), + ); + let p1: [F; EXT_DEG] = ps[0..EXT_DEG].try_into().unwrap(); + let p2: [F; EXT_DEG] = ps[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap(); + + prod_specific.p = ps; + + // compute expected eval + let eval = match mode { + NEXT_LAYER_MODE => FieldExtension::add( + FieldExtension::multiply(p1, c1), + FieldExtension::multiply(p2, c2), + ), + CURRENT_LAYER_MODE => FieldExtension::multiply(p1, p2), + _ => unreachable!("mode should be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), + }; + prod_specific.p_evals = eval; + + match mode { + NEXT_LAYER_MODE => { + prod_row.prod_next_round_evaluation = F::ONE; + } + CURRENT_LAYER_MODE => { + prod_row.prod_in_round_evaluation = F::ONE; + } + _ => unreachable!("mode should be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), + } + + // write p eval + tracing_write_native_inplace( + state.memory, + r_evals_ptr.as_canonical_u32() + (1 + i as u32) * (EXT_DEG as u32), + eval, + &mut prod_specific.write_record, + ); + cur_timestamp += 2; + + let eval_rlc = FieldExtension::multiply(alpha_acc, eval); + prod_specific.eval_rlc = eval_rlc; + + if mode == NEXT_LAYER_MODE && round + 1 < max_round - 1 { + eval_acc = FieldExtension::add(eval_acc, eval_rlc); + prod_row.should_acc = F::ONE; + prod_row.prod_acc = F::ONE; + prod_row.eval_acc = eval_acc; + } + } + + alpha_acc = FieldExtension::multiply(alpha_acc, alpha); + } + + // logup rows + for (i, logup_row) in rows.iter_mut().skip(1 + num_prod_spec as usize).enumerate() { + let logup_specific: &mut LogupSpecificCols = + logup_row.specific[..LogupSpecificCols::::width()].borrow_mut(); + + logup_row.logup_row = F::ONE; + logup_row.logup_continued = if i < (num_logup_spec - 1) as usize { + F::ONE + } else { + F::ZERO + }; + logup_row.curr_logup_n = F::from_canonical_usize(i + 1); // curr_logup_n starts from 1 + logup_row.start_timestamp = F::from_canonical_u32(cur_timestamp); + + let [max_round]: [F; 1] = tracing_read_native_helper( + state.memory, + ctx_ptr.as_canonical_u32() + num_prod_spec + (CONTEXT_ARR_BASE_LEN + i) as u32, + logup_specific.read_records[0].as_mut(), + ); + logup_row.max_round = max_round; + cur_timestamp += 1; + + let alpha_numerator = alpha_acc; + let alpha_denominator = FieldExtension::multiply(alpha_acc, alpha); + logup_row.challenges[0..EXT_DEG].copy_from_slice(&alpha_acc); + logup_row.challenges[3 * EXT_DEG..(4 * EXT_DEG)].copy_from_slice(&alpha_denominator); + + let max_round = max_round.as_canonical_u32(); + if round < max_round - 1 { + logup_row.within_round_limit = F::ONE; + let start = calculate_3d_ext_idx( + logup_specs_inner_inner_len, + logup_specs_inner_len, + i as u32, + round, + 0, + ); + logup_specific.data_ptr = F::from_canonical_u32(start); + + // read p1, p2, q1, q2 + let pqs: [F; EXT_DEG * 4] = tracing_read_native_helper( + state.memory, + logup_evals_ptr.as_canonical_u32() + start, + logup_specific.read_records[1].as_mut(), + ); + let p1: [F; EXT_DEG] = pqs[0..EXT_DEG].try_into().unwrap(); + let p2: [F; EXT_DEG] = pqs[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap(); + let q1: [F; EXT_DEG] = pqs[(EXT_DEG * 2)..(EXT_DEG * 3)].try_into().unwrap(); + let q2: [F; EXT_DEG] = pqs[(EXT_DEG * 3)..(EXT_DEG * 4)].try_into().unwrap(); + + logup_specific.pq = pqs; + + // compute expected evals + let p_eval = match mode { + NEXT_LAYER_MODE => FieldExtension::add( + FieldExtension::multiply(p1, c1), + FieldExtension::multiply(p2, c2), + ), + CURRENT_LAYER_MODE => FieldExtension::add( + FieldExtension::multiply(p1, q2), + FieldExtension::multiply(p2, q1), + ), + _ => unreachable!("mode should be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), + }; + let q_eval = match mode { + NEXT_LAYER_MODE => FieldExtension::add( + FieldExtension::multiply(q1, c1), + FieldExtension::multiply(q2, c2), + ), + CURRENT_LAYER_MODE => FieldExtension::multiply(q1, q2), + _ => unreachable!("mode should be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), + }; + + match mode { + NEXT_LAYER_MODE => { + logup_row.logup_next_round_evaluation = F::ONE; + } + CURRENT_LAYER_MODE => { + logup_row.logup_in_round_evaluation = F::ONE; + } + _ => unreachable!("mode should be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), + } + + logup_specific.p_evals = p_eval; + logup_specific.q_evals = q_eval; + + // write p_eval + tracing_write_native_inplace( + state.memory, + r_evals_ptr.as_canonical_u32() + + (1 + num_prod_spec + i as u32) * (EXT_DEG as u32), + p_eval, + &mut logup_specific.write_records[0], + ); + // write q_eval + tracing_write_native_inplace( + state.memory, + r_evals_ptr.as_canonical_u32() + + (1 + num_prod_spec + num_logup_spec + i as u32) * (EXT_DEG as u32), + q_eval, + &mut logup_specific.write_records[1], + ); + cur_timestamp += 3; // 1 read, 2 writes + + let eval_rlc = FieldExtension::add( + FieldExtension::multiply(alpha_numerator, p_eval), + FieldExtension::multiply(alpha_denominator, q_eval), + ); + logup_specific.eval_rlc = eval_rlc; + if mode == NEXT_LAYER_MODE && round + 1 < max_round - 1 { + eval_acc = FieldExtension::add(eval_acc, eval_rlc); + logup_row.should_acc = F::ONE; + logup_row.logup_acc = F::ONE; + logup_row.eval_acc = eval_acc; + } + } + + alpha_acc = FieldExtension::multiply(alpha_denominator, alpha); + } + + if let Some(last_row) = rows.last_mut() { + last_row.is_end = F::ONE; + } + + let head_row = &mut rows[0]; + head_row.last_timestamp = F::from_canonical_u32(cur_timestamp + 1); + + let head_specific: &mut HeaderSpecificCols = + head_row.specific[..HeaderSpecificCols::::width()].borrow_mut(); + + tracing_write_native_inplace( + state.memory, + r_evals_ptr.as_canonical_u32(), + eval_acc, + &mut head_specific.write_records, + ); + + for row in rows.iter_mut() { + if row.header_row == F::ONE { + row.eval_acc = eval_acc; + } else if row.prod_row == F::ONE { + let specific: &mut ProdSpecificCols = + row.specific[..ProdSpecificCols::::width()].borrow_mut(); + + eval_acc = FieldExtension::subtract(eval_acc, specific.eval_rlc); + row.eval_acc = eval_acc; + } else if row.logup_row == F::ONE { + let specific: &mut LogupSpecificCols = + row.specific[..LogupSpecificCols::::width()].borrow_mut(); + eval_acc = FieldExtension::subtract(eval_acc, specific.eval_rlc); + row.eval_acc = eval_acc; + } + } + assert_eq!(eval_acc, elem_to_ext(F::from_canonical_u32(0)),); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + Ok(()) + } + + // GKR layered IOP for product and logup relations + fn get_opcode_name(&self, opcode: usize) -> String { + assert_eq!(opcode, SUMCHECK_LAYER_EVAL.global_opcode().as_usize()); + String::from("SUMCHECK_LAYER_EVAL") + } +} + +impl TraceFiller for NativeSumcheckFiller { + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let cols: &mut NativeSumcheckCols = row_slice.borrow_mut(); + let start_timestamp = cols.start_timestamp.as_canonical_u32(); + let last_timestamp = cols.last_timestamp.as_canonical_u32(); + + if cols.header_row == F::ONE { + let header: &mut HeaderSpecificCols = + cols.specific[..HeaderSpecificCols::::width()].borrow_mut(); + + for i in 0..7usize { + mem_fill_helper( + mem_helper, + start_timestamp + i as u32, + header.read_records[i].as_mut(), + ); + } + mem_fill_helper( + mem_helper, + last_timestamp - 1, + header.write_records.as_mut(), + ); + } else if cols.prod_row == F::ONE { + let prod_row_specific: &mut ProdSpecificCols = + cols.specific[..ProdSpecificCols::::width()].borrow_mut(); + + // read max_round + mem_fill_helper( + mem_helper, + start_timestamp, + prod_row_specific.read_records[0].as_mut(), + ); + if cols.within_round_limit == F::ONE { + // read p1, p2 + mem_fill_helper( + mem_helper, + start_timestamp + 1, + prod_row_specific.read_records[1].as_mut(), + ); + // write p_eval + mem_fill_helper( + mem_helper, + start_timestamp + 2, + prod_row_specific.write_record.as_mut(), + ); + } + } else if cols.logup_row == F::ONE { + let logup_row_specific: &mut LogupSpecificCols = + cols.specific[..LogupSpecificCols::::width()].borrow_mut(); + + // read max_round + mem_fill_helper( + mem_helper, + start_timestamp, + logup_row_specific.read_records[0].as_mut(), + ); + if cols.within_round_limit == F::ONE { + // read p1, p2, q1, q2 + mem_fill_helper( + mem_helper, + start_timestamp + 1, + logup_row_specific.read_records[1].as_mut(), + ); + // write p_eval + mem_fill_helper( + mem_helper, + start_timestamp + 2, + logup_row_specific.write_records[0].as_mut(), + ); + // write q_eval + mem_fill_helper( + mem_helper, + start_timestamp + 3, + logup_row_specific.write_records[1].as_mut(), + ); + } + } + } + + fn fill_dummy_trace_row(&self, row_slice: &mut [F]) { + let cols: &mut NativeSumcheckCols = row_slice.borrow_mut(); + + cols.is_end = F::ONE; + } +} diff --git a/extensions/native/circuit/src/sumcheck/columns.rs b/extensions/native/circuit/src/sumcheck/columns.rs new file mode 100644 index 0000000000..b3e6bf4f25 --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/columns.rs @@ -0,0 +1,135 @@ +use openvm_circuit::system::memory::offline_checker::{MemoryReadAuxCols, MemoryWriteAuxCols}; +use openvm_circuit_primitives_derive::AlignedBorrow; + +use crate::{field_extension::EXT_DEG, utils::const_max}; + +const fn max3(a: usize, b: usize, c: usize) -> usize { + const_max(a, const_max(b, c)) +} + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct NativeSumcheckCols { + /// Indicates that this row is the header for a layer sum operation + pub header_row: T, + /// Indicates that this row is a step for prod_spec in the layer sum operation + pub prod_row: T, + /// Indicates that this row is a step for logup_spec in the layer sum operation + pub logup_row: T, + /// Indicates that this row is the end of the entire layer sum operation + pub is_end: T, + + pub prod_continued: T, + pub logup_continued: T, + + /// Indicates what type of evaluation constraints should be applied + pub prod_in_round_evaluation: T, + pub prod_next_round_evaluation: T, + pub logup_in_round_evaluation: T, + pub logup_next_round_evaluation: T, + + /// Indicates if evaluations are accumulated + pub prod_acc: T, + pub logup_acc: T, + + /// Timestamps + pub first_timestamp: T, + pub start_timestamp: T, + pub last_timestamp: T, + + // Register values + pub register_ptrs: [T; 5], + + // Context variables + // [ + // round, + // num_prod_spec, + // num_logup_spec, + // prod_spec_inner_len, + // prod_spec_inner_inner_len, + // logup_spec_inner_len, + // logup_spec_inner_inner_len, + // in_layer, + // ] + pub ctx: [T; EXT_DEG * 2], + + pub prod_nested_len: T, + pub logup_nested_len: T, + + pub curr_prod_n: T, + pub curr_logup_n: T, + + pub alpha: [T; EXT_DEG], + // alpha1, c1, c2, alpha2 (for logup rows) + pub challenges: [T; EXT_DEG * 4], + + // Specific to each row + pub max_round: T, + // Is this round within max_round + pub within_round_limit: T, + // Should the evaluation be accumualted + pub should_acc: T, + + // The current final evaluation accumulator. Extension element. + pub eval_acc: [T; EXT_DEG], + + // /// 1. For header row, 5 registers, ctx, challenges + // /// 2. For the rest: max_variables, p1, p2, q1, q2 + // pub read_records: [MemoryReadAuxCols; 7], + // /// 1. For header row, write final result + // /// 2. For prod rows: write prod_evals + // /// 3. For logup rows: write q_evals, p_evals + // pub write_records: [MemoryWriteAuxCols; 2], + pub specific: [T; max3( + HeaderSpecificCols::::width(), + ProdSpecificCols::::width(), + LogupSpecificCols::::width(), + )], +} + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct HeaderSpecificCols { + pub pc: T, + pub registers: [T; 5], + /// 5 register reads + ctx read + challenges read + pub read_records: [MemoryReadAuxCols; 7], + /// Write the final evaluation + pub write_records: MemoryWriteAuxCols, +} + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct ProdSpecificCols { + /// Pointer + pub data_ptr: T, + /// 2 extension elements + pub p: [T; EXT_DEG * 2], + /// read max varibale and 2 p values + pub read_records: [MemoryReadAuxCols; 2], + /// Calculated p evals + pub p_evals: [T; EXT_DEG], + /// write p_evals + pub write_record: MemoryWriteAuxCols, + /// p_evals * alpha^i + pub eval_rlc: [T; EXT_DEG], +} + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct LogupSpecificCols { + /// Pointer + pub data_ptr: T, + /// 4 extension elements + pub pq: [T; EXT_DEG * 4], + /// read max variable and 4 values: p1, p2, q1, q2 + pub read_records: [MemoryReadAuxCols; 2], + /// Calculated p evals + pub p_evals: [T; EXT_DEG], + /// Calculated q evals + pub q_evals: [T; EXT_DEG], + /// write both p_evals and q_evals + pub write_records: [MemoryWriteAuxCols; 2], + /// Evaluation for the accumulator + pub eval_rlc: [T; EXT_DEG], +} diff --git a/extensions/native/circuit/src/sumcheck/cuda.rs b/extensions/native/circuit/src/sumcheck/cuda.rs new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/cuda.rs @@ -0,0 +1 @@ + diff --git a/extensions/native/circuit/src/sumcheck/execution.rs b/extensions/native/circuit/src/sumcheck/execution.rs new file mode 100644 index 0000000000..7202e57b00 --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/execution.rs @@ -0,0 +1,345 @@ +use std::{ + borrow::{Borrow, BorrowMut}, + mem::size_of, +}; + +use openvm_circuit::{arch::*, system::memory::online::GuestMemory}; +use openvm_circuit_primitives::AlignedBytesBorrow; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, NATIVE_AS}; +use openvm_stark_backend::p3_field::PrimeField32; + +use crate::{ + field_extension::{FieldExtension, EXT_DEG}, + fri::elem_to_ext, + sumcheck::chip::{ + calculate_3d_ext_idx, NativeSumcheckExecutor, CONTEXT_ARR_BASE_LEN, CURRENT_LAYER_MODE, + NEXT_LAYER_MODE, + }, +}; + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct NativeSumcheckPreCompute { + r_evals_reg: u32, + ctx_reg: u32, + challenges_reg: u32, + prod_evals_reg: u32, + logup_evals_reg: u32, +} + +impl NativeSumcheckExecutor { + #[inline(always)] + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut NativeSumcheckPreCompute, + ) -> Result<(), StaticProgramError> { + let &Instruction { + a, + b, + c, + d, + e, + f, + g, + .. + } = inst; + + let r_evals_reg = a.as_canonical_u32(); + let ctx_reg = b.as_canonical_u32(); + let challenges_reg = c.as_canonical_u32(); + let prod_evals_reg = f.as_canonical_u32(); + let logup_evals_reg = g.as_canonical_u32(); + + if d.as_canonical_u32() != NATIVE_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + if e.as_canonical_u32() != NATIVE_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + + *data = NativeSumcheckPreCompute { + r_evals_reg, + ctx_reg, + challenges_reg, + prod_evals_reg, + logup_evals_reg, + }; + + Ok(()) + } +} + +impl Executor for NativeSumcheckExecutor +where + F: PrimeField32, +{ + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let pre_compute: &mut NativeSumcheckPreCompute = data.borrow_mut(); + + self.pre_compute_impl(pc, inst, pre_compute)?; + + let fn_ptr = execute_e1_handler; + Ok(fn_ptr) + } + + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() + } + + #[cfg(not(feature = "tco"))] + #[inline(always)] + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut NativeSumcheckPreCompute = data.borrow_mut(); + + self.pre_compute_impl(pc, inst, pre_compute)?; + + let fn_ptr = execute_e1_impl; + Ok(fn_ptr) + } +} + +impl MeteredExecutor for NativeSumcheckExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn metered_pre_compute_size(&self) -> usize { + size_of::>() + } + + #[cfg(not(feature = "tco"))] + #[inline(always)] + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + + let fn_ptr = execute_e2_impl; + Ok(fn_ptr) + } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + + let fn_ptr = execute_e2_handler; + Ok(fn_ptr) + } +} + +#[create_handler] +#[inline(always)] +unsafe fn execute_e1_impl( + pre_compute: &[u8], + instret: &mut u64, + pc: &mut u32, + _instret_end: u64, + exec_state: &mut VmExecState, +) { + let pre_compute: &NativeSumcheckPreCompute = pre_compute.borrow(); + execute_e12_impl(pre_compute, instret, pc, exec_state); +} + +#[create_handler] +#[inline(always)] +unsafe fn execute_e2_impl( + pre_compute: &[u8], + instret: &mut u64, + pc: &mut u32, + _arg: u64, + exec_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + let height = execute_e12_impl(&pre_compute.data, instret, pc, exec_state); + exec_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, height); +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &NativeSumcheckPreCompute, + instret: &mut u64, + pc: &mut u32, + exec_state: &mut VmExecState, +) -> u32 { + let [r_evals_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.r_evals_reg); + let [ctx_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.ctx_reg); + let [challenges_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.challenges_reg); + let [prod_evals_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.prod_evals_reg); + let [logup_evals_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.logup_evals_reg); + + let r_evals_ptr_u32 = r_evals_ptr.as_canonical_u32(); + let ctx_ptr_u32 = ctx_ptr.as_canonical_u32(); + let logup_evals_ptr = logup_evals_ptr.as_canonical_u32(); + let prod_evals_ptr = prod_evals_ptr.as_canonical_u32(); + + let ctx: [u32; 8] = exec_state + .vm_read(NATIVE_AS, ctx_ptr_u32) + .map(|x: F| x.as_canonical_u32()); + let [round, num_prod_spec, num_logup_spec, prod_specs_inner_len, prod_specs_inner_inner_len, logup_specs_inner_len, logup_specs_inner_inner_len, mode] = + ctx; + let challenges: [F; EXT_DEG * 4] = + exec_state.vm_read(NATIVE_AS, challenges_ptr.as_canonical_u32()); + let alpha: [F; EXT_DEG] = challenges[0..EXT_DEG].try_into().unwrap(); + let c1: [F; EXT_DEG] = challenges[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); + let c2: [F; EXT_DEG] = challenges[EXT_DEG * 2..EXT_DEG * 3].try_into().unwrap(); + + let mut height = 1; + let mut alpha_acc = elem_to_ext(F::ONE); + let mut eval_acc = elem_to_ext(F::ZERO); + + let prod_offset = ctx_ptr_u32 + CONTEXT_ARR_BASE_LEN as u32; + for i in 0..num_prod_spec { + let [max_round]: [u32; 1] = exec_state + .vm_read(NATIVE_AS, prod_offset + i) + .map(|x: F| x.as_canonical_u32()); + + let start = calculate_3d_ext_idx( + prod_specs_inner_inner_len, + prod_specs_inner_len, + i, + round, + 0, + ); + + if round < max_round - 1 { + let ps: [F; EXT_DEG * 2] = exec_state.vm_read(NATIVE_AS, prod_evals_ptr + start); + let p1: [F; EXT_DEG] = ps[0..EXT_DEG].try_into().unwrap(); + let p2: [F; EXT_DEG] = ps[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); + + let eval = match mode { + CURRENT_LAYER_MODE => FieldExtension::multiply(p1, p2), + NEXT_LAYER_MODE => FieldExtension::add( + FieldExtension::multiply(p1, c1), + FieldExtension::multiply(p2, c2), + ), + _ => unreachable!("mode can only be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), + }; + + exec_state.vm_write(NATIVE_AS, r_evals_ptr_u32 + (1 + i) * EXT_DEG as u32, &eval); + + if mode == NEXT_LAYER_MODE && round + 1 < max_round - 1 { + // update eval_acc + eval_acc = FieldExtension::add(eval_acc, FieldExtension::multiply(alpha_acc, eval)); + } + } + + // update alpha_acc + alpha_acc = FieldExtension::multiply(alpha_acc, alpha); + height += 1; + } + + let logup_offset = ctx_ptr_u32 + CONTEXT_ARR_BASE_LEN as u32 + num_prod_spec; + for i in 0..num_logup_spec { + // read max_round + let [max_round]: [u32; 1] = exec_state + .vm_read(NATIVE_AS, logup_offset + i) + .map(|x: F| x.as_canonical_u32()); + let start = calculate_3d_ext_idx( + logup_specs_inner_len, + logup_specs_inner_inner_len, + i, + round, + 0, + ); + + let alpha_denominator = FieldExtension::multiply(alpha_acc, alpha); + let alpha_numerator = alpha_acc; + + if round < max_round - 1 { + // read logup_evals + let pqs: [F; EXT_DEG * 4] = exec_state.vm_read(NATIVE_AS, logup_evals_ptr + start); + let p1: [F; EXT_DEG] = pqs[0..EXT_DEG].try_into().unwrap(); + let p2: [F; EXT_DEG] = pqs[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); + let q1: [F; EXT_DEG] = pqs[EXT_DEG * 2..EXT_DEG * 3].try_into().unwrap(); + let q2: [F; EXT_DEG] = pqs[EXT_DEG * 3..EXT_DEG * 4].try_into().unwrap(); + + // compute p_eval and q_eval + let p_eval = match mode { + CURRENT_LAYER_MODE => FieldExtension::add( + FieldExtension::multiply(p1, q2), + FieldExtension::multiply(p2, q1), + ), + NEXT_LAYER_MODE => FieldExtension::add( + FieldExtension::multiply(p1, c1), + FieldExtension::multiply(p2, c2), + ), + _ => unreachable!("mode can only be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), + }; + let q_eval = match mode { + CURRENT_LAYER_MODE => FieldExtension::multiply(q1, q2), + NEXT_LAYER_MODE => FieldExtension::add( + FieldExtension::multiply(q1, c1), + FieldExtension::multiply(q2, c2), + ), + _ => unreachable!("mode can only be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), + }; + + // write eval to r_evals + exec_state.vm_write( + NATIVE_AS, + r_evals_ptr_u32 + (1 + num_prod_spec + i) * EXT_DEG as u32, + &p_eval, + ); + exec_state.vm_write( + NATIVE_AS, + r_evals_ptr_u32 + (1 + num_prod_spec + num_logup_spec + i) * EXT_DEG as u32, + &q_eval, + ); + + let eval_rlc = FieldExtension::add( + FieldExtension::multiply(alpha_numerator, p_eval), + FieldExtension::multiply(alpha_denominator, q_eval), + ); + if mode == NEXT_LAYER_MODE && round + 1 < max_round - 1 { + // update eval_acc + eval_acc = FieldExtension::add(eval_acc, eval_rlc); + } + } + + // update alpha_acc + alpha_acc = FieldExtension::multiply(alpha_denominator, alpha); + height += 1; + } + + *pc += DEFAULT_PC_STEP; + *instret += 1; + + exec_state.vm_write(NATIVE_AS, r_evals_ptr_u32, &eval_acc); + // return height delta + height +} diff --git a/extensions/native/circuit/src/sumcheck/mod.rs b/extensions/native/circuit/src/sumcheck/mod.rs new file mode 100644 index 0000000000..8c35a0a7aa --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/mod.rs @@ -0,0 +1,11 @@ +pub mod air; +pub mod chip; +mod columns; + +#[cfg(feature = "cuda")] +mod cuda; +#[cfg(feature = "cuda")] +pub use cuda::*; + +// mod tests; +mod execution; diff --git a/extensions/native/circuit/src/utils.rs b/extensions/native/circuit/src/utils.rs index c38e656e74..3d05656f16 100644 --- a/extensions/native/circuit/src/utils.rs +++ b/extensions/native/circuit/src/utils.rs @@ -1,9 +1,36 @@ +use openvm_circuit::system::{ + memory::{offline_checker::MemoryBaseAuxCols, online::TracingMemory, MemoryAuxColsFactory}, + native_adapter::util::tracing_read_native, +}; +use p3_field::PrimeField32; + pub(crate) const CASTF_MAX_BITS: usize = 30; pub(crate) const fn const_max(a: usize, b: usize) -> usize { [a, b][(a < b) as usize] } +/// Fill `MemoryBaseAuxCols`, assuming that the `prev_timestamp` is already set in `base_aux`. +pub(crate) fn mem_fill_helper( + mem_helper: &MemoryAuxColsFactory, + timestamp: u32, + base_aux: &mut MemoryBaseAuxCols, +) { + let prev_ts = base_aux.prev_timestamp.as_canonical_u32(); + mem_helper.fill(prev_ts, timestamp, base_aux); +} + +pub(crate) fn tracing_read_native_helper( + memory: &mut TracingMemory, + ptr: u32, + base_aux: &mut MemoryBaseAuxCols, +) -> [F; BLOCK_SIZE] { + let mut prev_ts = 0; + let ret = tracing_read_native(memory, ptr, &mut prev_ts); + base_aux.set_prev(F::from_canonical_u32(prev_ts)); + ret +} + /// Testing framework #[cfg(any(test, feature = "test-utils"))] pub mod test_utils { diff --git a/extensions/native/compiler/src/asm/compiler.rs b/extensions/native/compiler/src/asm/compiler.rs index 42e32f3a7d..c80615ca7a 100644 --- a/extensions/native/compiler/src/asm/compiler.rs +++ b/extensions/native/compiler/src/asm/compiler.rs @@ -637,6 +637,18 @@ impl + TwoAdicField> AsmCo ); } } + DslIr::SumcheckLayerEval(input_ctx, challenges, prod_ptr, logup_ptr, r_ptr) => { + self.push( + AsmInstruction::SumcheckLayerEval( + input_ctx.fp(), + challenges.fp(), + prod_ptr.fp(), + logup_ptr.fp(), + r_ptr.fp(), + ), + debug_info, + ); + } _ => unimplemented!(), } } diff --git a/extensions/native/compiler/src/asm/instruction.rs b/extensions/native/compiler/src/asm/instruction.rs index ae0875b83a..3d498c00f4 100644 --- a/extensions/native/compiler/src/asm/instruction.rs +++ b/extensions/native/compiler/src/asm/instruction.rs @@ -171,6 +171,15 @@ pub enum AsmInstruction { CycleTrackerStart(), CycleTrackerEnd(), + + // Native opcode for calculating sumcheck layer evaluation + // SumcheckLayerEval(reg_a, reg_b, reg_c, ... , reg_f, reg_g) + // - reg_a: Output ptr for next layer's evaluations + // - reg_b: Context variables + // - reg_c: Challenge values (alpha, coeff) + // - reg_g: GKR product IOP evaluations + // - reg_f: GKR logup IOP evaluations + SumcheckLayerEval(i32, i32, i32, i32, i32), } impl> AsmInstruction { @@ -407,6 +416,13 @@ impl> AsmInstruction { AsmInstruction::RangeCheck(fp, lo_bits, hi_bits) => { write!(f, "range_check_fp ({})fp, ({}), ({})", fp, lo_bits, hi_bits) } + AsmInstruction::SumcheckLayerEval(ctx, cs, p_ptr, l_ptr, r_ptr) => { + write!( + f, + "sumcheck_layer_eval ({})fp, ({})fp, ({})fp, ({})fp, ({})fp", + ctx, cs, p_ptr, l_ptr, r_ptr + ) + } } } } diff --git a/extensions/native/compiler/src/conversion/mod.rs b/extensions/native/compiler/src/conversion/mod.rs index 82ee912703..e71608c190 100644 --- a/extensions/native/compiler/src/conversion/mod.rs +++ b/extensions/native/compiler/src/conversion/mod.rs @@ -12,7 +12,7 @@ use crate::{ asm::{AsmInstruction, AssemblyCode}, FieldArithmeticOpcode, FieldExtensionOpcode, FriOpcode, NativeBranchEqualOpcode, NativeJalOpcode, NativeLoadStore4Opcode, NativeLoadStoreOpcode, NativePhantom, - NativeRangeCheckOpcode, Poseidon2Opcode, VerifyBatchOpcode, + NativeRangeCheckOpcode, Poseidon2Opcode, SumcheckOpcode, VerifyBatchOpcode, }; #[derive(Clone, Copy, Debug, Serialize, Deserialize)] @@ -535,7 +535,19 @@ fn convert_instruction>( // Here it just requires a 0 AS::Immediate, )] - } + }, + AsmInstruction::SumcheckLayerEval(ctx, cs, p_ptr, l_ptr, r_ptr) => vec![ + Instruction { + opcode: options.opcode_with_offset(SumcheckOpcode::SUMCHECK_LAYER_EVAL), + a: i32_f(r_ptr), + b: i32_f(ctx), + c: i32_f(cs), + d: AS::Native.to_field(), + e: AS::Native.to_field(), + f: i32_f(p_ptr), + g: i32_f(l_ptr), + } + ], }; let debug_infos = vec![debug_info; instructions.len()]; diff --git a/extensions/native/compiler/src/ir/instructions.rs b/extensions/native/compiler/src/ir/instructions.rs index 3b30a45ad6..78347283d5 100644 --- a/extensions/native/compiler/src/ir/instructions.rs +++ b/extensions/native/compiler/src/ir/instructions.rs @@ -320,6 +320,31 @@ pub enum DslIr { CycleTrackerStart(String), /// End the cycle tracker used by a block of code annotated by the string input. CycleTrackerEnd(String), + + /// Native operation for calculating a sumcheck layer's evaluation + /// This op supports two modes: + /// 1. for computing expected evaluation for current layer, output = [ \sum_i alpha^i * + /// prod[i][0] * prod[i][1] + \sum_j alpha^(2j) * (logup_q[i][0] * logup_q[i][1] + alpha* + /// logup_p[i][0] * logup_q[i][1] + alpha * logup_p[i][1] * logup_q[i][0] ]; + /// + /// 2. for computing expected evaluation of next layer, output[1+i] = eq(0,r)*p[i][0] + eq(1,r) + /// * p[i][1]. + SumcheckLayerEval( + Ptr, // Context variables: + // 0. round, + // 1. number of product + // 2. number of logup + // 3. (3D array description) prod_specs_eval inner length + // 4. (3D array description) prod_specs_eval inner_inner length + // 5. (3D array description) logup_spec_eval inner length + // 6. (3D array description) logup_spec_eval inner length + // 7. Operational mode indicator + // 8+. usize-type variables indicating maximum rounds + Ptr, // Challenges: alpha, coeffs + Ptr, // prod_specs_eval + Ptr, // logup_specs_eval + Ptr, // output + ), } impl Default for DslIr { diff --git a/extensions/native/compiler/src/ir/mod.rs b/extensions/native/compiler/src/ir/mod.rs index 47e901cd3a..f708318c34 100644 --- a/extensions/native/compiler/src/ir/mod.rs +++ b/extensions/native/compiler/src/ir/mod.rs @@ -18,6 +18,7 @@ mod instructions; mod poseidon; mod ptr; mod select; +mod sumcheck; mod symbolic; mod types; mod utils; diff --git a/extensions/native/compiler/src/ir/sumcheck.rs b/extensions/native/compiler/src/ir/sumcheck.rs new file mode 100644 index 0000000000..0237fd6740 --- /dev/null +++ b/extensions/native/compiler/src/ir/sumcheck.rs @@ -0,0 +1,48 @@ +use super::{Array, Builder, Config, DslIr, Ext, Usize}; + +impl Builder { + /// Extends native VM ability to calculate the evaluation for a sumcheck layer + /// This opcode supports two modes (indicated by a context variable): + /// 1. calculate the expected evaluation of two types of sumchecks (prod, logup) + /// 2. calculate the expected value of next layer p[r] = eq(0,r)*p[0] + eq(1,r)*p[1] + /// + /// Context variables + /// + /// 0: round, + /// 1: number of product + /// 2. number of logup + /// 3. (3D array description) prod_specs_eval inner length + /// 4. (3D array description) prod_specs_eval inner_inner length + /// 5. (3D array description) logup_spec_eval inner length + /// 6. (3D array description) logup_spec_eval inner length + /// 7. Operational mode indicator + /// 8+ Additional usize-type variables indicating maximum rounds + /// + /// Output + /// + /// 1. for computing expected evaluation, output = [ \sum_i alpha^i * prod[i][0] * prod[i][1] + + /// \sum_j alpha^(2j) * (logup_q[i][0] * logup_q[i][1] + alpha* logup_p[i][0] * logup_q[i][1] + /// + alpha * logup_p[i][1] * logup_q[i][0] ]; + /// + /// 2. for computing expected eval of next layer, output[1+i] = eq(0,r)*p[i][0] + eq(1,r) * + /// p[i][1]. + pub fn sumcheck_layer_eval( + &mut self, + input_ctx: &Array>, // Context variables + challenges: &Array>, // Challenges + prod_specs_eval: &Array>, /* GKR product IOP evaluations. Flattened + * from 3D array. */ + logup_specs_eval: &Array>, /* GKR logup IOP evaluations. Flattened + * from 3D array. */ + r_evals: &Array>, /* Next layer's evaluations (pointer used for + * storing opcode output) */ + ) { + self.operations.push(DslIr::SumcheckLayerEval( + input_ctx.ptr(), + challenges.ptr(), + prod_specs_eval.ptr(), + logup_specs_eval.ptr(), + r_evals.ptr(), + )); + } +} diff --git a/extensions/native/compiler/src/lib.rs b/extensions/native/compiler/src/lib.rs index 66c786fbd9..efb45b0159 100644 --- a/extensions/native/compiler/src/lib.rs +++ b/extensions/native/compiler/src/lib.rs @@ -212,3 +212,18 @@ pub enum VerifyBatchOpcode { /// per column polynomial, per opening point VERIFY_BATCH, } + +/// Opcodes for sumcheck. +#[derive( + Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode, +)] +#[opcode_offset = 0x180] +#[repr(usize)] +#[allow(non_camel_case_types)] +pub enum SumcheckOpcode { + /// Compute the expected evaluation for each layer in the tower structure that GKR product IOP + /// and logup IOP uses Supports two modes of operation: + /// 1. Calculate current layer's expected evaluation + /// 2. Calculate next layer's evaluation + SUMCHECK_LAYER_EVAL, +} diff --git a/extensions/native/recursion/tests/sumcheck.rs b/extensions/native/recursion/tests/sumcheck.rs new file mode 100644 index 0000000000..a4039028bc --- /dev/null +++ b/extensions/native/recursion/tests/sumcheck.rs @@ -0,0 +1,431 @@ +use openvm_circuit::arch::instructions::program::Program; +#[cfg(not(feature = "cuda"))] +use openvm_circuit::utils::air_test_impl; +use openvm_native_circuit::{NativeBuilder, NativeConfig, EXT_DEG}; +use openvm_native_compiler::{ + asm::{AsmBuilder, AsmCompiler}, + conversion::{convert_program, CompilerOptions}, + ir::{Ext, Usize}, + prelude::*, +}; +use openvm_stark_backend::p3_field::{ + extension::BinomialExtensionField, FieldAlgebra, FieldExtensionAlgebra, +}; +use openvm_stark_sdk::{ + config::{ + baby_bear_poseidon2::BabyBearPoseidon2Engine, + fri_params::standard_fri_params_with_100_bits_conjectured_security, FriParameters, + }, + p3_baby_bear::BabyBear, +}; + +pub type F = BabyBear; +pub type E = BinomialExtensionField; + +#[test] +fn test_sumcheck_layer_eval() { + let mut builder = AsmBuilder::>::default(); + + build_test_program(&mut builder); + + let compilation_options = CompilerOptions::default().with_cycle_tracker(); + let mut compiler = AsmCompiler::new(compilation_options.word_size); + compiler.build(builder.operations); + let asm_code = compiler.code(); + + let program: Program<_> = convert_program(asm_code, compilation_options); + let sumcheck_max_constraint_degree = 3; + let fri_params = if matches!(std::env::var("OPENVM_FAST_TEST"), Ok(x) if &x == "1") { + FriParameters { + // max constraint degree = 2^log_blowup + 1 + log_blowup: 1, + log_final_poly_len: 0, + num_queries: 2, + proof_of_work_bits: 0, + } + } else { + standard_fri_params_with_100_bits_conjectured_security(1) + }; + + let mut config = NativeConfig::aggregation(0, sumcheck_max_constraint_degree); + config.system.memory_config.max_access_adapter_n = 16; + + let vb = NativeBuilder::default(); + #[cfg(not(feature = "cuda"))] + air_test_impl::(fri_params, vb, config, program, vec![], 1, true) + .unwrap(); + #[cfg(feature = "cuda")] + { + air_test_impl::( + fri_params, + vb, + config, + program, + vec![], + 1, + true, + ) + .unwrap(); + } +} + +fn build_test_program(builder: &mut Builder) { + let ctx_u32s = [3u32, 6, 5, 8, 2, 8, 4, 0, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9]; + let ctx: Array> = builder.dyn_array(ctx_u32s.len()); + for (idx, n) in ctx_u32s.into_iter().enumerate() { + builder.set(&ctx, idx, Usize::from(n as usize)); + } + + #[rustfmt::skip] + let challenges_u32s = [ + 548478283u32, 456436544, 1716290291, 791326976, + 1829717553, 1422025771, 1917123958, 727015942, + 183548369, 591240150, 96141963, 1286249979, + 0, 0, 0, 0, + ]; + let challenges: Array> = builder.dyn_array(challenges_u32s.len() / EXT_DEG); + for (idx, n) in challenges_u32s.chunks(EXT_DEG).enumerate() { + let e: Ext = builder.constant(C::EF::from_base_slice(&[ + C::F::from_canonical_u32(n[0]), + C::F::from_canonical_u32(n[1]), + C::F::from_canonical_u32(n[2]), + C::F::from_canonical_u32(n[3]), + ])); + + builder.set(&challenges, idx, e); + } + + #[rustfmt::skip] + let prod_spec_eval_u32s = [ + 1538906710u32, 637535518, 1753132406, 1395236651, + 278806441, 1722910382, 1475548665, 1117874675, + 1578586709, 1826764884, 384068476, 1852240363, + 707958906, 1960944944, 183554399, 1259273357, + 227285124, 243066436, 1718037317, 369721963, + 1752968006, 1061013677, 775617499, 1464907431, + 544300429, 871461966, 135151545, 1343592602, + 1622220528, 643966158, 3932580, 434948358, + 540553922, 1446502052, 153298741, 1191216273, + 265936762, 1463035257, 1237633339, 1797346310, + 1355791584, 389527741, 1741650463, 1728913415, + 1825739540, 1790924136, 460776743, 29536554, + 6842036, 252495270, 1968285155, 299467416, + 49085744, 1499815729, 1098802236, 644489275, + 1827273105, 1888401527, 390077051, 565528894, + 1366177188, 67441791, 958486301, 402056716, + 590379691, 462035406, 633459131, 843304872, + 584100013, 1932496508, 250656031, 146983915, + 1835173157, 939973454, 1844873638, 1916054832, + 1601784696, 167251717, 409107688, 1062925788, + 1291319514, 1790529531, 495655592, 1093359708, + 790197205, 674458164, 195988318, 399764452, + 106865258, 967050329, 350035523, 1109292118, + 1815460301, 281986036, 900636603, 1121197008, + 1228976590, 1879998708, 1924332706, 434695844, + 1159360621, 471397106, 473371067, 1009065094, + 1320176846, 168020789, 1265321929, 1901808675, + 223657700, 1480150183, 1779968584, 144416591, + 304407746, 1864498679, 1482460119, 1554376965, + 1479261548, 1657723043, 1039345063, 1053923521, + 442080513, 1964082352, 691664908, 1941008321, + 1007729002, 860529393, 849697342, 754485488, + 584295923, 1072251466, 1105105254, 996079746, + 1305909868, 1348028973, 122275988, 464050036, + 692807777, 1098809324, 397235220, 596459886, + 1663209783, 720230826, 1422510715, 1760654694, + 544197700, 1417744567, 1938716517, 1571826328, + 1591430185, 1173137446, 175285007, 1541718596, + 1715958587, 1429966110, 583013357, 1667787861, + 109891172, 668253167, 161783842, 296183397, + 1681897325, 1054396117, 264741948, 464026995, + 1907686022, 1532786783, 394869458, 1766734740, + 136047179, 536856195, 376188855, 700633625, + 515518419, 531043483, 60673499, 556496527, + 1743028981, 873954569, 1371062291, 632169731, + 1353239206, 526507035, 1894490088, 589441599, + 1610487168, 1074160583, 366366374, 247602990, + 1535354896, 894493713, 1555870413, 1389854934, + 1897251683, 1525812801, 675621735, 697919636, + 1690274072, 1466810921, 1221110784, 1741995587, + 1877169764, 390876982, 1794129810, 297662156, + 144295349, 417037264, 1290835727, 1654971513, + 1674131303, 1625667423, 1471248832, 1676797844, + 1172916558, 1707775403, 423725211, 1643279661, + 1695774264, 378140395, 1517569394, 1666625392, + 1803981250, 439036260, 247966130, 709534816, + 361144100, 1546096548, 1240886454, 1898161518, + 843262057, 1709259464, 1301015977, 1997626928, + 677153173, 1606710353, 1216038070, 435565562, + 98686333, 1773787396, 267051994, 99395396, + 545509105, 782289675, 1289865975, 1707775075, + 1158993015, 1506576588, 993215179, 1523099397, + 923914455, 1895162386, 284489994, 1444139016, + 1943825680, 466202724, 1632522710, 1384015062, + 723147188, 1284031324, 1430481515, 341213007, + 171192499, 1061688239, 808927167, 83182639, + 759209907, 1728321272, 976049976, 1652071995, + 1002877840, 69880246, 1095135165, 677588420, + 1384715290, 829619452, 170122781, 1958173727, + 13389238, 789379698, 1883383039, 1279195174, + 1618672336, 1192839317, 1348311124, 758896285, + 1939775389, 684108413, 1838340479, 1332232130, + 1070486028, 549228790, 868851698, 1678207843, + 1754321489, 637000403, 647901906, 45343322, + 1768524074, 1167955205, 1816497210, 1609414096, + 1985231742, 1540534482, 232730819, 232221968, + 1509637836, 1480860627, 884647789, 1096458024, + 163721583, 1248032262, 436419506, 1737102298, + 651105860, 452298073, 1064372507, 1792838683, + 619243471, 860127631, 721724708, 950768433, + 279913448, 339693210, 47730422, 1952683911, + 1316500770, 675944216, 386902809, 619333956, + 1194800389, 43989936, 1944372656, 666045666, + 1155873844, 522696968, 58874730, 1497238023, + 421619994, 1980672127, 1657191856, 1913792631, + 1784663131, 1118400672, 1828104993, 1637808383, + 414755472, 775410449, 747132157, 136820101, + 1082674285, 93190395, 357955402, 335652723, + 1192102705, 480365232, 1354935730, 1391829361, + 966662991, 1601510445, 569528575, 545490940, + 1753711688, 807025222, 580374183, 587718008, + 977546290, 1055719519, 1157107032, 562799608, + 859466927, 840450024, 815325134, 936576801, + 1010587056, 246624382, 1808049797, 1098183398, + 1005077390, 772432546, 1976629565, 1003772218, + 1655315418, 1767931114, 982008720, 785023351, + ]; + + let prod_spec_evals: Array> = + builder.dyn_array(prod_spec_eval_u32s.len() / EXT_DEG); + for (idx, n) in prod_spec_eval_u32s.chunks(EXT_DEG).enumerate() { + let e: Ext = builder.constant(C::EF::from_base_slice(&[ + C::F::from_canonical_u32(n[0]), + C::F::from_canonical_u32(n[1]), + C::F::from_canonical_u32(n[2]), + C::F::from_canonical_u32(n[3]), + ])); + + builder.set(&prod_spec_evals, idx, e); + } + + #[rustfmt::skip] + let logup_spec_eval_u32s = [ + 1522353967u32, 457603397, 421847521, 1352563318, + 1746817766, 737872688, 1087008622, 1850835028, + 456475558, 892966330, 638163666, 148568548, + 678863061, 1334386850, 1896333039, 154585769, + 433618446, 1186936470, 970218722, 1213827097, + 1798557019, 861757965, 119285527, 395360622, + 226164366, 1330279872, 66561048, 785421608, + 1950755756, 1559889596, 348449876, 1090789452, + 257578851, 273164442, 1644906, 295600924, + 1187949602, 1168249609, 469763604, 60929061, + 291163036, 403842501, 1421902433, 1700188477, + 1046093370, 921059131, 1638991894, 464012042, + 96905857, 1370999592, 271896041, 13595534, + 1489760970, 1650552701, 133367846, 25680377, + 377631580, 652729291, 645763356, 426747355, + 482475486, 1877299223, 103226636, 1333832358, + 1399609097, 458536972, 976248802, 1109365280, + 515164588, 1579426417, 1601829549, 607169702, + 852817956, 1980537127, 134138338, 913344050, + 737880920, 476360275, 61624034, 1610624252, + 264461991, 546933535, 937769429, 293346965, + 1522058041, 1012551797, 994330314, 23333322, + 1969510890, 974351570, 2012030621, 120742000, + 450250620, 180547360, 642746933, 1815029950, + 629489142, 1176992624, 723354779, 572648755, + 1218615348, 648847054, 351903235, 723149764, + 248065753, 243829448, 1283393001, 1912627886, + 581641342, 702465306, 205969758, 1061911274, + 1, 0, 0, 0, + 1, 0, 0, 0, + 1703043252, 1467887451, 1714319214, 907866644, + 1542426838, 742609036, 1814393459, 448706641, + 1960340767, 46490834, 186512520, 363973095, + 846448854, 463742343, 2012517527, 40473617, + 9472552, 263483342, 105738598, 586389136, + 254290990, 625150844, 960233097, 1488303724, + 1700231692, 1471714612, 1540211186, 1590246915, + 945341972, 1343225515, 179976237, 34857822, + 276912528, 984309272, 1277293398, 1520924162, + 1823117694, 604836357, 1460812009, 600052559, + 970469338, 1771022707, 181855831, 1445947220, + 467514809, 1514677498, 947030389, 170390653, + 415409007, 1601463730, 204153427, 904614278, + 1855419512, 2009471607, 1352607379, 576586082, + 1343812879, 1176377580, 1166188815, 1592289048, + 761793881, 1529621462, 193034837, 344011596, + 1669461833, 1356800025, 314186361, 586497329, + 1832810846, 1288092861, 1619454491, 732529408, + 737934269, 909504928, 769680420, 1437893101, + 1727002258, 1618231110, 535125583, 153412473, + 1917760929, 588586507, 564531165, 1790797737, + 1666283994, 1366948884, 117673690, 476470378, + 2012274032, 1951406668, 1739767532, 1273142151, + 1591812317, 1900205312, 1912608761, 1734766024, + 1265002082, 1450462894, 749810837, 1329222552, + 745081805, 1231519431, 1420957967, 883846107, + 1995463911, 407795592, 161655852, 125886157, + 995318920, 484905024, 284135318, 551493419, + 406742309, 1089024446, 637339867, 1858138403, + 1230680117, 187078889, 1929517480, 1125646261, + 1, 0, 0, 0, + 1, 0, 0, 0, + 1610035932, 462442436, 831412555, 44798862, + 1748147276, 1911945531, 1329343740, 971894393, + 362147969, 1583335926, 1528700112, 426908674, + 847905883, 447889090, 1050883911, 1883537469, + 1487501632, 964178870, 1818828551, 1980840799, + 340372118, 1697179193, 215113037, 1893217470, + 1138628493, 1788052486, 443362955, 1349213730, + 589553425, 562526667, 1006040406, 1194546769, + 1831034644, 612004157, 730213913, 1068905440, + 371983982, 502900790, 802785198, 822377635, + 1477528437, 501356237, 684668525, 1306043781, + 621032592, 1971342708, 1411586583, 733418745, + 186045462, 1559301855, 323758310, 453170140, + 498381240, 976247416, 631213663, 898017829, + 501459603, 609703046, 1379288251, 177682695, + 912381595, 121915494, 1137416430, 504054388, + 1138277238, 1603388253, 1838013301, 1700271853, + 20488607, 58775264, 217974275, 979141729, + 53136584, 1331566240, 1460303356, 525812787, + 718385521, 1477919263, 1663622276, 1089788203, + 1204483837, 54225863, 290660186, 1441441958, + 134168813, 349638823, 1867912015, 1579183319, + 55528656, 1602973359, 194297109, 949763297, + 101931919, 242300116, 1610052257, 1351823848, + 174522860, 776955925, 1706962365, 808187490, + 1487253852, 431806906, 213982593, 1170647308, + 1776840400, 295916317, 378708073, 381270341, + 457494568, 705823997, 1407301442, 1693003013, + 700310785, 1349874247, 1284363817, 1566253815, + 1014298154, 215294365, 1070968678, 871641358, + 1, 0, 0, 0, + 1, 0, 0, 0, + 1302679751, 1121894357, 368587356, 1564724097, + 733815591, 2012670011, 1146780092, 1439780227, + 1801628424, 838692317, 932318853, 213634365, + 155292454, 1644317110, 1599846194, 978829059, + 1282095862, 1780431647, 527412087, 1024583705, + 804423802, 951808322, 689345230, 180304167, + 1784562773, 1514653374, 2009396440, 1143778943, + 235299446, 1553017484, 475425117, 758292254, + 716575432, 517083432, 1728864125, 418010549, + 43202592, 507659742, 433077118, 1268144019, + 1462778342, 1928073362, 1330130180, 1749624351, + 827401013, 1236194147, 1875519726, 1437946791, + 607293265, 309229599, 1009445595, 1725229718, + 1436309341, 1952606463, 943149111, 291680468, + 1989684076, 1944713370, 1285294139, 399758737, + 1572979232, 213817406, 214840530, 184898060, + 1483844295, 1536616777, 494816009, 217625163, + 529448032, 786640964, 1766471731, 1424140424, + 1721961711, 740275169, 169908711, 913969302, + 1359358267, 1328322971, 593228769, 771095186, + 801680440, 450930656, 1796349530, 1824428677, + 1111258504, 1741666629, 1098430204, 1792001884, + 1679003061, 590088446, 647614538, 1324461639, + 818996796, 229187928, 74288115, 1158900266, + 1512606270, 1381672753, 785927403, 493453164, + 425259497, 1367873539, 931023744, 221202218, + 669580668, 424996238, 1840425275, 1873362670, + 967642716, 263556335, 578560519, 1558449223, + 607579284, 1724012378, 333582342, 1195784167, + 1419727276, 199294290, 138807165, 1061030752, + 1, 0, 0, 0, + 1, 0, 0, 0, + 776332180, 1333076185, 1855163818, 1897408938, + 799274251, 950452503, 691904988, 1205387466, + 659107883, 434394982, 129587940, 639018629, + 659238594, 1957584892, 864291238, 589178070, + 1267157231, 48925338, 200093884, 1953762869, + 1227617341, 1471420621, 193077633, 1007876111, + 228491220, 1377349503, 1889411060, 1807513892, + 1593042934, 1240864695, 1472870721, 583021932, + 598239104, 1862008818, 1811242869, 780768026, + 520870395, 292016292, 322246659, 868240490, + 1715620331, 1183509209, 2010262726, 1003957251, + 264895455, 307755941, 201990485, 1662471178, + 1643997923, 1573129362, 277821143, 388834470, + 943361405, 1449402196, 614413575, 1504113993, + 1860552739, 1755127315, 1734129760, 1232115188, + 803035456, 360488092, 271342171, 1269544258, + 290642673, 660703582, 986842267, 870891877, + 454573044, 1999346236, 701614601, 820253867, + 883282765, 137247873, 1727164949, 1320585493, + 1738664600, 1900116905, 472215154, 1114994489, + 104218174, 1694603079, 771486383, 935361143, + 92277671, 881040480, 925124484, 1464396527, + 100625197, 65290355, 1001454341, 134627585, + 58629702, 1541542242, 568583607, 1706262052, + 530687550, 1303187245, 1010302462, 264001857, + 789816678, 561378226, 827432508, 801307507, + 1613508315, 1650822853, 1603502703, 439320335, + 15283580, 1244486577, 254345266, 1745653280, + 1648250354, 1528271018, 528366563, 1078707735, + 1430767759, 1890467731, 2001894083, 799949326, + 1, 0, 0, 0, + 1, 0, 0, 0, + 1341839494, 1092219735, 755644898, 966729319, + 1914277278, 1545367697, 1765189119, 1693413008, + ]; + + let logup_spec_evals: Array> = + builder.dyn_array(logup_spec_eval_u32s.len() / EXT_DEG); + for (idx, n) in logup_spec_eval_u32s.chunks(EXT_DEG).enumerate() { + let e: Ext = builder.constant(C::EF::from_base_slice(&[ + C::F::from_canonical_u32(n[0]), + C::F::from_canonical_u32(n[1]), + C::F::from_canonical_u32(n[2]), + C::F::from_canonical_u32(n[3]), + ])); + + builder.set(&logup_spec_evals, idx, e); + } + + #[rustfmt::skip] + let r_evals_u32s = [ + 941378355u32, 1078920879, 696738840, 496039492, + 1555445457, 184545404, 905938226, 1847966044, + 1024875886, 1782716223, 1625644635, 266865456, + 465953066, 1663531470, 757423849, 1957075986, + 1919693393, 839104130, 127480221, 1527842912, + 918650796, 921462354, 575456073, 696646705, + 1585912361, 258186488, 353168830, 1111094691, + 1401166558, 1905942163, 1923083163, 393037255, + 1042127700, 1126793296, 895794165, 1124924482, + 1324266058, 722406365, 1963838171, 968504459, + 1934378800, 714588691, 6465911, 1168379648, + 903786009, 1326035939, 518289228, 418998914, + 1513133474, 1578096058, 617547414, 1658315126, + 68556894, 1697802593, 1346510664, 1709381671, + 345062962, 1254089535, 1002281845, 1882822096, + 700581748, 1431345304, 489112954, 98435728, + 1799886007, 479788390, 223111065, 631662309, + ]; + + let next_layer_evals: Array> = + builder.dyn_array(r_evals_u32s.len() / EXT_DEG); + for (idx, n) in r_evals_u32s.chunks(EXT_DEG).enumerate() { + let e: Ext = builder.constant(C::EF::from_base_slice(&[ + C::F::from_canonical_u32(n[0]), + C::F::from_canonical_u32(n[1]), + C::F::from_canonical_u32(n[2]), + C::F::from_canonical_u32(n[3]), + ])); + + builder.set(&next_layer_evals, idx, e); + } + + builder.sumcheck_layer_eval( + &ctx, + &challenges, + &prod_spec_evals, + &logup_spec_evals, + &next_layer_evals, + ); + + builder.halt(); +}