Skip to content

Commit

Permalink
Combine HAL kernels (#889)
Browse files Browse the repository at this point in the history
This combines the `batch_expand` and `batch_evaluate_ntt` kernels into a
single HAL function. This is because they are always called together and
allows for better GPU optimization by treating them as a single unit.
  • Loading branch information
flaub committed Sep 19, 2023
1 parent b7ac357 commit c623f4c
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 163 deletions.
53 changes: 29 additions & 24 deletions risc0/zkp/src/hal/cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,36 +295,41 @@ impl<F: Field> Hal for CpuHal<F> {
}

#[tracing::instrument(skip_all)]
fn batch_expand(
fn batch_expand_into_evaluate_ntt(
&self,
output: &Self::Buffer<Self::Elem>,
input: &Self::Buffer<Self::Elem>,
count: usize,
expand_bits: usize,
) {
let out_size = output.size() / count;
let in_size = input.size() / count;
let expand_bits = log2_ceil(out_size / in_size);
assert_eq!(out_size, in_size * (1 << expand_bits));
assert_eq!(out_size * count, output.size());
assert_eq!(in_size * count, input.size());
output
.as_slice_mut()
.par_chunks_exact_mut(out_size)
.zip(input.as_slice().par_chunks_exact(in_size))
.for_each(|(output, input)| {
expand(output, input, expand_bits);
});
}
// batch_expand
{
let out_size = output.size() / count;
let in_size = input.size() / count;
let expand_bits = log2_ceil(out_size / in_size);
assert_eq!(out_size, in_size * (1 << expand_bits));
assert_eq!(out_size * count, output.size());
assert_eq!(in_size * count, input.size());
output
.as_slice_mut()
.par_chunks_exact_mut(out_size)
.zip(input.as_slice().par_chunks_exact(in_size))
.for_each(|(output, input)| {
expand(output, input, expand_bits);
});
}

#[tracing::instrument(skip_all)]
fn batch_evaluate_ntt(&self, io: &Self::Buffer<Self::Elem>, count: usize, expand_bits: usize) {
let row_size = io.size() / count;
assert_eq!(row_size * count, io.size());
io.as_slice_mut()
.par_chunks_exact_mut(row_size)
.for_each(|row| {
evaluate_ntt::<Self::Elem, Self::Elem>(row, expand_bits);
});
// batch_evaluate_ntt
{
let row_size = output.size() / count;
assert_eq!(row_size * count, output.size());
output
.as_slice_mut()
.par_chunks_exact_mut(row_size)
.for_each(|row| {
evaluate_ntt::<Self::Elem, Self::Elem>(row, expand_bits);
});
}
}

#[tracing::instrument(skip_all)]
Expand Down
104 changes: 51 additions & 53 deletions risc0/zkp/src/hal/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -439,62 +439,66 @@ impl<CH: CudaHash> Hal for CudaHal<CH> {
}

#[tracing::instrument(skip_all)]
fn batch_expand(
fn batch_expand_into_evaluate_ntt(
&self,
output: &Self::Buffer<Self::Elem>,
input: &Self::Buffer<Self::Elem>,
poly_count: usize,
count: usize,
expand_bits: usize,
) {
let out_size = output.size() / poly_count;
let in_size = input.size() / poly_count;
let expand_bits = log2_ceil(out_size / in_size);
assert_eq!(output.size(), out_size * poly_count);
assert_eq!(input.size(), in_size * poly_count);
assert_eq!(out_size, in_size * (1 << expand_bits));

let stream = Stream::new(StreamFlags::DEFAULT, None).unwrap();
let kernel = self.module.get_function("batch_expand").unwrap();
let params = self.compute_simple_params(out_size.try_into().unwrap());
unsafe {
launch!(kernel<<<params.0, params.1, 0, stream>>>(
output.as_device_ptr(),
input.as_device_ptr(),
poly_count as u32,
out_size as u32,
in_size as u32,
expand_bits as u32
))
.unwrap();
}
stream.synchronize().unwrap();
}

#[tracing::instrument(skip_all)]
fn batch_evaluate_ntt(&self, io: &Self::Buffer<Self::Elem>, count: usize, expand_bits: usize) {
let row_size = io.size() / count;
assert_eq!(row_size * count, io.size());
let n_bits = log2_ceil(row_size);
assert_eq!(row_size, 1 << n_bits);
assert!(n_bits >= expand_bits);
assert!(n_bits < Self::Elem::MAX_ROU_PO2);
let rou = self.copy_from_elem("rou", Self::Elem::ROU_FWD);

let stream = Stream::new(StreamFlags::DEFAULT, None).unwrap();
let kernel = self.module.get_function("multi_ntt_fwd_step").unwrap();
for s_bits in 1 + expand_bits..=n_bits {
let params = self.compute_launch_params(n_bits as u32, s_bits as u32, count as u32);
// batch_expand
{
let out_size = output.size() / count;
let in_size = input.size() / count;
let expand_bits = log2_ceil(out_size / in_size);
assert_eq!(output.size(), out_size * count);
assert_eq!(input.size(), in_size * count);
assert_eq!(out_size, in_size * (1 << expand_bits));

let stream = Stream::new(StreamFlags::DEFAULT, None).unwrap();
let kernel = self.module.get_function("batch_expand").unwrap();
let params = self.compute_simple_params(out_size.try_into().unwrap());
unsafe {
launch!(kernel<<<params.0, params.1, 0, stream>>>(
io.as_device_ptr(),
rou.as_device_ptr(),
n_bits as u32,
s_bits as u32,
count as u32
output.as_device_ptr(),
input.as_device_ptr(),
count as u32,
out_size as u32,
in_size as u32,
expand_bits as u32
))
.unwrap();
}
stream.synchronize().unwrap();
}

// batch_evaluate_ntt
{
let row_size = output.size() / count;
assert_eq!(row_size * count, output.size());
let n_bits = log2_ceil(row_size);
assert_eq!(row_size, 1 << n_bits);
assert!(n_bits >= expand_bits);
assert!(n_bits < Self::Elem::MAX_ROU_PO2);
let rou = self.copy_from_elem("rou", Self::Elem::ROU_FWD);

let stream = Stream::new(StreamFlags::DEFAULT, None).unwrap();
let kernel = self.module.get_function("multi_ntt_fwd_step").unwrap();
for s_bits in 1 + expand_bits..=n_bits {
let params = self.compute_launch_params(n_bits as u32, s_bits as u32, count as u32);
unsafe {
launch!(kernel<<<params.0, params.1, 0, stream>>>(
output.as_device_ptr(),
rou.as_device_ptr(),
n_bits as u32,
s_bits as u32,
count as u32
))
.unwrap();
}
stream.synchronize().unwrap();
}
}
}

#[tracing::instrument(skip_all)]
Expand Down Expand Up @@ -861,14 +865,8 @@ mod tests {

#[test]
#[serial]
fn batch_expand() {
testutil::batch_expand(CudaHalSha256::new());
}

#[test]
#[serial]
fn batch_evaluate_ntt() {
testutil::batch_evaluate_ntt(CudaHalSha256::new());
fn batch_expand_into_evaluate_ntt() {
testutil::batch_expand_into_evaluate_ntt(CudaHalSha256::new());
}

#[test]
Expand Down
16 changes: 7 additions & 9 deletions risc0/zkp/src/hal/dual.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,23 +172,21 @@ where
BufferImpl::new(lhs, rhs)
}

fn batch_expand(
#[tracing::instrument(skip_all)]
fn batch_expand_into_evaluate_ntt(
&self,
output: &Self::Buffer<Self::Elem>,
input: &Self::Buffer<Self::Elem>,
count: usize,
expand_bits: usize,
) {
self.lhs.batch_expand(&output.lhs, &input.lhs, count);
self.rhs.batch_expand(&output.rhs, &input.rhs, count);
self.lhs
.batch_expand_into_evaluate_ntt(&output.lhs, &input.lhs, count, expand_bits);
self.rhs
.batch_expand_into_evaluate_ntt(&output.rhs, &input.rhs, count, expand_bits);
output.assert_eq();
}

fn batch_evaluate_ntt(&self, io: &Self::Buffer<Self::Elem>, count: usize, expand_bits: usize) {
self.lhs.batch_evaluate_ntt(&io.lhs, count, expand_bits);
self.rhs.batch_evaluate_ntt(&io.rhs, count, expand_bits);
io.assert_eq();
}

fn batch_interpolate_ntt(&self, io: &Self::Buffer<Self::Elem>, count: usize) {
self.lhs.batch_interpolate_ntt(&io.lhs, count);
self.rhs.batch_interpolate_ntt(&io.rhs, count);
Expand Down
103 changes: 51 additions & 52 deletions risc0/zkp/src/hal/metal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -468,58 +468,62 @@ impl<MH: MetalHash> Hal for MetalHal<MH> {
}

#[tracing::instrument(skip_all)]
fn batch_expand(
fn batch_expand_into_evaluate_ntt(
&self,
output: &Self::Buffer<Self::Elem>,
input: &Self::Buffer<Self::Elem>,
poly_count: usize,
count: usize,
expand_bits: usize,
) {
log::debug!(
"output: {}, input: {}, poly_count: {poly_count}",
output.size(),
input.size()
);
let out_size = output.size() / poly_count;
let in_size = input.size() / poly_count;
let expand_bits = log2_ceil(out_size / in_size);
assert_eq!(output.size(), out_size * poly_count);
assert_eq!(input.size(), in_size * poly_count);
assert_eq!(out_size, in_size * (1 << expand_bits));
let args = &[
output.as_arg(),
input.as_arg(),
KernelArg::Integer(poly_count as u32),
KernelArg::Integer(out_size as u32),
KernelArg::Integer(in_size as u32),
KernelArg::Integer(expand_bits as u32),
];
self.dispatch_by_name("batch_expand", args, out_size as u64);
}

#[tracing::instrument(skip_all)]
fn batch_evaluate_ntt(&self, io: &Self::Buffer<Self::Elem>, count: usize, expand_bits: usize) {
log::debug!(
"io: {}, count: {count}, expand_bits: {expand_bits}",
io.size()
);
let row_size = io.size() / count;
assert_eq!(row_size * count, io.size());
let n_bits = log2_ceil(row_size);
assert_eq!(row_size, 1 << n_bits);
assert!(n_bits >= expand_bits);
assert!(n_bits < Self::Elem::MAX_ROU_PO2);
let rou = self.copy_from_elem("rou", Self::Elem::ROU_FWD);
let kernel = self.kernels.get("multi_ntt_fwd_step").unwrap();
for s_bits in 1 + expand_bits..=n_bits {
// batch_expand
{
log::debug!(
"output: {}, input: {}, count: {count}",
output.size(),
input.size()
);
let out_size = output.size() / count;
let in_size = input.size() / count;
let expand_bits = log2_ceil(out_size / in_size);
assert_eq!(output.size(), out_size * count);
assert_eq!(input.size(), in_size * count);
assert_eq!(out_size, in_size * (1 << expand_bits));
let args = &[
io.as_arg(),
rou.as_arg(),
KernelArg::Integer(n_bits as u32),
KernelArg::Integer(s_bits as u32),
output.as_arg(),
input.as_arg(),
KernelArg::Integer(count as u32),
KernelArg::Integer(out_size as u32),
KernelArg::Integer(in_size as u32),
KernelArg::Integer(expand_bits as u32),
];
let params = compute_launch_params(n_bits as u32, s_bits as u32, count as u32);
self.dispatch(kernel, args, count as u64, Some(params));
self.dispatch_by_name("batch_expand", args, out_size as u64);
}

// batch_evaluate_ntt
{
log::debug!(
"output: {}, count: {count}, expand_bits: {expand_bits}",
output.size()
);
let row_size = output.size() / count;
assert_eq!(row_size * count, output.size());
let n_bits = log2_ceil(row_size);
assert_eq!(row_size, 1 << n_bits);
assert!(n_bits >= expand_bits);
assert!(n_bits < Self::Elem::MAX_ROU_PO2);
let rou = self.copy_from_elem("rou", Self::Elem::ROU_FWD);
let kernel = self.kernels.get("multi_ntt_fwd_step").unwrap();
for s_bits in 1 + expand_bits..=n_bits {
let args = &[
output.as_arg(),
rou.as_arg(),
KernelArg::Integer(n_bits as u32),
KernelArg::Integer(s_bits as u32),
KernelArg::Integer(count as u32),
];
let params = compute_launch_params(n_bits as u32, s_bits as u32, count as u32);
self.dispatch(kernel, args, count as u64, Some(params));
}
}
}

Expand Down Expand Up @@ -771,13 +775,8 @@ mod tests {
}

#[test]
fn batch_evaluate_ntt() {
testutil::batch_evaluate_ntt(MetalHalSha256::new());
}

#[test]
fn batch_expand() {
testutil::batch_expand(MetalHalSha256::new());
fn batch_expand_into_evaluate_ntt() {
testutil::batch_expand_into_evaluate_ntt(MetalHalSha256::new());
}

#[test]
Expand Down

0 comments on commit c623f4c

Please sign in to comment.