diff --git a/.github/workflows/lints.yml b/.github/workflows/lints.yml index 29ff53880..8c08a3662 100644 --- a/.github/workflows/lints.yml +++ b/.github/workflows/lints.yml @@ -66,3 +66,5 @@ jobs: run: taplo --version || cargo install taplo-cli - name: Run taplo run: taplo fmt --check --diff + - name: Ensure Cargo.lock not modified by build + run: git diff --exit-code Cargo.lock diff --git a/Cargo.lock b/Cargo.lock index d3cb607be..499d14457 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -939,7 +939,7 @@ dependencies = [ [[package]] name = "ceno_crypto_primitives" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno#4a9dff21fd408e93c21edb6e874a09b0171b0c8b" +source = "git+https://github.com/scroll-tech/ceno#050108047aad24101fcb010da4e7d29e9d72678a" dependencies = [ "ceno_syscall 0.1.0 (git+https://github.com/scroll-tech/ceno)", "elliptic-curve", @@ -958,9 +958,12 @@ dependencies = [ "multilinear_extensions", "num-derive", "num-traits", + "rayon", "rrs-succinct", + "rustc-hash", "secp", "serde", + "smallvec", "strum", "strum_macros", "substrate-bn 0.6.0 (registry+https://github.com/rust-lang/crates.io-index)", @@ -1013,7 +1016,7 @@ version = "0.1.0" [[package]] name = "ceno_syscall" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno#4a9dff21fd408e93c21edb6e874a09b0171b0c8b" +source = "git+https://github.com/scroll-tech/ceno#050108047aad24101fcb010da4e7d29e9d72678a" [[package]] name = "ceno_zkvm" @@ -1041,6 +1044,7 @@ dependencies = [ "multilinear_extensions", "ndarray", "num", + "num-bigint", "once_cell", "p3", "parse-size", @@ -1049,6 +1053,7 @@ dependencies = [ "proptest", "rand 0.8.5", "rayon", + "rustc-hash", "serde", "serde_json", "sp1-curves", @@ -1851,7 +1856,7 @@ dependencies = [ "ceno_syscall 0.1.0", "getrandom 0.3.2", "rand 0.8.5", - "revm-precompile 28.1.0", + "revm-precompile 28.1.1", "rkyv", "substrate-bn 0.6.0 (registry+https://github.com/rust-lang/crates.io-index)", "substrate-bn 0.6.0 (git+https://github.com/scroll-tech/bn?branch=ceno)", @@ -1899,7 +1904,7 @@ dependencies = [ [[package]] name = "ff_ext" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#b155b40356762733fac4264444a5e8cef323607e" dependencies = [ "once_cell", "p3", @@ -2711,7 +2716,7 @@ dependencies = [ [[package]] name = "mpcs" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#b155b40356762733fac4264444a5e8cef323607e" dependencies = [ "bincode", "clap", @@ -2735,7 +2740,7 @@ dependencies = [ [[package]] name = "multilinear_extensions" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#b155b40356762733fac4264444a5e8cef323607e" dependencies = [ "either", "ff_ext", @@ -3056,8 +3061,9 @@ dependencies = [ [[package]] name = "p3" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#b155b40356762733fac4264444a5e8cef323607e" dependencies = [ + "p3-air", "p3-baby-bear", "p3-challenger", "p3-commit", @@ -3069,12 +3075,23 @@ dependencies = [ "p3-maybe-rayon", "p3-mds", "p3-merkle-tree", + "p3-monty-31", "p3-poseidon", "p3-poseidon2", + "p3-poseidon2-air", "p3-symmetric", "p3-util", ] +[[package]] +name = "p3-air" +version = "0.1.0" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" +dependencies = [ + "p3-field", + "p3-matrix", +] + [[package]] name = "p3-baby-bear" version = "0.1.0" @@ -3290,6 +3307,22 @@ dependencies = [ "rand 0.8.5", ] +[[package]] +name = "p3-poseidon2-air" +version = "0.1.0" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" +dependencies = [ + "p3-air", + "p3-field", + "p3-matrix", + "p3-maybe-rayon", + "p3-poseidon2", + "p3-util", + "rand 0.8.5", + "tikv-jemallocator", + "tracing", +] + [[package]] name = "p3-symmetric" version = "0.1.0" @@ -3465,7 +3498,7 @@ dependencies = [ [[package]] name = "poseidon" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#b155b40356762733fac4264444a5e8cef323607e" dependencies = [ "ff_ext", "p3", @@ -3874,9 +3907,9 @@ dependencies = [ [[package]] name = "revm-precompile" -version = "28.1.0" +version = "28.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "176169b39beb1f57b11f2ea3900c404b8498a56dfd8394e66f4d24f66cea368e" +checksum = "e57aadd7a2087705f653b5aaacc8ad4f8e851f5d330661e3f4c43b5475bbceae" dependencies = [ "ark-bls12-381", "ark-bn254", @@ -3888,7 +3921,7 @@ dependencies = [ "cfg-if", "k256 0.13.4 (registry+https://github.com/rust-lang/crates.io-index)", "p256", - "revm-primitives 21.0.0", + "revm-primitives 21.0.1", "ripemd", "sha2 0.10.9", ] @@ -3907,9 +3940,9 @@ dependencies = [ [[package]] name = "revm-primitives" -version = "21.0.0" +version = "21.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38271b8b85f00154bdcf9f2ab0a3ec7a8100377d2c7a0d8eb23e19389b42c795" +checksum = "536f30e24c3c2bf0d3d7d20fa9cf99b93040ed0f021fd9301c78cddb0dacda13" dependencies = [ "alloy-primitives", "num_enum 0.7.4", @@ -4432,6 +4465,9 @@ name = "smallvec" version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8917285742e9f3e1683f0a9c4e6b57960b7314d0b08d30d1ecd426713ee2eee9" +dependencies = [ + "serde", +] [[package]] name = "snowbridge-amcl" @@ -4446,7 +4482,7 @@ dependencies = [ [[package]] name = "sp1-curves" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#b155b40356762733fac4264444a5e8cef323607e" dependencies = [ "cfg-if", "dashu", @@ -4568,7 +4604,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#b155b40356762733fac4264444a5e8cef323607e" dependencies = [ "either", "ff_ext", @@ -4586,7 +4622,7 @@ dependencies = [ [[package]] name = "sumcheck_macro" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#b155b40356762733fac4264444a5e8cef323607e" dependencies = [ "itertools 0.13.0", "p3", @@ -4981,7 +5017,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#b155b40356762733fac4264444a5e8cef323607e" dependencies = [ "ff_ext", "itertools 0.13.0", @@ -5253,7 +5289,7 @@ dependencies = [ [[package]] name = "whir" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#b155b40356762733fac4264444a5e8cef323607e" dependencies = [ "bincode", "clap", @@ -5540,7 +5576,7 @@ dependencies = [ [[package]] name = "witness" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#b155b40356762733fac4264444a5e8cef323607e" dependencies = [ "ff_ext", "multilinear_extensions", diff --git a/Cargo.toml b/Cargo.toml index 7acdd2d7a..a29a1807a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,16 +23,16 @@ repository = "https://github.com/scroll-tech/ceno" version = "0.1.0" [workspace.dependencies] -ff_ext = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "ff_ext", rev = "v1.0.0-alpha.9" } -mpcs = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "mpcs", rev = "v1.0.0-alpha.9" } -multilinear_extensions = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "multilinear_extensions", rev = "v1.0.0-alpha.9" } -p3 = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "p3", rev = "v1.0.0-alpha.9" } -poseidon = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "poseidon", rev = "v1.0.0-alpha.9" } -sp1-curves = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sp1-curves", rev = "v1.0.0-alpha.9" } -sumcheck = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sumcheck", rev = "v1.0.0-alpha.9" } -transcript = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "transcript", rev = "v1.0.0-alpha.9" } -whir = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "whir", rev = "v1.0.0-alpha.9" } -witness = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "witness", rev = "v1.0.0-alpha.9" } +ff_ext = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "ff_ext", branch = "chore/sw_curve_default" } +mpcs = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "mpcs", branch = "chore/sw_curve_default" } +multilinear_extensions = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "multilinear_extensions", branch = "chore/sw_curve_default" } +p3 = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "p3", branch = "chore/sw_curve_default" } +poseidon = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "poseidon", branch = "chore/sw_curve_default" } +sp1-curves = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sp1-curves", branch = "chore/sw_curve_default" } +sumcheck = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sumcheck", branch = "chore/sw_curve_default" } +transcript = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "transcript", branch = "chore/sw_curve_default" } +whir = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "whir", branch = "chore/sw_curve_default" } +witness = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "witness", branch = "chore/sw_curve_default" } alloy-primitives = "1.3" anyhow = { version = "1.0", default-features = false } @@ -57,9 +57,17 @@ rand_chacha = { version = "0.3", features = ["serde1"] } rand_core = "0.6" rayon = "1.10" rkyv = { version = "0.8", features = ["pointer_width_32"] } +rustc-hash = "2.0.0" secp = "0.4.1" serde = { version = "1.0", features = ["derive", "rc"] } serde_json = "1.0" +smallvec = { version = "1.13.2", features = [ + "const_generics", + "const_new", + "serde", + "union", + "write", +] } strum = "0.26" strum_macros = "0.26" substrate-bn = { version = "0.6.0" } @@ -91,13 +99,14 @@ lto = "thin" # [patch."ssh://git@github.com/scroll-tech/ceno-gpu.git"] # ceno_gpu = { path = "../ceno-gpu/cuda_hal", package = "cuda_hal" } -#[patch."https://github.com/scroll-tech/gkr-backend"] -#ff_ext = { path = "../gkr-backend/crates/ff_ext", package = "ff_ext" } -#mpcs = { path = "../gkr-backend/crates/mpcs", package = "mpcs" } -#multilinear_extensions = { path = "../gkr-backend/crates/multilinear_extensions", package = "multilinear_extensions" } -#p3 = { path = "../gkr-backend/crates/p3", package = "p3" } -#poseidon = { path = "../gkr-backend/crates/poseidon", package = "poseidon" } -#sumcheck = { path = "../gkr-backend/crates/sumcheck", package = "sumcheck" } -#transcript = { path = "../gkr-backend/crates/transcript", package = "transcript" } -#whir = { path = "../gkr-backend/crates/whir", package = "whir" } -#witness = { path = "../gkr-backend/crates/witness", package = "witness" } +# [patch."https://github.com/scroll-tech/gkr-backend"] +# ff_ext = { path = "../gkr-backend/crates/ff_ext", package = "ff_ext" } +# mpcs = { path = "../gkr-backend/crates/mpcs", package = "mpcs" } +# multilinear_extensions = { path = "../gkr-backend/crates/multilinear_extensions", package = "multilinear_extensions" } +# p3 = { path = "../gkr-backend/crates/p3", package = "p3" } +# poseidon = { path = "../gkr-backend/crates/poseidon", package = "poseidon" } +# sp1-curves = { path = "../gkr-backend/crates/curves", package = "sp1-curves" } +# sumcheck = { path = "../gkr-backend/crates/sumcheck", package = "sumcheck" } +# transcript = { path = "../gkr-backend/crates/transcript", package = "transcript" } +# whir = { path = "../gkr-backend/crates/whir", package = "whir" } +# witness = { path = "../gkr-backend/crates/witness", package = "witness" } diff --git a/ceno_cli/src/commands/common_args/ceno.rs b/ceno_cli/src/commands/common_args/ceno.rs index 9632986a8..d73841080 100644 --- a/ceno_cli/src/commands/common_args/ceno.rs +++ b/ceno_cli/src/commands/common_args/ceno.rs @@ -78,6 +78,14 @@ pub struct CenoOptions { #[arg(long)] pub out_vk: Option, + /// shard id + #[arg(long, default_value = "0")] + shard_id: u32, + + /// number of total shards. + #[arg(long, default_value = "1")] + max_num_shards: u32, + /// Profiling granularity. /// Setting any value restricts logs to profiling information #[arg(long)] @@ -337,6 +345,7 @@ fn run_elf_inner< std::fs::read(elf_path).context(format!("failed to read {}", elf_path.display()))?; let program = Program::load_elf(&elf_bytes, u32::MAX).context("failed to load elf")?; print_cargo_message("Loaded", format_args!("{}", elf_path.display())); + let shards = Shards::new(options.shard_id as usize, options.max_num_shards as usize); let public_io = options .read_public_io() @@ -385,6 +394,7 @@ fn run_elf_inner< create_prover(backend.clone()), program, platform, + shards, &hints, &public_io, options.max_steps, diff --git a/ceno_emul/Cargo.toml b/ceno_emul/Cargo.toml index b0af43fe3..6cc12cd17 100644 --- a/ceno_emul/Cargo.toml +++ b/ceno_emul/Cargo.toml @@ -19,9 +19,12 @@ itertools.workspace = true multilinear_extensions.workspace = true num-derive.workspace = true num-traits.workspace = true +rayon.workspace = true rrs_lib = { package = "rrs-succinct", version = "0.1.0" } +rustc-hash.workspace = true secp.workspace = true serde.workspace = true +smallvec.workspace = true strum.workspace = true strum_macros.workspace = true substrate-bn.workspace = true diff --git a/ceno_emul/src/chunked_vec.rs b/ceno_emul/src/chunked_vec.rs new file mode 100644 index 000000000..e53d51a73 --- /dev/null +++ b/ceno_emul/src/chunked_vec.rs @@ -0,0 +1,89 @@ +use rayon::iter::{IntoParallelIterator, ParallelIterator}; +use std::ops::{Index, IndexMut}; + +/// a chunked vector that grows in fixed-size chunks. +#[derive(Default, Debug, Clone)] +pub struct ChunkedVec { + chunks: Vec>, + chunk_size: usize, + len: usize, +} + +impl ChunkedVec { + /// create a new ChunkedVec with a given chunk size. + pub fn new(chunk_size: usize) -> Self { + assert!(chunk_size > 0, "chunk_size must be > 0"); + Self { + chunks: Vec::new(), + chunk_size, + len: 0, + } + } + + /// get the current number of elements. + pub fn len(&self) -> usize { + self.len + } + + /// returns true if the vector is empty. + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + /// access element by index (immutable). + pub fn get(&self, index: usize) -> Option<&T> { + if index >= self.len { + return None; + } + let chunk_idx = index / self.chunk_size; + let within_idx = index % self.chunk_size; + self.chunks.get(chunk_idx)?.get(within_idx) + } + + /// access element by index (mutable). + /// get mutable reference to element at index, auto-creating chunks as needed + pub fn get_or_create(&mut self, index: usize) -> &mut T { + let chunk_idx = index / self.chunk_size; + let within_idx = index % self.chunk_size; + + // Ensure enough chunks exist + if chunk_idx >= self.chunks.len() { + let to_create = chunk_idx + 1 - self.chunks.len(); + + // Use rayon to create all missing chunks in parallel + let mut new_chunks: Vec> = (0..to_create) + .map(|_| { + (0..self.chunk_size) + .into_par_iter() + .map(|_| Default::default()) + .collect::>() + }) + .collect(); + + self.chunks.append(&mut new_chunks); + } + + let chunk = &mut self.chunks[chunk_idx]; + + // Update the overall length + if index >= self.len { + self.len = index + 1; + } + + &mut chunk[within_idx] + } +} + +impl Index for ChunkedVec { + type Output = T; + + fn index(&self, index: usize) -> &Self::Output { + self.get(index).expect("index out of bounds") + } +} + +impl IndexMut for ChunkedVec { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + self.get_or_create(index) + } +} diff --git a/ceno_emul/src/lib.rs b/ceno_emul/src/lib.rs index 8f439d036..3d88484fa 100644 --- a/ceno_emul/src/lib.rs +++ b/ceno_emul/src/lib.rs @@ -7,7 +7,9 @@ mod platform; pub use platform::{CENO_PLATFORM, Platform}; mod tracer; -pub use tracer::{Change, MemOp, ReadOp, StepRecord, Tracer, WriteOp}; +pub use tracer::{ + Change, MemOp, NextAccessPair, NextCycleAccess, ReadOp, StepRecord, Tracer, WriteOp, +}; mod vm_state; pub use vm_state::VMState; @@ -44,4 +46,5 @@ pub mod utils; pub mod test_utils; +mod chunked_vec; pub mod host_utils; diff --git a/ceno_emul/src/syscalls/bn254/bn254_fptower.rs b/ceno_emul/src/syscalls/bn254/bn254_fptower.rs index 75c70a055..3fa98f368 100644 --- a/ceno_emul/src/syscalls/bn254/bn254_fptower.rs +++ b/ceno_emul/src/syscalls/bn254/bn254_fptower.rs @@ -12,6 +12,7 @@ use crate::{ use super::types::{BN254_FP_WORDS, BN254_FP2_WORDS}; pub struct Bn254FpAddSpec; + impl SyscallSpec for Bn254FpAddSpec { const NAME: &'static str = "BN254_FP_ADD"; diff --git a/ceno_emul/src/syscalls/secp256k1.rs b/ceno_emul/src/syscalls/secp256k1.rs index 2facffba4..fafabe78c 100644 --- a/ceno_emul/src/syscalls/secp256k1.rs +++ b/ceno_emul/src/syscalls/secp256k1.rs @@ -6,7 +6,9 @@ use std::iter; use super::{SyscallEffects, SyscallSpec, SyscallWitness}; pub struct Secp256k1AddSpec; + pub struct Secp256k1DoubleSpec; + pub struct Secp256k1DecompressSpec; impl SyscallSpec for Secp256k1AddSpec { diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index 8280e8351..c36bd5bef 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -1,13 +1,13 @@ -use std::{ - collections::{BTreeMap, HashMap}, - fmt, mem, -}; +use rustc_hash::FxHashMap; +use smallvec::SmallVec; +use std::{collections::BTreeMap, fmt, mem}; use ceno_rt::WORD_SIZE; use crate::{ CENO_PLATFORM, InsnKind, Instruction, PC_STEP_SIZE, Platform, addr::{ByteAddr, Cycle, RegIdx, Word, WordAddr}, + chunked_vec::ChunkedVec, encode_rv32, syscalls::{SyscallEffects, SyscallWitness}, }; @@ -39,6 +39,10 @@ pub struct StepRecord { syscall: Option, } +pub type NextAccessPair = SmallVec<[(WordAddr, Cycle); 1]>; +pub type NextCycleAccess = ChunkedVec; +const ACCESSED_CHUNK_SIZE: usize = 1 << 20; + #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct MemOp { /// Virtual Memory Address. @@ -305,7 +309,8 @@ pub struct Tracer { // record each section max access address // (start_addr -> (start_addr, end_addr, min_access_addr, max_access_addr)) mmio_min_max_access: Option>, - latest_accesses: HashMap, + latest_accesses: FxHashMap, + next_accesses: NextCycleAccess, } impl Default for Tracer { @@ -362,7 +367,8 @@ impl Tracer { cycle: Self::SUBCYCLES_PER_INSN, ..StepRecord::default() }, - latest_accesses: HashMap::new(), + latest_accesses: FxHashMap::default(), + next_accesses: NextCycleAccess::new(ACCESSED_CHUNK_SIZE), } } @@ -471,16 +477,24 @@ impl Tracer { /// - Record the current instruction as the origin of the latest access. /// - Accesses within the same instruction are distinguished by `subcycle ∈ [0, 3]`. pub fn track_access(&mut self, addr: WordAddr, subcycle: Cycle) -> Cycle { - self.latest_accesses - .insert(addr, self.record.cycle + subcycle) - .unwrap_or(0) + let cur_cycle = self.record.cycle + subcycle; + let prev_cycle = self.latest_accesses.insert(addr, cur_cycle).unwrap_or(0); + self.next_accesses + .get_or_create(prev_cycle as usize) + .push((addr, cur_cycle)); + prev_cycle } /// Return all the addresses that were accessed and the cycle when they were last accessed. - pub fn final_accesses(&self) -> &HashMap { + pub fn final_accesses(&self) -> &FxHashMap { &self.latest_accesses } + /// Return all the addresses that were accessed and the cycle when they were last accessed. + pub fn next_accesses(self) -> NextCycleAccess { + self.next_accesses + } + /// Return the cycle of the pending instruction (after the last completed step). pub fn cycle(&self) -> Cycle { self.record.cycle diff --git a/ceno_emul/src/vm_state.rs b/ceno_emul/src/vm_state.rs index 51057c2b0..eaac9d639 100644 --- a/ceno_emul/src/vm_state.rs +++ b/ceno_emul/src/vm_state.rs @@ -68,6 +68,10 @@ impl VMState { &self.tracer } + pub fn take_tracer(self) -> Tracer { + self.tracer + } + pub fn platform(&self) -> &Platform { &self.platform } diff --git a/ceno_emul/tests/test_vm_trace.rs b/ceno_emul/tests/test_vm_trace.rs index 74cc83d4e..14bf7a1fe 100644 --- a/ceno_emul/tests/test_vm_trace.rs +++ b/ceno_emul/tests/test_vm_trace.rs @@ -1,9 +1,7 @@ #![allow(clippy::unusual_byte_groupings)] use anyhow::Result; -use std::{ - collections::{BTreeMap, HashMap}, - sync::Arc, -}; +use rustc_hash::FxHashMap; +use std::{collections::BTreeMap, sync::Arc}; use ceno_emul::{ CENO_PLATFORM, Cycle, EmuContext, InsnKind, Instruction, Platform, Program, StepRecord, Tracer, @@ -111,8 +109,8 @@ fn expected_ops_fibonacci_20() -> Vec { } /// Reconstruct the last access of each register. -fn expected_final_accesses_fibonacci_20() -> HashMap { - let mut accesses = HashMap::new(); +fn expected_final_accesses_fibonacci_20() -> FxHashMap { + let mut accesses = FxHashMap::default(); let x = |i| WordAddr::from(Platform::register_vma(i)); const C: Cycle = Tracer::SUBCYCLES_PER_INSN; diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 3347b38cc..07d1394ac 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -34,6 +34,7 @@ witness.workspace = true itertools.workspace = true ndarray.workspace = true prettytable-rs.workspace = true +rustc-hash.workspace = true strum.workspace = true strum_macros.workspace = true tracing.workspace = true @@ -47,6 +48,7 @@ derive = { path = "../derive" } generic-array.workspace = true generic_static = "0.2" num.workspace = true +num-bigint = "0.4.6" parse-size = "1.1" rand.workspace = true sp1-curves.workspace = true diff --git a/ceno_zkvm/benches/fibonacci.rs b/ceno_zkvm/benches/fibonacci.rs index 878502f8e..325c59f46 100644 --- a/ceno_zkvm/benches/fibonacci.rs +++ b/ceno_zkvm/benches/fibonacci.rs @@ -13,7 +13,7 @@ use criterion::*; use ff_ext::BabyBearExt4; use gkr_iop::cpu::default_backend_config; -use ceno_zkvm::scheme::verifier::ZKVMVerifier; +use ceno_zkvm::{e2e::Shards, scheme::verifier::ZKVMVerifier}; use mpcs::BasefoldDefault; use transcript::BasicTranscript; @@ -54,6 +54,7 @@ fn fibonacci_prove(c: &mut Criterion) { create_prover(backend.clone()), program.clone(), platform.clone(), + Shards::default(), &Vec::from(&hints), &[], max_steps, @@ -91,6 +92,7 @@ fn fibonacci_prove(c: &mut Criterion) { create_prover(backend.clone()), program.clone(), platform.clone(), + Shards::default(), &Vec::from(&hints), &[], max_steps, diff --git a/ceno_zkvm/benches/fibonacci_witness.rs b/ceno_zkvm/benches/fibonacci_witness.rs index 483b690d5..d942743db 100644 --- a/ceno_zkvm/benches/fibonacci_witness.rs +++ b/ceno_zkvm/benches/fibonacci_witness.rs @@ -9,6 +9,7 @@ use std::{fs, path::PathBuf, time::Duration}; mod alloc; use criterion::*; +use ceno_zkvm::e2e::Shards; use ff_ext::BabyBearExt4; use gkr_iop::cpu::default_backend_config; use mpcs::BasefoldDefault; @@ -65,6 +66,7 @@ fn fibonacci_witness(c: &mut Criterion) { create_prover(backend.clone()), program.clone(), platform.clone(), + Shards::default(), &Vec::from(&hints), &[], max_steps, diff --git a/ceno_zkvm/benches/is_prime.rs b/ceno_zkvm/benches/is_prime.rs index b55805fb7..6d66ff859 100644 --- a/ceno_zkvm/benches/is_prime.rs +++ b/ceno_zkvm/benches/is_prime.rs @@ -8,6 +8,7 @@ use ceno_zkvm::{ scheme::{create_backend, create_prover}, }; mod alloc; +use ceno_zkvm::e2e::Shards; use criterion::*; use ff_ext::BabyBearExt4; use gkr_iop::cpu::default_backend_config; @@ -62,6 +63,7 @@ fn is_prime_1(c: &mut Criterion) { create_prover(backend.clone()), program.clone(), platform.clone(), + Shards::default(), &hints, &[], max_steps, diff --git a/ceno_zkvm/benches/keccak.rs b/ceno_zkvm/benches/keccak.rs index c1a889594..19011d460 100644 --- a/ceno_zkvm/benches/keccak.rs +++ b/ceno_zkvm/benches/keccak.rs @@ -8,7 +8,7 @@ use ceno_zkvm::{ scheme::{create_backend, create_prover}, }; mod alloc; -use ceno_zkvm::scheme::verifier::ZKVMVerifier; +use ceno_zkvm::{e2e::Shards, scheme::verifier::ZKVMVerifier}; use criterion::*; use ff_ext::BabyBearExt4; use gkr_iop::cpu::default_backend_config; @@ -51,6 +51,7 @@ fn keccak_prove(c: &mut Criterion) { create_prover(backend.clone()), program.clone(), platform.clone(), + Shards::default(), &Vec::from(&hints), &[], max_steps, @@ -85,6 +86,7 @@ fn keccak_prove(c: &mut Criterion) { create_prover(backend.clone()), program.clone(), platform.clone(), + Shards::default(), &Vec::from(&hints), &[], max_steps, diff --git a/ceno_zkvm/benches/quadratic_sorting.rs b/ceno_zkvm/benches/quadratic_sorting.rs index dc234a03a..93389c388 100644 --- a/ceno_zkvm/benches/quadratic_sorting.rs +++ b/ceno_zkvm/benches/quadratic_sorting.rs @@ -8,6 +8,7 @@ use ceno_zkvm::{ scheme::{create_backend, create_prover}, }; mod alloc; +use ceno_zkvm::e2e::Shards; use criterion::*; use ff_ext::BabyBearExt4; use gkr_iop::cpu::default_backend_config; @@ -63,6 +64,7 @@ fn quadratic_sorting_1(c: &mut Criterion) { create_prover(backend.clone()), program.clone(), platform.clone(), + Shards::default(), &hints, &[], max_steps, diff --git a/ceno_zkvm/benches/riscv_add.rs b/ceno_zkvm/benches/riscv_add.rs index 028748058..9d8cc22e8 100644 --- a/ceno_zkvm/benches/riscv_add.rs +++ b/ceno_zkvm/benches/riscv_add.rs @@ -111,7 +111,8 @@ fn bench_add(c: &mut Criterion) { witness: polys, structural_witness: vec![], public_input: vec![], - num_instances, + num_instances: vec![num_instances], + has_ecc_ops: false, }; let _ = prover .create_chip_proof( diff --git a/ceno_zkvm/src/bin/e2e.rs b/ceno_zkvm/src/bin/e2e.rs index c7ec2b310..52df7e6da 100644 --- a/ceno_zkvm/src/bin/e2e.rs +++ b/ceno_zkvm/src/bin/e2e.rs @@ -4,7 +4,7 @@ use ceno_host::{CenoStdin, memory_from_file}; use ceno_zkvm::print_allocated_bytes; use ceno_zkvm::{ e2e::{ - Checkpoint, FieldType, PcsKind, Preset, run_e2e_with_checkpoint, setup_platform, + Checkpoint, FieldType, PcsKind, Preset, Shards, run_e2e_with_checkpoint, setup_platform, setup_platform_debug, verify, }, scheme::{ @@ -108,6 +108,14 @@ struct Args { /// The security level to use. #[arg(short, long, value_enum, default_value_t = SecurityLevel::default())] security_level: SecurityLevel, + + // shard id + #[arg(long, default_value = "0")] + shard_id: u32, + + // number of total shards + #[arg(long, default_value = "1")] + max_num_shards: u32, } fn main() { @@ -240,6 +248,7 @@ fn main() { .unwrap_or_default(); let max_steps = args.max_steps.unwrap_or(usize::MAX); + let shards = Shards::new(args.shard_id as usize, args.max_num_shards as usize); match (args.pcs, args.field) { (PcsKind::Basefold, FieldType::Goldilocks) => { @@ -249,6 +258,7 @@ fn main() { prover, program, platform, + shards, &hints, &public_io, max_steps, @@ -264,6 +274,7 @@ fn main() { prover, program, platform, + shards, &hints, &public_io, max_steps, @@ -279,6 +290,7 @@ fn main() { prover, program, platform, + shards, &hints, &public_io, max_steps, @@ -294,6 +306,7 @@ fn main() { prover, program, platform, + shards, &hints, &public_io, max_steps, @@ -320,6 +333,7 @@ fn run_inner< pd: PD, program: Program, platform: Platform, + shards: Shards, hints: &[u32], public_io: &[u32], max_steps: usize, @@ -328,7 +342,7 @@ fn run_inner< checkpoint: Checkpoint, ) { let result = run_e2e_with_checkpoint::( - pd, program, platform, hints, public_io, max_steps, checkpoint, + pd, program, platform, shards, hints, public_io, max_steps, checkpoint, ); let zkvm_proof = result diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index e1ace19d0..9ea76595a 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -4,9 +4,10 @@ use gkr_iop::{error::CircuitBuilderError, tables::LookupTable}; use crate::{ circuit_builder::CircuitBuilder, instructions::riscv::constants::{ - END_CYCLE_IDX, END_PC_IDX, EXIT_CODE_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, PUBLIC_IO_IDX, - UINT_LIMBS, + END_CYCLE_IDX, END_PC_IDX, END_SHARD_ID_IDX, EXIT_CODE_IDX, GLOBAL_RW_SUM_IDX, + INIT_CYCLE_IDX, INIT_PC_IDX, PUBLIC_IO_IDX, UINT_LIMBS, }, + scheme::constants::SEPTIC_EXTENSION_DEGREE, tables::InsnRecord, }; use multilinear_extensions::{Expression, Instance}; @@ -21,7 +22,10 @@ pub trait PublicIOQuery { fn query_init_cycle(&mut self) -> Result; fn query_end_pc(&mut self) -> Result; fn query_end_cycle(&mut self) -> Result; + fn query_global_rw_sum(&mut self) -> Result, CircuitBuilderError>; fn query_public_io(&mut self) -> Result<[Instance; UINT_LIMBS], CircuitBuilderError>; + #[allow(dead_code)] + fn query_shard_id(&mut self) -> Result; } impl<'a, E: ExtensionField> InstFetch for CircuitBuilder<'a, E> { @@ -60,6 +64,10 @@ impl<'a, E: ExtensionField> PublicIOQuery for CircuitBuilder<'a, E> { self.cs.query_instance(|| "end_cycle", END_CYCLE_IDX) } + fn query_shard_id(&mut self) -> Result { + self.cs.query_instance(|| "shard_id", END_SHARD_ID_IDX) + } + fn query_public_io(&mut self) -> Result<[Instance; UINT_LIMBS], CircuitBuilderError> { Ok([ self.cs.query_instance(|| "public_io_low", PUBLIC_IO_IDX)?, @@ -67,4 +75,23 @@ impl<'a, E: ExtensionField> PublicIOQuery for CircuitBuilder<'a, E> { .query_instance(|| "public_io_high", PUBLIC_IO_IDX + 1)?, ]) } + + fn query_global_rw_sum(&mut self) -> Result, CircuitBuilderError> { + let x = (0..SEPTIC_EXTENSION_DEGREE) + .map(|i| { + self.cs + .query_instance(|| format!("global_rw_sum_x_{}", i), GLOBAL_RW_SUM_IDX + i) + }) + .collect::, CircuitBuilderError>>()?; + let y = (0..SEPTIC_EXTENSION_DEGREE) + .map(|i| { + self.cs.query_instance( + || format!("global_rw_sum_y_{}", i), + GLOBAL_RW_SUM_IDX + SEPTIC_EXTENSION_DEGREE + i, + ) + }) + .collect::, CircuitBuilderError>>()?; + + Ok([x, y].concat()) + } } diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 226231c2b..8421137f1 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -3,6 +3,7 @@ use crate::{ instructions::riscv::{DummyExtraConfig, MemPadder, MmuConfig, Rv32imConfig}, scheme::{ PublicValues, ZKVMProof, + constants::SEPTIC_EXTENSION_DEGREE, hal::ProverDevice, mock_prover::{LkMultiplicityKey, MockProver}, prover::ZKVMProver, @@ -16,22 +17,28 @@ use crate::{ tables::{MemFinalRecord, MemInitRecord, ProgramTableCircuit, ProgramTableConfig}, }; use ceno_emul::{ - Addr, ByteAddr, CENO_PLATFORM, EmuContext, InsnKind, IterAddresses, Platform, Program, - StepRecord, Tracer, VMState, WORD_SIZE, WordAddr, host_utils::read_all_messages, + Addr, ByteAddr, CENO_PLATFORM, Cycle, EmuContext, InsnKind, IterAddresses, NextCycleAccess, + Platform, Program, StepRecord, Tracer, VMState, WORD_SIZE, Word, WordAddr, + host_utils::read_all_messages, }; use clap::ValueEnum; +use either::Either; use ff_ext::ExtensionField; #[cfg(debug_assertions)] use ff_ext::{Instrumented, PoseidonField}; -use gkr_iop::hal::ProverBackend; +use gkr_iop::{RAMType, hal::ProverBackend}; use itertools::{Itertools, MinMaxResult, chain}; use mpcs::{PolynomialCommitmentScheme, SecurityLevel}; +use multilinear_extensions::util::max_usable_threads; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; use serde::Serialize; use std::{ - collections::{BTreeSet, HashMap, HashSet}, + borrow::Cow, + collections::{BTreeMap, BTreeSet, HashMap, HashSet}, sync::Arc, }; use transcript::BasicTranscript as Transcript; +use witness::next_pow2_instance_padding; /// The polynomial commitment scheme kind #[derive( @@ -87,19 +94,322 @@ pub struct FullMemState { type InitMemState = FullMemState; type FinalMemState = FullMemState; -pub struct EmulationResult { +pub struct EmulationResult<'a> { pub exit_code: Option, pub all_records: Vec, pub final_mem_state: FinalMemState, pub pi: PublicValues, + pub shard_ctx: ShardContext<'a>, } -pub fn emulate_program( +pub struct RAMRecord { + pub ram_type: RAMType, + pub id: u64, + pub addr: WordAddr, + // prev_cycle and cycle are global cycle + pub prev_cycle: Cycle, + pub cycle: Cycle, + // shard_cycle is cycle in current local shard, which already offset by start cycle + pub shard_cycle: Cycle, + pub prev_value: Option, + pub value: Word, + // for global reads, `shard_id` refers to the shard that previously produced this value. + // for global write, `shard_id` refers to current shard. + pub shard_id: usize, +} + +#[derive(Clone, Debug)] +pub struct Shards { + pub shard_id: usize, + pub max_num_shards: usize, +} + +impl Shards { + pub fn new(shard_id: usize, max_num_shards: usize) -> Self { + assert!(shard_id < max_num_shards); + Self { + shard_id, + max_num_shards, + } + } + + pub fn is_first_shard(&self) -> bool { + self.shard_id == 0 + } + + pub fn is_last_shard(&self) -> bool { + self.shard_id == self.max_num_shards - 1 + } +} + +impl Default for Shards { + fn default() -> Self { + Self { + shard_id: 0, + max_num_shards: 1, + } + } +} + +pub struct ShardContext<'a> { + shards: Shards, + max_cycle: Cycle, + // TODO optimize this map as it's super huge + addr_future_accesses: Cow<'a, NextCycleAccess>, + read_thread_based_record_storage: + Either>, &'a mut BTreeMap>, + write_thread_based_record_storage: + Either>, &'a mut BTreeMap>, + pub cur_shard_cycle_range: std::ops::Range, + pub expected_inst_per_shard: usize, +} + +impl<'a> Default for ShardContext<'a> { + fn default() -> Self { + let max_threads = max_usable_threads(); + Self { + shards: Shards::default(), + max_cycle: Cycle::default(), + addr_future_accesses: Cow::Owned(Default::default()), + read_thread_based_record_storage: Either::Left( + (0..max_threads) + .into_par_iter() + .map(|_| BTreeMap::new()) + .collect::>(), + ), + write_thread_based_record_storage: Either::Left( + (0..max_threads) + .into_par_iter() + .map(|_| BTreeMap::new()) + .collect::>(), + ), + cur_shard_cycle_range: Tracer::SUBCYCLES_PER_INSN as usize..usize::MAX, + expected_inst_per_shard: usize::MAX, + } + } +} + +impl<'a> ShardContext<'a> { + pub fn new( + shards: Shards, + executed_instructions: usize, + addr_future_accesses: NextCycleAccess, + ) -> Self { + // current strategy: at least each shard deal with one instruction + let max_num_shards = shards.max_num_shards.min(executed_instructions); + assert!( + shards.shard_id < max_num_shards, + "implement mechanism to skip current shard proof" + ); + + let subcycle_per_insn = Tracer::SUBCYCLES_PER_INSN as usize; + let max_threads = max_usable_threads(); + let expected_inst_per_shard = executed_instructions.div_ceil(max_num_shards); + let max_cycle = (executed_instructions + 1) * subcycle_per_insn; // cycle start from subcycle_per_insn + let cur_shard_cycle_range = (shards.shard_id * expected_inst_per_shard * subcycle_per_insn + + subcycle_per_insn) + ..((shards.shard_id + 1) * expected_inst_per_shard * subcycle_per_insn + + subcycle_per_insn) + .min(max_cycle); + + ShardContext { + shards, + max_cycle: max_cycle as Cycle, + addr_future_accesses: Cow::Owned(addr_future_accesses), + // TODO with_capacity optimisation + read_thread_based_record_storage: Either::Left( + (0..max_threads) + .into_par_iter() + .map(|_| BTreeMap::new()) + .collect::>(), + ), + // TODO with_capacity optimisation + write_thread_based_record_storage: Either::Left( + (0..max_threads) + .into_par_iter() + .map(|_| BTreeMap::new()) + .collect::>(), + ), + cur_shard_cycle_range, + expected_inst_per_shard, + } + } + + pub fn get_forked(&mut self) -> Vec> { + match ( + &mut self.read_thread_based_record_storage, + &mut self.write_thread_based_record_storage, + ) { + ( + Either::Left(read_thread_based_record_storage), + Either::Left(write_thread_based_record_storage), + ) => read_thread_based_record_storage + .iter_mut() + .zip(write_thread_based_record_storage.iter_mut()) + .map(|(read, write)| ShardContext { + shards: self.shards.clone(), + max_cycle: self.max_cycle, + addr_future_accesses: Cow::Borrowed(self.addr_future_accesses.as_ref()), + read_thread_based_record_storage: Either::Right(read), + write_thread_based_record_storage: Either::Right(write), + cur_shard_cycle_range: self.cur_shard_cycle_range.clone(), + expected_inst_per_shard: self.expected_inst_per_shard, + }) + .collect_vec(), + _ => panic!("invalid type"), + } + } + + pub fn read_records(&self) -> &[BTreeMap] { + match &self.read_thread_based_record_storage { + Either::Left(m) => m, + Either::Right(_) => panic!("undefined behaviour"), + } + } + + pub fn write_records(&self) -> &[BTreeMap] { + match &self.write_thread_based_record_storage { + Either::Left(m) => m, + Either::Right(_) => panic!("undefined behaviour"), + } + } + + #[inline(always)] + pub fn cur_shard(&self) -> usize { + self.shards.shard_id + } + + #[inline(always)] + pub fn is_first_shard(&self) -> bool { + self.shards.shard_id == 0 + } + + #[inline(always)] + pub fn is_last_shard(&self) -> bool { + self.shards.shard_id == self.shards.max_num_shards - 1 + } + + #[inline(always)] + pub fn is_current_shard_cycle(&self, cycle: Cycle) -> bool { + self.cur_shard_cycle_range.contains(&(cycle as usize)) + } + + #[inline(always)] + pub fn extract_prev_shard_id(&self, cycle: Cycle) -> usize { + let subcycle_per_insn = Tracer::SUBCYCLES_PER_INSN; + let per_shard_cycles = + (self.expected_inst_per_shard as u64).saturating_mul(subcycle_per_insn); + ((cycle.saturating_sub(subcycle_per_insn)) / per_shard_cycles) as usize + } + + #[inline(always)] + pub fn aligned_prev_ts(&self, prev_cycle: Cycle) -> Cycle { + let mut ts = prev_cycle.saturating_sub(self.current_shard_offset_cycle()); + if ts < Tracer::SUBCYCLES_PER_INSN { + ts = 0 + } + ts + } + + #[inline(always)] + pub fn aligned_current_ts(&self, cycle: Cycle) -> Cycle { + cycle.saturating_sub(self.current_shard_offset_cycle()) + } + + pub fn current_shard_offset_cycle(&self) -> Cycle { + // cycle of each local shard start from Tracer::SUBCYCLES_PER_INSN + (self.cur_shard_cycle_range.start as Cycle) - Tracer::SUBCYCLES_PER_INSN + } + + #[inline(always)] + #[allow(clippy::too_many_arguments)] + pub fn send( + &mut self, + ram_type: crate::structs::RAMType, + addr: WordAddr, + id: u64, + cycle: Cycle, + prev_cycle: Cycle, + value: Word, + prev_value: Option, + ) { + // check read from external mem bus + // exclude first shard + if prev_cycle < self.cur_shard_cycle_range.start as Cycle + && self.is_current_shard_cycle(cycle) + && !self.is_first_shard() + { + let prev_shard_id = self.extract_prev_shard_id(prev_cycle); + let ram_record = self + .read_thread_based_record_storage + .as_mut() + .right() + .expect("illegal type"); + ram_record.insert( + addr, + RAMRecord { + ram_type, + id, + addr, + prev_cycle, + cycle, + shard_cycle: 0, + prev_value, + value, + shard_id: prev_shard_id, + }, + ); + } + + // check write to external mem bus + if let Some(future_touch_cycle) = + self.addr_future_accesses + .get(cycle as usize) + .and_then(|res| { + if res.len() == 1 { + Some(res[0].1) + } else if res.len() > 1 { + res.iter() + .find(|(m_addr, _)| *m_addr == addr) + .map(|(_, cycle)| *cycle) + } else { + None + } + }) + && future_touch_cycle >= self.cur_shard_cycle_range.end as Cycle + && self.is_current_shard_cycle(cycle) + { + let shard_cycle = self.aligned_current_ts(cycle); + let ram_record = self + .write_thread_based_record_storage + .as_mut() + .right() + .expect("illegal type"); + ram_record.insert( + addr, + RAMRecord { + ram_type, + id, + addr, + prev_cycle, + cycle, + shard_cycle, + prev_value, + value, + shard_id: self.shards.shard_id, + }, + ); + } + } +} + +pub fn emulate_program<'a>( program: Arc, max_steps: usize, init_mem_state: &InitMemState, platform: &Platform, -) -> EmulationResult { + shards: &Shards, +) -> EmulationResult<'a> { let InitMemState { mem: mem_init, io: io_init, @@ -156,7 +466,9 @@ pub fn emulate_program( Tracer::SUBCYCLES_PER_INSN, vm.get_pc().into(), end_cycle, + shards.shard_id as u32, io_init.iter().map(|rec| rec.value).collect_vec(), + vec![0; SEPTIC_EXTENSION_DEGREE * 2], // point_at_infinity ); // Find the final register values and cycles. @@ -167,6 +479,7 @@ pub fn emulate_program( if index < VMState::REG_COUNT { let vma: WordAddr = Platform::register_vma(index).into(); MemFinalRecord { + ram_type: RAMType::Register, addr: rec.addr, value: vm.peek_register(index), cycle: *final_access.get(&vma).unwrap_or(&0), @@ -174,6 +487,7 @@ pub fn emulate_program( } else { // The table is padded beyond the number of registers. MemFinalRecord { + ram_type: RAMType::Register, addr: rec.addr, value: 0, cycle: 0, @@ -188,6 +502,7 @@ pub fn emulate_program( .map(|rec| { let vma: WordAddr = rec.addr.into(); MemFinalRecord { + ram_type: RAMType::Memory, addr: rec.addr, value: vm.peek_memory(vma), cycle: *final_access.get(&vma).unwrap_or(&0), @@ -199,6 +514,7 @@ pub fn emulate_program( let io_final = io_init .iter() .map(|rec| MemFinalRecord { + ram_type: RAMType::Memory, addr: rec.addr, value: rec.value, cycle: *final_access.get(&rec.addr.into()).unwrap_or(&0), @@ -209,6 +525,7 @@ pub fn emulate_program( let hints_final = hints_init .iter() .map(|rec| MemFinalRecord { + ram_type: RAMType::Memory, addr: rec.addr, value: rec.value, cycle: *final_access.get(&rec.addr.into()).unwrap_or(&0), @@ -226,6 +543,7 @@ pub fn emulate_program( .map(|vma| { let byte_addr = vma.baddr(); MemFinalRecord { + ram_type: RAMType::Memory, addr: byte_addr.0, value: vm.peek_memory(vma), cycle: *final_access.get(&vma).unwrap_or(&0), @@ -249,6 +567,7 @@ pub fn emulate_program( .map(|vma| { let byte_addr = vma.baddr(); MemFinalRecord { + ram_type: RAMType::Memory, addr: byte_addr.0, value: vm.peek_memory(vma), cycle: *final_access.get(&vma).unwrap_or(&0), @@ -270,10 +589,13 @@ pub fn emulate_program( ), ); + let shard_ctx = ShardContext::new(shards.clone(), insts, vm.take_tracer().next_accesses()); + EmulationResult { pi, exit_code, all_records, + shard_ctx, final_mem_state: FinalMemState { reg: reg_final, io: io_final, @@ -389,17 +711,17 @@ pub fn init_static_addrs(program: &Program) -> Vec { program_addrs } -pub struct ConstraintSystemConfig { +pub struct ConstraintSystemConfig<'a, E: ExtensionField> { pub zkvm_cs: ZKVMConstraintSystem, pub config: Rv32imConfig, - pub mmu_config: MmuConfig, + pub mmu_config: MmuConfig<'a, E>, pub dummy_config: DummyExtraConfig, pub prog_config: ProgramTableConfig, } -pub fn construct_configs( +pub fn construct_configs<'a, E: ExtensionField>( program_params: ProgramParams, -) -> ConstraintSystemConfig { +) -> ConstraintSystemConfig<'a, E> { let mut zkvm_cs = ZKVMConstraintSystem::new_with_platform(program_params); let config = Rv32imConfig::::construct_circuits(&mut zkvm_cs); @@ -450,7 +772,7 @@ pub fn generate_fixed_traces( pub fn generate_witness( system_config: &ConstraintSystemConfig, - emul_result: EmulationResult, + mut emul_result: EmulationResult, program: &Program, ) -> ZKVMWitnesses { let mut zkvm_witness = ZKVMWitnesses::default(); @@ -459,13 +781,19 @@ pub fn generate_witness( .config .assign_opcode_circuit( &system_config.zkvm_cs, + &mut emul_result.shard_ctx, &mut zkvm_witness, emul_result.all_records, ) .unwrap(); system_config .dummy_config - .assign_opcode_circuit(&system_config.zkvm_cs, &mut zkvm_witness, dummy_records) + .assign_opcode_circuit( + &system_config.zkvm_cs, + &mut emul_result.shard_ctx, + &mut zkvm_witness, + dummy_records, + ) .unwrap(); zkvm_witness.finalize_lk_multiplicities(); @@ -478,6 +806,7 @@ pub fn generate_witness( .mmu_config .assign_table_circuit( &system_config.zkvm_cs, + &emul_result.shard_ctx, &mut zkvm_witness, &emul_result.final_mem_state.reg, &emul_result.final_mem_state.mem, @@ -519,12 +848,13 @@ pub enum Checkpoint { pub type IntermediateState = (Option>, Option>); /// Context construct from a program and given platform -pub struct E2EProgramCtx { +pub struct E2EProgramCtx<'a, E: ExtensionField> { pub program: Arc, pub platform: Platform, + pub shards: Shards, pub static_addrs: Vec, pub pubio_len: usize, - pub system_config: ConstraintSystemConfig, + pub system_config: ConstraintSystemConfig<'a, E>, pub reg_init: Vec, pub io_init: Vec, pub zkvm_fixed_traces: ZKVMFixedTraces, @@ -549,12 +879,16 @@ impl> E2ECheckpointResult< } /// Set up a program with the given platform -pub fn setup_program(program: Program, platform: Platform) -> E2EProgramCtx { +pub fn setup_program<'a, E: ExtensionField>( + program: Program, + platform: Platform, + shards: Shards, +) -> E2EProgramCtx<'a, E> { let static_addrs = init_static_addrs(&program); let pubio_len = platform.public_io.iter_addresses().len(); let program_params = ProgramParams { platform: platform.clone(), - program_size: program.instructions.len(), + program_size: next_pow2_instance_padding(program.instructions.len()), static_memory_len: static_addrs.len(), pubio_len, }; @@ -574,6 +908,7 @@ pub fn setup_program(program: Program, platform: Platform) -> E2EProgramCtx { program: Arc::new(program), platform, + shards, static_addrs, pubio_len, system_config, @@ -583,7 +918,7 @@ pub fn setup_program(program: Program, platform: Platform) -> } } -impl E2EProgramCtx { +impl E2EProgramCtx<'_, E> { pub fn keygen + 'static>( &self, max_num_variables: usize, @@ -666,13 +1001,14 @@ pub fn run_e2e_with_checkpoint< device: PD, program: Program, platform: Platform, + shards: Shards, hints: &[u32], public_io: &[u32], max_steps: usize, checkpoint: Checkpoint, ) -> E2ECheckpointResult { let start = std::time::Instant::now(); - let ctx = setup_program::(program, platform); + let ctx = setup_program::(program, platform, shards); tracing::debug!("setup_program done in {:?}", start.elapsed()); // Keygen @@ -710,6 +1046,7 @@ pub fn run_e2e_with_checkpoint< max_steps, &init_full_mem, &ctx.platform, + &ctx.shards, ); tracing::debug!("emulate done in {:?}", start.elapsed()); @@ -793,7 +1130,13 @@ pub fn run_e2e_proof< is_mock_proving: bool, ) -> ZKVMProof { // Emulate program - let emul_result = emulate_program(ctx.program.clone(), max_steps, init_full_mem, &ctx.platform); + let emul_result = emulate_program( + ctx.program.clone(), + max_steps, + init_full_mem, + &ctx.platform, + &ctx.shards, + ); // clone pi before consuming let pi = emul_result.pi.clone(); diff --git a/ceno_zkvm/src/gadgets/mod.rs b/ceno_zkvm/src/gadgets/mod.rs index 5e429354f..a4d624568 100644 --- a/ceno_zkvm/src/gadgets/mod.rs +++ b/ceno_zkvm/src/gadgets/mod.rs @@ -1,6 +1,7 @@ mod div; mod field; mod is_lt; +mod poseidon2; mod signed; mod signed_ext; mod signed_limbs; @@ -13,6 +14,7 @@ pub use gkr_iop::gadgets::{ AssertLtConfig, InnerLtConfig, IsEqualConfig, IsLtConfig, IsZeroConfig, cal_lt_diff, }; pub use is_lt::{AssertSignedLtConfig, SignedLtConfig}; +pub use poseidon2::{Poseidon2BabyBearConfig, Poseidon2Config}; pub use signed::Signed; pub use signed_ext::SignedExtendConfig; pub use signed_limbs::{UIntLimbsLT, UIntLimbsLTConfig}; diff --git a/ceno_zkvm/src/gadgets/poseidon2.rs b/ceno_zkvm/src/gadgets/poseidon2.rs new file mode 100644 index 000000000..322ad675f --- /dev/null +++ b/ceno_zkvm/src/gadgets/poseidon2.rs @@ -0,0 +1,529 @@ +// Poseidon2 over BabyBear field + +use std::{ + borrow::{Borrow, BorrowMut}, + iter::from_fn, + mem::transmute, +}; + +use ff_ext::{BabyBearExt4, ExtensionField}; +use gkr_iop::error::CircuitBuilderError; +use itertools::Itertools; +use multilinear_extensions::{Expression, ToExpr, WitIn}; +use num_bigint::BigUint; +use p3::{ + babybear::BabyBearInternalLayerParameters, + field::{Field, FieldAlgebra, PrimeField}, + monty_31::InternalLayerBaseParameters, + poseidon2::{GenericPoseidon2LinearLayers, MDSMat4, mds_light_permutation}, + poseidon2_air::{FullRound, PartialRound, Poseidon2Cols, SBox, num_cols}, +}; + +use crate::circuit_builder::CircuitBuilder; + +// copied from poseidon2-air/src/constants.rs +// as the original one cannot be accessed here +#[derive(Debug, Clone)] +pub struct RoundConstants< + F: Field, + const WIDTH: usize, + const HALF_FULL_ROUNDS: usize, + const PARTIAL_ROUNDS: usize, +> { + pub beginning_full_round_constants: [[F; WIDTH]; HALF_FULL_ROUNDS], + pub partial_round_constants: [F; PARTIAL_ROUNDS], + pub ending_full_round_constants: [[F; WIDTH]; HALF_FULL_ROUNDS], +} + +impl + From> for RoundConstants +{ + fn from(value: Vec) -> Self { + let mut iter = value.into_iter(); + let mut beginning_full_round_constants = [[F::ZERO; WIDTH]; HALF_FULL_ROUNDS]; + + beginning_full_round_constants.iter_mut().for_each(|arr| { + arr.iter_mut() + .for_each(|c| *c = iter.next().expect("insufficient round constants")) + }); + + let mut partial_round_constants = [F::ZERO; PARTIAL_ROUNDS]; + + partial_round_constants + .iter_mut() + .for_each(|arr| *arr = iter.next().expect("insufficient round constants")); + + let mut ending_full_round_constants = [[F::ZERO; WIDTH]; HALF_FULL_ROUNDS]; + ending_full_round_constants.iter_mut().for_each(|arr| { + arr.iter_mut() + .for_each(|c| *c = iter.next().expect("insufficient round constants")) + }); + + assert!(iter.next().is_none(), "round constants are too many"); + + RoundConstants { + beginning_full_round_constants, + partial_round_constants, + ending_full_round_constants, + } + } +} + +pub type Poseidon2BabyBearConfig = Poseidon2Config; +pub struct Poseidon2Config< + E: ExtensionField, + const STATE_WIDTH: usize, + const SBOX_DEGREE: u64, + const SBOX_REGISTERS: usize, + const HALF_FULL_ROUNDS: usize, + const PARTIAL_ROUNDS: usize, +> { + cols: Vec, + constants: RoundConstants, +} + +#[derive(Debug, Clone)] +pub struct Poseidon2LinearLayers; + +impl GenericPoseidon2LinearLayers + for Poseidon2LinearLayers +{ + fn internal_linear_layer(state: &mut [F; WIDTH]) { + // this only works when F is BabyBear field for now + let babybear_prime = BigUint::from(0x7800_0001u32); + if F::order() == babybear_prime { + let diag_m1_matrix = &>::INTERNAL_DIAG_MONTY; + let diag_m1_matrix: &[F; WIDTH] = unsafe { transmute(diag_m1_matrix) }; + let sum = state.iter().cloned().sum::(); + for (input, diag_m1) in state.iter_mut().zip(diag_m1_matrix) { + *input = sum + F::from_f(*diag_m1) * *input; + } + } else { + panic!("Unsupported field"); + } + } + + fn external_linear_layer(state: &mut [F; WIDTH]) { + mds_light_permutation(state, &MDSMat4); + } +} + +impl< + E: ExtensionField, + const STATE_WIDTH: usize, + const SBOX_DEGREE: u64, + const SBOX_REGISTERS: usize, + const HALF_FULL_ROUNDS: usize, + const PARTIAL_ROUNDS: usize, +> Poseidon2Config +{ + // constraints taken from poseidon2_air/src/air.rs + fn eval_sbox( + sbox: &SBox, SBOX_DEGREE, SBOX_REGISTERS>, + x: &mut Expression, + cb: &mut CircuitBuilder, + ) -> Result<(), CircuitBuilderError> { + *x = match (SBOX_DEGREE, SBOX_REGISTERS) { + (3, 0) => x.cube(), + (5, 0) => x.exp_const_u64::<5>(), + (7, 0) => x.exp_const_u64::<7>(), + (5, 1) => { + let committed_x3: Expression = sbox.0[0].clone(); + let x2: Expression = x.square(); + cb.require_zero( + || "x3 = x.cube()", + committed_x3.clone() - x2.clone() * x.clone(), + )?; + committed_x3 * x2 + } + (7, 1) => { + let committed_x3: Expression = sbox.0[0].clone(); + // TODO: avoid x^3 as x may have ~STATE_WIDTH terms after the linear layer + // we can allocate one more column to store x^2 (which has ~STATE_WIDTH^2 terms) + // then x^3 = x * x^2 + // but this will increase the number of columns (by FULL_ROUNDS * STATE_WIDTH + PARTIAL_ROUNDS) + cb.require_zero(|| "x3 = x.cube()", committed_x3.clone() - x.cube())?; + committed_x3.square() * x.clone() + } + _ => panic!( + "Unexpected (SBOX_DEGREE, SBOX_REGISTERS) of ({}, {})", + SBOX_DEGREE, SBOX_REGISTERS + ), + }; + + Ok(()) + } + + fn eval_full_round( + state: &mut [Expression; STATE_WIDTH], + full_round: &FullRound, STATE_WIDTH, SBOX_DEGREE, SBOX_REGISTERS>, + round_constants: &[E::BaseField], + cb: &mut CircuitBuilder, + ) -> Result<(), CircuitBuilderError> { + for (i, (s, r)) in state.iter_mut().zip_eq(round_constants.iter()).enumerate() { + *s = s.clone() + r.expr(); + Self::eval_sbox(&full_round.sbox[i], s, cb)?; + } + Self::external_linear_layer(state); + for (state_i, post_i) in state.iter_mut().zip_eq(full_round.post.iter()) { + cb.require_zero(|| "post_i = state_i", state_i.clone() - post_i)?; + *state_i = post_i.clone(); + } + + Ok(()) + } + + fn eval_partial_round( + state: &mut [Expression; STATE_WIDTH], + partial_round: &PartialRound, STATE_WIDTH, SBOX_DEGREE, SBOX_REGISTERS>, + round_constant: &E::BaseField, + cb: &mut CircuitBuilder, + ) -> Result<(), CircuitBuilderError> { + state[0] = state[0].clone() + round_constant.expr(); + Self::eval_sbox(&partial_round.sbox, &mut state[0], cb)?; + + cb.require_zero( + || "state[0] = post_sbox", + state[0].clone() - partial_round.post_sbox.clone(), + )?; + state[0] = partial_round.post_sbox.clone(); + + Self::internal_linear_layer(state); + + Ok(()) + } + + fn external_linear_layer(state: &mut [Expression; STATE_WIDTH]) { + mds_light_permutation(state, &MDSMat4); + } + + fn internal_linear_layer(state: &mut [Expression; STATE_WIDTH]) { + let sum: Expression = state.iter().map(|s| s.get_monomial_form()).sum(); + // reduce to monomial form + let sum = sum.get_monomial_form(); + let babybear_prime = BigUint::from(0x7800_0001u32); + if E::BaseField::order() == babybear_prime { + // BabyBear + let diag_m1_matrix_bb = + &>:: + INTERNAL_DIAG_MONTY; + let diag_m1_matrix: &[E::BaseField; STATE_WIDTH] = + unsafe { transmute(diag_m1_matrix_bb) }; + for (input, diag_m1) in state.iter_mut().zip_eq(diag_m1_matrix) { + let updated = sum.clone() + Expression::from_f(*diag_m1) * input.clone(); + // reduce to monomial form + *input = updated.get_monomial_form(); + } + } else { + panic!("Unsupported field"); + } + } + + pub fn construct( + cb: &mut CircuitBuilder, + round_constants: RoundConstants< + E::BaseField, + STATE_WIDTH, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + >, + ) -> Self { + let num_cols = + num_cols::( + ); + let cols = from_fn(|| Some(cb.create_witin(|| "poseidon2 col"))) + .take(num_cols) + .collect::>(); + let mut col_exprs = cols + .iter() + .map(|c| c.expr()) + .collect::>>(); + + let poseidon2_cols: &mut Poseidon2Cols< + Expression, + STATE_WIDTH, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + > = col_exprs.as_mut_slice().borrow_mut(); + + // external linear layer + Self::external_linear_layer(&mut poseidon2_cols.inputs); + + // eval full round + for round in 0..HALF_FULL_ROUNDS { + Self::eval_full_round( + &mut poseidon2_cols.inputs, + &poseidon2_cols.beginning_full_rounds[round], + &round_constants.beginning_full_round_constants[round], + cb, + ) + .unwrap(); + } + + // eval partial round + for round in 0..PARTIAL_ROUNDS { + Self::eval_partial_round( + &mut poseidon2_cols.inputs, + &poseidon2_cols.partial_rounds[round], + &round_constants.partial_round_constants[round], + cb, + ) + .unwrap(); + } + + // TODO: after the last partial round, each state_i has ~STATE_WIDTH terms + // which will make the next full round to have many terms + + // eval full round + for round in 0..HALF_FULL_ROUNDS { + Self::eval_full_round( + &mut poseidon2_cols.inputs, + &poseidon2_cols.ending_full_rounds[round], + &round_constants.ending_full_round_constants[round], + cb, + ) + .unwrap(); + } + + Poseidon2Config { + cols, + constants: round_constants, + } + } + + #[inline(always)] + pub fn num_polys(&self) -> usize { + self.cols.len() + } + + pub fn inputs(&self) -> Vec> { + let col_exprs = self.cols.iter().map(|c| c.expr()).collect::>(); + + let poseidon2_cols: &Poseidon2Cols< + Expression, + STATE_WIDTH, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + > = col_exprs.as_slice().borrow(); + + poseidon2_cols.inputs.to_vec() + } + + pub fn output(&self) -> Vec> { + let col_exprs = self.cols.iter().map(|c| c.expr()).collect::>(); + + let poseidon2_cols: &Poseidon2Cols< + Expression, + STATE_WIDTH, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + > = col_exprs.as_slice().borrow(); + + poseidon2_cols + .ending_full_rounds + .last() + .map(|r| r.post.to_vec()) + .unwrap() + } + + pub fn assign_instance( + &self, + instance: &mut [E::BaseField], + state: [E::BaseField; STATE_WIDTH], + ) { + let poseidon2_cols: &mut Poseidon2Cols< + E::BaseField, + STATE_WIDTH, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + > = instance.borrow_mut(); + + generate_trace_rows_for_perm::< + E::BaseField, + Poseidon2LinearLayers, + STATE_WIDTH, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + >(poseidon2_cols, state, &self.constants); + } +} + +////////////////////////////////////////////////////////////////////////// +/// The following routines are taken from poseidon2-air/src/generation.rs +////////////////////////////////////////////////////////////////////////// +fn generate_trace_rows_for_perm< + F: PrimeField, + LinearLayers: GenericPoseidon2LinearLayers, + const WIDTH: usize, + const SBOX_DEGREE: u64, + const SBOX_REGISTERS: usize, + const HALF_FULL_ROUNDS: usize, + const PARTIAL_ROUNDS: usize, +>( + perm: &mut Poseidon2Cols< + F, + WIDTH, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + >, + mut state: [F; WIDTH], + constants: &RoundConstants, +) { + perm.export = F::ONE; + perm.inputs + .iter_mut() + .zip(state.iter()) + .for_each(|(input, &x)| { + *input = x; + }); + + LinearLayers::external_linear_layer(&mut state); + + for (full_round, constants) in perm + .beginning_full_rounds + .iter_mut() + .zip(&constants.beginning_full_round_constants) + { + generate_full_round::( + &mut state, full_round, constants, + ); + } + + for (partial_round, constant) in perm + .partial_rounds + .iter_mut() + .zip(&constants.partial_round_constants) + { + generate_partial_round::( + &mut state, + partial_round, + *constant, + ); + } + + for (full_round, constants) in perm + .ending_full_rounds + .iter_mut() + .zip(&constants.ending_full_round_constants) + { + generate_full_round::( + &mut state, full_round, constants, + ); + } +} + +#[inline] +fn generate_full_round< + F: PrimeField, + LinearLayers: GenericPoseidon2LinearLayers, + const WIDTH: usize, + const SBOX_DEGREE: u64, + const SBOX_REGISTERS: usize, +>( + state: &mut [F; WIDTH], + full_round: &mut FullRound, + round_constants: &[F; WIDTH], +) { + for (state_i, const_i) in state.iter_mut().zip(round_constants) { + *state_i += *const_i; + } + for (state_i, sbox_i) in state.iter_mut().zip(full_round.sbox.iter_mut()) { + generate_sbox(sbox_i, state_i); + } + LinearLayers::external_linear_layer(state); + full_round + .post + .iter_mut() + .zip(*state) + .for_each(|(post, x)| { + *post = x; + }); +} + +#[inline] +fn generate_partial_round< + F: PrimeField, + LinearLayers: GenericPoseidon2LinearLayers, + const WIDTH: usize, + const SBOX_DEGREE: u64, + const SBOX_REGISTERS: usize, +>( + state: &mut [F; WIDTH], + partial_round: &mut PartialRound, + round_constant: F, +) { + state[0] += round_constant; + generate_sbox(&mut partial_round.sbox, &mut state[0]); + partial_round.post_sbox = state[0]; + LinearLayers::internal_linear_layer(state); +} + +#[inline] +fn generate_sbox( + sbox: &mut SBox, + x: &mut F, +) { + *x = match (DEGREE, REGISTERS) { + (3, 0) => x.cube(), + (5, 0) => x.exp_const_u64::<5>(), + (7, 0) => x.exp_const_u64::<7>(), + (5, 1) => { + let x2 = x.square(); + let x3 = x2 * *x; + sbox.0[0] = x3; + x3 * x2 + } + (7, 1) => { + let x3 = x.cube(); + sbox.0[0] = x3; + x3 * x3 * *x + } + (11, 2) => { + let x2 = x.square(); + let x3 = x2 * *x; + let x9 = x3.cube(); + sbox.0[0] = x3; + sbox.0[1] = x9; + x9 * x2 + } + _ => panic!( + "Unexpected (DEGREE, REGISTERS) of ({}, {})", + DEGREE, REGISTERS + ), + } +} + +#[cfg(test)] +mod tests { + use crate::gadgets::poseidon2::Poseidon2BabyBearConfig; + use ff_ext::{BabyBearExt4, PoseidonField}; + use gkr_iop::circuit_builder::{CircuitBuilder, ConstraintSystem}; + use p3::babybear::BabyBear; + + type E = BabyBearExt4; + type F = BabyBear; + #[test] + fn test_poseidon2_gadget() { + let mut cs = ConstraintSystem::new(|| "poseidon2 gadget test"); + let mut cb = CircuitBuilder::::new(&mut cs); + + // let poseidon2_constants = horizen_round_consts(); + let rc = ::get_default_perm_rc().into(); + let _ = Poseidon2BabyBearConfig::construct(&mut cb, rc); + } +} diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 4591c47e3..12c137aa8 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -1,5 +1,5 @@ use crate::{ - circuit_builder::CircuitBuilder, error::ZKVMError, structs::ProgramParams, + circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, structs::ProgramParams, tables::RMMCollections, witness::LkMultiplicity, }; use ceno_emul::StepRecord; @@ -19,6 +19,7 @@ use rayon::{ }; use witness::{InstancePaddingStrategy, RowMajorMatrix, set_val}; +pub mod global; pub mod riscv; pub trait Instruction { @@ -56,7 +57,7 @@ pub trait Instruction { descending: false, }, ); - let selector_type = SelectorType::Prefix(E::BaseField::ZERO, selector.expr()); + let selector_type = SelectorType::Prefix(selector.expr()); // all shared the same selector let (out_evals, mut chip) = ( @@ -79,7 +80,7 @@ pub trait Instruction { cb.cs.lk_selector = Some(selector_type.clone()); cb.cs.zero_selector = Some(selector_type.clone()); - let layer = Layer::from_circuit_builder(cb, "Rounds".to_string(), 0, out_evals); + let layer = Layer::from_circuit_builder(cb, format!("{}_main", Self::name()), 0, out_evals); chip.add_layer(layer); Ok((config, chip.gkr_circuit())) @@ -93,8 +94,9 @@ pub trait Instruction { } // assign single instance giving step from trace - fn assign_instance( + fn assign_instance<'a>( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext<'a>, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, @@ -102,6 +104,7 @@ pub trait Instruction { fn assign_instances( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, steps: Vec, @@ -131,22 +134,32 @@ pub trait Instruction { let raw_witin_iter = raw_witin.par_batch_iter_mut(num_instance_per_batch); let raw_structual_witin_iter = raw_structual_witin.par_batch_iter_mut(num_instance_per_batch); + let shard_ctx_vec = shard_ctx.get_forked(); raw_witin_iter .zip_eq(raw_structual_witin_iter) .zip_eq(steps.par_chunks(num_instance_per_batch)) - .flat_map(|((instances, structural_instance), steps)| { - let mut lk_multiplicity = lk_multiplicity.clone(); - instances - .chunks_mut(num_witin) - .zip_eq(structural_instance.chunks_mut(num_structural_witin)) - .zip_eq(steps) - .map(|((instance, structural_instance), step)| { - set_val!(structural_instance, selector_witin, E::BaseField::ONE); - Self::assign_instance(config, instance, &mut lk_multiplicity, step) - }) - .collect::>() - }) + .zip(shard_ctx_vec) + .flat_map( + |(((instances, structural_instance), steps), mut shard_ctx)| { + let mut lk_multiplicity = lk_multiplicity.clone(); + instances + .chunks_mut(num_witin) + .zip_eq(structural_instance.chunks_mut(num_structural_witin)) + .zip_eq(steps) + .map(|((instance, structural_instance), step)| { + set_val!(structural_instance, selector_witin, E::BaseField::ONE); + Self::assign_instance( + config, + &mut shard_ctx, + instance, + &mut lk_multiplicity, + step, + ) + }) + .collect::>() + }, + ) .collect::>()?; raw_witin.padding_by_strategy(); diff --git a/ceno_zkvm/src/instructions/global.rs b/ceno_zkvm/src/instructions/global.rs new file mode 100644 index 000000000..c7a4fddb9 --- /dev/null +++ b/ceno_zkvm/src/instructions/global.rs @@ -0,0 +1,870 @@ +use std::{collections::HashMap, iter::repeat_n, marker::PhantomData}; + +use crate::{ + Value, + chip_handler::general::PublicIOQuery, + e2e::RAMRecord, + error::ZKVMError, + gadgets::Poseidon2Config, + instructions::riscv::constants::UINT_LIMBS, + scheme::septic_curve::{SepticExtension, SepticPoint}, + structs::{ProgramParams, RAMType}, + tables::{RMMCollections, TableCircuit}, + witness::LkMultiplicity, +}; +use ceno_emul::WordAddr; +use ff_ext::{ExtensionField, FieldInto, PoseidonField, SmallField}; +use gkr_iop::{ + chip::Chip, + circuit_builder::CircuitBuilder, + error::CircuitBuilderError, + gadgets::IsLtConfig, + gkr::{GKRCircuit, layer::Layer}, + selector::SelectorType, +}; +use itertools::{Itertools, chain}; +use multilinear_extensions::{ + Expression, StructuralWitInType::EqualDistanceSequence, ToExpr, WitIn, util::max_usable_threads, +}; +use p3::{ + field::{Field, FieldAlgebra}, + matrix::dense::RowMajorMatrix, + symmetric::Permutation, +}; +use rayon::{ + iter::{ + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelExtend, + ParallelIterator, + }, + prelude::ParallelSliceMut, + slice::ParallelSlice, +}; +use std::ops::Deref; +use witness::{InstancePaddingStrategy, next_pow2_instance_padding, set_val}; + +use crate::{instructions::riscv::constants::UInt, scheme::constants::SEPTIC_EXTENSION_DEGREE}; + +/// A record for a read/write into the global set +#[derive(Debug, Clone)] +pub struct GlobalRecord { + pub addr: u32, + pub ram_type: RAMType, + pub value: u32, + pub shard: u64, + pub local_clk: u64, + pub global_clk: u64, + pub is_write: bool, +} + +impl From<(&WordAddr, &RAMRecord, bool)> for GlobalRecord { + fn from((vma, record, is_write): (&WordAddr, &RAMRecord, bool)) -> Self { + let addr = match record.ram_type { + RAMType::Register => record.id as u32, + RAMType::Memory => (*vma).into(), + _ => unreachable!(), + }; + let value = record.prev_value.map_or(record.value, |v| v); + let (shard, local_clk, global_clk) = if is_write { + (record.shard_id, record.shard_cycle, record.cycle) + } else { + debug_assert_eq!(record.shard_cycle, 0); + (record.shard_id, 0, record.prev_cycle) + }; + + GlobalRecord { + addr, + ram_type: record.ram_type, + value, + shard: shard as u64, + local_clk, + global_clk, + is_write, + } + } +} +/// An EC point corresponding to a global read/write record +/// whose x-coordinate is derived from Poseidon2 hash of the record +#[derive(Clone, Debug)] +pub struct GlobalPoint { + pub nonce: u32, + pub point: SepticPoint, +} + +impl GlobalRecord { + pub fn to_ec_point>>( + &self, + hasher: &P, + ) -> GlobalPoint { + let mut nonce = 0; + let mut input = vec![ + E::BaseField::from_canonical_u32(self.addr), + E::BaseField::from_canonical_u32(self.ram_type as u32), + E::BaseField::from_canonical_u32(self.value & 0xFFFF), // lower 16 bits + E::BaseField::from_canonical_u32((self.value >> 16) & 0xFFFF), // higher 16 bits + E::BaseField::from_canonical_u64(self.shard), + E::BaseField::from_canonical_u64(self.global_clk), + E::BaseField::from_canonical_u32(nonce), + E::BaseField::ZERO, + E::BaseField::ZERO, + E::BaseField::ZERO, + E::BaseField::ZERO, + E::BaseField::ZERO, + E::BaseField::ZERO, + E::BaseField::ZERO, + E::BaseField::ZERO, + E::BaseField::ZERO, + ]; + + let prime = E::BaseField::order().to_u64_digits()[0]; + loop { + let x: SepticExtension = + hasher.permute(input.clone())[0..SEPTIC_EXTENSION_DEGREE].into(); + if let Some(p) = SepticPoint::from_x(x) { + let y6 = (p.y.0)[SEPTIC_EXTENSION_DEGREE - 1].to_canonical_u64(); + let is_y_in_2nd_half = y6 >= (prime / 2); + + // we negate y if needed + // to ensure read => y in [0, p/2) and write => y in [p/2, p) + let negate = match (self.is_write, is_y_in_2nd_half) { + (true, false) => true, // write, y in [0, p/2) + (false, true) => true, // read, y in [p/2, p) + _ => false, + }; + + let point = if negate { -p } else { p }; + + return GlobalPoint { nonce, point }; + } else { + // try again with different nonce + nonce += 1; + input[6] = E::BaseField::from_canonical_u32(nonce); + } + } + } +} +/// opcode circuit + mem init/final table + local finalize circuit + global chip +/// global chip is used to ensure the **local** reads and writes produced by +/// opcode circuits / memory init / memory finalize table / local finalize circuit +/// can balance out. +/// +/// 1. For a local memory read record whose previous write is not in the same shard, +/// the global chip will read it from the **global set** and insert a local write record. +/// 2. For a local memory write record which will **not** be read in the future, +/// the local finalize circuit will consume it by inserting a local read record. +/// 3. For a local memory write record which will be read in the future, +/// the global chip will insert a local read record and write it to the **global set**. +pub struct GlobalConfig { + addr: WitIn, + is_ram_register: WitIn, + value: UInt, + shard: WitIn, + global_clk: WitIn, + local_clk: WitIn, + nonce: WitIn, + is_shard_lt_cur: IsLtConfig, + // if it's a write to global set, then insert a local read record + // s.t. local offline memory checking can cancel out + // this serves as propagating local write to global. + is_global_write: WitIn, + x: Vec, + y: Vec, + slope: Vec, + perm_config: Poseidon2Config, +} + +impl GlobalConfig { + // TODO: make `WIDTH`, `HALF_FULL_ROUNDS`, `PARTIAL_ROUNDS` generic parameters + pub fn configure(cb: &mut CircuitBuilder) -> Result { + let x: Vec = (0..SEPTIC_EXTENSION_DEGREE) + .map(|i| cb.create_witin(|| format!("x{}", i))) + .collect(); + let y: Vec = (0..SEPTIC_EXTENSION_DEGREE) + .map(|i| cb.create_witin(|| format!("y{}", i))) + .collect(); + let slope: Vec = (0..SEPTIC_EXTENSION_DEGREE) + .map(|i| cb.create_witin(|| format!("slope{}", i))) + .collect(); + let addr = cb.create_witin(|| "addr"); + let is_ram_register = cb.create_bit(|| "is_ram_register")?; + let value = UInt::new_unchecked(|| "value", cb)?; + let shard = cb.create_witin(|| "shard"); + let global_clk = cb.create_witin(|| "global_clk"); + let local_clk = cb.create_witin(|| "local_clk"); + let nonce = cb.create_witin(|| "nonce"); + let is_global_write = cb.create_witin(|| "is_global_write"); + + let is_ram_reg: Expression = is_ram_register.expr(); + let reg: Expression = RAMType::Register.into(); + let mem: Expression = RAMType::Memory.into(); + let ram_type: Expression = is_ram_reg.clone() * reg + (1 - is_ram_reg) * mem; + + let mut input = vec![]; + input.push(addr.expr()); + input.push(ram_type.clone()); + // memory expr has same number of limbs as register expr + input.extend(value.memory_expr()); + input.push(shard.expr()); + input.push(global_clk.expr()); + // add nonce to ensure poseidon2(input) always map to a valid ec point + input.push(nonce.expr()); + input.extend(repeat_n(E::BaseField::ZERO.expr(), 16 - input.len())); + + let mut record = vec![]; + record.push(addr.expr()); + record.push(ram_type.clone()); + record.extend(value.memory_expr()); + record.push(local_clk.expr()); + + // if is_global_write = 1, then it means we are propagating a local write to global + // so we need to insert a local read record to cancel out this local write + cb.assert_bit(|| "is_global_write must be boolean", is_global_write.expr())?; + // TODO: for all local reads, enforce they come to global writes + // TODO: for all local writes, enforce they come from global reads + + // global read => insert a local write with local_clk = 0 + cb.condition_require_zero( + || "is_global_read => local_clk = 0", + 1 - is_global_write.expr(), + local_clk.expr(), + )?; + + // if it's global write => shard == cur_shard + let cur_shard = cb.query_shard_id()?; + cb.condition_require_zero( + || "global_write = true => shard = instance.shard", + is_global_write.expr(), + shard.expr() - Expression::Instance(cur_shard), + )?; + + // global read => shard < cur_shard + let is_shard_lt_cur = IsLtConfig::construct_circuit( + cb, + || "shard < cur_shard", + shard.expr(), + Expression::Instance(cur_shard), + 16, + )?; + cb.condition_require_equal( + || "global read => shard < cur_shard", + is_global_write.expr(), + is_shard_lt_cur.expr(), + E::BaseField::ONE.expr(), // true + E::BaseField::ZERO.expr(), // false + )?; + + cb.read_rlc_record( + || "r_record", + ram_type.clone(), + record.clone(), + cb.rlc_chip_record(record.clone()), + )?; + cb.write_rlc_record( + || "w_record", + ram_type, + record.clone(), + cb.rlc_chip_record(record), + )?; + + // enforces final_sum = \sum_i (x_i, y_i) using ecc quark protocol + let final_sum = cb.query_global_rw_sum()?; + cb.ec_sum( + x.iter().map(|xi| xi.expr()).collect::>(), + y.iter().map(|yi| yi.expr()).collect::>(), + slope.iter().map(|si| si.expr()).collect::>(), + final_sum.into_iter().map(|x| x.expr()).collect::>(), + ); + + let rc = ::get_default_perm_rc().into(); + let perm_config = Poseidon2Config::construct(cb, rc); + // enforces x = poseidon2([addr, ram_type, value[0], value[1], shard, global_clk, nonce, 0, ..., 0]) + for (input_expr, hasher_input) in input.into_iter().zip_eq(perm_config.inputs().into_iter()) + { + cb.require_equal(|| "poseidon2 input", input_expr, hasher_input)?; + } + for (xi, hasher_output) in x.iter().zip(perm_config.output().into_iter()) { + cb.require_equal(|| "x = poseidon2's output", xi.expr(), hasher_output)?; + } + + // both (x, y) and (x, -y) are valid ec points + // if is_global_write = 1, then y should be in [0, p/2) + // if is_global_write = 0, then y should be in [p/2, p) + + // TODO: enforce 0 <= y < p/2 if is_global_write = 1 + // enforce p/2 <= y < p if is_global_write = 0 + + Ok(GlobalConfig { + x, + y, + slope, + addr, + is_ram_register, + value, + shard, + is_shard_lt_cur, + global_clk, + local_clk, + nonce, + is_global_write, + perm_config, + }) + } +} + +/// This chip is used to manage read/write into a global set +/// shared among multiple shards +#[derive(Default)] +pub struct GlobalChip { + _marker: PhantomData, +} + +#[derive(Clone, Debug)] +pub struct GlobalChipInput { + pub record: GlobalRecord, + pub ec_point: GlobalPoint, +} + +impl GlobalChip { + fn assign_instance( + config: &GlobalConfig, + instance: &mut [E::BaseField], + lk_multiplicity: &mut LkMultiplicity, + input: &GlobalChipInput, + cur_shard: usize, + ) -> Result<(), crate::error::ZKVMError> { + // assign basic fields + let record = &input.record; + let is_ram_register = match record.ram_type { + RAMType::Register => 1, + RAMType::Memory => 0, + _ => unreachable!(), + }; + set_val!(instance, config.addr, record.addr as u64); + set_val!(instance, config.is_ram_register, is_ram_register as u64); + let value = Value::new_unchecked(record.value); + config.value.assign_limbs(instance, value.as_u16_limbs()); + set_val!(instance, config.shard, record.shard); + set_val!(instance, config.global_clk, record.global_clk); + set_val!(instance, config.local_clk, record.local_clk); + set_val!(instance, config.is_global_write, record.is_write as u64); + + config.is_shard_lt_cur.assign_instance( + instance, + lk_multiplicity, + record.shard, + cur_shard as u64, + )?; + + // assign (x, y) and nonce + let GlobalPoint { nonce, point } = &input.ec_point; + set_val!(instance, config.nonce, *nonce as u64); + config + .x + .iter() + .chain(config.y.iter()) + .zip_eq((point.x.deref()).iter().chain((point.y.deref()).iter())) + .for_each(|(witin, fe)| { + instance[witin.id as usize] = *fe; + }); + + let ram_type = E::BaseField::from_canonical_u32(record.ram_type as u32); + let mut input = [E::BaseField::ZERO; 16]; + + let k = UINT_LIMBS; + input[0] = E::BaseField::from_canonical_u32(record.addr); + input[1] = ram_type; + input[2..(k + 2)] + .iter_mut() + .zip(value.as_u16_limbs().iter()) + .for_each(|(i, v)| *i = E::BaseField::from_canonical_u16(*v)); + input[2 + k] = E::BaseField::from_canonical_u64(record.shard); + input[2 + k + 1] = E::BaseField::from_canonical_u64(record.global_clk); + input[2 + k + 2] = E::BaseField::from_canonical_u32(*nonce); + + let num_perm_polys = config.perm_config.num_polys(); + let offset = instance.len() - num_perm_polys; + config + .perm_config + .assign_instance(&mut instance[offset..], input); + + Ok(()) + } +} + +impl TableCircuit for GlobalChip { + type TableConfig = GlobalConfig; + type FixedInput = (); + type WitnessInput = (Vec>, usize); + + fn name() -> String { + "Global".to_string() + } + + fn construct_circuit( + cb: &mut CircuitBuilder, + _param: &ProgramParams, + ) -> Result { + let config = GlobalConfig::configure(cb)?; + + Ok(config) + } + + fn build_gkr_iop_circuit( + cb: &mut CircuitBuilder, + param: &ProgramParams, + ) -> Result<(Self::TableConfig, Option>), crate::error::ZKVMError> { + // create three selectors: selector_r, selector_w, selector_zero + let selector_r = cb.create_structural_witin( + || "selector_r", + // this is just a placeholder, the actural type is SelectorType::Prefix() + EqualDistanceSequence { + max_len: 0, + offset: 0, + multi_factor: 0, + descending: false, + }, + ); + let selector_w = cb.create_structural_witin( + || "selector_w", + // this is just a placeholder, the actural type is SelectorType::Prefix() + EqualDistanceSequence { + max_len: 0, + offset: 0, + multi_factor: 0, + descending: false, + }, + ); + let selector_zero = cb.create_structural_witin( + || "selector_zero", + // this is just a placeholder, the actural type is SelectorType::Prefix() + EqualDistanceSequence { + max_len: 0, + offset: 0, + multi_factor: 0, + descending: false, + }, + ); + + let config = Self::construct_circuit(cb, param)?; + + let w_len = cb.cs.w_expressions.len(); + let r_len = cb.cs.r_expressions.len(); + let lk_len = cb.cs.lk_expressions.len(); + let zero_len = + cb.cs.assert_zero_expressions.len() + cb.cs.assert_zero_sumcheck_expressions.len(); + + let selector_r = SelectorType::Prefix(selector_r.expr()); + // note that the actual offset should be set by prover + // depending on the number of local read instances + let selector_w = SelectorType::Prefix(selector_w.expr()); + // TODO: when selector_r = 1 => selector_zero = 1 + // when selector_w = 1 => selector_zero = 1 + let selector_zero = SelectorType::Prefix(selector_zero.expr()); + + cb.cs.r_selector = Some(selector_r); + cb.cs.w_selector = Some(selector_w); + cb.cs.zero_selector = Some(selector_zero.clone()); + cb.cs.lk_selector = Some(selector_zero); + + // all shared the same selector + let (out_evals, mut chip) = ( + [ + // r_record + (0..r_len).collect_vec(), + // w_record + (r_len..r_len + w_len).collect_vec(), + // lk_record + (r_len + w_len..r_len + w_len + lk_len).collect_vec(), + // zero_record + (0..zero_len).collect_vec(), + ], + Chip::new_from_cb(cb, 0), + ); + + let layer = Layer::from_circuit_builder(cb, format!("{}_main", Self::name()), 0, out_evals); + chip.add_layer(layer); + + Ok((config, Some(chip.gkr_circuit()))) + } + + fn generate_fixed_traces( + _config: &Self::TableConfig, + _num_fixed: usize, + _input: &Self::FixedInput, + ) -> witness::RowMajorMatrix<::BaseField> { + unimplemented!() + } + fn assign_instances<'a>( + config: &Self::TableConfig, + num_witin: usize, + num_structural_witin: usize, + _multiplicity: &[HashMap], + input: &Self::WitnessInput, + ) -> Result, ZKVMError> { + let steps = &input.0; + let cur_shard = input.1; + if steps.is_empty() { + return Ok([ + witness::RowMajorMatrix::empty(), + witness::RowMajorMatrix::empty(), + ]); + } + // FIXME selector is the only structural witness + // this is workaround, as call `construct_circuit` will not initialized selector + // we can remove this one all opcode unittest migrate to call `build_gkr_iop_circuit` + + assert_eq!(num_structural_witin, 3); + let selector_r_witin = WitIn { id: 0 }; + let selector_w_witin = WitIn { id: 1 }; + let selector_zero_witin = WitIn { id: 2 }; + + let nthreads = max_usable_threads(); + + // local read iff it's global write + // local reads are placed before local writes + // i.e. global writes are placed before global reads + let num_local_reads = steps.iter().filter(|s| s.record.is_write).count(); + tracing::debug!( + "{} local reads / {} local writes in global chip", + num_local_reads, + steps.len() - num_local_reads + ); + + let num_instance_per_batch = if steps.len() > 256 { + steps.len().div_ceil(nthreads) + } else { + steps.len() + } + .max(1); + + let n = next_pow2_instance_padding(steps.len()); + // compute the input for the binary tree for ec point summation + + let lk_multiplicity = LkMultiplicity::default(); + // *2 because we need to store the internal nodes of binary tree for ec point summation + let num_rows_padded = 2 * n; + + let mut raw_witin = { + let matrix_size = num_rows_padded * num_witin; + let mut value = Vec::with_capacity(matrix_size); + value.par_extend( + (0..matrix_size) + .into_par_iter() + .map(|_| E::BaseField::default()), + ); + RowMajorMatrix::new(value, num_witin) + }; + let mut raw_structual_witin = { + let matrix_size = num_rows_padded * num_structural_witin; + let mut value = Vec::with_capacity(matrix_size); + value.par_extend( + (0..matrix_size) + .into_par_iter() + .map(|_| E::BaseField::default()), + ); + RowMajorMatrix::new(value, num_structural_witin) + }; + let raw_witin_iter = raw_witin.values[0..steps.len() * num_witin] + .par_chunks_mut(num_instance_per_batch * num_witin); + let raw_structual_witin_iter = raw_structual_witin.values + [0..steps.len() * num_structural_witin] + .par_chunks_mut(num_instance_per_batch * num_structural_witin); + + raw_witin_iter + .zip_eq(raw_structual_witin_iter) + .zip_eq(steps.par_chunks(num_instance_per_batch)) + .enumerate() + .flat_map(|(chunk_idx, ((instances, structural_instance), steps))| { + let mut lk_multiplicity = lk_multiplicity.clone(); + instances + .chunks_mut(num_witin) + .zip_eq(structural_instance.chunks_mut(num_structural_witin)) + .zip_eq(steps) + .enumerate() + .map(|(i, ((instance, structural_instance), step))| { + let row = chunk_idx * num_instance_per_batch + i; + let (sel_r, sel_w) = if row < num_local_reads { + (E::BaseField::ONE, E::BaseField::ZERO) + } else { + (E::BaseField::ZERO, E::BaseField::ONE) + }; + set_val!(structural_instance, selector_r_witin, sel_r); + set_val!(structural_instance, selector_w_witin, sel_w); + set_val!(structural_instance, selector_zero_witin, E::BaseField::ONE); + Self::assign_instance( + config, + instance, + &mut lk_multiplicity, + step, + cur_shard, + ) + }) + .collect::>() + }) + .collect::>()?; + + // allocate num_rows_padded size, fill points on first half + let mut cur_layer_points_buffer: Vec<_> = (0..num_rows_padded) + .into_par_iter() + .map(|i| { + steps + .get(i) + .map(|step| step.ec_point.point.clone()) + .unwrap_or_else(SepticPoint::default) + }) + .collect(); + // raw_witin offset start from n. + // left node is at b, right node is at b + 1 + // op(left node, right node) = offset + b / 2 + let mut offset = num_rows_padded / 2; + let mut current_layer_len = cur_layer_points_buffer.len() / 2; + + // slope[1,b] = (input[b,0].y - input[b,1].y) / (input[b,0].x - input[b,1].x) + loop { + if current_layer_len <= 1 { + break; + } + let (current_layer, next_layer) = + cur_layer_points_buffer.split_at_mut(current_layer_len); + current_layer + .par_chunks(2) + .zip_eq(next_layer[..current_layer_len / 2].par_iter_mut()) + .zip(raw_witin.values[offset * num_witin..].par_chunks_mut(num_witin)) + .for_each(|((pair, parent), instance)| { + let p1 = &pair[0]; + let p2 = &pair[1]; + let (slope, q) = if p2.is_infinity { + // input[1,b] = bypass_left(input[b,0], input[b,1]) + (SepticExtension::zero(), p1.clone()) + } else { + // input[1,b] = affine_add(input[b,0], input[b,1]) + let slope = (&p1.y - &p2.y) * (&p1.x - &p2.x).inverse().unwrap(); + let q = p1.clone() + p2.clone(); + (slope, q) + }; + config + .x + .iter() + .chain(config.y.iter()) + .chain(config.slope.iter()) + .zip_eq(chain!( + q.x.deref().iter(), + q.y.deref().iter(), + slope.deref().iter(), + )) + .for_each(|(witin, fe)| { + set_val!(instance, *witin, *fe); + }); + *parent = q.clone(); + }); + cur_layer_points_buffer = cur_layer_points_buffer.split_off(current_layer_len); + current_layer_len /= 2; + offset += current_layer_len; + } + + let raw_witin = witness::RowMajorMatrix::new_by_inner_matrix( + raw_witin, + InstancePaddingStrategy::Default, + ); + let raw_structual_witin = witness::RowMajorMatrix::new_by_inner_matrix( + raw_structual_witin, + InstancePaddingStrategy::Default, + ); + Ok([raw_witin, raw_structual_witin]) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use ff_ext::{BabyBearExt4, FromUniformBytes, PoseidonField}; + use itertools::Itertools; + use mpcs::{BasefoldDefault, PolynomialCommitmentScheme, SecurityLevel}; + use p3::babybear::BabyBear; + use rand::thread_rng; + use tracing_forest::{ForestLayer, util::LevelFilter}; + use tracing_subscriber::{EnvFilter, Registry, layer::SubscriberExt, util::SubscriberInitExt}; + use transcript::BasicTranscript; + + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::global::{GlobalChip, GlobalChipInput, GlobalRecord}, + scheme::{ + PublicValues, create_backend, create_prover, hal::ProofInput, prover::ZKVMProver, + septic_curve::SepticPoint, verifier::ZKVMVerifier, + }, + structs::{ComposedConstrainSystem, PointAndEval, ProgramParams, RAMType, ZKVMProvingKey}, + tables::TableCircuit, + }; + use multilinear_extensions::mle::IntoMLE; + use p3::field::PrimeField32; + + type E = BabyBearExt4; + type F = BabyBear; + type Perm = ::P; + type Pcs = BasefoldDefault; + + #[test] + fn test_global_chip() { + // default filter + let default_filter = EnvFilter::builder() + .with_default_directive(LevelFilter::DEBUG.into()) + .from_env_lossy(); + + Registry::default() + .with(ForestLayer::default()) + .with(default_filter) + .init(); + + // init global chip with horizen_rc_consts + let perm = ::get_default_perm(); + + let mut cs = ConstraintSystem::new(|| "global chip test"); + let mut cb = CircuitBuilder::new(&mut cs); + + let (config, gkr_circuit) = + GlobalChip::build_gkr_iop_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + // create a bunch of random memory read/write records + let n_global_reads = 1700; + let n_global_writes = 1420; + let prev_shard = 0; + let cur_shard = 1; + let global_reads = (0..n_global_reads) + .map(|i| { + let addr = i * 8; + let value = (i + 1) * 8; + + GlobalRecord { + addr: addr as u32, + ram_type: RAMType::Memory, + value: value as u32, + shard: prev_shard, + local_clk: 0, + global_clk: i, + is_write: false, + } + }) + .collect::>(); + + let global_writes = (0..n_global_writes) + .map(|i| { + let addr = i * 8; + let value = (i + 1) * 8; + + GlobalRecord { + addr: addr as u32, + ram_type: RAMType::Memory, + value: value as u32, + shard: cur_shard, + local_clk: i, + global_clk: i, + is_write: true, + } + }) + .collect::>(); + + let input = global_writes // local reads + .into_iter() + .chain(global_reads) // local writes + .map(|record| { + let ec_point = record.to_ec_point::(&perm); + GlobalChipInput { record, ec_point } + }) + .collect::>(); + + let global_ec_sum: SepticPoint = input + .iter() + .map(|record| record.ec_point.point.clone()) + .sum(); + + let public_value = PublicValues::new( + 0, + 0, + 0, + 0, + 0, + cur_shard as u32, + vec![0], // dummy + global_ec_sum + .x + .iter() + .chain(global_ec_sum.y.iter()) + .map(|fe| fe.as_canonical_u32()) + .collect_vec(), + ); + + tracing::debug!("num_witin: {}", cs.num_witin); + // assign witness + let witness = GlobalChip::assign_instances( + &config, + cs.num_witin as usize, + cs.num_structural_witin as usize, + &[], + &(input, cur_shard as usize), + ) + .unwrap(); + + let composed_cs = ComposedConstrainSystem { + zkvm_v1_css: cs, + gkr_circuit, + }; + let pk = composed_cs.key_gen(); + + // create chip proof for global chip + let pcs_param = Pcs::setup(1 << 20, SecurityLevel::Conjecture100bits).unwrap(); + let (pp, vp) = Pcs::trim(pcs_param, 1 << 20).unwrap(); + let backend = create_backend::(20, SecurityLevel::Conjecture100bits); + let pd = create_prover(backend); + + let zkvm_pk = ZKVMProvingKey::new(pp, vp); + let zkvm_vk = zkvm_pk.get_vk_slow(); + let zkvm_prover = ZKVMProver::new(zkvm_pk, pd); + let mut transcript = BasicTranscript::new(b"global chip test"); + + let public_input_mles = public_value + .to_vec::() + .into_iter() + .map(|v| Arc::new(v.into_mle())) + .collect_vec(); + let proof_input = ProofInput { + witness: witness[0].to_mles().into_iter().map(Arc::new).collect(), + structural_witness: witness[1].to_mles().into_iter().map(Arc::new).collect(), + fixed: vec![], + public_input: public_input_mles.clone(), + num_instances: vec![n_global_writes as usize, n_global_reads as usize], + has_ecc_ops: true, + }; + let mut rng = thread_rng(); + let challenges = [E::random(&mut rng), E::random(&mut rng)]; + let (proof, _, point) = zkvm_prover + .create_chip_proof( + "global chip", + &pk, + proof_input, + &mut transcript, + &challenges, + ) + .unwrap(); + + let mut transcript = BasicTranscript::new(b"global chip test"); + let verifier = ZKVMVerifier::new(zkvm_vk); + let pi_evals = public_input_mles + .iter() + .map(|mle| mle.evaluate(&point[..mle.num_vars()])) + .collect_vec(); + let vrf_point = verifier + .verify_opcode_proof( + "global", + &pk.vk, + &proof, + &pi_evals, + &mut transcript, + 2, + &PointAndEval::default(), + &challenges, + ) + .expect("verify global chip proof"); + assert_eq!(vrf_point, point); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/arith.rs b/ceno_zkvm/src/instructions/riscv/arith.rs index b73abcda4..a94024b4a 100644 --- a/ceno_zkvm/src/instructions/riscv/arith.rs +++ b/ceno_zkvm/src/instructions/riscv/arith.rs @@ -2,8 +2,8 @@ use std::marker::PhantomData; use super::{RIVInstruction, constants::UInt, r_insn::RInstructionConfig}; use crate::{ - circuit_builder::CircuitBuilder, error::ZKVMError, instructions::Instruction, - structs::ProgramParams, uint::Value, witness::LkMultiplicity, + circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, + instructions::Instruction, structs::ProgramParams, uint::Value, witness::LkMultiplicity, }; use ceno_emul::{InsnKind, StepRecord}; use ff_ext::ExtensionField; @@ -87,13 +87,14 @@ impl Instruction for ArithInstruction::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { config .r_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rs2_read = Value::new_unchecked(step.rs2().unwrap().value); config @@ -186,6 +187,7 @@ mod test { let insn_code = encode_rv32(I::INST_KIND, 2, 3, 4, 0); let (raw_witin, lkm) = ArithInstruction::::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm.rs b/ceno_zkvm/src/instructions/riscv/arith_imm.rs index a040681bc..4de4069d0 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm.rs @@ -21,6 +21,7 @@ mod test { use crate::{ Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{Instruction, riscv::constants::UInt}, scheme::mock_prover::{MOCK_PC_START, MockProver}, structs::ProgramParams, @@ -63,6 +64,7 @@ mod test { let insn_code = encode_rv32(InsnKind::ADDI, 2, 0, 4, imm); let (raw_witin, lkm) = AddiInstruction::::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_i_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit.rs b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit.rs index 8a4722a08..11d93242c 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit.rs @@ -1,6 +1,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -58,6 +59,7 @@ impl Instruction for AddiInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, @@ -77,7 +79,7 @@ impl Instruction for AddiInstruction { config .i_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; Ok(()) } diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs index f969a68b0..8ed175d58 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs @@ -1,6 +1,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -70,6 +71,7 @@ impl Instruction for AddiInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, @@ -93,7 +95,7 @@ impl Instruction for AddiInstruction { config .i_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; Ok(()) } diff --git a/ceno_zkvm/src/instructions/riscv/auipc.rs b/ceno_zkvm/src/instructions/riscv/auipc.rs index 7957f7003..3244c5d60 100644 --- a/ceno_zkvm/src/instructions/riscv/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/auipc.rs @@ -4,6 +4,7 @@ use std::marker::PhantomData; use crate::{ circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -142,13 +143,14 @@ impl Instruction for AuipcInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &ceno_emul::StepRecord, ) -> Result<(), ZKVMError> { config .i_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rd_written = split_to_u8(step.rd().unwrap().value.after); config.rd_written.assign_limbs(instance, &rd_written); @@ -189,6 +191,7 @@ mod tests { use crate::{ Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{ Instruction, riscv::{auipc::AuipcInstruction, constants::UInt}, @@ -239,6 +242,7 @@ mod tests { let insn_code = encode_rv32(InsnKind::AUIPC, 0, 0, 4, imm); let (raw_witin, lkm) = AuipcInstruction::::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_i_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/b_insn.rs b/ceno_zkvm/src/instructions/riscv/b_insn.rs index 798902754..cdc1db56d 100644 --- a/ceno_zkvm/src/instructions/riscv/b_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/b_insn.rs @@ -5,6 +5,7 @@ use super::constants::PC_STEP_SIZE; use crate::{ chip_handler::{RegisterExpr, general::InstFetch}, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::riscv::insn_base::{ReadRS1, ReadRS2, StateInOut}, tables::InsnRecord, @@ -12,7 +13,6 @@ use crate::{ }; use ff_ext::FieldInto; use multilinear_extensions::{Expression, ToExpr, WitIn}; - // Opcode: 1100011 // Funct3: // 000 BEQ @@ -89,12 +89,15 @@ impl BInstructionConfig { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - self.vm_state.assign_instance(instance, step)?; - self.rs1.assign_instance(instance, lk_multiplicity, step)?; - self.rs2.assign_instance(instance, lk_multiplicity, step)?; + self.vm_state.assign_instance(instance, shard_ctx, step)?; + self.rs1 + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; + self.rs2 + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; // Immediate set_val!( diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs index 8aecd50f8..2c97a12ee 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs @@ -6,6 +6,7 @@ use ff_ext::ExtensionField; use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::{IsEqualConfig, IsLtConfig, SignedLtConfig}, instructions::{ @@ -137,13 +138,14 @@ impl Instruction for BranchCircuit Result<(), ZKVMError> { config .b_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rs1 = Value::new_unchecked(step.rs1().unwrap().value); let rs2 = Value::new_unchecked(step.rs2().unwrap().value); diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs index 94abb56d1..386d2c286 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs @@ -1,6 +1,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::{UIntLimbsLT, UIntLimbsLTConfig}, instructions::{ @@ -68,13 +69,14 @@ impl Instruction for BranchCircuit Result<(), ZKVMError> { config .b_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rs1 = Value::new_unchecked(step.rs1().unwrap().value); let rs1_limbs = rs1.as_u16_limbs(); diff --git a/ceno_zkvm/src/instructions/riscv/branch/test.rs b/ceno_zkvm/src/instructions/riscv/branch/test.rs index aaf468127..82dbcffac 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/test.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/test.rs @@ -6,6 +6,7 @@ use ff_ext::{ExtensionField, GoldilocksExt2}; use super::*; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, error::ZKVMError, instructions::Instruction, scheme::mock_prover::{MOCK_PC_START, MockProver}, @@ -39,6 +40,7 @@ fn impl_opcode_beq(equal: bool) { let pc_offset = if equal { 8 } else { PC_STEP_SIZE }; let (raw_witin, lkm) = BeqInstruction::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_b_instruction( @@ -79,6 +81,7 @@ fn impl_opcode_bne(equal: bool) { let pc_offset = if equal { PC_STEP_SIZE } else { 8 }; let (raw_witin, lkm) = BneInstruction::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_b_instruction( @@ -122,6 +125,7 @@ fn impl_bltu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { let insn_code = encode_rv32(InsnKind::BLTU, 2, 3, 0, -8); let (raw_witin, lkm) = BltuInstruction::assign_instances( &config, + &mut ShardContext::default(), circuit_builder.cs.num_witin as usize, circuit_builder.cs.num_structural_witin as usize, vec![StepRecord::new_b_instruction( @@ -166,6 +170,7 @@ fn impl_bgeu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { let insn_code = encode_rv32(InsnKind::BGEU, 2, 3, 0, -8); let (raw_witin, lkm) = BgeuInstruction::assign_instances( &config, + &mut ShardContext::default(), circuit_builder.cs.num_witin as usize, circuit_builder.cs.num_structural_witin as usize, vec![StepRecord::new_b_instruction( @@ -217,6 +222,7 @@ fn impl_blt_circuit(taken: bool, a: i32, b: i32) -> Result<() let insn_code = encode_rv32(InsnKind::BLT, 2, 3, 0, -8); let (raw_witin, lkm) = BltInstruction::assign_instances( &config, + &mut ShardContext::default(), circuit_builder.cs.num_witin as usize, circuit_builder.cs.num_structural_witin as usize, vec![StepRecord::new_b_instruction( @@ -268,6 +274,7 @@ fn impl_bge_circuit(taken: bool, a: i32, b: i32) -> Result<() let insn_code = encode_rv32(InsnKind::BGE, 2, 3, 0, -8); let (raw_witin, lkm) = BgeInstruction::assign_instances( &config, + &mut ShardContext::default(), circuit_builder.cs.num_witin as usize, circuit_builder.cs.num_structural_witin as usize, vec![StepRecord::new_b_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/constants.rs b/ceno_zkvm/src/instructions/riscv/constants.rs index 1992f4fa3..d98412b6f 100644 --- a/ceno_zkvm/src/instructions/riscv/constants.rs +++ b/ceno_zkvm/src/instructions/riscv/constants.rs @@ -9,7 +9,9 @@ pub const INIT_PC_IDX: usize = 2; pub const INIT_CYCLE_IDX: usize = 3; pub const END_PC_IDX: usize = 4; pub const END_CYCLE_IDX: usize = 5; -pub const PUBLIC_IO_IDX: usize = 6; +pub const END_SHARD_ID_IDX: usize = 6; +pub const PUBLIC_IO_IDX: usize = 7; +pub const GLOBAL_RW_SUM_IDX: usize = PUBLIC_IO_IDX + 2; pub const LIMB_BITS: usize = 16; pub const LIMB_MASK: u32 = 0xFFFF; diff --git a/ceno_zkvm/src/instructions/riscv/div.rs b/ceno_zkvm/src/instructions/riscv/div.rs index 7ca30d2b8..966320407 100644 --- a/ceno_zkvm/src/instructions/riscv/div.rs +++ b/ceno_zkvm/src/instructions/riscv/div.rs @@ -53,6 +53,7 @@ mod test { use crate::{ Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{ Instruction, riscv::{ @@ -179,6 +180,7 @@ mod test { // values assignment let ([raw_witin, _], lkm) = Insn::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/div/div_circuit.rs b/ceno_zkvm/src/instructions/riscv/div/div_circuit.rs index ef5b9d936..99a73a8a4 100644 --- a/ceno_zkvm/src/instructions/riscv/div/div_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/div/div_circuit.rs @@ -75,6 +75,7 @@ use super::{ }; use crate::{ circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::{AssertLtConfig, IsEqualConfig, IsLtConfig, IsZeroConfig, Signed}, instructions::{Instruction, riscv::constants::LIMB_BITS}, @@ -310,6 +311,7 @@ impl Instruction for ArithInstruction Instruction for ArithInstruction Instruction for ArithInstruction Instruction for ArithInstruction (true, true), diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs index 7c98e2159..1df279dd9 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs @@ -9,9 +9,9 @@ use super::super::{ insn_base::{ReadMEM, ReadRS1, ReadRS2, StateInOut, WriteMEM, WriteRD}, }; use crate::{ - chip_handler::general::InstFetch, circuit_builder::CircuitBuilder, error::ZKVMError, - instructions::Instruction, structs::ProgramParams, tables::InsnRecord, uint::Value, - witness::LkMultiplicity, + chip_handler::general::InstFetch, circuit_builder::CircuitBuilder, e2e::ShardContext, + error::ZKVMError, instructions::Instruction, structs::ProgramParams, tables::InsnRecord, + uint::Value, witness::LkMultiplicity, }; use ff_ext::FieldInto; use multilinear_extensions::{ToExpr, WitIn}; @@ -70,11 +70,12 @@ impl Instruction for DummyInstruction::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - config.assign_instance(instance, lk_multiplicity, step) + config.assign_instance(instance, shard_ctx, lk_multiplicity, step) } } @@ -242,30 +243,31 @@ impl DummyConfig { pub(super) fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { // State in and out - self.vm_state.assign_instance(instance, step)?; + self.vm_state.assign_instance(instance, shard_ctx, step)?; // Fetch instruction lk_multiplicity.fetch(step.pc().before.0); // Registers if let Some((rs1_op, rs1_read)) = &self.rs1 { - rs1_op.assign_instance(instance, lk_multiplicity, step)?; + rs1_op.assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rs1_val = Value::new_unchecked(step.rs1().expect("rs1 value").value); rs1_read.assign_value(instance, rs1_val); } if let Some((rs2_op, rs2_read)) = &self.rs2 { - rs2_op.assign_instance(instance, lk_multiplicity, step)?; + rs2_op.assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rs2_val = Value::new_unchecked(step.rs2().expect("rs2 value").value); rs2_read.assign_value(instance, rs2_val); } if let Some((rd_op, rd_written)) = &self.rd { - rd_op.assign_instance(instance, lk_multiplicity, step)?; + rd_op.assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rd_val = Value::new_unchecked(step.rd().expect("rd value").value.after); rd_written.assign_value(instance, rd_val); @@ -284,10 +286,10 @@ impl DummyConfig { mem_after.assign_value(instance, Value::new(mem_op.value.after, lk_multiplicity)); } if let Some(mem_read) = &self.mem_read { - mem_read.assign_instance(instance, lk_multiplicity, step)?; + mem_read.assign_instance(instance, shard_ctx, lk_multiplicity, step)?; } if let Some(mem_write) = &self.mem_write { - mem_write.assign_instance::(instance, lk_multiplicity, step)?; + mem_write.assign_instance::(instance, shard_ctx, lk_multiplicity, step)?; } let imm = InsnRecord::::imm_internal(&step.insn()).1; diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs index 69bdd1648..9cd5cb0f3 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs @@ -8,6 +8,7 @@ use super::{super::insn_base::WriteMEM, dummy_circuit::DummyConfig}; use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -84,6 +85,7 @@ impl Instruction for LargeEcallDummy fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, @@ -93,14 +95,14 @@ impl Instruction for LargeEcallDummy // Assign instruction. config .dummy_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; set_val!(instance, config.start_addr, u64::from(ops.mem_ops[0].addr)); // Assign registers. for ((value, writer), op) in config.reg_writes.iter().zip_eq(&ops.reg_ops) { value.assign_value(instance, Value::new_unchecked(op.value.after)); - writer.assign_op(instance, lk_multiplicity, step.cycle(), op)?; + writer.assign_op(instance, shard_ctx, lk_multiplicity, step.cycle(), op)?; } // Assign memory. @@ -112,7 +114,7 @@ impl Instruction for LargeEcallDummy .after .assign_value(instance, Value::new(op.value.after, lk_multiplicity)); set_val!(instance, addr, u64::from(op.addr)); - writer.assign_op(instance, lk_multiplicity, step.cycle(), op)?; + writer.assign_op(instance, shard_ctx, lk_multiplicity, step.cycle(), op)?; } Ok(()) diff --git a/ceno_zkvm/src/instructions/riscv/dummy/test.rs b/ceno_zkvm/src/instructions/riscv/dummy/test.rs index 6f7a89f73..c6f51d142 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/test.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/test.rs @@ -4,6 +4,7 @@ use ff_ext::GoldilocksExt2; use super::*; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{ Instruction, riscv::{arith::AddOp, branch::BeqOp, ecall::EcallDummy}, @@ -34,6 +35,7 @@ fn test_dummy_ecall() { let insn_code = step.insn(); let (raw_witin, lkm) = EcallDummy::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![step], @@ -63,6 +65,7 @@ fn test_dummy_keccak() { let (step, program) = ceno_emul::test_utils::keccak_step(); let (raw_witin, lkm) = KeccakDummy::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![step], @@ -90,6 +93,7 @@ fn test_dummy_r() { let insn_code = encode_rv32(InsnKind::ADD, 2, 3, 4, 0); let (raw_witin, lkm) = AddDummy::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( @@ -125,6 +129,7 @@ fn test_dummy_b() { let insn_code = encode_rv32(InsnKind::BEQ, 2, 3, 0, 8); let (raw_witin, lkm) = BeqDummy::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_b_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs index e14585727..bf38a67c4 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs @@ -1,6 +1,7 @@ use crate::{ chip_handler::{RegisterChipOperations, general::PublicIOQuery}, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::AssertLtConfig, instructions::{ @@ -70,6 +71,7 @@ impl Instruction for HaltInstruction { fn assign_instance( config: &Self::InstructionConfig, + _shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, diff --git a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs index b0ac2a505..dccdf34a2 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs @@ -21,6 +21,7 @@ use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ chip_handler::general::InstFetch, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -156,6 +157,7 @@ impl Instruction for KeccakInstruction { fn assign_instance( _config: &Self::InstructionConfig, + _shard_ctx: &mut ShardContext, _instance: &mut [::BaseField], _lk_multiplicity: &mut LkMultiplicity, _step: &StepRecord, @@ -165,6 +167,7 @@ impl Instruction for KeccakInstruction { fn assign_instances( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, steps: Vec, @@ -196,11 +199,13 @@ impl Instruction for KeccakInstruction { // each instance are composed of KECCAK_ROUNDS.next_power_of_two() let raw_witin_iter = raw_witin .par_batch_iter_mut(num_instance_per_batch * KECCAK_ROUNDS.next_power_of_two()); + let shard_ctx_vec = shard_ctx.get_forked(); // 1st pass: assign witness outside of gkr-iop scope raw_witin_iter .zip_eq(steps.par_chunks(num_instance_per_batch)) - .flat_map(|(instances, steps)| { + .zip(shard_ctx_vec) + .flat_map(|((instances, steps), mut shard_ctx)| { let mut lk_multiplicity = lk_multiplicity.clone(); instances @@ -218,10 +223,13 @@ impl Instruction for KeccakInstruction { [round_index as usize * num_witin..][..num_witin]; // vm_state - config.vm_state.assign_instance(instance, step)?; + config + .vm_state + .assign_instance(instance, &shard_ctx, step)?; config.ecall_id.assign_op( instance, + &mut shard_ctx, &mut lk_multiplicity, step.cycle(), &WriteOp::new_register_op( @@ -238,6 +246,7 @@ impl Instruction for KeccakInstruction { )?; config.state_ptr.0.assign_op( instance, + &mut shard_ctx, &mut lk_multiplicity, step.cycle(), &ops.reg_ops[0], @@ -246,6 +255,7 @@ impl Instruction for KeccakInstruction { for (writer, op) in config.mem_rw.iter().zip_eq(&ops.mem_ops) { writer.assign_op( instance, + &mut shard_ctx, &mut lk_multiplicity, step.cycle(), op, diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs index 6365cfcd2..adf52683f 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs @@ -24,6 +24,7 @@ use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ chip_handler::general::InstFetch, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -207,6 +208,7 @@ impl Instruction fn assign_instance( _config: &Self::InstructionConfig, + _shard_ctx: &mut ShardContext, _instance: &mut [::BaseField], _lk_multiplicity: &mut LkMultiplicity, _step: &StepRecord, @@ -216,6 +218,7 @@ impl Instruction fn assign_instances( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, steps: Vec, @@ -255,11 +258,13 @@ impl Instruction ); let raw_witin_iter = raw_witin.par_batch_iter_mut(num_instance_per_batch); + let shard_ctx_vec = shard_ctx.get_forked(); // 1st pass: assign witness outside of gkr-iop scope raw_witin_iter .zip_eq(steps.par_chunks(num_instance_per_batch)) - .flat_map(|(instances, steps)| { + .zip(shard_ctx_vec) + .flat_map(|((instances, steps), mut shard_ctx)| { let mut lk_multiplicity = lk_multiplicity.clone(); instances @@ -269,10 +274,13 @@ impl Instruction let ops = &step.syscall().expect("syscall step"); // vm_state - config.vm_state.assign_instance(instance, step)?; + config + .vm_state + .assign_instance(instance, &shard_ctx, step)?; config.ecall_id.assign_op( instance, + &mut shard_ctx, &mut lk_multiplicity, step.cycle(), &WriteOp::new_register_op( @@ -289,6 +297,7 @@ impl Instruction )?; config.point_ptr_0.0.assign_op( instance, + &mut shard_ctx, &mut lk_multiplicity, step.cycle(), &ops.reg_ops[0], @@ -301,12 +310,19 @@ impl Instruction )?; config.point_ptr_1.0.assign_op( instance, + &mut shard_ctx, &mut lk_multiplicity, step.cycle(), &ops.reg_ops[1], )?; for (writer, op) in config.mem_rw.iter().zip_eq(&ops.mem_ops) { - writer.assign_op(instance, &mut lk_multiplicity, step.cycle(), op)?; + writer.assign_op( + instance, + &mut shard_ctx, + &mut lk_multiplicity, + step.cycle(), + op, + )?; } // fetch lk_multiplicity.fetch(step.pc().before.0); diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs index 6003f9794..250141669 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs @@ -31,6 +31,7 @@ use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ chip_handler::general::InstFetch, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -208,6 +209,7 @@ impl Instruction::BaseField], _lk_multiplicity: &mut LkMultiplicity, _step: &StepRecord, @@ -217,6 +219,7 @@ impl Instruction, @@ -254,12 +257,14 @@ impl Instruction::WordsFieldElement::USIZE; // 1st pass: assign witness outside of gkr-iop scope let sign_bit_and_y_words = raw_witin_iter .zip_eq(steps.par_chunks(num_instance_per_batch)) - .flat_map(|(instances, steps)| { + .zip(shard_ctx_vec) + .flat_map(|((instances, steps), mut shard_ctx)| { let mut lk_multiplicity = lk_multiplicity.clone(); instances @@ -269,9 +274,12 @@ impl Instruction Instruction Instruction Instruction::BaseField], _lk_multiplicity: &mut LkMultiplicity, _step: &StepRecord, @@ -188,6 +190,7 @@ impl Instruction, @@ -227,11 +230,13 @@ impl Instruction Instruction Instruction OpFixedRS Result<(), ZKVMError> { - set_val!(instance, self.prev_ts, op.previous_cycle); + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); + let shard_cycle = cycle - current_shard_offset_cycle; + set_val!(instance, self.prev_ts, shard_prev_cycle); // Register state if let Some(prev_value) = self.prev_value.as_ref() { @@ -76,17 +82,30 @@ impl OpFixedRS IInstructionConfig { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - self.vm_state.assign_instance(instance, step)?; - self.rs1.assign_instance(instance, lk_multiplicity, step)?; - self.rd.assign_instance(instance, lk_multiplicity, step)?; + self.vm_state.assign_instance(instance, shard_ctx, step)?; + self.rs1 + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; + self.rd + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; // Fetch instruction lk_multiplicity.fetch(step.pc().before.0); diff --git a/ceno_zkvm/src/instructions/riscv/im_insn.rs b/ceno_zkvm/src/instructions/riscv/im_insn.rs index 5fa6cd501..c7f6cace0 100644 --- a/ceno_zkvm/src/instructions/riscv/im_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/im_insn.rs @@ -7,6 +7,7 @@ use crate::{ witness::LkMultiplicity, }; +use crate::e2e::ShardContext; use ceno_emul::{InsnKind, StepRecord}; use ff_ext::ExtensionField; use multilinear_extensions::{Expression, ToExpr}; @@ -67,14 +68,17 @@ impl IMInstructionConfig { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - self.vm_state.assign_instance(instance, step)?; - self.rs1.assign_instance(instance, lk_multiplicity, step)?; - self.rd.assign_instance(instance, lk_multiplicity, step)?; + self.vm_state.assign_instance(instance, shard_ctx, step)?; + self.rs1 + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; + self.rd + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; self.mem_read - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; // Fetch instruction lk_multiplicity.fetch(step.pc().before.0); diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index 43a72f739..4877df9d1 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -10,8 +10,10 @@ use crate::{ RegisterChipOperations, RegisterExpr, }, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::AssertLtConfig, + structs::RAMType, uint::Value, witness::{LkMultiplicity, set_val}, }; @@ -58,14 +60,17 @@ impl StateInOut { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &ShardContext, // lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { + let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); + set_val!(instance, self.pc, step.pc().before.0 as u64); if let Some(n_pc) = self.next_pc { set_val!(instance, n_pc, step.pc().after.0 as u64); } - set_val!(instance, self.ts, step.cycle()); + set_val!(instance, self.ts, step.cycle() - current_shard_offset_cycle); Ok(()) } @@ -106,20 +111,33 @@ impl ReadRS1 { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { let op = step.rs1().expect("rs1 op"); + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); + let shard_cycle = step.cycle() - current_shard_offset_cycle; set_val!(instance, self.id, op.register_index() as u64); - set_val!(instance, self.prev_ts, op.previous_cycle); + set_val!(instance, self.prev_ts, shard_prev_cycle); // Register read self.lt_cfg.assign_instance( instance, lk_multiplicity, - op.previous_cycle, - step.cycle() + Tracer::SUBCYCLE_RS1, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_RS1, )?; + shard_ctx.send( + RAMType::Register, + op.addr, + op.register_index() as u64, + step.cycle() + Tracer::SUBCYCLE_RS1, + op.previous_cycle, + op.value, + None, + ); Ok(()) } @@ -160,21 +178,35 @@ impl ReadRS2 { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { let op = step.rs2().expect("rs2 op"); + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); + let shard_cycle = step.cycle() - current_shard_offset_cycle; set_val!(instance, self.id, op.register_index() as u64); - set_val!(instance, self.prev_ts, op.previous_cycle); + set_val!(instance, self.prev_ts, shard_prev_cycle); // Register read self.lt_cfg.assign_instance( instance, lk_multiplicity, - op.previous_cycle, - step.cycle() + Tracer::SUBCYCLE_RS2, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_RS2, )?; + shard_ctx.send( + RAMType::Register, + op.addr, + op.register_index() as u64, + step.cycle() + Tracer::SUBCYCLE_RS2, + op.previous_cycle, + op.value, + None, + ); + Ok(()) } } @@ -216,22 +248,27 @@ impl WriteRD { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { let op = step.rd().expect("rd op"); - self.assign_op(instance, lk_multiplicity, step.cycle(), &op) + self.assign_op(instance, shard_ctx, lk_multiplicity, step.cycle(), &op) } pub fn assign_op( &self, instance: &mut [E::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, cycle: Cycle, op: &WriteOp, ) -> Result<(), ZKVMError> { + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); + let shard_cycle = cycle - current_shard_offset_cycle; set_val!(instance, self.id, op.register_index() as u64); - set_val!(instance, self.prev_ts, op.previous_cycle); + set_val!(instance, self.prev_ts, shard_prev_cycle); // Register state self.prev_value.assign_limbs( @@ -243,9 +280,18 @@ impl WriteRD { self.lt_cfg.assign_instance( instance, lk_multiplicity, - op.previous_cycle, - cycle + Tracer::SUBCYCLE_RD, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_RD, )?; + shard_ctx.send( + RAMType::Register, + op.addr, + op.register_index() as u64, + cycle + Tracer::SUBCYCLE_RD, + op.previous_cycle, + op.value.after, + Some(op.value.before), + ); Ok(()) } @@ -284,24 +330,35 @@ impl ReadMEM { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { + let op = step.memory_op().unwrap(); + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); + let shard_cycle = step.cycle() - current_shard_offset_cycle; // Memory state - set_val!( - instance, - self.prev_ts, - step.memory_op().unwrap().previous_cycle - ); + set_val!(instance, self.prev_ts, shard_prev_cycle); // Memory read self.lt_cfg.assign_instance( instance, lk_multiplicity, - step.memory_op().unwrap().previous_cycle, - step.cycle() + Tracer::SUBCYCLE_MEM, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_MEM, )?; + shard_ctx.send( + RAMType::Memory, + op.addr, + op.addr.baddr().0 as u64, + step.cycle() + Tracer::SUBCYCLE_MEM, + op.previous_cycle, + op.value.after, + None, + ); + Ok(()) } } @@ -337,29 +394,44 @@ impl WriteMEM { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { let op = step.memory_op().unwrap(); - self.assign_op(instance, lk_multiplicity, step.cycle(), &op) + self.assign_op(instance, shard_ctx, lk_multiplicity, step.cycle(), &op) } pub fn assign_op( &self, instance: &mut [F], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, cycle: Cycle, op: &WriteOp, ) -> Result<(), ZKVMError> { - set_val!(instance, self.prev_ts, op.previous_cycle); + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); + let shard_cycle = cycle - current_shard_offset_cycle; + set_val!(instance, self.prev_ts, shard_prev_cycle); self.lt_cfg.assign_instance( instance, lk_multiplicity, - op.previous_cycle, - cycle + Tracer::SUBCYCLE_MEM, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_MEM, )?; + shard_ctx.send( + RAMType::Memory, + op.addr, + op.addr.baddr().0 as u64, + cycle + Tracer::SUBCYCLE_MEM, + op.previous_cycle, + op.value.after, + Some(op.value.before), + ); + Ok(()) } } diff --git a/ceno_zkvm/src/instructions/riscv/j_insn.rs b/ceno_zkvm/src/instructions/riscv/j_insn.rs index 156aa1cd1..84cb84679 100644 --- a/ceno_zkvm/src/instructions/riscv/j_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/j_insn.rs @@ -4,13 +4,13 @@ use ff_ext::ExtensionField; use crate::{ chip_handler::{RegisterExpr, general::InstFetch}, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::riscv::insn_base::{StateInOut, WriteRD}, tables::InsnRecord, witness::LkMultiplicity, }; use multilinear_extensions::ToExpr; - // Opcode: 1101111 /// This config handles the common part of the J-type instruction (JAL): @@ -55,11 +55,13 @@ impl JInstructionConfig { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - self.vm_state.assign_instance(instance, step)?; - self.rd.assign_instance(instance, lk_multiplicity, step)?; + self.vm_state.assign_instance(instance, shard_ctx, step)?; + self.rd + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; // Fetch the instruction. lk_multiplicity.fetch(step.pc().before.0); diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal.rs b/ceno_zkvm/src/instructions/riscv/jump/jal.rs index a4c0a96f4..c8abc77ac 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jal.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jal.rs @@ -5,6 +5,7 @@ use ff_ext::ExtensionField; use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -64,13 +65,14 @@ impl Instruction for JalInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &ceno_emul::StepRecord, ) -> Result<(), ZKVMError> { config .j_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rd_written = Value::new(step.rd().unwrap().value.after, lk_multiplicity); config.rd_written.assign_value(instance, rd_written); diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs index 0f67be424..545adf275 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs @@ -4,6 +4,7 @@ use ff_ext::ExtensionField; use crate::{ circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -88,13 +89,14 @@ impl Instruction for JalInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &ceno_emul::StepRecord, ) -> Result<(), ZKVMError> { config .j_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rd_written = split_to_u8(step.rd().unwrap().value.after); config.rd_written.assign_limbs(instance, &rd_written); diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs index f1ba94aa7..77f6ad1f8 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs @@ -5,6 +5,7 @@ use ff_ext::ExtensionField; use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -111,6 +112,7 @@ impl Instruction for JalrInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &ceno_emul::StepRecord, @@ -150,7 +152,7 @@ impl Instruction for JalrInstruction { config .i_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; Ok(()) } diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs index bfec3a099..7f23ac9b6 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs @@ -5,6 +5,7 @@ use crate::{ Value, chip_handler::general::InstFetch, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -135,6 +136,7 @@ impl Instruction for JalrInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &ceno_emul::StepRecord, @@ -177,7 +179,7 @@ impl Instruction for JalrInstruction { config .i_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; Ok(()) } diff --git a/ceno_zkvm/src/instructions/riscv/jump/test.rs b/ceno_zkvm/src/instructions/riscv/jump/test.rs index 0b379f250..899e5a035 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/test.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/test.rs @@ -2,6 +2,7 @@ use super::{JalInstruction, JalrInstruction}; use crate::{ Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{Instruction, riscv::constants::UInt}, scheme::mock_prover::{MOCK_PC_START, MockProver}, structs::ProgramParams, @@ -42,6 +43,7 @@ fn verify_test_opcode_jal(pc_offset: i32) { let insn_code = encode_rv32(InsnKind::JAL, 0, 0, 4, pc_offset); let (raw_witin, lkm) = JalInstruction::::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_j_instruction( @@ -117,6 +119,7 @@ fn verify_test_opcode_jalr(rs1_read: Word, imm: i32) { let (raw_witin, lkm) = JalrInstruction::::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_i_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs index f761f6102..5a2d8e404 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs @@ -6,6 +6,7 @@ use std::marker::PhantomData; use crate::{ circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -53,6 +54,7 @@ impl Instruction for LogicInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, @@ -63,7 +65,7 @@ impl Instruction for LogicInstruction { step.rs2().unwrap().value as u64, ); - config.assign_instance(instance, lk_multiplicity, step) + config.assign_instance(instance, shard_ctx, lk_multiplicity, step) } } @@ -106,11 +108,12 @@ impl LogicConfig { fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { self.r_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rs1_read = split_to_u8(step.rs1().unwrap().value); self.rs1_read.assign_limbs(instance, &rs1_read); diff --git a/ceno_zkvm/src/instructions/riscv/logic/test.rs b/ceno_zkvm/src/instructions/riscv/logic/test.rs index dc01487d9..f68135c72 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/test.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/test.rs @@ -1,16 +1,16 @@ use ceno_emul::{Change, StepRecord, Word, encode_rv32}; use ff_ext::GoldilocksExt2; +use super::*; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{Instruction, riscv::constants::UInt8}, scheme::mock_prover::{MOCK_PC_START, MockProver}, structs::ProgramParams, utils::split_to_u8, }; -use super::*; - const A: Word = 0xbead1010; const B: Word = 0xef552020; @@ -32,6 +32,7 @@ fn test_opcode_and() { let insn_code = encode_rv32(InsnKind::AND, 2, 3, 4, 0); let (raw_witin, lkm) = AndInstruction::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( @@ -74,6 +75,7 @@ fn test_opcode_or() { let insn_code = encode_rv32(InsnKind::OR, 2, 3, 4, 0); let (raw_witin, lkm) = OrInstruction::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( @@ -116,6 +118,7 @@ fn test_opcode_xor() { let insn_code = encode_rv32(InsnKind::XOR, 2, 3, 4, 0); let (raw_witin, lkm) = XorInstruction::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs index aad60b43b..596792ad8 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs @@ -6,6 +6,7 @@ use std::marker::PhantomData; use crate::{ circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -48,6 +49,7 @@ impl Instruction for LogicInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lkm: &mut LkMultiplicity, step: &StepRecord, @@ -58,7 +60,7 @@ impl Instruction for LogicInstruction { InsnRecord::::imm_internal(&step.insn()).0 as u64, ); - config.assign_instance(instance, lkm, step) + config.assign_instance(instance, shard_ctx, lkm, step) } } @@ -102,10 +104,12 @@ impl LogicConfig { fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lkm: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - self.i_insn.assign_instance(instance, lkm, step)?; + self.i_insn + .assign_instance(instance, shard_ctx, lkm, step)?; let rs1_read = split_to_u8(step.rs1().unwrap().value); self.rs1_read.assign_limbs(instance, &rs1_read); diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs index c72f31efe..b48af7f5f 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs @@ -7,6 +7,7 @@ use std::marker::PhantomData; use crate::{ circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -94,6 +95,7 @@ impl Instruction for LogicInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lkm: &mut LkMultiplicity, step: &StepRecord, @@ -115,7 +117,7 @@ impl Instruction for LogicInstruction { imm_hi.into(), ); - config.assign_instance(instance, lkm, step) + config.assign_instance(instance, shard_ctx, lkm, step) } } @@ -163,11 +165,13 @@ impl LogicConfig { fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lkm: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { let num_limbs = LIMB_BITS / 8; - self.i_insn.assign_instance(instance, lkm, step)?; + self.i_insn + .assign_instance(instance, shard_ctx, lkm, step)?; let rs1_read = split_to_u8(step.rs1().unwrap().value); self.rs1_read.assign_limbs(instance, &rs1_read); diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs index 23aa2d77c..68032fd41 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs @@ -4,6 +4,7 @@ use gkr_iop::circuit_builder::DebugIndex; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{ Instruction, riscv::{ @@ -70,6 +71,7 @@ fn verify(name: &'static str, rs1_read: u32, imm: u32, expected_rd_w let insn_code = encode_rv32u(I::INST_KIND, 2, 0, 4, imm); let (raw_witin, lkm) = LogicInstruction::::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_i_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/lui.rs b/ceno_zkvm/src/instructions/riscv/lui.rs index 2cc280f04..198bafbc5 100644 --- a/ceno_zkvm/src/instructions/riscv/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/lui.rs @@ -4,6 +4,7 @@ use std::marker::PhantomData; use crate::{ circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -88,13 +89,14 @@ impl Instruction for LuiInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &ceno_emul::StepRecord, ) -> Result<(), ZKVMError> { config .i_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rd_written = split_to_u8(step.rd().unwrap().value.after); for (val, witin) in izip!(rd_written.iter().skip(1), config.rd_written) { @@ -117,6 +119,7 @@ mod tests { use crate::{ Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{ Instruction, riscv::{constants::UInt, lui::LuiInstruction}, @@ -153,6 +156,7 @@ mod tests { let insn_code = encode_rv32(InsnKind::LUI, 0, 0, 4, imm); let (raw_witin, lkm) = LuiInstruction::::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_i_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/memory/load.rs b/ceno_zkvm/src/instructions/riscv/memory/load.rs index 5945f26bd..41fbf0059 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load.rs @@ -1,6 +1,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::SignedExtendConfig, instructions::{ @@ -165,6 +166,7 @@ impl Instruction for LoadInstruction Instruction for LoadInstruction Instruction for LoadInstruction Instruction for LoadInstruction Instruction fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, @@ -124,7 +126,7 @@ impl Instruction let addr = ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); config .s_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; config.rs1_read.assign_value(instance, rs1); config.rs2_read.assign_value(instance, rs2); set_val!(instance, config.imm, imm.1); diff --git a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs index f07968d19..cb512975b 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs @@ -1,6 +1,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -127,6 +128,7 @@ impl Instruction fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, @@ -147,7 +149,7 @@ impl Instruction let addr = ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); config .s_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; config.rs1_read.assign_value(instance, rs1); config.rs2_read.assign_value(instance, rs2); set_val!(instance, config.imm, imm.1); diff --git a/ceno_zkvm/src/instructions/riscv/memory/test.rs b/ceno_zkvm/src/instructions/riscv/memory/test.rs index 90c5a0273..b2a04326b 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/test.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/test.rs @@ -1,6 +1,7 @@ use crate::{ Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{ Instruction, riscv::{ @@ -102,6 +103,7 @@ fn impl_opcode_store::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( @@ -217,6 +219,7 @@ mod test { let insn_code = encode_rv32(InsnKind::MULH, 2, 3, 4, 0); let (raw_witin, lkm) = MulhInstruction::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( @@ -300,6 +303,7 @@ mod test { let insn_code = encode_rv32(InsnKind::MULHSU, 2, 3, 4, 0); let (raw_witin, lkm) = MulhsuInstruction::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit.rs b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit.rs index bc5bc9ed4..dd919dd3e 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit.rs @@ -86,6 +86,7 @@ use p3::{field::FieldAlgebra, goldilocks::Goldilocks}; use crate::{ circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::{IsEqualConfig, Signed}, instructions::{ @@ -286,6 +287,7 @@ impl Instruction for MulhInstructionBas fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, @@ -312,7 +314,7 @@ impl Instruction for MulhInstructionBas // R-type instruction config .r_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; // Assign signed values, if any, and compute low 32-bit limb of product let prod_lo_hi = match &config.sign_deps { diff --git a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs index c1853d7a8..a94f63e74 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs @@ -19,6 +19,7 @@ use multilinear_extensions::{Expression, ToExpr as _, WitIn}; use p3::field::{Field, FieldAlgebra}; use witness::set_val; +use crate::e2e::ShardContext; use itertools::Itertools; use std::{array, marker::PhantomData}; @@ -223,6 +224,7 @@ impl Instruction for MulhInstructionBas fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, @@ -241,7 +243,7 @@ impl Instruction for MulhInstructionBas // R-type instruction config .r_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let (rd_high, rd_low, carry, rs1_ext, rs2_ext) = run_mulh::( I::INST_KIND, diff --git a/ceno_zkvm/src/instructions/riscv/r_insn.rs b/ceno_zkvm/src/instructions/riscv/r_insn.rs index 540ccaffe..a4b9bb128 100644 --- a/ceno_zkvm/src/instructions/riscv/r_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/r_insn.rs @@ -4,6 +4,7 @@ use ff_ext::ExtensionField; use crate::{ chip_handler::{RegisterExpr, general::InstFetch}, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::riscv::insn_base::{ReadRS1, ReadRS2, StateInOut, WriteRD}, tables::InsnRecord, @@ -63,13 +64,17 @@ impl RInstructionConfig { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - self.vm_state.assign_instance(instance, step)?; - self.rs1.assign_instance(instance, lk_multiplicity, step)?; - self.rs2.assign_instance(instance, lk_multiplicity, step)?; - self.rd.assign_instance(instance, lk_multiplicity, step)?; + self.vm_state.assign_instance(instance, shard_ctx, step)?; + self.rs1 + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; + self.rs2 + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; + self.rd + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; // Fetch instruction lk_multiplicity.fetch(step.pc().before.0); diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index b953fc2af..9957f2122 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -9,6 +9,7 @@ use crate::instructions::riscv::lui::LuiInstruction; #[cfg(not(feature = "u16limb_circuit"))] use crate::tables::PowTableCircuit; use crate::{ + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -409,6 +410,7 @@ impl Rv32imConfig { pub fn assign_opcode_circuit( &self, cs: &ZKVMConstraintSystem, + shard_ctx: &mut ShardContext, witness: &mut ZKVMWitnesses, steps: Vec, ) -> Result { @@ -422,38 +424,49 @@ impl Rv32imConfig { let mut secp256k1_add_records = Vec::new(); let mut secp256k1_double_records = Vec::new(); let mut secp256k1_decompress_records = Vec::new(); - steps.into_iter().for_each(|record| { - let insn_kind = record.insn.kind; - match insn_kind { - // ecall / halt - InsnKind::ECALL if record.rs1().unwrap().value == Platform::ecall_halt() => { - halt_records.push(record); + steps + .into_iter() + .filter_map(|step| { + if shard_ctx.is_current_shard_cycle(step.cycle()) { + Some(step) + } else { + None } - InsnKind::ECALL if record.rs1().unwrap().value == KeccakSpec::CODE => { - keccak_records.push(record); + }) + .for_each(|record| { + let insn_kind = record.insn.kind; + match insn_kind { + // ecall / halt + InsnKind::ECALL if record.rs1().unwrap().value == Platform::ecall_halt() => { + halt_records.push(record); + } + InsnKind::ECALL if record.rs1().unwrap().value == KeccakSpec::CODE => { + keccak_records.push(record); + } + InsnKind::ECALL if record.rs1().unwrap().value == Bn254AddSpec::CODE => { + bn254_add_records.push(record); + } + InsnKind::ECALL if record.rs1().unwrap().value == Bn254DoubleSpec::CODE => { + bn254_double_records.push(record); + } + InsnKind::ECALL if record.rs1().unwrap().value == Secp256k1AddSpec::CODE => { + secp256k1_add_records.push(record); + } + InsnKind::ECALL if record.rs1().unwrap().value == Secp256k1DoubleSpec::CODE => { + secp256k1_double_records.push(record); + } + InsnKind::ECALL + if record.rs1().unwrap().value == Secp256k1DecompressSpec::CODE => + { + secp256k1_decompress_records.push(record); + } + // other type of ecalls are handled by dummy ecall instruction + _ => { + // it's safe to unwrap as all_records are initialized with Vec::new() + all_records.get_mut(&insn_kind).unwrap().push(record); + } } - InsnKind::ECALL if record.rs1().unwrap().value == Bn254AddSpec::CODE => { - bn254_add_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Bn254DoubleSpec::CODE => { - bn254_double_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Secp256k1AddSpec::CODE => { - secp256k1_add_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Secp256k1DoubleSpec::CODE => { - secp256k1_double_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Secp256k1DecompressSpec::CODE => { - secp256k1_decompress_records.push(record); - } - // other type of ecalls are handled by dummy ecall instruction - _ => { - // it's safe to unwrap as all_records are initialized with Vec::new() - all_records.get_mut(&insn_kind).unwrap().push(record); - } - } - }); + }); for (insn_kind, (_, records)) in izip!(InsnKind::iter(), &all_records).sorted_by_key(|(_, (_, a))| Reverse(a.len())) @@ -465,6 +478,7 @@ impl Rv32imConfig { ($insn_kind:ident,$instruction:ty,$config:ident) => { witness.assign_opcode_circuit::<$instruction>( cs, + shard_ctx, &self.$config, all_records.remove(&($insn_kind)).unwrap(), )?; @@ -524,35 +538,46 @@ impl Rv32imConfig { assign_opcode!(SB, SbInstruction, sb_config); // ecall / halt - witness.assign_opcode_circuit::>(cs, &self.halt_config, halt_records)?; + witness.assign_opcode_circuit::>( + cs, + shard_ctx, + &self.halt_config, + halt_records, + )?; witness.assign_opcode_circuit::>( cs, + shard_ctx, &self.keccak_config, keccak_records, )?; witness.assign_opcode_circuit::>>( cs, + shard_ctx, &self.bn254_add_config, bn254_add_records, )?; witness.assign_opcode_circuit::>>( cs, + shard_ctx, &self.bn254_double_config, bn254_double_records, )?; witness.assign_opcode_circuit::>>( cs, + shard_ctx, &self.secp256k1_add_config, secp256k1_add_records, )?; witness .assign_opcode_circuit::>>( cs, + shard_ctx, &self.secp256k1_double_config, secp256k1_double_records, )?; witness.assign_opcode_circuit::>>( cs, + shard_ctx, &self.secp256k1_decompress_config, secp256k1_decompress_records, )?; @@ -671,6 +696,7 @@ impl DummyExtraConfig { pub fn assign_opcode_circuit( &self, cs: &ZKVMConstraintSystem, + shard_ctx: &mut ShardContext, witness: &mut ZKVMWitnesses, steps: GroupedSteps, ) -> Result<(), ZKVMError> { @@ -700,35 +726,46 @@ impl DummyExtraConfig { witness.assign_opcode_circuit::>( cs, + shard_ctx, &self.secp256k1_decompress_config, secp256k1_decompress_steps, )?; witness.assign_opcode_circuit::>( cs, + shard_ctx, &self.sha256_extend_config, sha256_extend_steps, )?; witness.assign_opcode_circuit::>( cs, + shard_ctx, &self.bn254_fp_add_config, bn254_fp_add_steps, )?; witness.assign_opcode_circuit::>( cs, + shard_ctx, &self.bn254_fp_mul_config, bn254_fp_mul_steps, )?; witness.assign_opcode_circuit::>( cs, + shard_ctx, &self.bn254_fp2_add_config, bn254_fp2_add_steps, )?; witness.assign_opcode_circuit::>( cs, + shard_ctx, &self.bn254_fp2_mul_config, bn254_fp2_mul_steps, )?; - witness.assign_opcode_circuit::>(cs, &self.ecall_config, other_steps)?; + witness.assign_opcode_circuit::>( + cs, + shard_ctx, + &self.ecall_config, + other_steps, + )?; let _ = steps.remove(&INVALID); let keys: Vec<&InsnKind> = steps.keys().collect::>(); diff --git a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs index d8c032c7b..900672a3d 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs @@ -1,54 +1,64 @@ -use std::{collections::HashSet, iter::zip, ops::Range}; - -use ceno_emul::{Addr, Cycle, IterAddresses, WORD_SIZE, Word}; -use ff_ext::ExtensionField; -use itertools::{Itertools, chain}; - use crate::{ + e2e::ShardContext, error::ZKVMError, + instructions::global::GlobalChip, structs::{ProgramParams, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, tables::{ - HeapCircuit, HintsCircuit, MemFinalRecord, MemInitRecord, NonVolatileTable, PubIOCircuit, - PubIOTable, RegTable, RegTableCircuit, StackCircuit, StaticMemCircuit, StaticMemTable, + DynVolatileRamTable, HeapInitCircuit, HeapTable, HintsCircuit, LocalFinalCircuit, + MemFinalRecord, MemInitRecord, NonVolatileTable, PubIOCircuit, PubIOTable, RegTable, + RegTableInitCircuit, StackInitCircuit, StackTable, StaticMemInitCircuit, StaticMemTable, TableCircuit, }, }; +use ceno_emul::{Addr, Cycle, IterAddresses, WORD_SIZE, Word}; +use ff_ext::ExtensionField; +use itertools::{Itertools, chain}; +use std::{collections::HashSet, iter::zip, ops::Range, sync::Arc}; +use witness::InstancePaddingStrategy; -pub struct MmuConfig { +pub struct MmuConfig<'a, E: ExtensionField> { /// Initialization of registers. - pub reg_config: as TableCircuit>::TableConfig, + pub reg_init_config: as TableCircuit>::TableConfig, /// Initialization of memory with static addresses. - pub static_mem_config: as TableCircuit>::TableConfig, + pub static_mem_init_config: as TableCircuit>::TableConfig, /// Initialization of public IO. pub public_io_config: as TableCircuit>::TableConfig, /// Initialization of hints. pub hints_config: as TableCircuit>::TableConfig, /// Initialization of heap. - pub heap_config: as TableCircuit>::TableConfig, + pub heap_init_config: as TableCircuit>::TableConfig, /// Initialization of stack. - pub stack_config: as TableCircuit>::TableConfig, + pub stack_init_config: as TableCircuit>::TableConfig, + /// finalized circuit for all MMIO + pub local_final_circuit: as TableCircuit>::TableConfig, + /// ram bus to deal with cross shard read/write + pub ram_bus_circuit: as TableCircuit>::TableConfig, pub params: ProgramParams, } -impl MmuConfig { +impl MmuConfig<'_, E> { pub fn construct_circuits(cs: &mut ZKVMConstraintSystem) -> Self { - let reg_config = cs.register_table_circuit::>(); + let reg_init_config = cs.register_table_circuit::>(); - let static_mem_config = cs.register_table_circuit::>(); + let static_mem_init_config = cs.register_table_circuit::>(); let public_io_config = cs.register_table_circuit::>(); let hints_config = cs.register_table_circuit::>(); - let stack_config = cs.register_table_circuit::>(); - let heap_config = cs.register_table_circuit::>(); + let stack_init_config = cs.register_table_circuit::>(); + let heap_init_config = cs.register_table_circuit::>(); + let local_final_circuit = cs.register_table_circuit::>(); + let ram_bus_circuit = cs.register_table_circuit::>(); Self { - reg_config, - static_mem_config, + reg_init_config, + static_mem_init_config, public_io_config, hints_config, - stack_config, - heap_config, + stack_init_config, + heap_init_config, + local_final_circuit, + ram_bus_circuit, params: cs.params.clone(), } } @@ -72,24 +82,27 @@ impl MmuConfig { "memory addresses must be unique" ); - fixed.register_table_circuit::>(cs, &self.reg_config, reg_init); + fixed.register_table_circuit::>(cs, &self.reg_init_config, reg_init); - fixed.register_table_circuit::>( + fixed.register_table_circuit::>( cs, - &self.static_mem_config, + &self.static_mem_init_config, static_mem_init, ); fixed.register_table_circuit::>(cs, &self.public_io_config, io_addrs); fixed.register_table_circuit::>(cs, &self.hints_config, &()); - fixed.register_table_circuit::>(cs, &self.stack_config, &()); - fixed.register_table_circuit::>(cs, &self.heap_config, &()); + fixed.register_table_circuit::>(cs, &self.stack_init_config, &()); + fixed.register_table_circuit::>(cs, &self.heap_init_config, &()); + fixed.register_table_circuit::>(cs, &self.local_final_circuit, &()); + // fixed.register_table_circuit::>(cs, &self.ram_bus_circuit, &()); } #[allow(clippy::too_many_arguments)] pub fn assign_table_circuit( &self, cs: &ZKVMConstraintSystem, + shard_ctx: &ShardContext, witness: &mut ZKVMWitnesses, reg_final: &[MemFinalRecord], static_mem_final: &[MemFinalRecord], @@ -98,18 +111,59 @@ impl MmuConfig { stack_final: &[MemFinalRecord], heap_final: &[MemFinalRecord], ) -> Result<(), ZKVMError> { - witness.assign_table_circuit::>(cs, &self.reg_config, reg_final)?; + witness.assign_table_circuit::>( + cs, + &self.reg_init_config, + reg_final, + )?; - witness.assign_table_circuit::>( + witness.assign_table_circuit::>( cs, - &self.static_mem_config, + &self.static_mem_init_config, static_mem_final, )?; witness.assign_table_circuit::>(cs, &self.public_io_config, io_cycles)?; witness.assign_table_circuit::>(cs, &self.hints_config, hints_final)?; - witness.assign_table_circuit::>(cs, &self.stack_config, stack_final)?; - witness.assign_table_circuit::>(cs, &self.heap_config, heap_final)?; + witness.assign_table_circuit::>( + cs, + &self.stack_init_config, + stack_final, + )?; + witness.assign_table_circuit::>( + cs, + &self.heap_init_config, + heap_final, + )?; + + let all_records = vec![ + (InstancePaddingStrategy::Default, reg_final), + (InstancePaddingStrategy::Default, static_mem_final), + ( + InstancePaddingStrategy::Custom({ + let params = cs.params.clone(); + Arc::new(move |row: u64, _: u64| StackTable::addr(¶ms, row as usize) as u64) + }), + stack_final, + ), + ( + InstancePaddingStrategy::Custom({ + let params = cs.params.clone(); + Arc::new(move |row: u64, _: u64| HeapTable::addr(¶ms, row as usize) as u64) + }), + heap_final, + ), + ] + .into_iter() + .filter(|(_, record)| !record.is_empty()) + .collect_vec(); + + witness.assign_table_circuit::>( + cs, + &self.local_final_circuit, + &(shard_ctx, all_records.as_slice()), + )?; + witness.assign_global_chip_circuit(cs, shard_ctx, &self.ram_bus_circuit)?; Ok(()) } diff --git a/ceno_zkvm/src/instructions/riscv/s_insn.rs b/ceno_zkvm/src/instructions/riscv/s_insn.rs index f46cf4c5d..f252a7c60 100644 --- a/ceno_zkvm/src/instructions/riscv/s_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/s_insn.rs @@ -1,6 +1,7 @@ use crate::{ chip_handler::{AddressExpr, MemoryExpr, RegisterExpr, general::InstFetch}, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::riscv::insn_base::{ReadRS1, ReadRS2, StateInOut, WriteMEM}, tables::InsnRecord, @@ -73,14 +74,17 @@ impl SInstructionConfig { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - self.vm_state.assign_instance(instance, step)?; - self.rs1.assign_instance(instance, lk_multiplicity, step)?; - self.rs2.assign_instance(instance, lk_multiplicity, step)?; + self.vm_state.assign_instance(instance, shard_ctx, step)?; + self.rs1 + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; + self.rs2 + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; self.mem_write - .assign_instance::(instance, lk_multiplicity, step)?; + .assign_instance::(instance, shard_ctx, lk_multiplicity, step)?; // Fetch instruction lk_multiplicity.fetch(step.pc().before.0); diff --git a/ceno_zkvm/src/instructions/riscv/shift.rs b/ceno_zkvm/src/instructions/riscv/shift.rs index 0c53f1a4c..d09b98c89 100644 --- a/ceno_zkvm/src/instructions/riscv/shift.rs +++ b/ceno_zkvm/src/instructions/riscv/shift.rs @@ -45,6 +45,7 @@ mod tests { use crate::utils::split_to_u8; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{Instruction, riscv::RIVInstruction}, scheme::mock_prover::{MOCK_PC_START, MockProver}, structs::ProgramParams, @@ -173,6 +174,7 @@ mod tests { let (raw_witin, lkm) = ShiftLogicalInstruction::::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs index 87374b20e..c1d83ce87 100644 --- a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs @@ -1,5 +1,6 @@ use crate::{ Value, + e2e::ShardContext, error::ZKVMError, gadgets::SignedExtendConfig, instructions::{ @@ -151,6 +152,7 @@ impl Instruction for ShiftLogicalInstru fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lk_multiplicity: &mut crate::witness::LkMultiplicity, step: &ceno_emul::StepRecord, @@ -211,7 +213,7 @@ impl Instruction for ShiftLogicalInstru config .r_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; Ok(()) } diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs index 4e929670c..fac05279e 100644 --- a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs @@ -1,3 +1,4 @@ +use crate::e2e::ShardContext; /// constrain implementation follow from https://github.com/openvm-org/openvm/blob/main/extensions/rv32im/circuit/src/shift/core.rs use crate::{ instructions::{ @@ -321,6 +322,7 @@ impl Instruction for ShiftLogicalInstru fn assign_instance( config: &ShiftRTypeConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lk_multiplicity: &mut crate::witness::LkMultiplicity, step: &ceno_emul::StepRecord, @@ -352,7 +354,7 @@ impl Instruction for ShiftLogicalInstru ); config .r_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; Ok(()) } @@ -419,6 +421,7 @@ impl Instruction for ShiftImmInstructio fn assign_instance( config: &ShiftImmConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lk_multiplicity: &mut crate::witness::LkMultiplicity, step: &ceno_emul::StepRecord, @@ -449,7 +452,7 @@ impl Instruction for ShiftImmInstructio ); config .i_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; Ok(()) } diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm.rs b/ceno_zkvm/src/instructions/riscv/shift_imm.rs index 4cf7ac155..1757a0fc7 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm.rs @@ -43,6 +43,7 @@ mod test { use crate::utils::split_to_u8; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{Instruction, riscv::RIVInstruction}, scheme::mock_prover::{MOCK_PC_START, MockProver}, structs::ProgramParams, @@ -170,6 +171,7 @@ mod test { let (raw_witin, lkm) = ShiftImmInstruction::::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_i_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs b/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs index 0bba35411..a2fa8d032 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs @@ -1,6 +1,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::SignedExtendConfig, instructions::{ @@ -132,6 +133,7 @@ impl Instruction for ShiftImmInstructio fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, @@ -168,7 +170,7 @@ impl Instruction for ShiftImmInstructio config .i_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; Ok(()) } diff --git a/ceno_zkvm/src/instructions/riscv/slt.rs b/ceno_zkvm/src/instructions/riscv/slt.rs index 7b27617ad..3ba12bb39 100644 --- a/ceno_zkvm/src/instructions/riscv/slt.rs +++ b/ceno_zkvm/src/instructions/riscv/slt.rs @@ -38,6 +38,7 @@ mod test { use crate::{ Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{Instruction, riscv::constants::UInt}, scheme::mock_prover::{MOCK_PC_START, MockProver}, structs::ProgramParams, @@ -72,6 +73,7 @@ mod test { let insn_code = encode_rv32(I::INST_KIND, 2, 3, 4, 0); let (raw_witin, lkm) = SetLessThanInstruction::<_, I>::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit.rs b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit.rs index 3ffd9de69..b9b63acaf 100644 --- a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit.rs @@ -1,5 +1,6 @@ use crate::{ Value, + e2e::ShardContext, error::ZKVMError, gadgets::SignedLtConfig, instructions::{ @@ -92,11 +93,14 @@ impl Instruction for SetLessThanInstruc fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lkm: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - config.r_insn.assign_instance(instance, lkm, step)?; + config + .r_insn + .assign_instance(instance, shard_ctx, lkm, step)?; let rs1 = step.rs1().unwrap().value; let rs2 = step.rs2().unwrap().value; diff --git a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs index 391dffb89..cd0b97ce4 100644 --- a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs @@ -1,6 +1,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::{UIntLimbsLT, UIntLimbsLTConfig}, instructions::{ @@ -75,11 +76,14 @@ impl Instruction for SetLessThanInstruc fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lkm: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - config.r_insn.assign_instance(instance, lkm, step)?; + config + .r_insn + .assign_instance(instance, shard_ctx, lkm, step)?; let rs1 = step.rs1().unwrap().value; let rs2 = step.rs2().unwrap().value; diff --git a/ceno_zkvm/src/instructions/riscv/slti.rs b/ceno_zkvm/src/instructions/riscv/slti.rs index 5802c4229..ff3a78043 100644 --- a/ceno_zkvm/src/instructions/riscv/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/slti.rs @@ -35,6 +35,7 @@ mod test { use crate::{ Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{ Instruction, riscv::{ @@ -185,6 +186,7 @@ mod test { let (raw_witin, lkm) = SetLessThanImmInstruction::::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_i_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs index 266faeed3..8b93f593c 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs @@ -1,6 +1,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::SignedExtendConfig, instructions::{ @@ -94,11 +95,14 @@ impl Instruction for SetLessThanImmInst fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lkm: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - config.i_insn.assign_instance(instance, lkm, step)?; + config + .i_insn + .assign_instance(instance, shard_ctx, lkm, step)?; let rs1 = step.rs1().unwrap().value; let rs1_value = Value::new_unchecked(rs1 as Word); diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs index 1085561fb..914424247 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs @@ -1,6 +1,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::{UIntLimbsLT, UIntLimbsLTConfig}, instructions::{ @@ -92,11 +93,14 @@ impl Instruction for SetLessThanImmInst fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lkm: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - config.i_insn.assign_instance(instance, lkm, step)?; + config + .i_insn + .assign_instance(instance, shard_ctx, lkm, step)?; let rs1 = step.rs1().unwrap().value; let rs1_value = Value::new_unchecked(rs1 as Word); diff --git a/ceno_zkvm/src/keygen.rs b/ceno_zkvm/src/keygen.rs index 17ab9e72c..0ced182b8 100644 --- a/ceno_zkvm/src/keygen.rs +++ b/ceno_zkvm/src/keygen.rs @@ -26,8 +26,11 @@ impl ZKVMConstraintSystem { .remove(&c_name) .flatten() .ok_or(ZKVMError::FixedTraceNotFound(c_name.clone().into()))?; + vm_pk + .circuit_index_fixed_num_instances + .insert(circuit_index, fixed_trace_rmm.num_instances()); fixed_traces.insert(circuit_index, fixed_trace_rmm); - }; + } let circuit_pk = cs.key_gen(); assert!(vm_pk.circuit_pks.insert(c_name, circuit_pk).is_none()); diff --git a/ceno_zkvm/src/lib.rs b/ceno_zkvm/src/lib.rs index 16e7ee821..a72c0ffe6 100644 --- a/ceno_zkvm/src/lib.rs +++ b/ceno_zkvm/src/lib.rs @@ -1,6 +1,7 @@ #![deny(clippy::cargo)] #![feature(box_patterns)] #![feature(stmt_expr_attributes)] +#![feature(variant_count)] pub mod error; pub mod instructions; diff --git a/ceno_zkvm/src/precompiles/bitwise_keccakf.rs b/ceno_zkvm/src/precompiles/bitwise_keccakf.rs index e25ee972d..51bf0092a 100644 --- a/ceno_zkvm/src/precompiles/bitwise_keccakf.rs +++ b/ceno_zkvm/src/precompiles/bitwise_keccakf.rs @@ -30,7 +30,7 @@ use gkr_iop::{ layer::Layer, layer_constraint_system::{LayerConstraintSystem, expansion_expr}, }, - selector::SelectorType, + selector::{SelectorContext, SelectorType}, utils::{indices_arr_with_offset, lk_multiplicity::LkMultiplicity, wits_fixed_and_eqs}, }; @@ -963,6 +963,14 @@ pub fn run_keccakf + 'stat }; let span = entered_span!("prove", profiling_1 = true); + let selector_ctxs = vec![ + SelectorContext::new(0, num_instances, log2_num_instances); + gkr_circuit + .layers + .first() + .map(|layer| layer.out_sel_and_eval_exprs.len()) + .unwrap() + ]; let GKRProverOutput { gkr_proof, .. } = gkr_circuit .prove::, CpuProver<_>>( num_threads, @@ -972,7 +980,7 @@ pub fn run_keccakf + 'stat &[], &[], &mut prover_transcript, - num_instances, + &selector_ctxs, ) .expect("Failed to prove phase"); exit_span!(span); @@ -993,7 +1001,7 @@ pub fn run_keccakf + 'stat &[], &[], &mut verifier_transcript, - num_instances, + &selector_ctxs, ) .expect("GKR verify failed"); diff --git a/ceno_zkvm/src/precompiles/lookup_keccakf.rs b/ceno_zkvm/src/precompiles/lookup_keccakf.rs index 2fcd8de79..bb105899d 100644 --- a/ceno_zkvm/src/precompiles/lookup_keccakf.rs +++ b/ceno_zkvm/src/precompiles/lookup_keccakf.rs @@ -14,7 +14,7 @@ use gkr_iop::{ layer::Layer, mock::MockProver, }, - selector::SelectorType, + selector::{SelectorContext, SelectorType}, utils::lk_multiplicity::LkMultiplicity, }; use itertools::{Itertools, iproduct, izip, zip_eq}; @@ -40,6 +40,7 @@ use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ chip_handler::MemoryExpr, + e2e::ShardContext, error::ZKVMError, instructions::riscv::insn_base::{StateInOut, WriteMEM}, precompiles::{ @@ -1025,6 +1026,7 @@ pub fn run_faster_keccakf verify: bool, test_outputs: bool, ) -> Result, BackendError> { + let mut shard_ctx = ShardContext::default(); let num_instances = states.len(); let num_instances_rounds = num_instances * ROUNDS.next_power_of_two(); let log2_num_instance_rounds = ceil_log2(num_instances_rounds); @@ -1073,9 +1075,11 @@ pub fn run_faster_keccakf ); let raw_witin_iter = phase1_witness.par_batch_iter_mut(num_instance_per_batch * ROUNDS.next_power_of_two()); + let shard_ctx_vec = shard_ctx.get_forked(); raw_witin_iter .zip_eq(instances.par_chunks(num_instance_per_batch)) - .for_each(|(instances, steps)| { + .zip(shard_ctx_vec) + .for_each(|((instances, steps), mut shard_ctx)| { let mut lk_multiplicity = lk_multiplicity.clone(); instances .chunks_mut(num_witin as usize * ROUNDS.next_power_of_two()) @@ -1087,6 +1091,7 @@ pub fn run_faster_keccakf .vm_state .assign_instance( instance, + &shard_ctx, &StepRecord::new_ecall_any(10, ByteAddr::from(0)), ) .expect("assign vm_state error"); @@ -1095,6 +1100,7 @@ pub fn run_faster_keccakf mem_config .assign_op( instance, + &mut shard_ctx, &mut lk_multiplicity, 10, &MemOp { @@ -1222,6 +1228,7 @@ pub fn run_faster_keccakf } let span = entered_span!("create_proof", profiling_2 = true); + let selector_ctxs = vec![SelectorContext::new(0, num_instances, log2_num_instance_rounds); 3]; let GKRProverOutput { gkr_proof, .. } = gkr_circuit .prove::, CpuProver<_>>( num_threads, @@ -1231,7 +1238,7 @@ pub fn run_faster_keccakf &[], &challenges, &mut prover_transcript, - num_instances, + &selector_ctxs, ) .expect("Failed to prove phase"); exit_span!(span); @@ -1260,7 +1267,7 @@ pub fn run_faster_keccakf &[], &challenges, &mut verifier_transcript, - num_instances, + &selector_ctxs, ) .expect("GKR verify failed"); diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs index 3be11dbd0..76df2b06a 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs @@ -36,7 +36,7 @@ use gkr_iop::{ cpu::{CpuBackend, CpuProver}, error::{BackendError, CircuitBuilderError}, gkr::{GKRCircuit, GKRProof, GKRProverOutput, layer::Layer, mock::MockProver}, - selector::SelectorType, + selector::{SelectorContext, SelectorType}, }; use itertools::{Itertools, izip}; use mpcs::PolynomialCommitmentScheme; @@ -63,6 +63,7 @@ use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ chip_handler::MemoryExpr, + e2e::ShardContext, error::ZKVMError, gadgets::{FieldOperation, field_op::FieldOpCols}, instructions::riscv::insn_base::{StateInOut, WriteMEM}, @@ -140,11 +141,12 @@ impl WeierstrassAddAssignLayout { descending: false, }, ); + let sel = SelectorType::Prefix(eq.expr()); let selector_type_layout = SelectorTypeLayout { - sel_mem_read: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), - sel_mem_write: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), - sel_lookup: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), - sel_zero: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), + sel_mem_read: sel.clone(), + sel_mem_write: sel.clone(), + sel_lookup: sel.clone(), + sel_zero: sel.clone(), }; // Default expression, will be updated in build_layer_logic @@ -559,6 +561,7 @@ pub fn run_weierstrass_add< verify: bool, test_outputs: bool, ) -> Result, BackendError> { + let mut shard_ctx = ShardContext::default(); let num_instances = points.len(); let log2_num_instance = ceil_log2(num_instances); let num_threads = optimal_sumcheck_threads(log2_num_instance); @@ -591,9 +594,11 @@ pub fn run_weierstrass_add< InstancePaddingStrategy::Default, ); let raw_witin_iter = phase1_witness.par_batch_iter_mut(num_instance_per_batch); + let shard_ctx_vec = shard_ctx.get_forked(); raw_witin_iter .zip_eq(instances.par_chunks(num_instance_per_batch)) - .for_each(|(instances, steps)| { + .zip(shard_ctx_vec) + .for_each(|((instances, steps), mut shard_ctx)| { let mut lk_multiplicity = lk_multiplicity.clone(); instances .chunks_mut(num_witin as usize) @@ -603,6 +608,7 @@ pub fn run_weierstrass_add< .vm_state .assign_instance( instance, + &shard_ctx, &StepRecord::new_ecall_any(10, ByteAddr::from(0)), ) .expect("assign vm_state error"); @@ -610,6 +616,7 @@ pub fn run_weierstrass_add< mem_config .assign_op( instance, + &mut shard_ctx, &mut lk_multiplicity, 10, &MemOp { @@ -746,6 +753,7 @@ pub fn run_weierstrass_add< } let span = entered_span!("create_proof", profiling_2 = true); + let selector_ctxs = vec![SelectorContext::new(0, num_instances, log2_num_instance); 1]; let GKRProverOutput { gkr_proof, .. } = gkr_circuit .prove::, CpuProver<_>>( num_threads, @@ -755,7 +763,7 @@ pub fn run_weierstrass_add< &[], &challenges, &mut prover_transcript, - num_instances, + &selector_ctxs, ) .expect("Failed to prove phase"); exit_span!(span); @@ -780,7 +788,7 @@ pub fn run_weierstrass_add< &[], &challenges, &mut verifier_transcript, - num_instances, + &selector_ctxs, ) .expect("GKR verify failed"); diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs index 52496e869..9f37a26c7 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs @@ -36,7 +36,7 @@ use gkr_iop::{ cpu::{CpuBackend, CpuProver}, error::{BackendError, CircuitBuilderError}, gkr::{GKRCircuit, GKRProof, GKRProverOutput, layer::Layer, mock::MockProver}, - selector::SelectorType, + selector::{SelectorContext, SelectorType}, }; use itertools::{Itertools, izip}; use mpcs::PolynomialCommitmentScheme; @@ -67,6 +67,7 @@ use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ chip_handler::MemoryExpr, + e2e::ShardContext, error::ZKVMError, gadgets::{ FieldOperation, field_inner_product::FieldInnerProductCols, field_op::FieldOpCols, @@ -158,11 +159,12 @@ impl descending: false, }, ); + let sel = SelectorType::Prefix(eq.expr()); let selector_type_layout = SelectorTypeLayout { - sel_mem_read: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), - sel_mem_write: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), - sel_lookup: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), - sel_zero: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), + sel_mem_read: sel.clone(), + sel_mem_write: sel.clone(), + sel_lookup: sel.clone(), + sel_zero: sel.clone(), }; let input32_exprs: GenericArray< @@ -557,6 +559,7 @@ pub fn run_weierstrass_decompress< test_outputs: bool, verify: bool, ) -> Result, BackendError> { + let mut shard_ctx = ShardContext::default(); let num_instances = instances.len(); let log2_num_instance = ceil_log2(num_instances); let num_threads = optimal_sumcheck_threads(log2_num_instance); @@ -577,9 +580,11 @@ pub fn run_weierstrass_decompress< InstancePaddingStrategy::Default, ); let raw_witin_iter = phase1_witness.par_batch_iter_mut(num_instance_per_batch); + let shard_ctx_vec = shard_ctx.get_forked(); raw_witin_iter .zip_eq(instances.par_chunks(num_instance_per_batch)) - .for_each(|(instances, steps)| { + .zip(shard_ctx_vec) + .for_each(|((instances, steps), mut shard_ctx)| { let mut lk_multiplicity = lk_multiplicity.clone(); instances .chunks_mut(num_witin as usize) @@ -589,6 +594,7 @@ pub fn run_weierstrass_decompress< .vm_state .assign_instance( instance, + &shard_ctx, &StepRecord::new_ecall_any(10, ByteAddr::from(0)), ) .expect("assign vm_state error"); @@ -596,6 +602,7 @@ pub fn run_weierstrass_decompress< mem_config .assign_op( instance, + &mut shard_ctx, &mut lk_multiplicity, 10, &MemOp { @@ -726,6 +733,7 @@ pub fn run_weierstrass_decompress< } let span = entered_span!("create_proof", profiling_2 = true); + let selector_ctxs = vec![SelectorContext::new(0, num_instances, log2_num_instance); 1]; let GKRProverOutput { gkr_proof, .. } = gkr_circuit .prove::, CpuProver<_>>( num_threads, @@ -735,7 +743,7 @@ pub fn run_weierstrass_decompress< &[], &challenges, &mut prover_transcript, - num_instances, + &selector_ctxs, ) .expect("Failed to prove phase"); exit_span!(span); @@ -760,7 +768,7 @@ pub fn run_weierstrass_decompress< &[], &challenges, &mut verifier_transcript, - num_instances, + &selector_ctxs, ) .expect("GKR verify failed"); diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs index e5f16ba2f..7f9a02997 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs @@ -36,7 +36,7 @@ use gkr_iop::{ cpu::{CpuBackend, CpuProver}, error::{BackendError, CircuitBuilderError}, gkr::{GKRCircuit, GKRProof, GKRProverOutput, layer::Layer, mock::MockProver}, - selector::SelectorType, + selector::{SelectorContext, SelectorType}, }; use itertools::{Itertools, izip}; use mpcs::PolynomialCommitmentScheme; @@ -64,6 +64,7 @@ use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ chip_handler::MemoryExpr, + e2e::ShardContext, error::ZKVMError, gadgets::{FieldOperation, field_op::FieldOpCols}, instructions::riscv::insn_base::{StateInOut, WriteMEM}, @@ -142,11 +143,12 @@ impl descending: false, }, ); + let sel = SelectorType::Prefix(eq.expr()); let selector_type_layout = SelectorTypeLayout { - sel_mem_read: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), - sel_mem_write: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), - sel_lookup: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), - sel_zero: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), + sel_mem_read: sel.clone(), + sel_mem_write: sel.clone(), + sel_lookup: sel.clone(), + sel_zero: sel.clone(), }; let input32_exprs: GenericArray< @@ -564,6 +566,7 @@ pub fn run_weierstrass_double< verify: bool, test_outputs: bool, ) -> Result, BackendError> { + let mut shard_ctx = ShardContext::default(); let num_instances = points.len(); let log2_num_instance = ceil_log2(num_instances); let num_threads = optimal_sumcheck_threads(log2_num_instance); @@ -593,9 +596,11 @@ pub fn run_weierstrass_double< InstancePaddingStrategy::Default, ); let raw_witin_iter = phase1_witness.par_batch_iter_mut(num_instance_per_batch); + let shard_ctx_vec = shard_ctx.get_forked(); raw_witin_iter .zip_eq(instances.par_chunks(num_instance_per_batch)) - .for_each(|(instances, steps)| { + .zip(shard_ctx_vec) + .for_each(|((instances, steps), mut shard_ctx)| { let mut lk_multiplicity = lk_multiplicity.clone(); instances .chunks_mut(num_witin as usize) @@ -605,6 +610,7 @@ pub fn run_weierstrass_double< .vm_state .assign_instance( instance, + &shard_ctx, &StepRecord::new_ecall_any(10, ByteAddr::from(0)), ) .expect("assign vm_state error"); @@ -612,6 +618,7 @@ pub fn run_weierstrass_double< mem_config .assign_op( instance, + &mut shard_ctx, &mut lk_multiplicity, 10, &MemOp { @@ -748,6 +755,7 @@ pub fn run_weierstrass_double< } let span = entered_span!("create_proof", profiling_2 = true); + let selector_ctxs = vec![SelectorContext::new(0, num_instances, log2_num_instance); 1]; let GKRProverOutput { gkr_proof, .. } = gkr_circuit .prove::, CpuProver<_>>( num_threads, @@ -757,7 +765,7 @@ pub fn run_weierstrass_double< &[], &challenges, &mut prover_transcript, - num_instances, + &selector_ctxs, ) .expect("Failed to prove phase"); exit_span!(span); @@ -782,7 +790,7 @@ pub fn run_weierstrass_double< &[], &challenges, &mut verifier_transcript, - num_instances, + &selector_ctxs, ) .expect("GKR verify failed"); diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index 58a9aae89..aa3928153 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -1,3 +1,4 @@ +use crate::structs::EccQuarkProof; use ff_ext::ExtensionField; use gkr_iop::gkr::GKRProof; use itertools::Itertools; @@ -29,6 +30,7 @@ pub mod cpu; pub mod gpu; pub mod hal; pub mod prover; +pub mod septic_curve; pub mod utils; pub mod verifier; @@ -58,8 +60,10 @@ pub struct ZKVMChipProof { pub gkr_iop_proof: Option>, pub tower_proof: TowerProofs, + pub ecc_proof: Option>, + + pub num_instances: Vec, - pub num_instances: usize, pub fixed_in_evals: Vec, pub wits_in_evals: Vec, } @@ -72,17 +76,22 @@ pub struct PublicValues { init_cycle: u64, end_pc: u32, end_cycle: u64, + shard_id: u32, public_io: Vec, + global_sum: Vec, } impl PublicValues { + #[allow(clippy::too_many_arguments)] pub fn new( exit_code: u32, init_pc: u32, init_cycle: u64, end_pc: u32, end_cycle: u64, + shard_id: u32, public_io: Vec, + global_sum: Vec, ) -> Self { Self { exit_code, @@ -90,7 +99,9 @@ impl PublicValues { init_cycle, end_pc, end_cycle, + shard_id, public_io, + global_sum, } } pub fn to_vec(&self) -> Vec> { @@ -103,6 +114,7 @@ impl PublicValues { vec![E::BaseField::from_canonical_u64(self.init_cycle)], vec![E::BaseField::from_canonical_u32(self.end_pc)], vec![E::BaseField::from_canonical_u64(self.end_cycle)], + vec![E::BaseField::from_canonical_u32(self.shard_id)], ] .into_iter() .chain( @@ -120,6 +132,12 @@ impl PublicValues { }) .collect_vec(), ) + .chain( + self.global_sum + .iter() + .map(|value| vec![E::BaseField::from_canonical_u32(*value)]) + .collect_vec(), + ) .collect::>() } } @@ -193,7 +211,7 @@ impl> ZKVMProof { let halt_instance_count = self .chip_proofs .get(&halt_circuit_index) - .map_or(0, |proof| proof.num_instances); + .map_or(0, |proof| proof.num_instances.iter().sum()); if halt_instance_count > 0 { assert_eq!( halt_instance_count, 1, diff --git a/ceno_zkvm/src/scheme/constants.rs b/ceno_zkvm/src/scheme/constants.rs index 3cc212e9f..20687183e 100644 --- a/ceno_zkvm/src/scheme/constants.rs +++ b/ceno_zkvm/src/scheme/constants.rs @@ -1,5 +1,4 @@ pub(crate) const MIN_PAR_SIZE: usize = 64; -pub(crate) const SEL_DEGREE: usize = 2; pub const NUM_FANIN: usize = 2; pub const NUM_FANIN_LOGUP: usize = 2; @@ -7,3 +6,6 @@ pub const NUM_FANIN_LOGUP: usize = 2; pub const MAX_NUM_VARIABLES: usize = 24; pub const DYNAMIC_RANGE_MAX_BITS: usize = 18; + +pub const SEPTIC_EXTENSION_DEGREE: usize = 7; +pub const SEPTIC_JACOBIAN_NUM_MLES: usize = 3 * SEPTIC_EXTENSION_DEGREE; diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index 9b0020116..cebe79899 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -5,14 +5,15 @@ use crate::{ circuit_builder::ConstraintSystem, error::ZKVMError, scheme::{ - constants::{NUM_FANIN, NUM_FANIN_LOGUP}, - hal::{DeviceProvingKey, MainSumcheckEvals, ProofInput, TowerProverSpec}, + constants::{NUM_FANIN, NUM_FANIN_LOGUP, SEPTIC_EXTENSION_DEGREE}, + hal::{DeviceProvingKey, EccQuarkProver, MainSumcheckEvals, ProofInput, TowerProverSpec}, + septic_curve::{SepticExtension, SepticPoint, SymbolicSepticExtension}, utils::{ infer_tower_logup_witness, infer_tower_product_witness, masked_mle_split_to_chunks, wit_infer_by_expr, }, }, - structs::{ComposedConstrainSystem, PointAndEval, TowerProofs}, + structs::{ComposedConstrainSystem, EccQuarkProof, PointAndEval, TowerProofs}, }; use either::Either; use ff_ext::ExtensionField; @@ -20,6 +21,7 @@ use gkr_iop::{ cpu::{CpuBackend, CpuProver}, gkr::{self, Evaluation, GKRProof, GKRProverOutput, layer::LayerWitness}, hal::ProverBackend, + selector::{SelectorContext, SelectorType}, }; use itertools::{Itertools, chain}; use mpcs::{Point, PolynomialCommitmentScheme}; @@ -30,7 +32,10 @@ use multilinear_extensions::{ virtual_poly::build_eq_x_r_vec, virtual_polys::VirtualPolynomialsBuilder, }; -use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}; +use rayon::iter::{ + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, + IntoParallelRefMutIterator, ParallelIterator, +}; use std::{collections::BTreeMap, sync::Arc}; use sumcheck::{ macros::{entered_span, exit_span}, @@ -47,6 +52,258 @@ pub type TowerRelationOutput = ( Vec>, Vec>, ); + +// accumulate N=2^n EC points into one EC point using affine coordinates +// in one layer which borrows ideas from the [Quark paper](https://eprint.iacr.org/2020/1275.pdf) +pub struct CpuEccProver; + +impl CpuEccProver { + pub fn create_ecc_proof<'a, E: ExtensionField>( + num_instances: usize, + xs: Vec>>, + ys: Vec>>, + invs: Vec>>, + transcript: &mut impl Transcript, + ) -> EccQuarkProof { + assert_eq!(xs.len(), SEPTIC_EXTENSION_DEGREE); + assert_eq!(ys.len(), SEPTIC_EXTENSION_DEGREE); + + let n = xs[0].num_vars() - 1; + tracing::debug!( + "Creating EC Summation Quark proof with {} points in {n} variables", + num_instances + ); + + let out_rt = transcript.sample_and_append_vec(b"ecc", n); + let num_threads = optimal_sumcheck_threads(out_rt.len()); + + // expression with add (3 zero constrains) and bypass (2 zero constrains) + let alpha_pows = transcript.sample_and_append_challenge_pows( + SEPTIC_EXTENSION_DEGREE * 3 + SEPTIC_EXTENSION_DEGREE * 2, + b"ecc_alpha", + ); + let mut alpha_pows_iter = alpha_pows.iter(); + + let mut expr_builder = VirtualPolynomialsBuilder::new(num_threads, out_rt.len()); + + let sel_add = SelectorType::QuarkBinaryTreeLessThan(0.into()); + let sel_add_ctx = SelectorContext { + offset: 0, + num_instances, + num_vars: n, + }; + let mut sel_add_mle: MultilinearExtension<'_, E> = + sel_add.compute(&out_rt, &sel_add_ctx).unwrap(); + // we construct sel_bypass witness here + // verifier can derive it via `sel_bypass = eq - sel_add - sel_last_onehot` + let mut sel_bypass_mle: Vec = build_eq_x_r_vec(&out_rt); + match sel_add_mle.evaluations() { + FieldType::Ext(sel_add_mle) => sel_add_mle + .par_iter() + .zip_eq(sel_bypass_mle.par_iter_mut()) + .for_each(|(sel_add, sel_bypass)| { + if *sel_add != E::ZERO { + *sel_bypass = E::ZERO; + } + }), + _ => unreachable!(), + } + *sel_bypass_mle.last_mut().unwrap() = E::ZERO; + let mut sel_bypass_mle = sel_bypass_mle.into_mle(); + let sel_add_expr = expr_builder.lift(sel_add_mle.to_either()); + let sel_bypass_expr = expr_builder.lift(sel_bypass_mle.to_either()); + + let mut exprs_add = vec![]; + let mut exprs_bypass = vec![]; + + let filter_bj = |v: &[Arc>], j: usize| { + v.iter() + .map(|v| { + v.get_base_field_vec() + .iter() + .enumerate() + .filter(|(i, _)| *i % 2 == j) + .map(|(_, v)| v) + .cloned() + .collect_vec() + .into_mle() + }) + .collect_vec() + }; + // build x[b,0], x[b,1], y[b,0], y[b,1] + let mut x0 = filter_bj(&xs, 0); + let mut y0 = filter_bj(&ys, 0); + let mut x1 = filter_bj(&xs, 1); + let mut y1 = filter_bj(&ys, 1); + // build x[1,b], y[1,b], s[1,b] + let mut x3 = xs.iter().map(|x| x.as_view_slice(2, 1)).collect_vec(); + let mut y3 = ys.iter().map(|x| x.as_view_slice(2, 1)).collect_vec(); + let mut s = invs.iter().map(|x| x.as_view_slice(2, 1)).collect_vec(); + + let s = SymbolicSepticExtension::new( + s.iter_mut() + .map(|s| expr_builder.lift(s.to_either())) + .collect(), + ); + let x0 = SymbolicSepticExtension::new( + x0.iter_mut() + .map(|x| expr_builder.lift(x.to_either())) + .collect(), + ); + let y0 = SymbolicSepticExtension::new( + y0.iter_mut() + .map(|y| expr_builder.lift(y.to_either())) + .collect(), + ); + let x1 = SymbolicSepticExtension::new( + x1.iter_mut() + .map(|x| expr_builder.lift(x.to_either())) + .collect(), + ); + let y1 = SymbolicSepticExtension::new( + y1.iter_mut() + .map(|y| expr_builder.lift(y.to_either())) + .collect(), + ); + let x3 = SymbolicSepticExtension::new( + x3.iter_mut() + .map(|x| expr_builder.lift(x.to_either())) + .collect(), + ); + let y3 = SymbolicSepticExtension::new( + y3.iter_mut() + .map(|y| expr_builder.lift(y.to_either())) + .collect(), + ); + // affine addition + // zerocheck: 0 = s[1,b] * (x[b,0] - x[b,1]) - (y[b,0] - y[b,1]) with b != (1,...,1) + exprs_add.extend( + (s.clone() * (&x0 - &x1) - (&y0 - &y1)) + .to_exprs() + .into_iter() + .zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) + .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), + ); + + // zerocheck: 0 = s[1,b]^2 - x[b,0] - x[b,1] - x[1,b] with b != (1,...,1) + exprs_add.extend( + ((&s * &s) - &x0 - &x1 - &x3) + .to_exprs() + .into_iter() + .zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) + .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), + ); + + // zerocheck: 0 = s[1,b] * (x[b,0] - x[1,b]) - (y[b,0] + y[1,b]) with b != (1,...,1) + exprs_add.extend( + (s.clone() * (&x0 - &x3) - (&y0 + &y3)) + .to_exprs() + .into_iter() + .zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) + .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), + ); + + let exprs_add = exprs_add.into_iter().sum::>() * sel_add_expr; + + // deal with bypass + // 0 = (x[1,b] - x[b,0]) + exprs_bypass.extend( + (&x3 - &x0) + .to_exprs() + .into_iter() + .zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) + .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), + ); + + // 0 = (y[1,b] - y[b,0]) + exprs_bypass.extend( + (&y3 - &y0) + .to_exprs() + .into_iter() + .zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) + .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), + ); + assert!(alpha_pows_iter.next().is_none()); + + let exprs_bypass = exprs_bypass.into_iter().sum::>() * sel_bypass_expr; + + let (zerocheck_proof, state) = IOPProverState::prove( + expr_builder.to_virtual_polys(&[exprs_add + exprs_bypass], &[]), + transcript, + ); + + let rt = state.collect_raw_challenges(); + let evals = state.get_mle_flatten_final_evaluations(); + + assert_eq!(zerocheck_proof.extract_sum(), E::ZERO); + // 7 for x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[1,rt], y[1,rt], s[1,rt] + assert_eq!(evals.len(), 2 + SEPTIC_EXTENSION_DEGREE * 7); + + let last_evaluation_index = (1 << n) - 1; + let x3 = xs.iter().map(|x| x.as_view_slice(2, 1)).collect_vec(); + let y3 = ys.iter().map(|y| y.as_view_slice(2, 1)).collect_vec(); + let final_sum_x: SepticExtension = (x3.iter()) + .map(|x| x.get_base_field_vec()[last_evaluation_index - 1]) // x[1,...,1,0] + .collect_vec() + .into(); + let final_sum_y: SepticExtension = (y3.iter()) + .map(|y| y.get_base_field_vec()[last_evaluation_index - 1]) // x[1,...,1,0] + .collect_vec() + .into(); + let final_sum = SepticPoint::from_affine(final_sum_x, final_sum_y); + + #[cfg(feature = "sanity-check")] + { + let s = invs.iter().map(|x| x.as_view_slice(2, 1)).collect_vec(); + let x0 = filter_bj(&xs, 0); + let y0 = filter_bj(&ys, 0); + let x1 = filter_bj(&xs, 1); + let y1 = filter_bj(&ys, 1); + + let evals = &evals[2..]; + // check evaluations + for i in 0..SEPTIC_EXTENSION_DEGREE { + assert_eq!(s[i].evaluate(&rt), evals[i]); + assert_eq!(x0[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE + i]); + assert_eq!(y0[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE * 2 + i]); + assert_eq!(x1[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE * 3 + i]); + assert_eq!(y1[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE * 4 + i]); + assert_eq!(x3[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE * 5 + i]); + assert_eq!(y3[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE * 6 + i]); + } + } + + EccQuarkProof { + zerocheck_proof, + num_instances, + evals, + rt, + sum: final_sum, + } + } +} + +impl> EccQuarkProver> + for CpuProver> +{ + fn prove_ec_sum_quark<'a>( + &self, + num_instances: usize, + xs: Vec>>, + ys: Vec>>, + invs: Vec>>, + transcript: &mut impl Transcript, + ) -> Result, ZKVMError> { + Ok(CpuEccProver::create_ecc_proof( + num_instances, + xs, + ys, + invs, + transcript, + )) + } +} + pub struct CpuTowerProver; impl CpuTowerProver { @@ -59,7 +316,7 @@ impl CpuTowerProver { #[derive(Debug, Clone)] enum GroupedMLE<'a, E: ExtensionField> { Prod((usize, Vec>)), // usize is the index in prod_specs - Logup((usize, Vec>)), /* usize is the index in logup_specs */ + Logup((usize, Vec>)), // usize is the index in logup_specs } // XXX to sumcheck batched product argument with logup, we limit num_product_fanin to 2 @@ -311,8 +568,8 @@ impl> TowerProver> MainSumcheckProver> MainSumcheckProver> MainSumcheckProver> MainSumcheckProver> MainSumcheckProver>(); - let fixed_in_evals = evals.split_off(input.witness.len()); - let wits_in_evals = evals; - exit_span!(span); + let (wits_in_evals, fixed_in_evals, main_sumcheck_proof, rt) = { + let span = entered_span!("fixed::evals + witin::evals"); + let mut evals = input + .witness + .par_iter() + .chain(input.fixed.par_iter()) + .map(|poly| poly.evaluate(&rt_tower[..poly.num_vars()])) + .collect::>(); + let fixed_in_evals = evals.split_off(input.witness.len()); + let wits_in_evals = evals; + exit_span!(span); + (wits_in_evals, fixed_in_evals, None, rt_tower) + }; Ok(( - rt_tower, + rt, MainSumcheckEvals { wits_in_evals, fixed_in_evals, }, - None, + main_sumcheck_proof, None, )) } @@ -713,38 +1003,38 @@ impl> OpeningProver>, points: Vec>, - mut evals: Vec>, // where each inner Vec = wit_evals + fixed_evals - circuit_num_polys: &[(usize, usize)], - num_instances: &[(usize, usize)], + mut evals: Vec>>, // where each inner vec![wit_evals, fixed_evals] transcript: &mut impl Transcript, ) -> PCS::Proof { let mut rounds = vec![]; - rounds.push(( - &witness_data, - points - .iter() - .zip_eq(evals.iter_mut()) - .zip_eq(num_instances.iter()) - .map(|((point, evals), (chip_idx, _))| { - let (num_witin, _) = circuit_num_polys[*chip_idx]; - (point.clone(), evals.drain(..num_witin).collect_vec()) + rounds.push((&witness_data, { + evals + .iter_mut() + .zip(&points) + .filter_map(|(evals, point)| { + let witin_evals = evals.remove(0); + if !witin_evals.is_empty() { + Some((point.clone(), witin_evals)) + } else { + None + } }) - .collect_vec(), - )); + .collect_vec() + })); if let Some(fixed_data) = fixed_data.as_ref().map(|f| f.as_ref()) { - rounds.push(( - fixed_data, - points - .iter() - .zip_eq(evals.iter_mut()) - .zip_eq(num_instances.iter()) - .filter(|(_, (chip_idx, _))| { - let (_, num_fixed) = circuit_num_polys[*chip_idx]; - num_fixed > 0 + rounds.push((fixed_data, { + evals + .iter_mut() + .zip(points) + .filter_map(|(evals, point)| { + if !evals.is_empty() && !evals[0].is_empty() { + Some((point.clone(), evals.remove(0))) + } else { + None + } }) - .map(|((point, evals), _)| (point.clone(), evals.to_vec())) - .collect_vec(), - )); + .collect_vec() + })); } PCS::batch_open(&self.backend.pp, rounds, transcript).unwrap() } @@ -830,3 +1120,125 @@ where self.backend.as_ref() } } + +#[cfg(test)] +mod tests { + use crate::scheme::{ + constants::SEPTIC_EXTENSION_DEGREE, + cpu::CpuEccProver, + septic_curve::{SepticExtension, SepticPoint}, + verifier::EccVerifier, + }; + use ff_ext::BabyBearExt4; + use itertools::Itertools; + use multilinear_extensions::{ + mle::{IntoMLE, MultilinearExtension}, + util::transpose, + }; + use p3::babybear::BabyBear; + use std::{iter::repeat_n, sync::Arc}; + use transcript::BasicTranscript; + use witness::next_pow2_instance_padding; + + #[test] + fn test_ecc_quark_prover() { + for n_points in 1..2 ^ 10 { + test_ecc_quark_prover_inner(n_points) + } + } + + fn test_ecc_quark_prover_inner(n_points: usize) { + type E = BabyBearExt4; + type F = BabyBear; + + let log2_n = next_pow2_instance_padding(n_points).ilog2(); + let mut rng = rand::thread_rng(); + + let final_sum; + // generate 1 ecc add witness + let ecc_spec: Vec> = { + // sample N = 2^n points + let mut points = (0..n_points) + .map(|_| SepticPoint::::random(&mut rng)) + .collect_vec(); + points.extend(repeat_n( + SepticPoint::point_at_infinity(), + (1 << log2_n) - points.len(), + )); + let mut s = Vec::with_capacity(1 << (log2_n + 1)); + s.extend(repeat_n(SepticExtension::zero(), 1 << log2_n)); + + for layer in (1..=log2_n).rev() { + let num_inputs = 1 << layer; + let inputs = &points[points.len() - num_inputs..]; + + s.extend(inputs.chunks_exact(2).map(|chunk| { + let p = &chunk[0]; + let q = &chunk[1]; + if q.is_infinity { + SepticExtension::zero() + } else { + (&p.y - &q.y) * (&p.x - &q.x).inverse().unwrap() + } + })); + + points.extend( + inputs + .chunks_exact(2) + .map(|chunk| { + let p = chunk[0].clone(); + let q = chunk[1].clone(); + p + q + }) + .collect_vec(), + ); + } + final_sum = points.last().cloned().unwrap(); + + // padding to 2*N + s.push(SepticExtension::zero()); + points.push(SepticPoint::point_at_infinity()); + + assert_eq!(s.len(), 1 << (log2_n + 1)); + assert_eq!(points.len(), 1 << (log2_n + 1)); + + // transform points to row major matrix + let trace = points + .iter() + .zip_eq(s.iter()) + .map(|(p, s)| { + p.x.iter() + .chain(p.y.iter()) + .chain(s.iter()) + .copied() + .collect_vec() + }) + .collect_vec(); + + // transpose row major matrix to column major matrix + transpose(trace) + .into_iter() + .map(|v| v.into_mle()) + .collect_vec() + }; + let (xs, rest) = ecc_spec.split_at(SEPTIC_EXTENSION_DEGREE); + let (ys, s) = rest.split_at(SEPTIC_EXTENSION_DEGREE); + + let mut transcript = BasicTranscript::new(b"test"); + let quark_proof = CpuEccProver::create_ecc_proof( + n_points, + xs.iter().cloned().map(Arc::new).collect_vec(), + ys.iter().cloned().map(Arc::new).collect_vec(), + s.iter().cloned().map(Arc::new).collect_vec(), + &mut transcript, + ); + + assert_eq!(quark_proof.sum, final_sum); + let mut transcript = BasicTranscript::new(b"test"); + assert!( + EccVerifier::verify_ecc_proof(&quark_proof, &mut transcript) + .inspect_err(|err| println!("err {:?}", err)) + .is_ok() + ); + } +} diff --git a/ceno_zkvm/src/scheme/gpu/mod.rs b/ceno_zkvm/src/scheme/gpu/mod.rs index 023686e36..07e5adb4d 100644 --- a/ceno_zkvm/src/scheme/gpu/mod.rs +++ b/ceno_zkvm/src/scheme/gpu/mod.rs @@ -203,7 +203,7 @@ fn build_tower_witness_gpu<'buf, E: ExtensionField>( zkvm_v1_css: cs, .. } = composed_cs; let num_instances_with_rotation = - input.num_instances << composed_cs.rotation_vars().unwrap_or(0); + input.num_instances() << composed_cs.rotation_vars().unwrap_or(0); let chip_record_alpha = challenges[0]; // TODO: safety ? @@ -653,9 +653,7 @@ impl> MainSumcheckProver> OpeningProver as ProverBackend>::PcsData>>, points: Vec>, mut evals: Vec>, // where each inner Vec = wit_evals + fixed_evals - circuit_num_polys: &[(usize, usize)], - num_instances: &[(usize, usize)], transcript: &mut (impl Transcript + 'static), ) -> PCS::Proof { if std::any::TypeId::of::() != std::any::TypeId::of::() { @@ -764,32 +760,34 @@ impl> OpeningProver 0 + rounds.push((fixed_data, { + evals + .iter_mut() + .zip(points) + .filter_map(|(evals, point)| { + if !evals.is_empty() && !evals[0].is_empty() { + Some((point.clone(), evals.remove(0))) + } else { + None + } }) - .map(|((point, evals), _)| (point.clone(), evals.to_vec())) - .collect_vec(), - )); + .collect_vec() + })); } // use ceno_gpu::{ diff --git a/ceno_zkvm/src/scheme/hal.rs b/ceno_zkvm/src/scheme/hal.rs index 17ad6b92a..c57966217 100644 --- a/ceno_zkvm/src/scheme/hal.rs +++ b/ceno_zkvm/src/scheme/hal.rs @@ -4,7 +4,7 @@ use crate::{ circuit_builder::ConstraintSystem, error::ZKVMError, scheme::cpu::TowerRelationOutput, - structs::{ComposedConstrainSystem, ZKVMProvingKey}, + structs::{ComposedConstrainSystem, EccQuarkProof, ZKVMProvingKey}, }; use ff_ext::ExtensionField; use gkr_iop::{ @@ -24,6 +24,7 @@ pub trait ProverDevice: + OpeningProver + DeviceTransporter + ProtocolWitnessGeneratorProver + + EccQuarkProver // + FixedMLEPadder where PB: ProverBackend, @@ -37,13 +38,26 @@ pub struct ProofInput<'a, PB: ProverBackend> { pub structural_witness: Vec>>, pub fixed: Vec>>, pub public_input: Vec>>, - pub num_instances: usize, + pub num_instances: Vec, + pub has_ecc_ops: bool, } impl<'a, PB: ProverBackend> ProofInput<'a, PB> { + pub fn num_instances(&self) -> usize { + self.num_instances.iter().sum() + } + #[inline] pub fn log2_num_instances(&self) -> usize { - ceil_log2(next_pow2_instance_padding(self.num_instances)) + let num_instance = self.num_instances(); + let log2 = ceil_log2(next_pow2_instance_padding(num_instance)); + if self.has_ecc_ops { + // the mles have one extra variable to store + // the internal partial sums for ecc additions + log2 + 1 + } else { + log2 + } } } @@ -65,6 +79,23 @@ pub trait TraceCommitter { ); } +/// Accumulate N (not necessarily power of 2) EC points into one EC point using affine coordinates +/// in one layer which borrows ideas from the [Quark paper](https://eprint.iacr.org/2020/1275.pdf) +/// Note that these points are defined over the septic extension field of BabyBear. +/// +/// The main constraint enforced in this quark layer is: +/// p[1,b] = affine_add(p[b,0], p[b,1]) for all b < N +pub trait EccQuarkProver { + fn prove_ec_sum_quark<'a>( + &self, + num_instances: usize, + xs: Vec>>, + ys: Vec>>, + invs: Vec>>, + transcript: &mut impl Transcript, + ) -> Result, ZKVMError>; +} + pub trait TowerProver { // infer read/write/logup records from the read/write/logup expressions and then // build multiple complete binary trees (tower tree) to accumulate these records @@ -147,9 +178,7 @@ pub trait OpeningProver { witness_data: PB::PcsData, fixed_data: Option>, points: Vec>, - evals: Vec>, - circuit_num_polys: &[(usize, usize)], - num_instances: &[(usize, usize)], + evals: Vec>>, transcript: &mut (impl Transcript + 'static), ) -> >::Proof; } diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 028f844a6..edf7a63f1 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -26,13 +26,14 @@ use itertools::{Itertools, chain, enumerate, izip}; use multilinear_extensions::{ Expression, WitnessId, fmt, mle::{ArcMultilinearExtension, IntoMLEs, MultilinearExtension}, + util::ceil_log2, utils::{eval_by_expr, eval_by_expr_with_fixed, eval_by_expr_with_instance}, }; use p3::field::{Field, FieldAlgebra}; use rand::thread_rng; use std::{ cmp::max, - collections::{BTreeSet, HashMap, HashSet}, + collections::{BTreeMap, BTreeSet, HashMap, HashSet}, fmt::Debug, fs::File, hash::Hash, @@ -42,6 +43,7 @@ use std::{ }; use strum::IntoEnumIterator; use tiny_keccak::{Hasher, Keccak}; +use witness::next_pow2_instance_padding; const MAX_CONSTRAINT_DEGREE: usize = 3; const MOCK_PROGRAM_SIZE: usize = 32; @@ -828,7 +830,10 @@ impl<'a, E: ExtensionField + Hash> MockProver { let mut cs = ConstraintSystem::::new(|| "mock_program"); let params = ProgramParams { platform: CENO_PLATFORM, - program_size: max(program.instructions.len(), MOCK_PROGRAM_SIZE), + program_size: max( + next_pow2_instance_padding(program.instructions.len()), + MOCK_PROGRAM_SIZE, + ), ..ProgramParams::default() }; let mut cb = CircuitBuilder::new(&mut cs); @@ -974,30 +979,48 @@ Hints: let mut fixed_mles = HashMap::new(); let mut num_instances = HashMap::new(); + let circuit_index_fixed_num_instances: BTreeMap = fixed_trace + .circuit_fixed_traces + .iter() + .map(|(circuit_name, rmm)| { + ( + circuit_name.clone(), + rmm.as_ref().map(|rmm| rmm.num_instances()).unwrap_or(0), + ) + }) + .collect(); let mut lkm_tables = LkMultiplicityRaw::::default(); let mut lkm_opcodes = LkMultiplicityRaw::::default(); // Process all circuits. - for ( - circuit_name, - ComposedConstrainSystem { + for (circuit_name, composed_cs) in &cs.circuit_css { + let ComposedConstrainSystem { zkvm_v1_css: cs, gkr_circuit, - }, - ) in &cs.circuit_css - { + } = &composed_cs; let is_opcode = gkr_circuit.is_some(); let [witness, structural_witness] = witnesses .get_opcode_witness(circuit_name) .or_else(|| witnesses.get_table_witness(circuit_name)) .unwrap_or_else(|| panic!("witness for {} should not be None", circuit_name)); - let num_rows = witness.num_instances(); + let num_rows = if witness.num_instances() > 0 { + witness.num_instances() + } else if structural_witness.num_instances() > 0 { + structural_witness.num_instances() + } else if composed_cs.is_static_circuit() { + circuit_index_fixed_num_instances + .get(circuit_name) + .copied() + .unwrap_or(0) + } else { + 0 + }; - if witness.num_instances() == 0 { + if num_rows == 0 { wit_mles.insert(circuit_name.clone(), vec![]); structural_wit_mles.insert(circuit_name.clone(), vec![]); fixed_mles.insert(circuit_name.clone(), vec![]); - num_instances.insert(circuit_name.clone(), num_rows); + num_instances.insert(circuit_name.clone(), 0); continue; } let mut witness = witness @@ -1133,21 +1156,20 @@ Hints: if *num_rows == 0 { continue; } - let w_selector: ArcMultilinearExtension<_> = if let Some(w_selector) = &cs.w_selector { structural_witness[w_selector.selector_expr().id()].clone() } else { let mut selector = vec![E::BaseField::ONE; *num_rows]; - selector.resize(witness[0].evaluations().len(), E::BaseField::ZERO); + selector.resize(next_pow2_instance_padding(*num_rows), E::BaseField::ZERO); MultilinearExtension::from_evaluation_vec_smart( - witness[0].num_vars(), + ceil_log2(next_pow2_instance_padding(*num_rows)), selector, ) .into() }; - for ((w_rlc_expr, annotation), _) in (cs + for ((w_rlc_expr, annotation), (ram_type_expr, _)) in (cs .w_expressions .iter() .chain(cs.w_table_expressions.iter().map(|expr| &expr.expr))) @@ -1157,8 +1179,19 @@ Hints: .chain(cs.w_table_expressions_namespace_map.iter()), ) .zip_eq(cs.w_ram_types.iter()) - .filter(|((_, _), (ram_type, _))| *ram_type == $ram_type) { + let ram_type_mle = wit_infer_by_expr( + ram_type_expr, + cs.num_witin, + cs.num_structural_witin, + cs.num_fixed as WitnessId, + fixed, + witness, + structural_witness, + &pi_mles, + &challenges, + ); + let ram_type_vec = ram_type_mle.get_ext_field_vec(); let write_rlc_records = wit_infer_by_expr( w_rlc_expr, cs.num_witin, @@ -1170,13 +1203,34 @@ Hints: &pi_mles, &challenges, ); + let w_selector_vec = w_selector.get_base_field_vec(); let write_rlc_records = - filter_mle_by_selector_mle(write_rlc_records, w_selector.clone()); + filter_mle_by_predicate(write_rlc_records, |i, _v| { + ram_type_vec[i] == E::from_canonical_u32($ram_type as u32) + && w_selector_vec[i] == E::BaseField::ONE + }); + if write_rlc_records.is_empty() { + continue; + } let mut records = vec![]; + let mut writes_within_expr_dedup = HashSet::new(); for (row, record_rlc) in enumerate(write_rlc_records) { // TODO: report error - assert_eq!(writes.insert(record_rlc), true); + assert_eq!( + writes_within_expr_dedup.insert(record_rlc), + true, + "within expression write duplicated on RAMType {:?} annotation {:?}", + $ram_type, + annotation + ); + assert_eq!( + writes.insert(record_rlc), + true, + "crossing-chip write duplicated on RAMType {:?} annotation {:?}", + $ram_type, + annotation + ); records.push((record_rlc, row)); } writes_grp_by_annotations @@ -1205,14 +1259,14 @@ Hints: structural_witness[r_selector.selector_expr().id()].clone() } else { let mut selector = vec![E::BaseField::ONE; *num_rows]; - selector.resize(witness[0].evaluations().len(), E::BaseField::ZERO); + selector.resize(next_pow2_instance_padding(*num_rows), E::BaseField::ZERO); MultilinearExtension::from_evaluation_vec_smart( - witness[0].num_vars(), + ceil_log2(next_pow2_instance_padding(*num_rows)), selector, ) .into() }; - for ((r_rlc_expr, annotation), (_, r_exprs)) in (cs + for ((r_rlc_expr, annotation), (ram_type_expr, r_exprs)) in (cs .r_expressions .iter() .chain(cs.r_table_expressions.iter().map(|expr| &expr.expr))) @@ -1222,8 +1276,19 @@ Hints: .chain(cs.r_table_expressions_namespace_map.iter()), ) .zip_eq(cs.r_ram_types.iter()) - .filter(|((_, _), (ram_type, _))| *ram_type == $ram_type) { + let ram_type_mle = wit_infer_by_expr( + ram_type_expr, + cs.num_witin, + cs.num_structural_witin, + cs.num_fixed as WitnessId, + fixed, + witness, + structural_witness, + &pi_mles, + &challenges, + ); + let ram_type_vec = ram_type_mle.get_ext_field_vec(); let read_records = wit_infer_by_expr( r_rlc_expr, cs.num_witin, @@ -1235,8 +1300,14 @@ Hints: &pi_mles, &challenges, ); - let read_records = - filter_mle_by_selector_mle(read_records, r_selector.clone()); + let r_selector_vec = r_selector.get_base_field_vec(); + let read_records = filter_mle_by_predicate(read_records, |i, _v| { + ram_type_vec[i] == E::from_canonical_u32($ram_type as u32) + && r_selector_vec[i] == E::BaseField::ONE + }); + if read_records.is_empty() { + continue; + } if $ram_type == RAMType::GlobalState { // r_exprs = [GlobalState, pc, timestamp] @@ -1269,9 +1340,23 @@ Hints: }; let mut records = vec![]; + let mut reads_within_expr_dedup = HashSet::new(); for (row, record) in enumerate(read_records) { // TODO: return error - assert_eq!(reads.insert(record), true); + assert_eq!( + reads_within_expr_dedup.insert(record), + true, + "within expression read duplicated on RAMType {:?} annotation {:?}", + $ram_type, + annotation, + ); + assert_eq!( + reads.insert(record), + true, + "crossing-chip read duplicated on RAMType {:?} annotation {:?}", + $ram_type, + annotation, + ); records.push((record, row)); } reads_grp_by_annotations @@ -1467,6 +1552,19 @@ fn print_errors( } } +fn filter_mle_by_predicate(target_mle: ArcMultilinearExtension, mut predicate: F) -> Vec +where + E: ExtensionField, + F: FnMut(usize, &E) -> bool, +{ + target_mle + .get_ext_field_vec() + .iter() + .enumerate() + .filter_map(|(i, v)| if predicate(i, v) { Some(*v) } else { None }) + .collect_vec() +} + fn filter_mle_by_selector_mle( target_mle: ArcMultilinearExtension, selector: ArcMultilinearExtension, @@ -1487,7 +1585,6 @@ fn filter_mle_by_selector_mle( #[cfg(test)] mod tests { - use super::*; use crate::{ ROMType, diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index e1094d77f..187d2a708 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -9,12 +9,12 @@ use std::{ sync::Arc, }; -use crate::scheme::hal::MainSumcheckEvals; +use crate::scheme::{constants::SEPTIC_EXTENSION_DEGREE, hal::MainSumcheckEvals}; use gkr_iop::hal::MultilinearPolynomial; use itertools::Itertools; use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{ - Instance, + Expression, Instance, mle::{IntoMLE, MultilinearExtension}, }; use p3::field::FieldAlgebra; @@ -113,38 +113,52 @@ impl< // only keep track of circuits that have non-zero instances let mut num_instances = Vec::with_capacity(self.pk.circuit_pks.len()); let mut num_instances_with_rotation = Vec::with_capacity(self.pk.circuit_pks.len()); + let mut circuit_name_num_instances_mapping = BTreeMap::new(); for (index, (circuit_name, ProvingKey { vk, .. })) in self.pk.circuit_pks.iter().enumerate() { // num_instance from witness might include rotation if let Some(num_instance) = witnesses - .get_opcode_witness(circuit_name) - .or_else(|| witnesses.get_table_witness(circuit_name)) - .map(|rmms| &rmms[0]) - .map(|rmm| rmm.num_instances()) + .num_instances + .get(circuit_name) + .cloned() .and_then(|num_instance| { - if num_instance > 0 { + if num_instance.iter().sum::() > 0 { Some(num_instance) } else { None } }) + .or_else(|| { + vk.get_cs().is_static_circuit().then(|| { + self.pk + .circuit_index_fixed_num_instances + .get(&index) + .copied() + .map(|num_instance| vec![num_instance]) + .unwrap_or(vec![]) + }) + }) { - num_instances.push(( - index, - num_instance >> vk.get_cs().rotation_vars().unwrap_or(0), - )); - num_instances_with_rotation.push((index, num_instance)) + let num_instance_exclude_rotation = num_instance + .iter() + .map(|num_instance| num_instance >> vk.get_cs().rotation_vars().unwrap_or(0)) + .collect_vec(); + num_instances.push((index, num_instance_exclude_rotation.clone())); + circuit_name_num_instances_mapping + .insert(circuit_name, num_instance_exclude_rotation); + num_instances_with_rotation.push((index, num_instance)); } } // write (circuit_idx, num_var) to transcript for (circuit_idx, num_instance) in &num_instances { transcript.append_message(&circuit_idx.to_le_bytes()); - transcript.append_message(&num_instance.to_le_bytes()); + for num_instance in num_instance { + transcript.append_message(&num_instance.to_le_bytes()); + } } let commit_to_traces_span = entered_span!("batch commit to traces", profiling_1 = true); - let mut wits_instances = BTreeMap::new(); let mut wits_rmms = BTreeMap::new(); let mut structural_wits = BTreeMap::new(); @@ -157,31 +171,19 @@ impl< } else { RowMajorMatrix::empty() }; - let rotation_vars = self - .pk - .circuit_pks - .get(&circuit_name) - .unwrap() - .vk - .get_cs() - .rotation_vars(); - let num_instances = witness_rmm.num_instances() >> (rotation_vars.unwrap_or(0)); - assert!( - wits_instances - .insert(circuit_name.clone(), num_instances) - .is_none() - ); - if num_instances == 0 { - continue; - } - let structural_witness = structural_witness_rmm.to_mles(); - wits_rmms.insert(circuit_name_index_mapping[&circuit_name], witness_rmm); - structural_wits.insert(circuit_name, (structural_witness, num_instances)); + if witness_rmm.num_instances() > 0 { + wits_rmms.insert(circuit_name_index_mapping[&circuit_name], witness_rmm); + } + if structural_witness_rmm.num_instances() > 0 { + let num_instances = circuit_name_num_instances_mapping + .get(&circuit_name) + .unwrap(); + let structural_witness = structural_witness_rmm.to_mles(); + structural_wits.insert(circuit_name, (structural_witness, num_instances)); + } } - debug_assert_eq!(num_instances.len(), wits_rmms.len()); - // commit to witness traces in batch let (mut witness_mles, witness_data, witin_commit) = self.device.commit_traces(wits_rmms); PCS::write_commitment(&witin_commit, &mut transcript).map_err(ZKVMError::PCSError)?; @@ -208,11 +210,12 @@ impl< let (points, evaluations) = self.pk.circuit_pks.iter().enumerate().try_fold( (vec![], vec![]), |(mut points, mut evaluations), (index, (circuit_name, pk))| { - let num_instances = *wits_instances - .get(circuit_name) - .ok_or(ZKVMError::WitnessNotFound(circuit_name.to_string().into()))?; + let num_instances = circuit_name_num_instances_mapping + .get(&circuit_name) + .cloned() + .unwrap_or_default(); let cs = pk.get_cs(); - if num_instances == 0 { + if num_instances.is_empty() { // we need to drain respective fixed when num_instances is 0 if cs.num_fixed() > 0 { let _ = fixed_mles.drain(..cs.num_fixed()).collect_vec(); @@ -237,13 +240,13 @@ impl< exit_span!(structural_witness_span); let fixed = fixed_mles.drain(..cs.num_fixed()).collect_vec(); - - let mut input = ProofInput { + let input = ProofInput { witness: witness_mle, fixed, structural_witness, public_input: public_input.clone(), - num_instances, + num_instances: num_instances.clone(), + has_ecc_ops: cs.has_ecc_ops(), }; if cs.is_opcode_circuit() { @@ -255,28 +258,35 @@ impl< &challenges, )?; tracing::trace!( - "generated proof for opcode {} with num_instances={}", + "generated proof for opcode {} with num_instances={:?}", circuit_name, num_instances ); points.push(input_opening_point); - evaluations.push(opcode_proof.wits_in_evals.clone()); + evaluations.push(vec![opcode_proof.wits_in_evals.clone()]); chip_proofs.insert(index, opcode_proof); } else { // FIXME: PROGRAM table circuit is not guaranteed to have 2^n instances - input.num_instances = 1 << input.log2_num_instances(); - let (mut table_proof, pi_in_evals, input_opening_point) = self - .create_chip_proof(circuit_name, pk, input, &mut transcript, &challenges)?; - points.push(input_opening_point); - evaluations.push( - [ + // input.num_instances = 1 << input.log2_num_instances(); + let (table_proof, pi_in_evals, input_opening_point) = self.create_chip_proof( + circuit_name, + pk, + input, + &mut transcript, + &challenges, + )?; + if cs.num_witin() > 0 || cs.num_fixed() > 0 { + points.push(input_opening_point); + evaluations.push(vec![ table_proof.wits_in_evals.clone(), table_proof.fixed_in_evals.clone(), - ] - .concat(), - ); + ]); + } else { + assert!(table_proof.wits_in_evals.is_empty()); + assert!(table_proof.fixed_in_evals.is_empty()); + } // FIXME: PROGRAM table circuit is not guaranteed to have 2^n instances - table_proof.num_instances = num_instances; + // table_proof.num_instances = num_instances; chip_proofs.insert(index, table_proof); for (idx, eval) in pi_in_evals { pi_evals[idx] = eval; @@ -289,20 +299,12 @@ impl< // batch opening pcs // generate static info from prover key for expected num variable - let circuit_num_polys = self - .pk - .circuit_pks - .values() - .map(|pk| (pk.get_cs().num_witin(), pk.get_cs().num_fixed())) - .collect_vec(); let pcs_opening = entered_span!("pcs_opening", profiling_1 = true); let mpcs_opening_proof = self.device.open( witness_data, Some(device_pk.pcs_data), points, evaluations, - &circuit_num_polys, - &num_instances_with_rotation, &mut transcript, ); exit_span!(pcs_opening); @@ -336,7 +338,38 @@ impl< let log2_num_instances = input.log2_num_instances(); let num_var_with_rotation = log2_num_instances + cs.rotation_vars().unwrap_or(0); - // println!("create_chip_proof: {}", name); + // run ecc quark prover + let ecc_proof = if !cs.zkvm_v1_css.ec_final_sum.is_empty() { + let ec_point_exprs = &cs.zkvm_v1_css.ec_point_exprs; + assert_eq!(ec_point_exprs.len(), SEPTIC_EXTENSION_DEGREE * 2); + let mut xs_ys = ec_point_exprs + .iter() + .map(|expr| match expr { + Expression::WitIn(id) => input.witness[*id as usize].clone(), + _ => unreachable!("ec point's expression must be WitIn"), + }) + .collect_vec(); + let ys = xs_ys.split_off(SEPTIC_EXTENSION_DEGREE); + let xs = xs_ys; + let slopes = cs + .zkvm_v1_css + .ec_slope_exprs + .iter() + .map(|expr| match expr { + Expression::WitIn(id) => input.witness[*id as usize].clone(), + _ => unreachable!("slope's expression must be WitIn"), + }) + .collect_vec(); + Some(self.device.prove_ec_sum_quark( + input.num_instances(), + xs, + ys, + slopes, + transcript, + )?) + } else { + None + }; // build main witness let (records, is_padded) = @@ -355,6 +388,15 @@ impl< num_var_with_rotation, ); + // TODO: batch reduction into main sumcheck + // x[rt,0] = \sum_b eq([rt,0], b) * x[b] + // x[rt,1] = \sum_b eq([rt,1], b) * x[b] + // x[1,rt] = \sum_b eq([1,rt], b) * x[b] + // y[rt,0] = \sum_b eq([rt,0], b) * y[b] + // y[rt,1] = \sum_b eq([rt,1], b) * y[b] + // y[1,rt] = \sum_b eq([1,rt], b) * y[b] + // s[0,rt] = \sum_b eq([0,rt], b) * s[b] + // 1. prove the main constraints among witness polynomials // 2. prove the relation between last layer in the tower and read/write/logup records let span = entered_span!("prove_main_constraints", profiling_2 = true); @@ -389,6 +431,7 @@ impl< main_sumcheck_proofs, gkr_iop_proof, tower_proof, + ecc_proof, fixed_in_evals, wits_in_evals, num_instances: input.num_instances, diff --git a/ceno_zkvm/src/scheme/septic_curve.rs b/ceno_zkvm/src/scheme/septic_curve.rs new file mode 100644 index 000000000..f9b6b4f76 --- /dev/null +++ b/ceno_zkvm/src/scheme/septic_curve.rs @@ -0,0 +1,1174 @@ +use either::Either; +use ff_ext::{ExtensionField, FromUniformBytes}; +use multilinear_extensions::Expression; +// The extension field and curve definition are adapted from +// https://github.com/succinctlabs/sp1/blob/v5.2.1/crates/stark/src/septic_curve.rs +use p3::field::{Field, FieldAlgebra}; +use rand::RngCore; +use serde::{Deserialize, Serialize}; +use std::{ + iter::Sum, + ops::{Add, Deref, Mul, MulAssign, Neg, Sub}, +}; + +/// F[z] / (z^6 - z - 4) +/// +/// ```sage +/// # finite field F = GF(2^31 - 2^27 + 1) +/// p = 2^31 - 2^27 + 1 +/// F = GF(p) +/// +/// # polynomial ring over F +/// R. = PolynomialRing(F) +/// f = x^6 - x - 4 +/// +/// # check if f(x) is irreducible +/// print(f.is_irreducible()) +/// ``` +pub struct SexticExtension([F; 6]); + +/// F[z] / (z^7 - 2z - 5) +/// +/// ```sage +/// # finite field F = GF(2^31 - 2^27 + 1) +/// p = 2^31 - 2^27 + 1 +/// F = GF(p) +/// +/// # polynomial ring over F +/// R. = PolynomialRing(F) +/// f = x^7 - 2x - 5 +/// +/// # check if f(x) is irreducible +/// print(f.is_irreducible()) +/// ``` +#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, Hash)] +pub struct SepticExtension(pub [F; 7]); + +impl From<&[F]> for SepticExtension { + fn from(slice: &[F]) -> Self { + assert!(slice.len() == 7); + let mut arr = [F::default(); 7]; + arr.copy_from_slice(&slice[0..7]); + Self(arr) + } +} + +impl From> for SepticExtension { + fn from(v: Vec) -> Self { + assert!(v.len() == 7); + let mut arr = [F::default(); 7]; + arr.copy_from_slice(&v[0..7]); + Self(arr) + } +} + +impl Deref for SepticExtension { + type Target = [F]; + + fn deref(&self) -> &[F] { + &self.0 + } +} + +impl SepticExtension { + pub fn is_zero(&self) -> bool { + self.0.iter().all(|c| *c == F::ZERO) + } + + pub fn zero() -> Self { + Self([F::ZERO; 7]) + } + + pub fn one() -> Self { + let mut arr = [F::ZERO; 7]; + arr[0] = F::ONE; + Self(arr) + } + + // returns z^{i*p} for i = 0..6 + // + // The sage script to compute z^{i*p} is as follows: + // ```sage + // p = 2^31 - 2^27 + 1 + // Fp = GF(p) + // R. = PolynomialRing(Fp) + // mod_poly = z^7 - 2*z - 5 + // Q = R.quotient(mod_poly) + // + // # compute z^(i*p) for i = 1..6 + // for k in range(1, 7): + // power = k * p + // z_power = Q(z)^power + // print(f"z^({k}*p) = {z_power}") + // ``` + fn z_pow_p(i: usize) -> Self { + match i { + 0 => [1, 0, 0, 0, 0, 0, 0].into(), + 1 => [ + 954599710, 1359279693, 566669999, 1982781815, 1735718361, 1174868538, 1120871770, + ] + .into(), + 2 => [ + 862825265, 597046311, 978840770, 1790138282, 1044777201, 835869808, 1342179023, + ] + .into(), + 3 => [ + 596273169, 658837454, 1515468261, 367059247, 781278880, 1544222616, 155490465, + ] + .into(), + 4 => [ + 557608863, 1173670028, 1749546888, 1086464137, 803900099, 1288818584, 1184677604, + ] + .into(), + 5 => [ + 763416381, 1252567168, 628856225, 1771903394, 650712211, 19417363, 57990258, + ] + .into(), + 6 => [ + 1734711039, 1749813853, 1227235221, 1707730636, 424560395, 1007029514, 498034669, + ] + .into(), + _ => unimplemented!("i should be in [0, 7]"), + } + } + + // returns z^{i*p^2} for i = 0..6 + // we can change the above sage script to compute z^{i*p^2} by replacing + // `power = k * p` with `power = k * p * p` + fn z_pow_p_square(i: usize) -> Self { + match i { + 0 => [1, 0, 0, 0, 0, 0, 0].into(), + 1 => [ + 1013489358, 1619071628, 304593143, 1949397349, 1564307636, 327761151, 415430835, + ] + .into(), + 2 => [ + 209824426, 1313900768, 38410482, 256593180, 1708830551, 1244995038, 1555324019, + ] + .into(), + 3 => [ + 1475628651, 777565847, 704492386, 1218528120, 1245363405, 475884575, 649166061, + ] + .into(), + 4 => [ + 550038364, 948935655, 68722023, 1251345762, 1692456177, 1177958698, 350232928, + ] + .into(), + 5 => [ + 882720258, 821925756, 199955840, 812002876, 1484951277, 1063138035, 491712810, + ] + .into(), + 6 => [ + 738287111, 1955364991, 552724293, 1175775744, 341623997, 1454022463, 408193320, + ] + .into(), + _ => unimplemented!("i should be in [0, 7]"), + } + } + + // returns self^p = (a0 + a1*z^p + ... + a6*z^(6p)) + pub fn frobenius(&self) -> Self { + Self::z_pow_p(0) * self.0[0] + + Self::z_pow_p(1) * self.0[1] + + Self::z_pow_p(2) * self.0[2] + + Self::z_pow_p(3) * self.0[3] + + Self::z_pow_p(4) * self.0[4] + + Self::z_pow_p(5) * self.0[5] + + Self::z_pow_p(6) * self.0[6] + } + + // returns self^(p^2) = (a0 + a1*z^(p^2) + ... + a6*z^(6*p^2)) + pub fn double_frobenius(&self) -> Self { + Self::z_pow_p_square(0) * self.0[0] + + Self::z_pow_p_square(1) * self.0[1] + + Self::z_pow_p_square(2) * self.0[2] + + Self::z_pow_p_square(3) * self.0[3] + + Self::z_pow_p_square(4) * self.0[4] + + Self::z_pow_p_square(5) * self.0[5] + + Self::z_pow_p_square(6) * self.0[6] + } + + // returns self^(p + p^2 + ... + p^6) + fn norm_sub(&self) -> Self { + let a = self.frobenius() * self.double_frobenius(); + let b = a.double_frobenius(); + let c = b.double_frobenius(); + + a * b * c + } + + // norm = self^(1 + p + ... + p^6) + // = self^((p^7-1)/(p-1)) + // it's a field element in F since norm^p = norm + fn norm(&self) -> F { + (self.norm_sub() * self).0[0] + } + + pub fn is_square(&self) -> bool { + // since a^((p^7 - 1)/2) = norm(a)^((p-1)/2) + // to test if self^((p^7 - 1) / 2) == 1? + // we can just test if norm(a)^((p-1)/2) == 1? + let exp_digits = ((F::order() - 1u32) / 2u32).to_u64_digits(); + debug_assert!(exp_digits.len() == 1); + let exp = exp_digits[0]; + + self.norm().exp_u64(exp) == F::ONE + } + + pub fn inverse(&self) -> Option { + match self.is_zero() { + true => None, + false => { + // since norm(a)^(-1) * a^(p + p^2 + ... + p^6) * a = 1 + // it's easy to see a^(-1) = norm(a)^(-1) * a^(p + p^2 + ... + p^6) + let x = self.norm_sub(); + let norm = (self * &x).0[0]; + // since self is not zero, norm is not zero + let norm_inv = norm.try_inverse().unwrap(); + + Some(x * norm_inv) + } + } + } + + pub fn square(&self) -> Self { + let mut result = [F::ZERO; 7]; + let two = F::from_canonical_u32(2); + let five = F::from_canonical_u32(5); + + // i < j + for i in 0..7 { + for j in (i + 1)..7 { + let term = two * self.0[i] * self.0[j]; + let mut index = i + j; + if index < 7 { + result[index] += term; + } else { + index -= 7; + // x^7 = 2x + 5 + result[index] += five * term; + result[index + 1] += two * term; + } + } + } + // i == j: i \in [0, 3] + result[0] += self.0[0] * self.0[0]; + result[2] += self.0[1] * self.0[1]; + result[4] += self.0[2] * self.0[2]; + result[6] += self.0[3] * self.0[3]; + // a4^2 * x^8 = a4^2 * (2x + 5)x = 5a4^2 * x + 2a4^2 * x^2 + let term = self.0[4] * self.0[4]; + result[1] += five * term; + result[2] += two * term; + // a5^2 * x^10 = a5^2 * (2x + 5)x^3 = 5a5^2 * x^3 + 2a5^2 * x^4 + let term = self.0[5] * self.0[5]; + result[3] += five * term; + result[4] += two * term; + // a6^2 * x^12 = a6^2 * (2x + 5)x^5 = 5a6^2 * x^5 + 2a6^2 * x^6 + let term = self.0[6] * self.0[6]; + result[5] += five * term; + result[6] += two * term; + + Self(result) + } + + pub fn pow(&self, exp: u64) -> Self { + let mut result = Self::one(); + let num_bits = 64 - exp.leading_zeros(); + for j in (0..num_bits).rev() { + result = result.square(); + if (exp >> j) & 1u64 == 1u64 { + result = result * self; + } + } + result + } + + pub fn sqrt(&self) -> Option { + // the algorithm is adapted from [Cipolla's algorithm](https://en.wikipedia.org/wiki/Cipolla%27s_algorithm + // the code is taken from https://github.com/succinctlabs/sp1/blob/dev/crates/stark/src/septic_extension.rs#L623 + let n = self.clone(); + + if n == Self::zero() || n == Self::one() { + return Some(n); + } + + // norm = n^(1 + p + ... + p^6) = n^(p^7-1)/(p-1) + let norm = n.norm(); + let exp = ((F::order() - 1u32) / 2u32).to_u64_digits()[0]; + // euler's criterion n^((p^7-1)/2) == 1 iff n is quadratic residue + if norm.exp_u64(exp) != F::ONE { + // it's not a square + return None; + }; + + // n_power = n^((p+1)/2) + let exp = ((F::order() + 1u32) / 2u32).to_u64_digits()[0]; + let n_power = self.pow(exp); + + // n^((p^2 + p)/2) + let mut n_frobenius = n_power.frobenius(); + let mut denominator = n_frobenius.clone(); + + // n^((p^4 + p^3)/2) + n_frobenius = n_frobenius.double_frobenius(); + denominator *= n_frobenius.clone(); + // n^((p^6 + p^5)/2) + n_frobenius = n_frobenius.double_frobenius(); + // d = n^((p^6 + p^5 + p^4 + p^3 + p^2 + p) / 2) + // d^2 * n = norm + denominator *= n_frobenius; + // d' = d*n + denominator *= n; + + let base = norm.inverse(); // norm^(-1) + let g = F::GENERATOR; + let mut a = F::ONE; + let mut non_residue = F::ONE - base; + let legendre_exp = (F::order() - 1u32) / 2u32; // (p-1)/2 + + // non_residue = a^2 - 1/norm + // find `a` such that non_residue is not a square in F + while non_residue.exp_u64(legendre_exp.to_u64_digits()[0]) == F::ONE { + a *= g; + non_residue = a.square() - base; + } + + // (p+1)/2 + let cipolla_exp = ((F::order() + 1u32) / 2u32).to_u64_digits()[0]; + // x = (a+i)^((p+1)/2) where a in Fp + // x^2 = (a+i) * (a+i)^p = (a+i)*(a-i) = a^2 - i^2 + // = a^2 - non_residue = 1/norm + // therefore, x is the square root of 1/norm + let mut x = QuadraticExtension::new(a, F::ONE, non_residue); + x = x.pow(cipolla_exp); + + // (x*d')^2 = x^2 * d^2 * n^2 = 1/norm * norm * n + Some(denominator * x.real) + } +} + +// a + bi where i^2 = non_residue +#[derive(Clone, Debug)] +pub struct QuadraticExtension { + pub real: F, + pub imag: F, + pub non_residue: F, +} + +impl QuadraticExtension { + pub fn new(real: F, imag: F, non_residue: F) -> Self { + Self { + real, + imag, + non_residue, + } + } + + pub fn square(&self) -> Self { + // (a + bi)^2 = (a^2 + b^2*i^2) + 2ab*i + let real = self.real * self.real + self.non_residue * self.imag * self.imag; + let mut imag = self.real * self.imag; + imag += imag; + + Self { + real, + imag, + non_residue: self.non_residue, + } + } + + pub fn mul(&self, other: &Self) -> Self { + // (a + bi)(c + di) = (ac + bd*i^2) + (ad + bc)i + let real = self.real * other.real + self.non_residue * self.imag * other.imag; + let imag = self.real * other.imag + self.imag * other.real; + + Self { + real, + imag, + non_residue: self.non_residue, + } + } + + pub fn pow(&self, exp: u64) -> Self { + let mut result = Self { + real: F::ONE, + imag: F::ZERO, + non_residue: self.non_residue, + }; + + let num_bits = 64 - exp.leading_zeros(); + for j in (0..num_bits).rev() { + result = result.square(); + if (exp >> j) & 1u64 == 1u64 { + result = result.mul(self); + } + } + + result + } +} + +impl SepticExtension { + pub fn random(mut rng: impl RngCore) -> Self { + let mut arr = [F::ZERO; 7]; + for item in arr.iter_mut() { + *item = F::random(&mut rng); + } + Self(arr) + } +} + +impl From<[u32; 7]> for SepticExtension { + fn from(arr: [u32; 7]) -> Self { + let mut result = [F::ZERO; 7]; + for i in 0..7 { + result[i] = F::from_canonical_u32(arr[i]); + } + Self(result) + } +} + +impl Add<&Self> for SepticExtension { + type Output = SepticExtension; + + fn add(self, other: &Self) -> Self { + let mut result = [F::ZERO; 7]; + for (i, res) in result.iter_mut().enumerate() { + *res = self.0[i] + other.0[i]; + } + Self(result) + } +} + +impl Add for &SepticExtension { + type Output = SepticExtension; + + fn add(self, other: Self) -> SepticExtension { + let mut result = [F::ZERO; 7]; + for (i, res) in result.iter_mut().enumerate() { + *res = self.0[i] + other.0[i]; + } + SepticExtension(result) + } +} + +impl Add for SepticExtension { + type Output = Self; + + fn add(self, other: Self) -> Self { + self.add(&other) + } +} + +impl Neg for SepticExtension { + type Output = Self; + + fn neg(self) -> Self { + let mut result = [F::ZERO; 7]; + for (res, src) in result.iter_mut().zip(self.0.iter()) { + *res = -(*src); + } + Self(result) + } +} + +impl Sub<&Self> for SepticExtension { + type Output = SepticExtension; + + fn sub(self, other: &Self) -> Self { + let mut result = [F::ZERO; 7]; + for (i, res) in result.iter_mut().enumerate() { + *res = self.0[i] - other.0[i]; + } + Self(result) + } +} + +impl Sub for &SepticExtension { + type Output = SepticExtension; + + fn sub(self, other: Self) -> SepticExtension { + let mut result = [F::ZERO; 7]; + for (i, res) in result.iter_mut().enumerate() { + *res = self.0[i] - other.0[i]; + } + SepticExtension(result) + } +} + +impl Sub for SepticExtension { + type Output = Self; + + fn sub(self, other: Self) -> Self { + self.sub(&other) + } +} + +impl Add for &SepticExtension { + type Output = SepticExtension; + + fn add(self, other: F) -> Self::Output { + let mut result = self.clone(); + result.0[0] += other; + + result + } +} + +impl Add for SepticExtension { + type Output = SepticExtension; + + fn add(self, other: F) -> Self::Output { + (&self).add(other) + } +} + +impl Mul for &SepticExtension { + type Output = SepticExtension; + + fn mul(self, other: F) -> Self::Output { + let mut result = [F::ZERO; 7]; + for (i, res) in result.iter_mut().enumerate() { + *res = self.0[i] * other; + } + SepticExtension(result) + } +} + +impl Mul for SepticExtension { + type Output = SepticExtension; + + fn mul(self, other: F) -> Self::Output { + (&self).mul(other) + } +} + +impl Mul for &SepticExtension { + type Output = SepticExtension; + + fn mul(self, other: Self) -> Self::Output { + let mut result = [F::ZERO; 7]; + let five = F::from_canonical_u32(5); + let two = F::from_canonical_u32(2); + for i in 0..7 { + for j in 0..7 { + let term = self.0[i] * other.0[j]; + let mut index = i + j; + if index < 7 { + result[index] += term; + } else { + index -= 7; + // x^7 = 2x + 5 + result[index] += five * term; + result[index + 1] += two * term; + } + } + } + SepticExtension(result) + } +} + +impl Mul for SepticExtension { + type Output = Self; + + fn mul(self, other: Self) -> Self { + (&self).mul(&other) + } +} + +impl Mul<&Self> for SepticExtension { + type Output = Self; + + fn mul(self, other: &Self) -> Self { + (&self).mul(other) + } +} + +impl MulAssign for SepticExtension { + fn mul_assign(&mut self, other: Self) { + *self = (&*self).mul(&other); + } +} + +#[derive(Clone, Debug)] +pub struct SymbolicSepticExtension(pub Vec>); + +impl SymbolicSepticExtension { + pub fn mul_scalar(&self, scalar: Either) -> Self { + let res = self + .0 + .iter() + .map(|a| a.clone() * Expression::Constant(scalar)) + .collect(); + + SymbolicSepticExtension(res) + } + + pub fn add_scalar(&self, scalar: Either) -> Self { + let res = self + .0 + .iter() + .map(|a| a.clone() + Expression::Constant(scalar)) + .collect(); + + SymbolicSepticExtension(res) + } +} + +impl Add for &SymbolicSepticExtension { + type Output = SymbolicSepticExtension; + + fn add(self, other: Self) -> Self::Output { + let res = self + .0 + .iter() + .zip(other.0.iter()) + .map(|(a, b)| a.clone() + b.clone()) + .collect(); + + SymbolicSepticExtension(res) + } +} + +impl Add<&Self> for SymbolicSepticExtension { + type Output = Self; + + fn add(self, other: &Self) -> Self { + (&self).add(other) + } +} + +impl Add for SymbolicSepticExtension { + type Output = Self; + + fn add(self, other: Self) -> Self { + (&self).add(&other) + } +} + +impl Sub for &SymbolicSepticExtension { + type Output = SymbolicSepticExtension; + + fn sub(self, other: Self) -> Self::Output { + let res = self + .0 + .iter() + .zip(other.0.iter()) + .map(|(a, b)| a.clone() - b.clone()) + .collect(); + + SymbolicSepticExtension(res) + } +} + +impl Sub<&Self> for SymbolicSepticExtension { + type Output = Self; + + fn sub(self, other: &Self) -> Self { + (&self).sub(other) + } +} + +impl Sub for SymbolicSepticExtension { + type Output = Self; + + fn sub(self, other: Self) -> Self { + (&self).sub(&other) + } +} + +impl Mul for &SymbolicSepticExtension { + type Output = SymbolicSepticExtension; + + fn mul(self, other: Self) -> Self::Output { + let mut result = vec![Expression::Constant(Either::Left(E::BaseField::ZERO)); 7]; + let five = Expression::Constant(Either::Left(E::BaseField::from_canonical_u32(5))); + let two = Expression::Constant(Either::Left(E::BaseField::from_canonical_u32(2))); + + for i in 0..7 { + for j in 0..7 { + let term = self.0[i].clone() * other.0[j].clone(); + let mut index = i + j; + if index < 7 { + result[index] += term; + } else { + index -= 7; + // x^7 = 2x + 5 + result[index] += five.clone() * term.clone(); + result[index + 1] += two.clone() * term.clone(); + } + } + } + SymbolicSepticExtension(result) + } +} + +impl Mul<&Self> for SymbolicSepticExtension { + type Output = Self; + + fn mul(self, other: &Self) -> Self { + (&self).mul(other) + } +} + +impl Mul for SymbolicSepticExtension { + type Output = Self; + + fn mul(self, other: Self) -> Self { + (&self).mul(&other) + } +} + +impl Mul<&Expression> for SymbolicSepticExtension { + type Output = SymbolicSepticExtension; + + fn mul(self, other: &Expression) -> Self::Output { + let res = self.0.iter().map(|a| a.clone() * other.clone()).collect(); + SymbolicSepticExtension(res) + } +} + +impl Mul> for SymbolicSepticExtension { + type Output = SymbolicSepticExtension; + + fn mul(self, other: Expression) -> Self::Output { + self.mul(&other) + } +} + +impl SymbolicSepticExtension { + pub fn new(exprs: Vec>) -> Self { + assert!( + exprs.len() == 7, + "exprs length must be 7, but got {}", + exprs.len() + ); + Self(exprs) + } + + pub fn to_exprs(&self) -> Vec> { + self.0.clone() + } +} + +/// A point on the short Weierstrass curve defined by +/// y^2 = x^3 + 2x + 26z^5 +/// over the extension field F[z] / (z^7 - 2z - 5). +/// +/// Note that +/// 1. The curve's cofactor is 1 +/// 2. The curve's order is a large prime number of 31x7 bits +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)] +pub struct SepticPoint { + pub x: SepticExtension, + pub y: SepticExtension, + pub is_infinity: bool, +} + +impl SepticPoint { + // if there exists y such that (x, y) is on the curve, return one of them + pub fn from_x(x: SepticExtension) -> Option { + let b: SepticExtension = [0, 0, 0, 0, 0, 26, 0].into(); + let a: F = F::from_canonical_u32(2); + + let y2 = x.square() * &x + (&x * a) + &b; + if y2.is_square() { + let y = y2.sqrt().unwrap(); + + Some(Self { + x, + y, + is_infinity: false, + }) + } else { + None + } + } + + pub fn from_affine(x: SepticExtension, y: SepticExtension) -> Self { + let is_infinity = x.is_zero() && y.is_zero(); + + Self { x, y, is_infinity } + } + pub fn double(&self) -> Self { + let a = F::from_canonical_u32(2); + let three = F::from_canonical_u32(3); + let two = F::from_canonical_u32(2); + + let x1 = &self.x; + let y1 = &self.y; + let x1_sqr = x1.square(); + + // x3 = (3*x1^2 + a)^2 / (2*y1)^2 - x1 - x1 + let slope = (x1_sqr * three + a) * (y1 * two).inverse().unwrap(); + let x3 = slope.square() - x1 - x1; + // y3 = slope * (x1 - x3) - y1 + let y3 = slope * (x1 - &x3) - y1; + + Self { + x: x3, + y: y3, + is_infinity: false, + } + } +} + +impl Default for SepticPoint { + fn default() -> Self { + Self { + x: SepticExtension::zero(), + y: SepticExtension::zero(), + is_infinity: true, + } + } +} + +impl Neg for SepticPoint { + type Output = SepticPoint; + + fn neg(self) -> Self::Output { + if self.is_infinity { + return self; + } + + Self { + x: self.x, + y: -self.y, + is_infinity: false, + } + } +} + +impl Add for SepticPoint { + type Output = Self; + + fn add(self, other: Self) -> Self { + if self.is_infinity { + return other; + } + + if other.is_infinity { + return self; + } + + if self.x == other.x { + if self.y == other.y { + return self.double(); + } else { + assert!((self.y + other.y).is_zero()); + + return Self { + x: SepticExtension::zero(), + y: SepticExtension::zero(), + is_infinity: true, + }; + } + } + + let slope = (other.y - &self.y) * (other.x.clone() - &self.x).inverse().unwrap(); + let x = slope.square() - (&self.x + &other.x); + let y = slope * (self.x - &x) - self.y; + + Self { + x, + y, + is_infinity: false, + } + } +} + +impl Sum for SepticPoint { + fn sum>(iter: I) -> Self { + iter.fold(Self::default(), |acc, p| acc + p) + } +} + +impl SepticPoint { + pub fn is_on_curve(&self) -> bool { + if self.is_infinity && self.x.is_zero() && self.y.is_zero() { + return true; + } + + let b: SepticExtension = [0, 0, 0, 0, 0, 26, 0].into(); + let a: F = F::from_canonical_u32(2); + + self.y.square() == self.x.square() * &self.x + (&self.x * a) + b + } + + pub fn point_at_infinity() -> Self { + Self::default() + } +} + +impl SepticPoint { + pub fn random(mut rng: impl RngCore) -> Self { + loop { + let x = SepticExtension::random(&mut rng); + if let Some(point) = Self::from_x(x) { + return point; + } + } + } +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct SepticJacobianPoint { + pub x: SepticExtension, + pub y: SepticExtension, + pub z: SepticExtension, +} + +impl From> for SepticJacobianPoint { + fn from(p: SepticPoint) -> Self { + if p.is_infinity { + Self::default() + } else { + Self { + x: p.x, + y: p.y, + z: SepticExtension::one(), + } + } + } +} + +impl Default for SepticJacobianPoint { + fn default() -> Self { + // return the point at infinity + Self { + x: SepticExtension::zero(), + y: SepticExtension::one(), + z: SepticExtension::zero(), + } + } +} + +impl SepticJacobianPoint { + pub fn point_at_infinity() -> Self { + Self::default() + } + + pub fn is_on_curve(&self) -> bool { + if self.z.is_zero() { + return self.x.is_zero() && !self.y.is_zero(); + } + + let b: SepticExtension = [0, 0, 0, 0, 0, 26, 0].into(); + let a: F = F::from_canonical_u32(2); + + let z2 = self.z.square(); + let z4 = z2.square(); + let z6 = &z4 * &z2; + + // y^2 = x^3 + 2x*z^4 + b*z^6 + self.y.square() == self.x.square() * &self.x + (&self.x * a * z4) + (b * &z6) + } + + pub fn into_affine(self) -> SepticPoint { + if self.z.is_zero() { + return SepticPoint::point_at_infinity(); + } + + let z_inv = self.z.inverse().unwrap(); + let z_inv2 = z_inv.square(); + let z_inv3 = &z_inv2 * &z_inv; + + let x = &self.x * &z_inv2; + let y = &self.y * &z_inv3; + + SepticPoint { + x, + y, + is_infinity: false, + } + } +} + +impl Add for &SepticJacobianPoint { + type Output = SepticJacobianPoint; + + fn add(self, rhs: Self) -> Self::Output { + // https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-3.html#addition-add-2007-bl + if self.z.is_zero() { + return rhs.clone(); + } + + if rhs.z.is_zero() { + return self.clone(); + } + + let z1z1 = self.z.square(); + let z2z2 = rhs.z.square(); + + let u1 = &self.x * &z2z2; + let u2 = &rhs.x * &z1z1; + + let s1 = &self.y * &z2z2 * &rhs.z; + let s2 = &rhs.y * &z1z1 * &self.z; + + if u1 == u2 { + if s1 == s2 { + return self.double(); + } else { + return SepticJacobianPoint::point_at_infinity(); + } + } + + let two = F::from_canonical_u32(2); + let h = u2 - &u1; + let i = (&h * two).square(); + let j = &h * &i; + let r = (s2 - &s1) * two; + let v = u1 * &i; + + let x3 = r.square() - &j - &v * two; + let y3 = r * (v - &x3) - s1 * &j * two; + let z3 = (&self.z + &rhs.z).square() - &z1z1 - &z2z2; + let z3 = z3 * h; + + Self::Output { + x: x3, + y: y3, + z: z3, + } + } +} + +impl Add for SepticJacobianPoint { + type Output = SepticJacobianPoint; + + fn add(self, rhs: Self) -> Self::Output { + (&self).add(&rhs) + } +} + +impl SepticJacobianPoint { + pub fn double(&self) -> Self { + // https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian.html#doubling-dbl-2007-bl + + // y = 0 means self.order = 2 + if self.y.is_zero() { + return SepticJacobianPoint::point_at_infinity(); + } + + let two = F::from_canonical_u32(2); + let three = F::from_canonical_u32(3); + let eight = F::from_canonical_u32(8); + let a = F::from_canonical_u32(2); // The curve coefficient a + + // xx = x1^2 + let xx = self.x.square(); + + // yy = y1^2 + let yy = self.y.square(); + + // yyyy = yy^2 + let yyyy = yy.square(); + + // zz = z1^2 + let zz = self.z.square(); + + // S = 2*((x1 + y1^2)^2 - x1^2 - y1^4) + let s = (&self.x + &yy).square() - &xx - &yyyy; + let s = s * two; + + // M = 3*x1^2 + a*z1^4 + let m = &xx * three + zz.square() * a; + + // T = M^2 - 2*S + let t = m.square() - &s * two; + + // Y3 = M*(S-T)-8*y^4 + let y3 = m * (&s - &t) - &yyyy * eight; + + // X3 = T + let x3 = t; + + // Z3 = (y1+z1)^2 - y1^2 - z1^2 + let z3 = (&self.y + &self.z).square() - &yy - &zz; + + Self { + x: x3, + y: y3, + z: z3, + } + } +} + +impl Sum for SepticJacobianPoint { + fn sum>(iter: I) -> Self { + iter.fold(Self::default(), |acc, p| acc + p) + } +} + +impl SepticJacobianPoint { + pub fn random(rng: impl RngCore) -> Self { + SepticPoint::random(rng).into() + } +} + +#[cfg(test)] +mod tests { + use super::SepticExtension; + use crate::scheme::septic_curve::{SepticJacobianPoint, SepticPoint}; + use p3::{babybear::BabyBear, field::Field}; + use rand::thread_rng; + + type F = BabyBear; + #[test] + fn test_septic_extension_arithmetic() { + let mut rng = thread_rng(); + // a = z, b = z^6 + z^5 + z^4 + let a: SepticExtension = SepticExtension::from([0, 1, 0, 0, 0, 0, 0]); + let b: SepticExtension = SepticExtension::from([0, 0, 0, 0, 1, 1, 1]); + + let c = SepticExtension::from([5, 2, 0, 0, 0, 1, 1]); + assert_eq!(a * b, c); + + // a^(p^2) = (a^p)^p + assert_eq!(c.double_frobenius(), c.frobenius().frobenius()); + + // norm_sub(a) * a must be in F + let norm = c.norm_sub() * &c; + assert!(norm.0[1..7].iter().all(|x| x.is_zero())); + + let d: SepticExtension = SepticExtension::random(&mut rng); + let e = d.square(); + assert!(e.is_square()); + + let f = e.sqrt().unwrap(); + let zero = SepticExtension::zero(); + assert!(f == d || f == zero - d); + } + + #[test] + fn test_septic_curve_arithmetic() { + let mut rng = thread_rng(); + let p1 = SepticPoint::::random(&mut rng); + let p2 = SepticPoint::::random(&mut rng); + + let j1 = SepticJacobianPoint::from(p1.clone()); + let j2 = SepticJacobianPoint::from(p2.clone()); + + let p3 = p1 + p2; + let j3 = &j1 + &j2; + + assert!(j1.is_on_curve()); + assert!(j2.is_on_curve()); + + assert!(j3.is_on_curve()); + assert!(p3.is_on_curve()); + + assert_eq!(p3, j3.clone().into_affine()); + + // 2*p3 - p3 = p3 + let p4 = p3.double(); + assert_eq!((-p3.clone() + p4.clone()), p3); + + // 2*j3 = 2*p3 + let j4 = j3.double(); + assert!(j4.is_on_curve()); + assert_eq!(j4.into_affine(), p4); + } +} diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 60dff6a99..73355017c 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -42,7 +42,7 @@ use super::{ utils::infer_tower_product_witness, verifier::{TowerVerify, ZKVMVerifier}, }; -use crate::tables::DynamicRangeTableCircuit; +use crate::{e2e::ShardContext, tables::DynamicRangeTableCircuit}; use itertools::Itertools; use mpcs::{ PolynomialCommitmentScheme, SecurityLevel, SecurityLevel::Conjecture100bits, WhirDefault, @@ -55,6 +55,7 @@ use transcript::{BasicTranscript, Transcript}; struct TestConfig { pub(crate) reg_id: WitIn, } + struct TestCircuit { phantom: PhantomData, } @@ -90,6 +91,7 @@ impl Instruction for Test fn assign_instance( config: &Self::InstructionConfig, + _shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], _lk_multiplicity: &mut LkMultiplicity, _step: &StepRecord, @@ -118,6 +120,7 @@ fn test_rw_lk_expression_combination() { let name = TestCircuit::::name(); let mut zkvm_cs = ZKVMConstraintSystem::default(); let config = zkvm_cs.register_opcode_circuit::>(); + let mut shard_ctx = ShardContext::default(); // generate fixed traces let mut zkvm_fixed_traces = ZKVMFixedTraces::default(); @@ -140,6 +143,7 @@ fn test_rw_lk_expression_combination() { zkvm_witness .assign_opcode_circuit::>( &zkvm_cs, + &mut shard_ctx, &config, vec![StepRecord::default(); num_instances], ) @@ -194,7 +198,8 @@ fn test_rw_lk_expression_combination() { witness: wits_in, structural_witness: structural_in, public_input: vec![], - num_instances, + num_instances: vec![num_instances], + has_ecc_ops: false, }; let (proof, _, _) = prover .create_chip_proof( @@ -274,6 +279,7 @@ fn test_single_add_instance_e2e() { Pcs::setup(1 << MAX_NUM_VARIABLES, SecurityLevel::default()).expect("Basefold PCS setup"); let (pp, vp) = Pcs::trim((), 1 << MAX_NUM_VARIABLES).expect("Basefold trim"); let mut zkvm_cs = ZKVMConstraintSystem::default(); + let mut shard_ctx = ShardContext::default(); // opcode circuits let add_config = zkvm_cs.register_opcode_circuit::>(); let halt_config = zkvm_cs.register_opcode_circuit::>(); @@ -339,10 +345,20 @@ fn test_single_add_instance_e2e() { let mut zkvm_witness = ZKVMWitnesses::default(); // assign opcode circuits zkvm_witness - .assign_opcode_circuit::>(&zkvm_cs, &add_config, add_records) + .assign_opcode_circuit::>( + &zkvm_cs, + &mut shard_ctx, + &add_config, + add_records, + ) .unwrap(); zkvm_witness - .assign_opcode_circuit::>(&zkvm_cs, &halt_config, halt_records) + .assign_opcode_circuit::>( + &zkvm_cs, + &mut shard_ctx, + &halt_config, + halt_records, + ) .unwrap(); zkvm_witness.finalize_lk_multiplicities(); zkvm_witness @@ -356,7 +372,7 @@ fn test_single_add_instance_e2e() { .assign_table_circuit::>(&zkvm_cs, &prog_config, &program) .unwrap(); - let pi = PublicValues::new(0, 0, 0, 0, 0, vec![0]); + let pi = PublicValues::new(0, 0, 0, 0, 0, 0, vec![0], vec![0; 14]); let transcript = BasicTranscript::new(b"riscv"); let zkvm_proof = prover .create_proof(zkvm_witness, pi, transcript) diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index c8b67929e..4bafca1c1 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -156,6 +156,11 @@ macro_rules! tower_mle_4 { }}; } +pub fn log2_strict_usize(n: usize) -> usize { + assert!(n.is_power_of_two()); + n.trailing_zeros() as usize +} + /// infer logup witness from last layer /// return is the ([p1,p2], [q1,q2]) for each layer pub(crate) fn infer_tower_logup_witness<'a, E: ExtensionField>( @@ -254,45 +259,80 @@ pub(crate) fn infer_tower_logup_witness<'a, E: ExtensionField>( .collect_vec() } -/// infer tower witness from last layer -pub(crate) fn infer_tower_product_witness( +/// Infer tower witness from input layer (layer 0 is the output layer and layer n is the input layer). +/// The relation between layer i and layer i+1 is as follows: +/// prod[i][b] = ∏_s prod[i+1][s,b] +/// where 2^s is the fanin of the product gate `num_product_fanin`. +pub fn infer_tower_product_witness( num_vars: usize, last_layer: Vec>, num_product_fanin: usize, ) -> Vec>> { + // sanity check assert!(last_layer.len() == num_product_fanin); - assert_eq!(num_product_fanin % 2, 0); - let log2_num_product_fanin = ceil_log2(num_product_fanin); - let mut wit_layers = - (0..(num_vars / log2_num_product_fanin) - 1).fold(vec![last_layer], |mut acc, _| { - let next_layer = acc.last().unwrap(); - let cur_len = next_layer[0].evaluations().len() / num_product_fanin; - let cur_layer: Vec> = (0..num_product_fanin) - .map(|index| { - let mut evaluations = vec![E::ONE; cur_len]; - next_layer.chunks_exact(2).for_each(|f| { - match (f[0].evaluations(), f[1].evaluations()) { - (FieldType::Ext(f1), FieldType::Ext(f2)) => { - let start: usize = index * cur_len; - (start..(start + cur_len)) + assert!(num_product_fanin.is_power_of_two()); + + let log2_num_product_fanin = log2_strict_usize(num_product_fanin); + assert!(num_vars.is_multiple_of(log2_num_product_fanin)); + assert!( + last_layer + .iter() + .all(|p| p.num_vars() == num_vars - log2_num_product_fanin) + ); + + let num_layers = num_vars / log2_num_product_fanin; + + let mut wit_layers = Vec::with_capacity(num_layers); + wit_layers.push(last_layer); + + for _ in (0..num_layers - 1).rev() { + let input_layer = wit_layers.last().unwrap(); + let output_len = input_layer[0].evaluations().len() / num_product_fanin; + + let output_layer: Vec> = (0..num_product_fanin) + .map(|index| { + // avoid the overhead of vector initialization + let mut evaluations: Vec = Vec::with_capacity(output_len); + let remaining = evaluations.spare_capacity_mut(); + + input_layer.chunks_exact(2).enumerate().for_each(|(i, f)| { + match (f[0].evaluations(), f[1].evaluations()) { + (FieldType::Ext(f1), FieldType::Ext(f2)) => { + let start: usize = index * output_len; + + if i == 0 { + (start..(start + output_len)) + .into_par_iter() + .zip(remaining.par_iter_mut()) + .with_min_len(MIN_PAR_SIZE) + .for_each(|(index, evaluations)| { + evaluations.write(f1[index] * f2[index]); + }); + } else { + (start..(start + output_len)) .into_par_iter() - .zip(evaluations.par_iter_mut()) + .zip(remaining.par_iter_mut()) .with_min_len(MIN_PAR_SIZE) - .map(|(index, evaluations)| { - *evaluations *= f1[index] * f2[index] - }) - .collect() + .for_each(|(index, evaluations)| { + evaluations.write(f1[index] * f2[index]); + }); } - _ => unreachable!("must be extension field"), } - }); - evaluations.into_mle() - }) - .collect_vec(); - acc.push(cur_layer); - acc - }); + _ => unreachable!("must be extension field"), + } + }); + + unsafe { + evaluations.set_len(output_len); + } + evaluations.into_mle() + }) + .collect_vec(); + wit_layers.push(output_layer); + } + wit_layers.reverse(); + wit_layers } @@ -350,12 +390,16 @@ pub fn build_main_witness< } if let Some(gkr_circuit) = gkr_circuit { - // opcode must have at least one read/write/lookup + // circuit must have at least one read/write/lookup assert!( - cs.lk_expressions.is_empty() - || !cs.r_expressions.is_empty() - || !cs.w_expressions.is_empty(), - "assert opcode circuit" + cs.r_expressions.len() + + cs.w_expressions.len() + + cs.lk_expressions.len() + + cs.r_table_expressions.len() + + cs.w_table_expressions.len() + + cs.lk_table_expressions.len() + > 0, + "assert circuit" ); let (_, gkr_circuit_out) = gkr_witness::( @@ -370,7 +414,7 @@ pub fn build_main_witness< } else { ( >::table_witness(device, input, cs, challenges), - false, + input.num_instances() > 1 && input.num_instances().is_power_of_two(), ) } }; @@ -462,6 +506,7 @@ pub fn gkr_witness< Either::Right(iter::empty()) }) .chain(fixed.iter().cloned()) + .chain(pub_io.iter().cloned()) .collect_vec(); // infer current layer output diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index f8c1c8a2a..4ed5a89e9 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -1,15 +1,34 @@ -use std::marker::PhantomData; - +use either::Either; use ff_ext::ExtensionField; +use std::marker::PhantomData; #[cfg(debug_assertions)] use ff_ext::{Instrumented, PoseidonField}; -use gkr_iop::gkr::GKRClaims; +use crate::{ + error::ZKVMError, + scheme::{ + constants::{NUM_FANIN, NUM_FANIN_LOGUP, SEPTIC_EXTENSION_DEGREE}, + septic_curve::SepticExtension, + }, + structs::{ + ComposedConstrainSystem, EccQuarkProof, PointAndEval, TowerProofs, VerifyingKey, + ZKVMVerifyingKey, + }, + utils::{ + eval_inner_repeated_incremental_vec, eval_outer_repeated_incremental_vec, + eval_stacked_constant_vec, eval_stacked_wellform_address_vec, eval_wellform_address_vec, + }, +}; +use gkr_iop::{ + gkr::GKRClaims, + selector::{SelectorContext, SelectorType}, +}; use itertools::{Itertools, chain, interleave, izip}; use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{ - Instance, StructuralWitIn, StructuralWitInType, + Expression, Instance, StructuralWitIn, StructuralWitInType, + StructuralWitInType::StackedConstantSequence, mle::IntoMLE, util::ceil_log2, utils::eval_by_expr_with_instance, @@ -23,16 +42,6 @@ use sumcheck::{ use transcript::{ForkableTranscript, Transcript}; use witness::next_pow2_instance_padding; -use crate::{ - error::ZKVMError, - scheme::constants::{NUM_FANIN, NUM_FANIN_LOGUP, SEL_DEGREE}, - structs::{ComposedConstrainSystem, PointAndEval, TowerProofs, VerifyingKey, ZKVMVerifyingKey}, - utils::{ - eval_inner_repeated_incremental_vec, eval_outer_repeated_incremental_vec, - eval_stacked_constant_vec, eval_stacked_wellform_address_vec, eval_wellform_address_vec, - }, -}; - use super::{ZKVMChipProof, ZKVMProof}; pub struct ZKVMVerifier> { @@ -63,15 +72,15 @@ impl> ZKVMVerifier &self, vm_proof: ZKVMProof, transcript: impl ForkableTranscript, - expect_halt: bool, + _expect_halt: bool, ) -> Result { // require ecall/halt proof to exist, depending whether we expect a halt. - let has_halt = vm_proof.has_halt(&self.vk); - if has_halt != expect_halt { - return Err(ZKVMError::VerifyError( - format!("ecall/halt mismatch: expected {expect_halt} != {has_halt}",).into(), - )); - } + // let has_halt = vm_proof.has_halt(&self.vk); + // if has_halt != expect_halt { + // return Err(ZKVMError::VerifyError( + // format!("ecall/halt mismatch: expected {expect_halt} != {has_halt}",).into(), + // )); + // } self.verify_proof_validity(vm_proof, transcript) } @@ -133,7 +142,9 @@ impl> ZKVMVerifier // write (circuit_idx, num_instance) to transcript for (circuit_idx, proof) in &vm_proof.chip_proofs { transcript.append_message(&circuit_idx.to_le_bytes()); - transcript.append_message(&proof.num_instances.to_le_bytes()); + for num_instance in &proof.num_instances { + transcript.append_message(&num_instance.to_le_bytes()); + } } // write witin commitment to transcript @@ -157,11 +168,11 @@ impl> ZKVMVerifier let dummy_table_item = challenges[0]; let mut dummy_table_item_multiplicity = 0; let point_eval = PointAndEval::default(); - let mut rt_points = Vec::with_capacity(vm_proof.chip_proofs.len()); - let mut evaluations = Vec::with_capacity(vm_proof.chip_proofs.len()); let mut witin_openings = Vec::with_capacity(vm_proof.chip_proofs.len()); let mut fixed_openings = Vec::with_capacity(vm_proof.chip_proofs.len()); for (index, proof) in &vm_proof.chip_proofs { + let num_instance: usize = proof.num_instances.iter().sum(); + assert!(num_instance > 0); let circuit_name = &self.vk.circuit_index_to_name[index]; let circuit_vk = &self.vk.circuit_vks[circuit_name]; @@ -219,11 +230,10 @@ impl> ZKVMVerifier // getting the number of dummy padding item that we used in this opcode circuit let num_lks = circuit_vk.get_cs().num_lks(); // each padding instance contribute to (2^rotation_vars) dummy lookup padding - let num_padded_instance = (next_pow2_instance_padding(proof.num_instances) - - proof.num_instances) + let num_padded_instance = (next_pow2_instance_padding(num_instance) - num_instance) * (1 << circuit_vk.get_cs().rotation_vars().unwrap_or(0)); // each instance contribute to (2^rotation_vars - rotated) dummy lookup padding - let num_instance_non_selected = proof.num_instances + let num_instance_non_selected = num_instance * ((1 << circuit_vk.get_cs().rotation_vars().unwrap_or(0)) - (circuit_vk.get_cs().rotation_subgroup_size().unwrap_or(0) + 1)); dummy_table_item_multiplicity += @@ -254,22 +264,18 @@ impl> ZKVMVerifier &challenges, )? }; - rt_points.push((*index, input_opening_point.clone())); - evaluations.push(( - *index, - [proof.wits_in_evals.clone(), proof.fixed_in_evals.clone()].concat(), - )); - witin_openings.push(( - input_opening_point.len(), - (input_opening_point.clone(), proof.wits_in_evals.clone()), - )); - if !proof.fixed_in_evals.is_empty() { + if circuit_vk.get_cs().num_witin() > 0 { + witin_openings.push(( + input_opening_point.len(), + (input_opening_point.clone(), proof.wits_in_evals.clone()), + )); + } + if circuit_vk.get_cs().num_fixed() > 0 { fixed_openings.push(( input_opening_point.len(), (input_opening_point.clone(), proof.fixed_in_evals.clone()), )); } - prod_w *= proof.w_out_evals.iter().flatten().copied().product::(); prod_r *= proof.r_out_evals.iter().flatten().copied().product::(); tracing::debug!("verified proof for circuit {}", circuit_name); @@ -353,18 +359,56 @@ impl> ZKVMVerifier zkvm_v1_css: cs, gkr_circuit, } = &composed_cs; - let num_instances = proof.num_instances; + let num_instances = proof.num_instances.iter().sum(); let (r_counts_per_instance, w_counts_per_instance, lk_counts_per_instance) = ( - cs.r_expressions.len(), - cs.w_expressions.len(), - cs.lk_expressions.len(), + cs.r_expressions.len() + cs.r_table_expressions.len(), + cs.w_expressions.len() + cs.w_table_expressions.len(), + cs.lk_expressions.len() + cs.lk_table_expressions.len() * 2, ); let num_batched = r_counts_per_instance + w_counts_per_instance + lk_counts_per_instance; let next_pow2_instance = next_pow2_instance_padding(num_instances); - let log2_num_instances = ceil_log2(next_pow2_instance); + let mut log2_num_instances = ceil_log2(next_pow2_instance); + if composed_cs.has_ecc_ops() { + // for opcode circuit with ecc ops, the mles have one extra variable + // to store the internal partial sums for ecc additions + log2_num_instances += 1; + } let num_var_with_rotation = log2_num_instances + composed_cs.rotation_vars().unwrap_or(0); + // verify ecc proof if exists + if composed_cs.has_ecc_ops() { + tracing::debug!("verifying ecc proof..."); + assert!(proof.ecc_proof.is_some()); + let ecc_proof = proof.ecc_proof.as_ref().unwrap(); + + // TODO: enable this + // let xy = cs + // .ec_final_sum + // .iter() + // .map(|expr| { + // eval_by_expr_with_instance(&[], &[], &[], pi, challenges, &expr) + // .right() + // .and_then(|v| v.as_base()) + // .unwrap() + // }) + // .collect_vec(); + // let x: SepticExtension = xy[0..SEPTIC_EXTENSION_DEGREE].into(); + // let y: SepticExtension = xy[SEPTIC_EXTENSION_DEGREE..].into(); + + // assert_eq!( + // SepticPoint { + // x, + // y, + // is_infinity: false, + // }, + // ecc_proof.sum + // ); + // assert ec sum in public input matches that in ecc proof + EccVerifier::verify_ecc_proof(ecc_proof, transcript)?; + tracing::debug!("ecc proof verified."); + } + // verify and reduce product tower sumcheck let tower_proofs = &proof.tower_proof; @@ -409,6 +453,44 @@ impl> ZKVMVerifier debug_assert_eq!(logup_q_evals.len(), lk_counts_per_instance); let gkr_circuit = gkr_circuit.as_ref().unwrap(); + let selector_ctxs = if cs.ec_final_sum.is_empty() { + assert_eq!(proof.num_instances.len(), 1); + // it's not global chip + vec![ + SelectorContext::new(0, num_instances, num_var_with_rotation); + gkr_circuit + .layers + .first() + .map(|layer| layer.out_sel_and_eval_exprs.len()) + .unwrap_or(0) + ] + } else { + assert_eq!(proof.num_instances.len(), 2); + // it's global chip + tracing::debug!( + "num_reads: {}, num_writes: {}, total: {}", + proof.num_instances[0], + proof.num_instances[1], + proof.num_instances[0] + proof.num_instances[1], + ); + vec![ + SelectorContext { + offset: 0, + num_instances: proof.num_instances[0], + num_vars: num_var_with_rotation, + }, + SelectorContext { + offset: proof.num_instances[0], + num_instances: proof.num_instances[1], + num_vars: num_var_with_rotation, + }, + SelectorContext { + offset: 0, + num_instances: proof.num_instances[0] + proof.num_instances[1], + num_vars: num_var_with_rotation, + }, + ] + }; let GKRClaims(opening_evaluations) = gkr_circuit.verify( num_var_with_rotation, proof.gkr_iop_proof.clone().unwrap(), @@ -416,7 +498,7 @@ impl> ZKVMVerifier pi, challenges, transcript, - num_instances, + &selector_ctxs, )?; Ok(opening_evaluations[0].point.clone()) } @@ -437,43 +519,44 @@ impl> ZKVMVerifier let ComposedConstrainSystem { zkvm_v1_css: cs, .. } = circuit_vk.get_cs(); - debug_assert!( - cs.r_table_expressions - .iter() - .zip_eq(cs.w_table_expressions.iter()) - .all(|(r, w)| r.table_spec.len == w.table_spec.len) - ); - - let log2_num_instances = next_pow2_instance_padding(proof.num_instances).ilog2() as usize; - - // in table proof, we always skip same point sumcheck for now - // as tower sumcheck batch product argument/logup in same length - let is_skip_same_point_sumcheck = true; + let with_rw = !cs.r_table_expressions.is_empty() && !cs.w_table_expressions.is_empty(); + if with_rw { + debug_assert!( + cs.r_table_expressions + .iter() + .zip_eq(cs.w_table_expressions.iter()) + .all(|(r, w)| r.table_spec.len == w.table_spec.len) + ); + } + let num_instances = proof.num_instances.iter().sum(); + let log2_num_instances = next_pow2_instance_padding(num_instances).ilog2() as usize; // verify and reduce product tower sumcheck let tower_proofs = &proof.tower_proof; // NOTE: for all structural witness within same constrain system should got same hints num variable via `log2_num_instances` - let expected_rounds = cs - // only iterate r set, as read/write set round should match - .r_table_expressions - .iter() - .flat_map(|r| { + let expected_rounds = interleave(&cs.r_table_expressions, &cs.w_table_expressions) + .map(|set_table_expr| { // iterate through structural witins and collect max round. - let num_vars = r.table_spec.len.map(ceil_log2).unwrap_or_else(|| { - r.table_spec - .structural_witins - .iter() - .map(|StructuralWitIn { witin_type, .. }| { - let hint_num_vars = log2_num_instances; - assert!((1 << hint_num_vars) <= witin_type.max_len()); - hint_num_vars - }) - .max() - .unwrap() - }); + let num_vars = set_table_expr + .table_spec + .len + .map(ceil_log2) + .unwrap_or_else(|| { + set_table_expr + .table_spec + .structural_witins + .iter() + .map(|StructuralWitIn { witin_type, .. }| { + let hint_num_vars = log2_num_instances; + assert!((1 << hint_num_vars) <= witin_type.max_len()); + hint_num_vars + }) + .max() + .unwrap() + }); assert_eq!(num_vars, log2_num_instances); - [num_vars, num_vars] // format: [read_round, write_round] + num_vars }) .chain(cs.lk_table_expressions.iter().map(|l| { // iterate through structural witins and collect max round. @@ -494,14 +577,10 @@ impl> ZKVMVerifier })) .collect_vec(); - let expected_max_rounds = expected_rounds.iter().cloned().max().unwrap(); let (rt_tower, prod_point_and_eval, logup_p_point_and_eval, logup_q_point_and_eval) = TowerVerify::verify( - proof - .r_out_evals - .iter() - .zip(proof.w_out_evals.iter()) - .flat_map(|(r_evals, w_evals)| [r_evals.to_vec(), w_evals.to_vec()]) + interleave(&proof.r_out_evals, &proof.w_out_evals) + .map(|eval| eval.to_vec()) .collect_vec(), proof .lk_out_evals @@ -530,13 +609,19 @@ impl> ZKVMVerifier cs.r_table_expressions.len() + cs.w_table_expressions.len(), "[prod_record] mismatch length" ); - let num_rw_records = cs.r_table_expressions.len() + cs.w_table_expressions.len(); - // evaluate the evaluation of structural mles at input_opening_point by verifier - let structural_evals = cs - .r_table_expressions - .iter() - .map(|r| &r.table_spec) + // TODO differentiate `ram_bus` via cs + let is_shard_ram_bus_circuit = false; + + let input_opening_point = if !is_shard_ram_bus_circuit { + // evaluate the evaluation of structural mles at input_opening_point by verifier + let structural_evals = if with_rw { + // only iterate r set, as read/write set round should match + Either::Left(cs.r_table_expressions.iter()) + } else { + Either::Right(cs.r_table_expressions.iter().chain(&cs.w_table_expressions)) + } + .map(|set_table_expr| &set_table_expr.table_spec) .chain(cs.lk_table_expressions.iter().map(|r| &r.table_spec)) .flat_map(|table_spec| { table_spec @@ -571,32 +656,30 @@ impl> ZKVMVerifier }) .collect_vec(); - // verify records (degree = 1) statement, thus no sumcheck - let expected_evals = interleave( - &cs.r_table_expressions, // r - &cs.w_table_expressions, // w - ) - .map(|rw| &rw.expr) - .chain( - cs.lk_table_expressions - .iter() - .flat_map(|lk| vec![&lk.multiplicity, &lk.values]), // p, q - ) - .map(|expr| { - eval_by_expr_with_instance( - &proof.fixed_in_evals, - &proof.wits_in_evals, - &structural_evals, - pi, - challenges, - expr, + // verify records (degree = 1) statement, thus no sumcheck + let expected_evals = interleave( + &cs.r_table_expressions, // r + &cs.w_table_expressions, // w ) - .right() - .unwrap() - }) - .collect_vec(); - - let input_opening_point = if is_skip_same_point_sumcheck { + .map(|rw| &rw.expr) + .chain( + cs.lk_table_expressions + .iter() + .flat_map(|lk| vec![&lk.multiplicity, &lk.values]), // p, q + ) + .map(|expr| { + eval_by_expr_with_instance( + &proof.fixed_in_evals, + &proof.wits_in_evals, + &structural_evals, + pi, + challenges, + expr, + ) + .right() + .unwrap() + }) + .collect_vec(); for (expected_eval, eval) in expected_evals.iter().zip( prod_point_and_eval .into_iter() @@ -619,84 +702,7 @@ impl> ZKVMVerifier } rt_tower } else { - assert!(proof.main_sumcheck_proofs.is_some()); - - // verify opening same point layer sumcheck - let alpha_pow = get_challenge_pows( - cs.r_table_expressions.len() - + cs.w_table_expressions.len() - + cs.lk_table_expressions.len() * 2, // 2 for lk numerator and denominator - transcript, - ); - - // \sum_i alpha_{i} * (out_r_eval{i}) - // + \sum_i alpha_{i} * (out_w_eval{i}) - // + \sum_i alpha_{i} * (out_lk_n{i}) - // + \sum_i alpha_{i} * (out_lk_d{i}) - let claim_sum = prod_point_and_eval - .iter() - .zip(alpha_pow.iter()) - .map(|(point_and_eval, alpha)| *alpha * point_and_eval.eval) - .sum::() - + interleave(&logup_p_point_and_eval, &logup_q_point_and_eval) - .zip_eq(alpha_pow.iter().skip(num_rw_records)) - .map(|(point_n_eval, alpha)| *alpha * point_n_eval.eval) - .sum::(); - let sel_subclaim = IOPVerifierState::verify( - claim_sum, - &IOPProof { - proofs: proof.main_sumcheck_proofs.clone().unwrap(), - }, - &VPAuxInfo { - max_degree: SEL_DEGREE, - max_num_variables: expected_max_rounds, - phantom: PhantomData, - }, - transcript, - ); - let (input_opening_point, expected_evaluation) = ( - sel_subclaim.point.iter().map(|c| c.elements).collect_vec(), - sel_subclaim.expected_evaluation, - ); - - let computed_evals = [ - // r, w - prod_point_and_eval - .into_iter() - .zip_eq(&expected_evals[0..num_rw_records]) - .zip(alpha_pow.iter()) - .map(|((point_and_eval, in_eval), alpha)| { - let eq = eq_eval( - &point_and_eval.point, - &input_opening_point[0..point_and_eval.point.len()], - ); - // TODO times multiplication factor - *alpha * eq * *in_eval - }) - .sum::(), - interleave(logup_p_point_and_eval, logup_q_point_and_eval) - .zip_eq(&expected_evals[num_rw_records..]) - .zip_eq(alpha_pow.iter().skip(num_rw_records)) - .map(|((point_and_eval, in_eval), alpha)| { - let eq = eq_eval( - &point_and_eval.point, - &input_opening_point[0..point_and_eval.point.len()], - ); - // TODO times multiplication factor - *alpha * eq * *in_eval - }) - .sum::(), - ] - .iter() - .copied() - .sum::(); - - if computed_evals != expected_evaluation { - return Err(ZKVMError::VerifyError( - "sel evaluation verify failed".into(), - )); - } - input_opening_point + unimplemented!("shard ram bus circuit go here"); }; // assume public io is tiny vector, so we evaluate it directly without PCS @@ -749,9 +755,9 @@ impl TowerVerify { let log2_num_fanin = ceil_log2(num_fanin); // sanity check - assert!(num_prod_spec == tower_proofs.prod_spec_size()); + assert_eq!(num_prod_spec, tower_proofs.prod_spec_size()); assert!(prod_out_evals.iter().all(|evals| evals.len() == num_fanin)); - assert!(num_logup_spec == tower_proofs.logup_spec_size()); + assert_eq!(num_logup_spec, tower_proofs.logup_spec_size()); assert!(logup_out_evals.iter().all(|evals| { evals.len() == 4 // [p1, p2, q1, q2] })); @@ -792,6 +798,8 @@ impl TowerVerify { ) }) .unzip::<_, _, Vec<_>, Vec<_>>(); + + // initial claim = \sum_j alpha^j * out_j[rt] let initial_claim = izip!(&prod_spec_point_n_eval, &alpha_pows) .map(|(point_n_eval, alpha)| point_n_eval.eval * *alpha) .sum::() @@ -804,7 +812,7 @@ impl TowerVerify { let max_num_variables = num_variables.iter().max().unwrap(); - let (next_rt, _) = (0..(max_num_variables-1)).try_fold( + let (next_rt, _) = (0..(max_num_variables - 1)).try_fold( ( PointAndEval { point: initial_rt, @@ -829,33 +837,40 @@ impl TowerVerify { // check expected_evaluation let rt: Point = sumcheck_claim.point.iter().map(|c| c.elements).collect(); + let eq = eq_eval(out_rt, &rt); let expected_evaluation: E = (0..num_prod_spec) .zip(alpha_pows.iter()) .zip(num_variables.iter()) .map(|((spec_index, alpha), max_round)| { - eq_eval(out_rt, &rt) - * *alpha - * if round < *max_round-1 {tower_proofs.prod_specs_eval[spec_index][round].iter().copied().product()} else { - E::ZERO - } + // prod'[b] = prod[0,b] * prod[1,b] + // prod'[out_rt] = \sum_b eq(out_rt,b) * prod'[b] = \sum_b eq(out_rt,b) * prod[0,b] * prod[1,b] + eq * *alpha + * if round < *max_round - 1 { tower_proofs.prod_specs_eval[spec_index][round].iter().copied().product() } else { + E::ZERO + } }) .sum::() + (0..num_logup_spec) - .zip_eq(alpha_pows[num_prod_spec..].chunks(2)) - .zip_eq(num_variables[num_prod_spec..].iter()) - .map(|((spec_index, alpha), max_round)| { - let (alpha_numerator, alpha_denominator) = (&alpha[0], &alpha[1]); - eq_eval(out_rt, &rt) * if round < *max_round-1 { - let evals = &tower_proofs.logup_specs_eval[spec_index][round]; - let (p1, p2, q1, q2) = - (evals[0], evals[1], evals[2], evals[3]); - *alpha_numerator * (p1 * q2 + p2 * q1) - + *alpha_denominator * (q1 * q2) - } else { - E::ZERO - } - }) - .sum::(); + .zip_eq(alpha_pows[num_prod_spec..].chunks(2)) + .zip_eq(num_variables[num_prod_spec..].iter()) + .map(|((spec_index, alpha), max_round)| { + // logup_q'[b] = logup_q[0,b] * logup_q[1,b] + // logup_p'[b] = logup_p[0,b] * logup_q[1,b] + logup_p[1,b] * logup_q[0,b] + // logup_p'[out_rt] = \sum_b eq(out_rt,b) * (logup_p[0,b] * logup_q[1,b] + logup_p[1,b] * logup_q[0,b]) + // logup_q'[out_rt] = \sum_b eq(out_rt,b) * logup_q[0,b] * logup_q[1,b] + let (alpha_numerator, alpha_denominator) = (&alpha[0], &alpha[1]); + eq * if round < *max_round - 1 { + let evals = &tower_proofs.logup_specs_eval[spec_index][round]; + let (p1, p2, q1, q2) = + (evals[0], evals[1], evals[2], evals[3]); + *alpha_numerator * (p1 * q2 + p2 * q1) + + *alpha_denominator * (q1 * q2) + } else { + E::ZERO + } + }) + .sum::(); + if expected_evaluation != sumcheck_claim.expected_evaluation { return Err(ZKVMError::VerifyError("mismatch tower evaluation".into())); } @@ -863,7 +878,7 @@ impl TowerVerify { // derive single eval // rt' = r_merge || rt // r_merge.len() == ceil_log2(num_product_fanin) - let r_merge =transcript.sample_and_append_vec(b"merge", log2_num_fanin); + let r_merge = transcript.sample_and_append_vec(b"merge", log2_num_fanin); let coeffs = build_eq_x_r_vec_sequential(&r_merge); assert_eq!(coeffs.len(), num_fanin); let rt_prime = [rt, r_merge].concat(); @@ -878,17 +893,18 @@ impl TowerVerify { .zip(next_alpha_pows.iter()) .zip(num_variables.iter()) .map(|((spec_index, alpha), max_round)| { - if round < max_round -1 { + // prod'[rt,r_merge] = \sum_b eq(r_merge, b) * prod'[b,rt] + if round < max_round - 1 { // merged evaluation let evals = izip!( tower_proofs.prod_specs_eval[spec_index][round].iter(), coeffs.iter() ) - .map(|(a, b)| *a * *b) - .sum::(); + .map(|(a, b)| *a * *b) + .sum::(); // this will keep update until round > evaluation prod_spec_point_n_eval[spec_index] = PointAndEval::new(rt_prime.clone(), evals); - if next_round < max_round -1 { + if next_round < max_round - 1 { *alpha * evals } else { E::ZERO @@ -902,28 +918,28 @@ impl TowerVerify { .zip_eq(next_alpha_pows[num_prod_spec..].chunks(2)) .zip_eq(num_variables[num_prod_spec..].iter()) .map(|((spec_index, alpha), max_round)| { - if round < max_round -1 { + if round < max_round - 1 { let (alpha_numerator, alpha_denominator) = (&alpha[0], &alpha[1]); // merged evaluation let p_evals = izip!( tower_proofs.logup_specs_eval[spec_index][round][0..2].iter(), coeffs.iter() ) - .map(|(a, b)| *a * *b) - .sum::(); + .map(|(a, b)| *a * *b) + .sum::(); let q_evals = izip!( tower_proofs.logup_specs_eval[spec_index][round][2..4].iter(), coeffs.iter() ) - .map(|(a, b)| *a * *b) - .sum::(); + .map(|(a, b)| *a * *b) + .sum::(); // this will keep update until round > evaluation logup_spec_p_point_n_eval[spec_index] = PointAndEval::new(rt_prime.clone(), p_evals); logup_spec_q_point_n_eval[spec_index] = PointAndEval::new(rt_prime.clone(), q_evals); - if next_round < max_round -1 { + if next_round < max_round - 1 { *alpha_numerator * p_evals + *alpha_denominator * q_evals } else { E::ZERO @@ -933,8 +949,10 @@ impl TowerVerify { } }) .sum::(); + // sum evaluation from different specs let next_eval = next_prod_spec_evals + next_logup_spec_evals; + Ok((PointAndEval { point: rt_prime, eval: next_eval, @@ -950,3 +968,134 @@ impl TowerVerify { )) } } + +pub struct EccVerifier; + +impl EccVerifier { + pub fn verify_ecc_proof( + proof: &EccQuarkProof, + transcript: &mut impl Transcript, + ) -> Result<(), ZKVMError> { + let num_vars = next_pow2_instance_padding(proof.num_instances).ilog2() as usize; + let out_rt = transcript.sample_and_append_vec(b"ecc", num_vars); + let alpha_pows = transcript.sample_and_append_challenge_pows( + SEPTIC_EXTENSION_DEGREE * 3 + SEPTIC_EXTENSION_DEGREE * 2, + b"ecc_alpha", + ); + let mut alpha_pows_iter = alpha_pows.iter(); + + let sumcheck_claim = IOPVerifierState::verify( + E::ZERO, + &proof.zerocheck_proof, + &VPAuxInfo { + max_degree: 3, + max_num_variables: num_vars, + phantom: PhantomData, + }, + transcript, + ); + + let s0: SepticExtension = proof.evals[2..][0..][..SEPTIC_EXTENSION_DEGREE].into(); + let x0: SepticExtension = + proof.evals[2..][SEPTIC_EXTENSION_DEGREE..][..SEPTIC_EXTENSION_DEGREE].into(); + let y0: SepticExtension = + proof.evals[2..][2 * SEPTIC_EXTENSION_DEGREE..][..SEPTIC_EXTENSION_DEGREE].into(); + let x1: SepticExtension = + proof.evals[2..][3 * SEPTIC_EXTENSION_DEGREE..][..SEPTIC_EXTENSION_DEGREE].into(); + let y1: SepticExtension = + proof.evals[2..][4 * SEPTIC_EXTENSION_DEGREE..][..SEPTIC_EXTENSION_DEGREE].into(); + let x3: SepticExtension = + proof.evals[2..][5 * SEPTIC_EXTENSION_DEGREE..][..SEPTIC_EXTENSION_DEGREE].into(); + let y3: SepticExtension = + proof.evals[2..][6 * SEPTIC_EXTENSION_DEGREE..][..SEPTIC_EXTENSION_DEGREE].into(); + + let rt = sumcheck_claim + .point + .iter() + .map(|c| c.elements) + .collect_vec(); + + // zerocheck: 0 = s[0,b] * (x[b,0] - x[b,1]) - (y[b,0] - y[b,1]) + // zerocheck: 0 = s[0,b]^2 - x[b,0] - x[b,1] - x[1,b] + // zerocheck: 0 = s[0,b] * (x[b,0] - x[1,b]) - (y[b,0] + y[1,b]) + // zerocheck: 0 = (x[1,b] - x[b,0]) + // zerocheck: 0 = (y[1,b] - y[b,0]) + // + // note that they are not septic extension field elements, + // we just want to reuse the multiply/add/sub formulas + let v1: SepticExtension = s0.clone() * (&x0 - &x1) - (&y0 - &y1); + let v2: SepticExtension = s0.square() - &x0 - &x1 - &x3; + let v3: SepticExtension = s0 * (&x0 - &x3) - (&y0 + &y3); + + let v4: SepticExtension = &x3 - &x0; + let v5: SepticExtension = &y3 - &y0; + + let [v1, v2, v3, v4, v5] = [v1, v2, v3, v4, v5].map(|v| { + v.0.into_iter() + .zip(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) + .map(|(c, alpha)| c * *alpha) + .collect_vec() + }); + + let sel_add_expr = SelectorType::::QuarkBinaryTreeLessThan(Expression::StructuralWitIn( + 0, + // this value doesn't matter, as we only need structural id + StackedConstantSequence { max_value: 0 }, + )); + let mut sel_evals = vec![E::ZERO]; + sel_add_expr.evaluate( + &mut sel_evals, + &out_rt, + &rt, + &SelectorContext { + offset: 0, + num_instances: proof.num_instances, + num_vars, + }, + 0, + ); + let expected_sel_add = sel_evals[0]; + + if proof.evals[0] != expected_sel_add { + return Err(ZKVMError::VerifyError( + (format!( + "sel_add evaluation mismatch, expected {}, got {}", + expected_sel_add, proof.evals[0] + )) + .into(), + )); + } + + // derive `sel_bypass = eq - sel_add - sel_last_onehot` + let expected_sel_bypass = eq_eval(&out_rt, &rt) + - expected_sel_add + - (out_rt.iter().copied().product::() * rt.iter().copied().product::()); + + if proof.evals[1] != expected_sel_bypass { + return Err(ZKVMError::VerifyError( + (format!( + "sel_bypass evaluation mismatch, expected {}, got {}", + expected_sel_bypass, proof.evals[1] + )) + .into(), + )); + } + + let add_evaluations = vec![v1, v2, v3].into_iter().flatten().sum::(); + let bypass_evaluations = vec![v4, v5].into_iter().flatten().sum::(); + if sumcheck_claim.expected_evaluation + != add_evaluations * expected_sel_add + bypass_evaluations * expected_sel_bypass + { + return Err(ZKVMError::VerifyError( + (format!( + "ecc zerocheck failed: mismatched evaluation, expected {}, got {}", + sumcheck_claim.expected_evaluation, + add_evaluations * expected_sel_add + bypass_evaluations * expected_sel_bypass + )) + .into(), + )); + } + + Ok(()) + } +} diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index cd76d6fcd..b0e971ffe 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -1,24 +1,47 @@ use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, error::ZKVMError, - instructions::Instruction, + instructions::{ + Instruction, + global::{GlobalChip, GlobalChipInput, GlobalPoint, GlobalRecord}, + }, + scheme::septic_curve::SepticPoint, state::StateCircuit, tables::{RMMCollections, TableCircuit}, }; use ceno_emul::{CENO_PLATFORM, Platform, StepRecord}; -use ff_ext::ExtensionField; +use ff_ext::{ExtensionField, PoseidonField}; use gkr_iop::{gkr::GKRCircuit, tables::LookupTable, utils::lk_multiplicity::Multiplicity}; use itertools::Itertools; use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{Expression, Instance}; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use std::{ collections::{BTreeMap, HashMap}, sync::Arc, }; -use sumcheck::structs::IOPProverMessage; +use sumcheck::structs::{IOPProof, IOPProverMessage}; use witness::RowMajorMatrix; +/// proof that the sum of N=2^n EC points is equal to `sum` +/// in one layer instead of GKR layered circuit approach +/// note that this one layer IOP borrowed ideas from +/// [Quark paper](https://eprint.iacr.org/2020/1275.pdf) +#[derive(Clone, Serialize, Deserialize)] +#[serde(bound( + serialize = "E::BaseField: Serialize", + deserialize = "E::BaseField: DeserializeOwned" +))] +pub struct EccQuarkProof { + pub zerocheck_proof: IOPProof, + pub num_instances: usize, + pub evals: Vec, // x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[0,rt], y[0,rt], s[0,rt] + pub rt: Point, + pub sum: SepticPoint, +} + #[derive(Clone, Serialize, Deserialize)] #[serde(bound( serialize = "E::BaseField: Serialize", @@ -108,10 +131,19 @@ impl ComposedConstrainSystem { self.zkvm_v1_css.num_witin.into() } + pub fn num_structural_witin(&self) -> usize { + self.zkvm_v1_css.num_structural_witin.into() + } + pub fn num_fixed(&self) -> usize { self.zkvm_v1_css.num_fixed } + /// static circuit means there is only fixed column + pub fn is_static_circuit(&self) -> bool { + (self.num_witin() + self.num_structural_witin()) == 0 && self.num_fixed() > 0 + } + pub fn num_reads(&self) -> usize { self.zkvm_v1_css.r_expressions.len() + self.zkvm_v1_css.r_table_expressions.len() } @@ -120,14 +152,17 @@ impl ComposedConstrainSystem { self.zkvm_v1_css.w_expressions.len() + self.zkvm_v1_css.w_table_expressions.len() } + pub fn has_ecc_ops(&self) -> bool { + !self.zkvm_v1_css.ec_final_sum.is_empty() + } + pub fn instance_name_map(&self) -> &HashMap { &self.zkvm_v1_css.instance_name_map } pub fn is_opcode_circuit(&self) -> bool { - self.zkvm_v1_css.lk_table_expressions.is_empty() - && self.zkvm_v1_css.r_table_expressions.is_empty() - && self.zkvm_v1_css.w_table_expressions.is_empty() + // TODO: is global chip opcode circuit?? + self.gkr_circuit.is_some() || self.has_ecc_ops() } /// return number of lookup operation @@ -209,18 +244,13 @@ impl ZKVMConstraintSystem { pub fn register_table_circuit>(&mut self) -> TC::TableConfig { let mut cs = ConstraintSystem::new(|| format!("riscv_table/{}", TC::name())); let mut circuit_builder = CircuitBuilder::::new(&mut cs); - let config = TC::construct_circuit(&mut circuit_builder, &self.params).unwrap(); - assert!( - self.circuit_css - .insert( - TC::name(), - ComposedConstrainSystem { - zkvm_v1_css: cs, - gkr_circuit: None - } - ) - .is_none() - ); + let (config, gkr_iop_circuit) = + TC::build_gkr_iop_circuit(&mut circuit_builder, &self.params).unwrap(); + let cs = ComposedConstrainSystem { + zkvm_v1_css: cs, + gkr_circuit: gkr_iop_circuit, + }; + assert!(self.circuit_css.insert(TC::name(), cs).is_none()); config } @@ -292,6 +322,8 @@ pub struct ZKVMWitnesses { witnesses_tables: BTreeMap>, lk_mlts: BTreeMap>, combined_lk_mlt: Option>>, + // in ram bus chip, num_instances length would be > 1 + pub num_instances: BTreeMap>, } impl ZKVMWitnesses { @@ -310,6 +342,7 @@ impl ZKVMWitnesses { pub fn assign_opcode_circuit>( &mut self, cs: &ZKVMConstraintSystem, + shard_ctx: &mut ShardContext, config: &OC::InstructionConfig, records: Vec, ) -> Result<(), ZKVMError> { @@ -318,10 +351,16 @@ impl ZKVMWitnesses { let cs = cs.get_cs(&OC::name()).unwrap(); let (witness, logup_multiplicity) = OC::assign_instances( config, + shard_ctx, cs.zkvm_v1_css.num_witin as usize, cs.zkvm_v1_css.num_structural_witin as usize, records, )?; + assert!( + self.num_instances + .insert(OC::name(), vec![witness[0].num_instances()]) + .is_none() + ); assert!(self.witnesses_opcodes.insert(OC::name(), witness).is_none()); assert!(!self.witnesses_tables.contains_key(&OC::name())); assert!( @@ -375,12 +414,99 @@ impl ZKVMWitnesses { self.combined_lk_mlt.as_ref().unwrap(), input, )?; + let num_instances = std::cmp::max(witness[0].num_instances(), witness[1].num_instances()); + assert!( + self.num_instances + .insert(TC::name(), vec![num_instances]) + .is_none() + ); assert!(self.witnesses_tables.insert(TC::name(), witness).is_none()); assert!(!self.witnesses_opcodes.contains_key(&TC::name())); Ok(()) } + pub fn assign_global_chip_circuit( + &mut self, + cs: &ZKVMConstraintSystem, + shard_ctx: &ShardContext, + config: & as TableCircuit>::TableConfig, + ) -> Result<(), ZKVMError> { + let perm = ::get_default_perm(); + let global_input = shard_ctx + .read_records() + .par_iter() + .flat_map_iter(|records| { + records.iter().map(|(vma, record)| { + let global_read: GlobalRecord = (vma, record, false).into(); + let ec_point: GlobalPoint = global_read.to_ec_point(&perm); + GlobalChipInput { + record: global_read, + ec_point, + } + }) + }) + .chain( + shard_ctx + .write_records() + .par_iter() + .flat_map_iter(|records| { + records.iter().map(|(vma, record)| { + let global_write: GlobalRecord = (vma, record, true).into(); + let ec_point: GlobalPoint = global_write.to_ec_point(&perm); + GlobalChipInput { + record: global_write, + ec_point, + } + }) + }), + ) + .collect::>(); + assert!(self.combined_lk_mlt.is_some()); + let cs = cs.get_cs(&GlobalChip::::name()).unwrap(); + let witness = GlobalChip::assign_instances( + config, + cs.zkvm_v1_css.num_witin as usize, + cs.zkvm_v1_css.num_structural_witin as usize, + self.combined_lk_mlt.as_ref().unwrap(), + &(global_input, shard_ctx.cur_shard()), + )?; + // set num_read, num_write as separate instance + assert!( + self.num_instances + .insert( + GlobalChip::::name(), + vec![ + // global write -> local read + shard_ctx + .write_records() + .iter() + .map(|records| records.len()) + .sum(), + // global read -> local write + shard_ctx + .read_records() + .iter() + .map(|records| records.len()) + .sum(), + ] + ) + .is_none() + ); + assert!( + self.witnesses_tables + .insert(GlobalChip::::name(), witness) + .is_none() + ); + assert!( + !self + .witnesses_opcodes + .contains_key(&GlobalChip::::name()) + ); + + Ok(()) + } + /// Iterate opcode/table circuits, sorted by alphabetical order. pub fn into_iter_sorted( self, @@ -404,6 +530,7 @@ pub struct ZKVMProvingKey> pub circuit_pks: BTreeMap>, pub fixed_commit_wd: Option>::CommitmentWithWitness>>, pub fixed_commit: Option<>::Commitment>, + pub circuit_index_fixed_num_instances: BTreeMap, // expression for global state in/out pub initial_global_state_expr: Expression, @@ -418,6 +545,7 @@ impl> ZKVMProvingKey { params: &ProgramParams, ) -> Result; + fn build_gkr_iop_circuit( + cb: &mut CircuitBuilder, + param: &ProgramParams, + ) -> Result<(Self::TableConfig, Option>), ZKVMError> { + let config = Self::construct_circuit(cb, param)?; + Ok((config, None)) + } + fn generate_fixed_traces( config: &Self::TableConfig, num_fixed: usize, diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 41890200e..833663e74 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -182,6 +182,7 @@ impl TableCircuit for ProgramTableCircuit { cb: &mut CircuitBuilder, params: &ProgramParams, ) -> Result { + assert!(params.program_size.is_power_of_two()); #[cfg(not(feature = "u16limb_circuit"))] let record = InsnRecord([ cb.create_fixed(|| "pc"), @@ -214,7 +215,7 @@ impl TableCircuit for ProgramTableCircuit { cb.lk_table_record( || "prog table", SetTableSpec { - len: Some(params.program_size.next_power_of_two()), + len: Some(params.program_size), structural_witins: vec![], }, ROMType::Instruction, diff --git a/ceno_zkvm/src/tables/ram.rs b/ceno_zkvm/src/tables/ram.rs index e34ce1dcc..6075b0440 100644 --- a/ceno_zkvm/src/tables/ram.rs +++ b/ceno_zkvm/src/tables/ram.rs @@ -8,6 +8,12 @@ use crate::{ mod ram_circuit; mod ram_impl; +use crate::tables::ram::{ + ram_circuit::{LocalFinalRamCircuit, RamBusCircuit}, + ram_impl::{ + DynVolatileRamTableConfig, DynVolatileRamTableInitConfig, NonVolatileInitTableConfig, + }, +}; pub use ram_circuit::{DynVolatileRamTable, MemFinalRecord, MemInitRecord, NonVolatileTable}; #[derive(Clone)] @@ -32,7 +38,8 @@ impl DynVolatileRamTable for HeapTable { } } -pub type HeapCircuit = DynVolatileRamCircuit; +pub type HeapInitCircuit = + DynVolatileRamCircuit>; #[derive(Clone)] pub struct StackTable; @@ -66,7 +73,8 @@ impl DynVolatileRamTable for StackTable { } } -pub type StackCircuit = DynVolatileRamCircuit; +pub type StackInitCircuit = + DynVolatileRamCircuit>; #[derive(Clone)] pub struct HintsTable; @@ -88,7 +96,8 @@ impl DynVolatileRamTable for HintsTable { "HintsTable" } } -pub type HintsCircuit = DynVolatileRamCircuit; +pub type HintsCircuit = + DynVolatileRamCircuit>; /// RegTable, fix size without offset #[derive(Clone)] @@ -108,7 +117,8 @@ impl NonVolatileTable for RegTable { } } -pub type RegTableCircuit = NonVolatileRamCircuit; +pub type RegTableInitCircuit = + NonVolatileRamCircuit>; #[derive(Clone)] pub struct StaticMemTable; @@ -127,7 +137,8 @@ impl NonVolatileTable for StaticMemTable { } } -pub type StaticMemCircuit = NonVolatileRamCircuit; +pub type StaticMemInitCircuit = + NonVolatileRamCircuit>; #[derive(Clone)] pub struct PubIOTable; @@ -147,3 +158,5 @@ impl NonVolatileTable for PubIOTable { } pub type PubIOCircuit = PubIORamCircuit; +pub type LocalFinalCircuit<'a, E> = LocalFinalRamCircuit<'a, UINT_LIMBS, E>; +pub type RBCircuit<'a, E> = RamBusCircuit<'a, UINT_LIMBS, E>; diff --git a/ceno_zkvm/src/tables/ram/ram_circuit.rs b/ceno_zkvm/src/tables/ram/ram_circuit.rs index 0a8b6bf97..344a8d891 100644 --- a/ceno_zkvm/src/tables/ram/ram_circuit.rs +++ b/ceno_zkvm/src/tables/ram/ram_circuit.rs @@ -1,17 +1,26 @@ use std::{collections::HashMap, marker::PhantomData}; -use ceno_emul::{Addr, Cycle, GetAddr, WORD_SIZE, Word}; -use ff_ext::ExtensionField; -use witness::{InstancePaddingStrategy, RowMajorMatrix}; - +use super::ram_impl::{ + LocalFinalRAMTableConfig, NonVolatileTableConfigTrait, PubIOTableConfig, RAMBusConfig, +}; use crate::{ circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, structs::{ProgramParams, RAMType}, tables::{RMMCollections, TableCircuit}, }; - -use super::ram_impl::{DynVolatileRamTableConfig, NonVolatileTableConfig, PubIOTableConfig}; +use ceno_emul::{Addr, Cycle, GetAddr, WORD_SIZE, Word}; +use ff_ext::{ExtensionField, SmallField}; +use gkr_iop::{ + chip::Chip, + error::CircuitBuilderError, + gkr::{GKRCircuit, layer::Layer}, + selector::SelectorType, +}; +use itertools::Itertools; +use multilinear_extensions::{StructuralWitInType, ToExpr}; +use witness::{InstancePaddingStrategy, RowMajorMatrix}; #[derive(Clone, Debug)] pub struct MemInitRecord { @@ -21,6 +30,7 @@ pub struct MemInitRecord { #[derive(Clone, Debug)] pub struct MemFinalRecord { + pub ram_type: RAMType, pub addr: Addr, pub cycle: Cycle, pub value: Word, @@ -60,12 +70,15 @@ pub trait NonVolatileTable { /// - with fixed initial content, /// - with witnessed final content that the program wrote, if WRITABLE, /// - or final content equal to initial content, if not WRITABLE. -pub struct NonVolatileRamCircuit(PhantomData<(E, R)>); +pub struct NonVolatileRamCircuit(PhantomData<(E, R, C)>); -impl TableCircuit - for NonVolatileRamCircuit +impl< + E: ExtensionField, + NVRAM: NonVolatileTable + Send + Sync + Clone, + C: NonVolatileTableConfigTrait, +> TableCircuit for NonVolatileRamCircuit { - type TableConfig = NonVolatileTableConfig; + type TableConfig = C::Config; type FixedInput = [MemInitRecord]; type WitnessInput = [MemFinalRecord]; @@ -77,10 +90,7 @@ impl TableCirc cb: &mut CircuitBuilder, params: &ProgramParams, ) -> Result { - Ok(cb.namespace( - || Self::name(), - |cb| Self::TableConfig::construct_circuit(cb, params), - )?) + Ok(cb.namespace(|| Self::name(), |cb| C::construct_circuit(cb, params))?) } fn generate_fixed_traces( @@ -89,7 +99,7 @@ impl TableCirc init_v: &Self::FixedInput, ) -> RowMajorMatrix { // assume returned table is well-formed include padding - config.gen_init_state(num_fixed, init_v) + C::gen_init_state(config, num_fixed, init_v) } fn assign_instances( @@ -100,7 +110,12 @@ impl TableCirc final_v: &Self::WitnessInput, ) -> Result, ZKVMError> { // assume returned table is well-formed include padding - Ok(config.assign_instances(num_witin, num_structural_witin, final_v)?) + Ok(C::assign_instances( + config, + num_witin, + num_structural_witin, + final_v, + )?) } } @@ -189,6 +204,20 @@ pub trait DynVolatileRamTable { } } +pub trait DynVolatileRamTableConfigTrait: Sized + Send + Sync { + type Config: Sized + Send + Sync; + fn construct_circuit( + cb: &mut CircuitBuilder, + params: &ProgramParams, + ) -> Result; + fn assign_instances( + config: &Self::Config, + num_witin: usize, + num_structural_witin: usize, + final_mem: &[MemFinalRecord], + ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError>; +} + /// DynVolatileRamCircuit initializes and finalizes memory /// - at witnessed addresses, in a contiguous range chosen by the prover, /// - with zeros as initial content if ZERO_INIT, @@ -197,12 +226,15 @@ pub trait DynVolatileRamTable { /// If not ZERO_INIT: /// - The initial content is an unconstrained prover hint. /// - The final content is equal to this initial content. -pub struct DynVolatileRamCircuit(PhantomData<(E, R)>); +pub struct DynVolatileRamCircuit(PhantomData<(E, R, C)>); -impl TableCircuit - for DynVolatileRamCircuit +impl< + E: ExtensionField, + DVRAM: DynVolatileRamTable + Send + Sync + Clone, + C: DynVolatileRamTableConfigTrait, +> TableCircuit for DynVolatileRamCircuit { - type TableConfig = DynVolatileRamTableConfig; + type TableConfig = C::Config; type FixedInput = (); type WitnessInput = [MemFinalRecord]; @@ -210,6 +242,57 @@ impl TableC format!("RAM_{:?}_{}", DVRAM::RAM_TYPE, DVRAM::name()) } + fn construct_circuit( + cb: &mut CircuitBuilder, + params: &ProgramParams, + ) -> Result { + Ok(cb.namespace(|| Self::name(), |cb| C::construct_circuit(cb, params))?) + } + + fn generate_fixed_traces( + _config: &Self::TableConfig, + _num_fixed: usize, + _init_v: &Self::FixedInput, + ) -> RowMajorMatrix { + RowMajorMatrix::::new(0, 0, InstancePaddingStrategy::Default) + } + + fn assign_instances( + config: &Self::TableConfig, + num_witin: usize, + num_structural_witin: usize, + _multiplicity: &[HashMap], + final_v: &Self::WitnessInput, + ) -> Result, ZKVMError> { + // assume returned table is well-formed include padding + Ok( + >::assign_instances( + config, + num_witin, + num_structural_witin, + final_v, + )?, + ) + } +} + +/// This circuit is generalized version to handle all mmio records +pub struct LocalFinalRamCircuit<'a, const V_LIMBS: usize, E>(PhantomData<(&'a (), E)>); + +impl<'a, E: ExtensionField, const V_LIMBS: usize> TableCircuit + for LocalFinalRamCircuit<'a, V_LIMBS, E> +{ + type TableConfig = LocalFinalRAMTableConfig; + type FixedInput = (); + type WitnessInput = ( + &'a ShardContext<'a>, + &'a [(InstancePaddingStrategy, &'a [MemFinalRecord])], + ); + + fn name() -> String { + "LocalRAMTableFinal".to_string() + } + fn construct_circuit( cb: &mut CircuitBuilder, params: &ProgramParams, @@ -220,6 +303,49 @@ impl TableC )?) } + fn build_gkr_iop_circuit( + cb: &mut CircuitBuilder, + param: &ProgramParams, + ) -> Result<(Self::TableConfig, Option>), ZKVMError> { + let config = Self::construct_circuit(cb, param)?; + let r_table_len = cb.cs.r_table_expressions.len(); + + let selector = cb.create_structural_witin( + || "selector", + StructuralWitInType::EqualDistanceSequence { + // TODO determin proper size of max length + max_len: u32::MAX as usize, + offset: 0, + multi_factor: 0, + descending: false, + }, + ); + let selector_type = SelectorType::Prefix(selector.expr()); + + // all shared the same selector + let (out_evals, mut chip) = ( + [ + // r_record + (0..r_table_len).collect_vec(), + // w_record + vec![], + // lk_record + vec![], + // zero_record + vec![], + ], + Chip::new_from_cb(cb, 0), + ); + + // register selector to legacy constrain system + cb.cs.r_selector = Some(selector_type.clone()); + + let layer = Layer::from_circuit_builder(cb, "Rounds".to_string(), 0, out_evals); + chip.add_layer(layer); + + Ok((config, Some(chip.gkr_circuit()))) + } + fn generate_fixed_traces( _config: &Self::TableConfig, _num_fixed: usize, @@ -233,9 +359,64 @@ impl TableC num_witin: usize, num_structural_witin: usize, _multiplicity: &[HashMap], - final_v: &Self::WitnessInput, + (shard_ctx, final_mem): &Self::WitnessInput, + ) -> Result, ZKVMError> { + // assume returned table is well-formed include padding + Ok(Self::TableConfig::assign_instances( + config, + shard_ctx, + num_witin, + num_structural_witin, + final_mem, + )?) + } +} + +/// This circuit is generalized version to handle all mmio records +pub struct RamBusCircuit<'a, const V_LIMBS: usize, E>(PhantomData<(&'a (), E)>); + +impl<'a, E: ExtensionField, const V_LIMBS: usize> TableCircuit + for RamBusCircuit<'a, V_LIMBS, E> +{ + type TableConfig = RAMBusConfig; + type FixedInput = (); + type WitnessInput = ShardContext<'a>; + + fn name() -> String { + "RamBusCircuit".to_string() + } + + fn construct_circuit( + cb: &mut CircuitBuilder, + params: &ProgramParams, + ) -> Result { + Ok(cb.namespace( + || Self::name(), + |cb| Self::TableConfig::construct_circuit(cb, params), + )?) + } + + fn generate_fixed_traces( + _config: &Self::TableConfig, + _num_fixed: usize, + _init_v: &Self::FixedInput, + ) -> RowMajorMatrix { + RowMajorMatrix::::new(0, 0, InstancePaddingStrategy::Default) + } + + fn assign_instances( + config: &Self::TableConfig, + num_witin: usize, + num_structural_witin: usize, + _multiplicity: &[HashMap], + shard_ctx: &Self::WitnessInput, ) -> Result, ZKVMError> { // assume returned table is well-formed include padding - Ok(config.assign_instances(num_witin, num_structural_witin, final_v)?) + Ok(Self::TableConfig::assign_instances( + config, + shard_ctx, + num_witin, + num_structural_witin, + )?) } } diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index f92dc37cc..554c71235 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -1,62 +1,82 @@ -use std::{marker::PhantomData, sync::Arc}; - use ceno_emul::{Addr, Cycle, WORD_SIZE}; +use either::Either; use ff_ext::{ExtensionField, SmallField}; use gkr_iop::error::CircuitBuilderError; use itertools::Itertools; -use rayon::iter::{IndexedParallelIterator, ParallelIterator}; -use witness::{InstancePaddingStrategy, RowMajorMatrix, set_fixed_val, set_val}; +use rayon::iter::{ + IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, +}; +use std::marker::PhantomData; +use witness::{ + InstancePaddingStrategy, RowMajorMatrix, next_pow2_instance_padding, set_fixed_val, set_val, +}; +use super::{ + MemInitRecord, + ram_circuit::{DynVolatileRamTable, MemFinalRecord, NonVolatileTable}, +}; use crate::{ chip_handler::general::PublicIOQuery, circuit_builder::{CircuitBuilder, SetTableSpec}, + e2e::ShardContext, instructions::riscv::constants::{LIMB_BITS, LIMB_MASK}, structs::ProgramParams, + tables::ram::ram_circuit::DynVolatileRamTableConfigTrait, }; use ff_ext::FieldInto; +use gkr_iop::RAMType; use multilinear_extensions::{ Expression, Fixed, StructuralWitIn, StructuralWitInType, ToExpr, WitIn, }; +use p3::field::FieldAlgebra; -use super::{ - MemInitRecord, - ram_circuit::{DynVolatileRamTable, MemFinalRecord, NonVolatileTable}, -}; +pub trait NonVolatileTableConfigTrait: Sized + Send + Sync { + type Config: Sized + Send + Sync; + + fn construct_circuit( + cb: &mut CircuitBuilder, + params: &ProgramParams, + ) -> Result; + + fn gen_init_state( + config: &Self::Config, + num_fixed: usize, + init_mem: &[MemInitRecord], + ) -> RowMajorMatrix; + + fn assign_instances( + config: &Self::Config, + num_witin: usize, + num_structural_witin: usize, + final_mem: &[MemFinalRecord], + ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError>; +} /// define a non-volatile memory with init value #[derive(Clone, Debug)] -pub struct NonVolatileTableConfig { +pub struct NonVolatileInitTableConfig { init_v: Vec, addr: Fixed, - final_v: Option>, - final_cycle: WitIn, - phantom: PhantomData, params: ProgramParams, } -impl NonVolatileTableConfig { - pub fn construct_circuit( +impl NonVolatileTableConfigTrait + for NonVolatileInitTableConfig +{ + type Config = NonVolatileInitTableConfig; + + fn construct_circuit( cb: &mut CircuitBuilder, params: &ProgramParams, ) -> Result { + assert!(NVRAM::WRITABLE); let init_v = (0..NVRAM::V_LIMBS) .map(|i| cb.create_fixed(|| format!("init_v_limb_{i}"))) .collect_vec(); let addr = cb.create_fixed(|| "addr"); - let final_cycle = cb.create_witin(|| "final_cycle"); - let final_v = if NVRAM::WRITABLE { - Some( - (0..NVRAM::V_LIMBS) - .map(|i| cb.create_witin(|| format!("final_v_limb_{i}"))) - .collect::>(), - ) - } else { - None - }; - let init_table = [ vec![(NVRAM::RAM_TYPE as usize).into()], vec![Expression::Fixed(addr)], @@ -65,18 +85,6 @@ impl NonVolatileTableConfig NonVolatileTableConfig( - &self, + fn gen_init_state( + config: &Self::Config, num_fixed: usize, init_mem: &[MemInitRecord], ) -> RowMajorMatrix { assert!( - NVRAM::len(&self.params).is_power_of_two(), + NVRAM::len(&config.params).is_power_of_two(), "{} len {} must be a power of 2", NVRAM::name(), - NVRAM::len(&self.params) + NVRAM::len(&config.params) ); let mut init_table = RowMajorMatrix::::new( - NVRAM::len(&self.params), + NVRAM::len(&config.params), num_fixed, InstancePaddingStrategy::Default, ); @@ -129,56 +126,31 @@ impl NonVolatileTableConfig> (l * LIMB_BITS)) & LIMB_MASK; set_fixed_val!(row, limb, (val as u64).into_f()); }); } - set_fixed_val!(row, self.addr, (rec.addr as u64).into_f()); + set_fixed_val!(row, config.addr, (rec.addr as u64).into_f()); }); init_table } /// TODO consider taking RowMajorMatrix as argument to save allocations. - pub fn assign_instances( - &self, - num_witin: usize, + fn assign_instances( + _config: &Self::Config, + _num_witin: usize, num_structural_witin: usize, - final_mem: &[MemFinalRecord], + _final_mem: &[MemFinalRecord], ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { assert_eq!(num_structural_witin, 0); - let mut final_table = RowMajorMatrix::::new( - NVRAM::len(&self.params), - num_witin, - InstancePaddingStrategy::Default, - ); - - final_table - .par_rows_mut() - .zip_eq(final_mem) - .for_each(|(row, rec)| { - if let Some(final_v) = &self.final_v { - if final_v.len() == 1 { - // Assign value directly. - set_val!(row, final_v[0], rec.value as u64); - } else { - // Assign value limbs. - final_v.iter().enumerate().for_each(|(l, limb)| { - let val = (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; - set_val!(row, limb, val as u64); - }); - } - } - set_val!(row, self.final_cycle, rec.cycle); - }); - - Ok([final_table, RowMajorMatrix::empty()]) + Ok([RowMajorMatrix::empty(), RowMajorMatrix::empty()]) } } @@ -311,8 +283,11 @@ pub struct DynVolatileRamTableConfig DynVolatileRamTableConfig { - pub fn construct_circuit( +impl DynVolatileRamTableConfigTrait + for DynVolatileRamTableConfig +{ + type Config = DynVolatileRamTableConfig; + fn construct_circuit( cb: &mut CircuitBuilder, params: &ProgramParams, ) -> Result { @@ -385,59 +360,664 @@ impl DynVolatileRamTableConfig } /// TODO consider taking RowMajorMatrix as argument to save allocations. - pub fn assign_instances( - &self, + fn assign_instances( + config: &Self::Config, num_witin: usize, num_structural_witin: usize, final_mem: &[MemFinalRecord], ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { - assert!(final_mem.len() <= DVRAM::max_len(&self.params)); - assert!(DVRAM::max_len(&self.params).is_power_of_two()); - - let params = self.params.clone(); - let addr_id = self.addr.id as u64; - let addr_padding_fn = move |row: u64, col: u64| { - assert_eq!(col, addr_id); - DVRAM::addr(¶ms, row as usize) as u64 - }; + if final_mem.is_empty() { + return Ok([RowMajorMatrix::empty(), RowMajorMatrix::empty()]); + } - let mut witness = - RowMajorMatrix::::new(final_mem.len(), num_witin, InstancePaddingStrategy::Default); + let num_instances_padded = next_pow2_instance_padding(final_mem.len()); + assert!(num_instances_padded <= DVRAM::max_len(&config.params)); + assert!(DVRAM::max_len(&config.params).is_power_of_two()); + + let mut witness = RowMajorMatrix::::new( + num_instances_padded, + num_witin, + InstancePaddingStrategy::Default, + ); let mut structural_witness = RowMajorMatrix::::new( - final_mem.len(), + num_instances_padded, num_structural_witin, - InstancePaddingStrategy::Custom(Arc::new(addr_padding_fn)), + InstancePaddingStrategy::Default, ); witness .par_rows_mut() - .zip(structural_witness.par_rows_mut()) - .zip(final_mem) + .zip_eq(structural_witness.par_rows_mut()) .enumerate() - .for_each(|(i, ((row, structural_row), rec))| { - assert_eq!( - rec.addr, - DVRAM::addr(&self.params, i), - "rec.addr {:x} != expected {:x}", - rec.addr, - DVRAM::addr(&self.params, i), - ); + .for_each(|(i, (row, structural_row))| { + if cfg!(debug_assertions) + && let Some(addr) = final_mem.get(i).map(|rec| rec.addr) + { + debug_assert_eq!( + addr, + DVRAM::addr(&config.params, i), + "rec.addr {:x} != expected {:x}", + addr, + DVRAM::addr(&config.params, i), + ); + } - if self.final_v.len() == 1 { - // Assign value directly. - set_val!(row, self.final_v[0], rec.value as u64); - } else { - // Assign value limbs. - self.final_v.iter().enumerate().for_each(|(l, limb)| { - let val = (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; - set_val!(row, limb, val as u64); - }); + if let Some(rec) = final_mem.get(i) { + if config.final_v.len() == 1 { + // Assign value directly. + set_val!(row, config.final_v[0], rec.value as u64); + } else { + // Assign value limbs. + config.final_v.iter().enumerate().for_each(|(l, limb)| { + let val = (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; + set_val!(row, limb, val as u64); + }); + } + set_val!(row, config.final_cycle, rec.cycle); } - set_val!(row, self.final_cycle, rec.cycle); + set_val!( + structural_row, + config.addr, + DVRAM::addr(&config.params, i) as u64 + ); + }); + + Ok([witness, structural_witness]) + } +} + +/// volatile with all init value as 0 +/// dynamic address as witin, relied on augment of knowledge to prove address form +#[derive(Clone, Debug)] +pub struct DynVolatileRamTableInitConfig { + addr: StructuralWitIn, + + phantom: PhantomData, + params: ProgramParams, +} + +impl DynVolatileRamTableConfigTrait + for DynVolatileRamTableInitConfig +{ + type Config = DynVolatileRamTableInitConfig; + + fn construct_circuit( + cb: &mut CircuitBuilder, + params: &ProgramParams, + ) -> Result { + let max_len = DVRAM::max_len(params); + let addr = cb.create_structural_witin( + || "addr", + StructuralWitInType::EqualDistanceSequence { + max_len, + offset: DVRAM::offset_addr(params), + multi_factor: WORD_SIZE, + descending: DVRAM::DESCENDING, + }, + ); + + assert!(DVRAM::ZERO_INIT); + + let init_expr = vec![Expression::ZERO; DVRAM::V_LIMBS]; + + let init_table = [ + vec![(DVRAM::RAM_TYPE as usize).into()], + vec![addr.expr()], + init_expr, + vec![Expression::ZERO], // Initial cycle. + ] + .concat(); + + cb.w_table_record( + || "init_table", + DVRAM::RAM_TYPE, + SetTableSpec { + len: None, + structural_witins: vec![addr], + }, + init_table, + )?; + + Ok(Self { + addr, + phantom: PhantomData, + params: params.clone(), + }) + } + + /// TODO consider taking RowMajorMatrix as argument to save allocations. + fn assign_instances( + config: &Self::Config, + _num_witin: usize, + num_structural_witin: usize, + final_mem: &[MemFinalRecord], + ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { + if final_mem.is_empty() { + return Ok([RowMajorMatrix::empty(), RowMajorMatrix::empty()]); + } + + let num_instances_padded = next_pow2_instance_padding(final_mem.len()); + assert!(num_instances_padded <= DVRAM::max_len(&config.params)); + assert!(DVRAM::max_len(&config.params).is_power_of_two()); - set_val!(structural_row, self.addr, rec.addr as u64); + let mut structural_witness = RowMajorMatrix::::new( + num_instances_padded, + num_structural_witin, + InstancePaddingStrategy::Default, + ); + + structural_witness + .par_rows_mut() + .enumerate() + .for_each(|(i, structural_row)| { + if cfg!(debug_assertions) + && let Some(addr) = final_mem.get(i).map(|rec| rec.addr) + { + debug_assert_eq!( + addr, + DVRAM::addr(&config.params, i), + "rec.addr {:x} != expected {:x}", + addr, + DVRAM::addr(&config.params, i), + ); + } + set_val!( + structural_row, + config.addr, + DVRAM::addr(&config.params, i) as u64 + ); }); + Ok([RowMajorMatrix::empty(), structural_witness]) + } +} + +/// This table is generalized version to handle all mmio records +#[derive(Clone, Debug)] +pub struct LocalFinalRAMTableConfig { + addr_subset: WitIn, + ram_type: WitIn, + + final_v: Vec, + final_cycle: WitIn, +} + +impl LocalFinalRAMTableConfig { + pub fn construct_circuit( + cb: &mut CircuitBuilder, + _params: &ProgramParams, + ) -> Result { + let addr_subset = cb.create_witin(|| "addr_subset"); + let ram_type = cb.create_witin(|| "ram_type"); + + let final_v = (0..V_LIMBS) + .map(|i| cb.create_witin(|| format!("final_v_limb_{i}"))) + .collect::>(); + let final_cycle = cb.create_witin(|| "final_cycle"); + + let final_expr = final_v.iter().map(|v| v.expr()).collect_vec(); + let raw_final_table = [ + // a v t + vec![ram_type.expr()], + vec![addr_subset.expr()], + final_expr, + vec![final_cycle.expr()], + ] + .concat(); + let rlc_record = cb.rlc_chip_record(raw_final_table.clone()); + cb.r_table_rlc_record( + || "final_table", + // XXX we mixed all ram type here to save column allocation + ram_type.expr(), + SetTableSpec { + len: None, + structural_witins: vec![], + }, + raw_final_table, + rlc_record, + )?; + + Ok(Self { + addr_subset, + ram_type, + final_v, + final_cycle, + }) + } + + /// TODO consider taking RowMajorMatrix as argument to save allocations. + pub fn assign_instances( + &self, + shard_ctx: &ShardContext, + num_witin: usize, + num_structural_witin: usize, + final_mem: &[(InstancePaddingStrategy, &[MemFinalRecord])], + ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { + assert!(num_structural_witin == 0 || num_structural_witin == 1); + let num_structural_witin = num_structural_witin.max(1); + let selector_witin = WitIn { id: 0 }; + + let is_current_shard_mem_record = |record: &&MemFinalRecord| -> bool { + (shard_ctx.is_first_shard() && record.cycle == 0) + || shard_ctx.is_current_shard_cycle(record.cycle) + }; + + // collect each raw mem belong to this shard, BEFORE padding length + let current_shard_mems_len: Vec = final_mem + .par_iter() + .map(|(_, mem)| mem.par_iter().filter(is_current_shard_mem_record).count()) + .collect(); + + // deal with non-pow2 padding for first shard + // format Vec<(pad_len, pad_start_index)> + let padding_info = if shard_ctx.is_first_shard() { + final_mem + .iter() + .map(|(_, mem)| { + assert!(!mem.is_empty()); + ( + next_pow2_instance_padding(mem.len()) - mem.len(), + mem.len(), + mem[0].ram_type, + ) + }) + .collect_vec() + } else { + vec![(0, 0, RAMType::Undefined); final_mem.len()] + }; + + // calculate mem length + let mem_lens = current_shard_mems_len + .iter() + .zip_eq(&padding_info) + .map(|(raw_len, (pad_len, _, _))| raw_len + pad_len) + .collect_vec(); + let total_records = mem_lens.iter().sum(); + + let mut witness = + RowMajorMatrix::::new(total_records, num_witin, InstancePaddingStrategy::Default); + let mut structural_witness = RowMajorMatrix::::new( + total_records, + num_structural_witin, + InstancePaddingStrategy::Default, + ); + + let mut witness_mut_slices = Vec::with_capacity(final_mem.len()); + let mut structural_witness_mut_slices = Vec::with_capacity(final_mem.len()); + let mut witness_value_rest = witness.values.as_mut_slice(); + let mut structural_witness_value_rest = structural_witness.values.as_mut_slice(); + + for mem_len in mem_lens { + let witness_length = mem_len * num_witin; + let structural_witness_length = mem_len * num_structural_witin; + assert!( + witness_length <= witness_value_rest.len(), + "chunk size exceeds remaining data" + ); + assert!( + structural_witness_length <= structural_witness_value_rest.len(), + "chunk size exceeds remaining data" + ); + let (witness_left, witness_r) = witness_value_rest.split_at_mut(witness_length); + let (structural_witness_left, structural_witness_r) = + structural_witness_value_rest.split_at_mut(structural_witness_length); + witness_mut_slices.push(witness_left); + structural_witness_mut_slices.push(structural_witness_left); + witness_value_rest = witness_r; + structural_witness_value_rest = structural_witness_r; + } + + witness_mut_slices + .par_iter_mut() + .zip_eq(structural_witness_mut_slices.par_iter_mut()) + .zip_eq(final_mem.par_iter()) + .zip_eq(padding_info.par_iter()) + .for_each( + |( + ((witness, structural_witness), (padding_strategy, final_mem)), + (pad_size, pad_start_index, ram_type), + )| { + let mem_record_count = witness + .chunks_mut(num_witin) + .zip_eq(structural_witness.chunks_mut(num_structural_witin)) + .zip(final_mem.iter().filter(is_current_shard_mem_record)) + .map(|((row, structural_row), rec)| { + if self.final_v.len() == 1 { + // Assign value directly. + set_val!(row, self.final_v[0], rec.value as u64); + } else { + // Assign value limbs. + self.final_v.iter().enumerate().for_each(|(l, limb)| { + let val = (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; + set_val!(row, limb, val as u64); + }); + } + set_val!(row, self.final_cycle, rec.cycle); + + set_val!(row, self.ram_type, rec.ram_type as u64); + set_val!(row, self.addr_subset, rec.addr as u64); + set_val!(structural_row, selector_witin, 1u64); + }) + .count(); + + if *pad_size > 0 && shard_ctx.is_first_shard() { + match padding_strategy { + InstancePaddingStrategy::Custom(pad_func) => { + witness[mem_record_count * num_witin..] + .chunks_mut(num_witin) + .zip_eq( + structural_witness + [mem_record_count * num_structural_witin..] + .chunks_mut(num_structural_witin), + ) + .zip_eq( + std::iter::successors(Some(*pad_start_index), |n| { + Some(*n + 1) + }) + .take(*pad_size), + ) + .for_each(|((row, structural_row), pad_index)| { + set_val!( + row, + self.addr_subset, + pad_func(pad_index as u64, self.addr_subset.id as u64) + ); + set_val!(row, self.ram_type, *ram_type as u64); + set_val!(structural_row, selector_witin, 1u64); + }); + } + _ => unimplemented!(), + } + } + }, + ); + + Ok([witness, structural_witness]) + } +} + +/// The general config to handle ram bus across all records +#[derive(Clone, Debug)] +pub struct RAMBusConfig { + addr_subset: WitIn, + + sel_read: StructuralWitIn, + sel_write: StructuralWitIn, + local_write_v: Vec, + local_read_v: Vec, + local_read_cycle: WitIn, +} + +impl RAMBusConfig { + pub fn construct_circuit( + cb: &mut CircuitBuilder, + _params: &ProgramParams, + ) -> Result { + let ram_type = cb.create_witin(|| "ram_type"); + let one = Expression::Constant(Either::Left(E::BaseField::ONE)); + let addr_subset = cb.create_witin(|| "addr_subset"); + // TODO add new selector to support sel_rw + let sel_read = cb.create_structural_witin( + || "sel_read", + StructuralWitInType::EqualDistanceSequence { + max_len: 0, + offset: 0, + multi_factor: WORD_SIZE, + descending: false, + }, + ); + let sel_write = cb.create_structural_witin( + || "sel_write", + StructuralWitInType::EqualDistanceSequence { + max_len: 0, + offset: 0, + multi_factor: WORD_SIZE, + descending: false, + }, + ); + + // local write + let local_write_v = (0..V_LIMBS) + .map(|i| cb.create_witin(|| format!("local_write_v_limb_{i}"))) + .collect::>(); + let local_write_v_expr = local_write_v.iter().map(|v| v.expr()).collect_vec(); + + // local read + let local_read_v = (0..V_LIMBS) + .map(|i| cb.create_witin(|| format!("local_read_v_limb_{i}"))) + .collect::>(); + let local_read_v_expr: Vec> = + local_read_v.iter().map(|v| v.expr()).collect_vec(); + let local_read_cycle = cb.create_witin(|| "local_read_cycle"); + + // TODO global write + // TODO global read + + // constraints + // read from global, write to local + // W_{local} = sel_read * local_write_record + (1 - sel_read) * ONE + let local_raw_write_record = [ + vec![ram_type.expr()], + vec![addr_subset.expr()], + local_write_v_expr.clone(), + vec![Expression::ZERO], // mem bus local init cycle always 0. + ] + .concat(); + let local_write_record = cb.rlc_chip_record(local_raw_write_record.clone()); + let local_write = + sel_read.expr() * local_write_record + (one.clone() - sel_read.expr()).expr(); + // local write, global read + cb.w_table_rlc_record( + || "local_write_record", + ram_type.expr(), + SetTableSpec { + len: None, + structural_witins: vec![sel_read], + }, + local_raw_write_record, + local_write, + )?; + // TODO R_{global} = mem_bus_with_read * (sel_read * global_read + (1-sel_read) * EC_INFINITY) + (1 - mem_bus_with_read) * EC_INFINITY + + // write to global, read from local + // R_{local} = sel_write * local_read_record + (1 - sel_write) * ONE + let local_raw_read_record = [ + vec![ram_type.expr()], + vec![addr_subset.expr()], + local_read_v_expr.clone(), + vec![local_read_cycle.expr()], + ] + .concat(); + let local_read_record = cb.rlc_chip_record(local_raw_read_record.clone()); + let local_read: Expression = + sel_write.expr() * local_read_record + (one.clone() - sel_write.expr()); + + // local read, global write + cb.r_table_rlc_record( + || "local_read_record", + ram_type.expr(), + SetTableSpec { + len: None, + structural_witins: vec![sel_write], + }, + local_raw_read_record, + local_read, + )?; + // TODO W_{local} = mem_bus_with_write * (sel_write * global_write + (1 - sel_write) * EC_INFINITY) + (1 - mem_bus_with_write) * EC_INFINITY + + Ok(Self { + addr_subset, + sel_write, + sel_read, + local_write_v, + local_read_v, + local_read_cycle, + }) + } + + /// TODO consider taking RowMajorMatrix as argument to save allocations. + pub fn assign_instances( + &self, + shard_ctx: &ShardContext, + num_witin: usize, + num_structural_witin: usize, + ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { + let (global_read_records, global_write_records) = + (shard_ctx.read_records(), shard_ctx.write_records()); + assert_eq!(global_read_records.len(), global_write_records.len()); + let raw_write_len: usize = global_write_records.iter().map(|m| m.len()).sum(); + let raw_read_len: usize = global_read_records.iter().map(|m| m.len()).sum(); + if raw_read_len + raw_write_len == 0 { + return Ok([RowMajorMatrix::empty(), RowMajorMatrix::empty()]); + } + // TODO refactor to deal with only read/write + + let witness_length = { + let max_len = raw_read_len.max(raw_write_len); + // first half write, second half read + next_pow2_instance_padding(max_len) * 2 + }; + let mut witness = + RowMajorMatrix::::new(witness_length, num_witin, InstancePaddingStrategy::Default); + let mut structural_witness = RowMajorMatrix::::new( + witness_length, + num_structural_witin, + InstancePaddingStrategy::Default, + ); + let witness_mid = witness.values.len() / 2; + let (witness_write, witness_read) = witness.values.split_at_mut(witness_mid); + let structural_witness_mid = structural_witness.values.len() / 2; + let (structural_witness_write, structural_witness_read) = structural_witness + .values + .split_at_mut(structural_witness_mid); + + let mut witness_write_mut_slices = Vec::with_capacity(global_write_records.len()); + let mut witness_read_mut_slices = Vec::with_capacity(global_read_records.len()); + let mut structural_witness_write_mut_slices = + Vec::with_capacity(global_write_records.len()); + let mut structural_witness_read_mut_slices = Vec::with_capacity(global_read_records.len()); + let mut witness_write_value_rest = witness_write; + let mut witness_read_value_rest = witness_read; + let mut structural_witness_write_value_rest = structural_witness_write; + let mut structural_witness_read_value_rest = structural_witness_read; + + for (global_read_record, global_write_record) in + global_read_records.iter().zip_eq(global_write_records) + { + let witness_write_length = global_write_record.len() * num_witin; + let witness_read_length = global_read_record.len() * num_witin; + let structural_witness_write_length = global_write_record.len() * num_structural_witin; + let structural_witness_read_length = global_read_record.len() * num_structural_witin; + assert!( + witness_write_length <= witness_write_value_rest.len(), + "chunk size exceeds remaining data" + ); + assert!( + witness_read_length <= witness_read_value_rest.len(), + "chunk size exceeds remaining data" + ); + assert!( + structural_witness_write_length <= structural_witness_write_value_rest.len(), + "chunk size exceeds remaining data" + ); + assert!( + structural_witness_read_length <= structural_witness_read_value_rest.len(), + "chunk size exceeds remaining data" + ); + let (witness_write, witness_write_r) = + witness_write_value_rest.split_at_mut(witness_write_length); + witness_write_mut_slices.push(witness_write); + witness_write_value_rest = witness_write_r; + + let (witness_read, witness_read_r) = + witness_read_value_rest.split_at_mut(witness_read_length); + witness_read_mut_slices.push(witness_read); + witness_read_value_rest = witness_read_r; + + let (structural_witness_write, structural_witness_write_r) = + structural_witness_write_value_rest.split_at_mut(structural_witness_write_length); + structural_witness_write_mut_slices.push(structural_witness_write); + structural_witness_write_value_rest = structural_witness_write_r; + + let (structural_witness_read, structural_witness_read_r) = + structural_witness_read_value_rest.split_at_mut(structural_witness_read_length); + structural_witness_read_mut_slices.push(structural_witness_read); + structural_witness_read_value_rest = structural_witness_read_r; + } + + rayon::join( + // global write, local read + || { + witness_write_mut_slices + .par_iter_mut() + .zip_eq(structural_witness_write_mut_slices.par_iter_mut()) + .zip_eq(global_write_records.par_iter()) + .for_each( + |((witness_write, structural_witness_write), global_write_mem)| { + witness_write + .chunks_mut(num_witin) + .zip_eq(structural_witness_write.chunks_mut(num_structural_witin)) + .zip_eq(global_write_mem.values()) + .for_each(|((row, structural_row), rec)| { + if self.local_read_v.len() == 1 { + // Assign value directly. + set_val!(row, self.local_read_v[0], rec.value as u64); + } else { + // Assign value limbs. + self.local_read_v.iter().enumerate().for_each( + |(l, limb)| { + let val = + (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; + set_val!(row, limb, val as u64); + }, + ); + } + set_val!(row, self.local_read_cycle, rec.cycle); + + set_val!(row, self.addr_subset, rec.addr.baddr().0 as u64); + set_val!(structural_row, self.sel_write, 1u64); + + // TODO assign W_{global} + }); + }, + ); + }, + // global read, local write + || { + witness_read_mut_slices + .par_iter_mut() + .zip_eq(structural_witness_read_mut_slices.par_iter_mut()) + .zip_eq(global_read_records.par_iter()) + .for_each( + |((witness_read, structural_witness_read), global_read_mem)| { + witness_read + .chunks_mut(num_witin) + .zip_eq(structural_witness_read.chunks_mut(num_structural_witin)) + .zip_eq(global_read_mem.values()) + .for_each(|((row, structural_row), rec)| { + if self.local_write_v.len() == 1 { + // Assign value directly. + set_val!(row, self.local_write_v[0], rec.value as u64); + } else { + // Assign value limbs. + self.local_write_v.iter().enumerate().for_each( + |(l, limb)| { + let val = + (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; + set_val!(row, limb, val as u64); + }, + ); + } + set_val!(row, self.addr_subset, rec.addr.baddr().0 as u64); + set_val!(structural_row, self.sel_read, 1u64); + + // TODO assign R_{global} + }); + }, + ); + }, + ); + structural_witness.padding_by_strategy(); Ok([witness, structural_witness]) } @@ -456,6 +1036,7 @@ mod tests { use ceno_emul::WORD_SIZE; use ff_ext::GoldilocksExt2 as E; + use gkr_iop::RAMType; use itertools::Itertools; use multilinear_extensions::mle::MultilinearExtension; use p3::{field::FieldAlgebra, goldilocks::Goldilocks as F}; @@ -474,6 +1055,7 @@ mod tests { let some_non_2_pow = 26; let input = (0..some_non_2_pow) .map(|i| MemFinalRecord { + ram_type: RAMType::Memory, addr: HintsTable::addr(&def_params, i), cycle: 0, value: 0, diff --git a/gkr_iop/src/chip.rs b/gkr_iop/src/chip.rs index 1b33bb1de..10048418e 100644 --- a/gkr_iop/src/chip.rs +++ b/gkr_iop/src/chip.rs @@ -40,11 +40,17 @@ impl Chip { n_evaluations: cb.cs.w_expressions.len() + cb.cs.r_expressions.len() + cb.cs.lk_expressions.len() + + cb.cs.w_table_expressions.len() + + cb.cs.r_table_expressions.len() + + cb.cs.lk_table_expressions.len() * 2 + cb.cs.num_fixed + cb.cs.num_witin as usize, final_out_evals: (0..cb.cs.w_expressions.len() + cb.cs.r_expressions.len() - + cb.cs.lk_expressions.len()) + + cb.cs.lk_expressions.len() + + cb.cs.w_table_expressions.len() + + cb.cs.r_table_expressions.len() + + cb.cs.lk_table_expressions.len() * 2) .collect_vec(), layers: vec![], } diff --git a/gkr_iop/src/circuit_builder.rs b/gkr_iop/src/circuit_builder.rs index e4129bfe8..e26e3682e 100644 --- a/gkr_iop/src/circuit_builder.rs +++ b/gkr_iop/src/circuit_builder.rs @@ -103,19 +103,23 @@ pub struct ConstraintSystem { pub instance_name_map: HashMap, + pub ec_point_exprs: Vec>, + pub ec_slope_exprs: Vec>, + pub ec_final_sum: Vec>, + pub r_selector: Option>, pub r_expressions: Vec>, pub r_expressions_namespace_map: Vec, // for each read expression we store its ram type and original value before doing RLC // the original value will be used for debugging - pub r_ram_types: Vec<(RAMType, Vec>)>, + pub r_ram_types: Vec<(Expression, Vec>)>, pub w_selector: Option>, pub w_expressions: Vec>, pub w_expressions_namespace_map: Vec, // for each write expression we store its ram type and original value before doing RLC // the original value will be used for debugging - pub w_ram_types: Vec<(RAMType, Vec>)>, + pub w_ram_types: Vec<(Expression, Vec>)>, /// init/final ram expression pub r_table_expressions: Vec>, @@ -167,6 +171,9 @@ impl ConstraintSystem { fixed_namespace_map: vec![], ns: NameSpace::new(root_name_fn), instance_name_map: HashMap::new(), + ec_final_sum: vec![], + ec_slope_exprs: vec![], + ec_point_exprs: vec![], r_selector: None, r_expressions: vec![], r_expressions_namespace_map: vec![], @@ -329,12 +336,27 @@ impl ConstraintSystem { N: FnOnce() -> NR, { let rlc_record = self.rlc_chip_record(record.clone()); - assert_eq!( - rlc_record.degree(), - 1, - "rlc record degree {} != 1", - rlc_record.degree() - ); + self.r_table_rlc_record( + name_fn, + (ram_type as u64).into(), + table_spec, + record, + rlc_record, + ) + } + + pub fn r_table_rlc_record( + &mut self, + name_fn: N, + ram_type: Expression, + table_spec: SetTableSpec, + record: Vec>, + rlc_record: Expression, + ) -> Result<(), CircuitBuilderError> + where + NR: Into, + N: FnOnce() -> NR, + { self.r_table_expressions.push(SetTableExpression { expr: rlc_record, table_spec, @@ -358,12 +380,27 @@ impl ConstraintSystem { N: FnOnce() -> NR, { let rlc_record = self.rlc_chip_record(record.clone()); - assert_eq!( - rlc_record.degree(), - 1, - "rlc record degree {} != 1", - rlc_record.degree() - ); + self.w_table_rlc_record( + name_fn, + (ram_type as u64).into(), + table_spec, + record, + rlc_record, + ) + } + + pub fn w_table_rlc_record( + &mut self, + name_fn: N, + ram_type: Expression, + table_spec: SetTableSpec, + record: Vec>, + rlc_record: Expression, + ) -> Result<(), CircuitBuilderError> + where + NR: Into, + N: FnOnce() -> NR, + { self.w_table_expressions.push(SetTableExpression { expr: rlc_record, table_spec, @@ -382,6 +419,16 @@ impl ConstraintSystem { record: Vec>, ) -> Result<(), CircuitBuilderError> { let rlc_record = self.rlc_chip_record(record.clone()); + self.read_rlc_record(name_fn, (ram_type as u64).into(), record, rlc_record) + } + + pub fn read_rlc_record, N: FnOnce() -> NR>( + &mut self, + name_fn: N, + ram_type: Expression, + record: Vec>, + rlc_record: Expression, + ) -> Result<(), CircuitBuilderError> { self.r_expressions.push(rlc_record); let path = self.ns.compute_path(name_fn().into()); self.r_expressions_namespace_map.push(path); @@ -398,13 +445,46 @@ impl ConstraintSystem { record: Vec>, ) -> Result<(), CircuitBuilderError> { let rlc_record = self.rlc_chip_record(record.clone()); + self.write_rlc_record(name_fn, (ram_type as u64).into(), record, rlc_record) + } + + pub fn write_rlc_record, N: FnOnce() -> NR>( + &mut self, + name_fn: N, + ram_type: Expression, + record: Vec>, + rlc_record: Expression, + ) -> Result<(), CircuitBuilderError> { self.w_expressions.push(rlc_record); let path = self.ns.compute_path(name_fn().into()); self.w_expressions_namespace_map.push(path); + // Since w_expression is RLC(record) and when we're debugging + // it's helpful to recover the value of record itself. self.w_ram_types.push((ram_type, record)); Ok(()) } + pub fn ec_sum( + &mut self, + xs: Vec>, + ys: Vec>, + slopes: Vec>, + final_sum: Vec>, + ) { + const SEPTIC_EXTENSION_DEGREE: usize = 7; + assert_eq!(xs.len(), SEPTIC_EXTENSION_DEGREE); + assert_eq!(ys.len(), SEPTIC_EXTENSION_DEGREE); + assert_eq!(slopes.len(), SEPTIC_EXTENSION_DEGREE); + assert_eq!(final_sum.len(), SEPTIC_EXTENSION_DEGREE * 2); + + assert_eq!(self.ec_point_exprs.len(), 0); + self.ec_point_exprs.extend(xs); + self.ec_point_exprs.extend(ys); + + self.ec_slope_exprs = slopes; + self.ec_final_sum = final_sum; + } + pub fn require_zero, N: FnOnce() -> NR>( &mut self, name_fn: N, @@ -579,6 +659,22 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { .r_table_record(name_fn, ram_type, table_spec, record) } + pub fn r_table_rlc_record( + &mut self, + name_fn: N, + ram_type: Expression, + table_spec: SetTableSpec, + record: Vec>, + rlc_record: Expression, + ) -> Result<(), CircuitBuilderError> + where + NR: Into, + N: FnOnce() -> NR, + { + self.cs + .r_table_rlc_record(name_fn, ram_type, table_spec, record, rlc_record) + } + pub fn w_table_record( &mut self, name_fn: N, @@ -594,6 +690,22 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { .w_table_record(name_fn, ram_type, table_spec, record) } + pub fn w_table_rlc_record( + &mut self, + name_fn: N, + ram_type: Expression, + table_spec: SetTableSpec, + record: Vec>, + rlc_record: Expression, + ) -> Result<(), CircuitBuilderError> + where + NR: Into, + N: FnOnce() -> NR, + { + self.cs + .w_table_rlc_record(name_fn, ram_type, table_spec, record, rlc_record) + } + pub fn read_record( &mut self, name_fn: N, @@ -607,6 +719,21 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { self.cs.read_record(name_fn, ram_type, record) } + pub fn read_rlc_record( + &mut self, + name_fn: N, + ram_type: Expression, + record: Vec>, + rlc_record: Expression, + ) -> Result<(), CircuitBuilderError> + where + NR: Into, + N: FnOnce() -> NR, + { + self.cs + .read_rlc_record(name_fn, ram_type, record, rlc_record) + } + pub fn write_record( &mut self, name_fn: N, @@ -620,10 +747,35 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { self.cs.write_record(name_fn, ram_type, record) } + pub fn write_rlc_record( + &mut self, + name_fn: N, + ram_type: Expression, + record: Vec>, + rlc_record: Expression, + ) -> Result<(), CircuitBuilderError> + where + NR: Into, + N: FnOnce() -> NR, + { + self.cs + .write_rlc_record(name_fn, ram_type, record, rlc_record) + } + pub fn rlc_chip_record(&self, records: Vec>) -> Expression { self.cs.rlc_chip_record(records) } + pub fn ec_sum( + &mut self, + xs: Vec>, + ys: Vec>, + slope: Vec>, + final_sum: Vec>, + ) { + self.cs.ec_sum(xs, ys, slope, final_sum); + } + pub fn create_bit(&mut self, name_fn: N) -> Result where NR: Into, diff --git a/gkr_iop/src/gkr.rs b/gkr_iop/src/gkr.rs index 7d80229fd..b06e8fe71 100644 --- a/gkr_iop/src/gkr.rs +++ b/gkr_iop/src/gkr.rs @@ -11,6 +11,7 @@ use transcript::Transcript; use crate::{ error::BackendError, hal::{ProverBackend, ProverDevice}, + selector::SelectorContext, }; pub mod booleanhypercube; @@ -77,7 +78,7 @@ impl GKRCircuit { pub_io_evals: &[E], challenges: &[E], transcript: &mut impl Transcript, - num_instances: usize, + selector_ctxs: &[SelectorContext], ) -> Result>, BackendError> { let mut running_evals = out_evals.to_vec(); // running evals is a global referable within chip @@ -97,7 +98,7 @@ impl GKRCircuit { pub_io_evals, &mut challenges, transcript, - num_instances, + selector_ctxs, ); exit_span!(span); res @@ -122,7 +123,7 @@ impl GKRCircuit { pub_io_evals: &[E], challenges: &[E], transcript: &mut impl Transcript, - num_instances: usize, + selector_ctxs: &[SelectorContext], ) -> Result>, BackendError> where E: ExtensionField, @@ -141,7 +142,7 @@ impl GKRCircuit { pub_io_evals, &mut challenges, transcript, - num_instances, + selector_ctxs, )?; } diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs index a337dde30..22312497d 100644 --- a/gkr_iop/src/gkr/layer.rs +++ b/gkr_iop/src/gkr/layer.rs @@ -1,3 +1,4 @@ +use either::Either; use ff_ext::ExtensionField; use itertools::{Itertools, chain, izip}; use linear_layer::{LayerClaims, LinearLayer}; @@ -20,7 +21,7 @@ use crate::{ error::BackendError, evaluation::EvalExpression, hal::{MultilinearPolynomial, ProverBackend, ProverDevice}, - selector::SelectorType, + selector::{SelectorContext, SelectorType}, }; pub mod cpu; @@ -183,7 +184,7 @@ impl Layer { pub_io_evals: &[E], challenges: &mut Vec, transcript: &mut T, - num_instances: usize, + selector_ctxs: &[SelectorContext], ) -> LayerProof { self.update_challenges(challenges, transcript); let mut eval_and_dedup_points = self.extract_claim_and_point(claims, challenges); @@ -203,7 +204,7 @@ impl Layer { pub_io_evals, challenges, transcript, - num_instances, + selector_ctxs, ) } LayerType::Linear => { @@ -231,7 +232,7 @@ impl Layer { pub_io_evals: &[E], challenges: &mut Vec, transcript: &mut Trans, - num_instances: usize, + selector_ctxs: &[SelectorContext], ) -> Result<(), BackendError> { self.update_challenges(challenges, transcript); let mut eval_and_dedup_points = self.extract_claim_and_point(claims, challenges); @@ -245,7 +246,7 @@ impl Layer { pub_io_evals, challenges, transcript, - num_instances, + selector_ctxs, )?, LayerType::Linear => { assert_eq!(eval_and_dedup_points.len(), 1); @@ -319,9 +320,9 @@ impl Layer { n_challenges: usize, out_evals: OutEvalGroups, ) -> Layer { - let w_len = cb.cs.w_expressions.len(); - let r_len = cb.cs.r_expressions.len(); - let lk_len = cb.cs.lk_expressions.len(); + let w_len = cb.cs.w_expressions.len() + cb.cs.w_table_expressions.len(); + let r_len = cb.cs.r_expressions.len() + cb.cs.r_table_expressions.len(); + let lk_len = cb.cs.lk_expressions.len() + cb.cs.lk_table_expressions.len() * 2; // logup lk table include p, q let zero_len = cb.cs.assert_zero_expressions.len() + cb.cs.assert_zero_sumcheck_expressions.len(); @@ -331,9 +332,12 @@ impl Layer { assert_eq!(lookup_evals.len(), lk_len); assert_eq!(zero_evals.len(), zero_len); - let non_zero_expr_len = cb.cs.w_expressions_namespace_map.len() - + cb.cs.r_expressions_namespace_map.len() - + cb.cs.lk_expressions.len(); + let non_zero_expr_len = cb.cs.w_expressions.len() + + cb.cs.w_table_expressions.len() + + cb.cs.r_expressions.len() + + cb.cs.r_table_expressions.len() + + cb.cs.lk_expressions.len() + + cb.cs.lk_table_expressions.len() * 2; let zero_expr_len = cb.cs.assert_zero_expressions.len() + cb.cs.assert_zero_sumcheck_expressions.len(); @@ -341,88 +345,116 @@ impl Layer { let mut expr_names = Vec::with_capacity(non_zero_expr_len + zero_expr_len); let mut expressions = Vec::with_capacity(non_zero_expr_len + zero_expr_len); - // process r_record - let evals = - Self::dedup_last_selector_evals(cb.cs.r_selector.as_ref().unwrap(), &mut expr_evals); - for (idx, ((ram_expr, name), ram_eval)) in cb - .cs - .r_expressions - .iter() - .zip_eq(&cb.cs.r_expressions_namespace_map) + if let Some(r_selector) = cb.cs.r_selector.as_ref() { + // process r_record + let evals = Self::dedup_last_selector_evals(r_selector, &mut expr_evals); + for (idx, ((ram_expr, name), ram_eval)) in (cb + .cs + .r_expressions + .iter() + .chain(cb.cs.r_table_expressions.iter().map(|t| &t.expr))) + .zip_eq( + cb.cs + .r_expressions_namespace_map + .iter() + .chain(&cb.cs.r_table_expressions_namespace_map), + ) .zip_eq(&r_record_evals) .enumerate() - { - expressions.push(ram_expr - E::BaseField::ONE.expr()); - evals.push(EvalExpression::::Linear( - // evaluation = claim * one - one (padding) - *ram_eval, - E::BaseField::ONE.expr().into(), - E::BaseField::ONE.neg().expr().into(), - )); - expr_names.push(format!("{}/{idx}", name)); + { + expressions.push(ram_expr - E::BaseField::ONE.expr()); + evals.push(EvalExpression::::Linear( + // evaluation = claim * one - one (padding) + *ram_eval, + E::BaseField::ONE.expr().into(), + E::BaseField::ONE.neg().expr().into(), + )); + expr_names.push(format!("{}/{idx}", name)); + } } - // process w_record - let evals = - Self::dedup_last_selector_evals(cb.cs.w_selector.as_ref().unwrap(), &mut expr_evals); - for (idx, ((ram_expr, name), ram_eval)) in cb - .cs - .w_expressions - .iter() - .zip_eq(&cb.cs.w_expressions_namespace_map) + if let Some(w_selector) = cb.cs.w_selector.as_ref() { + // process w_record + let evals = Self::dedup_last_selector_evals(w_selector, &mut expr_evals); + for (idx, ((ram_expr, name), ram_eval)) in (cb + .cs + .w_expressions + .iter() + .chain(cb.cs.w_table_expressions.iter().map(|t| &t.expr))) + .zip_eq( + cb.cs + .w_expressions_namespace_map + .iter() + .chain(&cb.cs.w_table_expressions_namespace_map), + ) .zip_eq(&w_record_evals) .enumerate() - { - expressions.push(ram_expr - E::BaseField::ONE.expr()); - evals.push(EvalExpression::::Linear( - // evaluation = claim * one - one (padding) - *ram_eval, - E::BaseField::ONE.expr().into(), - E::BaseField::ONE.neg().expr().into(), - )); - expr_names.push(format!("{}/{idx}", name)); + { + expressions.push(ram_expr - E::BaseField::ONE.expr()); + evals.push(EvalExpression::::Linear( + // evaluation = claim * one - one (padding) + *ram_eval, + E::BaseField::ONE.expr().into(), + E::BaseField::ONE.neg().expr().into(), + )); + expr_names.push(format!("{}/{idx}", name)); + } } - // process lookup records - let evals = - Self::dedup_last_selector_evals(cb.cs.lk_selector.as_ref().unwrap(), &mut expr_evals); - for (idx, ((lookup, name), lookup_eval)) in cb - .cs - .lk_expressions - .iter() - .zip_eq(&cb.cs.lk_expressions_namespace_map) + if let Some(lk_selector) = cb.cs.lk_selector.as_ref() { + // process lookup records + let evals = Self::dedup_last_selector_evals(lk_selector, &mut expr_evals); + for (idx, ((lookup, name), lookup_eval)) in (cb + .cs + .lk_expressions + .iter() + .chain(cb.cs.lk_table_expressions.iter().map(|t| &t.multiplicity)) + .chain(cb.cs.lk_table_expressions.iter().map(|t| &t.values))) + .zip_eq(if cb.cs.lk_table_expressions.is_empty() { + Either::Left(cb.cs.lk_expressions_namespace_map.iter()) + } else { + // repeat expressions_namespace_map twice to deal with lk p, q + Either::Right( + cb.cs + .lk_expressions_namespace_map + .iter() + .chain(&cb.cs.lk_expressions_namespace_map), + ) + }) .zip_eq(&lookup_evals) .enumerate() - { - expressions.push(lookup - cb.cs.chip_record_alpha.clone()); - evals.push(EvalExpression::::Linear( - // evaluation = claim * one - alpha (padding) - *lookup_eval, - E::BaseField::ONE.expr().into(), - cb.cs.chip_record_alpha.clone().neg().into(), - )); - expr_names.push(format!("{}/{idx}", name)); + { + expressions.push(lookup - cb.cs.chip_record_alpha.clone()); + evals.push(EvalExpression::::Linear( + // evaluation = claim * one - alpha (padding) + *lookup_eval, + E::BaseField::ONE.expr().into(), + cb.cs.chip_record_alpha.clone().neg().into(), + )); + expr_names.push(format!("{}/{idx}", name)); + } } - // process zero_record - let evals = - Self::dedup_last_selector_evals(cb.cs.zero_selector.as_ref().unwrap(), &mut expr_evals); - for (idx, (zero_expr, name)) in izip!( - 0.., - chain!( - cb.cs - .assert_zero_expressions - .iter() - .zip_eq(&cb.cs.assert_zero_expressions_namespace_map), - cb.cs - .assert_zero_sumcheck_expressions - .iter() - .zip_eq(&cb.cs.assert_zero_sumcheck_expressions_namespace_map) - ) - ) { - expressions.push(zero_expr.clone()); - evals.push(EvalExpression::Zero); - expr_names.push(format!("{}/{idx}", name)); + if let Some(zero_selector) = cb.cs.zero_selector.as_ref() { + // process zero_record + let evals = Self::dedup_last_selector_evals(zero_selector, &mut expr_evals); + for (idx, (zero_expr, name)) in izip!( + 0.., + chain!( + cb.cs + .assert_zero_expressions + .iter() + .zip_eq(&cb.cs.assert_zero_expressions_namespace_map), + cb.cs + .assert_zero_sumcheck_expressions + .iter() + .zip_eq(&cb.cs.assert_zero_sumcheck_expressions_namespace_map) + ) + ) { + expressions.push(zero_expr.clone()); + evals.push(EvalExpression::Zero); + expr_names.push(format!("{}/{idx}", name)); + } } // Sort expressions, expr_names, and evals according to eval.0 and classify evals. diff --git a/gkr_iop/src/gkr/layer/cpu/mod.rs b/gkr_iop/src/gkr/layer/cpu/mod.rs index fa4c33c5e..255daeed3 100644 --- a/gkr_iop/src/gkr/layer/cpu/mod.rs +++ b/gkr_iop/src/gkr/layer/cpu/mod.rs @@ -8,6 +8,7 @@ use crate::{ zerocheck_layer::RotationPoints, }, }, + selector::SelectorContext, utils::{rotation_next_base_mle, rotation_selector}, }; use either::Either; @@ -113,7 +114,7 @@ impl> ZerocheckLayerProver pub_io_evals: &[ as ProverBackend>::E], challenges: &[ as ProverBackend>::E], transcript: &mut impl Transcript< as ProverBackend>::E>, - num_instances: usize, + selector_ctxs: &[SelectorContext], ) -> ( LayerProof< as ProverBackend>::E>, Point< as ProverBackend>::E>, @@ -126,6 +127,12 @@ impl> ZerocheckLayerProver layer.out_sel_and_eval_exprs.len(), out_points.len(), ); + assert_eq!( + layer.out_sel_and_eval_exprs.len(), + selector_ctxs.len(), + "selector_ctxs length {}", + selector_ctxs.len() + ); let (_, raw_rotation_exprs) = &layer.rotation_exprs; let (rotation_proof, rotation_left, rotation_right, rotation_point) = @@ -168,12 +175,16 @@ impl> ZerocheckLayerProver ) ) .collect_vec(); + // zero check eq || rotation eq let mut eqs = layer .out_sel_and_eval_exprs .par_iter() .zip(out_points.par_iter()) - .filter_map(|((sel_type, _), point)| sel_type.compute(point, num_instances)) + .zip(selector_ctxs.par_iter()) + .filter_map(|(((sel_type, _), point), selector_ctx)| { + sel_type.compute(point, selector_ctx) + }) // for rotation left point .chain(rotation_left.par_iter().map(|rotation_left| { MultilinearExtension::from_evaluations_ext_vec( @@ -221,15 +232,16 @@ impl> ZerocheckLayerProver layer.n_structural_witin, layer.n_fixed, ); + let builder = VirtualPolynomialsBuilder::new_with_mles(num_threads, max_num_variables, all_witins); let span = entered_span!("IOPProverState::prove", profiling_4 = true); let (proof, prover_state) = IOPProverState::prove( builder.to_virtual_polys_with_monomial_terms( - &layer + layer .main_sumcheck_expression_monomial_terms - .clone() + .as_ref() .unwrap(), pub_io_evals, &main_sumcheck_challenges, diff --git a/gkr_iop/src/gkr/layer/hal.rs b/gkr_iop/src/gkr/layer/hal.rs index 06508e298..c6cce26a0 100644 --- a/gkr_iop/src/gkr/layer/hal.rs +++ b/gkr_iop/src/gkr/layer/hal.rs @@ -4,6 +4,7 @@ use transcript::Transcript; use crate::{ gkr::layer::{Layer, LayerWitness, sumcheck_layer::LayerProof}, hal::ProverBackend, + selector::SelectorContext, }; pub trait LinearLayerProver { @@ -37,6 +38,6 @@ pub trait ZerocheckLayerProver { pub_io_evals: &[PB::E], challenges: &[PB::E], transcript: &mut impl Transcript, - num_instances: usize, + selector_ctxs: &[SelectorContext], ) -> (LayerProof, Point); } diff --git a/gkr_iop/src/gkr/layer/zerocheck_layer.rs b/gkr_iop/src/gkr/layer/zerocheck_layer.rs index 1d4e6c56a..d9f13a2a9 100644 --- a/gkr_iop/src/gkr/layer/zerocheck_layer.rs +++ b/gkr_iop/src/gkr/layer/zerocheck_layer.rs @@ -27,7 +27,7 @@ use crate::{ }, }, hal::{ProverBackend, ProverDevice}, - selector::SelectorType, + selector::{SelectorContext, SelectorType}, utils::rotation_selector_eval, }; @@ -58,7 +58,7 @@ pub trait ZerocheckLayer { pub_io_evals: &[PB::E], challenges: &[PB::E], transcript: &mut impl Transcript, - num_instances: usize, + selector_ctxs: &[SelectorContext], ) -> (LayerProof, Point); #[allow(clippy::too_many_arguments)] @@ -70,7 +70,7 @@ pub trait ZerocheckLayer { pub_io_evals: &[E], challenges: &[E], transcript: &mut impl Transcript, - num_instances: usize, + selector_ctxs: &[SelectorContext], ) -> Result, BackendError>; } @@ -177,7 +177,7 @@ impl ZerocheckLayer for Layer { pub_io_evals: &[PB::E], challenges: &[PB::E], transcript: &mut impl Transcript, - num_instances: usize, + selector_ctxs: &[SelectorContext], ) -> (LayerProof, Point) { >::prove( self, @@ -188,7 +188,7 @@ impl ZerocheckLayer for Layer { pub_io_evals, challenges, transcript, - num_instances, + selector_ctxs, ) } @@ -200,7 +200,7 @@ impl ZerocheckLayer for Layer { pub_io_evals: &[E], challenges: &[E], transcript: &mut impl Transcript, - num_instances: usize, + selector_ctxs: &[SelectorContext], ) -> Result, BackendError> { assert_eq!( self.out_sel_and_eval_exprs.len(), @@ -284,17 +284,20 @@ impl ZerocheckLayer for Layer { let in_point = in_point.into_iter().map(|c| c.elements).collect_vec(); // eval eq and set to respective witin - izip!(&self.out_sel_and_eval_exprs, &eval_and_dedup_points).for_each( - |((sel_type, _), (_, out_point))| { - sel_type.evaluate( - &mut main_evals, - out_point.as_ref().unwrap(), - &in_point, - num_instances, - self.n_witin, - ); - }, - ); + izip!( + &self.out_sel_and_eval_exprs, + &eval_and_dedup_points, + selector_ctxs.iter() + ) + .for_each(|((sel_type, _), (_, out_point), selector_ctx)| { + sel_type.evaluate( + &mut main_evals, + out_point.as_ref().unwrap(), + &in_point, + selector_ctx, + self.n_witin, + ); + }); let got_claim = eval_by_expr_with_instance( &[], @@ -450,10 +453,11 @@ pub fn extend_exprs_with_rotation( let expr = match sel_type { SelectorType::None => zero_check_expr, SelectorType::Whole(sel) - | SelectorType::Prefix(_, sel) + | SelectorType::Prefix(sel) | SelectorType::OrderedSparse32 { expression: sel, .. - } => match_expr(sel) * zero_check_expr, + } + | SelectorType::QuarkBinaryTreeLessThan(sel) => match_expr(sel) * zero_check_expr, }; zero_check_exprs.push(expr); } diff --git a/gkr_iop/src/lib.rs b/gkr_iop/src/lib.rs index fc69037ff..a5e20f704 100644 --- a/gkr_iop/src/lib.rs +++ b/gkr_iop/src/lib.rs @@ -7,6 +7,7 @@ use either::Either; use ff_ext::ExtensionField; use multilinear_extensions::{Expression, impl_expr_from_unsigned, mle::ArcMultilinearExtension}; use std::marker::PhantomData; +use strum_macros::EnumIter; use transcript::Transcript; use witness::RowMajorMatrix; @@ -77,12 +78,13 @@ pub struct ProtocolVerifier, PCS>( PhantomData<(E, Trans, PCS)>, ); -#[derive(Clone, Debug, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[derive(Clone, Debug, Copy, EnumIter, PartialEq, Eq, serde::Serialize, serde::Deserialize)] #[repr(usize)] pub enum RAMType { - GlobalState, + GlobalState = 0, Register, Memory, + Undefined, } impl_expr_from_unsigned!(RAMType); diff --git a/gkr_iop/src/selector.rs b/gkr_iop/src/selector.rs index bc57295f1..9f10d2249 100644 --- a/gkr_iop/src/selector.rs +++ b/gkr_iop/src/selector.rs @@ -1,16 +1,41 @@ +use std::iter::repeat_n; + use rayon::iter::IndexedParallelIterator; use ff_ext::ExtensionField; use multilinear_extensions::{ Expression, mle::{IntoMLE, MultilinearExtension, Point}, + util::ceil_log2, virtual_poly::{build_eq_x_r_vec, eq_eval}, }; -use rayon::{iter::ParallelIterator, slice::ParallelSliceMut}; +use p3::field::FieldAlgebra; +use rayon::{ + iter::{IntoParallelIterator, ParallelIterator}, + slice::ParallelSliceMut, +}; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use crate::{gkr::booleanhypercube::CYCLIC_POW2_5, utils::eq_eval_less_or_equal_than}; +/// Provide context for selector's instantiation at runtime +#[derive(Clone, Debug)] +pub struct SelectorContext { + pub offset: usize, + pub num_instances: usize, + pub num_vars: usize, +} + +impl SelectorContext { + pub fn new(offset: usize, num_instances: usize, num_vars: usize) -> Self { + Self { + offset, + num_instances, + num_vars, + } + } +} + /// Selector selects part of the witnesses in the sumcheck protocol. #[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] #[serde(bound( @@ -20,42 +45,122 @@ use crate::{gkr::booleanhypercube::CYCLIC_POW2_5, utils::eq_eval_less_or_equal_t pub enum SelectorType { None, Whole(Expression), - /// Select a prefix as the instances, padded with a field element. - Prefix(E::BaseField, Expression), + /// Select part of the instances, other parts padded with a field element. + Prefix(Expression), /// selector activates on the specified `indices`, which are assumed to be in ascending order. /// each index corresponds to a position within a fixed-size chunk (e.g., size 32), OrderedSparse32 { indices: Vec, expression: Expression, }, + /// binary tree [`quark`] from paper + QuarkBinaryTreeLessThan(Expression), } impl SelectorType { + /// Returns an MultilinearExtension with `ctx.num_vars` variables whenever applicable + pub fn to_mle(&self, ctx: &SelectorContext) -> Option> { + match self { + SelectorType::None => None, + SelectorType::Whole(_) => { + assert_eq!(ceil_log2(ctx.num_instances), ctx.num_vars); + Some( + (0..(1 << ctx.num_vars)) + .into_par_iter() + .map(|_| E::BaseField::ONE) + .collect::>() + .into_mle(), + ) + } + SelectorType::Prefix(_) => { + assert!(ctx.offset + ctx.num_instances <= (1 << ctx.num_vars)); + let start = ctx.offset; + let end = start + ctx.num_instances; + Some( + (0..start) + .into_par_iter() + .map(|_| E::BaseField::ZERO) + .chain((start..end).into_par_iter().map(|_| E::BaseField::ONE)) + .chain( + (end..(1 << ctx.num_vars)) + .into_par_iter() + .map(|_| E::BaseField::ZERO), + ) + .collect::>() + .into_mle(), + ) + } + SelectorType::OrderedSparse32 { + indices, + expression: _, + } => { + assert_eq!(ceil_log2(ctx.num_instances) + 5, ctx.num_vars); + Some( + (0..(1 << (ctx.num_vars - 5))) + .into_par_iter() + .flat_map(|chunk_index| { + if chunk_index >= ctx.num_instances { + vec![E::ZERO; 32] + } else { + let mut chunk = vec![E::ZERO; 32]; + let mut indices_iter = indices.iter().copied(); + let mut next_keep = indices_iter.next(); + + for (i, e) in chunk.iter_mut().enumerate() { + if let Some(idx) = next_keep + && i == idx + { + *e = E::ONE; + next_keep = indices_iter.next(); // Keep this one + } + } + chunk + } + }) + .collect::>() + .into_mle(), + ) + } + SelectorType::QuarkBinaryTreeLessThan(..) => unimplemented!(), + } + } + /// Compute true and false mle eq(1; b[..5]) * sel(y; b[5..]), and eq(1; b[..5]) * (eq() - sel(y; b[5..])) pub fn compute( &self, out_point: &Point, - num_instances: usize, + ctx: &SelectorContext, ) -> Option> { + assert_eq!(out_point.len(), ctx.num_vars); + match self { SelectorType::None => None, - SelectorType::Whole(_expr) => Some(build_eq_x_r_vec(out_point).into_mle()), - SelectorType::Prefix(_, _expr) => { + SelectorType::Whole(_) => Some(build_eq_x_r_vec(out_point).into_mle()), + SelectorType::Prefix(_) => { + let start = ctx.offset; + let end = start + ctx.num_instances; + assert!( + end <= (1 << ctx.num_vars), + "start: {}, num_instances: {}, num_vars: {}", + start, + ctx.num_instances, + ctx.num_vars + ); + let mut sel = build_eq_x_r_vec(out_point); - if num_instances < sel.len() { - sel.splice( - num_instances..sel.len(), - std::iter::repeat_n(E::ZERO, sel.len() - num_instances), - ); - } + sel.splice(0..start, repeat_n(E::ZERO, start)); + sel.splice(end..sel.len(), repeat_n(E::ZERO, sel.len() - end)); Some(sel.into_mle()) } + // compute true and false mle eq(1; b[..5]) * sel(y; b[5..]), and eq(1; b[..5]) * (eq() - sel(y; b[5..])) SelectorType::OrderedSparse32 { indices, .. } => { + assert_eq!(out_point.len(), ceil_log2(ctx.num_instances) + 5); + let mut sel = build_eq_x_r_vec(out_point); sel.par_chunks_exact_mut(CYCLIC_POW2_5.len()) .enumerate() .for_each(|(chunk_index, chunk)| { - if chunk_index >= num_instances { + if chunk_index >= ctx.num_instances { // Zero out the entire chunk if out of instance range chunk.iter_mut().for_each(|e| *e = E::ZERO); return; @@ -75,31 +180,107 @@ impl SelectorType { }); Some(sel.into_mle()) } + // also see evaluate() function for more explanation + SelectorType::QuarkBinaryTreeLessThan(_) => { + assert_eq!(ctx.offset, 0); + // num_instances: number of prefix one in leaf layer + let mut sel: Vec = build_eq_x_r_vec(out_point); + let n = sel.len(); + + let num_instances_sequence = (0..out_point.len()) + // clean up sig bits + .scan(ctx.num_instances, |n_instance, _| { + // n points to sum means we have n/2 addition pairs + let cur = *n_instance / 2; + // the next layer has ceil(n/2) points to sum + *n_instance = (*n_instance).div_ceil(2); + Some(cur) + }) + .collect::>(); + + // split sel into different size of region, set tailing 0 of respective chunk size + // 1st round: take v = sel[0..sel.len()/2], zero out v[num_instances_sequence[0]..] + // 2nd round: take v = sel[sel.len()/2 .. sel.len()/4], zero out v[num_instances_sequence[1]..] + // ... + // each round: progressively smaller chunk + // example: round 0 uses first half, round 1 uses next quarter, etc. + // compute cumulative start indices: + // e.g. chunk = n/2, then start = 0, chunk, chunk + chunk/2, chunk + chunk/2 + chunk/4, ... + // compute disjoint start indices and lengths + let chunks: Vec<(usize, usize)> = { + let mut result = Vec::new(); + let mut start = 0; + let mut chunk_len = n / 2; + while chunk_len > 0 { + result.push((start, chunk_len)); + start += chunk_len; + chunk_len /= 2; + } + result + }; + + for (i, (start, len)) in chunks.into_iter().enumerate() { + let slice = &mut sel[start..start + len]; + + // determine from which index to zero + let zero_start = num_instances_sequence.get(i).copied().unwrap_or(0).min(len); + + for x in &mut slice[zero_start..] { + *x = E::ZERO; + } + } + + // zero out last bh evaluations + *sel.last_mut().unwrap() = E::ZERO; + Some(sel.into_mle()) + } } } - /// Evaluate true and false mle eq(CYCLIC_POW2_5[round]; b[..5]) * sel(y; b[5..]), and eq(1; b[..5]) * (1 - sel(y; b[5..])) pub fn evaluate( &self, evals: &mut Vec, out_point: &Point, in_point: &Point, - num_instances: usize, + ctx: &SelectorContext, offset_eq_id: usize, ) { + assert_eq!(in_point.len(), ctx.num_vars); + assert_eq!(out_point.len(), ctx.num_vars); + let (expr, eval) = match self { SelectorType::None => return, SelectorType::Whole(expr) => { debug_assert_eq!(out_point.len(), in_point.len()); (expr, eq_eval(out_point, in_point)) } - SelectorType::Prefix(_, expr) => { - debug_assert!(num_instances <= (1 << out_point.len())); - ( - expr, - eq_eval_less_or_equal_than(num_instances - 1, out_point, in_point), - ) + SelectorType::Prefix(expression) => { + let start = ctx.offset; + let end = start + ctx.num_instances; + + assert_eq!(in_point.len(), out_point.len()); + assert!( + end <= (1 << out_point.len()), + "start: {}, num_instances: {}, num_vars: {}", + start, + ctx.num_instances, + ctx.num_vars + ); + + if end == 0 { + (expression, E::ZERO) + } else { + let eq_end = eq_eval_less_or_equal_than(end - 1, out_point, in_point); + let sel = if start > 0 { + let eq_start = eq_eval_less_or_equal_than(start - 1, out_point, in_point); + eq_end - eq_start + } else { + eq_end + }; + (expression, sel) + } } + // evaluate true and false mle eq(CYCLIC_POW2_5[round]; b[..5]) * sel(y; b[5..]), and eq(1; b[..5]) * (1 - sel(y; b[5..])) SelectorType::OrderedSparse32 { indices, expression, @@ -110,10 +291,64 @@ impl SelectorType { for index in indices { eval += out_subgroup_eq[*index] * in_subgroup_eq[*index]; } - let sel = - eq_eval_less_or_equal_than(num_instances - 1, &out_point[5..], &in_point[5..]); + let sel = eq_eval_less_or_equal_than( + ctx.num_instances - 1, + &out_point[5..], + &in_point[5..], + ); (expression, eval * sel) } + SelectorType::QuarkBinaryTreeLessThan(expr) => { + // num_instances count on leaf layer + // where nodes size is 2^(N) / 2 + // out_point.len() is also log(2^(N)) - 1 + // so num_instances and 1 << out_point.len() are on same scaling + assert!(ctx.num_instances > 0); + assert!(ctx.num_instances <= (1 << out_point.len())); + assert!(!out_point.is_empty()); + assert_eq!(out_point.len(), in_point.len()); + + // we break down this special selector evaluation into recursive structure + // iterating through out_point and in_point, for each i + // next_eval = lhs * (1-out_point[i]) * (1 - in_point[i]) + prev_eval * out_point[i] * in_point[i] + // where the lhs is in consecutive prefix 1 follow by 0 + + // calculate prefix 1 length of each layer + let mut prefix_one_seq = (0..out_point.len()) + .scan(ctx.num_instances, |n_instance, _| { + // n points to sum means we have n/2 addition pairs + let cur = *n_instance / 2; + // next layer has ceil(n/2) points to sum + *n_instance = (*n_instance).div_ceil(2); + Some(cur) + }) + .collect::>(); + prefix_one_seq.reverse(); + + let mut res = if prefix_one_seq[0] == 0 { + E::ZERO + } else { + assert_eq!(prefix_one_seq[0], 1); + (E::ONE - out_point[0]) * (E::ONE - in_point[0]) + }; + for i in 1..out_point.len() { + let num_prefix_one_lhs = prefix_one_seq[i]; + let lhs_res = if num_prefix_one_lhs == 0 { + E::ZERO + } else { + (E::ONE - out_point[i]) + * (E::ONE - in_point[i]) + * eq_eval_less_or_equal_than( + num_prefix_one_lhs - 1, + &out_point[..i], + &in_point[..i], + ) + }; + let rhs_res = (out_point[i] * in_point[i]) * res; + res = lhs_res + rhs_res; + } + (expr, res) + } }; let Expression::StructuralWitIn(wit_id, _) = expr else { panic!("Wrong selector expression format"); @@ -137,8 +372,63 @@ impl SelectorType { match self { Self::OrderedSparse32 { expression, .. } | Self::Whole(expression) - | Self::Prefix(_, expression) => expression, + | Self::Prefix(expression) => expression, e => unimplemented!("no selector expression in {:?}", e), } } } + +#[cfg(test)] +mod tests { + use ff_ext::{BabyBearExt4, FromUniformBytes}; + use multilinear_extensions::{ + StructuralWitIn, ToExpr, util::ceil_log2, virtual_poly::build_eq_x_r_vec, + }; + use p3::field::FieldAlgebra; + use rand::thread_rng; + + use crate::selector::{SelectorContext, SelectorType}; + + type E = BabyBearExt4; + + #[test] + fn test_quark_lt_selector() { + let mut rng = thread_rng(); + let n_points = 5; + let n_vars = ceil_log2(n_points); + let witin = StructuralWitIn { + id: 0, + witin_type: multilinear_extensions::StructuralWitInType::EqualDistanceSequence { + max_len: 0, + offset: 0, + multi_factor: 0, + descending: false, + }, + }; + let selector = SelectorType::QuarkBinaryTreeLessThan(witin.expr()); + let ctx = SelectorContext::new(0, n_points, n_vars); + let out_rt = E::random_vec(n_vars, &mut rng); + let sel_mle = selector.compute(&out_rt, &ctx).unwrap(); + + // if we have 5 points to sum, then + // in 1st layer: two additions p12 = p1 + p2, p34 = p3 + p4, p5 kept + // in 2nd layer: one addition p14 = p12 + p34, p5 kept + // in 3rd layer: one addition p15 = p14 + p5 + let eq = build_eq_x_r_vec(&out_rt); + let vec = sel_mle.get_ext_field_vec(); + assert_eq!(vec[0], eq[0]); // p1+p2 + assert_eq!(vec[1], eq[1]); // p3+p4 + assert_eq!(vec[2], E::ZERO); // p5 + assert_eq!(vec[3], E::ZERO); + assert_eq!(vec[4], eq[4]); // p1+p2+p3+p4 + assert_eq!(vec[5], E::ZERO); // p5 + assert_eq!(vec[6], eq[6]); // p1+p2+p3+p4+p5 + assert_eq!(vec[7], E::ZERO); + + let in_rt = E::random_vec(n_vars, &mut rng); + let mut evals = vec![]; + // TODO: avoid the param evals when we evaluate a selector + selector.evaluate(&mut evals, &out_rt, &in_rt, &ctx, 0); + assert_eq!(sel_mle.evaluate(&in_rt), evals[0]); + } +}