Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: serializable execution record #328

Merged
merged 31 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
238 changes: 173 additions & 65 deletions Cargo.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.1.0"

[dependencies]
bincode = "1.3.3"
serde = { version = "1.0", features = ["derive", "rc"] }
elf = "0.7.4"
itertools = "0.12.0"
lazy_static = "1.4"
Expand All @@ -31,7 +32,6 @@ p3-symmetric = {workspace = true}
p3-uni-stark = {workspace = true}
p3-util = {workspace = true}
rrs-lib = {git = "https://github.com/GregAC/rrs.git"}
serde = {version = "1.0", features = ["derive"]}
sp1-derive = {path = "../derive"}

anyhow = "1.0.79"
Expand All @@ -46,6 +46,7 @@ hashbrown = "0.14.3"
hex = "0.4.3"
k256 = {version = "0.13.3", features = ["expose-field"]}
num_cpus = "1.16.0"
serde_with = "3.6.1"
petgraph = "0.6.4"
serde_json = {version = "1.0.113", default-features = false, features = [
"alloc",
Expand Down
2 changes: 2 additions & 0 deletions core/src/alu/lt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ impl<F: PrimeField> MachineAir<F> for LtChip {
"Lt".to_string()
}

fn generate_dependencies(&self, _input: &ExecutionRecord, _output: &mut ExecutionRecord) {}

#[instrument(name = "generate lt trace", skip_all)]
fn generate_trace(
&self,
Expand Down
4 changes: 3 additions & 1 deletion core/src/alu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ pub use sll::*;
pub use sr::*;
pub use sub::*;

use serde::{Deserialize, Serialize};

use crate::runtime::Opcode;

/// A standard format for describing ALU operations that need to be proven.
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct AluEvent {
/// The clock cycle that the operation occurs on.
pub clk: u32,
Expand Down
3 changes: 2 additions & 1 deletion core/src/bytes/event.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use super::ByteOpcode;
use serde::{Deserialize, Serialize};

/// A byte lookup event.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct ByteLookupEvent {
/// The opcode of the operation.
pub opcode: ByteOpcode,
Expand Down
3 changes: 2 additions & 1 deletion core/src/bytes/opcode.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use p3_field::Field;
use serde::{Deserialize, Serialize};

use crate::{bytes::NUM_BYTE_OPS, runtime::Opcode};

/// A byte opcode which the chip can process.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub enum ByteOpcode {
/// Bitwise AND.
AND = 0,
Expand Down
4 changes: 3 additions & 1 deletion core/src/cpu/event.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use serde::{Deserialize, Serialize};

use crate::runtime::Instruction;

use super::memory::MemoryRecordEnum;

/// A standard format for describing CPU operations that need to be proven.
#[derive(Debug, Copy, Clone)]
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
pub struct CpuEvent {
/// The current shard.
pub shard: u32,
Expand Down
10 changes: 6 additions & 4 deletions core/src/cpu/memory.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
#[derive(Debug, Copy, Clone)]
use serde::{Deserialize, Serialize};

#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
pub enum MemoryRecordEnum {
Read(MemoryReadRecord),
Write(MemoryWriteRecord),
}

#[derive(Debug, Copy, Clone, Default)]
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
pub struct MemoryRecord {
pub value: u32,
pub shard: u32,
pub timestamp: u32,
}

#[derive(Debug, Copy, Clone, Default)]
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
#[non_exhaustive]
pub struct MemoryReadRecord {
pub value: u32,
Expand All @@ -21,7 +23,7 @@ pub struct MemoryReadRecord {
pub prev_timestamp: u32,
}

#[derive(Debug, Copy, Clone, Default)]
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
#[non_exhaustive]
pub struct MemoryWriteRecord {
pub value: u32,
Expand Down
49 changes: 20 additions & 29 deletions core/src/cpu/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,44 +72,35 @@ impl<F: PrimeField> MachineAir<F> for CpuChip {

#[instrument(name = "generate CPU dependencies", skip_all)]
fn generate_dependencies(&self, input: &ExecutionRecord, output: &mut ExecutionRecord) {
let mut new_alu_events = HashMap::with_capacity(input.cpu_events.len());
let mut new_blu_events = Vec::with_capacity(input.cpu_events.len());
let mut new_field_events: Vec<FieldEvent> = Vec::with_capacity(input.cpu_events.len());

// Generate the trace rows for each event.
let chunk_size = std::cmp::max(input.cpu_events.len() / num_cpus::get(), 1);
let events = input
.cpu_events
.par_chunks(chunk_size)
.map(|ops: &[CpuEvent]| {
ops.iter()
.map(|op| {
let (_, alu_events, blu_events, field_events) = self.event_to_row::<F>(*op);
(alu_events, blu_events, field_events)
})
.collect::<Vec<_>>()
let mut alu = HashMap::new();
let mut blu: Vec<_> = Vec::default();
let mut field: Vec<_> = Vec::default();
ops.iter().for_each(|op| {
let (_, alu_events, blu_events, field_events) = self.event_to_row::<F>(*op);
alu_events.into_iter().for_each(|(key, value)| {
alu.entry(key).or_insert(Vec::default()).extend(value);
});
blu.extend(blu_events);
field.extend(field_events);
});
(alu, blu, field)
})
.flatten()
.collect::<Vec<_>>();

events.into_iter().for_each(|e| {
let (alu_events, blu_events, field_events) = e;
for (key, value) in alu_events {
new_alu_events
.entry(key)
.and_modify(|op_new_events: &mut Vec<AluEvent>| {
op_new_events.extend(value.clone())
})
.or_insert(value);
}
new_blu_events.extend(blu_events);
new_field_events.extend(field_events);
});

// Add the dependency events to the shard.
output.add_alu_events(new_alu_events);
output.add_byte_lookup_events(new_blu_events);
output.add_field_events(&new_field_events);
events
.into_iter()
.for_each(|(alu_events, blu_events, field_events)| {
// Add the dependency events to the shard.
output.add_alu_events(alu_events);
output.add_byte_lookup_events(blu_events);
output.add_field_events(&field_events);
});
}
}

Expand Down
4 changes: 3 additions & 1 deletion core/src/field/event.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use serde::{Deserialize, Serialize};

/// A standard format for proving operations over a triplet of field elements.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct FieldEvent {
pub ltu: bool,
pub b: u32,
Expand Down
11 changes: 5 additions & 6 deletions core/src/field/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,19 @@ impl<F: PrimeField> MachineAir<F> for FieldLtuChip {
"FieldLTU".to_string()
}

fn generate_dependencies(&self, _input: &ExecutionRecord, _output: &mut ExecutionRecord) {}

#[instrument(name = "generate FieldLTU trace", skip_all)]
fn generate_trace(
&self,
input: &ExecutionRecord,
_output: &mut ExecutionRecord,
) -> RowMajorMatrix<F> {
// Generate the trace rows for each event.
let rows = input
let cols = input
.field_events
.par_iter()
.map(|event| {
.flat_map_iter(|event| {
let mut row = [F::zero(); NUM_FIELD_COLS];
let cols: &mut FieldLtuCols<F> = row.as_mut_slice().borrow_mut();
let diff = event.b.wrapping_sub(event.c).wrapping_add(1 << LTU_NB_BITS);
Expand All @@ -82,10 +84,7 @@ impl<F: PrimeField> MachineAir<F> for FieldLtuChip {
.collect::<Vec<_>>();

// Convert the trace to a row major matrix.
let mut trace = RowMajorMatrix::new(
rows.into_iter().flatten().collect::<Vec<_>>(),
NUM_FIELD_COLS,
);
let mut trace = RowMajorMatrix::new(cols, NUM_FIELD_COLS);

// Pad the trace to a power of two.
pad_to_power_of_two::<NUM_FIELD_COLS, F>(&mut trace.values);
Expand Down
4 changes: 3 additions & 1 deletion core/src/runtime/instruction.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use core::fmt::Debug;

use serde::{Deserialize, Serialize};

use super::Opcode;

/// An instruction specifies an operation to execute and the operands.
#[derive(Clone, Copy)]
#[derive(Clone, Copy, Serialize, Deserialize)]
pub struct Instruction {
pub opcode: Opcode,
pub op_a: u32,
Expand Down
1 change: 1 addition & 0 deletions core/src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ impl Runtime {

#[inline]
fn validate_memory_access(&self, addr: u32, position: AccessPosition) {
#[cfg(debug_assertions)]
if position == AccessPosition::Memory {
assert_eq!(addr % 4, 0, "addr is not aligned");
let _ = BabyBear::from_canonical_u32(addr);
Expand Down
3 changes: 2 additions & 1 deletion core/src/runtime/opcode.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::fmt::Display;

use p3_field::Field;
use serde::{Deserialize, Serialize};

/// An opcode specifies which operation to execute.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[allow(non_camel_case_types)]
pub enum Opcode {
// Arithmetic instructions.
Expand Down
3 changes: 2 additions & 1 deletion core/src/runtime/program.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;

use super::Instruction;

/// A program that can be executed by the VM.
#[derive(Debug, Clone, Default)]
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Program {
/// The instructions of the program.
pub instructions: Vec<Instruction>,
Expand Down