Skip to content

Commit

Permalink
Batched PolyOps interface
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Apr 16, 2024
1 parent a5e69d9 commit 8669469
Show file tree
Hide file tree
Showing 6 changed files with 271 additions and 199 deletions.
10 changes: 8 additions & 2 deletions crates/prover/benches/eval_at_point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ pub fn cpu_eval_at_secure_point(c: &mut criterion::Criterion) {
let point = CirclePoint { x, y };
c.bench_function("cpu eval_at_secure_field_point 2^20", |b| {
b.iter(|| {
black_box(<CPUBackend as PolyOps>::eval_at_point(&poly, point));
black_box(<CPUBackend as PolyOps>::eval_at_points(
&[&poly],
&[vec![point]],
));
})
});
}
Expand Down Expand Up @@ -79,7 +82,10 @@ pub fn avx512_eval_at_secure_point(c: &mut criterion::Criterion) {
let point = CirclePoint { x, y };
c.bench_function("avx eval_at_secure_field_point 2^20", |b| {
b.iter(|| {
black_box(<AVX512Backend as PolyOps>::eval_at_point(&poly, point));
black_box(<AVX512Backend as PolyOps>::eval_at_points(
&[&poly],
&[vec![point]],
));
})
});
}
Expand Down
240 changes: 134 additions & 106 deletions crates/prover/src/core/backend/avx512/circle.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::iter::zip;

use bytemuck::{cast_slice, Zeroable};
use num_traits::One;

Expand Down Expand Up @@ -115,57 +117,11 @@ impl AVX512Backend {
fn advance_twiddle<F: Field>(twiddle: F, steps: &[F], curr_idx: usize) -> F {
twiddle * steps[curr_idx.trailing_ones() as usize]
}
}

// TODO(spapini): Everything is returned in redundant representation, where values can also be P.
// Decide if and when it's ok and what to do if it's not.
impl PolyOps for AVX512Backend {
// The twiddles type is i32, and not BaseField. This is because the fast AVX mul implementation
// requries one of the numbers to be shifted left by 1 bit. This is not a reduced
// representation of the field.
type Twiddles = Vec<i32>;

fn new_canonical_ordered(
coset: CanonicCoset,
values: Col<Self, BaseField>,
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
// TODO(spapini): Optimize.
let eval = CPUBackend::new_canonical_ordered(coset, as_cpu_vec(values));
CircleEvaluation::new(
eval.domain,
Col::<AVX512Backend, BaseField>::from_iter(eval.values),
)
}

fn interpolate(
eval: CircleEvaluation<Self, BaseField, BitReversedOrder>,
twiddles: &TwiddleTree<Self>,
) -> CirclePoly<Self> {
let mut values = eval.values;
let log_size = values.length.ilog2();

let twiddles = domain_line_twiddles_from_tree(eval.domain, &twiddles.itwiddles);

// Safe because [PackedBaseField] is aligned on 64 bytes.
unsafe {
ifft::ifft(
std::mem::transmute(values.data.as_mut_ptr()),
&twiddles,
log_size as usize,
);
}

// TODO(spapini): Fuse this multiplication / rotation.
let inv = BaseField::from_u32_unchecked(eval.domain.size() as u32).inverse();
let inv = PackedBaseField::from_array([inv; 16]);
for x in values.data.iter_mut() {
*x *= inv;
}

CirclePoly::new(values)
}

fn eval_at_point(poly: &CirclePoly<Self>, point: CirclePoint<SecureField>) -> SecureField {
fn eval_at_point(
poly: &CirclePoly<AVX512Backend>,
point: CirclePoint<SecureField>,
) -> SecureField {
// If the polynomial is small, fallback to evaluate directly.
// TODO(Ohad): it's possible to avoid falling back. Consider fixing.
if poly.log_size() <= 8 {
Expand Down Expand Up @@ -212,70 +168,142 @@ impl PolyOps for AVX512Backend {

(sum * twiddle_lows).pointwise_sum()
}
}

// TODO(spapini): Everything is returned in redundant representation, where values can also be P.
// Decide if and when it's ok and what to do if it's not.
impl PolyOps for AVX512Backend {
// The twiddles type is i32, and not BaseField. This is because the fast AVX mul implementation
// requries one of the numbers to be shifted left by 1 bit. This is not a reduced
// representation of the field.
type Twiddles = Vec<i32>;

fn new_canonical_ordered(
coset: CanonicCoset,
values: Col<Self, BaseField>,
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
// TODO(spapini): Optimize.
let eval = CPUBackend::new_canonical_ordered(coset, as_cpu_vec(values));
CircleEvaluation::new(
eval.domain,
Col::<AVX512Backend, BaseField>::from_iter(eval.values),
)
}

fn interpolate_batch(
evals: Vec<CircleEvaluation<Self, BaseField, BitReversedOrder>>,
twiddles: &TwiddleTree<Self>,
) -> Vec<CirclePoly<Self>> {
evals
.into_iter()
.map(|eval| {
let mut values = eval.values;
let log_size = values.length.ilog2();

let twiddles = domain_line_twiddles_from_tree(eval.domain, &twiddles.itwiddles);

// Safe because [PackedBaseField] is aligned on 64 bytes.
unsafe {
ifft::ifft(
std::mem::transmute(values.data.as_mut_ptr()),
&twiddles,
log_size as usize,
);
}

// TODO(spapini): Fuse this multiplication / rotation.
let inv = BaseField::from_u32_unchecked(eval.domain.size() as u32).inverse();
let inv = PackedBaseField::from_array([inv; 16]);
for x in values.data.iter_mut() {
*x *= inv;
}

CirclePoly::new(values)
})
.collect()
}

fn eval_at_points(
poly: &[&CirclePoly<Self>],
points: &[Vec<CirclePoint<SecureField>>],
) -> Vec<Vec<SecureField>> {
zip(poly, points)
.map(|(poly, points)| {
points
.iter()
.map(|point| Self::eval_at_point(poly, *point))
.collect()
})
.collect()
}

fn extend(poly: &CirclePoly<Self>, log_size: u32) -> CirclePoly<Self> {
// TODO(spapini): Optimize or get rid of extend.
poly.evaluate(CanonicCoset::new(log_size).circle_domain())
.interpolate()
}

fn evaluate(
poly: &CirclePoly<Self>,
domain: CircleDomain,
fn evaluate_batch(
polys: &[&CirclePoly<Self>],
domains: &[CircleDomain],
twiddles: &TwiddleTree<Self>,
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
// TODO(spapini): Precompute twiddles.
// TODO(spapini): Handle small cases.
let log_size = domain.log_size() as usize;
let fft_log_size = poly.log_size() as usize;
assert!(
log_size >= fft_log_size,
"Can only evaluate on larger domains"
);

let twiddles = domain_line_twiddles_from_tree(domain, &twiddles.twiddles);

// Evaluate on a big domains by evaluating on several subdomains.
let log_subdomains = log_size - fft_log_size;

// Alllocate the destination buffer without initializing.
let mut values = Vec::with_capacity(domain.size() >> VECS_LOG_SIZE);
#[allow(clippy::uninit_vec)]
unsafe {
values.set_len(domain.size() >> VECS_LOG_SIZE)
};

for i in 0..(1 << log_subdomains) {
// The subdomain twiddles are a slice of the large domain twiddles.
let subdomain_twiddles = (0..(fft_log_size - 1))
.map(|layer_i| {
&twiddles[layer_i]
[i << (fft_log_size - 2 - layer_i)..(i + 1) << (fft_log_size - 2 - layer_i)]
})
.collect::<Vec<_>>();

// FFT from the coefficients buffer to the values chunk.
unsafe {
rfft::fft(
std::mem::transmute(poly.coeffs.data.as_ptr()),
std::mem::transmute(
values[i << (fft_log_size - VECS_LOG_SIZE)
..(i + 1) << (fft_log_size - VECS_LOG_SIZE)]
.as_mut_ptr(),
),
&subdomain_twiddles,
fft_log_size,
) -> Vec<CircleEvaluation<Self, BaseField, BitReversedOrder>> {
zip(polys, domains)
.map(|(poly, domain)| {
// TODO(spapini): Precompute twiddles.
// TODO(spapini): Handle small cases.
let log_size = domain.log_size() as usize;
let fft_log_size = poly.log_size() as usize;
assert!(
log_size >= fft_log_size,
"Can only evaluate on larger domains"
);
}
}

CircleEvaluation::new(
domain,
BaseFieldVec {
data: values,
length: domain.size(),
},
)
let twiddles = domain_line_twiddles_from_tree(*domain, &twiddles.twiddles);

// Evaluate on a big domains by evaluating on several subdomains.
let log_subdomains = log_size - fft_log_size;

// Alllocate the destination buffer without initializing.
let mut values = Vec::with_capacity(domain.size() >> VECS_LOG_SIZE);
#[allow(clippy::uninit_vec)]
unsafe {
values.set_len(domain.size() >> VECS_LOG_SIZE)
};

for i in 0..(1 << log_subdomains) {
// The subdomain twiddles are a slice of the large domain twiddles.
let subdomain_twiddles = (0..(fft_log_size - 1))
.map(|layer_i| {
&twiddles[layer_i][i << (fft_log_size - 2 - layer_i)
..(i + 1) << (fft_log_size - 2 - layer_i)]
})
.collect::<Vec<_>>();

// FFT from the coefficients buffer to the values chunk.
unsafe {
rfft::fft(
std::mem::transmute(poly.coeffs.data.as_ptr()),
std::mem::transmute(
values[i << (fft_log_size - VECS_LOG_SIZE)
..(i + 1) << (fft_log_size - VECS_LOG_SIZE)]
.as_mut_ptr(),
),
&subdomain_twiddles,
fft_log_size,
);
}
}

CircleEvaluation::new(
*domain,
BaseFieldVec {
data: values,
length: domain.size(),
},
)
})
.collect()
}

fn precompute_twiddles(coset: Coset) -> TwiddleTree<Self> {
Expand Down Expand Up @@ -454,14 +482,14 @@ mod tests {
let p = CirclePoint { x, y };

assert_eq!(
<AVX512Backend as PolyOps>::eval_at_point(&poly, p),
<AVX512Backend as PolyOps>::eval_at_points(&[&poly], &[vec![p]])[0][0],
slow_eval_at_point(&poly, p),
"log_size = {log_size}"
);

println!(
"log_size = {log_size} passed, eval{}",
<AVX512Backend as PolyOps>::eval_at_point(&poly, p)
<AVX512Backend as PolyOps>::eval_at_points(&[&poly], &[vec![p]])[0][0]
);
}
}
Expand Down
Loading

0 comments on commit 8669469

Please sign in to comment.