Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion extensions/native/circuit/cuda/include/native/sumcheck.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ template <typename T> struct NativeSumcheckCols {

T eval_acc[EXT_DEG];

T is_hint_src_id;
T is_writeback;

T specific[COL_SPECIFIC_WIDTH];
};
Expand Down
71 changes: 46 additions & 25 deletions extensions/native/circuit/cuda/src/sumcheck.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,34 +32,55 @@ __device__ void fill_sumcheck_specific(RowSlice row, MemoryAuxColsFactory &mem_h
);
} else if (row[COL_INDEX(NativeSumcheckCols, prod_row)] == Fp::one()) {
if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) {
mem_fill_base(
mem_helper,
start_timestamp,
specific.slice_from(COL_INDEX(ProdSpecificCols, ps_record.base))
);
mem_fill_base(
mem_helper,
start_timestamp + 1,
specific.slice_from(COL_INDEX(ProdSpecificCols, write_record.base))
);
if (row[COL_INDEX(NativeSumcheckCols, is_writeback)] == Fp::one()) {
mem_fill_base(
mem_helper,
start_timestamp,
specific.slice_from(COL_INDEX(ProdSpecificCols, ps_record.base))
);
mem_fill_base(
mem_helper,
start_timestamp + 1,
specific.slice_from(COL_INDEX(ProdSpecificCols, write_record.base))
);
} else {
mem_fill_base(
mem_helper,
start_timestamp,
specific.slice_from(COL_INDEX(ProdSpecificCols, write_record.base))
);
}
}
} else if (row[COL_INDEX(NativeSumcheckCols, logup_row)] == Fp::one()) {
if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) {
mem_fill_base(
mem_helper,
start_timestamp,
specific.slice_from(COL_INDEX(LogupSpecificCols, pqs_record.base))
);
mem_fill_base(
mem_helper,
start_timestamp + 1,
specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[0].base))
);
mem_fill_base(
mem_helper,
start_timestamp + 2,
specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[1].base))
);
if (row[COL_INDEX(NativeSumcheckCols, is_writeback)] == Fp::one()) {
mem_fill_base(
mem_helper,
start_timestamp,
specific.slice_from(COL_INDEX(LogupSpecificCols, pqs_record.base))
);
mem_fill_base(
mem_helper,
start_timestamp + 1,
specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[0].base))
);
mem_fill_base(
mem_helper,
start_timestamp + 2,
specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[1].base))
);
} else {
mem_fill_base(
mem_helper,
start_timestamp,
specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[0].base))
);
mem_fill_base(
mem_helper,
start_timestamp + 1,
specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[1].base))
);
}
}
}
}
Expand Down
65 changes: 8 additions & 57 deletions extensions/native/circuit/src/sumcheck/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::borrow::Borrow;
use openvm_circuit::{
arch::{ExecutionBridge, ExecutionState},
system::memory::{
offline_checker::{MemoryBridge, MemoryReadAuxCols},
offline_checker::MemoryBridge,
MemoryAddress,
},
};
Expand All @@ -26,9 +26,6 @@ use crate::{
},
};

pub const NUM_RWS_FOR_PRODUCT: usize = 2;
pub const NUM_RWS_FOR_LOGUP: usize = 3;

#[derive(Clone, Debug)]
pub struct NativeSumcheckAir {
pub execution_bridge: ExecutionBridge,
Expand Down Expand Up @@ -105,7 +102,7 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
within_round_limit,
should_acc,
eval_acc,
is_hint_src_id,
is_writeback,
specific,
} = local;

Expand Down Expand Up @@ -235,22 +232,6 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
next.start_timestamp,
start_timestamp + AB::F::from_canonical_usize(8),
);
builder
.when(prod_row)
.when(next.prod_row + next.logup_row)
.assert_eq(
next.start_timestamp,
start_timestamp
+ within_round_limit * AB::F::from_canonical_usize(NUM_RWS_FOR_PRODUCT),
);
builder
.when(logup_row)
.when(next.prod_row + next.logup_row)
.assert_eq(
next.start_timestamp,
start_timestamp
+ within_round_limit * AB::F::from_canonical_usize(NUM_RWS_FOR_LOGUP),
);

// Termination condition
assert_array_eq(
Expand Down Expand Up @@ -349,7 +330,7 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
native_as,
register_ptrs[0] + AB::F::from_canonical_usize(CONTEXT_ARR_BASE_LEN),
),
[max_round, is_hint_src_id],
[max_round, is_writeback],
first_timestamp + AB::F::from_canonical_usize(7),
&header_row_specific.read_records[7],
)
Expand Down Expand Up @@ -392,21 +373,6 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
);
builder.assert_eq(prod_row * should_acc, prod_acc);

// Read p1, p2 from witness arrays
self.memory_bridge
.read(
MemoryAddress::new(native_as, register_ptrs[2] + prod_row_specific.data_ptr),
prod_row_specific.p,
start_timestamp,
&MemoryReadAuxCols {
base: prod_row_specific.ps_record.base,
},
)
.eval(
builder,
(prod_in_round_evaluation + prod_next_round_evaluation) * not(is_hint_src_id),
);

// Obtain p1, p2 from hint space and write back to witness arrays
self.memory_bridge
.write(
Expand All @@ -417,7 +383,7 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
)
.eval(
builder,
(prod_in_round_evaluation + prod_next_round_evaluation) * is_hint_src_id,
(prod_in_round_evaluation + prod_next_round_evaluation) * is_writeback,
);

let p1: [AB::Var; EXT_DEG] = prod_row_specific.p[0..EXT_DEG].try_into().unwrap();
Expand All @@ -432,7 +398,7 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
register_ptrs[4] + curr_prod_n * AB::F::from_canonical_usize(EXT_DEG),
),
prod_row_specific.p_evals,
start_timestamp + AB::F::ONE,
start_timestamp + is_writeback * AB::F::ONE,
&prod_row_specific.write_record,
)
.eval(builder, prod_row * within_round_limit);
Expand Down Expand Up @@ -499,21 +465,6 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
);
builder.assert_eq(logup_row * should_acc, logup_acc);

// Read p1, p2, q1, q2 from witness arrays
self.memory_bridge
.read(
MemoryAddress::new(native_as, register_ptrs[3] + logup_row_specific.data_ptr),
logup_row_specific.pq,
start_timestamp,
&MemoryReadAuxCols {
base: logup_row_specific.pqs_record.base,
},
)
.eval(
builder,
(logup_in_round_evaluation + logup_next_round_evaluation) * not(is_hint_src_id),
);

// Obtain p1, p2, q1, q2 from hint space
self.memory_bridge
.write(
Expand All @@ -524,7 +475,7 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
)
.eval(
builder,
(logup_in_round_evaluation + logup_next_round_evaluation) * is_hint_src_id,
(logup_in_round_evaluation + logup_next_round_evaluation) * is_writeback,
);
let p1: [_; EXT_DEG] = logup_row_specific.pq[0..EXT_DEG].try_into().unwrap();
let p2: [_; EXT_DEG] = logup_row_specific.pq[EXT_DEG..(EXT_DEG * 2)]
Expand All @@ -546,7 +497,7 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
+ (num_prod_spec + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG),
),
logup_row_specific.p_evals,
start_timestamp + AB::F::ONE,
start_timestamp + is_writeback * AB::F::ONE,
&logup_row_specific.write_records[0],
)
.eval(builder, logup_row * within_round_limit);
Expand All @@ -561,7 +512,7 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
* AB::F::from_canonical_usize(EXT_DEG),
),
logup_row_specific.q_evals,
start_timestamp + AB::F::TWO,
start_timestamp + is_writeback * AB::F::ONE + AB::F::ONE,
&logup_row_specific.write_records[1],
)
.eval(builder, logup_row * within_round_limit);
Expand Down
Loading
Loading