Skip to content

Commit

Permalink
Add support for Cuda Poseidon hal. (#454)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbruestle committed Mar 20, 2023
1 parent 285039c commit 0953b45
Show file tree
Hide file tree
Showing 10 changed files with 410 additions and 101 deletions.
4 changes: 2 additions & 2 deletions risc0/circuit/rv32im/benches/eval_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ pub fn eval_check(c: &mut Criterion) {
#[cfg(feature = "cuda")]
for po2 in [2, 8, 16, 20, 21].iter() {
let params = EvalCheckParams::new(*po2);
let hal = std::rc::Rc::new(risc0_zkp::hal::cuda::CudaHal::new());
let eval = risc0_circuit_rv32im::cuda::CudaEvalCheck::new(hal.clone());
let hal = std::rc::Rc::new(risc0_zkp::hal::cuda::CudaHalSha256::new());
let eval = risc0_circuit_rv32im::cuda::CudaEvalCheckSha256::new(hal.clone());
group.bench_function(BenchmarkId::new("cuda", po2), |b| {
b.iter(|| {
eval_check_impl(&params, hal.as_ref(), &eval);
Expand Down
19 changes: 11 additions & 8 deletions risc0/circuit/rv32im/src/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use risc0_core::field::{
use risc0_zkp::{
core::log2_ceil,
hal::{
cuda::{BufferImpl as CudaBuffer, CudaHal},
cuda::{BufferImpl as CudaBuffer, CudaHal, CudaHash, CudaHashPoseidon, CudaHashSha256},
Buffer, EvalCheck,
},
INV_RATE,
Expand All @@ -35,20 +35,20 @@ use crate::{

const KERNELS_FATBIN: &[u8] = include_bytes!(env!("RV32IM_CUDA_PATH"));

pub struct CudaEvalCheck {
hal: Rc<CudaHal>, // retain a reference to ensure the context remains valid
pub struct CudaEvalCheck<CH: CudaHash> {
hal: Rc<CudaHal<CH>>, // retain a reference to ensure the context remains valid
module: Module,
}

impl CudaEvalCheck {
impl<CH: CudaHash> CudaEvalCheck<CH> {
#[tracing::instrument(name = "CudaEvalCheck::new", skip_all)]
pub fn new(hal: Rc<CudaHal>) -> Self {
pub fn new(hal: Rc<CudaHal<CH>>) -> Self {
let module = Module::load_from_bytes(KERNELS_FATBIN).unwrap();
Self { hal, module }
}
}

impl<'a> EvalCheck<CudaHal> for CudaEvalCheck {
impl<'a, CH: CudaHash> EvalCheck<CudaHal<CH>> for CudaEvalCheck<CH> {
#[tracing::instrument(skip_all)]
fn eval_check(
&self,
Expand Down Expand Up @@ -111,11 +111,14 @@ impl<'a> EvalCheck<CudaHal> for CudaEvalCheck {
}
}

pub type CudaEvalCheckSha256 = CudaEvalCheck<CudaHashSha256>;
pub type CudaEvalCheckPoseidon = CudaEvalCheck<CudaHashPoseidon>;

#[cfg(test)]
mod tests {
use std::rc::Rc;

use risc0_zkp::hal::{cpu::BabyBearSha256CpuHal, cuda::CudaHal};
use risc0_zkp::hal::{cpu::BabyBearSha256CpuHal, cuda::CudaHalSha256};
use test_log::test;

use crate::cpu::CpuEvalCheck;
Expand All @@ -126,7 +129,7 @@ mod tests {
let circuit = crate::CircuitImpl::new();
let cpu_hal = BabyBearSha256CpuHal::new();
let cpu_eval = CpuEvalCheck::new(&circuit);
let gpu_hal = Rc::new(CudaHal::new());
let gpu_hal = Rc::new(CudaHalSha256::new());
let gpu_eval = super::CudaEvalCheck::new(gpu_hal.clone());
crate::testutil::eval_check(&cpu_hal, cpu_eval, gpu_hal.as_ref(), gpu_eval, PO2);
}
Expand Down
1 change: 1 addition & 0 deletions risc0/sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ fn build_cuda_kernels() {
"fri.cu",
"mix.cu",
"ntt.cu",
"poseidon.cu",
"sha.cu",
"zk.cu",
"sha256.h",
Expand Down
1 change: 1 addition & 0 deletions risc0/sys/kernels/zkp/cuda/all.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@
#include "fri.cu"
#include "mix.cu"
#include "ntt.cu"
#include "poseidon.cu"
#include "sha.cu"
#include "zk.cu"
153 changes: 153 additions & 0 deletions risc0/sys/kernels/zkp/cuda/poseidon.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
// Copyright 2023 Risc0, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "fp.h"

#define CELLS 24
#define ROUNDS_FULL 8
#define ROUNDS_HALF_FULL (ROUNDS_FULL / 2)
#define ROUNDS_PARTIAL 21
#define ROW_SIZE (CELLS + ROUNDS_PARTIAL)
#define CELLS_RATE 16
#define CELLS_OUT 8

__device__ void add_round_constants(const Fp* ROUND_CONSTANTS, Fp* cells, uint round) {
for (uint i = 0; i < CELLS; i++) {
cells[i] += ROUND_CONSTANTS[round * CELLS + i];
}
}

__device__ Fp sbox(Fp x) {
Fp x2 = x * x;
Fp x4 = x2 * x2;
Fp x6 = x4 * x2;
return x6 * x;
}

__device__ void do_full_sboxes(Fp* cells) {
for (uint i = 0; i < CELLS; i++) {
cells[i] = sbox(cells[i]);
}
}

__device__ void multiply_by_mds(const Fp* MDS, Fp* cells) {
Fp new_cells[CELLS];
for (uint i = 0; i < CELLS; i++) {
Fp tot = 0;
for (uint j = 0; j < CELLS; j++) {
tot += MDS[i * CELLS + j] * cells[j];
}
new_cells[i] = tot;
}
for (uint i = 0; i < CELLS; i++) {
cells[i] = new_cells[i];
}
}

__device__ void full_round(const Fp* ROUND_CONSTANTS, const Fp* MDS, Fp* cells, uint round) {
add_round_constants(ROUND_CONSTANTS, cells, round);
do_full_sboxes(cells);
multiply_by_mds(MDS, cells);
}

__device__ void poseidon_mix(const Fp* ROUND_CONSTANTS,
const Fp* MDS,
const Fp* PARTIAL_COMP_MATRIX,
const Fp* PARTIAL_COMP_OFFSET,
Fp* cells) {
uint round = 0;
for (uint i = 0; i < ROUNDS_HALF_FULL; i++) {
full_round(ROUND_CONSTANTS, MDS, cells, round);
round++;
}
Fp sboxes[ROUNDS_PARTIAL];
for (uint i = 0; i < ROUNDS_PARTIAL; i++) {
// For each sbox, compute it's input
Fp sbox_in = PARTIAL_COMP_OFFSET[CELLS + i];
for (uint j = 0; j < CELLS; j++) {
sbox_in += PARTIAL_COMP_MATRIX[(CELLS + i) * ROW_SIZE + j] * cells[j];
}
for (uint j = 0; j < i; j++) {
sbox_in += PARTIAL_COMP_MATRIX[(CELLS + i) * ROW_SIZE + CELLS + j] * sboxes[j];
}
// Run it through the sbox + record it
sboxes[i] = sbox(sbox_in);
}
// Forward output data back to cells
Fp new_cells[CELLS];
for (uint i = 0; i < CELLS; i++) {
Fp out = PARTIAL_COMP_OFFSET[i];
for (uint j = 0; j < CELLS; j++) {
out += PARTIAL_COMP_MATRIX[i * ROW_SIZE + j] * cells[j];
}
for (uint j = 0; j < ROUNDS_PARTIAL; j++) {
out += PARTIAL_COMP_MATRIX[i * ROW_SIZE + CELLS + j] * sboxes[j];
}
new_cells[i] = out;
}
round += ROUNDS_PARTIAL;
for (uint i = 0; i < CELLS; i++) {
cells[i] = new_cells[i];
}
for (uint i = 0; i < ROUNDS_HALF_FULL; i++) {
full_round(ROUND_CONSTANTS, MDS, cells, round);
round++;
}
}

extern "C" __global__ void poseidon_fold(const Fp* ROUND_CONSTANTS,
const Fp* MDS,
const Fp* PARTIAL_COMP_MATRIX,
const Fp* PARTIAL_COMP_OFFSET,
Fp* output,
const Fp* input,
uint32_t output_size) {
uint32_t gid = blockDim.x * blockIdx.x + threadIdx.x;
Fp cells[CELLS];
for (size_t i = 0; i < CELLS_OUT; i++) {
cells[i] = input[2 * gid * CELLS_OUT + i];
cells[CELLS_OUT + i] = input[(2 * gid + 1) * CELLS_OUT + i];
}
poseidon_mix(ROUND_CONSTANTS, MDS, PARTIAL_COMP_MATRIX, PARTIAL_COMP_OFFSET, cells);
for (uint i = 0; i < CELLS_OUT; i++) {
output[gid * CELLS_OUT + i] = cells[i];
}
}

extern "C" __global__ void poseidon_rows(const Fp* ROUND_CONSTANTS,
const Fp* MDS,
const Fp* PARTIAL_COMP_MATRIX,
const Fp* PARTIAL_COMP_OFFSET,
Fp* out,
const Fp* matrix,
uint32_t count,
uint32_t col_size) {
uint32_t gid = blockDim.x * blockIdx.x + threadIdx.x;
Fp cells[CELLS];
uint used = 0;
for (uint i = 0; i < col_size; i++) {
cells[used++] += matrix[i * count + gid];
if (used == CELLS_RATE) {
poseidon_mix(ROUND_CONSTANTS, MDS, PARTIAL_COMP_MATRIX, PARTIAL_COMP_OFFSET, cells);
used = 0;
}
}
if (used != 0 || count == 0) {
poseidon_mix(ROUND_CONSTANTS, MDS, PARTIAL_COMP_MATRIX, PARTIAL_COMP_OFFSET, cells);
}
for (uint i = 0; i < CELLS_OUT; i++) {
out[CELLS_OUT * gid + i] = cells[i];
}
}

0 comments on commit 0953b45

Please sign in to comment.