Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 86 additions & 53 deletions crypto/stark/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::sync::Arc;
use std::time::Instant;

use crypto::fiat_shamir::is_transcript::IsStarkTranscript;
use math::fft::cpu::bit_reversing::reverse_index;
use math::fft::cpu::bit_reversing::{in_place_bit_reverse_permute, reverse_index};
use math::fft::cpu::bowers_fft::LayerTwiddles;
use math::fft::errors::FFTError;

Expand Down Expand Up @@ -1032,12 +1032,18 @@ pub trait IsStarkProver<
// === Trace polynomials: barycentric evaluation via LDE ===
// Uses get_trace_evaluations_from_lde which performs barycentric interpolation
// on the LDE trace data, avoiding the need for coefficient-form trace_polys.
// Reuses coset_points, coset_offset_pow_n, domain_size_inv, g_n_inv already
// computed above for composition poly evaluation — avoids redundant work.
let trace_ood_evaluations = crate::trace::get_trace_evaluations_from_lde(
&round_1_result.lde_trace,
domain,
z,
&air.context().transition_offsets,
air.step_size(),
&coset_points,
&coset_offset_pow_n,
&domain_size_inv,
&g_n_inv,
);

Round3 {
Expand Down Expand Up @@ -1101,24 +1107,17 @@ pub trait IsStarkProver<
#[cfg(feature = "instruments")]
let other_dur_1 = t_sub.elapsed();

// Extend N trace-coset evaluations to 2N LDE-coset evaluations via standard LDE.
// deep_evals[i] = h(offset·ω_N^i) = f(ω_N^i) where f(x) = h(offset·x).
// Standard iFFT+FFT recovers f and evaluates on the 2N-th roots: f(Ω^j) = h(offset·Ω^j).
// DEEP evaluations are already at 2N LDE points — just bit-reverse for FRI.
// No iFFT+FFT extension needed (Plonky3-style direct LDE computation).
let domain_size = domain.lde_roots_of_unity_coset.len();
#[cfg(feature = "instruments")]
let t_sub = Instant::now();
let deep_poly =
Polynomial::interpolate_fft::<Field>(&deep_evals).expect("iFFT should succeed");
// FRI commit_phase consumes bit-reversed evaluations natively. Request them
// directly from evaluate_fft_bit_reversed to avoid a pair of redundant permutes
// (evaluate_fft's internal natural-order permute + an external re-bit-reverse).
let lde_evals =
Polynomial::evaluate_fft_bit_reversed::<Field>(&deep_poly, 1, Some(domain_size))
.expect("FFT should succeed");
let mut lde_evals = deep_evals;
in_place_bit_reverse_permute(&mut lde_evals);
#[cfg(feature = "instruments")]
let r4_fft_dur = t_sub.elapsed();

// FRI commit phase from pre-computed evaluations (no initial FFT)
// FRI commit phase from pre-computed evaluations
#[cfg(feature = "instruments")]
let t_sub = Instant::now();
let (fri_last_value, fri_layers) =
Expand Down Expand Up @@ -1183,10 +1182,11 @@ pub trait IsStarkProver<
.collect::<Vec<usize>>()
}

/// Computes the DEEP composition polynomial as evaluations on the trace-size coset.
/// Computes the DEEP composition polynomial at all 2N LDE points (Plonky3-style).
///
/// Evaluates `deep(x_i)` at N points (every bf-th point of the LDE coset).
/// The caller extends to the full 2N-point LDE domain before feeding to FRI.
/// Evaluates directly on the full LDE domain, eliminating the iFFT(N)+FFT(2N)
/// extension that was needed when computing at only N trace-coset points.
/// The result is ready for FRI after bit-reversal — no FFT needed.
///
/// The DEEP polynomial is:
/// deep(X) = Σ_j γ_j * (H_j(X) - H_j(z^K)) / (X - z^K)
Expand All @@ -1206,8 +1206,6 @@ pub trait IsStarkProver<
FieldElement<Field>: AsBytes,
FieldElement<FieldExtension>: AsBytes,
{
let domain_size = domain.interpolation_domain_size;
let blowup_factor = domain.blowup_factor;
let num_parts = round_2_result.lde_composition_poly_evaluations.len();
let z_power = z.pow(num_parts); // pole for H terms

Expand All @@ -1230,68 +1228,103 @@ pub trait IsStarkProver<
let num_main_cols = lde_trace.num_main_cols();
let num_aux_cols = lde_trace.num_aux_cols();

// Precompute all inverse denominators via batch inversion.
let num_denoms = domain_size * (1 + num_eval_points);
// Precompute all inverse denominators at ALL LDE points via batch inversion.
let lde_size = domain.lde_roots_of_unity_coset.len();
let num_denoms = lde_size * (1 + num_eval_points);
let mut denoms: Vec<FieldElement<FieldExtension>> = Vec::with_capacity(num_denoms);

// H-term denominators: x_i - z^K
for i in 0..domain_size {
let x_i = &domain.lde_roots_of_unity_coset[i * blowup_factor];
// H-term denominators: x_i - z^K (all 2N LDE points)
for i in 0..lde_size {
let x_i = &domain.lde_roots_of_unity_coset[i];
denoms.push(x_i - &z_power);
}

// Trace-term denominators: x_i - z_shifted[k]
// Trace-term denominators: x_i - z_shifted[k] (all 2N LDE points)
for z_k in z_shifted.iter().take(num_eval_points) {
for i in 0..domain_size {
let x_i = &domain.lde_roots_of_unity_coset[i * blowup_factor];
for i in 0..lde_size {
let x_i = &domain.lde_roots_of_unity_coset[i];
denoms.push(x_i - z_k);
}
}

FieldElement::inplace_batch_inverse(&mut denoms)
.expect("Denominators should be non-zero: coset points are base field, poles are extension field");

let inv_h = &denoms[0..domain_size];
let inv_h = &denoms[0..lde_size];

// OOD evaluations
let h_ood = &round_3_result.composition_poly_parts_ood_evaluation;
let trace_ood_columns = round_3_result.trace_ood_evaluations.columns();
let num_total_cols = num_main_cols + num_aux_cols;

// === Phase 1: Column compression (Plonky3-style) ===
// Instead of iterating all ~95 columns per row in the hot loop, we precompute:
// compressed_k[i] = Σ_j gamma[j][k] * lde_trace.get_main(i, j) for i in 0..lde_size
// ood_compressed_k = Σ_j gamma[j][k] * ood[j][k]
// This moves the column sum outside the hot loop. Since the new path evaluates
// DEEP directly at all 2N LDE points, no stride is needed — every row is used.

// Precompute OOD compressed values (one per eval point)
let mut ood_compressed: Vec<FieldElement<FieldExtension>> =
vec![FieldElement::zero(); num_eval_points];
for j in 0..num_total_cols {
let ood_evals_j = &trace_ood_columns[j];
let gammas_j = &trace_terms_gammas[j];
for k in 0..num_eval_points {
ood_compressed[k] += &gammas_j[k] * &ood_evals_j[k];
}
}

// Compressed traces at ALL 2N LDE points (Plonky3-style).
// Eliminates the iFFT(N)+FFT(2N) extension by computing directly at LDE size.
let compressed: Vec<Vec<FieldElement<FieldExtension>>> = (0..num_eval_points)
.map(|k| {
let main_gammas: Vec<&FieldElement<FieldExtension>> = (0..num_main_cols)
.map(|j| &trace_terms_gammas[j][k])
.collect();
let aux_gammas: Vec<&FieldElement<FieldExtension>> = (0..num_aux_cols)
.map(|j| &trace_terms_gammas[num_main_cols + j][k])
.collect();

#[cfg(feature = "parallel")]
let iter = (0..lde_size).into_par_iter();
#[cfg(not(feature = "parallel"))]
let iter = 0..lde_size;

iter.map(|i| {
let mut sum = FieldElement::<FieldExtension>::zero();
for (j, gamma) in main_gammas.iter().enumerate() {
sum += lde_trace.get_main(i, j) * *gamma;
}
for (j, gamma) in aux_gammas.iter().enumerate() {
sum += lde_trace.get_aux(i, j) * *gamma;
}
sum
})
.collect()
})
.collect();

// Compute deep(x_i) for each trace-size coset point
// Hot loop at all 2N LDE points — no FFT extension needed.
#[cfg(feature = "parallel")]
let iter = (0..domain_size).into_par_iter();
let iter = (0..lde_size).into_par_iter();
#[cfg(not(feature = "parallel"))]
let iter = 0..domain_size;
let iter = 0..lde_size;

iter.map(|i| {
let row_idx = i * blowup_factor; // LDE row index

// H terms: Σ_j γ_j * (H_j(x_i) - H_j(z^K)) * inv_h[i]
let mut result = FieldElement::<FieldExtension>::zero();

// H terms
for j in 0..num_parts {
let h_j_val = &round_2_result.lde_composition_poly_evaluations[j][row_idx];
let h_j_val = &round_2_result.lde_composition_poly_evaluations[j][i];
let h_j_ood = &h_ood[j];
let numerator = h_j_val - h_j_ood;
result += &composition_poly_gammas[j] * numerator * &inv_h[i];
result += &composition_poly_gammas[j] * (h_j_val - h_j_ood) * &inv_h[i];
}

// Trace terms: Σ_{j,k} γ'_{j,k} * (t_j(x_i) - t_j(z·w^k)) * inv_t_k[i]
let num_total_cols = num_main_cols + num_aux_cols;
for j in 0..num_total_cols {
let gammas_j = &trace_terms_gammas[j];
let ood_evals_j = &trace_ood_columns[j];

for k in 0..num_eval_points {
let inv_t_k_i = &denoms[(1 + k) * domain_size + i];

let t_j_ood = &ood_evals_j[k];
let numerator: FieldElement<FieldExtension> = if j < num_main_cols {
lde_trace.get_main(row_idx, j) - t_j_ood
} else {
lde_trace.get_aux(row_idx, j - num_main_cols) - t_j_ood
};
result += &gammas_j[k] * numerator * inv_t_k_i;
}
// Trace terms (compressed)
for k in 0..num_eval_points {
let inv_t_k_i = &denoms[(1 + k) * lde_size + i];
result += inv_t_k_i * (&compressed[k][i] - &ood_compressed[k]);
}

result
Expand Down
153 changes: 152 additions & 1 deletion crypto/stark/src/tests/prover_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,28 @@ fn barycentric_trace_eval_matches_horner_trace_eval() {
step_size,
);

// Precompute shared barycentric scalars
let n = domain.interpolation_domain_size;
let bf = domain.blowup_factor;
let coset_points: Vec<Felt> = (0..n)
.map(|i| domain.lde_roots_of_unity_coset[i * bf])
.collect();
let coset_offset_pow_n: Felt = domain.coset_offset.pow(n);
let n_inv: Felt = Felt::from(n as u64).inv().expect("n is a power of two");
let g_n_inv: Felt = coset_offset_pow_n.inv().expect("non-zero");

// Barycentric evaluation (new path)
let result = get_trace_evaluations_from_lde(&lde_trace, &domain, &z, &frame_offsets, step_size);
let result = get_trace_evaluations_from_lde(
&lde_trace,
&domain,
&z,
&frame_offsets,
step_size,
&coset_points,
&coset_offset_pow_n,
&n_inv,
&g_n_inv,
);

assert_eq!(result.width, expected.width);
assert_eq!(result.height, expected.height);
Expand Down Expand Up @@ -384,3 +404,134 @@ fn test_multi_prove_dedups_shared_domain_params() {
"verification should succeed when AIRs share all domain parameters"
);
}

/// Differential test for the DEEP composition polynomial "direct 2N" evaluation.
///
/// After this PR, `compute_deep_composition_poly_evaluations` evaluates the DEEP
/// polynomial directly at all 2N LDE points. The old path computed it at the N
/// trace-coset points and then extended via iFFT(N)+FFT(2N).
///
/// Both paths should produce the same values because `deep(X)` is a polynomial
/// of degree < N (the poles cancel by construction, since the numerators vanish
/// at the denominators' zeros). By uniqueness of polynomial interpolation, a
/// polynomial of degree < N is fully determined by its values on any N-point
/// subset — so extending from N matches evaluating directly at 2N.
///
/// This test constructs a synthetic scenario (known trace polys, composition
/// polys, OOD values, gammas), computes `deep(x)` at every LDE point two ways,
/// and asserts the results match exactly.
#[test]
fn test_deep_poly_direct_2n_matches_interpolate_fft_extend() {
let n = 16usize;
let blowup_factor = 2usize;
let two_n = n * blowup_factor;

let proof_options = ProofOptions {
blowup_factor: blowup_factor as u8,
fri_number_of_queries: 1,
coset_offset: 3,
grinding_factor: 0,
};

let air = QuadraticAIR::<GoldilocksField>::new(&proof_options);
let domain = Domain::new(&air, n);

// Trace polynomials (degree < N): two columns with deterministic coefficients.
let num_trace_cols = 2usize;
let trace_polys: Vec<Polynomial<Felt>> = (0..num_trace_cols)
.map(|j| {
let coeffs: Vec<Felt> = (0..n)
.map(|i| Felt::from(((i + 1) * (j + 2) * 11 + 7) as u64))
.collect();
Polynomial::new(&coeffs)
})
.collect();

// Composition poly parts (each of degree < N): two parts.
let num_parts = 2usize;
let h_polys: Vec<Polynomial<Felt>> = (0..num_parts)
.map(|j| {
let coeffs: Vec<Felt> = (0..n)
.map(|i| Felt::from(((i + 3) * (j + 5) * 19 + 31) as u64))
.collect();
Polynomial::new(&coeffs)
})
.collect();

// OOD evaluation point and the derived poles.
let z = Felt::from(12345u64);
let z_power = z.pow(num_parts);

let num_eval_points = 2usize;
let z_shifted: Vec<Felt> = (0..num_eval_points)
.map(|k| domain.trace_primitive_root.pow(k) * &z)
.collect();

// OOD values: H_j(z^K) and t_j(z·w^k).
let h_ood: Vec<Felt> = h_polys.iter().map(|h| h.evaluate(&z_power)).collect();
let t_ood: Vec<Vec<Felt>> = trace_polys
.iter()
.map(|t| z_shifted.iter().map(|z_k| t.evaluate(z_k)).collect())
.collect();

// Random-ish gammas.
let gamma_h: Vec<Felt> = (0..num_parts)
.map(|j| Felt::from((j as u64 + 1) * 100))
.collect();
let gamma_t: Vec<Vec<Felt>> = (0..num_trace_cols)
.map(|j| {
(0..num_eval_points)
.map(|k| Felt::from((((j + 1) * (k + 1)) as u64) * 200))
.collect()
})
.collect();

// Helper that computes deep(x) at a single point — same formula as the
// production code, written here without the per-row optimizations.
let compute_deep = |x: &Felt| -> Felt {
let mut result = Felt::zero();
// H terms
for j in 0..num_parts {
let numer = h_polys[j].evaluate(x) - &h_ood[j];
let denom_inv = (x - &z_power).inv().expect("z^K not on coset");
result += &gamma_h[j] * &numer * &denom_inv;
}
// Trace terms
for (j, trace_poly) in trace_polys.iter().enumerate().take(num_trace_cols) {
for k in 0..num_eval_points {
let numer = trace_poly.evaluate(x) - &t_ood[j][k];
let denom_inv = (x - &z_shifted[k]).inv().expect("z·w^k not on coset");
result += &gamma_t[j][k] * &numer * &denom_inv;
}
}
result
};

// Path A — direct evaluation at all 2N LDE points (the new path).
let direct_2n: Vec<Felt> = domain
.lde_roots_of_unity_coset
.iter()
.map(compute_deep)
.collect();

// Path B — evaluate at the N trace-coset points {g·ω^i} = lde_coset[i·bf],
// interpolate via iFFT, then extend via FFT to all 2N LDE points.
let trace_coset_evals: Vec<Felt> = (0..n)
.map(|i| compute_deep(&domain.lde_roots_of_unity_coset[i * blowup_factor]))
.collect();
let deep_poly = Polynomial::interpolate_offset_fft(&trace_coset_evals, &domain.coset_offset)
.expect("interpolation should succeed on trace-coset evaluations");
let extended_2n =
evaluate_polynomial_on_lde_domain(&deep_poly, blowup_factor, n, &domain.coset_offset)
.expect("LDE extension should succeed");

assert_eq!(direct_2n.len(), two_n);
assert_eq!(extended_2n.len(), two_n);
for i in 0..two_n {
assert_eq!(
direct_2n[i], extended_2n[i],
"deep evaluation mismatch at LDE index {i}: direct-2N path diverges from \
iFFT+FFT-extended path"
);
}
}
Loading
Loading