Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: derive MachineAir, chip and machine cleanup #278

Merged
merged 21 commits into from
Feb 21, 2024
Merged
17 changes: 17 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ k256 = { version = "0.13.3", features = ["expose-field"] }
elliptic-curve = "0.13.8"
anyhow = "1.0.79"
serial_test = "3.0.0"
petgraph = "0.6.4"
tiny-keccak = { version = "2.0.2", features = ["keccak"] }

[dev-dependencies]
Expand Down
2 changes: 2 additions & 0 deletions core/src/air/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use p3_matrix::dense::RowMajorMatrix;

use crate::runtime::{ExecutionRecord, Program};

pub use sp1_derive::MachineAir;

/// An AIR that is part of a Risc-V AIR arithmetization.
pub trait MachineAir<F: Field>: BaseAir<F> {
/// A unique identifier for this AIR as part of a machine.
Expand Down
4 changes: 2 additions & 2 deletions core/src/bytes/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ pub(crate) const BYTE_COL_MAP: ByteCols<usize> = make_col_map();
/// The multiplicity indices for each byte operation.
pub(crate) const BYTE_MULT_INDICES: [usize; NUM_BYTE_OPS] = BYTE_COL_MAP.multiplicities;

impl<F: Field> BaseAir<F> for ByteChip {
impl<F: Field> BaseAir<F> for ByteChip<F> {
fn width(&self) -> usize {
NUM_BYTE_COLS
}
}

impl<AB: SP1AirBuilder> Air<AB> for ByteChip {
impl<AB: SP1AirBuilder> Air<AB> for ByteChip<AB::F> {
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let local: &ByteCols<AB::Var> = main.row_slice(0).borrow();
Expand Down
8 changes: 4 additions & 4 deletions core/src/bytes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub use event::ByteLookupEvent;
use itertools::Itertools;
use p3_field::Field;
use p3_matrix::dense::RowMajorMatrix;
use std::marker::PhantomData;

use self::columns::{ByteCols, NUM_BYTE_COLS};
use self::utils::shr_carry;
Expand All @@ -26,17 +27,16 @@ pub const NUM_BYTE_OPS: usize = 9;
/// The chip contains a preprocessed table of all possible byte operations. Other chips can then
/// use lookups into this table to compute their own operations.
#[derive(Debug, Clone, Copy, Default)]
pub struct ByteChip;
pub struct ByteChip<F>(PhantomData<F>);

impl ByteChip {
impl<F: Field> ByteChip<F> {
/// Creates the preprocessed byte trace and event map.
///
/// This function returns a pair `(trace, map)`, where:
/// - `trace` is a matrix containing all possible byte operations.
/// - `map` is a map map from a byte lookup to the corresponding row it appears in the table and
/// the index of the result in the array of multiplicities.
pub fn trace_and_map<F: Field>(
) -> (RowMajorMatrix<F>, BTreeMap<ByteLookupEvent, (usize, usize)>) {
pub fn trace_and_map() -> (RowMajorMatrix<F>, BTreeMap<ByteLookupEvent, (usize, usize)>) {
// A map from a byte lookup to its corresponding row in the table and index in the array of
// multiplicities.
let mut event_map = BTreeMap::new();
Expand Down
4 changes: 2 additions & 2 deletions core/src/bytes/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::{air::MachineAir, runtime::ExecutionRecord};

pub const NUM_ROWS: usize = 1 << 16;

impl<F: Field> MachineAir<F> for ByteChip {
impl<F: Field> MachineAir<F> for ByteChip<F> {
fn name(&self) -> String {
"Byte".to_string()
}
Expand All @@ -16,7 +16,7 @@ impl<F: Field> MachineAir<F> for ByteChip {
input: &ExecutionRecord,
_output: &mut ExecutionRecord,
) -> RowMajorMatrix<F> {
let (mut trace, event_map) = ByteChip::trace_and_map::<F>();
let (mut trace, event_map) = ByteChip::trace_and_map();

for (lookup, mult) in input.byte_lookups.iter() {
let (row, index) = event_map[lookup];
Expand Down
8 changes: 4 additions & 4 deletions core/src/lookup/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use p3_matrix::Matrix;

use crate::air::MachineAir;
use crate::runtime::ExecutionRecord;
use crate::stark::{ChipRef, StarkGenericConfig};
use crate::stark::{RiscvChip, StarkGenericConfig};

use super::InteractionKind;

Expand Down Expand Up @@ -44,7 +44,7 @@ fn babybear_to_int(n: BabyBear) -> i32 {
}

pub fn debug_interactions<SC: StarkGenericConfig>(
chip: &ChipRef<SC>,
chip: &RiscvChip<SC>,
record: &ExecutionRecord,
interaction_kinds: Vec<InteractionKind>,
) -> (
Expand Down Expand Up @@ -110,14 +110,14 @@ pub fn debug_interactions<SC: StarkGenericConfig>(
/// Calculate the the number of times we send and receive each event of the given interaction type,
/// and print out the ones for which the set of sends and receives don't match.
pub fn debug_interactions_with_all_chips<SC: StarkGenericConfig<Val = BabyBear>>(
chips: &[ChipRef<SC>],
chips: &[RiscvChip<SC>],
segment: &ExecutionRecord,
interaction_kinds: Vec<InteractionKind>,
) -> bool {
let mut final_map = BTreeMap::new();

for chip in chips.iter() {
let (_, count) = debug_interactions(chip, segment, interaction_kinds.clone());
let (_, count) = debug_interactions::<SC>(chip, segment, interaction_kinds.clone());

tracing::debug!("{} chip has {} distinct events", chip.name(), count.len());
for (key, value) in count.iter() {
Expand Down
2 changes: 1 addition & 1 deletion core/src/lookup/interaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub struct Interaction<F: Field> {
}

/// The type of interaction for a lookup argument.
#[derive(Debug, Clone, Copy, PartialEq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum InteractionKind {
/// Interaction with the memory table, such as read and write.
Memory = 1,
Expand Down
8 changes: 4 additions & 4 deletions core/src/memory/global.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ mod tests {
runtime.run();

let machine = RiscvStark::new(BabyBearPoseidon2::new());
debug_interactions_with_all_chips(
&machine.chips(),
debug_interactions_with_all_chips::<BabyBearPoseidon2>(
machine.chips(),
&runtime.record,
vec![InteractionKind::Memory],
);
Expand All @@ -219,8 +219,8 @@ mod tests {
runtime.run();

let machine = RiscvStark::new(BabyBearPoseidon2::new());
debug_interactions_with_all_chips(
&machine.chips(),
debug_interactions_with_all_chips::<BabyBearPoseidon2>(
machine.chips(),
&runtime.record,
vec![InteractionKind::Byte],
);
Expand Down
198 changes: 198 additions & 0 deletions core/src/stark/air.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
use crate::air::MachineAir;
pub use crate::air::SP1AirBuilder;
use crate::memory::MemoryChipKind;
use crate::runtime::ExecutionRecord;
use p3_field::PrimeField32;
pub use riscv_chips::*;

/// A module for importing all the different RISC-V chips.
pub(crate) mod riscv_chips {
pub use crate::alu::AddChip;
pub use crate::alu::BitwiseChip;
pub use crate::alu::DivRemChip;
pub use crate::alu::LtChip;
pub use crate::alu::MulChip;
pub use crate::alu::ShiftLeft;
pub use crate::alu::ShiftRightChip;
pub use crate::alu::SubChip;
pub use crate::bytes::ByteChip;
pub use crate::cpu::CpuChip;
pub use crate::field::FieldLTUChip;
pub use crate::memory::MemoryGlobalChip;
pub use crate::program::ProgramChip;
pub use crate::syscall::precompiles::blake3::Blake3CompressInnerChip;
pub use crate::syscall::precompiles::edwards::EdAddAssignChip;
pub use crate::syscall::precompiles::edwards::EdDecompressChip;
pub use crate::syscall::precompiles::k256::K256DecompressChip;
pub use crate::syscall::precompiles::keccak256::KeccakPermuteChip;
pub use crate::syscall::precompiles::sha256::ShaCompressChip;
pub use crate::syscall::precompiles::sha256::ShaExtendChip;
pub use crate::syscall::precompiles::weierstrass::WeierstrassAddAssignChip;
pub use crate::syscall::precompiles::weierstrass::WeierstrassDoubleAssignChip;
pub use crate::utils::ec::edwards::ed25519::Ed25519Parameters;
pub use crate::utils::ec::edwards::EdwardsCurve;
pub use crate::utils::ec::weierstrass::secp256k1::Secp256k1Parameters;
pub use crate::utils::ec::weierstrass::SWCurve;
}

/// An AIR for encoding RISC-V execution.
///
/// This enum contains all the different AIRs that are used in the Sp1 RISC-V IOP. Each variant is
/// a different AIR that is used to encode a different part of the RISC-V execution, and the
/// different AIR variants have a joint lookup argument.
#[derive(MachineAir)]
pub enum RiscvAir<F: PrimeField32> {
/// An AIR that containts a preprocessed program table and a lookup for the instructions.
Program(ProgramChip),
/// An AIR for the RISC-V CPU. Each row represents a cpu cycle.
Cpu(CpuChip),
/// An AIR for the RISC-V Add instruction.
Add(AddChip),
/// An AIR for the RISC-V Sub instruction.
Sub(SubChip),
/// An AIR for RISC-V Bitwise instructions.
Bitwise(BitwiseChip),
/// An AIR for RISC-V Mul instruction.
Mul(MulChip),
/// An AIR for RISC-V Div and Rem instructions.
DivRem(DivRemChip),
/// An AIR for RISC-V Lt instruction.
Lt(LtChip),
/// An AIR for RISC-V SLL instruction.
ShiftLeft(ShiftLeft),
/// An AIR for RISC-V SRL and SRA instruction.
ShiftRight(ShiftRightChip),
/// A lookup table for byte operations.
ByteLookup(ByteChip<F>),
/// An table for `less than` operation on field elements.
FieldLTU(FieldLTUChip),
/// A table for initializing the memory state.
MemoryInit(MemoryGlobalChip),
/// A table for finalizing the memory state.
MemoryFinal(MemoryGlobalChip),
/// A table for initializing the program memory.
ProgramMemory(MemoryGlobalChip),
/// A precompile for sha256 extend.
Sha256Extend(ShaExtendChip),
/// A precompile for sha256 compress.
Sha256Compress(ShaCompressChip),
/// A precompile for addition on the Elliptic curve ed25519.
Ed25519Add(EdAddAssignChip<EdwardsCurve<Ed25519Parameters>>),
/// A precompile for decompressing a point on the Edwards curve ed25519.
Ed25519Decompress(EdDecompressChip<Ed25519Parameters>),
/// A precompile for decompressing a point on the K256 curve.
K256Decompress(K256DecompressChip),
/// A precompile for addition on the Elliptic curve secp256k1.
Secp256k1Add(WeierstrassAddAssignChip<SWCurve<Secp256k1Parameters>>),
/// A precompile for doubling a point on the Elliptic curve secp256k1.
Secp256k1Double(WeierstrassDoubleAssignChip<SWCurve<Secp256k1Parameters>>),
/// A precompile for the Keccak permutation.
KeccakP(KeccakPermuteChip),
/// A precompile for the Blake3 compression function.
Blake3Compress(Blake3CompressInnerChip),
}

impl<F: PrimeField32> RiscvAir<F> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like if you are going to go the derive route, then included should also be derived?

/// Get all the different RISC-V AIRs.
pub fn get_all() -> Vec<Self> {
// The order of the chips is important, as it is used to determine the order of trace
// generation. In the future, we will detect that order automatically.
let mut chips = vec![];
let cpu = CpuChip::default();
chips.push(RiscvAir::Cpu(cpu));
let program = ProgramChip::default();
chips.push(RiscvAir::Program(program));
let sha_extend = ShaExtendChip::default();
chips.push(RiscvAir::Sha256Extend(sha_extend));
let sha_compress = ShaCompressChip::default();
chips.push(RiscvAir::Sha256Compress(sha_compress));
let ed_add_assign = EdAddAssignChip::<EdwardsCurve<Ed25519Parameters>>::new();
chips.push(RiscvAir::Ed25519Add(ed_add_assign));
let ed_decompress = EdDecompressChip::<Ed25519Parameters>::default();
chips.push(RiscvAir::Ed25519Decompress(ed_decompress));
let k256_decompress = K256DecompressChip::default();
chips.push(RiscvAir::K256Decompress(k256_decompress));
let weierstrass_add_assign =
WeierstrassAddAssignChip::<SWCurve<Secp256k1Parameters>>::new();
chips.push(RiscvAir::Secp256k1Add(weierstrass_add_assign));
let weierstrass_double_assign =
WeierstrassDoubleAssignChip::<SWCurve<Secp256k1Parameters>>::new();
chips.push(RiscvAir::Secp256k1Double(weierstrass_double_assign));
let keccak_permute = KeccakPermuteChip::new();
chips.push(RiscvAir::KeccakP(keccak_permute));
let blake3_compress_inner = Blake3CompressInnerChip::new();
chips.push(RiscvAir::Blake3Compress(blake3_compress_inner));
let add = AddChip::default();
chips.push(RiscvAir::Add(add));
let sub = SubChip::default();
chips.push(RiscvAir::Sub(sub));
let bitwise = BitwiseChip::default();
chips.push(RiscvAir::Bitwise(bitwise));
let div_rem = DivRemChip::default();
chips.push(RiscvAir::DivRem(div_rem));
let mul = MulChip::default();
chips.push(RiscvAir::Mul(mul));
let shift_right = ShiftRightChip::default();
chips.push(RiscvAir::ShiftRight(shift_right));
let shift_left = ShiftLeft::default();
chips.push(RiscvAir::ShiftLeft(shift_left));
let lt = LtChip::default();
chips.push(RiscvAir::Lt(lt));
let memory_init = MemoryGlobalChip::new(MemoryChipKind::Init);
chips.push(RiscvAir::MemoryInit(memory_init));
let memory_finalize = MemoryGlobalChip::new(MemoryChipKind::Finalize);
chips.push(RiscvAir::MemoryFinal(memory_finalize));
let program_memory_init = MemoryGlobalChip::new(MemoryChipKind::Program);
chips.push(RiscvAir::ProgramMemory(program_memory_init));
let field_ltu = FieldLTUChip::default();
chips.push(RiscvAir::FieldLTU(field_ltu));
let byte = ByteChip::default();
chips.push(RiscvAir::ByteLookup(byte));

chips
}

/// Returns `true` if the given `shard` includes events for this AIR.
pub fn included(&self, shard: &ExecutionRecord) -> bool {
match self {
RiscvAir::Program(_) => true,
RiscvAir::Cpu(_) => true,
RiscvAir::Add(_) => !shard.add_events.is_empty(),
RiscvAir::Sub(_) => !shard.sub_events.is_empty(),
RiscvAir::Bitwise(_) => !shard.bitwise_events.is_empty(),
RiscvAir::Mul(_) => !shard.mul_events.is_empty(),
RiscvAir::DivRem(_) => !shard.divrem_events.is_empty(),
RiscvAir::Lt(_) => !shard.lt_events.is_empty(),
RiscvAir::ShiftLeft(_) => !shard.shift_left_events.is_empty(),
RiscvAir::ShiftRight(_) => !shard.shift_right_events.is_empty(),
RiscvAir::ByteLookup(_) => !shard.byte_lookups.is_empty(),
RiscvAir::FieldLTU(_) => !shard.field_events.is_empty(),
RiscvAir::MemoryInit(_) => !shard.first_memory_record.is_empty(),
RiscvAir::MemoryFinal(_) => !shard.last_memory_record.is_empty(),
RiscvAir::ProgramMemory(_) => !shard.program_memory_record.is_empty(),
RiscvAir::Sha256Extend(_) => !shard.sha_extend_events.is_empty(),
RiscvAir::Sha256Compress(_) => !shard.sha_compress_events.is_empty(),
RiscvAir::Ed25519Add(_) => !shard.ed_add_events.is_empty(),
RiscvAir::Ed25519Decompress(_) => !shard.ed_decompress_events.is_empty(),
RiscvAir::K256Decompress(_) => !shard.k256_decompress_events.is_empty(),
RiscvAir::Secp256k1Add(_) => !shard.weierstrass_add_events.is_empty(),
RiscvAir::Secp256k1Double(_) => !shard.weierstrass_double_events.is_empty(),
RiscvAir::KeccakP(_) => !shard.keccak_permute_events.is_empty(),
RiscvAir::Blake3Compress(_) => !shard.blake3_compress_inner_events.is_empty(),
}
}
}

impl<F: PrimeField32> PartialEq for RiscvAir<F> {
fn eq(&self, other: &Self) -> bool {
self.name() == other.name()
}
}

impl<F: PrimeField32> Eq for RiscvAir<F> {}

impl<F: PrimeField32> core::hash::Hash for RiscvAir<F> {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
self.name().hash(state);
}
}