Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 23 additions & 17 deletions prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,7 @@ pub struct TableCounts {
}

impl TableCounts {
/// Validate that all required tables have at least one chunk.
///
/// A zero count for any table would remove its constraints from verification,
/// allowing a malicious prover to bypass soundness checks.
/// Sum of all chunk counts across split tables.
/// Sum of all chunk counts across the split tables.
pub fn total(&self) -> usize {
self.cpu
+ self.lt
Expand Down Expand Up @@ -454,10 +450,8 @@ pub(crate) fn replay_transcript_phase_a(
for (air, proof) in airs.iter().zip(&multi_proof.proofs) {
if air.is_preprocessed() {
transcript.append_bytes(&air.precomputed_commitment());
transcript.append_bytes(&proof.lde_trace_main_merkle_root);
} else {
transcript.append_bytes(&proof.lde_trace_main_merkle_root);
}
transcript.append_bytes(&proof.lde_trace_main_merkle_root);
}
let z: FieldElement<E> = transcript.sample_field_element();
let alpha: FieldElement<E> = transcript.sample_field_element();
Expand Down Expand Up @@ -486,15 +480,27 @@ pub(crate) fn compute_commit_bus_offset(
let bus_id = FieldElement::<E>::from(BusId::Commit as u64);
let alpha_sq = alpha * alpha;

let mut total = FieldElement::<E>::zero();
for (i, &value) in public_output.iter().enumerate() {
let linear_combination = bus_id
+ (FieldElement::<E>::from(i as u64) * alpha)
+ (FieldElement::<E>::from(value as u64) * alpha_sq);
let fingerprint = z - linear_combination;
total += fingerprint.inv().ok()?;
}
Some(total)
// fingerprint_i = z - (BusId::Commit + i·α + value_i·α²)
let mut fingerprints: Vec<FieldElement<E>> = public_output
.iter()
.enumerate()
.map(|(i, &value)| {
let linear_combination = bus_id
+ (FieldElement::<E>::from(i as u64) * alpha)
+ (FieldElement::<E>::from(value as u64) * alpha_sq);
z - linear_combination
})
.collect();

// Batch inversion: 1 inversion + O(3N) muls instead of N field inversions.
// `Err` iff some fingerprint is zero (a collision) — treat as failure.
FieldElement::inplace_batch_inverse(&mut fingerprints).ok()?;

Some(
fingerprints
.iter()
.fold(FieldElement::<E>::zero(), |acc, term| acc + term),
)
}

/// Compute the expected COMMIT bus balance for a `MultiProof`.
Expand Down
100 changes: 100 additions & 0 deletions prover/src/tests/compute_commit_bus_offset_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
//! Unit tests for `compute_commit_bus_offset`.
//!
//! Pins the three behaviours the verify-path helper must preserve:
//! empty input short-circuit, success-path equivalence with a naive
//! per-element-inverse reference, and the zero-fingerprint failure path.

use math::field::element::FieldElement;

use crate::compute_commit_bus_offset;
use crate::tables::types::{BusId, GoldilocksExtension};

type E = GoldilocksExtension;

/// Reference implementation: one `inv()` per fingerprint, then sum.
/// Mirrors the original loop bit-for-bit modulo addition order, so any
/// future refactor of the batched routine must remain equivalent to this.
fn naive_offset(
public_output: &[u8],
z: &FieldElement<E>,
alpha: &FieldElement<E>,
) -> Option<FieldElement<E>> {
let bus_id = FieldElement::<E>::from(BusId::Commit as u64);
let alpha_sq = alpha * alpha;
let mut total = FieldElement::<E>::zero();
for (i, &value) in public_output.iter().enumerate() {
let lc = bus_id
+ (FieldElement::<E>::from(i as u64) * alpha)
+ (FieldElement::<E>::from(value as u64) * alpha_sq);
let fingerprint = z - lc;
total += fingerprint.inv().ok()?;
}
Some(total)
}

#[test]
fn test_empty_public_output_returns_zero() {
let z = FieldElement::<E>::from(7u64);
let alpha = FieldElement::<E>::from(11u64);
assert_eq!(
compute_commit_bus_offset(&[], &z, &alpha),
Some(FieldElement::<E>::zero()),
);
}

#[test]
fn test_non_empty_matches_naive_per_element_inverse() {
let z = FieldElement::<E>::from(987_654_321u64);
let alpha = FieldElement::<E>::from(31_415_926u64);
let public_output: [u8; 5] = [0x01, 0x02, 0xff, 0x10, 0x80];

let batched = compute_commit_bus_offset(&public_output, &z, &alpha);
let naive = naive_offset(&public_output, &z, &alpha);

assert_eq!(batched, naive);
assert!(batched.is_some(), "no fingerprint should collide here");
}

#[test]
fn test_longer_input_matches_naive() {
let z = FieldElement::<E>::from(0xdead_beefu64);
let alpha = FieldElement::<E>::from(0xcafe_babeu64);
let public_output: Vec<u8> = (0..=255u16).map(|x| x as u8).collect();

let batched = compute_commit_bus_offset(&public_output, &z, &alpha);
let naive = naive_offset(&public_output, &z, &alpha);

assert_eq!(batched, naive);
assert!(batched.is_some());
}

#[test]
fn test_zero_fingerprint_returns_none() {
// Craft fingerprint_0 = 0: i = 0, value = 0, then
// fingerprint_0 = z - (BusId::Commit + 0·α + 0·α²) = z - BusId::Commit.
// Setting z = BusId::Commit forces the collision regardless of alpha.
let z = FieldElement::<E>::from(BusId::Commit as u64);
let alpha = FieldElement::<E>::from(42u64);
let public_output: [u8; 1] = [0];

assert_eq!(
compute_commit_bus_offset(&public_output, &z, &alpha),
None,
"zero fingerprint must propagate as None",
);
}

#[test]
fn test_zero_fingerprint_in_middle_returns_none() {
// Same idea at i = 2, so some valid fingerprints precede the zero one.
let alpha = FieldElement::<E>::from(5u64);
let alpha_sq = alpha * alpha;
let bus_id = FieldElement::<E>::from(BusId::Commit as u64);
// value = 3 at index 2 → z = BusId + 2α + 3α² forces fingerprint_2 = 0.
let z = bus_id
+ (FieldElement::<E>::from(2u64) * alpha)
+ (FieldElement::<E>::from(3u64) * alpha_sq);
let public_output: [u8; 4] = [1, 2, 3, 4];

assert_eq!(compute_commit_bus_offset(&public_output, &z, &alpha), None,);
}
2 changes: 2 additions & 0 deletions prover/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ pub mod branch_constraints_tests;
#[cfg(test)]
pub mod commit_tests;
#[cfg(test)]
pub mod compute_commit_bus_offset_tests;
#[cfg(test)]
pub mod constraints_tests;
#[cfg(all(test, feature = "disk-spill"))]
pub mod count_table_lengths_drift_tests;
Expand Down
Loading