diff --git a/extensions/native/circuit/cuda/include/native/sumcheck.cuh b/extensions/native/circuit/cuda/include/native/sumcheck.cuh index a7d6eee536..86d122f476 100644 --- a/extensions/native/circuit/cuda/include/native/sumcheck.cuh +++ b/extensions/native/circuit/cuda/include/native/sumcheck.cuh @@ -82,7 +82,7 @@ template struct NativeSumcheckCols { T eval_acc[EXT_DEG]; - T is_hint_src_id; + T is_writeback; T specific[COL_SPECIFIC_WIDTH]; }; diff --git a/extensions/native/circuit/cuda/src/sumcheck.cu b/extensions/native/circuit/cuda/src/sumcheck.cu index 139a56473f..9ad666e183 100644 --- a/extensions/native/circuit/cuda/src/sumcheck.cu +++ b/extensions/native/circuit/cuda/src/sumcheck.cu @@ -32,34 +32,55 @@ __device__ void fill_sumcheck_specific(RowSlice row, MemoryAuxColsFactory &mem_h ); } else if (row[COL_INDEX(NativeSumcheckCols, prod_row)] == Fp::one()) { if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) { - mem_fill_base( - mem_helper, - start_timestamp, - specific.slice_from(COL_INDEX(ProdSpecificCols, ps_record.base)) - ); - mem_fill_base( - mem_helper, - start_timestamp + 1, - specific.slice_from(COL_INDEX(ProdSpecificCols, write_record.base)) - ); + if (row[COL_INDEX(NativeSumcheckCols, is_writeback)] == Fp::one()) { + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(ProdSpecificCols, ps_record.base)) + ); + mem_fill_base( + mem_helper, + start_timestamp + 1, + specific.slice_from(COL_INDEX(ProdSpecificCols, write_record.base)) + ); + } else { + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(ProdSpecificCols, write_record.base)) + ); + } } } else if (row[COL_INDEX(NativeSumcheckCols, logup_row)] == Fp::one()) { if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) { - mem_fill_base( - mem_helper, - start_timestamp, - specific.slice_from(COL_INDEX(LogupSpecificCols, pqs_record.base)) - ); - mem_fill_base( - mem_helper, - start_timestamp + 1, - specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[0].base)) - ); - mem_fill_base( - mem_helper, - start_timestamp + 2, - specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[1].base)) - ); + if (row[COL_INDEX(NativeSumcheckCols, is_writeback)] == Fp::one()) { + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(LogupSpecificCols, pqs_record.base)) + ); + mem_fill_base( + mem_helper, + start_timestamp + 1, + specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[0].base)) + ); + mem_fill_base( + mem_helper, + start_timestamp + 2, + specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[1].base)) + ); + } else { + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[0].base)) + ); + mem_fill_base( + mem_helper, + start_timestamp + 1, + specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[1].base)) + ); + } } } } diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index c9bbf1279e..ea4aeb5ec1 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -3,7 +3,7 @@ use std::borrow::Borrow; use openvm_circuit::{ arch::{ExecutionBridge, ExecutionState}, system::memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols}, + offline_checker::MemoryBridge, MemoryAddress, }, }; @@ -26,9 +26,6 @@ use crate::{ }, }; -pub const NUM_RWS_FOR_PRODUCT: usize = 2; -pub const NUM_RWS_FOR_LOGUP: usize = 3; - #[derive(Clone, Debug)] pub struct NativeSumcheckAir { pub execution_bridge: ExecutionBridge, @@ -105,7 +102,7 @@ impl Air for NativeSumcheckAir { within_round_limit, should_acc, eval_acc, - is_hint_src_id, + is_writeback, specific, } = local; @@ -235,22 +232,6 @@ impl Air for NativeSumcheckAir { next.start_timestamp, start_timestamp + AB::F::from_canonical_usize(8), ); - builder - .when(prod_row) - .when(next.prod_row + next.logup_row) - .assert_eq( - next.start_timestamp, - start_timestamp - + within_round_limit * AB::F::from_canonical_usize(NUM_RWS_FOR_PRODUCT), - ); - builder - .when(logup_row) - .when(next.prod_row + next.logup_row) - .assert_eq( - next.start_timestamp, - start_timestamp - + within_round_limit * AB::F::from_canonical_usize(NUM_RWS_FOR_LOGUP), - ); // Termination condition assert_array_eq( @@ -349,7 +330,7 @@ impl Air for NativeSumcheckAir { native_as, register_ptrs[0] + AB::F::from_canonical_usize(CONTEXT_ARR_BASE_LEN), ), - [max_round, is_hint_src_id], + [max_round, is_writeback], first_timestamp + AB::F::from_canonical_usize(7), &header_row_specific.read_records[7], ) @@ -392,21 +373,6 @@ impl Air for NativeSumcheckAir { ); builder.assert_eq(prod_row * should_acc, prod_acc); - // Read p1, p2 from witness arrays - self.memory_bridge - .read( - MemoryAddress::new(native_as, register_ptrs[2] + prod_row_specific.data_ptr), - prod_row_specific.p, - start_timestamp, - &MemoryReadAuxCols { - base: prod_row_specific.ps_record.base, - }, - ) - .eval( - builder, - (prod_in_round_evaluation + prod_next_round_evaluation) * not(is_hint_src_id), - ); - // Obtain p1, p2 from hint space and write back to witness arrays self.memory_bridge .write( @@ -417,7 +383,7 @@ impl Air for NativeSumcheckAir { ) .eval( builder, - (prod_in_round_evaluation + prod_next_round_evaluation) * is_hint_src_id, + (prod_in_round_evaluation + prod_next_round_evaluation) * is_writeback, ); let p1: [AB::Var; EXT_DEG] = prod_row_specific.p[0..EXT_DEG].try_into().unwrap(); @@ -432,7 +398,7 @@ impl Air for NativeSumcheckAir { register_ptrs[4] + curr_prod_n * AB::F::from_canonical_usize(EXT_DEG), ), prod_row_specific.p_evals, - start_timestamp + AB::F::ONE, + start_timestamp + is_writeback * AB::F::ONE, &prod_row_specific.write_record, ) .eval(builder, prod_row * within_round_limit); @@ -499,21 +465,6 @@ impl Air for NativeSumcheckAir { ); builder.assert_eq(logup_row * should_acc, logup_acc); - // Read p1, p2, q1, q2 from witness arrays - self.memory_bridge - .read( - MemoryAddress::new(native_as, register_ptrs[3] + logup_row_specific.data_ptr), - logup_row_specific.pq, - start_timestamp, - &MemoryReadAuxCols { - base: logup_row_specific.pqs_record.base, - }, - ) - .eval( - builder, - (logup_in_round_evaluation + logup_next_round_evaluation) * not(is_hint_src_id), - ); - // Obtain p1, p2, q1, q2 from hint space self.memory_bridge .write( @@ -524,7 +475,7 @@ impl Air for NativeSumcheckAir { ) .eval( builder, - (logup_in_round_evaluation + logup_next_round_evaluation) * is_hint_src_id, + (logup_in_round_evaluation + logup_next_round_evaluation) * is_writeback, ); let p1: [_; EXT_DEG] = logup_row_specific.pq[0..EXT_DEG].try_into().unwrap(); let p2: [_; EXT_DEG] = logup_row_specific.pq[EXT_DEG..(EXT_DEG * 2)] @@ -546,7 +497,7 @@ impl Air for NativeSumcheckAir { + (num_prod_spec + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG), ), logup_row_specific.p_evals, - start_timestamp + AB::F::ONE, + start_timestamp + is_writeback * AB::F::ONE, &logup_row_specific.write_records[0], ) .eval(builder, logup_row * within_round_limit); @@ -561,7 +512,7 @@ impl Air for NativeSumcheckAir { * AB::F::from_canonical_usize(EXT_DEG), ), logup_row_specific.q_evals, - start_timestamp + AB::F::TWO, + start_timestamp + is_writeback * AB::F::ONE + AB::F::ONE, &logup_row_specific.write_records[1], ) .eval(builder, logup_row * within_round_limit); diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index d4fbf2524d..9efaf62a48 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -207,7 +207,7 @@ where challenges_ptr.as_canonical_u32(), head_specific.read_records[6].as_mut(), ); - let [max_round, is_hint_src_id]: [F; 2] = tracing_read_native_helper( + let [max_round, is_writeback]: [F; 2] = tracing_read_native_helper( state.memory, ctx_ptr.as_canonical_u32() + CONTEXT_ARR_BASE_LEN as u32, head_specific.read_records[7].as_mut(), @@ -242,21 +242,13 @@ where row.register_ptrs[3] = logup_evals_ptr; row.register_ptrs[4] = r_evals_ptr; row.max_round = max_round; - row.is_hint_src_id = is_hint_src_id; + row.is_writeback = is_writeback; } - // Load hints if source is a ptr - let is_hint_src_id = is_hint_src_id > F::ZERO; let prod_evals_id = prod_evals_id.as_canonical_u32(); let logup_evals_id = logup_evals_id.as_canonical_u32(); - let (prod_evals, logup_evals) = if is_hint_src_id { - ( - state.streams.hint_space[prod_evals_id as usize].clone(), - state.streams.hint_space[logup_evals_id as usize].clone(), - ) - } else { - (Vec::new(), Vec::new()) - }; + let prod_evals = state.streams.hint_space[prod_evals_id as usize].clone(); + let logup_evals = state.streams.hint_space[logup_evals_id as usize].clone(); // product rows for (i, prod_row) in rows @@ -292,24 +284,16 @@ where prod_specific.data_ptr = F::from_canonical_u32(start); // read p1, p2 - let ps: [F; EXT_DEG * 2] = if is_hint_src_id { - prod_evals[(start as usize)..((start as usize) + EXT_DEG * 2)] - .try_into() - .unwrap() - } else { - tracing_read_native_helper( - state.memory, - prod_evals_ptr.as_canonical_u32() + start, - prod_specific.ps_record.as_mut(), - ) - }; + let ps: [F; EXT_DEG * 2] = prod_evals[(start as usize)..((start as usize) + EXT_DEG * 2)] + .try_into() + .unwrap(); let p1: [F; EXT_DEG] = ps[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = ps[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap(); prod_specific.p = ps; // If p values come from the hint stream, write back to the actual witness array - if is_hint_src_id { + if is_writeback > F::ZERO { tracing_write_native_inplace( state.memory, prod_evals_ptr.as_canonical_u32() + start, @@ -346,7 +330,7 @@ where eval, &mut prod_specific.write_record, ); - cur_timestamp += 2; // Either 1 read, 1 write (witness array input), or 2 writes (hint_ptr_id) + cur_timestamp += if is_writeback > F::ZERO { 2 } else { 1 }; // Only write back to the witness array when the is_writeback indicator is true let eval_rlc = FieldExtension::multiply(alpha_acc, eval); prod_specific.eval_rlc = eval_rlc; @@ -394,17 +378,9 @@ where logup_specific.data_ptr = F::from_canonical_u32(start); // read p1, p2, q1, q2 - let pqs: [F; EXT_DEG * 4] = if is_hint_src_id { - logup_evals[(start as usize)..(start as usize) + EXT_DEG * 4] - .try_into() - .unwrap() - } else { - tracing_read_native_helper( - state.memory, - logup_evals_ptr.as_canonical_u32() + start, - logup_specific.pqs_record.as_mut(), - ) - }; + let pqs: [F; EXT_DEG * 4] = logup_evals[(start as usize)..(start as usize) + EXT_DEG * 4] + .try_into() + .unwrap(); let p1: [F; EXT_DEG] = pqs[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = pqs[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap(); let q1: [F; EXT_DEG] = pqs[(EXT_DEG * 2)..(EXT_DEG * 3)].try_into().unwrap(); @@ -413,7 +389,7 @@ where logup_specific.pq = pqs; // write pqs - if is_hint_src_id { + if is_writeback > F::ZERO { tracing_write_native_inplace( state.memory, logup_evals_ptr.as_canonical_u32() + start, @@ -472,7 +448,7 @@ where q_eval, &mut logup_specific.write_records[1], ); - cur_timestamp += 3; // 1 read, 2 writes (witness array case) or 3 writes (hint space ptr case) + cur_timestamp += if is_writeback > F::ZERO { 3 } else { 2 }; // Only write back to the witness array when the is_writeback indicator is true let eval_rlc = FieldExtension::add( FieldExtension::multiply(alpha_numerator, p_eval), @@ -568,42 +544,66 @@ impl TraceFiller for NativeSumcheckFiller { cols.specific[..ProdSpecificCols::::width()].borrow_mut(); if cols.within_round_limit == F::ONE { - // obtain p1, p2 - mem_fill_helper( - mem_helper, - start_timestamp, - prod_row_specific.ps_record.as_mut(), - ); - // write p_eval - mem_fill_helper( - mem_helper, - start_timestamp + 1, - prod_row_specific.write_record.as_mut(), - ); + if cols.is_writeback == F::ONE { + // writeback p1, p2 + mem_fill_helper( + mem_helper, + start_timestamp, + prod_row_specific.ps_record.as_mut(), + ); + // write p_eval + mem_fill_helper( + mem_helper, + start_timestamp + 1, + prod_row_specific.write_record.as_mut(), + ); + } else { + // write p_eval + mem_fill_helper( + mem_helper, + start_timestamp, + prod_row_specific.write_record.as_mut(), + ); + } } } else if cols.logup_row == F::ONE { let logup_row_specific: &mut LogupSpecificCols = cols.specific[..LogupSpecificCols::::width()].borrow_mut(); if cols.within_round_limit == F::ONE { - // obtain p1, p2, q1, q2 - mem_fill_helper( - mem_helper, - start_timestamp, - logup_row_specific.pqs_record.as_mut(), - ); - // write p_eval - mem_fill_helper( - mem_helper, - start_timestamp + 1, - logup_row_specific.write_records[0].as_mut(), - ); - // write q_eval - mem_fill_helper( - mem_helper, - start_timestamp + 2, - logup_row_specific.write_records[1].as_mut(), - ); + if cols.is_writeback == F::ONE { + // writeback p1, p2, q1, q2 + mem_fill_helper( + mem_helper, + start_timestamp, + logup_row_specific.pqs_record.as_mut(), + ); + // write p_eval + mem_fill_helper( + mem_helper, + start_timestamp + 1, + logup_row_specific.write_records[0].as_mut(), + ); + // write q_eval + mem_fill_helper( + mem_helper, + start_timestamp + 2, + logup_row_specific.write_records[1].as_mut(), + ); + } else { + // write p_eval + mem_fill_helper( + mem_helper, + start_timestamp, + logup_row_specific.write_records[0].as_mut(), + ); + // write q_eval + mem_fill_helper( + mem_helper, + start_timestamp + 1, + logup_row_specific.write_records[1].as_mut(), + ); + } } } } diff --git a/extensions/native/circuit/src/sumcheck/columns.rs b/extensions/native/circuit/src/sumcheck/columns.rs index f02f154cf2..eeb81df134 100644 --- a/extensions/native/circuit/src/sumcheck/columns.rs +++ b/extensions/native/circuit/src/sumcheck/columns.rs @@ -73,8 +73,8 @@ pub struct NativeSumcheckCols { // The current final evaluation accumulator. Extension element. pub eval_acc: [T; EXT_DEG], - // Indicator for an alternative source form of the inputs prod_evals/logup_evals - pub is_hint_src_id: T, + // Indicate whether the values read from hint slices should be written back to a witness array + pub is_writeback: T, // /// 1. For header row, 5 registers, ctx, challenges // /// 2. For the rest: max_variables, p1, p2, q1, q2 diff --git a/extensions/native/circuit/src/sumcheck/execution.rs b/extensions/native/circuit/src/sumcheck/execution.rs index a311e634b8..7cbd1fe319 100644 --- a/extensions/native/circuit/src/sumcheck/execution.rs +++ b/extensions/native/circuit/src/sumcheck/execution.rs @@ -217,7 +217,7 @@ unsafe fn execute_e12_impl( ctx; let challenges: [F; EXT_DEG * 4] = exec_state.vm_read(NATIVE_AS, challenges_ptr.as_canonical_u32()); - let [max_round, is_hint_space_ids]: [u32; 2] = exec_state + let [max_round, is_writeback]: [u32; 2] = exec_state .vm_read(NATIVE_AS, ctx_ptr_u32 + CONTEXT_ARR_BASE_LEN as u32) .map(|x: F| x.as_canonical_u32()); let alpha: [F; EXT_DEG] = challenges[0..EXT_DEG].try_into().unwrap(); @@ -228,14 +228,8 @@ unsafe fn execute_e12_impl( let mut alpha_acc = elem_to_ext(F::ONE); let mut eval_acc = elem_to_ext(F::ZERO); - let (prod_evals, logup_evals) = if is_hint_space_ids > 0 { - ( - exec_state.streams.hint_space[prod_evals_id as usize].clone(), - exec_state.streams.hint_space[logup_evals_id as usize].clone(), - ) - } else { - (Vec::new(), Vec::new()) - }; + let prod_evals = exec_state.streams.hint_space[prod_evals_id as usize].clone(); + let logup_evals = exec_state.streams.hint_space[logup_evals_id as usize].clone(); for i in 0..num_prod_spec { let start = calculate_3d_ext_idx( @@ -247,16 +241,11 @@ unsafe fn execute_e12_impl( ); if round < max_round - 1 { - let ps: [F; EXT_DEG * 2] = if is_hint_space_ids > 0 { - prod_evals[(start as usize)..(start as usize) + EXT_DEG * 2].try_into().unwrap() - } else { - exec_state.vm_read::<_, { EXT_DEG * 2 }>(NATIVE_AS, prod_evals_ptr + start).try_into().unwrap() - }; - + let ps: [F; EXT_DEG * 2] = prod_evals[(start as usize)..(start as usize) + EXT_DEG * 2].try_into().unwrap(); let p1: [F; EXT_DEG] = ps[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = ps[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); - if is_hint_space_ids > 0 { + if is_writeback > 0 { exec_state.vm_write(NATIVE_AS, prod_evals_ptr + start, &ps); } @@ -297,17 +286,13 @@ unsafe fn execute_e12_impl( if round < max_round - 1 { // read logup_evals - let pqs: [F; EXT_DEG * 4] = if is_hint_space_ids > 0 { - logup_evals[(start as usize)..(start as usize) + EXT_DEG * 4].try_into().unwrap() - } else { - exec_state.vm_read::<_, { EXT_DEG * 4 }>(NATIVE_AS, logup_evals_ptr + start).try_into().unwrap() - }; + let pqs: [F; EXT_DEG * 4] = logup_evals[(start as usize)..(start as usize) + EXT_DEG * 4].try_into().unwrap(); let p1: [F; EXT_DEG] = pqs[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = pqs[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); let q1: [F; EXT_DEG] = pqs[EXT_DEG * 2..EXT_DEG * 3].try_into().unwrap(); let q2: [F; EXT_DEG] = pqs[EXT_DEG * 3..EXT_DEG * 4].try_into().unwrap(); - if is_hint_space_ids > 0 { + if is_writeback > 0 { exec_state.vm_write(NATIVE_AS, logup_evals_ptr + start, &pqs); }