Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "insecure sha compress 256" and "cycle count" GPIOs (host cpp support, guest rust support) #260

Merged
merged 9 commits into from
Aug 30, 2022
26 changes: 24 additions & 2 deletions risc0/zkp/rust/src/core/sha.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ pub const DIGEST_WORDS: usize = 8;
/// The size of a word within a [Digest] (32-bits = 4 bytes).
pub const DIGEST_WORD_SIZE: usize = mem::size_of::<u32>();

/// Standard SHA initialization vector .
pub static SHA256_INIT: Digest = Digest::new([
0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19,
]);
shkoo marked this conversation as resolved.
Show resolved Hide resolved

/// The result of a SHA-256 hashing function.
// TODO(nils): Remove 'Copy' trait on Digest; these are not small and
// we don't want to copy them around accidentally.
Expand All @@ -44,7 +49,7 @@ pub struct Digest([u32; DIGEST_WORDS]);

impl Digest {
/// Create a new [Digest] from an existing array of words.
pub fn new(data: [u32; DIGEST_WORDS]) -> Digest {
pub const fn new(data: [u32; DIGEST_WORDS]) -> Digest {
Digest(data)
}

Expand Down Expand Up @@ -162,8 +167,25 @@ pub trait Sha: Clone + Debug {
/// length.
fn hash_raw_words(&self, words: &[u32]) -> Self::DigestPtr;

/// Update a SHA digest with zero or more new blocks, zero padded
/// up to the next block boundry. Not all implementations provide
/// this.
fn update(&self, state: &Digest, bytes: &[u8]) -> Self::DigestPtr;

/// Generate a SHA from a pair of [Digests](Digest).
fn hash_pair(&self, a: &Digest, b: &Digest) -> Self::DigestPtr;
fn hash_pair(&self, a: &Digest, b: &Digest) -> Self::DigestPtr {
self.compress(&SHA256_INIT, a, b)
}

/// Execute the sha256 "compress" operation. The block is
/// specified as two half-blocks. Not all implementations provide
/// this.
fn compress(
&self,
_state: &Digest,
shkoo marked this conversation as resolved.
Show resolved Hide resolved
block_half1: &Digest,
block_half2: &Digest,
) -> Self::DigestPtr;

/// Generate a SHA from a slice of anything that can be
/// represented as plain old data. Pads up to the Sha block
Expand Down
41 changes: 37 additions & 4 deletions risc0/zkp/rust/src/core/sha_cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,50 @@ impl Sha for Impl {
}

// Digest two digest into one
fn hash_pair(&self, a: &Digest, b: &Digest) -> Self::DigestPtr {
let mut state = INIT_256;
fn compress(
&self,
orig_state: &Digest,
block_half1: &Digest,
block_half2: &Digest,
) -> Self::DigestPtr {
let mut state: [u32; DIGEST_WORDS] = *orig_state.get();
let mut block: GenericArray<u8, U64> = GenericArray::default();
for i in 0..8 {
set_word(block.as_mut_slice(), i, a.as_slice()[i]);
set_word(block.as_mut_slice(), 8 + i, b.as_slice()[i]);
set_word(block.as_mut_slice(), i, block_half1.as_slice()[i]);
set_word(block.as_mut_slice(), 8 + i, block_half2.as_slice()[i]);
}
compress256(&mut state, slice::from_ref(&block));
Box::new(Digest::new(state))
}

fn update(&self, orig_state: &Digest, bytes: &[u8]) -> Self::DigestPtr {
let mut state = *orig_state.get();
let mut block: GenericArray<u8, U64> = GenericArray::default();

let mut u8s = bytes.iter().fuse().cloned();
let mut off = 0;
while let Some(b1) = u8s.next() {
let b2 = u8s.next().unwrap_or(0);
let b3 = u8s.next().unwrap_or(0);
let b4 = u8s.next().unwrap_or(0);
set_word(
block.as_mut_slice(),
off,
u32::from_le_bytes([b1, b2, b3, b4]),
);
off += 1;
if off == 16 {
compress256(&mut state, slice::from_ref(&block));
off = 0;
}
}
if off != 0 {
block[off * 4..].fill(0);
compress256(&mut state, slice::from_ref(&block));
}
Box::new(Digest::new(state))
}

// Generate a new digest by mixing two digests together via XOR,
// and stores it back in the pool.
fn mix(&self, pool: &mut Self::DigestPtr, val: &Digest) {
Expand Down
43 changes: 43 additions & 0 deletions risc0/zkvm/platform/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ constexpr size_t kGPIO_SendRecvChannel = 0x01F00014;
constexpr size_t kGPIO_SendRecvSize = 0x01F00018;
constexpr size_t kGPIO_SendRecvAddr = 0x01F0001C;

constexpr size_t kGPIO_CycleCount = 0x01F00020;
constexpr size_t kGPIO_InsecureShaCompress = 0x01F00024;
constexpr size_t kGPIO_InsecureShaHash = 0x01F00028;

// Standard ZKVM channels; must match zkvm/sdk/rust/platform/src/io.rs.

// Request the initial input to the guest.
Expand Down Expand Up @@ -121,6 +125,45 @@ inline const char* volatile* GPIO_Log() {
return reinterpret_cast<const char* volatile*>(kGPIO_Log);
}

// To have the host execute a sha256 "compress" operation, write the
// address of an InsecureShaCompressDescriptor to
// kGPIO_InsecureShaCompress. The output state (64 bytes) is written
// to the guest's INPUT area. WARNING: The host calculates this
// independently and does not include the calculation in the proof, so
// this should not be used when security is a concern.
struct InsecureShaCompressDescriptor {
uint32_t state; // Pointer to input state, 64 bytes
uint32_t block_half1; // Pointer to first half of block
uint32_t block_half2; // Pointer to second half of block
};

inline const InsecureShaCompressDescriptor* volatile* GPIO_InsecureShaCompress() {
return reinterpret_cast<const InsecureShaCompressDescriptor* volatile*>(
kGPIO_InsecureShaCompress);
}

// To have the host zero-pad and hash zero or more blocks,write the
// address of a InsecureShaHashDescriptor to kGPIO_InsecureShaHash.
// The output state (64 bytes) is written to the guest's INPUT area.
// WARNING: The host calculates this independently and does not
// include the calculation in the proof, so this should not be used
// when security is a concern.
struct InsecureShaHashDescriptor {
uint32_t state; // Pointer to input state, 64 bytes
uint32_t start; // Pointer to beginning of data region
uint32_t len; // Number of bytes in the region to be hashed
};

inline const InsecureShaHashDescriptor* volatile* GPIO_InsecureShaHash() {
return reinterpret_cast<const InsecureShaHashDescriptor* volatile*>(kGPIO_InsecureShaHash);
}

// To get the current cycle count of the ZKVM, write a 0 to
// GPIO_CycleCount and read one word from the guest's INPUT area.
inline const uint32_t volatile* GPIO_CycleCount() {
return reinterpret_cast<const uint32_t volatile*>(kGPIO_CycleCount);
}

// TODO(nils) Document GetKey.
struct GetKeyDescriptor {
uint32_t name;
Expand Down
94 changes: 84 additions & 10 deletions risc0/zkvm/prove/io_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,18 @@ void MemoryHandler::onInit(MemoryState& mem) {
}
}

void MemoryHandler::sendToGuest(MemoryState& mem, const BufferU32& words) {
if ((cur_host_to_guest_offset + words.size()) >= kMemInputEnd) {
throw(std::runtime_error("Read buffer overrun"));
}
LOG(1, "Filling " << words.size() << " words of guest input area");
for (uint32_t word : words) {
LOG(1, "... " << word);
mem.store(cur_host_to_guest_offset, word);
cur_host_to_guest_offset += sizeof(uint32_t);
}
}

void MemoryHandler::onWrite(MemoryState& mem, uint32_t cycle, uint32_t addr, uint32_t value) {
LOG(2, "MemoryHandler::onWrite> " << hex(addr) << ": " << hex(value));
switch (addr) {
Expand Down Expand Up @@ -133,20 +145,82 @@ void MemoryHandler::onWrite(MemoryState& mem, uint32_t cycle, uint32_t addr, uin
LOG(1,
"MemoryHandler::onWrite> GPIO_SendReceive, host replied with " << result.size()
<< " bytes");
size_t aligned_len = align(result.size());
if ((cur_host_to_guest_offset + sizeof(uint32_t) + aligned_len) >= kMemInputEnd) {
throw(std::runtime_error("Read buffer overrun"));
}
mem.store(cur_host_to_guest_offset, result.size());
cur_host_to_guest_offset += sizeof(uint32_t);
for (size_t i = 0; i < result.size(); ++i) {
mem.storeByte(cur_host_to_guest_offset + i, result[i]);
}
cur_host_to_guest_offset += aligned_len;

// Send length in bytes
uint32_t result_bytes = result.size();
BufferU32 words;
words.push_back(result_bytes);
sendToGuest(mem, words);

// Send the buffer.
words.clear();
words.resize((result.size() + sizeof(uint32_t) - 1) / sizeof(uint32_t), 0);
memcpy(words.data(), result.data(), result_bytes);
sendToGuest(mem, words);
} else {
throw std::runtime_error("SendRecv called with no IO handler set");
}
} break;
case kGPIO_CycleCount: {
LOG(1, "MemoryHandler::onWrite> GPIO_CycleCount, cycle = " << cycle);
if (value != 0) {
throw std::runtime_error("CycleCount GPIO should only be written as zero.");
}

BufferU32 cyclebuf;
cyclebuf.push_back(cycle);
sendToGuest(mem, cyclebuf);
} break;
case kGPIO_InsecureShaCompress: {
LOG(1, "MemoryHandler::onWrite> GPIO_InsecureSha256Compress, descriptor at " << hex(value));
InsecureShaCompressDescriptor desc;
mem.loadRegion(value, &desc, sizeof(desc));
if (!io) {
throw std::runtime_error("InsecureSHA called with no IO handler set");
}
constexpr size_t kDigestWords = 8;

ShaDigest state;
mem.loadRegion(desc.state, &state, sizeof(state));
uint32_t chunk[kDigestWords * 2];
mem.loadRegion(desc.block_half1, &chunk[0], sizeof(chunk) / 2);
mem.loadRegion(desc.block_half2, &chunk[kDigestWords], sizeof(chunk) / 2);

for (auto& i : chunk) {
i = impl::convertU32(i);
}
impl::compress(state, chunk);

BufferU32 result(std::begin(state.words), std::end(state.words));
sendToGuest(mem, result);
} break;
case kGPIO_InsecureShaHash: {
LOG(1, "MemoryHandler::onWrite> GPIO_InsecureSha256Hash, descriptor at " << hex(value));
InsecureShaHashDescriptor desc;
mem.loadRegion(value, &desc, sizeof(desc));
if (!io) {
throw std::runtime_error("InsecureSHA called with no IO handler set");
}
constexpr size_t kDigestWords = 8;

ShaDigest state;
mem.loadRegion(desc.state, &state, sizeof(state));
constexpr size_t kBlockLen = kDigestWords * 2 * sizeof(uint32_t);
for (size_t i = 0; i < desc.len; i += kBlockLen) {
uint32_t chunk[kBlockLen];
size_t chunk_len = std::min<size_t>(i + kBlockLen, desc.len) - i;
mem.loadRegion(desc.start + i, &chunk, chunk_len);
memset(reinterpret_cast<uint8_t*>(&chunk) + chunk_len, 0, kBlockLen - chunk_len);

// Convert to bigendian for impl::compress
for (auto& i : chunk) {
i = impl::convertU32(i);
}
impl::compress(state, chunk);
}
BufferU32 result(std::begin(state.words), std::end(state.words));
sendToGuest(mem, result);
} break;
}
}

Expand Down
3 changes: 3 additions & 0 deletions risc0/zkvm/prove/step.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ class MemoryHandler {
virtual void onHalt(const MemoryState& mem, const std::array<uint32_t, 8>& output) {}

private:
// Copies the given words to the guest's input area, and advances cur_host_to_guest_offset.
void sendToGuest(MemoryState& me, const BufferU32& words);

IoHandler* io;

// Memory address of current host->guest transmission. The host can only
Expand Down
20 changes: 20 additions & 0 deletions risc0/zkvm/sdk/cpp/guest/test/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,26 @@ TEST(CoreTests, Memset) {
TEST(CoreTests, SHAAccel) {
MethodId methodId = loadMethodId("risc0/zkvm/sdk/rust/methods/test_sha_accel.id");
Prover prover("risc0/zkvm/sdk/rust/methods/test_sha_accel", methodId);
prover.writeInput(0); // Test risc0_zkvm_guest::sha::Impl
prover.writeInput(0); // Compute an empty digest
Receipt receipt = prover.run();
receipt.verify(methodId);
}

TEST(CoreTests, InsecureSHAAccel) {
MethodId methodId = loadMethodId("risc0/zkvm/sdk/rust/methods/test_sha_accel.id");
Prover prover("risc0/zkvm/sdk/rust/methods/test_sha_accel", methodId);
prover.writeInput(1); // Test risc0_zkvm_guest::sha::InsecureImpl
prover.writeInput(0); // Compute an empty digest
Receipt receipt = prover.run();
receipt.verify(methodId);
}

// Test simulated SHA cycle count; should be 76
TEST(CoreTests, ShaCycleCount) {
MethodId methodId = loadMethodId("risc0/zkvm/sdk/rust/methods/test_sha_accel.id");
Prover prover("risc0/zkvm/sdk/rust/methods/test_sha_accel", methodId);
prover.writeInput(2);
prover.writeInput(0); // Compute an empty digest
Receipt receipt = prover.run();
receipt.verify(methodId);
Expand Down
25 changes: 22 additions & 3 deletions risc0/zkvm/sdk/rust/guest/src/env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,21 @@ use core::{cell::UnsafeCell, mem::MaybeUninit, slice};
use risc0_zkp::core::sha::Digest;
use risc0_zkvm::serde::{Deserializer, Serializer, Slice};
use risc0_zkvm_platform::{
io::{IoDescriptor, GPIO_COMMIT, SENDRECV_CHANNEL_INITIAL_INPUT, SENDRECV_CHANNEL_STDOUT},
memory, WORD_SIZE,
io::{
IoDescriptor, GPIO_COMMIT, GPIO_CYCLECOUNT, SENDRECV_CHANNEL_INITIAL_INPUT,
SENDRECV_CHANNEL_STDOUT,
},
memory,
rt::host_io::host_recv,
WORD_SIZE,
};
use serde::{Deserialize, Serialize};

use crate::{align_up, memory_barrier, sha};

// Re-export for easy use by user programs.
#[cfg(target_os = "zkvm")]
pub use risc0_zkvm_platform::rt::host_sendrecv;
pub use risc0_zkvm_platform::rt::host_io::host_sendrecv;

#[cfg(not(target_os = "zkvm"))]
// Bazel really wants to compile this file for the host too, so provide a stub.
Expand Down Expand Up @@ -118,6 +123,12 @@ pub fn commit<T: Serialize>(data: &T) {
ENV.get().commit(data);
}

/// Returns the number of processor cycles that have occured since the guest
/// began.
pub fn get_cycle_count() -> usize {
ENV.get().get_cycle_count()
}

impl Env {
fn new() -> Self {
Env {
Expand Down Expand Up @@ -216,4 +227,12 @@ impl Env {
};
sha::finalize();
}

fn get_cycle_count(&self) -> usize {
unsafe { GPIO_CYCLECOUNT.as_ptr().write_volatile(0) }
match host_recv(1) {
&[nbytes] => nbytes as usize,
_ => unreachable!(),
}
}
}
4 changes: 4 additions & 0 deletions risc0/zkvm/sdk/rust/guest/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ pub mod env;
/// Functions for computing SHA-256 hashes.
pub mod sha;

/// Faster than "sha", but delegates to host so should not be trusted
/// to prove anything.
pub mod sha_insecure;

use core::{arch::asm, mem, panic::PanicInfo, ptr};

extern "C" {
Expand Down