From e5389e7b7c3cbb447d5c8b9fb99ebd6a7e51d349 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Tue, 5 Aug 2025 21:02:12 +0800 Subject: [PATCH 01/46] fix mockprover error --- ceno_zkvm/src/instructions.rs | 14 +- ceno_zkvm/src/instructions/riscv/div.rs | 2 + .../src/instructions/riscv/ecall/keccak.rs | 10 +- ceno_zkvm/src/instructions/riscv/insn_base.rs | 2 + ceno_zkvm/src/scheme/mock_prover.rs | 149 ++++++++++++------ ceno_zkvm/src/structs.rs | 15 +- ceno_zkvm/src/tables/program.rs | 2 +- ceno_zkvm/src/tables/ram/ram_impl.rs | 2 +- ceno_zkvm/src/uint/arithmetic.rs | 10 +- gkr_iop/src/utils/lk_multiplicity.rs | 35 ++-- 10 files changed, 151 insertions(+), 90 deletions(-) diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 702e8ffa9..d810ac083 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -8,6 +8,7 @@ use gkr_iop::{ chip::Chip, gkr::{GKRCircuit, layer::Layer}, selector::SelectorType, + utils::lk_multiplicity::Multiplicity, }; use itertools::Itertools; use multilinear_extensions::{ToExpr, WitIn, util::max_usable_threads}; @@ -93,7 +94,7 @@ pub trait Instruction { num_witin: usize, num_structural_witin: usize, steps: Vec, - ) -> Result<(RMMCollections, LkMultiplicity), ZKVMError> { + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { // FIXME selector is the only structural witness // this is workaround, as call `construct_circuit` will not initialized selector // we can remove this one all opcode unittest migrate to call `build_gkr_iop_circuit` @@ -121,14 +122,14 @@ pub trait Instruction { raw_structual_witin.par_batch_iter_mut(num_instance_per_batch); raw_witin_iter - .zip(raw_structual_witin_iter) - .zip(steps.par_chunks(num_instance_per_batch)) + .zip_eq(raw_structual_witin_iter) + .zip_eq(steps.par_chunks(num_instance_per_batch)) .flat_map(|((instances, structural_instance), steps)| { let mut lk_multiplicity = lk_multiplicity.clone(); instances .chunks_mut(num_witin) .zip_eq(structural_instance.chunks_mut(num_structural_witin)) - .zip(steps) + .zip_eq(steps) .map(|((instance, structural_instance), step)| { set_val!(structural_instance, selector_witin, E::BaseField::ONE); Self::assign_instance(config, instance, &mut lk_multiplicity, step) @@ -139,6 +140,9 @@ pub trait Instruction { raw_witin.padding_by_strategy(); raw_structual_witin.padding_by_strategy(); - Ok(([raw_witin, raw_structual_witin], lk_multiplicity)) + Ok(( + [raw_witin, raw_structual_witin], + lk_multiplicity.into_finalize_result(), + )) } } diff --git a/ceno_zkvm/src/instructions/riscv/div.rs b/ceno_zkvm/src/instructions/riscv/div.rs index b8c599d24..7f397a68c 100644 --- a/ceno_zkvm/src/instructions/riscv/div.rs +++ b/ceno_zkvm/src/instructions/riscv/div.rs @@ -604,11 +604,13 @@ mod test { let expected_errors: &[_] = if is_ok { &[] } else { &[name] }; MockProver::assert_with_expected_errors( &cb, + &[], &raw_witin .to_mles() .into_iter() .map(|v| v.into()) .collect_vec(), + &[], &[insn_code], expected_errors, None, diff --git a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs index 51e353dc7..2eaa9f4f7 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs @@ -5,6 +5,7 @@ use ff_ext::ExtensionField; use gkr_iop::{ ProtocolBuilder, ProtocolWitnessGenerator, gkr::{GKRCircuit, layer::Layer}, + utils::lk_multiplicity::Multiplicity, }; use itertools::{Itertools, izip}; use multilinear_extensions::{ToExpr, util::max_usable_threads}; @@ -164,7 +165,7 @@ impl Instruction for KeccakInstruction { num_witin: usize, num_structural_witin: usize, steps: Vec, - ) -> Result<(RMMCollections, LkMultiplicity), ZKVMError> { + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { let mut lk_multiplicity = LkMultiplicity::default(); if steps.is_empty() { return Ok(( @@ -172,7 +173,7 @@ impl Instruction for KeccakInstruction { RowMajorMatrix::new(0, num_witin, InstancePaddingStrategy::Default), RowMajorMatrix::new(0, num_structural_witin, InstancePaddingStrategy::Default), ], - lk_multiplicity, + lk_multiplicity.into_finalize_result(), )); } let nthreads = max_usable_threads(); @@ -279,6 +280,9 @@ impl Instruction for KeccakInstruction { raw_witin.padding_by_strategy(); raw_structural_witin.padding_by_strategy(); - Ok(([raw_witin, raw_structural_witin], lk_multiplicity)) + Ok(( + [raw_witin, raw_structural_witin], + lk_multiplicity.into_finalize_result(), + )) } } diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index 89c132b94..d9bc62104 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -560,12 +560,14 @@ mod test { } MockProver::assert_with_expected_errors( &cb, + &[], &raw_witin .to_mles() .into_iter() .map(|v| v.into()) .collect_vec(), &[], + &[], if is_ok { &[] } else { &["mid_u14"] }, None, None, diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 7ef50ae0d..c10ebce86 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -23,7 +23,7 @@ use gkr_iop::{ OpsTable, ops::{AndTable, LtuTable, OrTable, PowTable, XorTable}, }, - utils::lk_multiplicity::LkMultiplicityRaw, + utils::lk_multiplicity::{LkMultiplicityRaw, Multiplicity}, }; use itertools::{Itertools, chain, enumerate, izip}; use multilinear_extensions::{ @@ -469,46 +469,72 @@ fn load_once_tables( impl<'a, E: ExtensionField + Hash> MockProver { pub fn run_with_challenge( cb: &CircuitBuilder, + fixed: &[ArcMultilinearExtension<'a, E>], wits_in: &[ArcMultilinearExtension<'a, E>], + structural_witin: &[ArcMultilinearExtension<'a, E>], challenge: [E; 2], - lkm: Option, + lkm: Option>, ) -> Result<(), Vec>> { - Self::run_maybe_challenge(cb, wits_in, &[], &[], Some(challenge), lkm) + Self::run_maybe_challenge( + cb, + fixed, + wits_in, + structural_witin, + &[], + &[], + Some(challenge), + lkm, + ) } pub fn run( cb: &CircuitBuilder, wits_in: &[ArcMultilinearExtension<'a, E>], program: &[ceno_emul::Instruction], - lkm: Option, + lkm: Option>, ) -> Result<(), Vec>> { - Self::run_maybe_challenge(cb, wits_in, program, &[], None, lkm) + Self::run_maybe_challenge(cb, &[], wits_in, &[], program, &[], None, lkm) } + #[allow(clippy::too_many_arguments)] fn run_maybe_challenge( cb: &CircuitBuilder, + fixed: &[ArcMultilinearExtension<'a, E>], wits_in: &[ArcMultilinearExtension<'a, E>], + structural_witin: &[ArcMultilinearExtension<'a, E>], program: &[ceno_emul::Instruction], pi: &[ArcMultilinearExtension<'a, E>], challenge: Option<[E; 2]>, - lkm: Option, + lkm: Option>, ) -> Result<(), Vec>> { let program = Program::from(program); let (table, challenge) = Self::load_tables_with_program(cb.cs, &program, challenge); - Self::run_maybe_challenge_with_table(cb.cs, &table, wits_in, pi, 1, challenge, lkm) - .map(|_| ()) + Self::run_maybe_challenge_with_table( + cb.cs, + &table, + fixed, + wits_in, + structural_witin, + pi, + 1, + challenge, + lkm, + ) + .map(|_| ()) } #[allow(clippy::too_many_arguments)] fn run_maybe_challenge_with_table( cs: &ConstraintSystem, table: &HashSet>, + fixed: &[ArcMultilinearExtension<'a, E>], wits_in: &[ArcMultilinearExtension<'a, E>], + structural_witin: &[ArcMultilinearExtension<'a, E>], pi: &[ArcMultilinearExtension<'a, E>], num_instances: usize, challenge: [E; 2], - expected_lkm: Option, + expected_lkm: Option>, ) -> Result, Vec>> { let mut shared_lkm = LkMultiplicityRaw::::default(); let mut errors = vec![]; @@ -544,9 +570,9 @@ impl<'a, E: ExtensionField + Hash> MockProver { cs.num_witin, cs.num_structural_witin, cs.num_fixed as WitnessId, - &[], + fixed, wits_in, - &[], + structural_witin, pi, &challenge, ); @@ -557,9 +583,9 @@ impl<'a, E: ExtensionField + Hash> MockProver { cs.num_witin, cs.num_structural_witin, cs.num_fixed as WitnessId, - &[], + fixed, wits_in, - &[], + structural_witin, pi, &challenge, ); @@ -587,9 +613,9 @@ impl<'a, E: ExtensionField + Hash> MockProver { cs.num_witin, cs.num_structural_witin, cs.num_fixed as WitnessId, - &[], + fixed, wits_in, - &[], + structural_witin, pi, &challenge, ); @@ -620,9 +646,9 @@ impl<'a, E: ExtensionField + Hash> MockProver { cs.num_witin, cs.num_structural_witin, cs.num_fixed as WitnessId, - &[], + fixed, wits_in, - &[], + structural_witin, pi, &challenge, ); @@ -660,22 +686,25 @@ impl<'a, E: ExtensionField + Hash> MockProver { cs.num_witin, cs.num_structural_witin, cs.num_fixed as WitnessId, - &[], + fixed, wits_in, - &[], + structural_witin, pi, &challenge, ); let mut arg_eval = arg_eval .get_ext_field_vec() .iter() - .map(|v| v.to_canonical_u64_vec()[0]) + .map(|v| { + let v = v.to_canonical_u64_vec(); + assert!(v[1..].iter().all(|x| *x == 0)); + v[0] + }) .take(num_instances) .collect_vec(); // Constant terms will have single element in `args_expr_evaluated`, so let's fix that. - if arg_expr.is_constant() { - assert_eq!(arg_eval.len(), 1); + if arg_expr.is_constant() && arg_eval.len() == 1 { arg_eval.resize(num_instances, arg_eval[0]) } arg_eval @@ -707,7 +736,10 @@ impl<'a, E: ExtensionField + Hash> MockProver { } } - errors.extend(compare_lkm(lkm_from_cs, lkm_from_assignment)); + errors.extend(compare_lkm( + lkm_from_cs.into_finalize_result(), + lkm_from_assignment, + )); } if errors.is_empty() { @@ -755,20 +787,23 @@ impl<'a, E: ExtensionField + Hash> MockProver { t_vec } + #[allow(clippy::too_many_arguments)] /// Run and check errors /// /// Panic, unless we see exactly the expected errors. /// (Expecting no errors is a valid expectation.) pub fn assert_with_expected_errors( cb: &CircuitBuilder, + fixed: &[ArcMultilinearExtension<'a, E>], wits_in: &[ArcMultilinearExtension<'a, E>], + structural_witin: &[ArcMultilinearExtension<'a, E>], program: &[ceno_emul::Instruction], constraint_names: &[&str], challenge: Option<[E; 2]>, - lkm: Option, + lkm: Option>, ) { let error_groups = if let Some(challenge) = challenge { - Self::run_with_challenge(cb, wits_in, challenge, lkm) + Self::run_with_challenge(cb, fixed, wits_in, structural_witin, challenge, lkm) } else { Self::run(cb, wits_in, program, lkm) } @@ -805,27 +840,43 @@ Hints: pub fn assert_satisfied_raw( cb: &CircuitBuilder, - [raw_witin, _raw_structural_witin]: RMMCollections, + [raw_witin, raw_structural_witin]: RMMCollections, program: &[ceno_emul::Instruction], challenge: Option<[E; 2]>, - lkm: Option, + lkm: Option>, ) { let wits_in = raw_witin .to_mles() .into_iter() .map(|v| v.into()) .collect_vec(); - Self::assert_satisfied(cb, &wits_in, program, challenge, lkm); + let structural_witin = raw_structural_witin + .to_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(); + Self::assert_satisfied(cb, &wits_in, &structural_witin, program, challenge, lkm); } pub fn assert_satisfied( cb: &CircuitBuilder, wits_in: &[ArcMultilinearExtension<'a, E>], + structural_witin: &[ArcMultilinearExtension<'a, E>], program: &[ceno_emul::Instruction], challenge: Option<[E; 2]>, - lkm: Option, + lkm: Option>, ) { - Self::assert_with_expected_errors(cb, wits_in, program, &[], challenge, lkm); + assert_eq!(cb.cs.num_fixed, 0); + Self::assert_with_expected_errors( + cb, + &[], + wits_in, + structural_witin, + program, + &[], + challenge, + lkm, + ); } pub fn assert_satisfied_full( @@ -913,14 +964,15 @@ Hints: ); // Assert opcode and check single opcode lk multiplicity // Also combine multiplicity in lkm_opcodes - let lkm_from_assignments = witnesses - .get_lk_mlt(circuit_name) - .map(LkMultiplicityRaw::deep_clone); + let lkm_from_assignments = witnesses.get_lk_mlt(circuit_name).cloned(); + match Self::run_maybe_challenge_with_table( cs, &lookup_table, + &fixed, &witness, - &[], + &structural_witness, + &pi_mles, num_rows, challenges, lkm_from_assignments, @@ -987,7 +1039,10 @@ Hints: } // Assert lkm between all tables and combined opcode circuits - let errors: Vec> = compare_lkm(lkm_tables, lkm_opcodes); + let errors: Vec> = compare_lkm( + lkm_tables.into_finalize_result(), + lkm_opcodes.into_finalize_result(), + ); if errors.is_empty() { tracing::info!("Mock proving successful for tables"); @@ -1014,6 +1069,7 @@ Hints: { let fixed = fixed_mles.get(circuit_name).unwrap(); let witness = wit_mles.get(circuit_name).unwrap(); + let structural_witness = structural_wit_mles.get(circuit_name).unwrap(); let num_rows = num_instances.get(circuit_name).unwrap(); if *num_rows == 0 { continue; @@ -1037,7 +1093,7 @@ Hints: cs.num_fixed as WitnessId, fixed, witness, - &[], + structural_witness, &pi_mles, &challenges, ) @@ -1058,7 +1114,7 @@ Hints: cs.num_fixed as WitnessId, fixed, witness, - &[], + structural_witness, &pi_mles, &challenges, ); @@ -1095,6 +1151,7 @@ Hints: { let fixed = fixed_mles.get(circuit_name).unwrap(); let witness = wit_mles.get(circuit_name).unwrap(); + let structural_witness = structural_wit_mles.get(circuit_name).unwrap(); let num_rows = num_instances.get(circuit_name).unwrap(); if *num_rows == 0 { continue; @@ -1118,7 +1175,7 @@ Hints: cs.num_fixed as WitnessId, fixed, witness, - &[], + structural_witness, &pi_mles, &challenges, ) @@ -1284,19 +1341,13 @@ Hints: } } -fn compare_lkm( - lkm_a: LkMultiplicityRaw, - lkm_b: LkMultiplicityRaw, -) -> Vec> +fn compare_lkm(lkm_a: Multiplicity, lkm_b: Multiplicity) -> Vec> where E: ExtensionField, K: LkMultiplicityKey + Default + Ord, { - let lkm_a = lkm_a.into_finalize_result(); - let lkm_b = lkm_b.into_finalize_result(); - // Compare each LK Multiplicity. - izip!(ROMType::iter(), &lkm_a, &lkm_b) + izip!(ROMType::iter(), &lkm_a.0, &lkm_b.0) .flat_map(|(rom_type, a_map, b_map)| { // We use a BTreeSet, instead of a HashSet, to ensure deterministic order. let keys: BTreeSet<_> = chain!(a_map.keys(), b_map.keys()).collect(); @@ -1399,7 +1450,7 @@ mod tests { .map(|f| f.into_mle().into()) .collect_vec(); - MockProver::assert_satisfied(&builder, &wits_in, &[], None, None); + MockProver::assert_satisfied(&builder, &wits_in, &[], &[], None, None); } #[derive(Debug)] @@ -1435,7 +1486,7 @@ mod tests { ]; let challenge = [1.into_f(), 1000.into_f()]; - MockProver::assert_satisfied(&builder, &wits_in, &[], Some(challenge), None); + MockProver::assert_satisfied(&builder, &wits_in, &[], &[], Some(challenge), None); } #[test] @@ -1449,7 +1500,7 @@ mod tests { let wits_in = vec![(vec![123u64.into_f()] as Vec).into_mle().into()]; let challenge = [2.into_f(), 1000.into_f()]; - let result = MockProver::run_with_challenge(&builder, &wits_in, challenge, None); + let result = MockProver::run_with_challenge(&builder, &[], &wits_in, &[], challenge, None); assert!(result.is_err(), "Expected error"); let err = result.unwrap_err(); assert_eq!( diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index fee796d3d..77ccf394b 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -4,11 +4,10 @@ use crate::{ instructions::Instruction, state::StateCircuit, tables::{RMMCollections, TableCircuit}, - witness::LkMultiplicity, }; use ceno_emul::{CENO_PLATFORM, Platform, StepRecord}; use ff_ext::ExtensionField; -use gkr_iop::{gkr::GKRCircuit, tables::LookupTable}; +use gkr_iop::{gkr::GKRCircuit, tables::LookupTable, utils::lk_multiplicity::Multiplicity}; use itertools::Itertools; use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{Expression, Instance}; @@ -279,7 +278,7 @@ impl ZKVMFixedTraces { pub struct ZKVMWitnesses { witnesses_opcodes: BTreeMap>, witnesses_tables: BTreeMap>, - lk_mlts: BTreeMap, + lk_mlts: BTreeMap>, combined_lk_mlt: Option>>, } @@ -292,7 +291,7 @@ impl ZKVMWitnesses { self.witnesses_tables.get(name) } - pub fn get_lk_mlt(&self, name: &String) -> Option<&LkMultiplicity> { + pub fn get_lk_mlt(&self, name: &String) -> Option<&Multiplicity> { self.lk_mlts.get(name) } @@ -332,13 +331,9 @@ impl ZKVMWitnesses { for name in keys { let lk_mlt = if is_keep_raw_lk_mlts { // mock prover needs the lk_mlt for processing, so we do not remove it - self.lk_mlts - .get(&name) - .unwrap() - .deep_clone() - .into_finalize_result() + self.lk_mlts.get(&name).cloned().unwrap() } else { - self.lk_mlts.remove(&name).unwrap().into_finalize_result() + self.lk_mlts.remove(&name).unwrap() }; if combined_lk_mlt.is_empty() { diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 638264400..b679b1195 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -246,7 +246,7 @@ mod tests { &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - &lkm, + &lkm.0, &program, ) .unwrap(); diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index b93b0019f..fd1506f36 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -479,7 +479,7 @@ mod tests { &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - &lkm, + &lkm.0, &input, ) .unwrap(); diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index 136073dd2..c9218f612 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -773,7 +773,7 @@ mod tests { .require_equal(|| "assert_g", &mut cb, &uint_e) .unwrap(); - MockProver::assert_satisfied(&cb, &witness_values, &[], None, None); + MockProver::assert_satisfied(&cb, &witness_values, &[], &[], None, None); } #[test] @@ -823,7 +823,7 @@ mod tests { .require_equal(|| "assert_g", &mut cb, &uint_g) .unwrap(); - MockProver::assert_satisfied(&cb, &witness_values, &[], None, None); + MockProver::assert_satisfied(&cb, &witness_values, &[], &[], None, None); } #[test] @@ -862,7 +862,7 @@ mod tests { .require_equal(|| "assert_e", &mut cb, &uint_e) .unwrap(); - MockProver::assert_satisfied(&cb, &witness_values, &[], None, None); + MockProver::assert_satisfied(&cb, &witness_values, &[], &[], None, None); } #[test] @@ -901,7 +901,7 @@ mod tests { .require_equal(|| "assert_e", &mut cb, &uint_e) .unwrap(); - MockProver::assert_satisfied(&cb, &witness_values, &[], None, None); + MockProver::assert_satisfied(&cb, &witness_values, &[], &[], None, None); } #[test] @@ -938,7 +938,7 @@ mod tests { .require_equal(|| "assert_g", &mut cb, &uint_c) .unwrap(); - MockProver::assert_satisfied(&cb, &witness_values, &[], None, None); + MockProver::assert_satisfied(&cb, &witness_values, &[], &[], None, None); } } } diff --git a/gkr_iop/src/utils/lk_multiplicity.rs b/gkr_iop/src/utils/lk_multiplicity.rs index 7170f30f9..c9068948c 100644 --- a/gkr_iop/src/utils/lk_multiplicity.rs +++ b/gkr_iop/src/utils/lk_multiplicity.rs @@ -5,7 +5,7 @@ use std::{ fmt::Debug, hash::Hash, mem::{self}, - ops::AddAssign, + ops::{AddAssign, Deref, DerefMut}, sync::Arc, }; use thread_local::ThreadLocal; @@ -20,6 +20,20 @@ pub type MultiplicityRaw = [HashMap; mem::variant_count::(pub MultiplicityRaw); +impl Deref for Multiplicity { + type Target = MultiplicityRaw; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for Multiplicity { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + /// A lock-free thread safe struct to count logup multiplicity for each ROM type /// Lock-free by thread-local such that each thread will only have its local copy /// struct is cloneable, for internallly it use Arc so the clone will be low cost @@ -34,7 +48,7 @@ where K: Copy + Clone + Debug + Default + Eq + Hash + Send, { fn add_assign(&mut self, rhs: Self) { - *self += Multiplicity(rhs.into_finalize_result()); + *self += Multiplicity(rhs.into_finalize_result().0); } } @@ -91,12 +105,12 @@ where impl LkMultiplicityRaw { /// Merge result from multiple thread local to single result. - pub fn into_finalize_result(self) -> MultiplicityRaw { + pub fn into_finalize_result(self) -> Multiplicity { let mut results = Multiplicity::default(); for y in Arc::try_unwrap(self.multiplicity).unwrap() { results += y.into_inner(); } - results.0 + results } pub fn increment(&mut self, rom_type: LookupTable, key: K) { @@ -115,17 +129,6 @@ impl LkMultiplicityRaw table.insert(key, count); } } - - /// Clone inner, expensive operation. - pub fn deep_clone(&self) -> Self { - let multiplicity = self.multiplicity.get_or_default(); - let deep_cloned = multiplicity.borrow().clone(); - let thread_local = ThreadLocal::new(); - thread_local.get_or(|| RefCell::new(deep_cloned)); - LkMultiplicityRaw { - multiplicity: Arc::new(thread_local), - } - } } /// Default LkMultiplicity with u64 key. @@ -212,6 +215,6 @@ mod tests { } let res = lkm.into_finalize_result(); // check multiplicity counts of assert_byte - assert_eq!(res[LookupTable::U8 as usize][&8], thread_count); + assert_eq!(res.0[LookupTable::U8 as usize][&8], thread_count); } } From 1359174466f579b55ce17f9a3458995263209f8a Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Tue, 5 Aug 2025 21:17:54 +0800 Subject: [PATCH 02/46] ci mock proving to debug build --- .github/workflows/integration.yml | 11 ++++++----- ceno_zkvm/src/scheme/mock_prover.rs | 16 ++++++---------- ff_ext/src/lib.rs | 7 +++++++ 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index b275433f4..af80eb654 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -3,7 +3,7 @@ name: Integrations on: merge_group: pull_request: - types: [synchronize, opened, reopened, ready_for_review] + types: [ synchronize, opened, reopened, ready_for_review ] push: branches: - master @@ -14,7 +14,7 @@ concurrency: jobs: skip_check: - runs-on: [self-hosted, Linux, X64] + runs-on: [ self-hosted, Linux, X64 ] outputs: should_skip: ${{ steps.skip_check.outputs.should_skip }} steps: @@ -27,14 +27,14 @@ jobs: do_not_skip: '["pull_request", "workflow_dispatch", "schedule", "merge_group"]' integration: - needs: [skip_check] + needs: [ skip_check ] if: | github.event.pull_request.draft == false && (github.event.action == 'ready_for_review' || needs.skip_check.outputs.should_skip != 'true') name: Integration testing timeout-minutes: 30 - runs-on: [self-hosted, Linux, X64] + runs-on: [ self-hosted, Linux, X64 ] steps: - uses: actions/checkout@v4 @@ -52,7 +52,8 @@ jobs: env: RUST_LOG: debug RUSTFLAGS: "-C opt-level=3" - run: cargo run --package ceno_zkvm --bin e2e -- --platform=ceno --hints=10 --public-io=4191 examples/target/riscv32im-ceno-zkvm-elf/debug/examples/fibonacci + MOCK_PROVING: 1 + run: cargo run --package ceno_zkvm --features sanity-check --bin e2e -- --platform=ceno --hints=10 --public-io=4191 examples/target/riscv32im-ceno-zkvm-elf/debug/examples/fibonacci - name: Run fibonacci (release) env: diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index c10ebce86..c5f8fbe6f 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -695,11 +695,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { let mut arg_eval = arg_eval .get_ext_field_vec() .iter() - .map(|v| { - let v = v.to_canonical_u64_vec(); - assert!(v[1..].iter().all(|x| *x == 0)); - v[0] - }) + .map(E::to_canonical_u64) .take(num_instances) .collect_vec(); @@ -1027,7 +1023,7 @@ Hints: lkm_tables.set_count( *rom_type, key, - multiplicity.to_canonical_u64_vec()[0] as usize, + multiplicity.to_canonical_u64() as usize, ); } } @@ -1218,9 +1214,9 @@ Hints: .take(10) .for_each(|(_, row)| { let pc = - gs_of_circuit.map_or(0, |gs| gs[*row][0].to_canonical_u64_vec()[0]); + gs_of_circuit.map_or(0, |gs| gs[*row][0].to_canonical_u64()); let ts = - gs_of_circuit.map_or(0, |gs| gs[*row][1].to_canonical_u64_vec()[0]); + gs_of_circuit.map_or(0, |gs| gs[*row][1].to_canonical_u64()); tracing::error!( "{} at row {} (pc={:x},ts={}) not found in {:?} writes", annotation, @@ -1256,9 +1252,9 @@ Hints: .take(10) .for_each(|(_, row)| { let pc = - gs_of_circuit.map_or(0, |gs| gs[*row][0].to_canonical_u64_vec()[0]); + gs_of_circuit.map_or(0, |gs| gs[*row][0].to_canonical_u64()); let ts = - gs_of_circuit.map_or(0, |gs| gs[*row][1].to_canonical_u64_vec()[0]); + gs_of_circuit.map_or(0, |gs| gs[*row][1].to_canonical_u64()); tracing::error!( "{} at row {} (pc={:x},ts={}) not found in {:?} reads", annotation, diff --git a/ff_ext/src/lib.rs b/ff_ext/src/lib.rs index 5916806a9..4c2c37008 100644 --- a/ff_ext/src/lib.rs +++ b/ff_ext/src/lib.rs @@ -146,4 +146,11 @@ pub trait ExtensionField: /// Convert a field elements to a u64 vector fn to_canonical_u64_vec(&self) -> Vec; + + /// retrive first field elements to u64 + fn to_canonical_u64(&self) -> u64 { + let res = self.to_canonical_u64_vec(); + assert!(res[1..].iter().all(|v| *v == 0)); + res[0] + } } From 7e7ee26855d786c4f5325bba0139862d13036081 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Tue, 5 Aug 2025 21:41:22 +0800 Subject: [PATCH 03/46] better coding style --- ceno_zkvm/src/scheme/mock_prover.rs | 14 ++++----- gkr_iop/src/utils/lk_multiplicity.rs | 43 +++++++++++++++++++++++++++- 2 files changed, 47 insertions(+), 10 deletions(-) diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index c5f8fbe6f..5eefbf315 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -1213,10 +1213,8 @@ Hints: .filter(|(read, _)| !$writes.contains(read)) .take(10) .for_each(|(_, row)| { - let pc = - gs_of_circuit.map_or(0, |gs| gs[*row][0].to_canonical_u64()); - let ts = - gs_of_circuit.map_or(0, |gs| gs[*row][1].to_canonical_u64()); + let pc = gs_of_circuit.map_or(0, |gs| gs[*row][0].to_canonical_u64()); + let ts = gs_of_circuit.map_or(0, |gs| gs[*row][1].to_canonical_u64()); tracing::error!( "{} at row {} (pc={:x},ts={}) not found in {:?} writes", annotation, @@ -1251,10 +1249,8 @@ Hints: .filter(|(write, _)| !$reads.contains(write)) .take(10) .for_each(|(_, row)| { - let pc = - gs_of_circuit.map_or(0, |gs| gs[*row][0].to_canonical_u64()); - let ts = - gs_of_circuit.map_or(0, |gs| gs[*row][1].to_canonical_u64()); + let pc = gs_of_circuit.map_or(0, |gs| gs[*row][0].to_canonical_u64()); + let ts = gs_of_circuit.map_or(0, |gs| gs[*row][1].to_canonical_u64()); tracing::error!( "{} at row {} (pc={:x},ts={}) not found in {:?} reads", annotation, @@ -1343,7 +1339,7 @@ where K: LkMultiplicityKey + Default + Ord, { // Compare each LK Multiplicity. - izip!(ROMType::iter(), &lkm_a.0, &lkm_b.0) + izip!(ROMType::iter(), &lkm_a, &lkm_b) .flat_map(|(rom_type, a_map, b_map)| { // We use a BTreeSet, instead of a HashSet, to ensure deterministic order. let keys: BTreeSet<_> = chain!(a_map.keys(), b_map.keys()).collect(); diff --git a/gkr_iop/src/utils/lk_multiplicity.rs b/gkr_iop/src/utils/lk_multiplicity.rs index c9068948c..2b93662f9 100644 --- a/gkr_iop/src/utils/lk_multiplicity.rs +++ b/gkr_iop/src/utils/lk_multiplicity.rs @@ -34,6 +34,47 @@ impl DerefMut for Multiplicity { } } +/// for consuming the wrapper +impl IntoIterator for Multiplicity +where + MultiplicityRaw: IntoIterator, +{ + type Item = as IntoIterator>::Item; + type IntoIter = as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +/// for immutable references +impl<'a, K> IntoIterator for &'a Multiplicity +where + &'a MultiplicityRaw: IntoIterator, +{ + type Item = <&'a MultiplicityRaw as IntoIterator>::Item; + type IntoIter = <&'a MultiplicityRaw as IntoIterator>::IntoIter; + + #[allow(clippy::into_iter_on_ref)] + fn into_iter(self) -> Self::IntoIter { + (&self.0).into_iter() + } +} + +/// for mutable references +impl<'a, K> IntoIterator for &'a mut Multiplicity +where + &'a mut MultiplicityRaw: IntoIterator, +{ + type Item = <&'a mut MultiplicityRaw as IntoIterator>::Item; + type IntoIter = <&'a mut MultiplicityRaw as IntoIterator>::IntoIter; + + #[allow(clippy::into_iter_on_ref)] + fn into_iter(self) -> Self::IntoIter { + (&mut self.0).into_iter() + } +} + /// A lock-free thread safe struct to count logup multiplicity for each ROM type /// Lock-free by thread-local such that each thread will only have its local copy /// struct is cloneable, for internallly it use Arc so the clone will be low cost @@ -215,6 +256,6 @@ mod tests { } let res = lkm.into_finalize_result(); // check multiplicity counts of assert_byte - assert_eq!(res.0[LookupTable::U8 as usize][&8], thread_count); + assert_eq!(res[LookupTable::U8 as usize][&8], thread_count); } } From d2d8031b26d8c660c4d498e105004043f98c3902 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 7 Aug 2025 11:08:07 +0800 Subject: [PATCH 04/46] wip --- ceno_zkvm/Cargo.toml | 3 +- ceno_zkvm/src/instructions/riscv/branch.rs | 41 ++++- .../riscv/branch/branch_circuit_v2.rs | 163 ++++++++++++++++++ 3 files changed, 197 insertions(+), 10 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 59f888737..43e8855ed 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -67,7 +67,7 @@ ceno-examples = { path = "../examples-builder" } glob = "0.3" [features] -default = ["forbid_overflow"] +default = ["forbid_overflow", "u16limb_circuit"] flamegraph = ["pprof2/flamegraph", "pprof2/criterion"] forbid_overflow = [] jemalloc = ["dep:tikv-jemallocator", "dep:tikv-jemalloc-ctl"] @@ -83,6 +83,7 @@ nightly-features = [ "witness/nightly-features", ] sanity-check = ["mpcs/sanity-check"] +u16limb_circuit = [] [[bench]] harness = false diff --git a/ceno_zkvm/src/instructions/riscv/branch.rs b/ceno_zkvm/src/instructions/riscv/branch.rs index f5a41de8d..fa273fc80 100644 --- a/ceno_zkvm/src/instructions/riscv/branch.rs +++ b/ceno_zkvm/src/instructions/riscv/branch.rs @@ -1,9 +1,8 @@ -mod branch_circuit; - use super::RIVInstruction; -use branch_circuit::BranchCircuit; use ceno_emul::InsnKind; +mod branch_circuit; +mod branch_circuit_v2; #[cfg(test)] mod test; @@ -11,34 +10,58 @@ pub struct BeqOp; impl RIVInstruction for BeqOp { const INST_KIND: InsnKind = InsnKind::BEQ; } -pub type BeqInstruction = BranchCircuit; +#[cfg(feature = "u16limb_circuit")] +// TODO use branch_circuit_v2 +pub type BeqInstruction = branch_circuit::BranchCircuit; +#[cfg(not(feature = "u16limb_circuit"))] +pub type BeqInstruction = branch_circuit::BranchCircuit; pub struct BneOp; impl RIVInstruction for BneOp { const INST_KIND: InsnKind = InsnKind::BNE; } -pub type BneInstruction = BranchCircuit; +#[cfg(feature = "u16limb_circuit")] +// TODO use branch_circuit_v2 +pub type BneInstruction = branch_circuit::BranchCircuit; +#[cfg(not(feature = "u16limb_circuit"))] +pub type BneInstruction = branch_circuit::BranchCircuit; pub struct BltuOp; impl RIVInstruction for BltuOp { const INST_KIND: InsnKind = InsnKind::BLTU; } -pub type BltuInstruction = BranchCircuit; +#[cfg(feature = "u16limb_circuit")] +// TODO use branch_circuit_v2 +pub type BltuInstruction = branch_circuit::BranchCircuit; +#[cfg(not(feature = "u16limb_circuit"))] +pub type BltuInstruction = branch_circuit::BranchCircuit; pub struct BgeuOp; impl RIVInstruction for BgeuOp { const INST_KIND: InsnKind = InsnKind::BGEU; } -pub type BgeuInstruction = BranchCircuit; +#[cfg(feature = "u16limb_circuit")] +// TODO use branch_circuit_v2 +pub type BgeuInstruction = branch_circuit::BranchCircuit; +#[cfg(not(feature = "u16limb_circuit"))] +pub type BgeuInstruction = branch_circuit::BranchCircuit; pub struct BltOp; impl RIVInstruction for BltOp { const INST_KIND: InsnKind = InsnKind::BLT; } -pub type BltInstruction = BranchCircuit; +#[cfg(feature = "u16limb_circuit")] +// TODO use branch_circuit_v2 +pub type BltInstruction = branch_circuit::BranchCircuit; +#[cfg(not(feature = "u16limb_circuit"))] +pub type BltInstruction = branch_circuit_v2::BranchCircuit; pub struct BgeOp; impl RIVInstruction for BgeOp { const INST_KIND: InsnKind = InsnKind::BGE; } -pub type BgeInstruction = BranchCircuit; +#[cfg(feature = "u16limb_circuit")] +// TODO use branch_circuit_v2 +pub type BgeInstruction = branch_circuit::BranchCircuit; +#[cfg(not(feature = "u16limb_circuit"))] +pub type BgeInstruction = branch_circuit_v2::BranchCircuit; diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs new file mode 100644 index 000000000..bb286cab7 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs @@ -0,0 +1,163 @@ +use crate::{ + circuit_builder::CircuitBuilder, + error::ZKVMError, + gadgets::SignedLtConfig, + instructions::{ + Instruction, + riscv::{ + RIVInstruction, + b_insn::BInstructionConfig, + constants::{UINT_LIMBS, UInt}, + }, + }, + structs::ProgramParams, + witness::LkMultiplicity, +}; +use ceno_emul::{InsnKind, StepRecord}; +use ff_ext::ExtensionField; +use gkr_iop::gadgets::{IsEqualConfig, IsLtConfig}; +use multilinear_extensions::{Expression, ToExpr, WitIn}; +use std::{array, marker::PhantomData}; + +pub struct BranchCircuit(PhantomData<(E, I)>); + +pub struct BranchConfig { + pub b_insn: BInstructionConfig, + pub read_rs1: UInt, + pub read_rs2: UInt, + + // Most significant limb of a and b respectively as a field element, will be range + // checked to be within [-128, 127) if signed and [0, 256) if unsigned. + pub read_rs1_msb_f: WitIn, + pub read_rs2_msb_f: WitIn, + + // 1 at the most significant index i such that read_rs1[i] != read_rs2[i], otherwise 0. If such + // an i exists, diff_val = read_rs2[i] - read_rs1[i]. + pub diff_marker: [WitIn; UINT_LIMBS], + pub diff_val: WitIn, + phantom: PhantomData, +} + +impl Instruction for BranchCircuit { + type InstructionConfig = BranchConfig; + + fn name() -> String { + todo!() + } + + fn construct_circuit( + circuit_builder: &mut CircuitBuilder, + param: &ProgramParams, + ) -> Result { + let read_rs1 = UInt::new_unchecked(|| "rs1_limbs", circuit_builder)?; + let read_rs2 = UInt::new_unchecked(|| "rs2_limbs", circuit_builder)?; + + let read_rs1_expr = read_rs1.expr(); + let read_rs2_expr = read_rs2.expr(); + + let read_rs1_msb_f = circuit_builder.create_witin(|| "read_rs1_msb_f"); + let read_rs2_msb_f = circuit_builder.create_witin(|| "read_rs2_msb_f"); + let diff_marker: [WitIn; UINT_LIMBS] = + array::from_fn(|_| circuit_builder.create_witin(|| "diff_maker")); + let diff_val = circuit_builder.create_witin(|| "diff_val"); + + // Check if a_msb_f and b_msb_f are signed values of read_rs1[NUM_LIMBS - 1] and read_rs2[NUM_LIMBS - 1] in prime field F. + let a_diff = read_rs1_expr[UINT_LIMBS - 1].expr() - read_rs1_msb_f.expr(); + let b_diff = read_rs2_expr[UINT_LIMBS - 1].expr() - read_rs2_msb_f.expr(); + + let (branch_taken_bit, is_equal, is_signed_lt, is_unsigned_lt) = match I::INST_KIND { + InsnKind::BEQ => { + let equal = IsEqualConfig::construct_circuit( + circuit_builder, + || "rs1!=rs2", + read_rs2.value(), + read_rs1.value(), + )?; + (equal.expr(), Some(equal), None, None) + } + InsnKind::BNE => { + let equal = IsEqualConfig::construct_circuit( + circuit_builder, + || "rs1==rs2", + read_rs2.value(), + read_rs1.value(), + )?; + (Expression::ONE - equal.expr(), Some(equal), None, None) + } + InsnKind::BLT => { + let signed_lt = SignedLtConfig::construct_circuit( + circuit_builder, + || "rs1 { + let signed_lt = SignedLtConfig::construct_circuit( + circuit_builder, + || "rs1>=rs2", + &read_rs1, + &read_rs2, + )?; + ( + Expression::ONE - signed_lt.expr(), + None, + Some(signed_lt), + None, + ) + } + InsnKind::BLTU => { + let unsigned_lt = IsLtConfig::construct_circuit( + circuit_builder, + || "rs1 { + let unsigned_lt = IsLtConfig::construct_circuit( + circuit_builder, + || "rs1 >= rs2", + read_rs1.value(), + read_rs2.value(), + UINT_LIMBS, + )?; + ( + Expression::ONE - unsigned_lt.expr(), + None, + None, + Some(unsigned_lt), + ) + } + _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), + }; + + let b_insn = BInstructionConfig::construct_circuit( + circuit_builder, + I::INST_KIND, + read_rs1.register_expr(), + read_rs2.register_expr(), + branch_taken_bit, + )?; + + // Ok(BranchConfig { + // b_insn, + // read_rs1, + // read_rs2, + // .. + // }) + todo!() + } + + fn assign_instance( + _config: &Self::InstructionConfig, + _instance: &mut [E::BaseField], + _lk_multiplicity: &mut LkMultiplicity, + _step: &StepRecord, + ) -> Result<(), ZKVMError> { + todo!() + } +} From cae1b09e91428cbffa9adc21c08397c33604d3f9 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 7 Aug 2025 15:43:46 +0800 Subject: [PATCH 05/46] complete v2 circuit --- .../riscv/branch/branch_circuit_v2.rs | 198 +++++++++++------- gkr_iop/src/circuit_builder.rs | 28 +++ 2 files changed, 155 insertions(+), 71 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs index bb286cab7..1d4e51b69 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs @@ -7,7 +7,7 @@ use crate::{ riscv::{ RIVInstruction, b_insn::BInstructionConfig, - constants::{UINT_LIMBS, UInt}, + constants::{LIMB_BITS, UINT_LIMBS, UInt}, }, }, structs::ProgramParams, @@ -17,7 +17,8 @@ use ceno_emul::{InsnKind, StepRecord}; use ff_ext::ExtensionField; use gkr_iop::gadgets::{IsEqualConfig, IsLtConfig}; use multilinear_extensions::{Expression, ToExpr, WitIn}; -use std::{array, marker::PhantomData}; +use p3::field::FieldAlgebra; +use std::{array, marker::PhantomData, ops::Neg}; pub struct BranchCircuit(PhantomData<(E, I)>); @@ -35,6 +36,10 @@ pub struct BranchConfig { // an i exists, diff_val = read_rs2[i] - read_rs1[i]. pub diff_marker: [WitIn; UINT_LIMBS], pub diff_val: WitIn, + + // 1 if read_rs1 < read_rs2, 0 otherwise. + pub cmp_lt: WitIn, + phantom: PhantomData, } @@ -47,8 +52,11 @@ impl Instruction for BranchCircuit, - param: &ProgramParams, + _param: &ProgramParams, ) -> Result { + // 1 if a < b, 0 otherwise. + let cmp_lt = circuit_builder.create_bit(|| "cmp_lt")?; + let read_rs1 = UInt::new_unchecked(|| "rs1_limbs", circuit_builder)?; let read_rs2 = UInt::new_unchecked(|| "rs2_limbs", circuit_builder)?; @@ -57,80 +65,123 @@ impl Instruction for BranchCircuit { - let equal = IsEqualConfig::construct_circuit( - circuit_builder, - || "rs1!=rs2", - read_rs2.value(), - read_rs1.value(), - )?; - (equal.expr(), Some(equal), None, None) - } - InsnKind::BNE => { - let equal = IsEqualConfig::construct_circuit( - circuit_builder, - || "rs1==rs2", - read_rs2.value(), - read_rs1.value(), - )?; - (Expression::ONE - equal.expr(), Some(equal), None, None) - } + // Check if read_rs1_msb_f and read_rs2_msb_f are signed values of read_rs1[NUM_LIMBS - 1] and read_rs2[NUM_LIMBS - 1] in prime field F. + let read_rs1_diff = read_rs1_expr[UINT_LIMBS - 1].expr() - read_rs1_msb_f.expr(); + let read_rs2_diff = read_rs2_expr[UINT_LIMBS - 1].expr() - read_rs2_msb_f.expr(); + + circuit_builder.require_zero( + || "read_rs1_diff", + read_rs1_diff.expr() + * (E::BaseField::from_canonical_u32(1 << LIMB_BITS).expr() - read_rs1_diff.expr()), + )?; + circuit_builder.require_zero( + || "read_rs2_diff", + read_rs2_diff.expr() + * (E::BaseField::from_canonical_u32(1 << LIMB_BITS).expr() - read_rs2_diff.expr()), + )?; + + let mut prefix_sum = Expression::ZERO; + + for i in (0..UINT_LIMBS).rev() { + let diff = (if i == UINT_LIMBS - 1 { + read_rs2_msb_f.expr() - read_rs1_msb_f.expr() + } else { + read_rs2_expr[i].expr() - read_rs1_expr[i].expr() + }) * (E::BaseField::from_canonical_u8(2).expr() * cmp_lt.expr() + - E::BaseField::ONE.expr()); + prefix_sum += diff_marker[i].expr(); + circuit_builder.require_zero( + || format!("prefix_diff_zero_{i}"), + (prefix_sum.clone() * diff.clone()).neg(), + )?; + circuit_builder.condition_require_zero( + || format!("diff_maker_conditional_equal_{i}"), + diff_marker[i].expr(), + diff_val.expr() - diff.expr(), + )?; + } + + // - If x != y, then prefix_sum = 1 so marker[i] must be 1 iff i is the first index where diff != 0. + // Constrains that diff == diff_val where diff_val is non-zero. + // - If x == y, then prefix_sum = 0 and cmp_lt = 0. + // Here, prefix_sum cannot be 1 because all diff are zero, making diff == diff_val fails. + + circuit_builder.assert_bit(|| "prefix_sum_bit", prefix_sum.expr())?; + circuit_builder.condition_require_zero( + || "cmp_lt_conditional_zero", + prefix_sum.expr().neg(), + cmp_lt.expr(), + )?; + + // Range check to ensure diff_val is non-zero. + circuit_builder.assert_ux::<_, _, 8>( + || "diff_val is non-zero", + prefix_sum.expr() * (diff_val.expr() - E::BaseField::ONE.expr()), + )?; + + let branch_taken_bit = match I::INST_KIND { InsnKind::BLT => { - let signed_lt = SignedLtConfig::construct_circuit( - circuit_builder, - || "rs1( + || "read_rs1_msb_f_signed_range_check", + read_rs1_msb_f.expr() + + E::BaseField::from_canonical_u32(1 << (LIMB_BITS - 1)).expr(), )?; - (signed_lt.expr(), None, Some(signed_lt), None) + + circuit_builder.assert_ux::<_, _, 8>( + || "read_rs2_msb_f_signed_range_check", + read_rs2_msb_f.expr() + + E::BaseField::from_canonical_u32(1 << (LIMB_BITS - 1)).expr(), + )?; + cmp_lt.expr() } InsnKind::BGE => { - let signed_lt = SignedLtConfig::construct_circuit( - circuit_builder, - || "rs1>=rs2", - &read_rs1, - &read_rs2, + // Check if read_rs1_msb_f and read_rs2_msb_f are in [-128, 127) if signed, [0, 256) if unsigned. + circuit_builder.assert_ux::<_, _, 8>( + || "read_rs1_msb_f_signed_range_check", + read_rs1_msb_f.expr() + + E::BaseField::from_canonical_u32(1 << (LIMB_BITS - 1)).expr(), + )?; + + circuit_builder.assert_ux::<_, _, 8>( + || "read_rs2_msb_f_signed_range_check", + read_rs2_msb_f.expr() + + E::BaseField::from_canonical_u32(1 << (LIMB_BITS - 1)).expr(), )?; - ( - Expression::ONE - signed_lt.expr(), - None, - Some(signed_lt), - None, - ) + Expression::ONE - cmp_lt.expr() } InsnKind::BLTU => { - let unsigned_lt = IsLtConfig::construct_circuit( - circuit_builder, - || "rs1( + || "read_rs1_msb_f_signed_range_check", + read_rs1_msb_f.expr(), + )?; + + circuit_builder.assert_ux::<_, _, 8>( + || "read_rs2_msb_f_signed_range_check", + read_rs2_msb_f.expr(), )?; - (unsigned_lt.expr(), None, None, Some(unsigned_lt)) + cmp_lt.expr() } InsnKind::BGEU => { - let unsigned_lt = IsLtConfig::construct_circuit( - circuit_builder, - || "rs1 >= rs2", - read_rs1.value(), - read_rs2.value(), - UINT_LIMBS, + // Check if read_rs1_msb_f and read_rs2_msb_f are in [-128, 127) if signed, [0, 256) if unsigned. + circuit_builder.assert_ux::<_, _, 8>( + || "read_rs1_msb_f_signed_range_check", + read_rs1_msb_f.expr(), )?; - ( - Expression::ONE - unsigned_lt.expr(), - None, - None, - Some(unsigned_lt), - ) + + circuit_builder.assert_ux::<_, _, 8>( + || "read_rs2_msb_f_signed_range_check", + read_rs2_msb_f.expr(), + )?; + Expression::ONE - cmp_lt.expr() } _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), }; @@ -143,13 +194,18 @@ impl Instruction for BranchCircuit CircuitBuilder<'a, E> { self.cs.rlc_chip_record(records) } + pub fn create_bit(&mut self, name_fn: N) -> Result + where + NR: Into, + N: FnOnce() -> NR + Clone, + { + let bit = self.cs.create_witin(name_fn.clone()); + self.assert_bit(name_fn, bit.expr())?; + + Ok(bit) + } + pub fn create_u8(&mut self, name_fn: N) -> Result where NR: Into, @@ -748,6 +759,23 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { ) } + pub fn condition_require_zero( + &mut self, + name_fn: N, + cond: Expression, + expr: Expression, + ) -> Result<(), CircuitBuilderError> + where + NR: Into, + N: FnOnce() -> NR, + { + // cond * expr + self.namespace( + || "cond_require_zero", + |cb| cb.cs.require_zero(name_fn, cond * expr.expr()), + ) + } + pub fn select( &mut self, cond: &Expression, From 6dea5b72b689df549171d7beb6f01e9e748a57f5 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 7 Aug 2025 17:13:53 +0800 Subject: [PATCH 06/46] branch v2 witness assignment --- ceno_zkvm/src/instructions/riscv/branch.rs | 12 +- .../riscv/branch/branch_circuit_v2.rs | 121 ++++++++++++++++-- ceno_zkvm/src/scheme/mock_prover.rs | 2 +- gkr_iop/src/circuit_builder.rs | 6 - 4 files changed, 118 insertions(+), 23 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/branch.rs b/ceno_zkvm/src/instructions/riscv/branch.rs index fa273fc80..8afd52ea0 100644 --- a/ceno_zkvm/src/instructions/riscv/branch.rs +++ b/ceno_zkvm/src/instructions/riscv/branch.rs @@ -32,7 +32,7 @@ impl RIVInstruction for BltuOp { } #[cfg(feature = "u16limb_circuit")] // TODO use branch_circuit_v2 -pub type BltuInstruction = branch_circuit::BranchCircuit; +pub type BltuInstruction = branch_circuit_v2::BranchCircuit; #[cfg(not(feature = "u16limb_circuit"))] pub type BltuInstruction = branch_circuit::BranchCircuit; @@ -42,7 +42,7 @@ impl RIVInstruction for BgeuOp { } #[cfg(feature = "u16limb_circuit")] // TODO use branch_circuit_v2 -pub type BgeuInstruction = branch_circuit::BranchCircuit; +pub type BgeuInstruction = branch_circuit_v2::BranchCircuit; #[cfg(not(feature = "u16limb_circuit"))] pub type BgeuInstruction = branch_circuit::BranchCircuit; @@ -52,9 +52,9 @@ impl RIVInstruction for BltOp { } #[cfg(feature = "u16limb_circuit")] // TODO use branch_circuit_v2 -pub type BltInstruction = branch_circuit::BranchCircuit; -#[cfg(not(feature = "u16limb_circuit"))] pub type BltInstruction = branch_circuit_v2::BranchCircuit; +#[cfg(not(feature = "u16limb_circuit"))] +pub type BltInstruction = branch_circuit::BranchCircuit; pub struct BgeOp; impl RIVInstruction for BgeOp { @@ -62,6 +62,6 @@ impl RIVInstruction for BgeOp { } #[cfg(feature = "u16limb_circuit")] // TODO use branch_circuit_v2 -pub type BgeInstruction = branch_circuit::BranchCircuit; -#[cfg(not(feature = "u16limb_circuit"))] pub type BgeInstruction = branch_circuit_v2::BranchCircuit; +#[cfg(not(feature = "u16limb_circuit"))] +pub type BgeInstruction = branch_circuit::BranchCircuit; diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs index 1d4e51b69..a02950720 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs @@ -1,7 +1,7 @@ use crate::{ + Value, circuit_builder::CircuitBuilder, error::ZKVMError, - gadgets::SignedLtConfig, instructions::{ Instruction, riscv::{ @@ -14,11 +14,11 @@ use crate::{ witness::LkMultiplicity, }; use ceno_emul::{InsnKind, StepRecord}; -use ff_ext::ExtensionField; -use gkr_iop::gadgets::{IsEqualConfig, IsLtConfig}; +use ff_ext::{ExtensionField, FieldInto, SmallField}; use multilinear_extensions::{Expression, ToExpr, WitIn}; use p3::field::FieldAlgebra; use std::{array, marker::PhantomData, ops::Neg}; +use witness::set_val; pub struct BranchCircuit(PhantomData<(E, I)>); @@ -47,9 +47,10 @@ impl Instruction for BranchCircuit; fn name() -> String { - todo!() + format!("{:?}", I::INST_KIND) } + /// circuit implementation refer from https://github.com/openvm-org/openvm/blob/ca36de3803213da664b03d111801ab903d55e360/extensions/rv32im/circuit/src/branch_lt/core.rs fn construct_circuit( circuit_builder: &mut CircuitBuilder, _param: &ProgramParams, @@ -99,7 +100,7 @@ impl Instruction for BranchCircuit Instruction for BranchCircuit Result<(), ZKVMError> { - todo!() + config + .b_insn + .assign_instance(instance, lk_multiplicity, step)?; + + let rs1 = Value::new_unchecked(step.rs1().unwrap().value); + let rs1_limbs = rs1.as_u16_limbs(); + let rs2 = Value::new_unchecked(step.rs2().unwrap().value); + let rs2_limbs = rs2.as_u16_limbs(); + config.read_rs1.assign_limbs(instance, rs1_limbs); + config.read_rs2.assign_limbs(instance, rs2_limbs); + + let (cmp_result, diff_idx, rs1_sign, rs2_sign) = + run_cmp::(step.insn.kind, rs1_limbs, rs2_limbs); + config + .diff_marker + .iter() + .enumerate() + .for_each(|(i, witin)| { + set_val!(instance, witin, (i == diff_idx) as u64); + }); + + let is_signed = matches!(step.insn().kind, InsnKind::BLT | InsnKind::BGE); + let is_ge = matches!(step.insn().kind, InsnKind::BGE | InsnKind::BGEU); + + let cmp_lt = cmp_result ^ is_ge; + set_val!(instance, config.cmp_lt, cmp_lt as u64); + + // We range check (read_rs1_msb_f + 128) and (read_rs2_msb_f + 128) if signed, + // read_rs1_msb_f and read_rs2_msb_f if not + let (read_rs1_msb_f, a_msb_range) = if rs1_sign { + ( + -E::BaseField::from_canonical_u32( + (1 << LIMB_BITS) - rs1_limbs[UINT_LIMBS - 1] as u32, + ), + rs1_limbs[UINT_LIMBS - 1] - (1 << (LIMB_BITS - 1)), + ) + } else { + ( + E::BaseField::from_canonical_u16(rs1_limbs[UINT_LIMBS - 1]), + rs1_limbs[UINT_LIMBS - 1] + ((is_signed as u16) << (LIMB_BITS - 1)), + ) + }; + let (read_rs2_msb_f, b_msb_range) = if rs2_sign { + ( + -E::BaseField::from_canonical_u32( + (1 << LIMB_BITS) - rs2_limbs[UINT_LIMBS - 1] as u32, + ), + rs2_limbs[UINT_LIMBS - 1] - (1 << (LIMB_BITS - 1)), + ) + } else { + ( + E::BaseField::from_canonical_u16(rs2_limbs[UINT_LIMBS - 1]), + rs2_limbs[UINT_LIMBS - 1] + ((is_signed as u16) << (LIMB_BITS - 1)), + ) + }; + + set_val!(instance, config.read_rs1_msb_f, read_rs1_msb_f); + set_val!(instance, config.read_rs2_msb_f, read_rs2_msb_f); + + let diff_val = if diff_idx == UINT_LIMBS { + 0 + } else if diff_idx == (UINT_LIMBS - 1) { + if cmp_lt { + read_rs2_msb_f - read_rs1_msb_f + } else { + read_rs1_msb_f - read_rs2_msb_f + } + .to_canonical_u64() as u16 + } else if cmp_lt { + rs2_limbs[diff_idx] - rs1_limbs[diff_idx] + } else { + rs1_limbs[diff_idx] - rs2_limbs[diff_idx] + }; + set_val!(instance, config.diff_val, diff_val as u64); + + if diff_idx != UINT_LIMBS { + lk_multiplicity.assert_ux::<8>((diff_val - 1) as u64); + } + + lk_multiplicity.assert_ux::<8>(a_msb_range as u64); + lk_multiplicity.assert_ux::<8>(b_msb_range as u64); + + Ok(()) + } +} + +// returns (cmp_result, diff_idx, x_sign, y_sign) +pub(super) fn run_cmp( + local_opcode: InsnKind, + x: &[u16], + y: &[u16], +) -> (bool, usize, bool, bool) { + let signed = matches!(local_opcode, InsnKind::BLT | InsnKind::BGE); + let ge_op = matches!(local_opcode, InsnKind::BGE | InsnKind::BGEU); + let x_sign = (x[UINT_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && signed; + let y_sign = (y[UINT_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && signed; + for i in (0..UINT_LIMBS).rev() { + if x[i] != y[i] { + return ((x[i] < y[i]) ^ x_sign ^ y_sign ^ ge_op, i, x_sign, y_sign); + } } + (ge_op, UINT_LIMBS, x_sign, y_sign) } diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 7ef50ae0d..00511cf0a 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -45,7 +45,7 @@ use std::{ use strum::IntoEnumIterator; use tiny_keccak::{Hasher, Keccak}; -const MAX_CONSTRAINT_DEGREE: usize = 2; +const MAX_CONSTRAINT_DEGREE: usize = 3; const MOCK_PROGRAM_SIZE: usize = 32; pub const MOCK_PC_START: ByteAddr = ByteAddr({ // This needs to be a static, because otherwise the compiler complains diff --git a/gkr_iop/src/circuit_builder.rs b/gkr_iop/src/circuit_builder.rs index 7084dfa1e..8673daa28 100644 --- a/gkr_iop/src/circuit_builder.rs +++ b/gkr_iop/src/circuit_builder.rs @@ -265,12 +265,6 @@ impl ConstraintSystem { .chain(record.clone()) .collect(), ); - assert_eq!( - rlc_record.degree(), - 1, - "rlc lk_record degree ({})", - name_fn().into() - ); self.lk_expressions.push(rlc_record); let path = self.ns.compute_path(name_fn().into()); self.lk_expressions_namespace_map.push(path); From 85699f3cda849bb7433dccc786ffb005b7873c96 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 7 Aug 2025 19:37:14 +0800 Subject: [PATCH 07/46] all test pass --- .../riscv/branch/branch_circuit_v2.rs | 38 ++++++++++--------- .../src/instructions/riscv/branch/test.rs | 1 - 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs index a02950720..57415e31d 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs @@ -17,7 +17,7 @@ use ceno_emul::{InsnKind, StepRecord}; use ff_ext::{ExtensionField, FieldInto, SmallField}; use multilinear_extensions::{Expression, ToExpr, WitIn}; use p3::field::FieldAlgebra; -use std::{array, marker::PhantomData, ops::Neg}; +use std::{array, marker::PhantomData}; use witness::set_val; pub struct BranchCircuit(PhantomData<(E, I)>); @@ -66,9 +66,9 @@ impl Instruction for BranchCircuit Instruction for BranchCircuit Instruction for BranchCircuit( + circuit_builder.assert_ux::<_, _, LIMB_BITS>( || "diff_val is non-zero", prefix_sum.expr() * (diff_val.expr() - E::BaseField::ONE.expr()), )?; let branch_taken_bit = match I::INST_KIND { InsnKind::BLT => { - // Check if read_rs1_msb_f and read_rs2_msb_f are in [-128, 127) if signed, [0, 256) if unsigned. - circuit_builder.assert_ux::<_, _, 8>( + // Check if read_rs1_msb_f and read_rs2_msb_f are in [-32768, 32767) if signed, [0, 65536) if unsigned. + circuit_builder.assert_ux::<_, _, LIMB_BITS>( || "read_rs1_msb_f_signed_range_check", read_rs1_msb_f.expr() + E::BaseField::from_canonical_u32(1 << (LIMB_BITS - 1)).expr(), )?; - circuit_builder.assert_ux::<_, _, 8>( + circuit_builder.assert_ux::<_, _, LIMB_BITS>( || "read_rs2_msb_f_signed_range_check", read_rs2_msb_f.expr() + E::BaseField::from_canonical_u32(1 << (LIMB_BITS - 1)).expr(), @@ -145,13 +145,13 @@ impl Instruction for BranchCircuit { // Check if read_rs1_msb_f and read_rs2_msb_f are in [-128, 127) if signed, [0, 256) if unsigned. - circuit_builder.assert_ux::<_, _, 8>( + circuit_builder.assert_ux::<_, _, LIMB_BITS>( || "read_rs1_msb_f_signed_range_check", read_rs1_msb_f.expr() + E::BaseField::from_canonical_u32(1 << (LIMB_BITS - 1)).expr(), )?; - circuit_builder.assert_ux::<_, _, 8>( + circuit_builder.assert_ux::<_, _, LIMB_BITS>( || "read_rs2_msb_f_signed_range_check", read_rs2_msb_f.expr() + E::BaseField::from_canonical_u32(1 << (LIMB_BITS - 1)).expr(), @@ -160,12 +160,12 @@ impl Instruction for BranchCircuit { // Check if read_rs1_msb_f and read_rs2_msb_f are in [-128, 127) if signed, [0, 256) if unsigned. - circuit_builder.assert_ux::<_, _, 8>( + circuit_builder.assert_ux::<_, _, LIMB_BITS>( || "read_rs1_msb_f_signed_range_check", read_rs1_msb_f.expr(), )?; - circuit_builder.assert_ux::<_, _, 8>( + circuit_builder.assert_ux::<_, _, LIMB_BITS>( || "read_rs2_msb_f_signed_range_check", read_rs2_msb_f.expr(), )?; @@ -173,12 +173,12 @@ impl Instruction for BranchCircuit { // Check if read_rs1_msb_f and read_rs2_msb_f are in [-128, 127) if signed, [0, 256) if unsigned. - circuit_builder.assert_ux::<_, _, 8>( + circuit_builder.assert_ux::<_, _, LIMB_BITS>( || "read_rs1_msb_f_signed_range_check", read_rs1_msb_f.expr(), )?; - circuit_builder.assert_ux::<_, _, 8>( + circuit_builder.assert_ux::<_, _, LIMB_BITS>( || "read_rs2_msb_f_signed_range_check", read_rs2_msb_f.expr(), )?; @@ -291,11 +291,13 @@ impl Instruction for BranchCircuit((diff_val - 1) as u64); + lk_multiplicity.assert_ux::((diff_val - 1) as u64); + } else { + lk_multiplicity.assert_ux::(0); } - lk_multiplicity.assert_ux::<8>(a_msb_range as u64); - lk_multiplicity.assert_ux::<8>(b_msb_range as u64); + lk_multiplicity.assert_ux::(a_msb_range as u64); + lk_multiplicity.assert_ux::(b_msb_range as u64); Ok(()) } diff --git a/ceno_zkvm/src/instructions/riscv/branch/test.rs b/ceno_zkvm/src/instructions/riscv/branch/test.rs index 6b5774fb0..fd9da41a1 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/test.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/test.rs @@ -118,7 +118,6 @@ fn impl_bltu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { }; let insn_code = encode_rv32(InsnKind::BLTU, 2, 3, 0, -8); - println!("{:?}", insn_code); let (raw_witin, lkm) = BltuInstruction::assign_instances( &config, circuit_builder.cs.num_witin as usize, From 9f5408cbb5105faa5c0f27b181be478d8662c516 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 7 Aug 2025 21:09:35 +0800 Subject: [PATCH 08/46] slt/sltu with limb style circuit --- ceno_zkvm/src/gadgets/mod.rs | 2 + ceno_zkvm/src/gadgets/signed_limbs.rs | 231 +++++++++++++++ ceno_zkvm/src/instructions/riscv/branch.rs | 12 - .../riscv/branch/branch_circuit_v2.rs | 270 ++---------------- ceno_zkvm/src/instructions/riscv/slt.rs | 145 ++-------- .../src/instructions/riscv/slt/slt_circuit.rs | 124 ++++++++ .../instructions/riscv/slt/slt_circuit_v2.rs | 107 +++++++ 7 files changed, 506 insertions(+), 385 deletions(-) create mode 100644 ceno_zkvm/src/gadgets/signed_limbs.rs create mode 100644 ceno_zkvm/src/instructions/riscv/slt/slt_circuit.rs create mode 100644 ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs diff --git a/ceno_zkvm/src/gadgets/mod.rs b/ceno_zkvm/src/gadgets/mod.rs index 630858fbc..6660a9339 100644 --- a/ceno_zkvm/src/gadgets/mod.rs +++ b/ceno_zkvm/src/gadgets/mod.rs @@ -2,6 +2,7 @@ mod div; mod is_lt; mod signed; mod signed_ext; +mod signed_limbs; pub use div::DivConfig; pub use gkr_iop::gadgets::{ @@ -10,3 +11,4 @@ pub use gkr_iop::gadgets::{ pub use is_lt::{AssertSignedLtConfig, SignedLtConfig}; pub use signed::Signed; pub use signed_ext::SignedExtendConfig; +pub use signed_limbs::{UIntLimbsLT, UIntLimbsLTConfig}; diff --git a/ceno_zkvm/src/gadgets/signed_limbs.rs b/ceno_zkvm/src/gadgets/signed_limbs.rs new file mode 100644 index 000000000..269eae87e --- /dev/null +++ b/ceno_zkvm/src/gadgets/signed_limbs.rs @@ -0,0 +1,231 @@ +use crate::{ + circuit_builder::CircuitBuilder, + error::ZKVMError, + instructions::riscv::constants::{LIMB_BITS, UINT_LIMBS, UInt}, +}; +use ff_ext::{ExtensionField, FieldInto, SmallField}; +use gkr_iop::error::CircuitBuilderError; +use multilinear_extensions::{Expression, ToExpr, WitIn}; +use p3::field::FieldAlgebra; +use std::{array, marker::PhantomData}; +use witness::set_val; + +pub struct UIntLimbsLTConfig { + // Most significant limb of a and b respectively as a field element, will be range + // checked to be within [-32768, 32767) if signed and [0, 65536) if unsigned. + pub a_msb_f: WitIn, + pub b_msb_f: WitIn, + + // 1 at the most significant index i such that a[i] != b[i], otherwise 0. If such + // an i exists, diff_val = a[i] - b[i]. + pub diff_marker: [WitIn; UINT_LIMBS], + pub diff_val: WitIn, + + // 1 if a < b, 0 otherwise. + pub cmp_lt: WitIn, + phantom: PhantomData, +} + +impl UIntLimbsLTConfig { + pub fn is_lt(&self) -> Expression { + self.cmp_lt.expr() + } +} + +pub struct UIntLimbsLT { + phantom: PhantomData, +} + +impl UIntLimbsLT { + pub fn construct_circuit( + circuit_builder: &mut CircuitBuilder, + a: &UInt, + b: &UInt, + is_signed: bool, + ) -> Result, ZKVMError> { + // 1 if a < b, 0 otherwise. + let cmp_lt = circuit_builder.create_bit(|| "cmp_lt")?; + + let a_expr = a.expr(); + let b_expr = b.expr(); + + let a_msb_f = circuit_builder.create_witin(|| "a_msb_f"); + let b_msb_f = circuit_builder.create_witin(|| "b_msb_f"); + let diff_marker: [WitIn; UINT_LIMBS] = array::from_fn(|i| { + circuit_builder + .create_bit(|| format!("diff_maker_{i}")) + .expect("create_bit_error") + }); + let diff_val = circuit_builder.create_witin(|| "diff_val"); + + // Check if a_msb_f and b_msb_f are signed values of a[NUM_LIMBS - 1] and b[NUM_LIMBS - 1] in prime field F. + let a_diff = a_expr[UINT_LIMBS - 1].expr() - a_msb_f.expr(); + let b_diff = b_expr[UINT_LIMBS - 1].expr() - b_msb_f.expr(); + + circuit_builder.require_zero( + || "a_diff", + a_diff.expr() + * (E::BaseField::from_canonical_u32(1 << LIMB_BITS).expr() - a_diff.expr()), + )?; + circuit_builder.require_zero( + || "b_diff", + b_diff.expr() + * (E::BaseField::from_canonical_u32(1 << LIMB_BITS).expr() - b_diff.expr()), + )?; + + let mut prefix_sum = Expression::ZERO; + + for i in (0..UINT_LIMBS).rev() { + let diff = (if i == UINT_LIMBS - 1 { + b_msb_f.expr() - a_msb_f.expr() + } else { + b_expr[i].expr() - a_expr[i].expr() + }) * (E::BaseField::from_canonical_u8(2).expr() * cmp_lt.expr() + - E::BaseField::ONE.expr()); + prefix_sum += diff_marker[i].expr(); + circuit_builder.require_zero( + || format!("prefix_diff_zero_{i}"), + (E::BaseField::ONE.expr() - prefix_sum.expr()) * diff.clone(), + )?; + circuit_builder.condition_require_zero( + || format!("diff_maker_conditional_equal_{i}"), + diff_marker[i].expr(), + diff_val.expr() - diff.expr(), + )?; + } + + // - If x != y, then prefix_sum = 1 so marker[i] must be 1 iff i is the first index where diff != 0. + // Constrains that diff == diff_val where diff_val is non-zero. + // - If x == y, then prefix_sum = 0 and cmp_lt = 0. + // Here, prefix_sum cannot be 1 because all diff are zero, making diff == diff_val fails. + + circuit_builder.assert_bit(|| "prefix_sum_bit", prefix_sum.expr())?; + circuit_builder.condition_require_zero( + || "cmp_lt_conditional_zero", + E::BaseField::ONE.expr() - prefix_sum.expr(), + cmp_lt.expr(), + )?; + + // Range check to ensure diff_val is non-zero. + circuit_builder.assert_ux::<_, _, LIMB_BITS>( + || "diff_val is non-zero", + prefix_sum.expr() * (diff_val.expr() - E::BaseField::ONE.expr()), + )?; + + circuit_builder.assert_ux::<_, _, LIMB_BITS>( + || "a_msb_f_signed_range_check", + a_msb_f.expr() + + if is_signed { + E::BaseField::from_canonical_u32(1 << (LIMB_BITS - 1)).expr() + } else { + Expression::ZERO + }, + )?; + + circuit_builder.assert_ux::<_, _, LIMB_BITS>( + || "b_msb_f_signed_range_check", + b_msb_f.expr() + + if is_signed { + E::BaseField::from_canonical_u32(1 << (LIMB_BITS - 1)).expr() + } else { + Expression::ZERO + }, + )?; + + Ok(UIntLimbsLTConfig { + a_msb_f, + b_msb_f, + diff_marker, + diff_val, + cmp_lt, + phantom: PhantomData, + }) + } + + pub fn assign( + config: &UIntLimbsLTConfig, + instance: &mut [E::BaseField], + lkm: &mut gkr_iop::utils::lk_multiplicity::LkMultiplicity, + a: &[u16], + b: &[u16], + is_signed: bool, + ) -> Result<(), CircuitBuilderError> { + let (cmp_lt, diff_idx, a_sign, b_sign) = run_cmp(is_signed, a, b); + config + .diff_marker + .iter() + .enumerate() + .for_each(|(i, witin)| { + set_val!(instance, witin, (i == diff_idx) as u64); + }); + set_val!(instance, config.cmp_lt, cmp_lt as u64); + + // We range check (read_rs1_msb_f + 128) and (read_rs2_msb_f + 128) if signed, + // read_rs1_msb_f and read_rs2_msb_f if not + let (a_msb_f, a_msb_range) = if a_sign { + ( + -E::BaseField::from_canonical_u32((1 << LIMB_BITS) - a[UINT_LIMBS - 1] as u32), + a[UINT_LIMBS - 1] - (1 << (LIMB_BITS - 1)), + ) + } else { + ( + E::BaseField::from_canonical_u16(a[UINT_LIMBS - 1]), + a[UINT_LIMBS - 1] + ((is_signed as u16) << (LIMB_BITS - 1)), + ) + }; + let (b_msb_f, b_msb_range) = if b_sign { + ( + -E::BaseField::from_canonical_u32((1 << LIMB_BITS) - b[UINT_LIMBS - 1] as u32), + b[UINT_LIMBS - 1] - (1 << (LIMB_BITS - 1)), + ) + } else { + ( + E::BaseField::from_canonical_u16(b[UINT_LIMBS - 1]), + b[UINT_LIMBS - 1] + ((is_signed as u16) << (LIMB_BITS - 1)), + ) + }; + + set_val!(instance, config.a_msb_f, a_msb_f); + set_val!(instance, config.b_msb_f, b_msb_f); + + let diff_val = if diff_idx == UINT_LIMBS { + 0 + } else if diff_idx == (UINT_LIMBS - 1) { + if cmp_lt { + b_msb_f - a_msb_f + } else { + a_msb_f - b_msb_f + } + .to_canonical_u64() as u16 + } else if cmp_lt { + b[diff_idx] - a[diff_idx] + } else { + a[diff_idx] - b[diff_idx] + }; + set_val!(instance, config.diff_val, diff_val as u64); + + if diff_idx != UINT_LIMBS { + lkm.assert_ux::((diff_val - 1) as u64); + } else { + lkm.assert_ux::(0); + } + + lkm.assert_ux::(a_msb_range as u64); + lkm.assert_ux::(b_msb_range as u64); + + Ok(()) + } +} + +// returns (cmp_lt, diff_idx, x_sign, y_sign) +// cmp_lt = true if a < b else false +pub fn run_cmp(signed: bool, x: &[u16], y: &[u16]) -> (bool, usize, bool, bool) { + let x_sign = (x[UINT_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && signed; + let y_sign = (y[UINT_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && signed; + for i in (0..UINT_LIMBS).rev() { + if x[i] != y[i] { + return ((x[i] < y[i]) ^ x_sign ^ y_sign, i, x_sign, y_sign); + } + } + (false, UINT_LIMBS, x_sign, y_sign) +} diff --git a/ceno_zkvm/src/instructions/riscv/branch.rs b/ceno_zkvm/src/instructions/riscv/branch.rs index 8afd52ea0..dc2c8c9e6 100644 --- a/ceno_zkvm/src/instructions/riscv/branch.rs +++ b/ceno_zkvm/src/instructions/riscv/branch.rs @@ -10,20 +10,12 @@ pub struct BeqOp; impl RIVInstruction for BeqOp { const INST_KIND: InsnKind = InsnKind::BEQ; } -#[cfg(feature = "u16limb_circuit")] -// TODO use branch_circuit_v2 -pub type BeqInstruction = branch_circuit::BranchCircuit; -#[cfg(not(feature = "u16limb_circuit"))] pub type BeqInstruction = branch_circuit::BranchCircuit; pub struct BneOp; impl RIVInstruction for BneOp { const INST_KIND: InsnKind = InsnKind::BNE; } -#[cfg(feature = "u16limb_circuit")] -// TODO use branch_circuit_v2 -pub type BneInstruction = branch_circuit::BranchCircuit; -#[cfg(not(feature = "u16limb_circuit"))] pub type BneInstruction = branch_circuit::BranchCircuit; pub struct BltuOp; @@ -31,7 +23,6 @@ impl RIVInstruction for BltuOp { const INST_KIND: InsnKind = InsnKind::BLTU; } #[cfg(feature = "u16limb_circuit")] -// TODO use branch_circuit_v2 pub type BltuInstruction = branch_circuit_v2::BranchCircuit; #[cfg(not(feature = "u16limb_circuit"))] pub type BltuInstruction = branch_circuit::BranchCircuit; @@ -41,7 +32,6 @@ impl RIVInstruction for BgeuOp { const INST_KIND: InsnKind = InsnKind::BGEU; } #[cfg(feature = "u16limb_circuit")] -// TODO use branch_circuit_v2 pub type BgeuInstruction = branch_circuit_v2::BranchCircuit; #[cfg(not(feature = "u16limb_circuit"))] pub type BgeuInstruction = branch_circuit::BranchCircuit; @@ -51,7 +41,6 @@ impl RIVInstruction for BltOp { const INST_KIND: InsnKind = InsnKind::BLT; } #[cfg(feature = "u16limb_circuit")] -// TODO use branch_circuit_v2 pub type BltInstruction = branch_circuit_v2::BranchCircuit; #[cfg(not(feature = "u16limb_circuit"))] pub type BltInstruction = branch_circuit::BranchCircuit; @@ -61,7 +50,6 @@ impl RIVInstruction for BgeOp { const INST_KIND: InsnKind = InsnKind::BGE; } #[cfg(feature = "u16limb_circuit")] -// TODO use branch_circuit_v2 pub type BgeInstruction = branch_circuit_v2::BranchCircuit; #[cfg(not(feature = "u16limb_circuit"))] pub type BgeInstruction = branch_circuit::BranchCircuit; diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs index 57415e31d..6fc4e121e 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs @@ -2,23 +2,18 @@ use crate::{ Value, circuit_builder::CircuitBuilder, error::ZKVMError, + gadgets::{UIntLimbsLT, UIntLimbsLTConfig}, instructions::{ Instruction, - riscv::{ - RIVInstruction, - b_insn::BInstructionConfig, - constants::{LIMB_BITS, UINT_LIMBS, UInt}, - }, + riscv::{RIVInstruction, b_insn::BInstructionConfig, constants::UInt}, }, structs::ProgramParams, witness::LkMultiplicity, }; use ceno_emul::{InsnKind, StepRecord}; -use ff_ext::{ExtensionField, FieldInto, SmallField}; -use multilinear_extensions::{Expression, ToExpr, WitIn}; -use p3::field::FieldAlgebra; -use std::{array, marker::PhantomData}; -use witness::set_val; +use ff_ext::ExtensionField; +use multilinear_extensions::Expression; +use std::marker::PhantomData; pub struct BranchCircuit(PhantomData<(E, I)>); @@ -27,19 +22,7 @@ pub struct BranchConfig { pub read_rs1: UInt, pub read_rs2: UInt, - // Most significant limb of a and b respectively as a field element, will be range - // checked to be within [-128, 127) if signed and [0, 256) if unsigned. - pub read_rs1_msb_f: WitIn, - pub read_rs2_msb_f: WitIn, - - // 1 at the most significant index i such that read_rs1[i] != read_rs2[i], otherwise 0. If such - // an i exists, diff_val = read_rs2[i] - read_rs1[i]. - pub diff_marker: [WitIn; UINT_LIMBS], - pub diff_val: WitIn, - - // 1 if read_rs1 < read_rs2, 0 otherwise. - pub cmp_lt: WitIn, - + pub uint_lt_config: UIntLimbsLTConfig, phantom: PhantomData, } @@ -55,138 +38,18 @@ impl Instruction for BranchCircuit, _param: &ProgramParams, ) -> Result { - // 1 if a < b, 0 otherwise. - let cmp_lt = circuit_builder.create_bit(|| "cmp_lt")?; - let read_rs1 = UInt::new_unchecked(|| "rs1_limbs", circuit_builder)?; let read_rs2 = UInt::new_unchecked(|| "rs2_limbs", circuit_builder)?; - let read_rs1_expr = read_rs1.expr(); - let read_rs2_expr = read_rs2.expr(); - - let read_rs1_msb_f = circuit_builder.create_witin(|| "read_rs1_msb_f"); - let read_rs2_msb_f = circuit_builder.create_witin(|| "read_rs2_msb_f"); - let diff_marker: [WitIn; UINT_LIMBS] = array::from_fn(|i| { - circuit_builder - .create_bit(|| format!("diff_maker_{i}")) - .expect("create_bit_error") - }); - let diff_val = circuit_builder.create_witin(|| "diff_val"); - - // Check if read_rs1_msb_f and read_rs2_msb_f are signed values of read_rs1[NUM_LIMBS - 1] and read_rs2[NUM_LIMBS - 1] in prime field F. - let read_rs1_diff = read_rs1_expr[UINT_LIMBS - 1].expr() - read_rs1_msb_f.expr(); - let read_rs2_diff = read_rs2_expr[UINT_LIMBS - 1].expr() - read_rs2_msb_f.expr(); - - circuit_builder.require_zero( - || "read_rs1_diff", - read_rs1_diff.expr() - * (E::BaseField::from_canonical_u32(1 << LIMB_BITS).expr() - read_rs1_diff.expr()), - )?; - circuit_builder.require_zero( - || "read_rs2_diff", - read_rs2_diff.expr() - * (E::BaseField::from_canonical_u32(1 << LIMB_BITS).expr() - read_rs2_diff.expr()), - )?; - - let mut prefix_sum = Expression::ZERO; - - for i in (0..UINT_LIMBS).rev() { - let diff = (if i == UINT_LIMBS - 1 { - read_rs2_msb_f.expr() - read_rs1_msb_f.expr() - } else { - read_rs2_expr[i].expr() - read_rs1_expr[i].expr() - }) * (E::BaseField::from_canonical_u8(2).expr() * cmp_lt.expr() - - E::BaseField::ONE.expr()); - prefix_sum += diff_marker[i].expr(); - circuit_builder.require_zero( - || format!("prefix_diff_zero_{i}"), - (E::BaseField::ONE.expr() - prefix_sum.expr()) * diff.clone(), - )?; - circuit_builder.condition_require_zero( - || format!("diff_maker_conditional_equal_{i}"), - diff_marker[i].expr(), - diff_val.expr() - diff.expr(), - )?; - } - - // - If x != y, then prefix_sum = 1 so marker[i] must be 1 iff i is the first index where diff != 0. - // Constrains that diff == diff_val where diff_val is non-zero. - // - If x == y, then prefix_sum = 0 and cmp_lt = 0. - // Here, prefix_sum cannot be 1 because all diff are zero, making diff == diff_val fails. - - circuit_builder.assert_bit(|| "prefix_sum_bit", prefix_sum.expr())?; - circuit_builder.condition_require_zero( - || "cmp_lt_conditional_zero", - E::BaseField::ONE.expr() - prefix_sum.expr(), - cmp_lt.expr(), - )?; - - // Range check to ensure diff_val is non-zero. - circuit_builder.assert_ux::<_, _, LIMB_BITS>( - || "diff_val is non-zero", - prefix_sum.expr() * (diff_val.expr() - E::BaseField::ONE.expr()), - )?; - - let branch_taken_bit = match I::INST_KIND { - InsnKind::BLT => { - // Check if read_rs1_msb_f and read_rs2_msb_f are in [-32768, 32767) if signed, [0, 65536) if unsigned. - circuit_builder.assert_ux::<_, _, LIMB_BITS>( - || "read_rs1_msb_f_signed_range_check", - read_rs1_msb_f.expr() - + E::BaseField::from_canonical_u32(1 << (LIMB_BITS - 1)).expr(), - )?; - - circuit_builder.assert_ux::<_, _, LIMB_BITS>( - || "read_rs2_msb_f_signed_range_check", - read_rs2_msb_f.expr() - + E::BaseField::from_canonical_u32(1 << (LIMB_BITS - 1)).expr(), - )?; - cmp_lt.expr() - } - InsnKind::BGE => { - // Check if read_rs1_msb_f and read_rs2_msb_f are in [-128, 127) if signed, [0, 256) if unsigned. - circuit_builder.assert_ux::<_, _, LIMB_BITS>( - || "read_rs1_msb_f_signed_range_check", - read_rs1_msb_f.expr() - + E::BaseField::from_canonical_u32(1 << (LIMB_BITS - 1)).expr(), - )?; - - circuit_builder.assert_ux::<_, _, LIMB_BITS>( - || "read_rs2_msb_f_signed_range_check", - read_rs2_msb_f.expr() - + E::BaseField::from_canonical_u32(1 << (LIMB_BITS - 1)).expr(), - )?; - Expression::ONE - cmp_lt.expr() - } - InsnKind::BLTU => { - // Check if read_rs1_msb_f and read_rs2_msb_f are in [-128, 127) if signed, [0, 256) if unsigned. - circuit_builder.assert_ux::<_, _, LIMB_BITS>( - || "read_rs1_msb_f_signed_range_check", - read_rs1_msb_f.expr(), - )?; - - circuit_builder.assert_ux::<_, _, LIMB_BITS>( - || "read_rs2_msb_f_signed_range_check", - read_rs2_msb_f.expr(), - )?; - cmp_lt.expr() - } - InsnKind::BGEU => { - // Check if read_rs1_msb_f and read_rs2_msb_f are in [-128, 127) if signed, [0, 256) if unsigned. - circuit_builder.assert_ux::<_, _, LIMB_BITS>( - || "read_rs1_msb_f_signed_range_check", - read_rs1_msb_f.expr(), - )?; - - circuit_builder.assert_ux::<_, _, LIMB_BITS>( - || "read_rs2_msb_f_signed_range_check", - read_rs2_msb_f.expr(), - )?; - Expression::ONE - cmp_lt.expr() - } - _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), + let is_signed = matches!(I::INST_KIND, InsnKind::BLT | InsnKind::BGE); + let is_ge = matches!(I::INST_KIND, InsnKind::BGEU | InsnKind::BGE); + let uint_lt_config = + UIntLimbsLT::::construct_circuit(circuit_builder, &read_rs1, &read_rs2, is_signed)?; + let branch_taken_bit = if is_ge { + Expression::ONE - uint_lt_config.is_lt() + } else { + uint_lt_config.is_lt() }; - let b_insn = BInstructionConfig::construct_circuit( circuit_builder, I::INST_KIND, @@ -199,12 +62,7 @@ impl Instruction for BranchCircuit Instruction for BranchCircuit(step.insn.kind, rs1_limbs, rs2_limbs); - config - .diff_marker - .iter() - .enumerate() - .for_each(|(i, witin)| { - set_val!(instance, witin, (i == diff_idx) as u64); - }); - let is_signed = matches!(step.insn().kind, InsnKind::BLT | InsnKind::BGE); - let is_ge = matches!(step.insn().kind, InsnKind::BGE | InsnKind::BGEU); - - let cmp_lt = cmp_result ^ is_ge; - set_val!(instance, config.cmp_lt, cmp_lt as u64); - - // We range check (read_rs1_msb_f + 128) and (read_rs2_msb_f + 128) if signed, - // read_rs1_msb_f and read_rs2_msb_f if not - let (read_rs1_msb_f, a_msb_range) = if rs1_sign { - ( - -E::BaseField::from_canonical_u32( - (1 << LIMB_BITS) - rs1_limbs[UINT_LIMBS - 1] as u32, - ), - rs1_limbs[UINT_LIMBS - 1] - (1 << (LIMB_BITS - 1)), - ) - } else { - ( - E::BaseField::from_canonical_u16(rs1_limbs[UINT_LIMBS - 1]), - rs1_limbs[UINT_LIMBS - 1] + ((is_signed as u16) << (LIMB_BITS - 1)), - ) - }; - let (read_rs2_msb_f, b_msb_range) = if rs2_sign { - ( - -E::BaseField::from_canonical_u32( - (1 << LIMB_BITS) - rs2_limbs[UINT_LIMBS - 1] as u32, - ), - rs2_limbs[UINT_LIMBS - 1] - (1 << (LIMB_BITS - 1)), - ) - } else { - ( - E::BaseField::from_canonical_u16(rs2_limbs[UINT_LIMBS - 1]), - rs2_limbs[UINT_LIMBS - 1] + ((is_signed as u16) << (LIMB_BITS - 1)), - ) - }; - - set_val!(instance, config.read_rs1_msb_f, read_rs1_msb_f); - set_val!(instance, config.read_rs2_msb_f, read_rs2_msb_f); - - let diff_val = if diff_idx == UINT_LIMBS { - 0 - } else if diff_idx == (UINT_LIMBS - 1) { - if cmp_lt { - read_rs2_msb_f - read_rs1_msb_f - } else { - read_rs1_msb_f - read_rs2_msb_f - } - .to_canonical_u64() as u16 - } else if cmp_lt { - rs2_limbs[diff_idx] - rs1_limbs[diff_idx] - } else { - rs1_limbs[diff_idx] - rs2_limbs[diff_idx] - }; - set_val!(instance, config.diff_val, diff_val as u64); - - if diff_idx != UINT_LIMBS { - lk_multiplicity.assert_ux::((diff_val - 1) as u64); - } else { - lk_multiplicity.assert_ux::(0); - } - - lk_multiplicity.assert_ux::(a_msb_range as u64); - lk_multiplicity.assert_ux::(b_msb_range as u64); - + UIntLimbsLT::::assign( + &config.uint_lt_config, + instance, + lk_multiplicity, + rs1_limbs, + rs2_limbs, + is_signed, + )?; Ok(()) } } - -// returns (cmp_result, diff_idx, x_sign, y_sign) -pub(super) fn run_cmp( - local_opcode: InsnKind, - x: &[u16], - y: &[u16], -) -> (bool, usize, bool, bool) { - let signed = matches!(local_opcode, InsnKind::BLT | InsnKind::BGE); - let ge_op = matches!(local_opcode, InsnKind::BGE | InsnKind::BGEU); - let x_sign = (x[UINT_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && signed; - let y_sign = (y[UINT_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && signed; - for i in (0..UINT_LIMBS).rev() { - if x[i] != y[i] { - return ((x[i] < y[i]) ^ x_sign ^ y_sign ^ ge_op, i, x_sign, y_sign); - } - } - (ge_op, UINT_LIMBS, x_sign, y_sign) -} diff --git a/ceno_zkvm/src/instructions/riscv/slt.rs b/ceno_zkvm/src/instructions/riscv/slt.rs index 34e99c073..81213b66f 100644 --- a/ceno_zkvm/src/instructions/riscv/slt.rs +++ b/ceno_zkvm/src/instructions/riscv/slt.rs @@ -1,139 +1,27 @@ -use std::marker::PhantomData; +mod slt_circuit; +mod slt_circuit_v2; -use ceno_emul::{InsnKind, SWord, StepRecord}; -use ff_ext::ExtensionField; +use ceno_emul::InsnKind; -use super::{ - RIVInstruction, - constants::{UINT_LIMBS, UInt}, - r_insn::RInstructionConfig, -}; -use crate::{ - circuit_builder::CircuitBuilder, - error::ZKVMError, - gadgets::{IsLtConfig, SignedLtConfig}, - instructions::Instruction, - structs::ProgramParams, - uint::Value, - witness::LkMultiplicity, -}; - -pub struct SetLessThanInstruction(PhantomData<(E, I)>); +use super::RIVInstruction; pub struct SltOp; impl RIVInstruction for SltOp { const INST_KIND: InsnKind = InsnKind::SLT; } -pub type SltInstruction = SetLessThanInstruction; +#[cfg(feature = "u16limb_circuit")] +pub type SltInstruction = slt_circuit_v2::SetLessThanInstruction; +#[cfg(not(feature = "u16limb_circuit"))] +pub type SltInstruction = slt_circuit::SetLessThanInstruction; pub struct SltuOp; impl RIVInstruction for SltuOp { const INST_KIND: InsnKind = InsnKind::SLTU; } -pub type SltuInstruction = SetLessThanInstruction; - -/// This config handles R-Instructions that represent registers values as 2 * u16. -pub struct SetLessThanConfig { - r_insn: RInstructionConfig, - - rs1_read: UInt, - rs2_read: UInt, - #[cfg_attr(not(test), allow(dead_code))] - rd_written: UInt, - - deps: SetLessThanDependencies, -} - -enum SetLessThanDependencies { - Slt { signed_lt: SignedLtConfig }, - Sltu { is_lt: IsLtConfig }, -} - -impl Instruction for SetLessThanInstruction { - type InstructionConfig = SetLessThanConfig; - - fn name() -> String { - format!("{:?}", I::INST_KIND) - } - - fn construct_circuit( - cb: &mut CircuitBuilder, - _params: &ProgramParams, - ) -> Result { - // If rs1_read < rs2_read, rd_written = 1. Otherwise rd_written = 0 - let rs1_read = UInt::new_unchecked(|| "rs1_read", cb)?; - let rs2_read = UInt::new_unchecked(|| "rs2_read", cb)?; - - let (deps, rd_written) = match I::INST_KIND { - InsnKind::SLT => { - let signed_lt = - SignedLtConfig::construct_circuit(cb, || "rs1 < rs2", &rs1_read, &rs2_read)?; - let rd_written = UInt::from_exprs_unchecked(vec![signed_lt.expr()]); - (SetLessThanDependencies::Slt { signed_lt }, rd_written) - } - InsnKind::SLTU => { - let is_lt = IsLtConfig::construct_circuit( - cb, - || "rs1 < rs2", - rs1_read.value(), - rs2_read.value(), - UINT_LIMBS, - )?; - let rd_written = UInt::from_exprs_unchecked(vec![is_lt.expr()]); - (SetLessThanDependencies::Sltu { is_lt }, rd_written) - } - _ => unreachable!(), - }; - - let r_insn = RInstructionConfig::::construct_circuit( - cb, - I::INST_KIND, - rs1_read.register_expr(), - rs2_read.register_expr(), - rd_written.register_expr(), - )?; - - Ok(SetLessThanConfig { - r_insn, - rs1_read, - rs2_read, - rd_written, - deps, - }) - } - - fn assign_instance( - config: &Self::InstructionConfig, - instance: &mut [::BaseField], - lkm: &mut LkMultiplicity, - step: &StepRecord, - ) -> Result<(), ZKVMError> { - config.r_insn.assign_instance(instance, lkm, step)?; - - let rs1 = step.rs1().unwrap().value; - let rs2 = step.rs2().unwrap().value; - - let rs1_read = Value::new_unchecked(rs1); - let rs2_read = Value::new_unchecked(rs2); - config - .rs1_read - .assign_limbs(instance, rs1_read.as_u16_limbs()); - config - .rs2_read - .assign_limbs(instance, rs2_read.as_u16_limbs()); - - match &config.deps { - SetLessThanDependencies::Slt { signed_lt } => { - signed_lt.assign_instance(instance, lkm, rs1 as SWord, rs2 as SWord)? - } - SetLessThanDependencies::Sltu { is_lt } => { - is_lt.assign_instance(instance, lkm, rs1.into(), rs2.into())? - } - } - - Ok(()) - } -} +#[cfg(feature = "u16limb_circuit")] +pub type SltuInstruction = slt_circuit_v2::SetLessThanInstruction; +#[cfg(not(feature = "u16limb_circuit"))] +pub type SltuInstruction = slt_circuit::SetLessThanInstruction; #[cfg(test)] mod test { @@ -144,10 +32,16 @@ mod test { use super::*; use crate::{ + Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, - instructions::Instruction, + instructions::{Instruction, riscv::constants::UInt}, scheme::mock_prover::{MOCK_PC_START, MockProver}, + structs::ProgramParams, }; + #[cfg(not(feature = "u16limb_circuit"))] + use slt_circuit::SetLessThanInstruction; + #[cfg(feature = "u16limb_circuit")] + use slt_circuit_v2::SetLessThanInstruction; fn verify(name: &'static str, rs1: Word, rs2: Word, rd: Word) { let mut cs = ConstraintSystem::::new(|| "riscv"); @@ -232,7 +126,6 @@ mod test { let mut rng = rand::thread_rng(); let a: i32 = rng.next_u32() as i32; let b: i32 = rng.next_u32() as i32; - println!("random: {} ("random 1", a as Word, b as Word, (a < b) as u32); verify::("random 2", b as Word, a as Word, (a >= b) as u32); } diff --git a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit.rs b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit.rs new file mode 100644 index 000000000..38cd20e85 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit.rs @@ -0,0 +1,124 @@ +use crate::{ + Value, + error::ZKVMError, + gadgets::SignedLtConfig, + instructions::{ + Instruction, + riscv::{ + RIVInstruction, + constants::{UINT_LIMBS, UInt}, + r_insn::RInstructionConfig, + }, + }, + structs::ProgramParams, + witness::LkMultiplicity, +}; +use ceno_emul::{InsnKind, SWord, StepRecord}; +use ff_ext::ExtensionField; +use gkr_iop::{circuit_builder::CircuitBuilder, gadgets::IsLtConfig}; +use std::marker::PhantomData; + +pub struct SetLessThanInstruction(PhantomData<(E, I)>); + +/// This config handles R-Instructions that represent registers values as 2 * u16. +pub struct SetLessThanConfig { + r_insn: RInstructionConfig, + + rs1_read: UInt, + rs2_read: UInt, + #[allow(dead_code)] + pub(crate) rd_written: UInt, + + deps: SetLessThanDependencies, +} + +enum SetLessThanDependencies { + Slt { signed_lt: SignedLtConfig }, + Sltu { is_lt: IsLtConfig }, +} + +impl Instruction for SetLessThanInstruction { + type InstructionConfig = SetLessThanConfig; + + fn name() -> String { + format!("{:?}", I::INST_KIND) + } + + fn construct_circuit( + cb: &mut CircuitBuilder, + _params: &ProgramParams, + ) -> Result { + // If rs1_read < rs2_read, rd_written = 1. Otherwise rd_written = 0 + let rs1_read = UInt::new_unchecked(|| "rs1_read", cb)?; + let rs2_read = UInt::new_unchecked(|| "rs2_read", cb)?; + + let (deps, rd_written) = match I::INST_KIND { + InsnKind::SLT => { + let signed_lt = + SignedLtConfig::construct_circuit(cb, || "rs1 < rs2", &rs1_read, &rs2_read)?; + let rd_written = UInt::from_exprs_unchecked(vec![signed_lt.expr()]); + (SetLessThanDependencies::Slt { signed_lt }, rd_written) + } + InsnKind::SLTU => { + let is_lt = IsLtConfig::construct_circuit( + cb, + || "rs1 < rs2", + rs1_read.value(), + rs2_read.value(), + UINT_LIMBS, + )?; + let rd_written = UInt::from_exprs_unchecked(vec![is_lt.expr()]); + (SetLessThanDependencies::Sltu { is_lt }, rd_written) + } + _ => unreachable!(), + }; + + let r_insn = RInstructionConfig::::construct_circuit( + cb, + I::INST_KIND, + rs1_read.register_expr(), + rs2_read.register_expr(), + rd_written.register_expr(), + )?; + + Ok(SetLessThanConfig { + r_insn, + rs1_read, + rs2_read, + rd_written, + deps, + }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [::BaseField], + lkm: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config.r_insn.assign_instance(instance, lkm, step)?; + + let rs1 = step.rs1().unwrap().value; + let rs2 = step.rs2().unwrap().value; + + let rs1_read = Value::new_unchecked(rs1); + let rs2_read = Value::new_unchecked(rs2); + config + .rs1_read + .assign_limbs(instance, rs1_read.as_u16_limbs()); + config + .rs2_read + .assign_limbs(instance, rs2_read.as_u16_limbs()); + + match &config.deps { + SetLessThanDependencies::Slt { signed_lt } => { + signed_lt.assign_instance(instance, lkm, rs1 as SWord, rs2 as SWord)? + } + SetLessThanDependencies::Sltu { is_lt } => { + is_lt.assign_instance(instance, lkm, rs1.into(), rs2.into())? + } + } + + Ok(()) + } +} diff --git a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs new file mode 100644 index 000000000..306690f4e --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs @@ -0,0 +1,107 @@ +use crate::{ + Value, + circuit_builder::CircuitBuilder, + error::ZKVMError, + gadgets::{UIntLimbsLT, UIntLimbsLTConfig}, + instructions::{ + Instruction, + riscv::{RIVInstruction, constants::UInt, r_insn::RInstructionConfig}, + }, + structs::ProgramParams, + witness::LkMultiplicity, +}; +use ceno_emul::{InsnKind, StepRecord}; +use ff_ext::ExtensionField; +use std::marker::PhantomData; + +pub struct SetLessThanInstruction(PhantomData<(E, I)>); + +/// This config handles R-Instructions that represent registers values as 2 * u16. +pub struct SetLessThanConfig { + r_insn: RInstructionConfig, + + rs1_read: UInt, + rs2_read: UInt, + #[cfg_attr(not(test), allow(dead_code))] + pub(crate) rd_written: UInt, + + uint_lt_config: UIntLimbsLTConfig, +} +impl Instruction for SetLessThanInstruction { + type InstructionConfig = SetLessThanConfig; + + fn name() -> String { + format!("{:?}", I::INST_KIND) + } + + fn construct_circuit( + cb: &mut CircuitBuilder, + _params: &ProgramParams, + ) -> Result { + // If rs1_read < rs2_read, rd_written = 1. Otherwise rd_written = 0 + let rs1_read = UInt::new_unchecked(|| "rs1_read", cb)?; + let rs2_read = UInt::new_unchecked(|| "rs2_read", cb)?; + + let (rd_written, uint_lt_config) = match I::INST_KIND { + InsnKind::SLT => { + let config = UIntLimbsLT::construct_circuit(cb, &rs1_read, &rs2_read, true)?; + let rd_written = UInt::from_exprs_unchecked(vec![config.is_lt()]); + (rd_written, config) + } + InsnKind::SLTU => { + let config = UIntLimbsLT::construct_circuit(cb, &rs1_read, &rs2_read, false)?; + let rd_written = UInt::from_exprs_unchecked(vec![config.is_lt()]); + (rd_written, config) + } + _ => unreachable!(), + }; + + let r_insn = RInstructionConfig::::construct_circuit( + cb, + I::INST_KIND, + rs1_read.register_expr(), + rs2_read.register_expr(), + rd_written.register_expr(), + )?; + + Ok(SetLessThanConfig { + r_insn, + rs1_read, + rs2_read, + rd_written, + uint_lt_config, + }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [::BaseField], + lkm: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config.r_insn.assign_instance(instance, lkm, step)?; + + let rs1 = step.rs1().unwrap().value; + let rs2 = step.rs2().unwrap().value; + + let rs1_read = Value::new_unchecked(rs1); + let rs2_read = Value::new_unchecked(rs2); + config + .rs1_read + .assign_limbs(instance, rs1_read.as_u16_limbs()); + config + .rs2_read + .assign_limbs(instance, rs2_read.as_u16_limbs()); + + let is_signed = matches!(step.insn().kind, InsnKind::SLT); + UIntLimbsLT::::assign( + &config.uint_lt_config, + instance, + lkm, + rs1_read.as_u16_limbs(), + rs2_read.as_u16_limbs(), + is_signed, + )?; + Ok(()) + } +} From 6a44ba35843a04a059563ba0168c6c47b3680e97 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 7 Aug 2025 21:45:13 +0800 Subject: [PATCH 09/46] add ci steps --- .github/workflows/integration.yml | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index b275433f4..c4216ea6b 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -3,7 +3,7 @@ name: Integrations on: merge_group: pull_request: - types: [synchronize, opened, reopened, ready_for_review] + types: [ synchronize, opened, reopened, ready_for_review ] push: branches: - master @@ -14,7 +14,7 @@ concurrency: jobs: skip_check: - runs-on: [self-hosted, Linux, X64] + runs-on: [ self-hosted, Linux, X64 ] outputs: should_skip: ${{ steps.skip_check.outputs.should_skip }} steps: @@ -27,14 +27,14 @@ jobs: do_not_skip: '["pull_request", "workflow_dispatch", "schedule", "merge_group"]' integration: - needs: [skip_check] + needs: [ skip_check ] if: | github.event.pull_request.draft == false && (github.event.action == 'ready_for_review' || needs.skip_check.outputs.should_skip != 'true') name: Integration testing timeout-minutes: 30 - runs-on: [self-hosted, Linux, X64] + runs-on: [ self-hosted, Linux, X64 ] steps: - uses: actions/checkout@v4 @@ -54,6 +54,13 @@ jobs: RUSTFLAGS: "-C opt-level=3" run: cargo run --package ceno_zkvm --bin e2e -- --platform=ceno --hints=10 --public-io=4191 examples/target/riscv32im-ceno-zkvm-elf/debug/examples/fibonacci + - name: Run fibonacci (debug) + feature u16limb_circuit + env: + RUST_LOG: debug + RUSTFLAGS: "-C opt-level=3" + run: cargo run --package ceno_zkvm --features u16limb_circuit --bin e2e -- --platform=ceno --hints=10 --public-io=4191 examples/target/riscv32im-ceno-zkvm-elf/debug/examples/fibonacci + + - name: Run fibonacci (release) env: RUSTFLAGS: "-C opt-level=3" From f72ffb628de6c4195af63c5270f76d79b3d1e073 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 7 Aug 2025 21:47:16 +0800 Subject: [PATCH 10/46] update comments --- ceno_zkvm/Cargo.toml | 2 +- ceno_zkvm/src/gadgets/signed_limbs.rs | 1 + ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs | 1 - ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs | 2 +- 4 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 43e8855ed..29191b4ef 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -67,7 +67,7 @@ ceno-examples = { path = "../examples-builder" } glob = "0.3" [features] -default = ["forbid_overflow", "u16limb_circuit"] +default = ["forbid_overflow"] flamegraph = ["pprof2/flamegraph", "pprof2/criterion"] forbid_overflow = [] jemalloc = ["dep:tikv-jemallocator", "dep:tikv-jemalloc-ctl"] diff --git a/ceno_zkvm/src/gadgets/signed_limbs.rs b/ceno_zkvm/src/gadgets/signed_limbs.rs index 269eae87e..a73a72b8c 100644 --- a/ceno_zkvm/src/gadgets/signed_limbs.rs +++ b/ceno_zkvm/src/gadgets/signed_limbs.rs @@ -1,3 +1,4 @@ +/// circuit implementation refer from https://github.com/openvm-org/openvm/blob/ca36de3803213da664b03d111801ab903d55e360/extensions/rv32im/circuit/src/branch_lt/core.rs use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs index 6fc4e121e..94abb56d1 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs @@ -33,7 +33,6 @@ impl Instruction for BranchCircuit, _param: &ProgramParams, diff --git a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs index 306690f4e..391dffb89 100644 --- a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs @@ -22,7 +22,7 @@ pub struct SetLessThanConfig { rs1_read: UInt, rs2_read: UInt, - #[cfg_attr(not(test), allow(dead_code))] + #[allow(dead_code)] pub(crate) rd_written: UInt, uint_lt_config: UIntLimbsLTConfig, From 6ba167b64a0fa70238177b0b7823b3028c9ed465 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 7 Aug 2025 23:29:00 +0800 Subject: [PATCH 11/46] wip --- ceno_zkvm/src/instructions/riscv/slti.rs | 107 +-------------- .../instructions/riscv/slti/slti_circuit.rs | 127 +++++++++++++++++ .../riscv/slti/slti_circuit_v2.rs | 129 ++++++++++++++++++ 3 files changed, 260 insertions(+), 103 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs create mode 100644 ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs diff --git a/ceno_zkvm/src/instructions/riscv/slti.rs b/ceno_zkvm/src/instructions/riscv/slti.rs index 736d157b1..e58c4a522 100644 --- a/ceno_zkvm/src/instructions/riscv/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/slti.rs @@ -1,3 +1,6 @@ +mod slti_circuit; +mod slti_circuit_v2; + use std::marker::PhantomData; use ceno_emul::{InsnKind, SWord, StepRecord, Word}; @@ -21,22 +24,7 @@ use crate::{ }; use ff_ext::FieldInto; use multilinear_extensions::{ToExpr, WitIn}; - -#[derive(Debug)] -pub struct SetLessThanImmConfig { - i_insn: IInstructionConfig, - - rs1_read: UInt, - imm: WitIn, - #[allow(dead_code)] - rd_written: UInt, - lt: IsLtConfig, - - // SLTI - is_rs1_neg: Option>, -} - -pub struct SetLessThanImmInstruction(PhantomData<(E, I)>); +use crate::instructions::riscv::slti::slti_circuit::SetLessThanImmInstruction; pub struct SltiOp; impl RIVInstruction for SltiOp { @@ -50,93 +38,6 @@ impl RIVInstruction for SltiuOp { } pub type SltiuInstruction = SetLessThanImmInstruction; -impl Instruction for SetLessThanImmInstruction { - type InstructionConfig = SetLessThanImmConfig; - - fn name() -> String { - format!("{:?}", I::INST_KIND) - } - - fn construct_circuit( - cb: &mut CircuitBuilder, - _params: &ProgramParams, - ) -> Result { - // If rs1_read < imm, rd_written = 1. Otherwise rd_written = 0 - let rs1_read = UInt::new_unchecked(|| "rs1_read", cb)?; - let imm = cb.create_witin(|| "imm"); - - let (value_expr, is_rs1_neg) = match I::INST_KIND { - InsnKind::SLTIU => (rs1_read.value(), None), - InsnKind::SLTI => { - let is_rs1_neg = rs1_read.is_negative(cb)?; - (rs1_read.to_field_expr(is_rs1_neg.expr()), Some(is_rs1_neg)) - } - _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), - }; - - let lt = - IsLtConfig::construct_circuit(cb, || "rs1 < imm", value_expr, imm.expr(), UINT_LIMBS)?; - let rd_written = UInt::from_exprs_unchecked(vec![lt.expr()]); - - let i_insn = IInstructionConfig::::construct_circuit( - cb, - I::INST_KIND, - imm.expr(), - rs1_read.register_expr(), - rd_written.register_expr(), - false, - )?; - - Ok(SetLessThanImmConfig { - i_insn, - rs1_read, - imm, - rd_written, - is_rs1_neg, - lt, - }) - } - - fn assign_instance( - config: &Self::InstructionConfig, - instance: &mut [E::BaseField], - lkm: &mut LkMultiplicity, - step: &StepRecord, - ) -> Result<(), ZKVMError> { - config.i_insn.assign_instance(instance, lkm, step)?; - - let rs1 = step.rs1().unwrap().value; - let rs1_value = Value::new_unchecked(rs1 as Word); - config - .rs1_read - .assign_value(instance, Value::new_unchecked(rs1)); - - let imm = InsnRecord::imm_internal(&step.insn()); - set_val!(instance, config.imm, i64_to_base::(imm)); - - match I::INST_KIND { - InsnKind::SLTIU => { - config - .lt - .assign_instance(instance, lkm, rs1 as u64, imm as u64)?; - } - InsnKind::SLTI => { - config.is_rs1_neg.as_ref().unwrap().assign_instance( - instance, - lkm, - *rs1_value.as_u16_limbs().last().unwrap() as u64, - )?; - let (rs1, imm) = (rs1 as SWord, imm as SWord); - config - .lt - .assign_instance_signed(instance, lkm, rs1 as i64, imm as i64)?; - } - _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), - } - - Ok(()) - } -} #[cfg(test)] mod test { diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs new file mode 100644 index 000000000..73badf3a4 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs @@ -0,0 +1,127 @@ +use crate::{ + Value, + circuit_builder::CircuitBuilder, + error::ZKVMError, + gadgets::SignedExtendConfig, + instructions::{ + Instruction, + riscv::{ + RIVInstruction, + constants::{UINT_LIMBS, UInt}, + i_insn::IInstructionConfig, + }, + }, + structs::ProgramParams, + tables::InsnRecord, + witness::LkMultiplicity, +}; +use ceno_emul::{InsnKind, SWord, StepRecord, Word}; +use ff_ext::ExtensionField; +use gkr_iop::{gadgets::IsLtConfig, utils::i64_to_base}; +use multilinear_extensions::{ToExpr, WitIn}; +use std::marker::PhantomData; +use witness::set_val; + +#[derive(Debug)] +pub struct SetLessThanImmConfig { + i_insn: IInstructionConfig, + + rs1_read: UInt, + imm: WitIn, + #[allow(dead_code)] + rd_written: UInt, + lt: IsLtConfig, + + // SLTI + is_rs1_neg: Option>, +} + +pub struct SetLessThanImmInstruction(PhantomData<(E, I)>); + +impl Instruction for SetLessThanImmInstruction { + type InstructionConfig = SetLessThanImmConfig; + + fn name() -> String { + format!("{:?}", I::INST_KIND) + } + + fn construct_circuit( + cb: &mut CircuitBuilder, + _params: &ProgramParams, + ) -> Result { + // If rs1_read < imm, rd_written = 1. Otherwise rd_written = 0 + let rs1_read = UInt::new_unchecked(|| "rs1_read", cb)?; + let imm = cb.create_witin(|| "imm"); + + let (value_expr, is_rs1_neg) = match I::INST_KIND { + InsnKind::SLTIU => (rs1_read.value(), None), + InsnKind::SLTI => { + let is_rs1_neg = rs1_read.is_negative(cb)?; + (rs1_read.to_field_expr(is_rs1_neg.expr()), Some(is_rs1_neg)) + } + _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), + }; + + let lt = + IsLtConfig::construct_circuit(cb, || "rs1 < imm", value_expr, imm.expr(), UINT_LIMBS)?; + let rd_written = UInt::from_exprs_unchecked(vec![lt.expr()]); + + let i_insn = IInstructionConfig::::construct_circuit( + cb, + I::INST_KIND, + imm.expr(), + rs1_read.register_expr(), + rd_written.register_expr(), + false, + )?; + + Ok(SetLessThanImmConfig { + i_insn, + rs1_read, + imm, + rd_written, + is_rs1_neg, + lt, + }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [E::BaseField], + lkm: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config.i_insn.assign_instance(instance, lkm, step)?; + + let rs1 = step.rs1().unwrap().value; + let rs1_value = Value::new_unchecked(rs1 as Word); + config + .rs1_read + .assign_value(instance, Value::new_unchecked(rs1)); + + let imm = InsnRecord::imm_internal(&step.insn()); + set_val!(instance, config.imm, i64_to_base::(imm)); + + match I::INST_KIND { + InsnKind::SLTIU => { + config + .lt + .assign_instance(instance, lkm, rs1 as u64, imm as u64)?; + } + InsnKind::SLTI => { + config.is_rs1_neg.as_ref().unwrap().assign_instance( + instance, + lkm, + *rs1_value.as_u16_limbs().last().unwrap() as u64, + )?; + let (rs1, imm) = (rs1 as SWord, imm as SWord); + config + .lt + .assign_instance_signed(instance, lkm, rs1 as i64, imm as i64)?; + } + _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), + } + + Ok(()) + } +} diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs new file mode 100644 index 000000000..f13abc33f --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs @@ -0,0 +1,129 @@ +use crate::{ + Value, + circuit_builder::CircuitBuilder, + error::ZKVMError, + gadgets::{SignedExtendConfig, UIntLimbsLT, UIntLimbsLTConfig}, + instructions::{ + Instruction, + riscv::{ + RIVInstruction, + constants::{UINT_LIMBS, UInt}, + i_insn::IInstructionConfig, + }, + }, + structs::ProgramParams, + tables::InsnRecord, + witness::LkMultiplicity, +}; +use ceno_emul::{InsnKind, SWord, StepRecord, Word}; +use ff_ext::ExtensionField; +use gkr_iop::{gadgets::IsLtConfig, utils::i64_to_base}; +use multilinear_extensions::{ToExpr, WitIn}; +use std::marker::PhantomData; +use witness::set_val; + +#[derive(Debug)] +pub struct SetLessThanImmConfig { + i_insn: IInstructionConfig, + + rs1_read: UInt, + imm: WitIn, + #[allow(dead_code)] + rd_written: UInt, + + uint_lt_config: UIntLimbsLTConfig, +} + +pub struct SetLessThanImmInstruction(PhantomData<(E, I)>); + +impl Instruction for SetLessThanImmInstruction { + type InstructionConfig = SetLessThanImmConfig; + + fn name() -> String { + format!("{:?}", I::INST_KIND) + } + + fn construct_circuit( + cb: &mut CircuitBuilder, + _params: &ProgramParams, + ) -> Result { + // If rs1_read < imm, rd_written = 1. Otherwise rd_written = 0 + let rs1_read = UInt::new_unchecked(|| "rs1_read", cb)?; + let imm = cb.create_witin(|| "imm"); + let imm_uint = UInt::from_exprs_unchecked(vec![imm.expr()]); + + let uint_lt_config = match I::INST_KIND { + InsnKind::SLTIU => UIntLimbsLT::construct_circuit(cb, &rs1_read, &imm_uint, false)?, + InsnKind::SLTI => UIntLimbsLT::construct_circuit(cb, &rs1_read, &imm_uint, true)?, + _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), + }; + + let rd_written = UInt::from_exprs_unchecked(vec![uint_lt_config.is_lt()]); + + let i_insn = IInstructionConfig::::construct_circuit( + cb, + I::INST_KIND, + imm.expr(), + rs1_read.register_expr(), + rd_written.register_expr(), + false, + )?; + + Ok(SetLessThanImmConfig { + i_insn, + rs1_read, + imm, + rd_written, + uint_lt_config, + }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [E::BaseField], + lkm: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config.i_insn.assign_instance(instance, lkm, step)?; + + let rs1 = step.rs1().unwrap().value; + let rs1_value = Value::new_unchecked(rs1 as Word); + config + .rs1_read + .assign_value(instance, Value::new_unchecked(rs1)); + + let imm = InsnRecord::imm_internal(&step.insn()); + set_val!(instance, config.imm, i64_to_base::(imm)); + + let is_signed = matches!(step.insn().kind, InsnKind::SLT); + UIntLimbsLT::::assign( + &config.uint_lt_config, + instance, + lkm, + rs1_read.as_u16_limbs(), + rs2_read.as_u16_limbs(), + is_signed, + )?; + match I::INST_KIND { + InsnKind::SLTIU => { + config + .lt + .assign_instance(instance, lkm, rs1 as u64, imm as u64)?; + } + InsnKind::SLTI => { + config.is_rs1_neg.as_ref().unwrap().assign_instance( + instance, + lkm, + *rs1_value.as_u16_limbs().last().unwrap() as u64, + )?; + let (rs1, imm) = (rs1 as SWord, imm as SWord); + config + .lt + .assign_instance_signed(instance, lkm, rs1 as i64, imm as i64)?; + } + _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), + } + + Ok(()) + } +} From 14c5c39cec324fec3f1f9b989e3bec68846cc7e8 Mon Sep 17 00:00:00 2001 From: Wu Sung-Ming Date: Fri, 8 Aug 2025 15:43:10 +0800 Subject: [PATCH 12/46] finish slti/sltiu logic --- ceno_zkvm/Cargo.toml | 2 +- ceno_zkvm/src/gadgets/signed_limbs.rs | 1 + ceno_zkvm/src/instructions/riscv/i_insn.rs | 3 + ceno_zkvm/src/instructions/riscv/slti.rs | 13 ++- .../instructions/riscv/slti/slti_circuit.rs | 7 +- .../riscv/slti/slti_circuit_v2.rs | 86 +++++++++++-------- ceno_zkvm/src/tables/program.rs | 73 +++++++++++++--- ceno_zkvm/src/utils.rs | 32 ++++++- 8 files changed, 164 insertions(+), 53 deletions(-) diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 29191b4ef..43e8855ed 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -67,7 +67,7 @@ ceno-examples = { path = "../examples-builder" } glob = "0.3" [features] -default = ["forbid_overflow"] +default = ["forbid_overflow", "u16limb_circuit"] flamegraph = ["pprof2/flamegraph", "pprof2/criterion"] forbid_overflow = [] jemalloc = ["dep:tikv-jemallocator", "dep:tikv-jemalloc-ctl"] diff --git a/ceno_zkvm/src/gadgets/signed_limbs.rs b/ceno_zkvm/src/gadgets/signed_limbs.rs index a73a72b8c..f79710c91 100644 --- a/ceno_zkvm/src/gadgets/signed_limbs.rs +++ b/ceno_zkvm/src/gadgets/signed_limbs.rs @@ -11,6 +11,7 @@ use p3::field::FieldAlgebra; use std::{array, marker::PhantomData}; use witness::set_val; +#[derive(Debug)] pub struct UIntLimbsLTConfig { // Most significant limb of a and b respectively as a field element, will be range // checked to be within [-32768, 32767) if signed and [0, 65536) if unsigned. diff --git a/ceno_zkvm/src/instructions/riscv/i_insn.rs b/ceno_zkvm/src/instructions/riscv/i_insn.rs index f3d13d408..4f51fde7d 100644 --- a/ceno_zkvm/src/instructions/riscv/i_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/i_insn.rs @@ -28,6 +28,7 @@ impl IInstructionConfig { circuit_builder: &mut CircuitBuilder, insn_kind: InsnKind, imm: Expression, + #[cfg(feature = "u16limb_circuit")] imm_sign: Expression, rs1_read: RegisterExpr, rd_written: RegisterExpr, branching: bool, @@ -49,6 +50,8 @@ impl IInstructionConfig { rs1.id.expr(), 0.into(), imm.clone(), + #[cfg(feature = "u16limb_circuit")] + imm_sign, ))?; Ok(IInstructionConfig { vm_state, rs1, rd }) diff --git a/ceno_zkvm/src/instructions/riscv/slti.rs b/ceno_zkvm/src/instructions/riscv/slti.rs index e58c4a522..a164a59cd 100644 --- a/ceno_zkvm/src/instructions/riscv/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/slti.rs @@ -1,6 +1,15 @@ -mod slti_circuit; +#[cfg(feature = "u16limb_circuit")] mod slti_circuit_v2; +#[cfg(not(feature = "u16limb_circuit"))] +mod slti_circuit; + +#[cfg(feature = "u16limb_circuit")] +use crate::instructions::riscv::slti::slti_circuit_v2::SetLessThanImmInstruction; + +#[cfg(not(feature = "u16limb_circuit"))] +use crate::instructions::riscv::slti::slti_circuit::SetLessThanImmInstruction; + use std::marker::PhantomData; use ceno_emul::{InsnKind, SWord, StepRecord, Word}; @@ -24,7 +33,6 @@ use crate::{ }; use ff_ext::FieldInto; use multilinear_extensions::{ToExpr, WitIn}; -use crate::instructions::riscv::slti::slti_circuit::SetLessThanImmInstruction; pub struct SltiOp; impl RIVInstruction for SltiOp { @@ -38,7 +46,6 @@ impl RIVInstruction for SltiuOp { } pub type SltiuInstruction = SetLessThanImmInstruction; - #[cfg(test)] mod test { use ceno_emul::{Change, PC_STEP_SIZE, StepRecord, encode_rv32}; diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs index 73badf3a4..114f04bc2 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs @@ -16,9 +16,10 @@ use crate::{ witness::LkMultiplicity, }; use ceno_emul::{InsnKind, SWord, StepRecord, Word}; -use ff_ext::ExtensionField; +use ff_ext::{ExtensionField, FieldInto}; use gkr_iop::{gadgets::IsLtConfig, utils::i64_to_base}; use multilinear_extensions::{ToExpr, WitIn}; +use p3::field::FieldAlgebra; use std::marker::PhantomData; use witness::set_val; @@ -29,7 +30,7 @@ pub struct SetLessThanImmConfig { rs1_read: UInt, imm: WitIn, #[allow(dead_code)] - rd_written: UInt, + pub(crate) rd_written: UInt, lt: IsLtConfig, // SLTI @@ -70,6 +71,8 @@ impl Instruction for SetLessThanImmInst cb, I::INST_KIND, imm.expr(), + #[cfg(feature = "u16limb_circuit")] + E::BaseField::ZERO.expr(), rs1_read.register_expr(), rd_written.register_expr(), false, diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs index f13abc33f..a953b4bbf 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs @@ -2,7 +2,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, error::ZKVMError, - gadgets::{SignedExtendConfig, UIntLimbsLT, UIntLimbsLTConfig}, + gadgets::{UIntLimbsLT, UIntLimbsLTConfig}, instructions::{ Instruction, riscv::{ @@ -12,13 +12,13 @@ use crate::{ }, }, structs::ProgramParams, - tables::InsnRecord, + utils::{imm_sign_extend, imm_sign_extend_circuit}, witness::LkMultiplicity, }; -use ceno_emul::{InsnKind, SWord, StepRecord, Word}; -use ff_ext::ExtensionField; -use gkr_iop::{gadgets::IsLtConfig, utils::i64_to_base}; +use ceno_emul::{InsnKind, StepRecord, Word}; +use ff_ext::{ExtensionField, FieldInto}; use multilinear_extensions::{ToExpr, WitIn}; +use p3::field::FieldAlgebra; use std::marker::PhantomData; use witness::set_val; @@ -28,8 +28,10 @@ pub struct SetLessThanImmConfig { rs1_read: UInt, imm: WitIn, + // 0 positive, 1 negative + imm_sign: Option, #[allow(dead_code)] - rd_written: UInt, + pub(crate) rd_written: UInt, uint_lt_config: UIntLimbsLTConfig, } @@ -47,14 +49,34 @@ impl Instruction for SetLessThanImmInst cb: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result { + assert_eq!(UINT_LIMBS, 2); // If rs1_read < imm, rd_written = 1. Otherwise rd_written = 0 let rs1_read = UInt::new_unchecked(|| "rs1_read", cb)?; let imm = cb.create_witin(|| "imm"); - let imm_uint = UInt::from_exprs_unchecked(vec![imm.expr()]); - let uint_lt_config = match I::INST_KIND { - InsnKind::SLTIU => UIntLimbsLT::construct_circuit(cb, &rs1_read, &imm_uint, false)?, - InsnKind::SLTI => UIntLimbsLT::construct_circuit(cb, &rs1_read, &imm_uint, true)?, + let (uint_lt_config, imm_sign_extend, imm_sign) = match I::INST_KIND { + InsnKind::SLTIU => { + let imm_sign_extend = UInt::from_exprs_unchecked( + imm_sign_extend_circuit::(false, E::BaseField::ZERO.expr(), imm.expr()) + .to_vec(), + ); + ( + UIntLimbsLT::construct_circuit(cb, &rs1_read, &imm_sign_extend, false)?, + imm_sign_extend, + None, + ) + } + InsnKind::SLTI => { + let imm_sign = cb.create_bit(|| "imm_sign")?; + let imm_sign_extend = UInt::from_exprs_unchecked( + imm_sign_extend_circuit::(true, imm_sign.expr(), imm.expr()).to_vec(), + ); + ( + UIntLimbsLT::construct_circuit(cb, &rs1_read, &imm_sign_extend, true)?, + imm_sign_extend, + Some(imm_sign), + ) + } _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), }; @@ -63,7 +85,10 @@ impl Instruction for SetLessThanImmInst let i_insn = IInstructionConfig::::construct_circuit( cb, I::INST_KIND, - imm.expr(), + imm_sign_extend.expr().remove(0), + imm_sign + .map(|imm_sign| imm_sign.expr()) + .unwrap_or(E::BaseField::ZERO.expr()), rs1_read.register_expr(), rd_written.register_expr(), false, @@ -73,6 +98,7 @@ impl Instruction for SetLessThanImmInst i_insn, rs1_read, imm, + imm_sign, rd_written, uint_lt_config, }) @@ -92,38 +118,26 @@ impl Instruction for SetLessThanImmInst .rs1_read .assign_value(instance, Value::new_unchecked(rs1)); - let imm = InsnRecord::imm_internal(&step.insn()); - set_val!(instance, config.imm, i64_to_base::(imm)); - + let imm = step.insn().imm as i16 as u16; let is_signed = matches!(step.insn().kind, InsnKind::SLT); + set_val!(instance, config.imm, E::BaseField::from_canonical_u16(imm)); + let imm_sign_extend = imm_sign_extend(is_signed, step.insn().imm as i16); + if is_signed { + set_val!( + instance, + config.imm_sign.as_ref().unwrap(), + E::BaseField::from_bool(imm_sign_extend[1] > 0) + ); + } + UIntLimbsLT::::assign( &config.uint_lt_config, instance, lkm, - rs1_read.as_u16_limbs(), - rs2_read.as_u16_limbs(), + rs1_value.as_u16_limbs(), + &imm_sign_extend, is_signed, )?; - match I::INST_KIND { - InsnKind::SLTIU => { - config - .lt - .assign_instance(instance, lkm, rs1 as u64, imm as u64)?; - } - InsnKind::SLTI => { - config.is_rs1_neg.as_ref().unwrap().assign_instance( - instance, - lkm, - *rs1_value.as_u16_limbs().last().unwrap() as u64, - )?; - let (rs1, imm) = (rs1 as SWord, imm as SWord); - config - .lt - .assign_instance_signed(instance, lkm, rs1 as i64, imm as i64)?; - } - _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), - } - Ok(()) } } diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 638264400..02b7be44d 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -1,5 +1,3 @@ -use std::{collections::HashMap, marker::PhantomData}; - use super::RMMCollections; use crate::{ circuit_builder::{CircuitBuilder, SetTableSpec}, @@ -16,12 +14,20 @@ use itertools::Itertools; use multilinear_extensions::{Expression, Fixed, ToExpr, WitIn}; use p3::field::FieldAlgebra; use rayon::iter::{IndexedParallelIterator, ParallelIterator}; +use std::{collections::HashMap, marker::PhantomData}; use witness::{InstancePaddingStrategy, RowMajorMatrix, set_fixed_val, set_val}; /// This structure establishes the order of the fields in instruction records, common to the program table and circuit fetches. + +#[cfg(not(feature = "u16limb_circuit"))] #[derive(Clone, Debug)] pub struct InsnRecord([T; 6]); +#[cfg(feature = "u16limb_circuit")] +#[derive(Clone, Debug)] +pub struct InsnRecord([T; 7]); + impl InsnRecord { + #[cfg(not(feature = "u16limb_circuit"))] pub fn new(pc: T, kind: T, rd: Option, rs1: T, rs2: T, imm_internal: T) -> Self where T: From, @@ -30,6 +36,15 @@ impl InsnRecord { InsnRecord([pc, kind, rd, rs1, rs2, imm_internal]) } + #[cfg(feature = "u16limb_circuit")] + pub fn new(pc: T, kind: T, rd: Option, rs1: T, rs2: T, imm_internal: T, imm_sign: T) -> Self + where + T: From, + { + let rd = rd.unwrap_or_else(|| T::from(Instruction::RD_NULL)); + InsnRecord([pc, kind, rd, rs1, rs2, imm_internal, imm_sign]) + } + pub fn as_slice(&self) -> &[T] { &self.0 } @@ -37,14 +52,30 @@ impl InsnRecord { impl InsnRecord { fn from_decoded(pc: u32, insn: &Instruction) -> Self { - InsnRecord([ - (pc as u64).into_f(), - (insn.kind as u64).into_f(), - (insn.rd_internal() as u64).into_f(), - (insn.rs1_or_zero() as u64).into_f(), - (insn.rs2_or_zero() as u64).into_f(), - i64_to_base(InsnRecord::imm_internal(insn)), - ]) + #[cfg(not(feature = "u16limb_circuit"))] + { + InsnRecord([ + (pc as u64).into_f(), + (insn.kind as u64).into_f(), + (insn.rd_internal() as u64).into_f(), + (insn.rs1_or_zero() as u64).into_f(), + (insn.rs2_or_zero() as u64).into_f(), + i64_to_base(InsnRecord::imm_internal(insn)), + ]) + } + + #[cfg(feature = "u16limb_circuit")] + { + InsnRecord([ + (pc as u64).into_f(), + (insn.kind as u64).into_f(), + (insn.rd_internal() as u64).into_f(), + (insn.rs1_or_zero() as u64).into_f(), + (insn.rs2_or_zero() as u64).into_f(), + F::from_canonical_u16(insn.imm as i16 as u16), + F::from_bool(InsnRecord::imm_signed_internal(insn)), + ]) + } } } @@ -69,6 +100,16 @@ impl InsnRecord<()> { _ => insn.imm as i64, } } + + pub fn imm_signed_internal(insn: &Instruction) -> bool { + match (insn.kind, InsnFormat::from(insn.kind)) { + (SLLI | SRLI | SRAI, _) => false, + // Unsigned view. + (_, R | U) | (ADDI | SLTIU | ANDI | XORI | ORI, _) => false, + // Signed view. + _ => true, + } + } } #[derive(Clone, Debug)] @@ -96,6 +137,17 @@ impl TableCircuit for ProgramTableCircuit { cb: &mut CircuitBuilder, params: &ProgramParams, ) -> Result { + #[cfg(not(feature = "u16limb_circuit"))] + let record = InsnRecord([ + cb.create_fixed(|| "pc"), + cb.create_fixed(|| "kind"), + cb.create_fixed(|| "rd"), + cb.create_fixed(|| "rs1"), + cb.create_fixed(|| "rs2"), + cb.create_fixed(|| "imm_internal"), + ]); + + #[cfg(feature = "u16limb_circuit")] let record = InsnRecord([ cb.create_fixed(|| "pc"), cb.create_fixed(|| "kind"), @@ -103,6 +155,7 @@ impl TableCircuit for ProgramTableCircuit { cb.create_fixed(|| "rs1"), cb.create_fixed(|| "rs2"), cb.create_fixed(|| "imm_internal"), + cb.create_fixed(|| "imm_sign"), ]); let mlt = cb.create_witin(|| "mlt"); diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index d757d4ee6..35fc7ecc1 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -1,3 +1,4 @@ +use multilinear_extensions::ToExpr; use std::{ collections::HashMap, fmt::Display, @@ -5,10 +6,12 @@ use std::{ panic::{self, PanicHookInfo}, }; +use crate::instructions::riscv::constants::UINT_LIMBS; use ff_ext::ExtensionField; pub use gkr_iop::utils::i64_to_base; use itertools::Itertools; -use p3::field::Field; +use multilinear_extensions::Expression; +use p3::field::{Field, FieldAlgebra}; pub fn split_to_u8>(value: u32) -> Vec { (0..(u32::BITS / 8)) @@ -128,6 +131,33 @@ where result } +pub fn imm_sign_extend_circuit( + require_signed: bool, + is_signed: Expression, + imm: Expression, +) -> [Expression; UINT_LIMBS] { + if !require_signed { + [imm, E::BaseField::ZERO.expr()] + } else { + [ + imm, + is_signed * E::BaseField::from_canonical_u16(0xffff).expr(), + ] + } +} +#[inline(always)] +pub fn imm_sign_extend(is_signed: bool, imm: i16) -> [u16; UINT_LIMBS] { + if !is_signed { + [imm as u16, 0] + } else { + if imm > 0 { + [imm as u16, 0u16] + } else { + [imm as u16, 0xffff] + } + } +} + #[cfg(all(feature = "jemalloc", unix, not(test)))] pub fn print_allocated_bytes() { use tikv_jemalloc_ctl::{epoch, stats}; From 31406524e68bf3294296367b228bb495f46ff956 Mon Sep 17 00:00:00 2001 From: Wu Sung-Ming Date: Fri, 8 Aug 2025 16:09:37 +0800 Subject: [PATCH 13/46] finish addi logic --- ceno_zkvm/src/instructions/riscv/arith_imm.rs | 80 ++------------ .../riscv/arith_imm/arith_imm_circuit.rs | 84 +++++++++++++++ .../riscv/arith_imm/arith_imm_circuit_v2.rs | 100 ++++++++++++++++++ .../instructions/riscv/dummy/dummy_circuit.rs | 3 + .../src/instructions/riscv/ecall/keccak.rs | 2 + .../src/instructions/riscv/ecall_insn.rs | 2 + .../riscv/logic_imm/logic_imm_circuit.rs | 2 + ceno_zkvm/src/instructions/riscv/r_insn.rs | 2 + ceno_zkvm/src/instructions/riscv/s_insn.rs | 2 + .../instructions/riscv/slti/slti_circuit.rs | 1 - 10 files changed, 208 insertions(+), 70 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit.rs create mode 100644 ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm.rs b/ceno_zkvm/src/instructions/riscv/arith_imm.rs index 14088f14e..abab44807 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm.rs @@ -1,3 +1,14 @@ +#[cfg(not(feature = "u16limb_circuit"))] +mod arith_imm_circuit; +#[cfg(feature = "u16limb_circuit")] +mod arith_imm_circuit_v2; + +#[cfg(feature = "u16limb_circuit")] +pub use crate::instructions::riscv::arith_imm::arith_imm_circuit_v2::AddiInstruction; + +#[cfg(not(feature = "u16limb_circuit"))] +pub use crate::instructions::riscv::arith_imm::arith_imm_circuit::AddiInstruction; + use std::marker::PhantomData; use ceno_emul::StepRecord; @@ -10,79 +21,10 @@ use crate::{ use super::{RIVInstruction, constants::UInt, i_insn::IInstructionConfig}; -pub struct AddiInstruction(PhantomData); - impl RIVInstruction for AddiInstruction { const INST_KIND: ceno_emul::InsnKind = ceno_emul::InsnKind::ADDI; } -pub struct InstructionConfig { - i_insn: IInstructionConfig, - - rs1_read: UInt, - imm: UInt, - rd_written: UInt, -} - -impl Instruction for AddiInstruction { - type InstructionConfig = InstructionConfig; - - fn name() -> String { - format!("{:?}", Self::INST_KIND) - } - - fn construct_circuit( - circuit_builder: &mut CircuitBuilder, - _params: &ProgramParams, - ) -> Result { - let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; - let imm = UInt::new(|| "imm", circuit_builder)?; - let rd_written = rs1_read.add(|| "rs1_read + imm", circuit_builder, &imm, true)?; - - let i_insn = IInstructionConfig::::construct_circuit( - circuit_builder, - Self::INST_KIND, - imm.value(), - rs1_read.register_expr(), - rd_written.register_expr(), - false, - )?; - - Ok(InstructionConfig { - i_insn, - rs1_read, - imm, - rd_written, - }) - } - - fn assign_instance( - config: &Self::InstructionConfig, - instance: &mut [::BaseField], - lk_multiplicity: &mut LkMultiplicity, - step: &StepRecord, - ) -> Result<(), ZKVMError> { - let rs1_read = Value::new_unchecked(step.rs1().unwrap().value); - let imm = Value::new( - InsnRecord::imm_internal(&step.insn()) as u32, - lk_multiplicity, - ); - - let result = rs1_read.add(&imm, lk_multiplicity, true); - - config.rs1_read.assign_value(instance, rs1_read); - config.imm.assign_value(instance, imm); - - config.rd_written.assign_add_outcome(instance, &result); - - config - .i_insn - .assign_instance(instance, lk_multiplicity, step)?; - - Ok(()) - } -} - #[cfg(test)] mod test { use ceno_emul::{Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit.rs b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit.rs new file mode 100644 index 000000000..26c1ecfd9 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit.rs @@ -0,0 +1,84 @@ +use crate::{ + Value, + circuit_builder::CircuitBuilder, + error::ZKVMError, + instructions::{ + Instruction, + riscv::{constants::UInt, i_insn::IInstructionConfig}, + }, + structs::ProgramParams, + tables::InsnRecord, + witness::LkMultiplicity, +}; +use ceno_emul::StepRecord; +use ff_ext::ExtensionField; +use std::marker::PhantomData; + +pub struct AddiInstruction(PhantomData); + +pub struct InstructionConfig { + i_insn: IInstructionConfig, + + rs1_read: UInt, + imm: UInt, + rd_written: UInt, +} + +impl Instruction for AddiInstruction { + type InstructionConfig = crate::instructions::riscv::arith_imm::InstructionConfig; + + fn name() -> String { + format!("{:?}", Self::INST_KIND) + } + + fn construct_circuit( + circuit_builder: &mut CircuitBuilder, + _params: &ProgramParams, + ) -> Result { + let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; + let imm = UInt::new(|| "imm", circuit_builder)?; + let rd_written = rs1_read.add(|| "rs1_read + imm", circuit_builder, &imm, true)?; + + let i_insn = IInstructionConfig::::construct_circuit( + circuit_builder, + Self::INST_KIND, + imm.value(), + rs1_read.register_expr(), + rd_written.register_expr(), + false, + )?; + + Ok(crate::instructions::riscv::arith_imm::InstructionConfig { + i_insn, + rs1_read, + imm, + rd_written, + }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [::BaseField], + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + let rs1_read = Value::new_unchecked(step.rs1().unwrap().value); + let imm = Value::new( + InsnRecord::imm_internal(&step.insn()) as u32, + lk_multiplicity, + ); + + let result = rs1_read.add(&imm, lk_multiplicity, true); + + config.rs1_read.assign_value(instance, rs1_read); + config.imm.assign_value(instance, imm); + + config.rd_written.assign_add_outcome(instance, &result); + + config + .i_insn + .assign_instance(instance, lk_multiplicity, step)?; + + Ok(()) + } +} diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs new file mode 100644 index 000000000..8b14062c2 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs @@ -0,0 +1,100 @@ +use crate::{ + Value, + circuit_builder::CircuitBuilder, + error::ZKVMError, + instructions::{ + Instruction, + riscv::{RIVInstruction, constants::UInt, i_insn::IInstructionConfig}, + }, + structs::ProgramParams, + utils::{imm_sign_extend, imm_sign_extend_circuit}, + witness::LkMultiplicity, +}; +use ceno_emul::StepRecord; +use ff_ext::{ExtensionField, FieldInto}; +use multilinear_extensions::{ToExpr, WitIn}; +use p3::field::FieldAlgebra; +use std::marker::PhantomData; +use witness::set_val; + +pub struct AddiInstruction(PhantomData); + +pub struct InstructionConfig { + i_insn: IInstructionConfig, + + rs1_read: UInt, + imm: WitIn, + // 0 positive, 1 negative + imm_sign: WitIn, + rd_written: UInt, +} + +impl Instruction for AddiInstruction { + type InstructionConfig = InstructionConfig; + + fn name() -> String { + format!("{:?}", Self::INST_KIND) + } + + fn construct_circuit( + circuit_builder: &mut CircuitBuilder, + _params: &ProgramParams, + ) -> Result { + let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; + let imm = circuit_builder.create_witin(|| "imm"); + let imm_sign = circuit_builder.create_bit(|| "imm_sign")?; + let imm_sign_extend = UInt::from_exprs_unchecked( + imm_sign_extend_circuit::(true, imm_sign.expr(), imm.expr()).to_vec(), + ); + let rd_written = + rs1_read.add(|| "rs1_read + imm", circuit_builder, &imm_sign_extend, true)?; + + let i_insn = IInstructionConfig::::construct_circuit( + circuit_builder, + Self::INST_KIND, + imm_sign_extend.expr().remove(0), + imm_sign.expr(), + rs1_read.register_expr(), + rd_written.register_expr(), + false, + )?; + + Ok(InstructionConfig { + i_insn, + rs1_read, + imm, + imm_sign, + rd_written, + }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [::BaseField], + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + let rs1_read = Value::new_unchecked(step.rs1().unwrap().value); + + let imm = step.insn().imm as i16 as u16; + set_val!(instance, config.imm, E::BaseField::from_canonical_u16(imm)); + let imm_sign_extend = imm_sign_extend(true, step.insn().imm as i16); + + set_val!( + instance, + config.imm_sign, + E::BaseField::from_bool(imm_sign_extend[1] > 0) + ); + + let imm_sign_extend = Value::from_limb_slice_unchecked(&imm_sign_extend); + let result = rs1_read.add(&imm_sign_extend, lk_multiplicity, true); + config.rs1_read.assign_value(instance, rs1_read); + config.rd_written.assign_add_outcome(instance, &result); + + config + .i_insn + .assign_instance(instance, lk_multiplicity, step)?; + + Ok(()) + } +} diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs index 6de3ee10c..898ff856f 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs @@ -16,6 +16,7 @@ use crate::{ }; use ff_ext::FieldInto; use multilinear_extensions::{ToExpr, WitIn}; +use p3::field::FieldAlgebra; use witness::set_val; /// DummyInstruction can handle any instruction and produce its side-effects. @@ -198,6 +199,8 @@ impl DummyConfig { rs1_id, rs2_id, imm.expr(), + #[cfg(feature = "u16limb_circuit")] + E::BaseField::ZERO.expr(), ))?; Ok(DummyConfig { diff --git a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs index 51e353dc7..5530eee87 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs @@ -98,6 +98,8 @@ impl Instruction for KeccakInstruction { E::BaseField::ZERO.expr(), E::BaseField::ZERO.expr(), E::BaseField::ZERO.expr(), + #[cfg(feature = "u16limb_circuit")] + 0.into(), ))?; let mut layout = as gkr_iop::ProtocolBuilder>::build_layer_logic( diff --git a/ceno_zkvm/src/instructions/riscv/ecall_insn.rs b/ceno_zkvm/src/instructions/riscv/ecall_insn.rs index e975b81b7..b7eb20ea4 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall_insn.rs @@ -44,6 +44,8 @@ impl EcallInstructionConfig { 0.into(), 0.into(), 0.into(), // imm = 0 + #[cfg(feature = "u16limb_circuit")] + 0.into(), // imm_sign = 0 ))?; let prev_x5_ts = cb.create_witin(|| "prev_x5_ts"); diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs index 23485f0e9..259ccf7ab 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs @@ -92,6 +92,8 @@ impl LogicConfig { cb, insn_kind, imm.value(), + #[cfg(feature = "u16limb_circuit")] + 0.into(), rs1_read.register_expr(), rd_written.register_expr(), false, diff --git a/ceno_zkvm/src/instructions/riscv/r_insn.rs b/ceno_zkvm/src/instructions/riscv/r_insn.rs index f0f72d06f..540ccaffe 100644 --- a/ceno_zkvm/src/instructions/riscv/r_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/r_insn.rs @@ -48,6 +48,8 @@ impl RInstructionConfig { rs1.id.expr(), rs2.id.expr(), 0.into(), + #[cfg(feature = "u16limb_circuit")] + 0.into(), ))?; Ok(RInstructionConfig { diff --git a/ceno_zkvm/src/instructions/riscv/s_insn.rs b/ceno_zkvm/src/instructions/riscv/s_insn.rs index dd1d5035b..bb2a994d9 100644 --- a/ceno_zkvm/src/instructions/riscv/s_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/s_insn.rs @@ -48,6 +48,8 @@ impl SInstructionConfig { rs1.id.expr(), rs2.id.expr(), imm.clone(), + #[cfg(feature = "u16limb_circuit")] + 0.into(), ))?; // Memory diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs index 114f04bc2..91ce2d99c 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs @@ -71,7 +71,6 @@ impl Instruction for SetLessThanImmInst cb, I::INST_KIND, imm.expr(), - #[cfg(feature = "u16limb_circuit")] E::BaseField::ZERO.expr(), rs1_read.register_expr(), rd_written.register_expr(), From 01d31d6064bfeeb4053854cd7dccfce3646b0d60 Mon Sep 17 00:00:00 2001 From: Wu Sung-Ming Date: Fri, 8 Aug 2025 16:12:34 +0800 Subject: [PATCH 14/46] skip imm_sign range check --- .../src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs | 2 +- ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs index 8b14062c2..f969a68b0 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs @@ -42,7 +42,7 @@ impl Instruction for AddiInstruction { ) -> Result { let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; let imm = circuit_builder.create_witin(|| "imm"); - let imm_sign = circuit_builder.create_bit(|| "imm_sign")?; + let imm_sign = circuit_builder.create_witin(|| "imm_sign"); let imm_sign_extend = UInt::from_exprs_unchecked( imm_sign_extend_circuit::(true, imm_sign.expr(), imm.expr()).to_vec(), ); diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs index a953b4bbf..5a76fb2c5 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs @@ -67,7 +67,7 @@ impl Instruction for SetLessThanImmInst ) } InsnKind::SLTI => { - let imm_sign = cb.create_bit(|| "imm_sign")?; + let imm_sign = cb.create_witin(|| "imm_sign"); let imm_sign_extend = UInt::from_exprs_unchecked( imm_sign_extend_circuit::(true, imm_sign.expr(), imm.expr()).to_vec(), ); @@ -88,7 +88,7 @@ impl Instruction for SetLessThanImmInst imm_sign_extend.expr().remove(0), imm_sign .map(|imm_sign| imm_sign.expr()) - .unwrap_or(E::BaseField::ZERO.expr()), + .unwrap_or(0.into()), rs1_read.register_expr(), rd_written.register_expr(), false, From 79a8379d6834bec14a22fc260757dba66701c289 Mon Sep 17 00:00:00 2001 From: Wu Sung-Ming Date: Fri, 8 Aug 2025 19:10:42 +0800 Subject: [PATCH 15/46] branch imm could be 1 limb --- ceno_zkvm/src/instructions/riscv/b_insn.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ceno_zkvm/src/instructions/riscv/b_insn.rs b/ceno_zkvm/src/instructions/riscv/b_insn.rs index 71754c7fe..04f0920bb 100644 --- a/ceno_zkvm/src/instructions/riscv/b_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/b_insn.rs @@ -65,6 +65,8 @@ impl BInstructionConfig { rs1.id.expr(), rs2.id.expr(), imm.expr(), + #[cfg(feature = "u16limb_circuit")] + 0.into(), ))?; // Branch program counter From 64c291aa6ae582d4363e6820659dea45baa2226c Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Mon, 11 Aug 2025 09:51:32 +0800 Subject: [PATCH 16/46] addi test pass --- ceno_zkvm/src/instructions/riscv/im_insn.rs | 1 + ceno_zkvm/src/instructions/riscv/j_insn.rs | 2 ++ ceno_zkvm/src/instructions/riscv/jump/jalr.rs | 9 +++++++-- ceno_zkvm/src/instructions/riscv/shift_imm.rs | 2 ++ ceno_zkvm/src/tables/program.rs | 4 ++-- 5 files changed, 14 insertions(+), 4 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/im_insn.rs b/ceno_zkvm/src/instructions/riscv/im_insn.rs index 41dd5ce0f..f7e0f8393 100644 --- a/ceno_zkvm/src/instructions/riscv/im_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/im_insn.rs @@ -50,6 +50,7 @@ impl IMInstructionConfig { rs1.id.expr(), 0.into(), imm.clone(), + #[cfg(feature = "u16limb_circuit")] 0.into(), ))?; Ok(IMInstructionConfig { diff --git a/ceno_zkvm/src/instructions/riscv/j_insn.rs b/ceno_zkvm/src/instructions/riscv/j_insn.rs index 40dc98f90..156aa1cd1 100644 --- a/ceno_zkvm/src/instructions/riscv/j_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/j_insn.rs @@ -45,6 +45,8 @@ impl JInstructionConfig { 0.into(), 0.into(), vm_state.next_pc.unwrap().expr() - vm_state.pc.expr(), + #[cfg(feature = "u16limb_circuit")] + 0.into(), ))?; Ok(JInstructionConfig { vm_state, rd }) diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs index f8995dd63..3cb3c9b20 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs @@ -12,7 +12,7 @@ use crate::{ }, structs::ProgramParams, tables::InsnRecord, - utils::i64_to_base, + utils::{i64_to_base, imm_sign_extend_circuit}, witness::{LkMultiplicity, set_val}, }; use ceno_emul::{InsnKind, PC_STEP_SIZE}; @@ -48,12 +48,17 @@ impl Instruction for JalrInstruction { ) -> Result, ZKVMError> { let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; // unsigned 32-bit value let imm = circuit_builder.create_witin(|| "imm"); // signed 12-bit value + let imm_sign = circuit_builder.create_witin(|| "imm_sign"); + let imm_sign_extend = UInt::from_exprs_unchecked( + imm_sign_extend_circuit::(true, imm_sign.expr(), imm.expr()).to_vec(), + ); let rd_written = UInt::new(|| "rd_written", circuit_builder)?; let i_insn = IInstructionConfig::construct_circuit( circuit_builder, InsnKind::JALR, - imm.expr(), + imm_sign_extend.expr().remove(0), + #[cfg(feature = "u16limb_circuit")] imm_sign.expr(), rs1_read.register_expr(), rd_written.register_expr(), true, diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm.rs b/ceno_zkvm/src/instructions/riscv/shift_imm.rs index 9d88b0a2e..cc5f557f7 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm.rs @@ -128,6 +128,8 @@ impl Instruction for ShiftImmInstructio circuit_builder, I::INST_KIND, imm.expr(), + #[cfg(feature = "u16limb_circuit")] + 0.into(), rs1_read.register_expr(), rd_written.register_expr(), false, diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 02b7be44d..c00947f58 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -105,9 +105,9 @@ impl InsnRecord<()> { match (insn.kind, InsnFormat::from(insn.kind)) { (SLLI | SRLI | SRAI, _) => false, // Unsigned view. - (_, R | U) | (ADDI | SLTIU | ANDI | XORI | ORI, _) => false, + (_, R | U) | (SLTIU | ANDI | XORI | ORI, _) => false, // Signed view. - _ => true, + _ => insn.imm < 0, } } } From a35031b4fb0869b5c5af3a868622fe521d95810b Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Tue, 12 Aug 2025 09:40:35 +0800 Subject: [PATCH 17/46] slti(u) test pass --- ceno_zkvm/src/gadgets/signed_limbs.rs | 2 +- ceno_zkvm/src/instructions/riscv/arith_imm.rs | 12 +------- ceno_zkvm/src/instructions/riscv/slti.rs | 26 ++--------------- .../riscv/slti/slti_circuit_v2.rs | 29 +++++++++---------- ceno_zkvm/src/tables/program.rs | 3 +- ceno_zkvm/src/utils.rs | 2 +- 6 files changed, 21 insertions(+), 53 deletions(-) diff --git a/ceno_zkvm/src/gadgets/signed_limbs.rs b/ceno_zkvm/src/gadgets/signed_limbs.rs index f79710c91..93fcefad4 100644 --- a/ceno_zkvm/src/gadgets/signed_limbs.rs +++ b/ceno_zkvm/src/gadgets/signed_limbs.rs @@ -162,7 +162,7 @@ impl UIntLimbsLT { }); set_val!(instance, config.cmp_lt, cmp_lt as u64); - // We range check (read_rs1_msb_f + 128) and (read_rs2_msb_f + 128) if signed, + // We range check (read_rs1_msb_f + 32768) and (read_rs2_msb_f + 32768) if signed, // read_rs1_msb_f and read_rs2_msb_f if not let (a_msb_f, a_msb_range) = if a_sign { ( diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm.rs b/ceno_zkvm/src/instructions/riscv/arith_imm.rs index abab44807..32ce74007 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm.rs @@ -9,17 +9,7 @@ pub use crate::instructions::riscv::arith_imm::arith_imm_circuit_v2::AddiInstruc #[cfg(not(feature = "u16limb_circuit"))] pub use crate::instructions::riscv::arith_imm::arith_imm_circuit::AddiInstruction; -use std::marker::PhantomData; - -use ceno_emul::StepRecord; -use ff_ext::ExtensionField; - -use crate::{ - Value, circuit_builder::CircuitBuilder, error::ZKVMError, instructions::Instruction, - structs::ProgramParams, tables::InsnRecord, witness::LkMultiplicity, -}; - -use super::{RIVInstruction, constants::UInt, i_insn::IInstructionConfig}; +use super::RIVInstruction; impl RIVInstruction for AddiInstruction { const INST_KIND: ceno_emul::InsnKind = ceno_emul::InsnKind::ADDI; diff --git a/ceno_zkvm/src/instructions/riscv/slti.rs b/ceno_zkvm/src/instructions/riscv/slti.rs index a164a59cd..c9fb46b83 100644 --- a/ceno_zkvm/src/instructions/riscv/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/slti.rs @@ -10,29 +10,8 @@ use crate::instructions::riscv::slti::slti_circuit_v2::SetLessThanImmInstruction #[cfg(not(feature = "u16limb_circuit"))] use crate::instructions::riscv::slti::slti_circuit::SetLessThanImmInstruction; -use std::marker::PhantomData; - -use ceno_emul::{InsnKind, SWord, StepRecord, Word}; -use ff_ext::ExtensionField; - -use super::{ - RIVInstruction, - constants::{UINT_LIMBS, UInt}, - i_insn::IInstructionConfig, -}; -use crate::{ - circuit_builder::CircuitBuilder, - error::ZKVMError, - gadgets::{IsLtConfig, SignedExtendConfig}, - instructions::Instruction, - structs::ProgramParams, - tables::InsnRecord, - uint::Value, - utils::i64_to_base, - witness::{LkMultiplicity, set_val}, -}; -use ff_ext::FieldInto; -use multilinear_extensions::{ToExpr, WitIn}; +use super::{RIVInstruction, constants::UInt}; +use crate::{structs::ProgramParams, uint::Value}; pub struct SltiOp; impl RIVInstruction for SltiOp { @@ -74,6 +53,7 @@ mod test { verify("lt = true, 0 < u32::MAX-1", 0, -1); verify("lt = true, 1 < u32::MAX-1", 1, -1); verify("lt = true, 0 < imm lower bondary", 0, -2048); + verify("lt = true, 65535 < imm lower bondary", 65535, -1); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs index 5a76fb2c5..2f5cd55a7 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs @@ -56,14 +56,14 @@ impl Instruction for SetLessThanImmInst let (uint_lt_config, imm_sign_extend, imm_sign) = match I::INST_KIND { InsnKind::SLTIU => { + let imm_sign = cb.create_witin(|| "imm_sign"); let imm_sign_extend = UInt::from_exprs_unchecked( - imm_sign_extend_circuit::(false, E::BaseField::ZERO.expr(), imm.expr()) - .to_vec(), + imm_sign_extend_circuit::(true, imm_sign.expr(), imm.expr()).to_vec(), ); ( UIntLimbsLT::construct_circuit(cb, &rs1_read, &imm_sign_extend, false)?, imm_sign_extend, - None, + Some(imm_sign), ) } InsnKind::SLTI => { @@ -86,9 +86,7 @@ impl Instruction for SetLessThanImmInst cb, I::INST_KIND, imm_sign_extend.expr().remove(0), - imm_sign - .map(|imm_sign| imm_sign.expr()) - .unwrap_or(0.into()), + imm_sign.map(|imm_sign| imm_sign.expr()).unwrap_or(0.into()), rs1_read.register_expr(), rd_written.register_expr(), false, @@ -119,16 +117,17 @@ impl Instruction for SetLessThanImmInst .assign_value(instance, Value::new_unchecked(rs1)); let imm = step.insn().imm as i16 as u16; - let is_signed = matches!(step.insn().kind, InsnKind::SLT); + let is_signed = matches!(step.insn().kind, InsnKind::SLTI); set_val!(instance, config.imm, E::BaseField::from_canonical_u16(imm)); - let imm_sign_extend = imm_sign_extend(is_signed, step.insn().imm as i16); - if is_signed { - set_val!( - instance, - config.imm_sign.as_ref().unwrap(), - E::BaseField::from_bool(imm_sign_extend[1] > 0) - ); - } + // accroding to riscvim32 spec, imm always do signed extension + let imm_sign_extend = imm_sign_extend(true, step.insn().imm as i16); + // if is_signed { + set_val!( + instance, + config.imm_sign.as_ref().unwrap(), + E::BaseField::from_bool(imm_sign_extend[1] > 0) + ); + // } UIntLimbsLT::::assign( &config.uint_lt_config, diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index c00947f58..c799f109c 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -9,7 +9,6 @@ use ceno_emul::{ InsnFormat, InsnFormat::*, InsnKind::*, Instruction, PC_STEP_SIZE, Program, WORD_SIZE, }; use ff_ext::{ExtensionField, FieldInto, SmallField}; -use gkr_iop::utils::i64_to_base; use itertools::Itertools; use multilinear_extensions::{Expression, Fixed, ToExpr, WitIn}; use p3::field::FieldAlgebra; @@ -105,7 +104,7 @@ impl InsnRecord<()> { match (insn.kind, InsnFormat::from(insn.kind)) { (SLLI | SRLI | SRAI, _) => false, // Unsigned view. - (_, R | U) | (SLTIU | ANDI | XORI | ORI, _) => false, + (_, R | U) | (ANDI | XORI | ORI, _) => false, // Signed view. _ => insn.imm < 0, } diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index 35fc7ecc1..b15ffdc56 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -150,7 +150,7 @@ pub fn imm_sign_extend(is_signed: bool, imm: i16) -> [u16; UINT_LIMBS] { if !is_signed { [imm as u16, 0] } else { - if imm > 0 { + if imm >= 0 { [imm as u16, 0u16] } else { [imm as u16, 0xffff] From 9754a6aadc2256f92574a696e65b40e2492be381 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Tue, 12 Aug 2025 09:46:41 +0800 Subject: [PATCH 18/46] refactor slti(u) properly --- ceno_zkvm/src/gadgets/signed_limbs.rs | 14 +++---- ceno_zkvm/src/instructions/riscv/im_insn.rs | 3 +- ceno_zkvm/src/instructions/riscv/jump/jalr.rs | 3 +- .../riscv/slti/slti_circuit_v2.rs | 41 ++++++------------- ceno_zkvm/src/utils.rs | 4 +- 5 files changed, 26 insertions(+), 39 deletions(-) diff --git a/ceno_zkvm/src/gadgets/signed_limbs.rs b/ceno_zkvm/src/gadgets/signed_limbs.rs index 93fcefad4..e30a3a4ab 100644 --- a/ceno_zkvm/src/gadgets/signed_limbs.rs +++ b/ceno_zkvm/src/gadgets/signed_limbs.rs @@ -43,7 +43,7 @@ impl UIntLimbsLT { circuit_builder: &mut CircuitBuilder, a: &UInt, b: &UInt, - is_signed: bool, + is_sign_comparison: bool, ) -> Result, ZKVMError> { // 1 if a < b, 0 otherwise. let cmp_lt = circuit_builder.create_bit(|| "cmp_lt")?; @@ -117,7 +117,7 @@ impl UIntLimbsLT { circuit_builder.assert_ux::<_, _, LIMB_BITS>( || "a_msb_f_signed_range_check", a_msb_f.expr() - + if is_signed { + + if is_sign_comparison { E::BaseField::from_canonical_u32(1 << (LIMB_BITS - 1)).expr() } else { Expression::ZERO @@ -127,7 +127,7 @@ impl UIntLimbsLT { circuit_builder.assert_ux::<_, _, LIMB_BITS>( || "b_msb_f_signed_range_check", b_msb_f.expr() - + if is_signed { + + if is_sign_comparison { E::BaseField::from_canonical_u32(1 << (LIMB_BITS - 1)).expr() } else { Expression::ZERO @@ -150,9 +150,9 @@ impl UIntLimbsLT { lkm: &mut gkr_iop::utils::lk_multiplicity::LkMultiplicity, a: &[u16], b: &[u16], - is_signed: bool, + is_sign_comparison: bool, ) -> Result<(), CircuitBuilderError> { - let (cmp_lt, diff_idx, a_sign, b_sign) = run_cmp(is_signed, a, b); + let (cmp_lt, diff_idx, a_sign, b_sign) = run_cmp(is_sign_comparison, a, b); config .diff_marker .iter() @@ -172,7 +172,7 @@ impl UIntLimbsLT { } else { ( E::BaseField::from_canonical_u16(a[UINT_LIMBS - 1]), - a[UINT_LIMBS - 1] + ((is_signed as u16) << (LIMB_BITS - 1)), + a[UINT_LIMBS - 1] + ((is_sign_comparison as u16) << (LIMB_BITS - 1)), ) }; let (b_msb_f, b_msb_range) = if b_sign { @@ -183,7 +183,7 @@ impl UIntLimbsLT { } else { ( E::BaseField::from_canonical_u16(b[UINT_LIMBS - 1]), - b[UINT_LIMBS - 1] + ((is_signed as u16) << (LIMB_BITS - 1)), + b[UINT_LIMBS - 1] + ((is_sign_comparison as u16) << (LIMB_BITS - 1)), ) }; diff --git a/ceno_zkvm/src/instructions/riscv/im_insn.rs b/ceno_zkvm/src/instructions/riscv/im_insn.rs index f7e0f8393..c87737150 100644 --- a/ceno_zkvm/src/instructions/riscv/im_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/im_insn.rs @@ -50,7 +50,8 @@ impl IMInstructionConfig { rs1.id.expr(), 0.into(), imm.clone(), - #[cfg(feature = "u16limb_circuit")] 0.into(), + #[cfg(feature = "u16limb_circuit")] + 0.into(), ))?; Ok(IMInstructionConfig { diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs index 3cb3c9b20..71fc196ac 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs @@ -58,7 +58,8 @@ impl Instruction for JalrInstruction { circuit_builder, InsnKind::JALR, imm_sign_extend.expr().remove(0), - #[cfg(feature = "u16limb_circuit")] imm_sign.expr(), + #[cfg(feature = "u16limb_circuit")] + imm_sign.expr(), rs1_read.register_expr(), rd_written.register_expr(), true, diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs index 2f5cd55a7..1085561fb 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs @@ -29,7 +29,7 @@ pub struct SetLessThanImmConfig { rs1_read: UInt, imm: WitIn, // 0 positive, 1 negative - imm_sign: Option, + imm_sign: WitIn, #[allow(dead_code)] pub(crate) rd_written: UInt, @@ -53,29 +53,17 @@ impl Instruction for SetLessThanImmInst // If rs1_read < imm, rd_written = 1. Otherwise rd_written = 0 let rs1_read = UInt::new_unchecked(|| "rs1_read", cb)?; let imm = cb.create_witin(|| "imm"); - - let (uint_lt_config, imm_sign_extend, imm_sign) = match I::INST_KIND { + // a bool witness to mark sign extend of imm no matter sign/unsign + let imm_sign = cb.create_witin(|| "imm_sign"); + let imm_sign_extend = UInt::from_exprs_unchecked( + imm_sign_extend_circuit::(true, imm_sign.expr(), imm.expr()).to_vec(), + ); + let uint_lt_config = match I::INST_KIND { InsnKind::SLTIU => { - let imm_sign = cb.create_witin(|| "imm_sign"); - let imm_sign_extend = UInt::from_exprs_unchecked( - imm_sign_extend_circuit::(true, imm_sign.expr(), imm.expr()).to_vec(), - ); - ( - UIntLimbsLT::construct_circuit(cb, &rs1_read, &imm_sign_extend, false)?, - imm_sign_extend, - Some(imm_sign), - ) + UIntLimbsLT::construct_circuit(cb, &rs1_read, &imm_sign_extend, false)? } InsnKind::SLTI => { - let imm_sign = cb.create_witin(|| "imm_sign"); - let imm_sign_extend = UInt::from_exprs_unchecked( - imm_sign_extend_circuit::(true, imm_sign.expr(), imm.expr()).to_vec(), - ); - ( - UIntLimbsLT::construct_circuit(cb, &rs1_read, &imm_sign_extend, true)?, - imm_sign_extend, - Some(imm_sign), - ) + UIntLimbsLT::construct_circuit(cb, &rs1_read, &imm_sign_extend, true)? } _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), }; @@ -86,7 +74,7 @@ impl Instruction for SetLessThanImmInst cb, I::INST_KIND, imm_sign_extend.expr().remove(0), - imm_sign.map(|imm_sign| imm_sign.expr()).unwrap_or(0.into()), + imm_sign.expr(), rs1_read.register_expr(), rd_written.register_expr(), false, @@ -117,17 +105,14 @@ impl Instruction for SetLessThanImmInst .assign_value(instance, Value::new_unchecked(rs1)); let imm = step.insn().imm as i16 as u16; - let is_signed = matches!(step.insn().kind, InsnKind::SLTI); set_val!(instance, config.imm, E::BaseField::from_canonical_u16(imm)); - // accroding to riscvim32 spec, imm always do signed extension + // according to riscvim32 spec, imm always do signed extension let imm_sign_extend = imm_sign_extend(true, step.insn().imm as i16); - // if is_signed { set_val!( instance, - config.imm_sign.as_ref().unwrap(), + config.imm_sign, E::BaseField::from_bool(imm_sign_extend[1] > 0) ); - // } UIntLimbsLT::::assign( &config.uint_lt_config, @@ -135,7 +120,7 @@ impl Instruction for SetLessThanImmInst lkm, rs1_value.as_u16_limbs(), &imm_sign_extend, - is_signed, + matches!(step.insn().kind, InsnKind::SLTI), )?; Ok(()) } diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index b15ffdc56..499ce56fe 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -146,8 +146,8 @@ pub fn imm_sign_extend_circuit( } } #[inline(always)] -pub fn imm_sign_extend(is_signed: bool, imm: i16) -> [u16; UINT_LIMBS] { - if !is_signed { +pub fn imm_sign_extend(is_signed_extension: bool, imm: i16) -> [u16; UINT_LIMBS] { + if !is_signed_extension { [imm as u16, 0] } else { if imm >= 0 { From e708e6e8b06e8361d917dde9fc345ef1a9c82371 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Tue, 12 Aug 2025 10:06:20 +0800 Subject: [PATCH 19/46] fix clippy --- ceno_zkvm/src/instructions/riscv/slti.rs | 10 +++++++--- ceno_zkvm/src/utils.rs | 9 ++++----- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/slti.rs b/ceno_zkvm/src/instructions/riscv/slti.rs index c9fb46b83..9e3e99a65 100644 --- a/ceno_zkvm/src/instructions/riscv/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/slti.rs @@ -10,8 +10,7 @@ use crate::instructions::riscv::slti::slti_circuit_v2::SetLessThanImmInstruction #[cfg(not(feature = "u16limb_circuit"))] use crate::instructions::riscv::slti::slti_circuit::SetLessThanImmInstruction; -use super::{RIVInstruction, constants::UInt}; -use crate::{structs::ProgramParams, uint::Value}; +use super::RIVInstruction; pub struct SltiOp; impl RIVInstruction for SltiOp { @@ -34,12 +33,17 @@ mod test { use super::*; use crate::{ + Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, instructions::{ Instruction, - riscv::test_utils::{i32_extra, imm_extra, immu_extra, u32_extra}, + riscv::{ + constants::UInt, + test_utils::{i32_extra, imm_extra, immu_extra, u32_extra}, + }, }, scheme::mock_prover::{MOCK_PC_START, MockProver}, + structs::ProgramParams, }; #[test] diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index 499ce56fe..fda483942 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -147,14 +147,13 @@ pub fn imm_sign_extend_circuit( } #[inline(always)] pub fn imm_sign_extend(is_signed_extension: bool, imm: i16) -> [u16; UINT_LIMBS] { + #[allow(clippy::if_same_then_else)] if !is_signed_extension { [imm as u16, 0] + } else if imm >= 0 { + [imm as u16, 0u16] } else { - if imm >= 0 { - [imm as u16, 0u16] - } else { - [imm as u16, 0xffff] - } + [imm as u16, 0xffff] } } From eb4b9a7ecd1a42343bc834f7e5038cf8750d050f Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Tue, 12 Aug 2025 12:49:43 +0800 Subject: [PATCH 20/46] combine imm_internal + i64_base into one --- ceno_zkvm/Cargo.toml | 18 +++++----- .../riscv/arith_imm/arith_imm_circuit.rs | 8 ++--- ceno_zkvm/src/instructions/riscv/b_insn.rs | 7 ++-- .../src/instructions/riscv/branch/test.rs | 14 ++++---- .../instructions/riscv/dummy/dummy_circuit.rs | 2 +- ceno_zkvm/src/instructions/riscv/jump/jalr.rs | 8 ++--- .../riscv/logic_imm/logic_imm_circuit.rs | 5 +-- .../src/instructions/riscv/memory/load.rs | 6 ++-- .../src/instructions/riscv/memory/store.rs | 6 ++-- ceno_zkvm/src/instructions/riscv/shift_imm.rs | 2 +- .../instructions/riscv/slti/slti_circuit.rs | 10 +++--- ceno_zkvm/src/tables/program.rs | 36 ++++++++++++++----- 12 files changed, 70 insertions(+), 52 deletions(-) diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 43e8855ed..950de39a2 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -67,20 +67,20 @@ ceno-examples = { path = "../examples-builder" } glob = "0.3" [features] -default = ["forbid_overflow", "u16limb_circuit"] +default = ["forbid_overflow"] flamegraph = ["pprof2/flamegraph", "pprof2/criterion"] forbid_overflow = [] jemalloc = ["dep:tikv-jemallocator", "dep:tikv-jemalloc-ctl"] jemalloc-prof = ["jemalloc", "tikv-jemallocator?/profiling"] nightly-features = [ - "p3/nightly-features", - "ff_ext/nightly-features", - "mpcs/nightly-features", - "multilinear_extensions/nightly-features", - "poseidon/nightly-features", - "sumcheck/nightly-features", - "transcript/nightly-features", - "witness/nightly-features", + "p3/nightly-features", + "ff_ext/nightly-features", + "mpcs/nightly-features", + "multilinear_extensions/nightly-features", + "poseidon/nightly-features", + "sumcheck/nightly-features", + "transcript/nightly-features", + "witness/nightly-features", ] sanity-check = ["mpcs/sanity-check"] u16limb_circuit = [] diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit.rs b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit.rs index 26c1ecfd9..8a4722a08 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit.rs @@ -4,7 +4,7 @@ use crate::{ error::ZKVMError, instructions::{ Instruction, - riscv::{constants::UInt, i_insn::IInstructionConfig}, + riscv::{RIVInstruction, constants::UInt, i_insn::IInstructionConfig}, }, structs::ProgramParams, tables::InsnRecord, @@ -25,7 +25,7 @@ pub struct InstructionConfig { } impl Instruction for AddiInstruction { - type InstructionConfig = crate::instructions::riscv::arith_imm::InstructionConfig; + type InstructionConfig = InstructionConfig; fn name() -> String { format!("{:?}", Self::INST_KIND) @@ -48,7 +48,7 @@ impl Instruction for AddiInstruction { false, )?; - Ok(crate::instructions::riscv::arith_imm::InstructionConfig { + Ok(InstructionConfig { i_insn, rs1_read, imm, @@ -64,7 +64,7 @@ impl Instruction for AddiInstruction { ) -> Result<(), ZKVMError> { let rs1_read = Value::new_unchecked(step.rs1().unwrap().value); let imm = Value::new( - InsnRecord::imm_internal(&step.insn()) as u32, + InsnRecord::::imm_internal(&step.insn()).0 as u32, lk_multiplicity, ); diff --git a/ceno_zkvm/src/instructions/riscv/b_insn.rs b/ceno_zkvm/src/instructions/riscv/b_insn.rs index 04f0920bb..39fb5c6b7 100644 --- a/ceno_zkvm/src/instructions/riscv/b_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/b_insn.rs @@ -97,12 +97,9 @@ impl BInstructionConfig { self.rs1.assign_instance(instance, lk_multiplicity, step)?; self.rs2.assign_instance(instance, lk_multiplicity, step)?; + println!("&step.insn() {:?}", &step.insn()); // Immediate - set_val!( - instance, - self.imm, - i64_to_base::(InsnRecord::imm_internal(&step.insn())) - ); + set_val!(instance, self.imm, InsnRecord::::imm_internal(&step.insn()).1); // Fetch the instruction. lk_multiplicity.fetch(step.pc().before.0); diff --git a/ceno_zkvm/src/instructions/riscv/branch/test.rs b/ceno_zkvm/src/instructions/riscv/branch/test.rs index fd9da41a1..68483e50c 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/test.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/test.rs @@ -184,13 +184,13 @@ fn impl_bgeu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { #[test] fn test_blt_circuit() -> Result<(), ZKVMError> { impl_blt_circuit(false, 0, 0)?; - impl_blt_circuit(true, 0, 1)?; - - impl_blt_circuit(false, 1, -10)?; - impl_blt_circuit(false, -10, -10)?; - impl_blt_circuit(false, -9, -10)?; - impl_blt_circuit(true, -9, 1)?; - impl_blt_circuit(true, -10, -9)?; + // impl_blt_circuit(true, 0, 1)?; + // + // impl_blt_circuit(false, 1, -10)?; + // impl_blt_circuit(false, -10, -10)?; + // impl_blt_circuit(false, -9, -10)?; + // impl_blt_circuit(true, -9, 1)?; + // impl_blt_circuit(true, -10, -9)?; Ok(()) } diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs index 898ff856f..6b282cb77 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs @@ -261,7 +261,7 @@ impl DummyConfig { mem_write.assign_instance::(instance, lk_multiplicity, step)?; } - let imm = i64_to_base::(InsnRecord::imm_internal(&step.insn())); + let imm = InsnRecord::::imm_internal(&step.insn()).1; set_val!(instance, self.imm, imm); Ok(()) diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs index 71fc196ac..adebc313f 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs @@ -125,10 +125,10 @@ impl Instruction for JalrInstruction { let insn = step.insn(); let rs1 = step.rs1().unwrap().value; - let imm = InsnRecord::imm_internal(&insn); + let imm = InsnRecord::::imm_internal(&insn); let rd = step.rd().unwrap().value.after; - let (sum, overflowing) = rs1.overflowing_add_signed(imm as i32); + let (sum, overflowing) = rs1.overflowing_add_signed(imm.0 as i32); config .rs1_read @@ -137,14 +137,14 @@ impl Instruction for JalrInstruction { .rd_written .assign_value(instance, Value::new(rd, lk_multiplicity)); - set_val!(instance, config.imm, i64_to_base::(imm)); + set_val!(instance, config.imm, imm.1); config .next_pc_addr .assign_instance(instance, lk_multiplicity, sum)?; if let Some((overflow_cfg, tmp_cfg)) = &config.overflow { - let (overflow, tmp) = match (overflowing, imm < 0) { + let (overflow, tmp) = match (overflowing, imm.0 < 0) { (false, _) => (E::BaseField::ZERO, E::BaseField::ONE), (true, false) => (E::BaseField::ONE, E::BaseField::ZERO), (true, true) => (-E::BaseField::ONE, E::BaseField::ZERO), diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs index 259ccf7ab..ab66963d8 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs @@ -61,7 +61,7 @@ impl Instruction for LogicInstruction { UInt8::::logic_assign::( lkm, step.rs1().unwrap().value.into(), - InsnRecord::imm_internal(&step.insn()) as u64, + InsnRecord::::imm_internal(&step.insn()).0 as u64, ); config.assign_instance(instance, lkm, step) @@ -118,7 +118,8 @@ impl LogicConfig { let rs1_read = split_to_u8(step.rs1().unwrap().value); self.rs1_read.assign_limbs(instance, &rs1_read); - let imm = split_to_u8::(InsnRecord::imm_internal(&step.insn()) as u32); + let imm = + split_to_u8::(InsnRecord::::imm_internal(&step.insn()).0 as u32); self.imm.assign_limbs(instance, &imm); let rd_written = split_to_u8(step.rd().unwrap().value.after); diff --git a/ceno_zkvm/src/instructions/riscv/memory/load.rs b/ceno_zkvm/src/instructions/riscv/memory/load.rs index d65756dc7..47b56cdd0 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load.rs @@ -204,15 +204,15 @@ impl Instruction for LoadInstruction::imm_internal(&step.insn()); let unaligned_addr = - ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm as i32)); + ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); let shift = unaligned_addr.shift(); let addr_low_bits = [shift & 0x01, (shift >> 1) & 0x01]; let target_limb = memory_read.as_u16_limbs()[addr_low_bits[1] as usize]; let mut target_limb_bytes = target_limb.to_le_bytes(); - set_val!(instance, config.imm, i64_to_base::(imm)); + set_val!(instance, config.imm, imm.1); config .im_insn .assign_instance(instance, lk_multiplicity, step)?; diff --git a/ceno_zkvm/src/instructions/riscv/memory/store.rs b/ceno_zkvm/src/instructions/riscv/memory/store.rs index 4e053ee8e..0cf049b88 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store.rs @@ -136,16 +136,16 @@ impl Instruction let rs1 = Value::new_unchecked(step.rs1().unwrap().value); let rs2 = Value::new_unchecked(step.rs2().unwrap().value); let memory_op = step.memory_op().unwrap(); - let imm = InsnRecord::imm_internal(&step.insn()); + let imm = InsnRecord::::imm_internal(&step.insn()); let prev_mem_value = Value::new(memory_op.value.before, lk_multiplicity); - let addr = ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm as i32)); + let addr = ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); config .s_insn .assign_instance(instance, lk_multiplicity, step)?; config.rs1_read.assign_value(instance, rs1); config.rs2_read.assign_value(instance, rs2); - set_val!(instance, config.imm, i64_to_base::(imm)); + set_val!(instance, config.imm, imm.1); config .prev_memory_value .assign_value(instance, prev_mem_value); diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm.rs b/ceno_zkvm/src/instructions/riscv/shift_imm.rs index cc5f557f7..e4a951921 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm.rs @@ -153,7 +153,7 @@ impl Instruction for ShiftImmInstructio step: &StepRecord, ) -> Result<(), ZKVMError> { // imm_internal is a precomputed 2**shift. - let imm = InsnRecord::imm_internal(&step.insn()) as u64; + let imm = InsnRecord::::imm_internal(&step.insn()).0 as u64; let rs1_read = Value::new_unchecked(step.rs1().unwrap().value); let rd_written = Value::new(step.rd().unwrap().value.after, lk_multiplicity); diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs index 91ce2d99c..74e272da4 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs @@ -71,7 +71,7 @@ impl Instruction for SetLessThanImmInst cb, I::INST_KIND, imm.expr(), - E::BaseField::ZERO.expr(), + #[cfg(feature = "u16limb_circuit")] E::BaseField::ZERO.expr(), rs1_read.register_expr(), rd_written.register_expr(), false, @@ -101,14 +101,14 @@ impl Instruction for SetLessThanImmInst .rs1_read .assign_value(instance, Value::new_unchecked(rs1)); - let imm = InsnRecord::imm_internal(&step.insn()); - set_val!(instance, config.imm, i64_to_base::(imm)); + let imm = InsnRecord::::imm_internal(&step.insn()); + set_val!(instance, config.imm, imm.1); match I::INST_KIND { InsnKind::SLTIU => { config .lt - .assign_instance(instance, lkm, rs1 as u64, imm as u64)?; + .assign_instance(instance, lkm, rs1 as u64, imm.0 as u64)?; } InsnKind::SLTI => { config.is_rs1_neg.as_ref().unwrap().assign_instance( @@ -116,7 +116,7 @@ impl Instruction for SetLessThanImmInst lkm, *rs1_value.as_u16_limbs().last().unwrap() as u64, )?; - let (rs1, imm) = (rs1 as SWord, imm as SWord); + let (rs1, imm) = (rs1 as SWord, imm.0 as SWord); config .lt .assign_instance_signed(instance, lkm, rs1 as i64, imm as i64)?; diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index c799f109c..066319c13 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -9,6 +9,7 @@ use ceno_emul::{ InsnFormat, InsnFormat::*, InsnKind::*, Instruction, PC_STEP_SIZE, Program, WORD_SIZE, }; use ff_ext::{ExtensionField, FieldInto, SmallField}; +use gkr_iop::utils::i64_to_base; use itertools::Itertools; use multilinear_extensions::{Expression, Fixed, ToExpr, WitIn}; use p3::field::FieldAlgebra; @@ -59,7 +60,7 @@ impl InsnRecord { (insn.rd_internal() as u64).into_f(), (insn.rs1_or_zero() as u64).into_f(), (insn.rs2_or_zero() as u64).into_f(), - i64_to_base(InsnRecord::imm_internal(insn)), + InsnRecord::imm_internal(insn).1, ]) } @@ -78,7 +79,7 @@ impl InsnRecord { } } -impl InsnRecord<()> { +impl InsnRecord { /// The internal view of the immediate in the program table. /// This is encoded in a way that is efficient for circuits, depending on the instruction. /// @@ -86,17 +87,33 @@ impl InsnRecord<()> { /// - `as u32` and `as i32` as usual. /// - `i64_to_base(imm)` gives the field element going into the program table. /// - `as u64` in unsigned cases. - pub fn imm_internal(insn: &Instruction) -> i64 { + #[cfg(not(feature = "u16limb_circuit"))] + pub fn imm_internal(insn: &Instruction) -> (i64, F) { match (insn.kind, InsnFormat::from(insn.kind)) { // Prepare the immediate for ShiftImmInstruction. // The shift is implemented as a multiplication/division by 1 << immediate. - (SLLI | SRLI | SRAI, _) => 1 << insn.imm, + (SLLI | SRLI | SRAI, _) => (1 << insn.imm, i64_to_base(1 << insn.imm)), // Unsigned view. - // For example, u32::MAX is `u32::MAX mod p` in the finite field. - (_, R | U) | (ADDI | SLTIU | ANDI | XORI | ORI, _) => insn.imm as u32 as i64, + // For example, u32::MAX is `u32::MAX mod p` in the finite field + (_, R | U) | (ADDI | SLTIU | ANDI | XORI | ORI, _) => { + (insn.imm as u32 as i64, i64_to_base(insn.imm as u32 as i64)) + } // Signed view. // For example, u32::MAX is `-1 mod p` in the finite field. - _ => insn.imm as i64, + _ => (insn.imm as i64, i64_to_base(insn.imm as i64)), + } + } + + #[cfg(feature = "u16limb_circuit")] + pub fn imm_internal(insn: &Instruction) -> i64 { + match (insn.kind, InsnFormat::from(insn.kind)) { + // Prepare the immediate for ShiftImmInstruction. + // The shift is implemented as a multiplication/division by 1 << immediate. + (SLLI | SRLI | SRAI, _) => 1 << insn.imm, + // for imm operate with program counter => convert to field value + (BLT, _) => i64_to_base(insn.imm as u32 as i64), + // for default imm to operate with register value + _ => F::from_canonical_u16(insn.imm as i16 as u16), } } @@ -105,7 +122,10 @@ impl InsnRecord<()> { (SLLI | SRLI | SRAI, _) => false, // Unsigned view. (_, R | U) | (ANDI | XORI | ORI, _) => false, - // Signed view. + // in particular imm operated with program counter + // encode as field element, which do not need extra sign extension of imm + (BLT, _) => false, + // Signed views _ => insn.imm < 0, } } From 8cb9a85ea326fae63445bb0845c4b85ffbf833d5 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Tue, 12 Aug 2025 12:50:28 +0800 Subject: [PATCH 21/46] reformat code --- ceno_zkvm/Cargo.toml | 18 +++++++++--------- ceno_zkvm/src/instructions/riscv/b_insn.rs | 6 +++++- .../instructions/riscv/slti/slti_circuit.rs | 3 ++- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 950de39a2..43e8855ed 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -67,20 +67,20 @@ ceno-examples = { path = "../examples-builder" } glob = "0.3" [features] -default = ["forbid_overflow"] +default = ["forbid_overflow", "u16limb_circuit"] flamegraph = ["pprof2/flamegraph", "pprof2/criterion"] forbid_overflow = [] jemalloc = ["dep:tikv-jemallocator", "dep:tikv-jemalloc-ctl"] jemalloc-prof = ["jemalloc", "tikv-jemallocator?/profiling"] nightly-features = [ - "p3/nightly-features", - "ff_ext/nightly-features", - "mpcs/nightly-features", - "multilinear_extensions/nightly-features", - "poseidon/nightly-features", - "sumcheck/nightly-features", - "transcript/nightly-features", - "witness/nightly-features", + "p3/nightly-features", + "ff_ext/nightly-features", + "mpcs/nightly-features", + "multilinear_extensions/nightly-features", + "poseidon/nightly-features", + "sumcheck/nightly-features", + "transcript/nightly-features", + "witness/nightly-features", ] sanity-check = ["mpcs/sanity-check"] u16limb_circuit = [] diff --git a/ceno_zkvm/src/instructions/riscv/b_insn.rs b/ceno_zkvm/src/instructions/riscv/b_insn.rs index 39fb5c6b7..ea39ddcd2 100644 --- a/ceno_zkvm/src/instructions/riscv/b_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/b_insn.rs @@ -99,7 +99,11 @@ impl BInstructionConfig { println!("&step.insn() {:?}", &step.insn()); // Immediate - set_val!(instance, self.imm, InsnRecord::::imm_internal(&step.insn()).1); + set_val!( + instance, + self.imm, + InsnRecord::::imm_internal(&step.insn()).1 + ); // Fetch the instruction. lk_multiplicity.fetch(step.pc().before.0); diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs index 74e272da4..63fe08ffc 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs @@ -71,7 +71,8 @@ impl Instruction for SetLessThanImmInst cb, I::INST_KIND, imm.expr(), - #[cfg(feature = "u16limb_circuit")] E::BaseField::ZERO.expr(), + #[cfg(feature = "u16limb_circuit")] + E::BaseField::ZERO.expr(), rs1_read.register_expr(), rd_written.register_expr(), false, From caf4e27bf5afc06ceda633e8793d0e2c640e8af6 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Tue, 12 Aug 2025 14:02:21 +0800 Subject: [PATCH 22/46] fix clippy --- ceno_zkvm/src/instructions/riscv/b_insn.rs | 1 - .../src/instructions/riscv/dummy/dummy_circuit.rs | 1 - ceno_zkvm/src/instructions/riscv/jump/jalr.rs | 2 +- ceno_zkvm/src/instructions/riscv/memory/load.rs | 1 - ceno_zkvm/src/instructions/riscv/memory/store.rs | 1 - ceno_zkvm/src/tables/program.rs | 15 +++++++++------ 6 files changed, 10 insertions(+), 11 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/b_insn.rs b/ceno_zkvm/src/instructions/riscv/b_insn.rs index ea39ddcd2..95a092e33 100644 --- a/ceno_zkvm/src/instructions/riscv/b_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/b_insn.rs @@ -8,7 +8,6 @@ use crate::{ error::ZKVMError, instructions::riscv::insn_base::{ReadRS1, ReadRS2, StateInOut}, tables::InsnRecord, - utils::i64_to_base, witness::{LkMultiplicity, set_val}, }; use ff_ext::FieldInto; diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs index 6b282cb77..abb92a165 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs @@ -2,7 +2,6 @@ use std::marker::PhantomData; use ceno_emul::{InsnCategory, InsnFormat, InsnKind, StepRecord}; use ff_ext::ExtensionField; -use gkr_iop::utils::i64_to_base; use super::super::{ RIVInstruction, diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs index adebc313f..46ca78a48 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs @@ -12,7 +12,7 @@ use crate::{ }, structs::ProgramParams, tables::InsnRecord, - utils::{i64_to_base, imm_sign_extend_circuit}, + utils::imm_sign_extend_circuit, witness::{LkMultiplicity, set_val}, }; use ceno_emul::{InsnKind, PC_STEP_SIZE}; diff --git a/ceno_zkvm/src/instructions/riscv/memory/load.rs b/ceno_zkvm/src/instructions/riscv/memory/load.rs index 47b56cdd0..615bffdf3 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load.rs @@ -11,7 +11,6 @@ use crate::{ }, structs::ProgramParams, tables::InsnRecord, - utils::i64_to_base, witness::{LkMultiplicity, set_val}, }; use ceno_emul::{ByteAddr, InsnKind, StepRecord}; diff --git a/ceno_zkvm/src/instructions/riscv/memory/store.rs b/ceno_zkvm/src/instructions/riscv/memory/store.rs index 0cf049b88..dd08064b7 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store.rs @@ -11,7 +11,6 @@ use crate::{ }, structs::ProgramParams, tables::InsnRecord, - utils::i64_to_base, witness::{LkMultiplicity, set_val}, }; use ceno_emul::{ByteAddr, InsnKind, StepRecord}; diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 066319c13..cc16b0658 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -72,8 +72,8 @@ impl InsnRecord { (insn.rd_internal() as u64).into_f(), (insn.rs1_or_zero() as u64).into_f(), (insn.rs2_or_zero() as u64).into_f(), - F::from_canonical_u16(insn.imm as i16 as u16), - F::from_bool(InsnRecord::imm_signed_internal(insn)), + InsnRecord::imm_internal(insn).1, + F::from_bool(InsnRecord::::imm_signed_internal(insn)), ]) } } @@ -105,15 +105,18 @@ impl InsnRecord { } #[cfg(feature = "u16limb_circuit")] - pub fn imm_internal(insn: &Instruction) -> i64 { + pub fn imm_internal(insn: &Instruction) -> (i64, F) { match (insn.kind, InsnFormat::from(insn.kind)) { // Prepare the immediate for ShiftImmInstruction. // The shift is implemented as a multiplication/division by 1 << immediate. - (SLLI | SRLI | SRAI, _) => 1 << insn.imm, + (SLLI | SRLI | SRAI, _) => (1 << insn.imm, i64_to_base(1 << insn.imm)), // for imm operate with program counter => convert to field value - (BLT, _) => i64_to_base(insn.imm as u32 as i64), + (BLT, _) => (insn.imm as u32 as i64, i64_to_base(insn.imm as u32 as i64)), // for default imm to operate with register value - _ => F::from_canonical_u16(insn.imm as i16 as u16), + _ => ( + insn.imm as u32 as i64, + F::from_canonical_u16(insn.imm as i16 as u16), + ), } } From f947856977a65bc7d1d1baf1b1f8974acc27bdb9 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Tue, 12 Aug 2025 14:23:41 +0800 Subject: [PATCH 23/46] bge/blt test pass --- ceno_zkvm/src/instructions/riscv/b_insn.rs | 1 - ceno_zkvm/src/instructions/riscv/branch/test.rs | 13 ++++++------- ceno_zkvm/src/tables/program.rs | 4 ++-- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/b_insn.rs b/ceno_zkvm/src/instructions/riscv/b_insn.rs index 95a092e33..798902754 100644 --- a/ceno_zkvm/src/instructions/riscv/b_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/b_insn.rs @@ -96,7 +96,6 @@ impl BInstructionConfig { self.rs1.assign_instance(instance, lk_multiplicity, step)?; self.rs2.assign_instance(instance, lk_multiplicity, step)?; - println!("&step.insn() {:?}", &step.insn()); // Immediate set_val!( instance, diff --git a/ceno_zkvm/src/instructions/riscv/branch/test.rs b/ceno_zkvm/src/instructions/riscv/branch/test.rs index 68483e50c..ee8e60615 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/test.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/test.rs @@ -184,13 +184,12 @@ fn impl_bgeu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { #[test] fn test_blt_circuit() -> Result<(), ZKVMError> { impl_blt_circuit(false, 0, 0)?; - // impl_blt_circuit(true, 0, 1)?; - // - // impl_blt_circuit(false, 1, -10)?; - // impl_blt_circuit(false, -10, -10)?; - // impl_blt_circuit(false, -9, -10)?; - // impl_blt_circuit(true, -9, 1)?; - // impl_blt_circuit(true, -10, -9)?; + impl_blt_circuit(true, 0, 1)?; + impl_blt_circuit(false, 1, -10)?; + impl_blt_circuit(false, -10, -10)?; + impl_blt_circuit(false, -9, -10)?; + impl_blt_circuit(true, -9, 1)?; + impl_blt_circuit(true, -10, -9)?; Ok(()) } diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index cc16b0658..9ab5590a0 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -111,7 +111,7 @@ impl InsnRecord { // The shift is implemented as a multiplication/division by 1 << immediate. (SLLI | SRLI | SRAI, _) => (1 << insn.imm, i64_to_base(1 << insn.imm)), // for imm operate with program counter => convert to field value - (BLT, _) => (insn.imm as u32 as i64, i64_to_base(insn.imm as u32 as i64)), + (_, B) => (insn.imm as i64, i64_to_base(insn.imm as i64)), // for default imm to operate with register value _ => ( insn.imm as u32 as i64, @@ -127,7 +127,7 @@ impl InsnRecord { (_, R | U) | (ANDI | XORI | ORI, _) => false, // in particular imm operated with program counter // encode as field element, which do not need extra sign extension of imm - (BLT, _) => false, + (_, B) => false, // Signed views _ => insn.imm < 0, } From ea6a827570ffdc93908793de1a3dc2624a598573 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Wed, 13 Aug 2025 00:17:18 +0800 Subject: [PATCH 24/46] logic imm test pass --- ceno_zkvm/src/instructions/riscv/logic_imm.rs | 22 ++- .../riscv/logic_imm/logic_imm_circuit.rs | 10 +- .../riscv/logic_imm/logic_imm_circuit_v2.rs | 180 ++++++++++++++++++ .../src/instructions/riscv/logic_imm/test.rs | 92 +++++++++ ceno_zkvm/src/tables/program.rs | 24 ++- ceno_zkvm/src/uint/logic.rs | 4 +- 6 files changed, 314 insertions(+), 18 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs create mode 100644 ceno_zkvm/src/instructions/riscv/logic_imm/test.rs diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm.rs b/ceno_zkvm/src/instructions/riscv/logic_imm.rs index 97a628cac..a4b46edcc 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm.rs @@ -1,8 +1,28 @@ +#[cfg(not(feature = "u16limb_circuit"))] mod logic_imm_circuit; + +#[cfg(feature = "u16limb_circuit")] +mod logic_imm_circuit_v2; + +#[cfg(not(feature = "u16limb_circuit"))] +pub use crate::instructions::riscv::logic_imm::logic_imm_circuit::LogicInstruction; + +#[cfg(feature = "u16limb_circuit")] +pub use crate::instructions::riscv::logic_imm::logic_imm_circuit_v2::LogicInstruction; + +#[cfg(test)] +mod test; + +/// This trait defines a logic instruction, connecting an instruction type to a lookup table. +pub trait LogicOp { + const INST_KIND: InsnKind; + type OpsTable: OpsTable; +} + use gkr_iop::tables::ops::{AndTable, OrTable, XorTable}; -use logic_imm_circuit::{LogicInstruction, LogicOp}; use ceno_emul::InsnKind; +use gkr_iop::tables::OpsTable; pub struct AndiOp; impl LogicOp for AndiOp { diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs index ab66963d8..cfb6bbc59 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs @@ -9,7 +9,7 @@ use crate::{ error::ZKVMError, instructions::{ Instruction, - riscv::{constants::UInt8, i_insn::IInstructionConfig}, + riscv::{constants::UInt8, i_insn::IInstructionConfig, logic_imm::LogicOp}, }, structs::ProgramParams, tables::InsnRecord, @@ -18,12 +18,6 @@ use crate::{ }; use ceno_emul::{InsnKind, StepRecord}; -/// This trait defines a logic instruction, connecting an instruction type to a lookup table. -pub trait LogicOp { - const INST_KIND: InsnKind; - type OpsTable: OpsTable; -} - /// The Instruction circuit for a given LogicOp. pub struct LogicInstruction(PhantomData<(E, I)>); @@ -92,8 +86,6 @@ impl LogicConfig { cb, insn_kind, imm.value(), - #[cfg(feature = "u16limb_circuit")] - 0.into(), rs1_read.register_expr(), rd_written.register_expr(), false, diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs new file mode 100644 index 000000000..855409a41 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs @@ -0,0 +1,180 @@ +//! The circuit implementation of logic instructions. + +use ff_ext::ExtensionField; +use gkr_iop::tables::OpsTable; +use itertools::Itertools; +use std::marker::PhantomData; + +use crate::{ + circuit_builder::CircuitBuilder, + error::ZKVMError, + instructions::{ + Instruction, + riscv::{ + constants::{LIMB_BITS, LIMB_MASK, UInt8}, + i_insn::IInstructionConfig, + logic_imm::LogicOp, + }, + }, + structs::ProgramParams, + tables::InsnRecord, + uint::UIntLimbs, + utils::split_to_u8, + witness::LkMultiplicity, +}; +use ceno_emul::{InsnKind, StepRecord}; +use multilinear_extensions::ToExpr; + +/// The Instruction circuit for a given LogicOp. +pub struct LogicInstruction(PhantomData<(E, I)>); + +impl Instruction for LogicInstruction { + type InstructionConfig = LogicConfig; + + fn name() -> String { + format!("{:?}", I::INST_KIND) + } + + fn construct_circuit( + cb: &mut CircuitBuilder, + _params: &ProgramParams, + ) -> Result { + let config = LogicConfig::construct_circuit(cb, I::INST_KIND)?; + + // Constrain the registers based on the given lookup table. + // lo + UIntLimbs::<{ LIMB_BITS }, 8, E>::logic( + cb, + I::OpsTable::ROM_TYPE, + &UIntLimbs::from_exprs_unchecked( + config.rs1_read.expr().into_iter().take(2).collect_vec(), + ), + &config.imm_lo, + &UIntLimbs::from_exprs_unchecked( + config.rd_written.expr().into_iter().take(2).collect_vec(), + ), + )?; + // hi + UIntLimbs::<{ LIMB_BITS }, 8, E>::logic( + cb, + I::OpsTable::ROM_TYPE, + &UIntLimbs::from_exprs_unchecked( + config + .rs1_read + .expr() + .into_iter() + .skip(2) + .take(2) + .collect_vec(), + ), + &config.imm_hi, + &UIntLimbs::from_exprs_unchecked( + config + .rd_written + .expr() + .into_iter() + .skip(2) + .take(2) + .collect_vec(), + ), + )?; + + Ok(config) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [::BaseField], + lkm: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + let rs1_lo = step.rs1().unwrap().value & LIMB_MASK; + let rs1_hi = (step.rs1().unwrap().value >> LIMB_BITS) & LIMB_MASK; + let imm_lo = InsnRecord::::imm_internal(&step.insn()).0 as u32 & LIMB_MASK; + let imm_hi = (InsnRecord::::imm_signed_internal(&step.insn()).0 as u32 + >> LIMB_BITS) + & LIMB_MASK; + UIntLimbs::<{ LIMB_BITS }, 8, E>::logic_assign::( + lkm, + rs1_lo.into(), + imm_lo.into(), + ); + UIntLimbs::<{ LIMB_BITS }, 8, E>::logic_assign::( + lkm, + rs1_hi.into(), + imm_hi.into(), + ); + + config.assign_instance(instance, lkm, step) + } +} + +/// This config implements I-Instructions that represent registers values as 4 * u8. +/// Non-generic code shared by several circuits. +#[derive(Debug)] +pub struct LogicConfig { + i_insn: IInstructionConfig, + + rs1_read: UInt8, + pub(crate) rd_written: UInt8, + imm_lo: UIntLimbs<{ LIMB_BITS }, 8, E>, + imm_hi: UIntLimbs<{ LIMB_BITS }, 8, E>, +} + +impl LogicConfig { + fn construct_circuit( + cb: &mut CircuitBuilder, + insn_kind: InsnKind, + ) -> Result { + let rs1_read = UInt8::new_unchecked(|| "rs1_read", cb)?; + let rd_written = UInt8::new_unchecked(|| "rd_written", cb)?; + let imm_lo = UIntLimbs::<{ LIMB_BITS }, 8, E>::new_unchecked(|| "imm_lo", cb)?; + let imm_hi = UIntLimbs::<{ LIMB_BITS }, 8, E>::new_unchecked(|| "imm_hi", cb)?; + + let i_insn = IInstructionConfig::::construct_circuit( + cb, + insn_kind, + imm_lo.value(), + imm_hi.value(), + rs1_read.register_expr(), + rd_written.register_expr(), + false, + )?; + + Ok(Self { + i_insn, + rs1_read, + imm_lo, + imm_hi, + rd_written, + }) + } + + fn assign_instance( + &self, + instance: &mut [::BaseField], + lkm: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + self.i_insn.assign_instance(instance, lkm, step)?; + + let rs1_read = split_to_u8(step.rs1().unwrap().value); + self.rs1_read.assign_limbs(instance, &rs1_read); + + let imm_lo = + split_to_u8::(InsnRecord::::imm_internal(&step.insn()).0 as u32) + [0..2] + .to_vec(); + let imm_hi = split_to_u8::( + InsnRecord::::imm_signed_internal(&step.insn()).0 as u32, + )[2..] + .to_vec(); + self.imm_lo.assign_limbs(instance, &imm_lo); + self.imm_hi.assign_limbs(instance, &imm_hi); + + let rd_written = split_to_u8(step.rd().unwrap().value.after); + self.rd_written.assign_limbs(instance, &rd_written); + + Ok(()) + } +} diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs new file mode 100644 index 000000000..23aa2d77c --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs @@ -0,0 +1,92 @@ +use ceno_emul::{Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32u}; +use ff_ext::GoldilocksExt2; +use gkr_iop::circuit_builder::DebugIndex; + +use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{ + Instruction, + riscv::{ + constants::UInt8, + logic_imm::{AndiOp, LogicInstruction, LogicOp, OriOp, XoriOp}, + }, + }, + scheme::mock_prover::{MOCK_PC_START, MockProver}, + structs::ProgramParams, + utils::split_to_u8, +}; + +/// An arbitrary test value. +const TEST: u32 = 0xabed_5eff; +/// An example of a sign-extended negative immediate value. +const NEG: u32 = 0xffff_ff55; + +#[test] +fn test_opcode_andi() { + verify::("basic", 0x0000_0011, 3, 0x0000_0011 & 3); + verify::("zero result", 0x0000_0100, 3, 0x0000_0100 & 3); + verify::("negative imm", TEST, NEG, TEST & NEG); +} + +#[test] +fn test_opcode_ori() { + verify::("basic", 0x0000_0011, 3, 0x0000_0011 | 3); + verify::("basic2", 0x0000_0100, 3, 0x0000_0100 | 3); + verify::("negative imm", TEST, NEG, TEST | NEG); +} + +#[test] +fn test_opcode_xori() { + verify::("basic", 0x0000_0011, 3, 0x0000_0011 ^ 3); + verify::("non-overlap", 0x0000_0100, 3, 0x0000_0100 ^ 3); + verify::("negative imm", TEST, NEG, TEST ^ NEG); +} + +fn verify(name: &'static str, rs1_read: u32, imm: u32, expected_rd_written: u32) { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + + let (prefix, rd_written) = match I::INST_KIND { + InsnKind::ANDI => ("ANDI", rs1_read & imm), + InsnKind::ORI => ("ORI", rs1_read | imm), + InsnKind::XORI => ("XORI", rs1_read ^ imm), + _ => unreachable!(), + }; + + let config = cb + .namespace( + || format!("{prefix}_({name})"), + |cb| { + let config = LogicInstruction::::construct_circuit( + cb, + &ProgramParams::default(), + ); + Ok(config) + }, + ) + .unwrap() + .unwrap(); + + let insn_code = encode_rv32u(I::INST_KIND, 2, 0, 4, imm); + let (raw_witin, lkm) = LogicInstruction::::assign_instances( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + vec![StepRecord::new_i_instruction( + 3, + Change::new(MOCK_PC_START, MOCK_PC_START + PC_STEP_SIZE), + insn_code, + rs1_read, + Change::new(0, rd_written), + 0, + )], + ) + .unwrap(); + + let expected = UInt8::from_const_unchecked(split_to_u8::(expected_rd_written)); + let rd_written_expr = cb.get_debug_expr(DebugIndex::RdWrite as usize)[0].clone(); + cb.require_equal(|| "assert_rd_written", rd_written_expr, expected.value()) + .unwrap(); + + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); +} diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 9ab5590a0..a80ff4bf4 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -2,6 +2,7 @@ use super::RMMCollections; use crate::{ circuit_builder::{CircuitBuilder, SetTableSpec}, error::ZKVMError, + instructions::riscv::constants::LIMB_BITS, structs::{ProgramParams, ROMType}, tables::TableCircuit, }; @@ -16,6 +17,7 @@ use p3::field::FieldAlgebra; use rayon::iter::{IndexedParallelIterator, ParallelIterator}; use std::{collections::HashMap, marker::PhantomData}; use witness::{InstancePaddingStrategy, RowMajorMatrix, set_fixed_val, set_val}; + /// This structure establishes the order of the fields in instruction records, common to the program table and circuit fetches. #[cfg(not(feature = "u16limb_circuit"))] @@ -73,7 +75,7 @@ impl InsnRecord { (insn.rs1_or_zero() as u64).into_f(), (insn.rs2_or_zero() as u64).into_f(), InsnRecord::imm_internal(insn).1, - F::from_bool(InsnRecord::::imm_signed_internal(insn)), + InsnRecord::::imm_signed_internal(insn).1, ]) } } @@ -110,6 +112,11 @@ impl InsnRecord { // Prepare the immediate for ShiftImmInstruction. // The shift is implemented as a multiplication/division by 1 << immediate. (SLLI | SRLI | SRAI, _) => (1 << insn.imm, i64_to_base(1 << insn.imm)), + // logic imm + (XORI | ORI | ANDI, _) => ( + insn.imm as i16 as i64, + F::from_canonical_u16(insn.imm as u16), + ), // for imm operate with program counter => convert to field value (_, B) => (insn.imm as i64, i64_to_base(insn.imm as i64)), // for default imm to operate with register value @@ -120,16 +127,21 @@ impl InsnRecord { } } - pub fn imm_signed_internal(insn: &Instruction) -> bool { + pub fn imm_signed_internal(insn: &Instruction) -> (i64, F) { match (insn.kind, InsnFormat::from(insn.kind)) { - (SLLI | SRLI | SRAI, _) => false, + (SLLI | SRLI | SRAI, _) => (false as i64, F::from_bool(false)), + // logic imm + (XORI | ORI | ANDI, _) => ( + (insn.imm >> LIMB_BITS) as i16 as i64, + F::from_canonical_u16((insn.imm >> LIMB_BITS) as u16), + ), // Unsigned view. - (_, R | U) | (ANDI | XORI | ORI, _) => false, + (_, R | U) => (false as i64, F::from_bool(false)), // in particular imm operated with program counter // encode as field element, which do not need extra sign extension of imm - (_, B) => false, + (_, B) => (false as i64, F::from_bool(false)), // Signed views - _ => insn.imm < 0, + _ => ((insn.imm < 0) as i64, F::from_bool(insn.imm < 0)), } } } diff --git a/ceno_zkvm/src/uint/logic.rs b/ceno_zkvm/src/uint/logic.rs index 210929920..ee0622839 100644 --- a/ceno_zkvm/src/uint/logic.rs +++ b/ceno_zkvm/src/uint/logic.rs @@ -17,8 +17,8 @@ impl UIntLimbs { b: &Self, c: &Self, ) -> Result<(), ZKVMError> { - for (a_byte, b_byte, c_byte) in izip!(&a.limbs, &b.limbs, &c.limbs) { - cb.logic_u8(rom_type, a_byte.expr(), b_byte.expr(), c_byte.expr())?; + for (a_byte_expr, b_byte_expr, c_byte_expr) in izip!(a.expr(), b.expr(), c.expr()) { + cb.logic_u8(rom_type, a_byte_expr, b_byte_expr, c_byte_expr)?; } Ok(()) } From 58a44df4c874e367d92d167ac3d9cadfa98dc067 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Wed, 13 Aug 2025 11:22:39 +0800 Subject: [PATCH 25/46] code cosmetics --- .../riscv/logic_imm/logic_imm_circuit_v2.rs | 27 +++++++++++++------ 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs index 855409a41..c72f31efe 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs @@ -39,19 +39,29 @@ impl Instruction for LogicInstruction { cb: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result { + let num_limbs = LIMB_BITS / 8; let config = LogicConfig::construct_circuit(cb, I::INST_KIND)?; - // Constrain the registers based on the given lookup table. // lo UIntLimbs::<{ LIMB_BITS }, 8, E>::logic( cb, I::OpsTable::ROM_TYPE, &UIntLimbs::from_exprs_unchecked( - config.rs1_read.expr().into_iter().take(2).collect_vec(), + config + .rs1_read + .expr() + .into_iter() + .take(num_limbs) + .collect_vec(), ), &config.imm_lo, &UIntLimbs::from_exprs_unchecked( - config.rd_written.expr().into_iter().take(2).collect_vec(), + config + .rd_written + .expr() + .into_iter() + .take(num_limbs) + .collect_vec(), ), )?; // hi @@ -63,8 +73,8 @@ impl Instruction for LogicInstruction { .rs1_read .expr() .into_iter() - .skip(2) - .take(2) + .skip(num_limbs) + .take(num_limbs) .collect_vec(), ), &config.imm_hi, @@ -73,8 +83,8 @@ impl Instruction for LogicInstruction { .rd_written .expr() .into_iter() - .skip(2) - .take(2) + .skip(num_limbs) + .take(num_limbs) .collect_vec(), ), )?; @@ -156,6 +166,7 @@ impl LogicConfig { lkm: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { + let num_limbs = LIMB_BITS / 8; self.i_insn.assign_instance(instance, lkm, step)?; let rs1_read = split_to_u8(step.rs1().unwrap().value); @@ -163,7 +174,7 @@ impl LogicConfig { let imm_lo = split_to_u8::(InsnRecord::::imm_internal(&step.insn()).0 as u32) - [0..2] + [..num_limbs] .to_vec(); let imm_hi = split_to_u8::( InsnRecord::::imm_signed_internal(&step.insn()).0 as u32, From 00affa72cfda54bf066edee5aabfe7bf2c6ee7fc Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Wed, 13 Aug 2025 11:42:22 +0800 Subject: [PATCH 26/46] refactor memory opcode for migration --- ceno_zkvm/src/instructions/riscv/memory.rs | 70 +++++- .../src/instructions/riscv/memory/load.rs | 32 --- .../src/instructions/riscv/memory/load_v2.rs | 220 ++++++++++++++++++ .../src/instructions/riscv/memory/store.rs | 18 -- .../src/instructions/riscv/memory/store_v2.rs | 142 +++++++++++ .../src/instructions/riscv/memory/test.rs | 10 +- ceno_zkvm/src/tables/program.rs | 2 +- 7 files changed, 434 insertions(+), 60 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/memory/load_v2.rs create mode 100644 ceno_zkvm/src/instructions/riscv/memory/store_v2.rs diff --git a/ceno_zkvm/src/instructions/riscv/memory.rs b/ceno_zkvm/src/instructions/riscv/memory.rs index 4b850a39c..8ccab9203 100644 --- a/ceno_zkvm/src/instructions/riscv/memory.rs +++ b/ceno_zkvm/src/instructions/riscv/memory.rs @@ -1,9 +1,75 @@ mod gadget; + +#[cfg(not(feature = "u16limb_circuit"))] pub mod load; +#[cfg(feature = "u16limb_circuit")] pub mod store; +#[cfg(feature = "u16limb_circuit")] +mod load_v2; +#[cfg(feature = "u16limb_circuit")] +mod store_v2; #[cfg(test)] mod test; -pub use load::{LbInstruction, LbuInstruction, LhInstruction, LhuInstruction, LwInstruction}; -pub use store::{SbInstruction, ShInstruction, SwInstruction}; +use crate::instructions::riscv::RIVInstruction; +#[cfg(not(feature = "u16limb_circuit"))] +pub use crate::instructions::riscv::memory::load::LoadInstruction; +#[cfg(feature = "u16limb_circuit")] +pub use crate::instructions::riscv::memory::load_v2::LoadInstruction; +#[cfg(not(feature = "u16limb_circuit"))] +pub use crate::instructions::riscv::memory::store::StoreInstruction; +#[cfg(feature = "u16limb_circuit")] +pub use crate::instructions::riscv::memory::store_v2::StoreInstruction; + +use ceno_emul::InsnKind; + +pub struct LwOp; + +impl RIVInstruction for LwOp { + const INST_KIND: InsnKind = InsnKind::LW; +} + +pub type LwInstruction = LoadInstruction; + +pub struct LhOp; +impl RIVInstruction for LhOp { + const INST_KIND: InsnKind = InsnKind::LH; +} +pub type LhInstruction = LoadInstruction; + +pub struct LhuOp; +impl RIVInstruction for LhuOp { + const INST_KIND: InsnKind = InsnKind::LHU; +} +pub type LhuInstruction = LoadInstruction; + +pub struct LbOp; +impl RIVInstruction for LbOp { + const INST_KIND: InsnKind = InsnKind::LB; +} +pub type LbInstruction = LoadInstruction; + +pub struct LbuOp; +impl RIVInstruction for LbuOp { + const INST_KIND: InsnKind = InsnKind::LBU; +} +pub type LbuInstruction = LoadInstruction; + +pub struct SWOp; +impl RIVInstruction for SWOp { + const INST_KIND: InsnKind = InsnKind::SW; +} +pub type SwInstruction = StoreInstruction; + +pub struct SHOp; +impl RIVInstruction for SHOp { + const INST_KIND: InsnKind = InsnKind::SH; +} +pub type ShInstruction = StoreInstruction; + +pub struct SBOp; +impl RIVInstruction for SBOp { + const INST_KIND: InsnKind = InsnKind::SB; +} +pub type SbInstruction = StoreInstruction; diff --git a/ceno_zkvm/src/instructions/riscv/memory/load.rs b/ceno_zkvm/src/instructions/riscv/memory/load.rs index 3ba83bdec..f76ddad6f 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load.rs @@ -35,38 +35,6 @@ pub struct LoadConfig { pub struct LoadInstruction(PhantomData<(E, I)>); -pub struct LwOp; - -impl RIVInstruction for LwOp { - const INST_KIND: InsnKind = InsnKind::LW; -} - -pub type LwInstruction = LoadInstruction; - -pub struct LhOp; -impl RIVInstruction for LhOp { - const INST_KIND: InsnKind = InsnKind::LH; -} -pub type LhInstruction = LoadInstruction; - -pub struct LhuOp; -impl RIVInstruction for LhuOp { - const INST_KIND: InsnKind = InsnKind::LHU; -} -pub type LhuInstruction = LoadInstruction; - -pub struct LbOp; -impl RIVInstruction for LbOp { - const INST_KIND: InsnKind = InsnKind::LB; -} -pub type LbInstruction = LoadInstruction; - -pub struct LbuOp; -impl RIVInstruction for LbuOp { - const INST_KIND: InsnKind = InsnKind::LBU; -} -pub type LbuInstruction = LoadInstruction; - impl Instruction for LoadInstruction { type InstructionConfig = LoadConfig; diff --git a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs new file mode 100644 index 000000000..f76ddad6f --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs @@ -0,0 +1,220 @@ +use crate::{ + Value, + circuit_builder::CircuitBuilder, + error::ZKVMError, + gadgets::SignedExtendConfig, + instructions::{ + Instruction, + riscv::{ + RIVInstruction, constants::UInt, im_insn::IMInstructionConfig, insn_base::MemAddr, + }, + }, + structs::ProgramParams, + tables::InsnRecord, + witness::{LkMultiplicity, set_val}, +}; +use ceno_emul::{ByteAddr, InsnKind, StepRecord}; +use ff_ext::{ExtensionField, FieldInto}; +use itertools::izip; +use multilinear_extensions::{Expression, ToExpr, WitIn}; +use p3::field::FieldAlgebra; +use std::marker::PhantomData; + +pub struct LoadConfig { + im_insn: IMInstructionConfig, + + rs1_read: UInt, + imm: WitIn, + memory_addr: MemAddr, + + memory_read: UInt, + target_limb: Option, + target_limb_bytes: Option>, + signed_extend_config: Option>, +} + +pub struct LoadInstruction(PhantomData<(E, I)>); + +impl Instruction for LoadInstruction { + type InstructionConfig = LoadConfig; + + fn name() -> String { + format!("{:?}", I::INST_KIND) + } + + fn construct_circuit( + circuit_builder: &mut CircuitBuilder, + _params: &ProgramParams, + ) -> Result { + let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; // unsigned 32-bit value + let imm = circuit_builder.create_witin(|| "imm"); // signed 12-bit value + // skip read range check, assuming constraint in write. + let memory_read = UInt::new_unchecked(|| "memory_read", circuit_builder)?; + + let memory_addr = match I::INST_KIND { + InsnKind::LW => MemAddr::construct_align4(circuit_builder), + InsnKind::LH | InsnKind::LHU => MemAddr::construct_align2(circuit_builder), + InsnKind::LB | InsnKind::LBU => MemAddr::construct_unaligned(circuit_builder), + _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), + }?; + + circuit_builder.require_equal( + || "memory_addr = rs1_read + imm", + memory_addr.expr_unaligned(), + rs1_read.value() + imm.expr(), + )?; + + let addr_low_bits = memory_addr.low_bit_exprs(); + let memory_value = memory_read.expr(); + + // get target limb from memory word for load instructions except LW + let target_limb = match I::INST_KIND { + InsnKind::LB | InsnKind::LBU | InsnKind::LH | InsnKind::LHU => { + let target_limb = circuit_builder.create_witin(|| "target_limb"); + circuit_builder.condition_require_equal( + || "target_limb = memory_value[low_bits[1]]", + addr_low_bits[1].clone(), + target_limb.expr(), + memory_value[1].clone(), + memory_value[0].clone(), + )?; + Some(target_limb) + } + _ => None, + }; + + // get target byte from memory word for LB and LBU + let (target_byte_expr, target_limb_bytes) = match I::INST_KIND { + InsnKind::LB | InsnKind::LBU => { + let target_byte = circuit_builder.create_u8(|| "limb.le_bytes[low_bits[0]]")?; + let dummy_byte = circuit_builder.create_u8(|| "limb.le_bytes[1-low_bits[0]]")?; + + circuit_builder.condition_require_equal( + || "target_byte = target_limb[low_bits[0]]", + addr_low_bits[0].clone(), + target_limb.unwrap().expr(), + target_byte.expr() * (1<<8) + dummy_byte.expr(), // target_byte = limb.le_bytes[1] + dummy_byte.expr() * (1<<8) + target_byte.expr(), // target_byte = limb.le_bytes[0] + )?; + + ( + Some(target_byte.expr()), + Some(vec![target_byte, dummy_byte]), + ) + } + _ => (None, None), + }; + let (signed_extend_config, rd_written) = match I::INST_KIND { + InsnKind::LW => (None, memory_read.clone()), + InsnKind::LH => { + let val = target_limb.unwrap(); + let signed_extend_config = + SignedExtendConfig::construct_limb(circuit_builder, val.expr())?; + let rd_written = signed_extend_config.signed_extended_value(val.expr()); + + (Some(signed_extend_config), rd_written) + } + InsnKind::LHU => { + ( + None, + // it's safe to unwrap as `UInt::from_exprs_unchecked` never return error + UInt::from_exprs_unchecked(vec![ + target_limb.as_ref().map(|limb| limb.expr()).unwrap(), + Expression::ZERO, + ]), + ) + } + InsnKind::LB => { + let val = target_byte_expr.unwrap(); + let signed_extend_config = + SignedExtendConfig::construct_byte(circuit_builder, val.clone())?; + let rd_written = signed_extend_config.signed_extended_value(val); + + (Some(signed_extend_config), rd_written) + } + InsnKind::LBU => ( + None, + UInt::from_exprs_unchecked(vec![target_byte_expr.unwrap(), Expression::ZERO]), + ), + _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), + }; + + let im_insn = IMInstructionConfig::::construct_circuit( + circuit_builder, + I::INST_KIND, + &imm.expr(), + rs1_read.register_expr(), + memory_read.memory_expr(), + memory_addr.expr_align4(), + rd_written.register_expr(), + )?; + + Ok(LoadConfig { + im_insn, + rs1_read, + imm, + memory_addr, + memory_read, + target_limb, + target_limb_bytes, + signed_extend_config, + }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [E::BaseField], + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + let rs1 = Value::new_unchecked(step.rs1().unwrap().value); + let memory_value = step.memory_op().unwrap().value.before; + let memory_read = Value::new_unchecked(memory_value); + // imm is signed 12-bit value + let imm = InsnRecord::::imm_internal(&step.insn()); + let unaligned_addr = + ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); + let shift = unaligned_addr.shift(); + let addr_low_bits = [shift & 0x01, (shift >> 1) & 0x01]; + let target_limb = memory_read.as_u16_limbs()[addr_low_bits[1] as usize]; + let mut target_limb_bytes = target_limb.to_le_bytes(); + + set_val!(instance, config.imm, imm.1); + config + .im_insn + .assign_instance(instance, lk_multiplicity, step)?; + config.rs1_read.assign_value(instance, rs1); + config.memory_read.assign_value(instance, memory_read); + config + .memory_addr + .assign_instance(instance, lk_multiplicity, unaligned_addr.into())?; + if let Some(&limb) = config.target_limb.as_ref() { + set_val!( + instance, + limb, + E::BaseField::from_canonical_u16(target_limb) + ); + } + if let Some(limb_bytes) = config.target_limb_bytes.as_ref() { + if addr_low_bits[0] == 1 { + // target_limb_bytes[0] = target_limb.to_le_bytes[1] + // target_limb_bytes[1] = target_limb.to_le_bytes[0] + target_limb_bytes.reverse(); + } + for (&col, byte) in izip!(limb_bytes.iter(), target_limb_bytes.into_iter()) { + lk_multiplicity.assert_ux::<8>(byte as u64); + set_val!(instance, col, E::BaseField::from_canonical_u8(byte)); + } + } + let val = match I::INST_KIND { + InsnKind::LB | InsnKind::LBU => target_limb_bytes[0] as u64, + InsnKind::LH | InsnKind::LHU => target_limb as u64, + _ => 0, + }; + if let Some(signed_ext_config) = config.signed_extend_config.as_ref() { + signed_ext_config.assign_instance(instance, lk_multiplicity, val)?; + } + + Ok(()) + } +} diff --git a/ceno_zkvm/src/instructions/riscv/memory/store.rs b/ceno_zkvm/src/instructions/riscv/memory/store.rs index aaf5f545f..e4e16d0f9 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store.rs @@ -32,24 +32,6 @@ pub struct StoreConfig { pub struct StoreInstruction(PhantomData<(E, I)>); -pub struct SWOp; -impl RIVInstruction for SWOp { - const INST_KIND: InsnKind = InsnKind::SW; -} -pub type SwInstruction = StoreInstruction; - -pub struct SHOp; -impl RIVInstruction for SHOp { - const INST_KIND: InsnKind = InsnKind::SH; -} -pub type ShInstruction = StoreInstruction; - -pub struct SBOp; -impl RIVInstruction for SBOp { - const INST_KIND: InsnKind = InsnKind::SB; -} -pub type SbInstruction = StoreInstruction; - impl Instruction for StoreInstruction { diff --git a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs new file mode 100644 index 000000000..dd992071f --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs @@ -0,0 +1,142 @@ +use crate::{ + Value, + circuit_builder::CircuitBuilder, + error::ZKVMError, + instructions::{ + Instruction, + riscv::{ + RIVInstruction, constants::UInt, insn_base::MemAddr, memory::gadget::MemWordUtil, + s_insn::SInstructionConfig, + }, + }, + structs::ProgramParams, + tables::InsnRecord, + witness::{LkMultiplicity, set_val}, +}; +use ceno_emul::{ByteAddr, InsnKind, StepRecord}; +use ff_ext::{ExtensionField, FieldInto}; +use multilinear_extensions::{ToExpr, WitIn}; +use std::marker::PhantomData; + +pub struct StoreConfig { + s_insn: SInstructionConfig, + + rs1_read: UInt, + rs2_read: UInt, + imm: WitIn, + prev_memory_value: UInt, + + memory_addr: MemAddr, + next_memory_value: Option>, +} + +pub struct StoreInstruction(PhantomData<(E, I)>); + +impl Instruction +for StoreInstruction +{ + type InstructionConfig = StoreConfig; + + fn name() -> String { + format!("{:?}", I::INST_KIND) + } + + fn construct_circuit( + circuit_builder: &mut CircuitBuilder, + params: &ProgramParams, + ) -> Result { + let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; // unsigned 32-bit value + let rs2_read = UInt::new_unchecked(|| "rs2_read", circuit_builder)?; + let prev_memory_value = UInt::new(|| "prev_memory_value", circuit_builder)?; + let imm = circuit_builder.create_witin(|| "imm"); // signed 12-bit value + + let memory_addr = match I::INST_KIND { + InsnKind::SW => MemAddr::construct_align4(circuit_builder), + InsnKind::SH => MemAddr::construct_align2(circuit_builder), + InsnKind::SB => MemAddr::construct_unaligned(circuit_builder), + _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), + }?; + + if cfg!(feature = "forbid_overflow") { + const MAX_RAM_ADDR: u32 = u32::MAX - 0x7FF; // max positive imm is 0x7FF + const MIN_RAM_ADDR: u32 = 0x800; // min negative imm is -0x800 + assert!( + !params.platform.can_write(MAX_RAM_ADDR + 1) + && !params.platform.can_write(MIN_RAM_ADDR - 1) + ); + } + circuit_builder.require_equal( + || "memory_addr = rs1_read + imm", + memory_addr.expr_unaligned(), + rs1_read.value() + imm.expr(), + )?; + + let (next_memory_value, next_memory) = match I::INST_KIND { + InsnKind::SW => (rs2_read.memory_expr(), None), + InsnKind::SH | InsnKind::SB => { + let next_memory = MemWordUtil::::construct_circuit( + circuit_builder, + &memory_addr, + &prev_memory_value, + &rs2_read, + )?; + (next_memory.as_lo_hi().clone(), Some(next_memory)) + } + _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), + }; + + let s_insn = SInstructionConfig::::construct_circuit( + circuit_builder, + I::INST_KIND, + &imm.expr(), + rs1_read.register_expr(), + rs2_read.register_expr(), + memory_addr.expr_align4(), + prev_memory_value.memory_expr(), + next_memory_value, + )?; + + Ok(StoreConfig { + s_insn, + rs1_read, + rs2_read, + imm, + prev_memory_value, + memory_addr, + next_memory_value: next_memory, + }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [E::BaseField], + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + let rs1 = Value::new_unchecked(step.rs1().unwrap().value); + let rs2 = Value::new_unchecked(step.rs2().unwrap().value); + let memory_op = step.memory_op().unwrap(); + let imm = InsnRecord::::imm_internal(&step.insn()); + let prev_mem_value = Value::new(memory_op.value.before, lk_multiplicity); + + let addr = ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); + config + .s_insn + .assign_instance(instance, lk_multiplicity, step)?; + config.rs1_read.assign_value(instance, rs1); + config.rs2_read.assign_value(instance, rs2); + set_val!(instance, config.imm, imm.1); + config + .prev_memory_value + .assign_value(instance, prev_mem_value); + + config + .memory_addr + .assign_instance(instance, lk_multiplicity, addr.into())?; + if let Some(change) = config.next_memory_value.as_ref() { + change.assign_instance(instance, lk_multiplicity, step, addr.shift())?; + } + + Ok(()) + } +} diff --git a/ceno_zkvm/src/instructions/riscv/memory/test.rs b/ceno_zkvm/src/instructions/riscv/memory/test.rs index dcba4424a..409d940a2 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/test.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/test.rs @@ -4,15 +4,11 @@ use crate::{ instructions::{ Instruction, riscv::{ - RIVInstruction, + LbInstruction, LbuInstruction, LhInstruction, LhuInstruction, RIVInstruction, constants::UInt, memory::{ - LwInstruction, SbInstruction, ShInstruction, SwInstruction, - load::{ - LbInstruction, LbOp, LbuInstruction, LbuOp, LhInstruction, LhOp, - LhuInstruction, LhuOp, LwOp, - }, - store::{SBOp, SHOp, SWOp}, + LbOp, LbuOp, LhOp, LhuOp, LwInstruction, LwOp, SBOp, SHOp, SWOp, SbInstruction, + ShInstruction, SwInstruction, }, }, }, diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index a80ff4bf4..a60488588 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -121,7 +121,7 @@ impl InsnRecord { (_, B) => (insn.imm as i64, i64_to_base(insn.imm as i64)), // for default imm to operate with register value _ => ( - insn.imm as u32 as i64, + insn.imm as i16 as i64, F::from_canonical_u16(insn.imm as i16 as u16), ), } From b1b1cddd0eb025b052646f9a6a5263e8311786b4 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Wed, 13 Aug 2025 14:55:39 +0800 Subject: [PATCH 27/46] all load/store test pass --- ceno_zkvm/src/instructions/riscv/im_insn.rs | 1 - ceno_zkvm/src/instructions/riscv/memory.rs | 2 +- ceno_zkvm/src/instructions/riscv/memory/store_v2.rs | 2 +- ceno_zkvm/src/tables/program.rs | 6 ++++++ 4 files changed, 8 insertions(+), 3 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/im_insn.rs b/ceno_zkvm/src/instructions/riscv/im_insn.rs index c87737150..7d03e64a7 100644 --- a/ceno_zkvm/src/instructions/riscv/im_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/im_insn.rs @@ -50,7 +50,6 @@ impl IMInstructionConfig { rs1.id.expr(), 0.into(), imm.clone(), - #[cfg(feature = "u16limb_circuit")] 0.into(), ))?; diff --git a/ceno_zkvm/src/instructions/riscv/memory.rs b/ceno_zkvm/src/instructions/riscv/memory.rs index 8ccab9203..bb29491f7 100644 --- a/ceno_zkvm/src/instructions/riscv/memory.rs +++ b/ceno_zkvm/src/instructions/riscv/memory.rs @@ -2,7 +2,7 @@ mod gadget; #[cfg(not(feature = "u16limb_circuit"))] pub mod load; -#[cfg(feature = "u16limb_circuit")] +#[cfg(not(feature = "u16limb_circuit"))] pub mod store; #[cfg(feature = "u16limb_circuit")] diff --git a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs index dd992071f..e4e16d0f9 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs @@ -33,7 +33,7 @@ pub struct StoreConfig { pub struct StoreInstruction(PhantomData<(E, I)>); impl Instruction -for StoreInstruction + for StoreInstruction { type InstructionConfig = StoreConfig; diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index a60488588..384d03bfe 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -112,6 +112,10 @@ impl InsnRecord { // Prepare the immediate for ShiftImmInstruction. // The shift is implemented as a multiplication/division by 1 << immediate. (SLLI | SRLI | SRAI, _) => (1 << insn.imm, i64_to_base(1 << insn.imm)), + // TODO convert to 2 limbs to support smaller field + (LB | LH | LW | LBU | LHU | SB | SH | SW, _) => { + (insn.imm as i64, i64_to_base(insn.imm as i64)) + } // logic imm (XORI | ORI | ANDI, _) => ( insn.imm as i16 as i64, @@ -130,6 +134,8 @@ impl InsnRecord { pub fn imm_signed_internal(insn: &Instruction) -> (i64, F) { match (insn.kind, InsnFormat::from(insn.kind)) { (SLLI | SRLI | SRAI, _) => (false as i64, F::from_bool(false)), + // TODO convert to 2 limbs to support smaller field + (LB | LH | LW | LBU | LHU | SB | SH | SW, _) => (false as i64, F::from_bool(false)), // logic imm (XORI | ORI | ANDI, _) => ( (insn.imm >> LIMB_BITS) as i16 as i64, From 19be4cf5ce25cb072eced948b5a3b3563986859d Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Wed, 13 Aug 2025 15:52:49 +0800 Subject: [PATCH 28/46] make JALR as TODO --- ceno_zkvm/src/tables/program.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 384d03bfe..4545032d9 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -122,7 +122,9 @@ impl InsnRecord { F::from_canonical_u16(insn.imm as u16), ), // for imm operate with program counter => convert to field value - (_, B) => (insn.imm as i64, i64_to_base(insn.imm as i64)), + (_, B | J) => (insn.imm as i64, i64_to_base(insn.imm as i64)), + // TODO JALR need to connecting register (2 limb) with pc (1 limb) + (JALR, _) => (insn.imm as i64, i64_to_base(insn.imm as i64)), // for default imm to operate with register value _ => ( insn.imm as i16 as i64, @@ -145,7 +147,9 @@ impl InsnRecord { (_, R | U) => (false as i64, F::from_bool(false)), // in particular imm operated with program counter // encode as field element, which do not need extra sign extension of imm - (_, B) => (false as i64, F::from_bool(false)), + (_, B | J) => (false as i64, F::from_bool(false)), + // TODO JALR need to connecting register (2 limb) with pc (1 limb) + (JALR, _) => (false as i64, F::from_bool(false)), // Signed views _ => ((insn.imm < 0) as i64, F::from_bool(insn.imm < 0)), } From f03c6f1263b0b683492a4a7b3f177c0112de7551 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Wed, 13 Aug 2025 16:05:44 +0800 Subject: [PATCH 29/46] fix clippy --- ceno_zkvm/Cargo.toml | 2 +- ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs | 1 + ceno_zkvm/src/instructions/riscv/im_insn.rs | 1 + ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs | 3 +-- ceno_zkvm/src/tables/program.rs | 2 -- ceno_zkvm/src/utils.rs | 2 ++ 6 files changed, 6 insertions(+), 5 deletions(-) diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index ae3db9f32..d972bd904 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -61,7 +61,7 @@ ceno-examples = { path = "../examples-builder" } glob = "0.3" [features] -default = ["forbid_overflow", "u16limb_circuit"] +default = ["forbid_overflow"] flamegraph = ["pprof2/flamegraph", "pprof2/criterion"] forbid_overflow = [] jemalloc = ["dep:tikv-jemallocator", "dep:tikv-jemalloc-ctl"] diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs index 42acc1231..7c98e2159 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs @@ -15,6 +15,7 @@ use crate::{ }; use ff_ext::FieldInto; use multilinear_extensions::{ToExpr, WitIn}; +#[cfg(feature = "u16limb_circuit")] use p3::field::FieldAlgebra; use witness::set_val; diff --git a/ceno_zkvm/src/instructions/riscv/im_insn.rs b/ceno_zkvm/src/instructions/riscv/im_insn.rs index 7d03e64a7..c87737150 100644 --- a/ceno_zkvm/src/instructions/riscv/im_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/im_insn.rs @@ -50,6 +50,7 @@ impl IMInstructionConfig { rs1.id.expr(), 0.into(), imm.clone(), + #[cfg(feature = "u16limb_circuit")] 0.into(), ))?; diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs index 63fe08ffc..632bf873e 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs @@ -17,9 +17,8 @@ use crate::{ }; use ceno_emul::{InsnKind, SWord, StepRecord, Word}; use ff_ext::{ExtensionField, FieldInto}; -use gkr_iop::{gadgets::IsLtConfig, utils::i64_to_base}; +use gkr_iop::gadgets::IsLtConfig; use multilinear_extensions::{ToExpr, WitIn}; -use p3::field::FieldAlgebra; use std::marker::PhantomData; use witness::set_val; diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 4545032d9..741fc4f67 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -19,11 +19,9 @@ use std::{collections::HashMap, marker::PhantomData}; use witness::{InstancePaddingStrategy, RowMajorMatrix, set_fixed_val, set_val}; /// This structure establishes the order of the fields in instruction records, common to the program table and circuit fetches. - #[cfg(not(feature = "u16limb_circuit"))] #[derive(Clone, Debug)] pub struct InsnRecord([T; 6]); - #[cfg(feature = "u16limb_circuit")] #[derive(Clone, Debug)] pub struct InsnRecord([T; 7]); diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index fda483942..7a703850f 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -145,6 +145,8 @@ pub fn imm_sign_extend_circuit( ] } } + +#[cfg(feature = "u16limb_circuit")] #[inline(always)] pub fn imm_sign_extend(is_signed_extension: bool, imm: i16) -> [u16; UINT_LIMBS] { #[allow(clippy::if_same_then_else)] From a84dc9115e641c3b778a1fd908ab98734e6b1f20 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Wed, 13 Aug 2025 16:18:57 +0800 Subject: [PATCH 30/46] clean up jalr code --- ceno_zkvm/Cargo.toml | 2 +- ceno_zkvm/src/instructions/riscv/jump/jalr.rs | 9 ++------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index d972bd904..ae3db9f32 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -61,7 +61,7 @@ ceno-examples = { path = "../examples-builder" } glob = "0.3" [features] -default = ["forbid_overflow"] +default = ["forbid_overflow", "u16limb_circuit"] flamegraph = ["pprof2/flamegraph", "pprof2/criterion"] forbid_overflow = [] jemalloc = ["dep:tikv-jemallocator", "dep:tikv-jemalloc-ctl"] diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs index 46ca78a48..fe077d464 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs @@ -12,7 +12,6 @@ use crate::{ }, structs::ProgramParams, tables::InsnRecord, - utils::imm_sign_extend_circuit, witness::{LkMultiplicity, set_val}, }; use ceno_emul::{InsnKind, PC_STEP_SIZE}; @@ -48,18 +47,14 @@ impl Instruction for JalrInstruction { ) -> Result, ZKVMError> { let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; // unsigned 32-bit value let imm = circuit_builder.create_witin(|| "imm"); // signed 12-bit value - let imm_sign = circuit_builder.create_witin(|| "imm_sign"); - let imm_sign_extend = UInt::from_exprs_unchecked( - imm_sign_extend_circuit::(true, imm_sign.expr(), imm.expr()).to_vec(), - ); let rd_written = UInt::new(|| "rd_written", circuit_builder)?; let i_insn = IInstructionConfig::construct_circuit( circuit_builder, InsnKind::JALR, - imm_sign_extend.expr().remove(0), + imm.expr(), #[cfg(feature = "u16limb_circuit")] - imm_sign.expr(), + 0.into(), rs1_read.register_expr(), rd_written.register_expr(), true, From 46c32cbe334a3c6563c41aed438e8cc2ac94e832 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Wed, 13 Aug 2025 18:21:45 +0800 Subject: [PATCH 31/46] refactor arith_imm --- ceno_zkvm/src/instructions/riscv/arith_imm.rs | 68 +++++++------------ 1 file changed, 23 insertions(+), 45 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm.rs b/ceno_zkvm/src/instructions/riscv/arith_imm.rs index 32ce74007..0b98ed64d 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm.rs @@ -17,20 +17,25 @@ impl RIVInstruction for AddiInstruction { #[cfg(test)] mod test { - use ceno_emul::{Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; - use ff_ext::GoldilocksExt2; - + use super::AddiInstruction; use crate::{ + Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, - instructions::Instruction, + instructions::{Instruction, riscv::constants::UInt}, scheme::mock_prover::{MOCK_PC_START, MockProver}, structs::ProgramParams, }; - - use super::AddiInstruction; + use ceno_emul::{Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ff_ext::GoldilocksExt2; + use gkr_iop::circuit_builder::DebugIndex; #[test] - fn test_opcode_addi() { + fn test_opcode_addi_v1() { + test_opcode_addi(1000, 1003, 3); + test_opcode_addi(1000, 997, -3); + } + + fn test_opcode_addi(rs1: u32, rd: u32, imm: i32) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); let config = cb @@ -47,7 +52,7 @@ mod test { .unwrap() .unwrap(); - let insn_code = encode_rv32(InsnKind::ADDI, 2, 0, 4, 3); + let insn_code = encode_rv32(InsnKind::ADDI, 2, 0, 4, imm); let (raw_witin, lkm) = AddiInstruction::::assign_instances( &config, cb.cs.num_witin as usize, @@ -56,48 +61,21 @@ mod test { 3, Change::new(MOCK_PC_START, MOCK_PC_START + PC_STEP_SIZE), insn_code, - 1000, - Change::new(0, 1003), + rs1, + Change::new(0, rd), 0, )], ) .unwrap(); - MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); - } - - #[test] - fn test_opcode_addi_sub() { - let mut cs = ConstraintSystem::::new(|| "riscv"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = cb - .namespace( - || "addi", - |cb| { - let config = AddiInstruction::::construct_circuit( - cb, - &ProgramParams::default(), - ); - Ok(config) - }, - ) - .unwrap() - .unwrap(); - - let insn_code = encode_rv32(InsnKind::ADDI, 2, 0, 4, -3); - - let (raw_witin, lkm) = AddiInstruction::::assign_instances( - &config, - cb.cs.num_witin as usize, - cb.cs.num_structural_witin as usize, - vec![StepRecord::new_i_instruction( - 3, - Change::new(MOCK_PC_START, MOCK_PC_START + PC_STEP_SIZE), - insn_code, - 1000, - Change::new(0, 997), - 0, - )], + // verify rd_written + let expected_rd_written = + UInt::from_const_unchecked(Value::new_unchecked(rd).as_u16_limbs().to_vec()); + let rd_written_expr = cb.get_debug_expr(DebugIndex::RdWrite as usize)[0].clone(); + cb.require_equal( + || "assert_rd_written", + rd_written_expr, + expected_rd_written.value(), ) .unwrap(); From c39e3ae4c006b9fbfb86361c341c159f88d7a241 Mon Sep 17 00:00:00 2001 From: Wu Sung-Ming Date: Fri, 15 Aug 2025 14:25:29 +0800 Subject: [PATCH 32/46] wip --- ceno_emul/src/rv32im.rs | 5 +- ceno_zkvm/src/instructions/riscv.rs | 1 + ceno_zkvm/src/instructions/riscv/constants.rs | 2 + ceno_zkvm/src/instructions/riscv/jump.rs | 9 ++ .../src/instructions/riscv/jump/jal_lui_v2.rs | 128 ++++++++++++++++++ ceno_zkvm/src/instructions/riscv/lui.rs | 98 ++++++++++++++ 6 files changed, 242 insertions(+), 1 deletion(-) create mode 100644 ceno_zkvm/src/instructions/riscv/jump/jal_lui_v2.rs create mode 100644 ceno_zkvm/src/instructions/riscv/lui.rs diff --git a/ceno_emul/src/rv32im.rs b/ceno_emul/src/rv32im.rs index 752f2af93..384102a47 100644 --- a/ceno_emul/src/rv32im.rs +++ b/ceno_emul/src/rv32im.rs @@ -195,6 +195,7 @@ pub enum InsnKind { LW, LBU, LHU, + LUI, SB, SH, SW, @@ -212,7 +213,7 @@ impl From for InsnCategory { | MULHU | DIV | DIVU | REM | REMU => Compute, ADDI | XORI | ORI | ANDI | SLLI | SRLI | SRAI | SLTI | SLTIU => Compute, BEQ | BNE | BLT | BGE | BLTU | BGEU => Branch, - JAL | JALR => Compute, + JAL | JALR | LUI => Compute, LB | LH | LW | LBU | LHU => Load, SB | SH | SW => Store, ECALL => System, @@ -231,6 +232,7 @@ impl From for InsnFormat { JAL => J, JALR => I, LB | LH | LW | LBU | LHU => I, + LUI => U, SB | SH | SW => S, ECALL => I, INVALID => I, @@ -306,6 +308,7 @@ fn step_compute(ctx: &mut M, kind: InsnKind, insn: &Instruction) match kind { ADDI => rs1.wrapping_add(imm_i), + LUI => imm_i << 12, XORI => rs1 ^ imm_i, ORI => rs1 | imm_i, ANDI => rs1 & imm_i, diff --git a/ceno_zkvm/src/instructions/riscv.rs b/ceno_zkvm/src/instructions/riscv.rs index 20e8a460f..3f97fa798 100644 --- a/ceno_zkvm/src/instructions/riscv.rs +++ b/ceno_zkvm/src/instructions/riscv.rs @@ -33,6 +33,7 @@ mod r_insn; mod ecall_insn; mod im_insn; +mod lui; mod memory; mod s_insn; #[cfg(test)] diff --git a/ceno_zkvm/src/instructions/riscv/constants.rs b/ceno_zkvm/src/instructions/riscv/constants.rs index b604f7c92..334e8364c 100644 --- a/ceno_zkvm/src/instructions/riscv/constants.rs +++ b/ceno_zkvm/src/instructions/riscv/constants.rs @@ -16,6 +16,8 @@ pub const LIMB_MASK: u32 = 0xFFFF; pub const BIT_WIDTH: usize = 32usize; +pub const PC_BITS: usize = 30; + pub type UInt = UIntLimbs; pub type UIntMul = UIntLimbs<{ 2 * BIT_WIDTH }, LIMB_BITS, E>; /// use UInt for x bits limb size diff --git a/ceno_zkvm/src/instructions/riscv/jump.rs b/ceno_zkvm/src/instructions/riscv/jump.rs index b57aadbbb..67bc3f97e 100644 --- a/ceno_zkvm/src/instructions/riscv/jump.rs +++ b/ceno_zkvm/src/instructions/riscv/jump.rs @@ -1,7 +1,16 @@ +#[cfg(not(feature = "u16limb_circuit"))] mod jal; +#[cfg(feature = "u16limb_circuit")] +mod jal_lui_v2; + mod jalr; +#[cfg(not(feature = "u16limb_circuit"))] pub use jal::JalInstruction; + +#[cfg(feature = "u16limb_circuit")] +pub use jal_lui_v2::JalInstruction; + pub use jalr::JalrInstruction; #[cfg(test)] diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal_lui_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jal_lui_v2.rs new file mode 100644 index 000000000..009181ada --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/jump/jal_lui_v2.rs @@ -0,0 +1,128 @@ +use std::marker::PhantomData; + +use ff_ext::ExtensionField; + +use crate::{ + Value, + circuit_builder::CircuitBuilder, + error::ZKVMError, + instructions::{ + Instruction, + riscv::{ + RIVInstruction, + constants::{BIT_WIDTH, PC_BITS, UInt, UInt8}, + j_insn::JInstructionConfig, + }, + }, + structs::ProgramParams, + witness::LkMultiplicity, +}; +use ceno_emul::{InsnKind, PC_STEP_SIZE}; +use gkr_iop::tables::LookupTable; +use multilinear_extensions::{Expression, ToExpr, WitIn}; +use p3::field::FieldAlgebra; + +pub struct JalConfig { + pub j_insn: JInstructionConfig, + pub imm: Option, + pub rd_written: UInt8, +} + +pub struct JalInstruction(PhantomData<(E, I)>); + +/// JAL instruction circuit +/// +/// Note: does not validate that next_pc is aligned by 4-byte increments, which +/// should be verified by lookup argument of the next execution step against +/// the program table +/// +/// Assumption: values for valid initial program counter must lie between +/// 2^20 and 2^32 - 2^20 + 2 inclusive, probably enforced by the static +/// program lookup table. If this assumption does not hold, then resulting +/// value for next_pc may not correctly wrap mod 2^32 because of the use +/// of native WitIn values for address space arithmetic. +impl Instruction for JalInstruction { + type InstructionConfig = JalConfig; + + fn name() -> String { + format!("{:?}", InsnKind::JAL) + } + + fn construct_circuit( + circuit_builder: &mut CircuitBuilder, + _params: &ProgramParams, + ) -> Result, ZKVMError> { + let rd_written = UInt8::new(|| "rd_written", circuit_builder)?; + let rd_exprs = rd_written.expr(); + let imm = circuit_builder.create_witin(|| "imm"); + + let intermed_val = + rd_exprs + .iter() + .skip(1) + .enumerate() + .fold(Expression::ZERO, |acc, (i, &val)| { + acc + val.expr() + * E::BaseField::from_canonical_u32(1 << (i * UInt8::LIMB_BITS)).expr() + }); + + match I::INST_KIND { + InsnKind::JAL => { + let j_insn = JInstructionConfig::construct_circuit( + circuit_builder, + InsnKind::JAL, + rd_written.register_expr(), + )?; + // constrain rd_exprs[PC_BITS..u32::BITS] are all 0 via xor + let last_limb_bits = PC_BITS - UInt8::LIMB_BITS * (UInt8::NUM_LIMBS - 1); + let additional_bits = + (last_limb_bits..UInt8::LIMB_BITS).fold(0, |acc, x| acc + (1 << x)); + let additional_bits = E::BaseField::from_canonical_u32(additional_bits); + circuit_builder.logic_u8( + LookupTable::Xor, + rd_exprs[3].expr(), + additional_bits.expr(), + rd_exprs[3].expr() + additional_bits.expr(), + )?; + circuit_builder.require_equal( + intermed_val, + j_insn + AB::F::from_canonical_u32(DEFAULT_PC_STEP), + ); + } + InsnKind::LUI => { + // constrain rd[1..4] = imm * 2^4 + circuit_builder.require_equal( + || "constrain rd[1..4] = imm * 2^4", + intermed_val, + imm.expr() + * E::BaseField::from_canonical_u32(1 << (12 - UInt8::LIMB_BITS)).expr(), + )?; + } + other => panic!("invalid kind {}", other), + } + + circuit_builder.require_equal( + || "jal rd_written", + rd_written.value(), + j_insn.vm_state.pc.expr() + PC_STEP_SIZE, + )?; + + Ok(JalConfig { j_insn, rd_written }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [E::BaseField], + lk_multiplicity: &mut LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + config + .j_insn + .assign_instance(instance, lk_multiplicity, step)?; + + let rd_written = Value::new(step.rd().unwrap().value.after, lk_multiplicity); + config.rd_written.assign_value(instance, rd_written); + + Ok(()) + } +} diff --git a/ceno_zkvm/src/instructions/riscv/lui.rs b/ceno_zkvm/src/instructions/riscv/lui.rs new file mode 100644 index 000000000..743e98519 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/lui.rs @@ -0,0 +1,98 @@ +use std::marker::PhantomData; + +use ff_ext::ExtensionField; + +use crate::{ + Value, + circuit_builder::CircuitBuilder, + error::ZKVMError, + instructions::{ + Instruction, + riscv::{ + constants::{BIT_WIDTH, PC_BITS, UInt, UInt8}, + i_insn::IInstructionConfig, + }, + }, + structs::ProgramParams, + witness::LkMultiplicity, +}; +use ceno_emul::{InsnKind, PC_STEP_SIZE}; +use gkr_iop::tables::LookupTable; +use multilinear_extensions::{Expression, ToExpr, WitIn}; +use p3::field::FieldAlgebra; + +pub struct LuiConfig { + pub j_insn: IInstructionConfig, + pub imm: WitIn, + pub rd_written: UInt8, +} + +pub struct LuiInstruction(PhantomData); + +impl Instruction for LuiInstruction { + type InstructionConfig = LuiConfig; + + fn name() -> String { + format!("{:?}", InsnKind::LUI) + } + + fn construct_circuit( + circuit_builder: &mut CircuitBuilder, + _params: &ProgramParams, + ) -> Result, ZKVMError> { + let rd_written = UInt8::new(|| "rd_written", circuit_builder)?; + let rd_exprs = rd_written.expr(); + let imm = circuit_builder.create_witin(|| "imm"); + let i_insn = IInstructionConfig::::construct_circuit( + circuit_builder, + InsnKind::LUI, + imm.expr(), + #[cfg(feature = "u16limb_circuit")] + 0.into(), + [0.into(), 0.into()], + rd_written.register_expr(), + false, + )?; + + let intermed_val = + rd_exprs + .iter() + .skip(1) + .enumerate() + .fold(Expression::ZERO, |acc, (i, &val)| { + acc + val.expr() + * E::BaseField::from_canonical_u32(1 << (i * UInt8::LIMB_BITS)).expr() + }); + + // imm * 2^4 is the correct composition of intermed_val in case of LUI + circuit_builder.require_equal( + || "imm * 2^4 is the correct composition of intermed_val in case of LUI", + intermed_val.expr(), + imm * E::BaseField::from_canonical_u32(1 << (12 - UInt8::LIMB_BITS)), + )?; + + circuit_builder.require_equal( + || "jal rd_written", + rd_written.value(), + i_insn.vm_state.pc.expr() + PC_STEP_SIZE, + )?; + + Ok(JalConfig { j_insn, rd_written }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [E::BaseField], + lk_multiplicity: &mut LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + config + .j_insn + .assign_instance(instance, lk_multiplicity, step)?; + + let rd_written = Value::new(step.rd().unwrap().value.after, lk_multiplicity); + config.rd_written.assign_value(instance, rd_written); + + Ok(()) + } +} From 38530782cdccc810612641b42f217ca2d141d4db Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Fri, 15 Aug 2025 16:23:49 +0800 Subject: [PATCH 33/46] migrated lui and test pass --- ceno_emul/Cargo.toml | 1 + ceno_emul/src/disassemble/mod.rs | 46 +++--- ceno_emul/src/rv32im.rs | 11 +- ceno_zkvm/Cargo.toml | 18 +-- ceno_zkvm/src/instructions/riscv.rs | 1 + ceno_zkvm/src/instructions/riscv/constants.rs | 1 + ceno_zkvm/src/instructions/riscv/jump.rs | 9 -- .../src/instructions/riscv/jump/jal_lui_v2.rs | 128 ----------------- ceno_zkvm/src/instructions/riscv/lui.rs | 133 ++++++++++++++---- ceno_zkvm/src/instructions/riscv/rv32im.rs | 21 ++- ceno_zkvm/src/tables/program.rs | 5 + ceno_zkvm/src/utils.rs | 13 +- 12 files changed, 189 insertions(+), 198 deletions(-) delete mode 100644 ceno_zkvm/src/instructions/riscv/jump/jal_lui_v2.rs diff --git a/ceno_emul/Cargo.toml b/ceno_emul/Cargo.toml index bda3dd5d2..7fb4b2809 100644 --- a/ceno_emul/Cargo.toml +++ b/ceno_emul/Cargo.toml @@ -30,3 +30,4 @@ tracing.workspace = true [features] default = ["forbid_overflow"] forbid_overflow = [] +u16limb_circuit = [] diff --git a/ceno_emul/src/disassemble/mod.rs b/ceno_emul/src/disassemble/mod.rs index 94709e74b..f4af91d7d 100644 --- a/ceno_emul/src/disassemble/mod.rs +++ b/ceno_emul/src/disassemble/mod.rs @@ -249,24 +249,38 @@ impl InstructionProcessor for InstructionTranspiler { } } - /// Convert LUI to ADDI. - /// - /// RiscV's load-upper-immediate instruction is necessary to build arbitrary constants, - /// because its ADDI can only have a relatively small immediate value: there's just not - /// enough space in the 32 bits for more. - /// - /// Our internal ADDI does not have this limitation, so we can convert LUI to ADDI. - /// See [`InstructionTranspiler::process_auipc`] for more background on the conversion. fn process_lui(&mut self, dec_insn: UType) -> Self::InstructionResult { // Verify assumption that the immediate is already shifted left by 12 bits. assert_eq!(dec_insn.imm & 0xfff, 0); - Instruction { - kind: InsnKind::ADDI, - rd: dec_insn.rd, - rs1: 0, - rs2: 0, - imm: dec_insn.imm, - raw: self.word, + #[cfg(not(feature = "u16limb_circuit"))] + { + // Convert LUI to ADDI. + // + // RiscV's load-upper-immediate instruction is necessary to build arbitrary constants, + // because its ADDI can only have a relatively small immediate value: there's just not + // enough space in the 32 bits for more. + // + // Our internal ADDI does not have this limitation, so we can convert LUI to ADDI. + // See [`InstructionTranspiler::process_auipc`] for more background on the conversion. + Instruction { + kind: InsnKind::ADDI, + rd: dec_insn.rd, + rs1: 0, + rs2: 0, + imm: dec_insn.imm, + raw: self.word, + } + } + #[cfg(feature = "u16limb_circuit")] + { + Instruction { + kind: InsnKind::LUI, + rd: dec_insn.rd, + rs1: 0, + rs2: 0, + imm: dec_insn.imm, + raw: self.word, + } } } @@ -289,8 +303,6 @@ impl InstructionProcessor for InstructionTranspiler { /// In any case, AUIPC and LUI together make up ~0.1% of instructions executed in typical /// real world scenarios like a `reth` run. /// - /// TODO(Matthias): run benchmarks to verify the impact on recursion, once we have a working - /// recursion. fn process_auipc(&mut self, dec_insn: UType) -> Self::InstructionResult { let pc = self.pc; // Verify our assumption that the immediate is already shifted left by 12 bits. diff --git a/ceno_emul/src/rv32im.rs b/ceno_emul/src/rv32im.rs index 384102a47..ec3f356c8 100644 --- a/ceno_emul/src/rv32im.rs +++ b/ceno_emul/src/rv32im.rs @@ -195,6 +195,7 @@ pub enum InsnKind { LW, LBU, LHU, + #[cfg(feature = "u16limb_circuit")] LUI, SB, SH, @@ -213,10 +214,12 @@ impl From for InsnCategory { | MULHU | DIV | DIVU | REM | REMU => Compute, ADDI | XORI | ORI | ANDI | SLLI | SRLI | SRAI | SLTI | SLTIU => Compute, BEQ | BNE | BLT | BGE | BLTU | BGEU => Branch, - JAL | JALR | LUI => Compute, + JAL | JALR => Compute, LB | LH | LW | LBU | LHU => Load, SB | SH | SW => Store, ECALL => System, + #[cfg(feature = "u16limb_circuit")] + LUI => Compute, } } } @@ -232,10 +235,11 @@ impl From for InsnFormat { JAL => J, JALR => I, LB | LH | LW | LBU | LHU => I, - LUI => U, SB | SH | SW => S, ECALL => I, INVALID => I, + #[cfg(feature = "u16limb_circuit")] + LUI => U, } } } @@ -308,7 +312,8 @@ fn step_compute(ctx: &mut M, kind: InsnKind, insn: &Instruction) match kind { ADDI => rs1.wrapping_add(imm_i), - LUI => imm_i << 12, + #[cfg(feature = "u16limb_circuit")] + LUI => imm_i, XORI => rs1 ^ imm_i, ORI => rs1 | imm_i, ANDI => rs1 & imm_i, diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index ae3db9f32..392ec1a69 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -67,17 +67,17 @@ forbid_overflow = [] jemalloc = ["dep:tikv-jemallocator", "dep:tikv-jemalloc-ctl"] jemalloc-prof = ["jemalloc", "tikv-jemallocator?/profiling"] nightly-features = [ - "p3/nightly-features", - "ff_ext/nightly-features", - "mpcs/nightly-features", - "multilinear_extensions/nightly-features", - "poseidon/nightly-features", - "sumcheck/nightly-features", - "transcript/nightly-features", - "witness/nightly-features", + "p3/nightly-features", + "ff_ext/nightly-features", + "mpcs/nightly-features", + "multilinear_extensions/nightly-features", + "poseidon/nightly-features", + "sumcheck/nightly-features", + "transcript/nightly-features", + "witness/nightly-features", ] sanity-check = ["mpcs/sanity-check"] -u16limb_circuit = [] +u16limb_circuit = ["ceno_emul/u16limb_circuit"] [[bench]] harness = false diff --git a/ceno_zkvm/src/instructions/riscv.rs b/ceno_zkvm/src/instructions/riscv.rs index 3f97fa798..73ec31650 100644 --- a/ceno_zkvm/src/instructions/riscv.rs +++ b/ceno_zkvm/src/instructions/riscv.rs @@ -33,6 +33,7 @@ mod r_insn; mod ecall_insn; mod im_insn; +#[cfg(feature = "u16limb_circuit")] mod lui; mod memory; mod s_insn; diff --git a/ceno_zkvm/src/instructions/riscv/constants.rs b/ceno_zkvm/src/instructions/riscv/constants.rs index 334e8364c..a5395e500 100644 --- a/ceno_zkvm/src/instructions/riscv/constants.rs +++ b/ceno_zkvm/src/instructions/riscv/constants.rs @@ -23,3 +23,4 @@ pub type UIntMul = UIntLimbs<{ 2 * BIT_WIDTH }, LIMB_BITS, E>; /// use UInt for x bits limb size pub type UInt8 = UIntLimbs; pub const UINT_LIMBS: usize = BIT_WIDTH.div_ceil(LIMB_BITS); +pub const UINT_BYTE_LIMBS: usize = BIT_WIDTH.div_ceil(8); \ No newline at end of file diff --git a/ceno_zkvm/src/instructions/riscv/jump.rs b/ceno_zkvm/src/instructions/riscv/jump.rs index 67bc3f97e..b57aadbbb 100644 --- a/ceno_zkvm/src/instructions/riscv/jump.rs +++ b/ceno_zkvm/src/instructions/riscv/jump.rs @@ -1,16 +1,7 @@ -#[cfg(not(feature = "u16limb_circuit"))] mod jal; -#[cfg(feature = "u16limb_circuit")] -mod jal_lui_v2; - mod jalr; -#[cfg(not(feature = "u16limb_circuit"))] pub use jal::JalInstruction; - -#[cfg(feature = "u16limb_circuit")] -pub use jal_lui_v2::JalInstruction; - pub use jalr::JalrInstruction; #[cfg(test)] diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal_lui_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jal_lui_v2.rs deleted file mode 100644 index 009181ada..000000000 --- a/ceno_zkvm/src/instructions/riscv/jump/jal_lui_v2.rs +++ /dev/null @@ -1,128 +0,0 @@ -use std::marker::PhantomData; - -use ff_ext::ExtensionField; - -use crate::{ - Value, - circuit_builder::CircuitBuilder, - error::ZKVMError, - instructions::{ - Instruction, - riscv::{ - RIVInstruction, - constants::{BIT_WIDTH, PC_BITS, UInt, UInt8}, - j_insn::JInstructionConfig, - }, - }, - structs::ProgramParams, - witness::LkMultiplicity, -}; -use ceno_emul::{InsnKind, PC_STEP_SIZE}; -use gkr_iop::tables::LookupTable; -use multilinear_extensions::{Expression, ToExpr, WitIn}; -use p3::field::FieldAlgebra; - -pub struct JalConfig { - pub j_insn: JInstructionConfig, - pub imm: Option, - pub rd_written: UInt8, -} - -pub struct JalInstruction(PhantomData<(E, I)>); - -/// JAL instruction circuit -/// -/// Note: does not validate that next_pc is aligned by 4-byte increments, which -/// should be verified by lookup argument of the next execution step against -/// the program table -/// -/// Assumption: values for valid initial program counter must lie between -/// 2^20 and 2^32 - 2^20 + 2 inclusive, probably enforced by the static -/// program lookup table. If this assumption does not hold, then resulting -/// value for next_pc may not correctly wrap mod 2^32 because of the use -/// of native WitIn values for address space arithmetic. -impl Instruction for JalInstruction { - type InstructionConfig = JalConfig; - - fn name() -> String { - format!("{:?}", InsnKind::JAL) - } - - fn construct_circuit( - circuit_builder: &mut CircuitBuilder, - _params: &ProgramParams, - ) -> Result, ZKVMError> { - let rd_written = UInt8::new(|| "rd_written", circuit_builder)?; - let rd_exprs = rd_written.expr(); - let imm = circuit_builder.create_witin(|| "imm"); - - let intermed_val = - rd_exprs - .iter() - .skip(1) - .enumerate() - .fold(Expression::ZERO, |acc, (i, &val)| { - acc + val.expr() - * E::BaseField::from_canonical_u32(1 << (i * UInt8::LIMB_BITS)).expr() - }); - - match I::INST_KIND { - InsnKind::JAL => { - let j_insn = JInstructionConfig::construct_circuit( - circuit_builder, - InsnKind::JAL, - rd_written.register_expr(), - )?; - // constrain rd_exprs[PC_BITS..u32::BITS] are all 0 via xor - let last_limb_bits = PC_BITS - UInt8::LIMB_BITS * (UInt8::NUM_LIMBS - 1); - let additional_bits = - (last_limb_bits..UInt8::LIMB_BITS).fold(0, |acc, x| acc + (1 << x)); - let additional_bits = E::BaseField::from_canonical_u32(additional_bits); - circuit_builder.logic_u8( - LookupTable::Xor, - rd_exprs[3].expr(), - additional_bits.expr(), - rd_exprs[3].expr() + additional_bits.expr(), - )?; - circuit_builder.require_equal( - intermed_val, - j_insn + AB::F::from_canonical_u32(DEFAULT_PC_STEP), - ); - } - InsnKind::LUI => { - // constrain rd[1..4] = imm * 2^4 - circuit_builder.require_equal( - || "constrain rd[1..4] = imm * 2^4", - intermed_val, - imm.expr() - * E::BaseField::from_canonical_u32(1 << (12 - UInt8::LIMB_BITS)).expr(), - )?; - } - other => panic!("invalid kind {}", other), - } - - circuit_builder.require_equal( - || "jal rd_written", - rd_written.value(), - j_insn.vm_state.pc.expr() + PC_STEP_SIZE, - )?; - - Ok(JalConfig { j_insn, rd_written }) - } - - fn assign_instance( - config: &Self::InstructionConfig, - instance: &mut [E::BaseField], - lk_multiplicity: &mut LkMultiplicity, - step: &ceno_emul::StepRecord, - ) -> Result<(), ZKVMError> { - config - .j_insn - .assign_instance(instance, lk_multiplicity, step)?; - - let rd_written = Value::new(step.rd().unwrap().value.after, lk_multiplicity); - config.rd_written.assign_value(instance, rd_written); - - Ok(()) - } -} diff --git a/ceno_zkvm/src/instructions/riscv/lui.rs b/ceno_zkvm/src/instructions/riscv/lui.rs index 743e98519..fec3e63f7 100644 --- a/ceno_zkvm/src/instructions/riscv/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/lui.rs @@ -1,30 +1,32 @@ +use ff_ext::{ExtensionField, FieldInto}; +use itertools::{Itertools, izip}; use std::marker::PhantomData; -use ff_ext::ExtensionField; - use crate::{ - Value, circuit_builder::CircuitBuilder, error::ZKVMError, instructions::{ Instruction, riscv::{ - constants::{BIT_WIDTH, PC_BITS, UInt, UInt8}, + constants::{UINT_BYTE_LIMBS, UInt8}, i_insn::IInstructionConfig, }, }, structs::ProgramParams, + tables::InsnRecord, + utils::split_to_u8, witness::LkMultiplicity, }; -use ceno_emul::{InsnKind, PC_STEP_SIZE}; -use gkr_iop::tables::LookupTable; +use ceno_emul::InsnKind; use multilinear_extensions::{Expression, ToExpr, WitIn}; use p3::field::FieldAlgebra; +use witness::set_val; pub struct LuiConfig { - pub j_insn: IInstructionConfig, + pub i_insn: IInstructionConfig, pub imm: WitIn, - pub rd_written: UInt8, + // for rd, we skip lsb byte as it's always zero + pub rd_written: [WitIn; UINT_BYTE_LIMBS - 1], } pub struct LuiInstruction(PhantomData); @@ -40,8 +42,15 @@ impl Instruction for LuiInstruction { circuit_builder: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result, ZKVMError> { - let rd_written = UInt8::new(|| "rd_written", circuit_builder)?; - let rd_exprs = rd_written.expr(); + let rd_written = std::array::from_fn(|i| { + circuit_builder + .create_u8(|| format!("rd_written_limb_{}", i)) + .unwrap() + }); + // rd lsb byte is always zero + let rd_exprs = std::iter::once(0.into()) + .chain(rd_written.map(|w| w.expr())) + .collect_vec(); let imm = circuit_builder.create_witin(|| "imm"); let i_insn = IInstructionConfig::::construct_circuit( circuit_builder, @@ -50,7 +59,7 @@ impl Instruction for LuiInstruction { #[cfg(feature = "u16limb_circuit")] 0.into(), [0.into(), 0.into()], - rd_written.register_expr(), + UInt8::from_exprs_unchecked(rd_exprs.clone()).register_expr(), false, )?; @@ -59,25 +68,23 @@ impl Instruction for LuiInstruction { .iter() .skip(1) .enumerate() - .fold(Expression::ZERO, |acc, (i, &val)| { + .fold(Expression::ZERO, |acc, (i, val)| { acc + val.expr() - * E::BaseField::from_canonical_u32(1 << (i * UInt8::LIMB_BITS)).expr() + * E::BaseField::from_canonical_u32(1 << (i * UInt8::::LIMB_BITS)).expr() }); // imm * 2^4 is the correct composition of intermed_val in case of LUI circuit_builder.require_equal( || "imm * 2^4 is the correct composition of intermed_val in case of LUI", intermed_val.expr(), - imm * E::BaseField::from_canonical_u32(1 << (12 - UInt8::LIMB_BITS)), - )?; - - circuit_builder.require_equal( - || "jal rd_written", - rd_written.value(), - i_insn.vm_state.pc.expr() + PC_STEP_SIZE, + imm.expr() * E::BaseField::from_canonical_u32(1 << (12 - UInt8::::LIMB_BITS)).expr(), )?; - Ok(JalConfig { j_insn, rd_written }) + Ok(LuiConfig { + i_insn, + imm, + rd_written, + }) } fn assign_instance( @@ -87,12 +94,90 @@ impl Instruction for LuiInstruction { step: &ceno_emul::StepRecord, ) -> Result<(), ZKVMError> { config - .j_insn + .i_insn .assign_instance(instance, lk_multiplicity, step)?; - let rd_written = Value::new(step.rd().unwrap().value.after, lk_multiplicity); - config.rd_written.assign_value(instance, rd_written); + let rd_written = split_to_u8(step.rd().unwrap().value.after); + for (val, witin) in izip!(rd_written.iter().skip(1), config.rd_written) { + lk_multiplicity.assert_ux::<8>(*val as u64); + set_val!(instance, witin, E::BaseField::from_canonical_u8(*val)); + } + let imm = InsnRecord::::imm_internal(&step.insn()).0 as u64; + set_val!(instance, config.imm, imm); Ok(()) } } + +#[cfg(test)] +mod tests { + use ceno_emul::{Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ff_ext::{BabyBearExt4, ExtensionField, GoldilocksExt2}; + use gkr_iop::circuit_builder::DebugIndex; + + use crate::{ + Value, + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{ + Instruction, + riscv::{constants::UInt, lui::LuiInstruction}, + }, + scheme::mock_prover::{MOCK_PC_START, MockProver}, + structs::ProgramParams, + }; + + #[test] + fn test_lui() { + let cases = vec![(0, 0), (0x1000, 1), (0xfffff000, 0xfffff)]; + for &(rd, imm) in &cases { + test_opcode_lui::(rd, imm); + // #[cfg(feature = "u16limb_circuit")] + // test_opcode_lui::(rd, imm); + } + } + + fn test_opcode_lui(rd: u32, imm: i32) { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = cb + .namespace( + || "lui", + |cb| { + let config = + LuiInstruction::::construct_circuit(cb, &ProgramParams::default()); + Ok(config) + }, + ) + .unwrap() + .unwrap(); + + let insn_code = encode_rv32(InsnKind::LUI, 0, 0, 4, imm); + let (raw_witin, lkm) = LuiInstruction::::assign_instances( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + vec![StepRecord::new_i_instruction( + 3, + Change::new(MOCK_PC_START, MOCK_PC_START + PC_STEP_SIZE), + insn_code, + 0, + Change::new(0, rd), + 0, + )], + ) + .unwrap(); + + // verify rd_written + let expected_rd_written = + UInt::from_const_unchecked(Value::new_unchecked(rd).as_u16_limbs().to_vec()); + let rd_written_expr = cb.get_debug_expr(DebugIndex::RdWrite as usize)[0].clone(); + cb.require_equal( + || "assert_rd_written", + rd_written_expr, + expected_rd_written.value(), + ) + .unwrap(); + + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index 1ad7968ec..6d410013d 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -1,3 +1,9 @@ +use super::{ + arith::AddInstruction, branch::BltuInstruction, ecall::HaltInstruction, jump::JalInstruction, + memory::LwInstruction, +}; +#[cfg(feature = "u16limb_circuit")] +use crate::instructions::riscv::lui::LuiInstruction; use crate::{ error::ZKVMError, instructions::{ @@ -45,11 +51,6 @@ use std::{ }; use strum::IntoEnumIterator; -use super::{ - arith::AddInstruction, branch::BltuInstruction, ecall::HaltInstruction, jump::JalInstruction, - memory::LwInstruction, -}; - pub mod mmu; pub struct Rv32imConfig { @@ -83,6 +84,8 @@ pub struct Rv32imConfig { pub srai_config: as Instruction>::InstructionConfig, pub slti_config: as Instruction>::InstructionConfig, pub sltiu_config: as Instruction>::InstructionConfig, + #[cfg(feature = "u16limb_circuit")] + pub lui_config: as Instruction>::InstructionConfig, // Branching Opcodes pub beq_config: as Instruction>::InstructionConfig, @@ -154,6 +157,8 @@ impl Rv32imConfig { let srai_config = cs.register_opcode_circuit::>(); let slti_config = cs.register_opcode_circuit::>(); let sltiu_config = cs.register_opcode_circuit::>(); + #[cfg(feature = "u16limb_circuit")] + let lui_config = cs.register_opcode_circuit::>(); // branching opcodes let beq_config = cs.register_opcode_circuit::>(); @@ -222,6 +227,8 @@ impl Rv32imConfig { srai_config, slti_config, sltiu_config, + #[cfg(feature = "u16limb_circuit")] + lui_config, // branching opcodes beq_config, bne_config, @@ -291,6 +298,8 @@ impl Rv32imConfig { fixed.register_opcode_circuit::>(cs, &self.srai_config); fixed.register_opcode_circuit::>(cs, &self.slti_config); fixed.register_opcode_circuit::>(cs, &self.sltiu_config); + #[cfg(feature = "u16limb_circuit")] + fixed.register_opcode_circuit::>(cs, &self.lui_config); // branching fixed.register_opcode_circuit::>(cs, &self.beq_config); fixed.register_opcode_circuit::>(cs, &self.bne_config); @@ -402,6 +411,8 @@ impl Rv32imConfig { assign_opcode!(SRAI, SraiInstruction, srai_config); assign_opcode!(SLTI, SltiInstruction, slti_config); assign_opcode!(SLTIU, SltiuInstruction, sltiu_config); + #[cfg(feature = "u16limb_circuit")] + assign_opcode!(LUI, LuiInstruction, lui_config); // branching assign_opcode!(BEQ, BeqInstruction, beq_config); assign_opcode!(BNE, BneInstruction, bne_config); diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 5add74a22..f56e54ca2 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -121,6 +121,11 @@ impl InsnRecord { ), // for imm operate with program counter => convert to field value (_, B | J) => (insn.imm as i64, i64_to_base(insn.imm as i64)), + // U type + (_, U) => ( + (insn.imm as u32 & 0xfffff) as i64, + F::from_wrapped_u32(insn.imm as u32 & 0xfffff), + ), // TODO JALR need to connecting register (2 limb) with pc (1 limb) (JALR, _) => (insn.imm as i64, i64_to_base(insn.imm as i64)), // for default imm to operate with register value diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index 7a703850f..1990bae39 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -1,4 +1,3 @@ -use multilinear_extensions::ToExpr; use std::{ collections::HashMap, fmt::Display, @@ -6,12 +5,19 @@ use std::{ panic::{self, PanicHookInfo}, }; -use crate::instructions::riscv::constants::UINT_LIMBS; use ff_ext::ExtensionField; pub use gkr_iop::utils::i64_to_base; use itertools::Itertools; +use p3::field::Field; + +#[cfg(feature = "u16limb_circuit")] +use crate::instructions::riscv::constants::UINT_LIMBS; +#[cfg(feature = "u16limb_circuit")] use multilinear_extensions::Expression; -use p3::field::{Field, FieldAlgebra}; +#[cfg(feature = "u16limb_circuit")] +use multilinear_extensions::ToExpr; +#[cfg(feature = "u16limb_circuit")] +use p3::field::FieldAlgebra; pub fn split_to_u8>(value: u32) -> Vec { (0..(u32::BITS / 8)) @@ -131,6 +137,7 @@ where result } +#[cfg(feature = "u16limb_circuit")] pub fn imm_sign_extend_circuit( require_signed: bool, is_signed: Expression, From 4c38c3821d6d3394d0d411469b61ad844a100a99 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Fri, 15 Aug 2025 17:42:39 +0800 Subject: [PATCH 34/46] add auipc --- ceno_emul/src/rv32im.rs | 8 + ceno_zkvm/src/instructions/riscv.rs | 2 + ceno_zkvm/src/instructions/riscv/auipc.rs | 250 ++++++++++++++++++ ceno_zkvm/src/instructions/riscv/constants.rs | 2 +- 4 files changed, 261 insertions(+), 1 deletion(-) create mode 100644 ceno_zkvm/src/instructions/riscv/auipc.rs diff --git a/ceno_emul/src/rv32im.rs b/ceno_emul/src/rv32im.rs index ec3f356c8..b711ae532 100644 --- a/ceno_emul/src/rv32im.rs +++ b/ceno_emul/src/rv32im.rs @@ -197,6 +197,8 @@ pub enum InsnKind { LHU, #[cfg(feature = "u16limb_circuit")] LUI, + #[cfg(feature = "u16limb_circuit")] + AUIPC, SB, SH, SW, @@ -220,6 +222,8 @@ impl From for InsnCategory { ECALL => System, #[cfg(feature = "u16limb_circuit")] LUI => Compute, + #[cfg(feature = "u16limb_circuit")] + AUIPC => Compute, } } } @@ -240,6 +244,8 @@ impl From for InsnFormat { INVALID => I, #[cfg(feature = "u16limb_circuit")] LUI => U, + #[cfg(feature = "u16limb_circuit")] + AUIPC => U, } } } @@ -314,6 +320,8 @@ fn step_compute(ctx: &mut M, kind: InsnKind, insn: &Instruction) ADDI => rs1.wrapping_add(imm_i), #[cfg(feature = "u16limb_circuit")] LUI => imm_i, + #[cfg(feature = "u16limb_circuit")] + AUIPC => pc.wrapping_add(insn.imm as u32).0, XORI => rs1 ^ imm_i, ORI => rs1 | imm_i, ANDI => rs1 & imm_i, diff --git a/ceno_zkvm/src/instructions/riscv.rs b/ceno_zkvm/src/instructions/riscv.rs index 73ec31650..69c656148 100644 --- a/ceno_zkvm/src/instructions/riscv.rs +++ b/ceno_zkvm/src/instructions/riscv.rs @@ -32,6 +32,8 @@ mod r_insn; mod ecall_insn; +#[cfg(feature = "u16limb_circuit")] +mod auipc; mod im_insn; #[cfg(feature = "u16limb_circuit")] mod lui; diff --git a/ceno_zkvm/src/instructions/riscv/auipc.rs b/ceno_zkvm/src/instructions/riscv/auipc.rs new file mode 100644 index 000000000..42942332c --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/auipc.rs @@ -0,0 +1,250 @@ +use ff_ext::{ExtensionField, FieldInto}; +use itertools::{Itertools, izip}; +use std::marker::PhantomData; + +use crate::{ + circuit_builder::CircuitBuilder, + error::ZKVMError, + instructions::{ + Instruction, + riscv::{ + constants::{PC_BITS, UINT_BYTE_LIMBS, UInt8}, + i_insn::IInstructionConfig, + }, + }, + structs::ProgramParams, + tables::InsnRecord, + utils::split_to_u8, + witness::LkMultiplicity, +}; +use ceno_emul::InsnKind; +use gkr_iop::tables::LookupTable; +use gkr_iop::tables::ops::XorTable; +use multilinear_extensions::{Expression, ToExpr, WitIn}; +use p3::field::{Field, FieldAlgebra}; +use witness::set_val; + +pub struct AuipcConfig { + pub i_insn: IInstructionConfig, + // The limbs of the immediate except the least significant limb since it is always 0 + pub imm_limbs: [WitIn; UINT_BYTE_LIMBS - 1], + // The limbs of the PC except the most significant and the least significant limbs + pub pc_limbs: [WitIn; UINT_BYTE_LIMBS - 2], + pub rd_written: UInt8, +} + +pub struct AuipcInstruction(PhantomData); + +impl Instruction for AuipcInstruction { + type InstructionConfig = AuipcConfig; + + fn name() -> String { + format!("{:?}", InsnKind::AUIPC) + } + + fn construct_circuit( + circuit_builder: &mut CircuitBuilder, + _params: &ProgramParams, + ) -> Result, ZKVMError> { + let rd_written = UInt8::::new(|| "rd_written", circuit_builder)?; + let rd_exprs = rd_written.expr(); + let pc_limbs = std::array::from_fn(|i| { + circuit_builder + .create_u8(|| format!("pc_limbs_{}", i)) + .unwrap() + }); + let imm_limbs = std::array::from_fn(|i| { + circuit_builder + .create_u8(|| format!("imm_limbs_{}", i)) + .unwrap() + }); + let imm = imm_limbs + .iter() + .enumerate() + .fold(E::BaseField::ZERO.expr(), |acc, (i, &val)| { + acc + val.expr() + * E::BaseField::from_canonical_u32(1 << (i * UInt8::::LIMB_BITS)).expr() + }); + + let i_insn = IInstructionConfig::::construct_circuit( + circuit_builder, + InsnKind::AUIPC, + imm.expr(), + #[cfg(feature = "u16limb_circuit")] + 0.into(), + [0.into(), 0.into()], + UInt8::from_exprs_unchecked(rd_exprs.clone()).register_expr(), + false, + )?; + + let intermed_val = rd_exprs[0].expr() + + pc_limbs + .iter() + .enumerate() + .fold(E::BaseField::ZERO.expr(), |acc, (i, val)| { + acc + val.expr() + * E::BaseField::from_canonical_u32(1 << ((i + 1) * UInt8::::LIMB_BITS)) + .expr() + }); + + // Compute the most significant limb of PC + let pc_msl = (i_insn.vm_state.pc.expr() - intermed_val.expr()) + * E::BaseField::from_canonical_usize( + 1 << (UInt8::::LIMB_BITS * (UINT_BYTE_LIMBS - 1)), + ) + .inverse() + .expr(); + + // The vector pc_limbs contains the actual limbs of PC in little endian order + let pc_limbs_expr = [rd_exprs[0].expr()] + .into_iter() + .chain(pc_limbs.iter().map(|w| w.expr())) + .map(|x| x.expr()) + .chain([pc_msl.expr()]) + .collect::>(); + assert_eq!(pc_limbs_expr.len(), UINT_BYTE_LIMBS); + + // Range check the most significant limb of pc to be in [0, 2^{PC_BITS-(RV32_REGISTER_NUM_LIMBS-1)*RV32_CELL_BITS}) + let last_limb_bits = PC_BITS - UInt8::::LIMB_BITS * (UInt8::::NUM_LIMBS - 1); + let additional_bits = + (last_limb_bits..UInt8::::LIMB_BITS).fold(0, |acc, x| acc + (1 << x)); + let additional_bits = E::BaseField::from_canonical_u32(additional_bits); + circuit_builder.logic_u8( + LookupTable::Xor, + pc_limbs_expr[3].expr(), + additional_bits.expr(), + pc_limbs_expr[3].expr() + additional_bits.expr(), + )?; + + let mut carry: [Expression; UINT_BYTE_LIMBS] = + std::array::from_fn(|_| E::BaseField::ZERO.expr()); + let carry_divide = E::BaseField::from_canonical_usize(1 << UInt8::::LIMB_BITS) + .inverse() + .expr(); + + // Don't need to constrain the least significant limb of the addition + // since we already know that rd_data[0] = pc_limbs[0] and the least significant limb of imm is 0 + // Note: imm_limbs doesn't include the least significant limb so imm_limbs[i - 1] means the i-th limb of imm + for i in 1..UINT_BYTE_LIMBS { + carry[i] = carry_divide.expr() + * (pc_limbs_expr[i].expr() + imm_limbs[i - 1].expr() - rd_exprs[i].expr() + + carry[i - 1].expr()); + circuit_builder.assert_bit(|| format!("carry_bit_{i}"), carry[i].expr())?; + } + + Ok(AuipcConfig { + i_insn, + imm_limbs, + pc_limbs, + rd_written, + }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [E::BaseField], + lk_multiplicity: &mut LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + config + .i_insn + .assign_instance(instance, lk_multiplicity, step)?; + + let rd_written = split_to_u8(step.rd().unwrap().value.after); + config.rd_written.assign_limbs(instance, &rd_written); + for val in &rd_written { + lk_multiplicity.assert_ux::<8>(*val as u64); + } + let pc = split_to_u8(step.pc().before.0); + for (val, witin) in izip!(pc.iter().skip(1), config.pc_limbs) { + lk_multiplicity.assert_ux::<8>(*val as u64); + set_val!(instance, witin, E::BaseField::from_canonical_u8(*val)); + } + let imm = InsnRecord::::imm_internal(&step.insn()).0 as u32; + let imm = split_to_u8(imm << 4); + for (val, witin) in izip!(imm.iter(), config.imm_limbs) { + lk_multiplicity.assert_ux::<8>(*val as u64); + set_val!(instance, witin, E::BaseField::from_canonical_u8(*val)); + } + // constrain pc msb limb range via xor + let last_limb_bits = PC_BITS - UInt8::::LIMB_BITS * (UInt8::::NUM_LIMBS - 1); + let additional_bits = + (last_limb_bits..UInt8::::LIMB_BITS).fold(0, |acc, x| acc + (1 << x)); + lk_multiplicity.logic_u8::(pc[3] as u64, additional_bits as u64); + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use ceno_emul::{Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ff_ext::{BabyBearExt4, ExtensionField, GoldilocksExt2}; + use gkr_iop::circuit_builder::DebugIndex; + + use crate::{ + Value, + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{ + Instruction, + riscv::{auipc::AuipcInstruction, constants::UInt}, + }, + scheme::mock_prover::{MOCK_PC_START, MockProver}, + structs::ProgramParams, + }; + + #[test] + fn test_auipc() { + let cases = vec![(MOCK_PC_START.0 + 0, 0), (MOCK_PC_START.0 + 0x1000, 1)]; + for &(rd, imm) in &cases { + test_opcode_auipc::(rd, imm); + // #[cfg(feature = "u16limb_circuit")] + // test_opcode_Auipc::(rd, imm); + } + } + + fn test_opcode_auipc(rd: u32, imm: i32) { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = cb + .namespace( + || "auipc", + |cb| { + let config = + AuipcInstruction::::construct_circuit(cb, &ProgramParams::default()); + Ok(config) + }, + ) + .unwrap() + .unwrap(); + + let insn_code = encode_rv32(InsnKind::AUIPC, 0, 0, 4, imm); + let (raw_witin, lkm) = AuipcInstruction::::assign_instances( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + vec![StepRecord::new_i_instruction( + 3, + Change::new(MOCK_PC_START, MOCK_PC_START + PC_STEP_SIZE), + insn_code, + 0, + Change::new(0, rd), + 0, + )], + ) + .unwrap(); + + // verify rd_written + let expected_rd_written = + UInt::from_const_unchecked(Value::new_unchecked(rd).as_u16_limbs().to_vec()); + let rd_written_expr = cb.get_debug_expr(DebugIndex::RdWrite as usize)[0].clone(); + cb.require_equal( + || "assert_rd_written", + rd_written_expr, + expected_rd_written.value(), + ) + .unwrap(); + + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/constants.rs b/ceno_zkvm/src/instructions/riscv/constants.rs index a5395e500..b775b62d3 100644 --- a/ceno_zkvm/src/instructions/riscv/constants.rs +++ b/ceno_zkvm/src/instructions/riscv/constants.rs @@ -23,4 +23,4 @@ pub type UIntMul = UIntLimbs<{ 2 * BIT_WIDTH }, LIMB_BITS, E>; /// use UInt for x bits limb size pub type UInt8 = UIntLimbs; pub const UINT_LIMBS: usize = BIT_WIDTH.div_ceil(LIMB_BITS); -pub const UINT_BYTE_LIMBS: usize = BIT_WIDTH.div_ceil(8); \ No newline at end of file +pub const UINT_BYTE_LIMBS: usize = BIT_WIDTH.div_ceil(8); From 635c424caa12d63bc245f48768ae52ce15129e1c Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Fri, 15 Aug 2025 18:01:50 +0800 Subject: [PATCH 35/46] auipc test pass --- ceno_emul/src/disassemble/mod.rs | 65 ++++++++++++++--------- ceno_zkvm/src/instructions/riscv/auipc.rs | 5 +- ceno_zkvm/src/tables/program.rs | 5 ++ 3 files changed, 46 insertions(+), 29 deletions(-) diff --git a/ceno_emul/src/disassemble/mod.rs b/ceno_emul/src/disassemble/mod.rs index f4af91d7d..0b3e789c1 100644 --- a/ceno_emul/src/disassemble/mod.rs +++ b/ceno_emul/src/disassemble/mod.rs @@ -284,36 +284,49 @@ impl InstructionProcessor for InstructionTranspiler { } } - /// Convert AUIPC to ADDI. - /// - /// RiscV's instructions are designed to be (mosty) position-independent. AUIPC is used - /// to get access to the current program counter, even if the code has been moved around - /// by the linker. - /// - /// Our conversion here happens after the linker has done its job, so we can safely hardcode - /// the current program counter into the immediate value of our internal ADDI. - /// - /// Note that our internal ADDI can have arbitrary intermediate values, not just 12 bits. - /// - /// ADDI is slightly more general than LUI or AUIPC, because you can also specify an - /// input register rs1. That generality might cost us sligthtly in the non-recursive proof, - /// but we suspect decreasing the total number of different instruction kinds will speed up - /// the recursive proof. - /// - /// In any case, AUIPC and LUI together make up ~0.1% of instructions executed in typical - /// real world scenarios like a `reth` run. - /// fn process_auipc(&mut self, dec_insn: UType) -> Self::InstructionResult { let pc = self.pc; // Verify our assumption that the immediate is already shifted left by 12 bits. assert_eq!(dec_insn.imm & 0xfff, 0); - Instruction { - kind: InsnKind::ADDI, - rd: dec_insn.rd, - rs1: 0, - rs2: 0, - imm: dec_insn.imm.wrapping_add(pc as i32), - raw: self.word, + #[cfg(not(feature = "u16limb_circuit"))] + { + // Convert AUIPC to ADDI. + // + // RiscV's instructions are designed to be (mosty) position-independent. AUIPC is used + // to get access to the current program counter, even if the code has been moved around + // by the linker. + // + // Our conversion here happens after the linker has done its job, so we can safely hardcode + // the current program counter into the immediate value of our internal ADDI. + // + // Note that our internal ADDI can have arbitrary intermediate values, not just 12 bits. + // + // ADDI is slightly more general than LUI or AUIPC, because you can also specify an + // input register rs1. That generality might cost us sligthtly in the non-recursive proof, + // but we suspect decreasing the total number of different instruction kinds will speed up + // the recursive proof. + // + // In any case, AUIPC and LUI together make up ~0.1% of instructions executed in typical + // real world scenarios like a `reth` run. + Instruction { + kind: InsnKind::ADDI, + rd: dec_insn.rd, + rs1: 0, + rs2: 0, + imm: dec_insn.imm.wrapping_add(pc as i32), + raw: self.word, + } + } + #[cfg(feature = "u16limb_circuit")] + { + Instruction { + kind: InsnKind::AUIPC, + rd: dec_insn.rd, + rs1: 0, + rs2: 0, + imm: dec_insn.imm, + raw: self.word, + } } } diff --git a/ceno_zkvm/src/instructions/riscv/auipc.rs b/ceno_zkvm/src/instructions/riscv/auipc.rs index 42942332c..e052977f5 100644 --- a/ceno_zkvm/src/instructions/riscv/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/auipc.rs @@ -18,8 +18,7 @@ use crate::{ witness::LkMultiplicity, }; use ceno_emul::InsnKind; -use gkr_iop::tables::LookupTable; -use gkr_iop::tables::ops::XorTable; +use gkr_iop::tables::{LookupTable, ops::XorTable}; use multilinear_extensions::{Expression, ToExpr, WitIn}; use p3::field::{Field, FieldAlgebra}; use witness::set_val; @@ -161,7 +160,7 @@ impl Instruction for AuipcInstruction { set_val!(instance, witin, E::BaseField::from_canonical_u8(*val)); } let imm = InsnRecord::::imm_internal(&step.insn()).0 as u32; - let imm = split_to_u8(imm << 4); + let imm = split_to_u8(imm); for (val, witin) in izip!(imm.iter(), config.imm_limbs) { lk_multiplicity.assert_ux::<8>(*val as u64); set_val!(instance, witin, E::BaseField::from_canonical_u8(*val)); diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index f56e54ca2..28834b4c3 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -121,6 +121,11 @@ impl InsnRecord { ), // for imm operate with program counter => convert to field value (_, B | J) => (insn.imm as i64, i64_to_base(insn.imm as i64)), + // AUIPC + (AUIPC, U) => ( + ((insn.imm as u32 & 0xfffff) << 4) as i64, + F::from_wrapped_u32((insn.imm as u32 & 0xfffff) << 4), + ), // U type (_, U) => ( (insn.imm as u32 & 0xfffff) as i64, From b3876d2ceb4182b634343b21d6a9abdd7e6921c0 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Mon, 18 Aug 2025 11:05:16 +0800 Subject: [PATCH 36/46] all test pass --- ceno_emul/src/disassemble/mod.rs | 3 +- ceno_emul/src/rv32im.rs | 2 +- ceno_zkvm/src/instructions/riscv/auipc.rs | 32 +++++++++++++++------- ceno_zkvm/src/instructions/riscv/lui.rs | 7 ++--- ceno_zkvm/src/instructions/riscv/rv32im.rs | 12 ++++++++ ceno_zkvm/src/tables/program.rs | 10 ++++--- 6 files changed, 46 insertions(+), 20 deletions(-) diff --git a/ceno_emul/src/disassemble/mod.rs b/ceno_emul/src/disassemble/mod.rs index 0b3e789c1..8332a6d6f 100644 --- a/ceno_emul/src/disassemble/mod.rs +++ b/ceno_emul/src/disassemble/mod.rs @@ -8,6 +8,7 @@ use rrs_lib::{ /// A transpiler that converts the 32-bit encoded instructions into instructions. pub(crate) struct InstructionTranspiler { + #[allow(dead_code)] pc: u32, word: u32, } @@ -285,11 +286,11 @@ impl InstructionProcessor for InstructionTranspiler { } fn process_auipc(&mut self, dec_insn: UType) -> Self::InstructionResult { - let pc = self.pc; // Verify our assumption that the immediate is already shifted left by 12 bits. assert_eq!(dec_insn.imm & 0xfff, 0); #[cfg(not(feature = "u16limb_circuit"))] { + let pc = self.pc; // Convert AUIPC to ADDI. // // RiscV's instructions are designed to be (mosty) position-independent. AUIPC is used diff --git a/ceno_emul/src/rv32im.rs b/ceno_emul/src/rv32im.rs index b711ae532..5098fd523 100644 --- a/ceno_emul/src/rv32im.rs +++ b/ceno_emul/src/rv32im.rs @@ -321,7 +321,7 @@ fn step_compute(ctx: &mut M, kind: InsnKind, insn: &Instruction) #[cfg(feature = "u16limb_circuit")] LUI => imm_i, #[cfg(feature = "u16limb_circuit")] - AUIPC => pc.wrapping_add(insn.imm as u32).0, + AUIPC => pc.wrapping_add(imm_i).0, XORI => rs1 ^ imm_i, ORI => rs1 | imm_i, ANDI => rs1 & imm_i, diff --git a/ceno_zkvm/src/instructions/riscv/auipc.rs b/ceno_zkvm/src/instructions/riscv/auipc.rs index e052977f5..5c123f77b 100644 --- a/ceno_zkvm/src/instructions/riscv/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/auipc.rs @@ -1,5 +1,5 @@ use ff_ext::{ExtensionField, FieldInto}; -use itertools::{Itertools, izip}; +use itertools::izip; use std::marker::PhantomData; use crate::{ @@ -69,7 +69,6 @@ impl Instruction for AuipcInstruction { circuit_builder, InsnKind::AUIPC, imm.expr(), - #[cfg(feature = "u16limb_circuit")] 0.into(), [0.into(), 0.into()], UInt8::from_exprs_unchecked(rd_exprs.clone()).register_expr(), @@ -88,10 +87,10 @@ impl Instruction for AuipcInstruction { // Compute the most significant limb of PC let pc_msl = (i_insn.vm_state.pc.expr() - intermed_val.expr()) - * E::BaseField::from_canonical_usize( + * (E::BaseField::from_canonical_usize( 1 << (UInt8::::LIMB_BITS * (UINT_BYTE_LIMBS - 1)), ) - .inverse() + .inverse()) .expr(); // The vector pc_limbs contains the actual limbs of PC in little endian order @@ -104,7 +103,7 @@ impl Instruction for AuipcInstruction { assert_eq!(pc_limbs_expr.len(), UINT_BYTE_LIMBS); // Range check the most significant limb of pc to be in [0, 2^{PC_BITS-(RV32_REGISTER_NUM_LIMBS-1)*RV32_CELL_BITS}) - let last_limb_bits = PC_BITS - UInt8::::LIMB_BITS * (UInt8::::NUM_LIMBS - 1); + let last_limb_bits = PC_BITS - UInt8::::LIMB_BITS * (UINT_BYTE_LIMBS - 1); let additional_bits = (last_limb_bits..UInt8::::LIMB_BITS).fold(0, |acc, x| acc + (1 << x)); let additional_bits = E::BaseField::from_canonical_u32(additional_bits); @@ -128,6 +127,7 @@ impl Instruction for AuipcInstruction { carry[i] = carry_divide.expr() * (pc_limbs_expr[i].expr() + imm_limbs[i - 1].expr() - rd_exprs[i].expr() + carry[i - 1].expr()); + // carry[i] * 2^(UInt8::LIMB_BITS) + rd_exprs[i].expr() = pc_limbs_expr[i] + imm_limbs[i].expr() + carry[i - 1].expr() circuit_builder.assert_bit(|| format!("carry_bit_{i}"), carry[i].expr())?; } @@ -166,7 +166,7 @@ impl Instruction for AuipcInstruction { set_val!(instance, witin, E::BaseField::from_canonical_u8(*val)); } // constrain pc msb limb range via xor - let last_limb_bits = PC_BITS - UInt8::::LIMB_BITS * (UInt8::::NUM_LIMBS - 1); + let last_limb_bits = PC_BITS - UInt8::::LIMB_BITS * (UINT_BYTE_LIMBS - 1); let additional_bits = (last_limb_bits..UInt8::::LIMB_BITS).fold(0, |acc, x| acc + (1 << x)); lk_multiplicity.logic_u8::(pc[3] as u64, additional_bits as u64); @@ -194,15 +194,27 @@ mod tests { #[test] fn test_auipc() { - let cases = vec![(MOCK_PC_START.0 + 0, 0), (MOCK_PC_START.0 + 0x1000, 1)]; - for &(rd, imm) in &cases { - test_opcode_auipc::(rd, imm); + let cases = vec![ + // imm without lower 12 bits zero + 0, 0x1, + // imm = -1 → all 1’s in 20-bit imm + // rd = PC - 0x1000 + -1i32, 0x12345, // imm = 0x12345 + // max positive imm + 0xfffff, + ]; + for imm in &cases { + test_opcode_auipc::( + MOCK_PC_START.0.wrapping_add((*imm as u32) << 12), + imm << 12, + ); // #[cfg(feature = "u16limb_circuit")] - // test_opcode_Auipc::(rd, imm); + // test_opcode_auipc::(rd, imm); } } fn test_opcode_auipc(rd: u32, imm: i32) { + use ceno_emul::ByteAddr; let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); let config = cb diff --git a/ceno_zkvm/src/instructions/riscv/lui.rs b/ceno_zkvm/src/instructions/riscv/lui.rs index fec3e63f7..a305dea14 100644 --- a/ceno_zkvm/src/instructions/riscv/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/lui.rs @@ -56,7 +56,6 @@ impl Instruction for LuiInstruction { circuit_builder, InsnKind::LUI, imm.expr(), - #[cfg(feature = "u16limb_circuit")] 0.into(), [0.into(), 0.into()], UInt8::from_exprs_unchecked(rd_exprs.clone()).register_expr(), @@ -128,9 +127,9 @@ mod tests { #[test] fn test_lui() { - let cases = vec![(0, 0), (0x1000, 1), (0xfffff000, 0xfffff)]; - for &(rd, imm) in &cases { - test_opcode_lui::(rd, imm); + let cases = vec![0, 0x1, 0xfffff]; + for imm in &cases { + test_opcode_lui::((*imm as u32) << 12, imm << 12); // #[cfg(feature = "u16limb_circuit")] // test_opcode_lui::(rd, imm); } diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index 6d410013d..55acd2b9e 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -3,6 +3,8 @@ use super::{ memory::LwInstruction, }; #[cfg(feature = "u16limb_circuit")] +use crate::instructions::riscv::auipc::AuipcInstruction; +#[cfg(feature = "u16limb_circuit")] use crate::instructions::riscv::lui::LuiInstruction; use crate::{ error::ZKVMError, @@ -86,6 +88,8 @@ pub struct Rv32imConfig { pub sltiu_config: as Instruction>::InstructionConfig, #[cfg(feature = "u16limb_circuit")] pub lui_config: as Instruction>::InstructionConfig, + #[cfg(feature = "u16limb_circuit")] + pub auipc_config: as Instruction>::InstructionConfig, // Branching Opcodes pub beq_config: as Instruction>::InstructionConfig, @@ -159,6 +163,8 @@ impl Rv32imConfig { let sltiu_config = cs.register_opcode_circuit::>(); #[cfg(feature = "u16limb_circuit")] let lui_config = cs.register_opcode_circuit::>(); + #[cfg(feature = "u16limb_circuit")] + let auipc_config = cs.register_opcode_circuit::>(); // branching opcodes let beq_config = cs.register_opcode_circuit::>(); @@ -229,6 +235,8 @@ impl Rv32imConfig { sltiu_config, #[cfg(feature = "u16limb_circuit")] lui_config, + #[cfg(feature = "u16limb_circuit")] + auipc_config, // branching opcodes beq_config, bne_config, @@ -300,6 +308,8 @@ impl Rv32imConfig { fixed.register_opcode_circuit::>(cs, &self.sltiu_config); #[cfg(feature = "u16limb_circuit")] fixed.register_opcode_circuit::>(cs, &self.lui_config); + #[cfg(feature = "u16limb_circuit")] + fixed.register_opcode_circuit::>(cs, &self.auipc_config); // branching fixed.register_opcode_circuit::>(cs, &self.beq_config); fixed.register_opcode_circuit::>(cs, &self.bne_config); @@ -413,6 +423,8 @@ impl Rv32imConfig { assign_opcode!(SLTIU, SltiuInstruction, sltiu_config); #[cfg(feature = "u16limb_circuit")] assign_opcode!(LUI, LuiInstruction, lui_config); + #[cfg(feature = "u16limb_circuit")] + assign_opcode!(AUIPC, AuipcInstruction, auipc_config); // branching assign_opcode!(BEQ, BeqInstruction, beq_config); assign_opcode!(BNE, BneInstruction, bne_config); diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 28834b4c3..305ec9765 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -123,13 +123,15 @@ impl InsnRecord { (_, B | J) => (insn.imm as i64, i64_to_base(insn.imm as i64)), // AUIPC (AUIPC, U) => ( - ((insn.imm as u32 & 0xfffff) << 4) as i64, - F::from_wrapped_u32((insn.imm as u32 & 0xfffff) << 4), + // riv32 u type lower 12 bits are 0 + // take all except for least significant limb (8 bit) + (insn.imm as u32 >> 8) as i64, + F::from_wrapped_u32(insn.imm as u32 >> 8), ), // U type (_, U) => ( - (insn.imm as u32 & 0xfffff) as i64, - F::from_wrapped_u32(insn.imm as u32 & 0xfffff), + (insn.imm as u32 >> 12) as i64, + F::from_wrapped_u32(insn.imm as u32 >> 12), ), // TODO JALR need to connecting register (2 limb) with pc (1 limb) (JALR, _) => (insn.imm as i64, i64_to_base(insn.imm as i64)), From 6ef6c7884cba36e7197c036ae86dcfdcea5e6163 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Mon, 18 Aug 2025 14:48:14 +0800 Subject: [PATCH 37/46] format fix --- ceno_zkvm/Cargo.toml | 16 ++++++++-------- ceno_zkvm/src/instructions/riscv/auipc.rs | 8 +++++--- ceno_zkvm/src/instructions/riscv/lui.rs | 2 +- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 392ec1a69..7e5aab876 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -67,14 +67,14 @@ forbid_overflow = [] jemalloc = ["dep:tikv-jemallocator", "dep:tikv-jemalloc-ctl"] jemalloc-prof = ["jemalloc", "tikv-jemallocator?/profiling"] nightly-features = [ - "p3/nightly-features", - "ff_ext/nightly-features", - "mpcs/nightly-features", - "multilinear_extensions/nightly-features", - "poseidon/nightly-features", - "sumcheck/nightly-features", - "transcript/nightly-features", - "witness/nightly-features", + "p3/nightly-features", + "ff_ext/nightly-features", + "mpcs/nightly-features", + "multilinear_extensions/nightly-features", + "poseidon/nightly-features", + "sumcheck/nightly-features", + "transcript/nightly-features", + "witness/nightly-features", ] sanity-check = ["mpcs/sanity-check"] u16limb_circuit = ["ceno_emul/u16limb_circuit"] diff --git a/ceno_zkvm/src/instructions/riscv/auipc.rs b/ceno_zkvm/src/instructions/riscv/auipc.rs index 5c123f77b..d9a7dd053 100644 --- a/ceno_zkvm/src/instructions/riscv/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/auipc.rs @@ -208,13 +208,15 @@ mod tests { MOCK_PC_START.0.wrapping_add((*imm as u32) << 12), imm << 12, ); - // #[cfg(feature = "u16limb_circuit")] - // test_opcode_auipc::(rd, imm); + #[cfg(feature = "u16limb_circuit")] + test_opcode_auipc::( + MOCK_PC_START.0.wrapping_add((*imm as u32) << 12), + imm << 12, + ); } } fn test_opcode_auipc(rd: u32, imm: i32) { - use ceno_emul::ByteAddr; let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); let config = cb diff --git a/ceno_zkvm/src/instructions/riscv/lui.rs b/ceno_zkvm/src/instructions/riscv/lui.rs index 41052b887..2cc280f04 100644 --- a/ceno_zkvm/src/instructions/riscv/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/lui.rs @@ -131,7 +131,7 @@ mod tests { for imm in &cases { test_opcode_lui::((*imm as u32) << 12, imm << 12); #[cfg(feature = "u16limb_circuit")] - test_opcode_lui::(rd, imm); + test_opcode_lui::((*imm as u32) << 12, imm << 12); } } From 155987f7fcc7ab88cb599b1fae9587b9d3de3550 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Mon, 18 Aug 2025 14:59:27 +0800 Subject: [PATCH 38/46] format fix --- ceno_emul/src/rv32im.rs | 8 +- .../src/instructions/riscv/jump/jal_v2.rs | 118 ------------------ 2 files changed, 2 insertions(+), 124 deletions(-) delete mode 100644 ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs diff --git a/ceno_emul/src/rv32im.rs b/ceno_emul/src/rv32im.rs index 5098fd523..1726d3009 100644 --- a/ceno_emul/src/rv32im.rs +++ b/ceno_emul/src/rv32im.rs @@ -221,9 +221,7 @@ impl From for InsnCategory { SB | SH | SW => Store, ECALL => System, #[cfg(feature = "u16limb_circuit")] - LUI => Compute, - #[cfg(feature = "u16limb_circuit")] - AUIPC => Compute, + LUI | AUIPC => Compute, } } } @@ -243,9 +241,7 @@ impl From for InsnFormat { ECALL => I, INVALID => I, #[cfg(feature = "u16limb_circuit")] - LUI => U, - #[cfg(feature = "u16limb_circuit")] - AUIPC => U, + LUI | AUIPC => U, } } } diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs deleted file mode 100644 index 42d0b3d01..000000000 --- a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs +++ /dev/null @@ -1,118 +0,0 @@ -use std::marker::PhantomData; - -use ff_ext::ExtensionField; - -use crate::{ - Value, - circuit_builder::CircuitBuilder, - error::ZKVMError, - instructions::{ - Instruction, - riscv::{ - RIVInstruction, - constants::{BIT_WIDTH, PC_BITS, UInt, UInt8}, - j_insn::JInstructionConfig, - }, - }, - structs::ProgramParams, - witness::LkMultiplicity, -}; -use ceno_emul::{InsnKind, PC_STEP_SIZE}; -use gkr_iop::tables::LookupTable; -use multilinear_extensions::{Expression, ToExpr, WitIn}; -use p3::field::FieldAlgebra; - -pub struct JalConfig { - pub j_insn: JInstructionConfig, - pub imm: WitIn, - pub rd_written: UInt8, -} - -pub struct JalInstruction(PhantomData); - -/// JAL instruction circuit -/// -/// Note: does not validate that next_pc is aligned by 4-byte increments, which -/// should be verified by lookup argument of the next execution step against -/// the program table -/// -/// Assumption: values for valid initial program counter must lie between -/// 2^20 and 2^32 - 2^20 + 2 inclusive, probably enforced by the static -/// program lookup table. If this assumption does not hold, then resulting -/// value for next_pc may not correctly wrap mod 2^32 because of the use -/// of native WitIn values for address space arithmetic. -impl Instruction for JalInstruction { - type InstructionConfig = JalConfig; - - fn name() -> String { - format!("{:?}", InsnKind::JAL) - } - - fn construct_circuit( - circuit_builder: &mut CircuitBuilder, - _params: &ProgramParams, - ) -> Result, ZKVMError> { - let rd_written = UInt8::new(|| "rd_written", circuit_builder)?; - let rd_exprs = rd_written.expr(); - let imm = circuit_builder.create_witin(|| "imm"); - - let intermed_val = - rd_exprs - .iter() - .skip(1) - .enumerate() - .fold(Expression::ZERO, |acc, (i, val)| { - acc + val.expr() - * E::BaseField::from_canonical_u32(1 << (i * UInt8::::LIMB_BITS)).expr() - }); - - let j_insn = JInstructionConfig::construct_circuit( - circuit_builder, - InsnKind::JAL, - rd_written.register_expr(), - )?; - - // constrain rd_exprs [PC_BITS .. u32::BITS] are all 0 via xor - let last_limb_bits = PC_BITS - UInt8::::LIMB_BITS * (UInt8::::NUM_LIMBS - 1); - let additional_bits = (last_limb_bits..UInt8::::LIMB_BITS).fold(0, |acc, x| acc + (1 << x)); - let additional_bits = E::BaseField::from_canonical_u32(additional_bits); - circuit_builder.logic_u8( - LookupTable::Xor, - rd_exprs[3].expr(), - additional_bits.expr(), - rd_exprs[3].expr() + additional_bits.expr(), - )?; - // circuit_builder.require_equal( - // intermed_val, - // j_insn. + E::BaseField::from_canonical_u32(DEFAULT_PC_STEP), - // ); - - circuit_builder.require_equal( - || "jal rd_written", - rd_written.value(), - j_insn.vm_state.pc.expr() + PC_STEP_SIZE, - )?; - - Ok(JalConfig { - j_insn, - imm, - rd_written, - }) - } - - fn assign_instance( - config: &Self::InstructionConfig, - instance: &mut [E::BaseField], - lk_multiplicity: &mut LkMultiplicity, - step: &ceno_emul::StepRecord, - ) -> Result<(), ZKVMError> { - config - .j_insn - .assign_instance(instance, lk_multiplicity, step)?; - - let rd_written = Value::new(step.rd().unwrap().value.after, lk_multiplicity); - config.rd_written.assign_value(instance, rd_written); - - Ok(()) - } -} From d1a60407b6762941f3bd51c29595463ddf0e9af6 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Mon, 18 Aug 2025 15:36:58 +0800 Subject: [PATCH 39/46] babybear test for addi/logici --- .github/workflows/tests.yml | 4 +- Makefile.toml | 16 ++ ceno_zkvm/Cargo.toml | 2 +- ceno_zkvm/src/instructions/riscv/arith_imm.rs | 30 ++-- .../src/instructions/riscv/branch/test.rs | 4 +- .../riscv/logic_imm/logic_imm_circuit.rs | 63 ++++++-- ceno_zkvm/src/instructions/riscv/slt.rs | 4 +- ceno_zkvm/src/instructions/riscv/slti.rs | 146 +++++++++++------- .../instructions/riscv/slti/slti_circuit.rs | 11 +- 9 files changed, 190 insertions(+), 90 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e89b3e875..ddf01b269 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -3,7 +3,7 @@ name: Tests on: merge_group: pull_request: - types: [synchronize, opened, reopened, ready_for_review] + types: [ synchronize, opened, reopened, ready_for_review ] push: branches: - master @@ -44,3 +44,5 @@ jobs: cargo make --version || cargo install cargo-make - name: run test run: cargo make tests + - name: run test + feature u16limb_circuit + run: cargo make tests_v2 diff --git a/Makefile.toml b/Makefile.toml index 15dd542ea..03e6d4ee4 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -14,6 +14,22 @@ args = [ command = "cargo" workspace = false +[tasks.tests_v2] +args = [ + "test", + # Run everything but 'benches'. + "--lib", + "--bins", + "--tests", + "--examples", + "--workspace", + "--features", + "u16limb_circuit", +] +command = "cargo" +workspace = false + + [tasks.riscv_stats] args = ["run", "--bin", "riscv_stats"] command = "cargo" diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 7e5aab876..9ada3e96e 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -61,7 +61,7 @@ ceno-examples = { path = "../examples-builder" } glob = "0.3" [features] -default = ["forbid_overflow", "u16limb_circuit"] +default = ["forbid_overflow"] flamegraph = ["pprof2/flamegraph", "pprof2/criterion"] forbid_overflow = [] jemalloc = ["dep:tikv-jemallocator", "dep:tikv-jemalloc-ctl"] diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm.rs b/ceno_zkvm/src/instructions/riscv/arith_imm.rs index 0b98ed64d..a040681bc 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm.rs @@ -26,26 +26,34 @@ mod test { structs::ProgramParams, }; use ceno_emul::{Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; - use ff_ext::GoldilocksExt2; + #[cfg(feature = "u16limb_circuit")] + use ff_ext::BabyBearExt4; + use ff_ext::{ExtensionField, GoldilocksExt2}; use gkr_iop::circuit_builder::DebugIndex; #[test] - fn test_opcode_addi_v1() { - test_opcode_addi(1000, 1003, 3); - test_opcode_addi(1000, 997, -3); + fn test_opcode_addi() { + let cases = vec![ + (1000, 1003, 3), // positive immediate + (1000, 997, -3), // negative immediate + ]; + + for &(rs1, expected, imm) in &cases { + test_opcode_addi_internal::(rs1, expected, imm); + #[cfg(feature = "u16limb_circuit")] + test_opcode_addi_internal::(rs1, expected, imm); + } } - fn test_opcode_addi(rs1: u32, rd: u32, imm: i32) { - let mut cs = ConstraintSystem::::new(|| "riscv"); + fn test_opcode_addi_internal(rs1: u32, rd: u32, imm: i32) { + let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); let config = cb .namespace( || "addi", |cb| { - let config = AddiInstruction::::construct_circuit( - cb, - &ProgramParams::default(), - ); + let config = + AddiInstruction::::construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) @@ -53,7 +61,7 @@ mod test { .unwrap(); let insn_code = encode_rv32(InsnKind::ADDI, 2, 0, 4, imm); - let (raw_witin, lkm) = AddiInstruction::::assign_instances( + let (raw_witin, lkm) = AddiInstruction::::assign_instances( &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, diff --git a/ceno_zkvm/src/instructions/riscv/branch/test.rs b/ceno_zkvm/src/instructions/riscv/branch/test.rs index 9697b0422..aaf468127 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/test.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/test.rs @@ -1,5 +1,7 @@ use ceno_emul::{ByteAddr, Change, PC_STEP_SIZE, StepRecord, Word, encode_rv32}; -use ff_ext::{BabyBearExt4, ExtensionField, GoldilocksExt2}; +#[cfg(feature = "u16limb_circuit")] +use ff_ext::BabyBearExt4; +use ff_ext::{ExtensionField, GoldilocksExt2}; use super::*; use crate::{ diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs index cfb6bbc59..aad60b43b 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs @@ -124,7 +124,9 @@ impl LogicConfig { #[cfg(test)] mod test { use ceno_emul::{Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32u}; - use ff_ext::GoldilocksExt2; + #[cfg(feature = "u16limb_circuit")] + use ff_ext::BabyBearExt4; + use ff_ext::{ExtensionField, GoldilocksExt2}; use gkr_iop::circuit_builder::DebugIndex; use crate::{ @@ -150,27 +152,56 @@ mod test { #[test] fn test_opcode_andi() { - verify::("basic", 0x0000_0011, 3, 0x0000_0011 & 3); - verify::("zero result", 0x0000_0100, 3, 0x0000_0100 & 3); - verify::("negative imm", TEST, NEG, TEST & NEG); + let cases = vec![ + ("basic", 0x0000_0011, 3, 0x0000_0011 & 3), + ("zero result", 0x0000_0100, 3, 0x0000_0100 & 3), + ("negative imm", TEST, NEG, TEST & NEG), + ]; + + for &(name, rs1, imm, expected) in &cases { + verify::(name, rs1, imm, expected); + #[cfg(feature = "u16limb_circuit")] + verify::(name, rs1, imm, expected); + } } #[test] fn test_opcode_ori() { - verify::("basic", 0x0000_0011, 3, 0x0000_0011 | 3); - verify::("basic2", 0x0000_0100, 3, 0x0000_0100 | 3); - verify::("negative imm", TEST, NEG, TEST | NEG); + let cases = vec![ + ("basic", 0x0000_0011, 3, 0x0000_0011 | 3), + ("basic2", 0x0000_0100, 3, 0x0000_0100 | 3), + ("negative imm", TEST, NEG, TEST | NEG), + ]; + + for &(name, rs1, imm, expected) in &cases { + verify::(name, rs1, imm, expected); + #[cfg(feature = "u16limb_circuit")] + verify::(name, rs1, imm, expected); + } } #[test] fn test_opcode_xori() { - verify::("basic", 0x0000_0011, 3, 0x0000_0011 ^ 3); - verify::("non-overlap", 0x0000_0100, 3, 0x0000_0100 ^ 3); - verify::("negative imm", TEST, NEG, TEST ^ NEG); + let cases = vec![ + ("basic", 0x0000_0011, 3, 0x0000_0011 ^ 3), + ("non-overlap", 0x0000_0100, 3, 0x0000_0100 ^ 3), + ("negative imm", TEST, NEG, TEST ^ NEG), + ]; + + for &(name, rs1, imm, expected) in &cases { + verify::(name, rs1, imm, expected); + #[cfg(feature = "u16limb_circuit")] + verify::(name, rs1, imm, expected); + } } - fn verify(name: &'static str, rs1_read: u32, imm: u32, expected_rd_written: u32) { - let mut cs = ConstraintSystem::::new(|| "riscv"); + fn verify( + name: &'static str, + rs1_read: u32, + imm: u32, + expected_rd_written: u32, + ) { + let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); let (prefix, rd_written) = match I::INST_KIND { @@ -184,10 +215,8 @@ mod test { .namespace( || format!("{prefix}_({name})"), |cb| { - let config = LogicInstruction::::construct_circuit( - cb, - &ProgramParams::default(), - ); + let config = + LogicInstruction::::construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) @@ -195,7 +224,7 @@ mod test { .unwrap(); let insn_code = encode_rv32u(I::INST_KIND, 2, 0, 4, imm); - let (raw_witin, lkm) = LogicInstruction::::assign_instances( + let (raw_witin, lkm) = LogicInstruction::::assign_instances( &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, diff --git a/ceno_zkvm/src/instructions/riscv/slt.rs b/ceno_zkvm/src/instructions/riscv/slt.rs index e5303ced2..0ff3e230d 100644 --- a/ceno_zkvm/src/instructions/riscv/slt.rs +++ b/ceno_zkvm/src/instructions/riscv/slt.rs @@ -26,7 +26,9 @@ pub type SltuInstruction = slt_circuit::SetLessThanInstruction; #[cfg(test)] mod test { use ceno_emul::{Change, StepRecord, Word, encode_rv32}; - use ff_ext::{BabyBearExt4, ExtensionField, GoldilocksExt2}; + #[cfg(feature = "u16limb_circuit")] + use ff_ext::BabyBearExt4; + use ff_ext::{ExtensionField, GoldilocksExt2}; use rand::RngCore; diff --git a/ceno_zkvm/src/instructions/riscv/slti.rs b/ceno_zkvm/src/instructions/riscv/slti.rs index 9e3e99a65..5802c4229 100644 --- a/ceno_zkvm/src/instructions/riscv/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/slti.rs @@ -27,7 +27,7 @@ pub type SltiuInstruction = SetLessThanImmInstruction; #[cfg(test)] mod test { use ceno_emul::{Change, PC_STEP_SIZE, StepRecord, encode_rv32}; - use ff_ext::GoldilocksExt2; + use ff_ext::{ExtensionField, GoldilocksExt2}; use proptest::proptest; @@ -45,35 +45,50 @@ mod test { scheme::mock_prover::{MOCK_PC_START, MockProver}, structs::ProgramParams, }; + #[cfg(feature = "u16limb_circuit")] + use ff_ext::BabyBearExt4; #[test] fn test_sltiu_true() { - let verify = |name, a, imm| verify::(name, a, imm, true); - verify("lt = true, 0 < 1", 0, 1); - verify("lt = true, 1 < 2", 1, 2); - verify("lt = true, 10 < 20", 10, 20); - verify("lt = true, 0 < imm upper boundary", 0, 2047); - // negative imm is treated as positive - verify("lt = true, 0 < u32::MAX-1", 0, -1); - verify("lt = true, 1 < u32::MAX-1", 1, -1); - verify("lt = true, 0 < imm lower bondary", 0, -2048); - verify("lt = true, 65535 < imm lower bondary", 65535, -1); + let cases = vec![ + ("lt = true, 0 < 1", 0, 1i32), + ("lt = true, 1 < 2", 1, 2), + ("lt = true, 10 < 20", 10, 20), + ("lt = true, 0 < imm upper boundary", 0, 2047), + // negative imm is treated as positive + ("lt = true, 0 < u32::MAX-1", 0, -1), + ("lt = true, 1 < u32::MAX-1", 1, -1), + ("lt = true, 0 < imm lower boundary", 0, -2048), + ("lt = true, 65535 < imm lower boundary", 65535, -1), + ]; + + for &(name, a, imm) in &cases { + verify::(name, a, imm, true); + #[cfg(feature = "u16limb_circuit")] + verify::(name, a, imm, true); + } } #[test] fn test_sltiu_false() { - let verify = |name, a, imm| verify::(name, a, imm, false); - verify("lt = false, 1 < 0", 1, 0); - verify("lt = false, 2 < 1", 2, 1); - verify("lt = false, 100 < 50", 100, 50); - verify("lt = false, 500 < 100", 500, 100); - verify("lt = false, 100000 < 2047", 100000, 2047); - verify("lt = false, 100000 < 0", 100000, 0); - verify("lt = false, 0 == 0", 0, 0); - verify("lt = false, 1 == 1", 1, 1); - verify("lt = false, imm upper bondary", u32::MAX, 2047); - // negative imm is treated as positive - verify("lt = false, imm lower bondary", u32::MAX, -2048); + let cases = vec![ + ("lt = false, 1 < 0", 1, 0i32), + ("lt = false, 2 < 1", 2, 1), + ("lt = false, 100 < 50", 100, 50), + ("lt = false, 500 < 100", 500, 100), + ("lt = false, 100000 < 2047", 100_000, 2047), + ("lt = false, 100000 < 0", 100_000, 0), + ("lt = false, 0 == 0", 0, 0), + ("lt = false, 1 == 1", 1, 1), + ("lt = false, imm upper boundary", u32::MAX, 2047), + ("lt = false, imm lower boundary", u32::MAX, -2048), /* negative imm treated as positive */ + ]; + + for &(name, a, imm) in &cases { + verify::(name, a, imm, false); + #[cfg(feature = "u16limb_circuit")] + verify::(name, a, imm, false); + } } proptest! { @@ -82,37 +97,53 @@ mod test { a in u32_extra(), imm in immu_extra(12), ) { - verify::("random SltiuOp", a, imm as i32, a < imm); + verify::("random SltiuOp", a, imm as i32, a < imm); + #[cfg(feature = "u16limb_circuit")] + verify::("random SltiuOp", a, imm as i32, a < imm); } } #[test] fn test_slti_true() { - let verify = |name, a: i32, imm| verify::(name, a as u32, imm, true); - verify("lt = true, 0 < 1", 0, 1); - verify("lt = true, 1 < 2", 1, 2); - verify("lt = true, -1 < 0", -1, 0); - verify("lt = true, -1 < 1", -1, 1); - verify("lt = true, -2 < -1", -2, -1); - // -2048 <= imm <= 2047 - verify("lt = true, imm upper bondary", i32::MIN, 2047); - verify("lt = true, imm lower bondary", i32::MIN, -2048); + let cases = vec![ + ("lt = true, 0 < 1", 0, 1), + ("lt = true, 1 < 2", 1, 2), + ("lt = true, -1 < 0", -1, 0), + ("lt = true, -1 < 1", -1, 1), + ("lt = true, -2 < -1", -2, -1), + // -2048 <= imm <= 2047 + ("lt = true, imm upper boundary", i32::MIN, 2047), + ("lt = true, imm lower boundary", i32::MIN, -2048), + ]; + + for &(name, a, imm) in &cases { + verify::(name, a as u32, imm, true); + #[cfg(feature = "u16limb_circuit")] + verify::(name, a as u32, imm, true); + } } #[test] fn test_slti_false() { - let verify = |name, a: i32, imm| verify::(name, a as u32, imm, false); - verify("lt = false, 1 < 0", 1, 0); - verify("lt = false, 2 < 1", 2, 1); - verify("lt = false, 0 < -1", 0, -1); - verify("lt = false, 1 < -1", 1, -1); - verify("lt = false, -1 < -2", -1, -2); - verify("lt = false, 0 == 0", 0, 0); - verify("lt = false, 1 == 1", 1, 1); - verify("lt = false, -1 == -1", -1, -1); - // -2048 <= imm <= 2047 - verify("lt = false, imm upper bondary", i32::MAX, 2047); - verify("lt = false, imm lower bondary", i32::MAX, -2048); + let cases = vec![ + ("lt = false, 1 < 0", 1, 0), + ("lt = false, 2 < 1", 2, 1), + ("lt = false, 0 < -1", 0, -1), + ("lt = false, 1 < -1", 1, -1), + ("lt = false, -1 < -2", -1, -2), + ("lt = false, 0 == 0", 0, 0), + ("lt = false, 1 == 1", 1, 1), + ("lt = false, -1 == -1", -1, -1), + // -2048 <= imm <= 2047 + ("lt = false, imm upper boundary", i32::MAX, 2047), + ("lt = false, imm lower boundary", i32::MAX, -2048), + ]; + + for &(name, a, imm) in &cases { + verify::(name, a as u32, imm, false); + #[cfg(feature = "u16limb_circuit")] + verify::(name, a as u32, imm, false); + } } proptest! { @@ -121,13 +152,20 @@ mod test { a in i32_extra(), imm in imm_extra(12), ) { - verify::("random SltiOp", a as u32, imm, a < imm); + verify::("random SltiOp", a as u32, imm, a < imm); + #[cfg(feature = "u16limb_circuit")] + verify::("random SltiOp", a as u32, imm, a < imm); } } - fn verify(name: &'static str, rs1_read: u32, imm: i32, expected_rd: bool) { + fn verify( + name: &'static str, + rs1_read: u32, + imm: i32, + expected_rd: bool, + ) { let expected_rd = expected_rd as u32; - let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); let insn_code = encode_rv32(I::INST_KIND, 2, 0, 4, imm); @@ -136,18 +174,16 @@ mod test { .namespace( || format!("{:?}_({name})", I::INST_KIND), |cb| { - Ok( - SetLessThanImmInstruction::::construct_circuit( - cb, - &ProgramParams::default(), - ), - ) + Ok(SetLessThanImmInstruction::::construct_circuit( + cb, + &ProgramParams::default(), + )) }, ) .unwrap() .unwrap(); - let (raw_witin, lkm) = SetLessThanImmInstruction::::assign_instances( + let (raw_witin, lkm) = SetLessThanImmInstruction::::assign_instances( &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs index 632bf873e..266faeed3 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs @@ -7,7 +7,7 @@ use crate::{ Instruction, riscv::{ RIVInstruction, - constants::{UINT_LIMBS, UInt}, + constants::{LIMB_BITS, UINT_LIMBS, UInt}, i_insn::IInstructionConfig, }, }, @@ -62,8 +62,13 @@ impl Instruction for SetLessThanImmInst _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), }; - let lt = - IsLtConfig::construct_circuit(cb, || "rs1 < imm", value_expr, imm.expr(), UINT_LIMBS)?; + let lt = IsLtConfig::construct_circuit( + cb, + || "rs1 < imm", + value_expr, + imm.expr(), + UINT_LIMBS * LIMB_BITS, + )?; let rd_written = UInt::from_exprs_unchecked(vec![lt.expr()]); let i_insn = IInstructionConfig::::construct_circuit( From 8151d7250e3fe953fab6a4cd1579b8dbb3b5c7da Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Mon, 18 Aug 2025 21:14:09 +0800 Subject: [PATCH 40/46] jal migrated --- ceno_zkvm/src/instructions/riscv/jump.rs | 7 ++ .../src/instructions/riscv/jump/jal_v2.rs | 114 ++++++++++++++++++ 2 files changed, 121 insertions(+) create mode 100644 ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs diff --git a/ceno_zkvm/src/instructions/riscv/jump.rs b/ceno_zkvm/src/instructions/riscv/jump.rs index b57aadbbb..01161ceff 100644 --- a/ceno_zkvm/src/instructions/riscv/jump.rs +++ b/ceno_zkvm/src/instructions/riscv/jump.rs @@ -1,7 +1,14 @@ +#[cfg(not(feature = "u16limb_circuit"))] mod jal; +#[cfg(feature = "u16limb_circuit")] +mod jal_v2; mod jalr; +#[cfg(not(feature = "u16limb_circuit"))] pub use jal::JalInstruction; +#[cfg(feature = "u16limb_circuit")] +pub use jal_v2::JalInstruction; + pub use jalr::JalrInstruction; #[cfg(test)] diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs new file mode 100644 index 000000000..e89590c6b --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs @@ -0,0 +1,114 @@ +use std::marker::PhantomData; + +use ff_ext::ExtensionField; + +use crate::{ + Value, + circuit_builder::CircuitBuilder, + error::ZKVMError, + instructions::{ + Instruction, + riscv::{ + constants::{BIT_WIDTH, PC_BITS, UINT_BYTE_LIMBS, UInt8}, + j_insn::JInstructionConfig, + }, + }, + structs::ProgramParams, + utils::split_to_u8, + witness::LkMultiplicity, +}; +use ceno_emul::{InsnKind, PC_STEP_SIZE}; +use gkr_iop::tables::{LookupTable, ops::XorTable}; +use multilinear_extensions::{Expression, ToExpr}; +use p3::field::FieldAlgebra; + +pub struct JalConfig { + pub j_insn: JInstructionConfig, + pub rd_written: UInt8, +} + +pub struct JalInstruction(PhantomData); + +/// JAL instruction circuit +/// +/// Note: does not validate that next_pc is aligned by 4-byte increments, which +/// should be verified by lookup argument of the next execution step against +/// the program table +/// +/// Assumption: values for valid initial program counter must lie between +/// 2^20 and 2^32 - 2^20 + 2 inclusive, probably enforced by the static +/// program lookup table. If this assumption does not hold, then resulting +/// value for next_pc may not correctly wrap mod 2^32 because of the use +/// of native WitIn values for address space arithmetic. +impl Instruction for JalInstruction { + type InstructionConfig = JalConfig; + + fn name() -> String { + format!("{:?}", InsnKind::JAL) + } + + fn construct_circuit( + circuit_builder: &mut CircuitBuilder, + _params: &ProgramParams, + ) -> Result, ZKVMError> { + let rd_written = UInt8::new(|| "rd_written", circuit_builder)?; + let rd_exprs = rd_written.expr(); + + let j_insn = JInstructionConfig::construct_circuit( + circuit_builder, + InsnKind::JAL, + rd_written.register_expr(), + )?; + + // constrain rd_exprs [PC_BITS .. u32::BITS] are all 0 via xor + let last_limb_bits = PC_BITS - UInt8::::LIMB_BITS * (UInt8::::NUM_LIMBS - 1); + let additional_bits = + (last_limb_bits..UInt8::::LIMB_BITS).fold(0, |acc, x| acc + (1 << x)); + let additional_bits = E::BaseField::from_canonical_u32(additional_bits); + circuit_builder.logic_u8( + LookupTable::Xor, + rd_exprs[3].expr(), + additional_bits.expr(), + rd_exprs[3].expr() + additional_bits.expr(), + )?; + + circuit_builder.require_equal( + || "jal rd_written", + rd_exprs + .iter() + .enumerate() + .fold(Expression::ZERO, |acc, (i, val)| { + acc + val.expr() + * E::BaseField::from_canonical_u32(1 << (i * UInt8::::LIMB_BITS)).expr() + }), + j_insn.vm_state.pc.expr() + PC_STEP_SIZE, + )?; + + Ok(JalConfig { j_insn, rd_written }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [E::BaseField], + lk_multiplicity: &mut LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + config + .j_insn + .assign_instance(instance, lk_multiplicity, step)?; + + let rd_written = split_to_u8(step.rd().unwrap().value.after); + config.rd_written.assign_limbs(instance, &rd_written); + for val in &rd_written { + lk_multiplicity.assert_ux::<8>(*val as u64); + } + + // constrain pc msb limb range via xor + let last_limb_bits = PC_BITS - UInt8::::LIMB_BITS * (UINT_BYTE_LIMBS - 1); + let additional_bits = + (last_limb_bits..UInt8::::LIMB_BITS).fold(0, |acc, x| acc + (1 << x)); + lk_multiplicity.logic_u8::(rd_written[3] as u64, additional_bits as u64); + + Ok(()) + } +} From 2ccf4ed50b077478ffe53b31ea86aeeab26268ad Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Tue, 19 Aug 2025 15:59:53 +0800 Subject: [PATCH 41/46] migrate jal/jalr --- ceno_zkvm/src/instructions/riscv/insn_base.rs | 47 ++++- ceno_zkvm/src/instructions/riscv/jump.rs | 7 + .../src/instructions/riscv/jump/jal_v2.rs | 3 +- ceno_zkvm/src/instructions/riscv/jump/jalr.rs | 2 - .../src/instructions/riscv/jump/jalr_v2.rs | 184 ++++++++++++++++++ ceno_zkvm/src/instructions/riscv/jump/test.rs | 86 ++++++-- ceno_zkvm/src/tables/program.rs | 8 +- gkr_iop/src/circuit_builder.rs | 20 ++ gkr_iop/src/utils/lk_multiplicity.rs | 14 ++ 9 files changed, 337 insertions(+), 34 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index bca216b66..0412dfb09 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -3,7 +3,7 @@ use ff_ext::{ExtensionField, FieldInto, SmallField}; use itertools::Itertools; use p3::field::{Field, FieldAlgebra}; -use super::constants::{PC_STEP_SIZE, UINT_LIMBS, UInt}; +use super::constants::{BIT_WIDTH, PC_STEP_SIZE, UINT_LIMBS, UInt}; use crate::{ chip_handler::{ AddressExpr, GlobalStateRegisterMachineChipOperations, MemoryChipOperations, MemoryExpr, @@ -368,6 +368,7 @@ impl WriteMEM { pub struct MemAddr { addr: UInt, low_bits: Vec, + max_bits: usize, } impl MemAddr { @@ -393,6 +394,17 @@ impl MemAddr { self.addr.address_expr() } + pub fn uint_unaligned(&self) -> UInt { + UInt::from_exprs_unchecked(self.addr.expr()) + } + + pub fn uint_align2(&self) -> UInt { + UInt::from_exprs_unchecked(vec![ + self.addr.limbs[0].expr() - &self.low_bit_exprs()[0], + self.addr.limbs[1].expr(), + ]) + } + /// Represent the address aligned to 2 bytes. pub fn expr_align2(&self) -> AddressExpr { self.addr.address_expr() - &self.low_bit_exprs()[0] @@ -404,6 +416,14 @@ impl MemAddr { self.addr.address_expr() - &low_bits[1] * 2 - &low_bits[0] } + pub fn uint_align4(&self) -> UInt { + let low_bits = self.low_bit_exprs(); + UInt::from_exprs_unchecked(vec![ + self.addr.limbs[0].expr() - &low_bits[1] * 2 - &low_bits[0], + self.addr.limbs[1].expr(), + ]) + } + /// Expressions of the low bits of the address, LSB-first: [bit_0, bit_1]. pub fn low_bit_exprs(&self) -> Vec> { iter::repeat_n(Expression::ZERO, self.n_zeros()) @@ -412,6 +432,14 @@ impl MemAddr { } fn construct(cb: &mut CircuitBuilder, n_zeros: usize) -> Result { + Self::construct_with_max_bits(cb, n_zeros, BIT_WIDTH) + } + + pub fn construct_with_max_bits( + cb: &mut CircuitBuilder, + n_zeros: usize, + max_bits: usize, + ) -> Result { assert!(n_zeros <= Self::N_LOW_BITS); // The address as two u16 limbs. @@ -442,11 +470,19 @@ impl MemAddr { cb.assert_ux::<_, _, 14>(|| "mid_u14", mid_u14)?; // Range check the high limb. - for high_u16 in limbs.iter().skip(1) { - cb.assert_ux::<_, _, 16>(|| "high_u16", high_u16.clone())?; + for (i, high_limb) in limbs.iter().enumerate().skip(1) { + cb.assert_ux_v2( + || "high_limb", + high_limb.clone(), + (max_bits - i * 16).min(16), + )?; } - Ok(MemAddr { addr, low_bits }) + Ok(MemAddr { + addr, + low_bits, + max_bits, + }) } pub fn assign_instance( @@ -470,7 +506,8 @@ impl MemAddr { // Range check the high limb. for i in 1..UINT_LIMBS { let high_u16 = (addr >> (i * 16)) & 0xffff; - lkm.assert_ux::<16>(high_u16 as u64); + println!("assignment max bit {}", (self.max_bits - i * 16).min(16)); + lkm.assert_ux_v2(high_u16 as u64, (self.max_bits - i * 16).min(16)); } Ok(()) diff --git a/ceno_zkvm/src/instructions/riscv/jump.rs b/ceno_zkvm/src/instructions/riscv/jump.rs index 01161ceff..7bf1a41f6 100644 --- a/ceno_zkvm/src/instructions/riscv/jump.rs +++ b/ceno_zkvm/src/instructions/riscv/jump.rs @@ -2,14 +2,21 @@ mod jal; #[cfg(feature = "u16limb_circuit")] mod jal_v2; + +#[cfg(not(feature = "u16limb_circuit"))] mod jalr; +#[cfg(feature = "u16limb_circuit")] +mod jalr_v2; #[cfg(not(feature = "u16limb_circuit"))] pub use jal::JalInstruction; #[cfg(feature = "u16limb_circuit")] pub use jal_v2::JalInstruction; +#[cfg(not(feature = "u16limb_circuit"))] pub use jalr::JalrInstruction; +#[cfg(feature = "u16limb_circuit")] +pub use jalr_v2::JalrInstruction; #[cfg(test)] mod test; diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs index e89590c6b..cd6ad194b 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs @@ -3,13 +3,12 @@ use std::marker::PhantomData; use ff_ext::ExtensionField; use crate::{ - Value, circuit_builder::CircuitBuilder, error::ZKVMError, instructions::{ Instruction, riscv::{ - constants::{BIT_WIDTH, PC_BITS, UINT_BYTE_LIMBS, UInt8}, + constants::{PC_BITS, UINT_BYTE_LIMBS, UInt8}, j_insn::JInstructionConfig, }, }, diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs index fe077d464..f1ba94aa7 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs @@ -53,8 +53,6 @@ impl Instruction for JalrInstruction { circuit_builder, InsnKind::JALR, imm.expr(), - #[cfg(feature = "u16limb_circuit")] - 0.into(), rs1_read.register_expr(), rd_written.register_expr(), true, diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs new file mode 100644 index 000000000..66bb06c26 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs @@ -0,0 +1,184 @@ +use ff_ext::ExtensionField; +use std::marker::PhantomData; + +use crate::{ + Value, + chip_handler::general::InstFetch, + circuit_builder::CircuitBuilder, + error::ZKVMError, + instructions::{ + Instruction, + riscv::{ + constants::{PC_BITS, UINT_LIMBS, UInt}, + i_insn::IInstructionConfig, + insn_base::{MemAddr, ReadRS1, StateInOut, WriteRD}, + }, + }, + structs::ProgramParams, + tables::InsnRecord, + utils::imm_sign_extend, + witness::{LkMultiplicity, set_val}, +}; +use ceno_emul::{InsnKind, PC_STEP_SIZE}; +use ff_ext::FieldInto; +use multilinear_extensions::{Expression, ToExpr, WitIn}; +use p3::field::{Field, FieldAlgebra}; + +pub struct JalrConfig { + pub i_insn: IInstructionConfig, + pub rs1_read: UInt, + pub imm: WitIn, + pub imm_sign: WitIn, + pub jump_pc_addr: MemAddr, + pub rd_high: WitIn, +} + +pub struct JalrInstruction(PhantomData); + +/// JALR instruction circuit +/// NOTE: does not validate that next_pc is aligned by 4-byte increments, which +/// should be verified by lookup argument of the next execution step against +/// the program table +impl Instruction for JalrInstruction { + type InstructionConfig = JalrConfig; + + fn name() -> String { + format!("{:?}", InsnKind::JALR) + } + + fn construct_circuit( + circuit_builder: &mut CircuitBuilder, + _params: &ProgramParams, + ) -> Result, ZKVMError> { + assert_eq!(UINT_LIMBS, 2); + let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; // unsigned 32-bit value + let imm = circuit_builder.create_witin(|| "imm"); // signed 12-bit value + let imm_sign = circuit_builder.create_witin(|| "imm_sign"); + // State in and out + let vm_state = StateInOut::construct_circuit(circuit_builder, true)?; + let rd_high = circuit_builder.create_witin(|| "rd_high"); + let rd_low: Expression<_> = vm_state.pc.expr() + + E::BaseField::from_canonical_usize(PC_STEP_SIZE).expr() + - rd_high.expr() * E::BaseField::from_canonical_u32(1 << UInt::::LIMB_BITS).expr(); + // rd range check + // rd_low + circuit_builder.assert_ux_v2(|| "rd_low_u16", rd_low.expr(), UInt::::LIMB_BITS)?; + // rd_high + circuit_builder.assert_ux_v2( + || "rd_high_range", + rd_high.expr(), + PC_BITS - UInt::::LIMB_BITS, + )?; + let rd_uint = UInt::from_exprs_unchecked(vec![rd_low.expr(), rd_high.expr()]); + + let jump_pc_addr = MemAddr::construct_with_max_bits(circuit_builder, 0, PC_BITS)?; + + // Registers + let rs1 = + ReadRS1::construct_circuit(circuit_builder, rs1_read.register_expr(), vm_state.ts)?; + let rd = WriteRD::construct_circuit(circuit_builder, rd_uint.register_expr(), vm_state.ts)?; + + // Fetch the instruction. + circuit_builder.lk_fetch(&InsnRecord::new( + vm_state.pc.expr(), + InsnKind::JALR.into(), + Some(rd.id.expr()), + rs1.id.expr(), + 0.into(), + imm.expr(), + imm_sign.expr(), + ))?; + + let i_insn = IInstructionConfig { vm_state, rs1, rd }; + + // Next pc is obtained by rounding rs1+imm down to an even value. + // To implement this, check three conditions: + // 1. rs1 + imm = jump_pc_addr + overflow*2^32 + // 3. next_pc = jump_pc_addr aligned to even value (round down) + + let inv = E::BaseField::from_canonical_u32(1 << UInt::::LIMB_BITS).inverse(); + + let carry = (rs1_read.expr()[0].expr() + imm.expr() + - jump_pc_addr.uint_unaligned().expr()[0].expr()) + * inv.expr(); + circuit_builder.assert_bit(|| "carry_lo_bit", carry.expr())?; + + let imm_extend_limb = imm_sign.expr() + * E::BaseField::from_canonical_u32((1 << UInt::::LIMB_BITS) - 1).expr(); + let carry = (rs1_read.expr()[1].expr() + imm_extend_limb.expr() + carry + - jump_pc_addr.uint_unaligned().expr()[1].expr()) + * inv.expr(); + circuit_builder.assert_bit(|| "overflow_bit", carry)?; + + circuit_builder.require_equal( + || "jump_pc_addr = next_pc", + jump_pc_addr.expr_align2(), + i_insn.vm_state.next_pc.unwrap().expr(), + )?; + + // write pc+4 to rd + circuit_builder.require_equal( + || "rd_written = pc+4", + rd_uint.value(), // this operation is safe + i_insn.vm_state.pc.expr() + PC_STEP_SIZE, + )?; + + Ok(JalrConfig { + i_insn, + rs1_read, + imm, + imm_sign, + jump_pc_addr, + rd_high, + }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [E::BaseField], + lk_multiplicity: &mut LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + let insn = step.insn(); + + let rs1 = step.rs1().unwrap().value; + let imm = InsnRecord::::imm_internal(&insn); + set_val!(instance, config.imm, imm.1); + // according to riscvim32 spec, imm always do signed extension + let imm_sign_extend = imm_sign_extend(true, step.insn().imm as i16); + set_val!( + instance, + config.imm_sign, + E::BaseField::from_bool(imm_sign_extend[1] > 0) + ); + let rd = Value::new_unchecked(step.rd().unwrap().value.after); + let rd_limb = rd.as_u16_limbs(); + lk_multiplicity.assert_ux_v2(rd_limb[0] as u64, 16); + lk_multiplicity.assert_ux_v2(rd_limb[1] as u64, PC_BITS - 16); + + config + .rs1_read + .assign_value(instance, Value::new_unchecked(rs1)); + set_val!( + instance, + config.rd_high, + E::BaseField::from_canonical_u16(rd_limb[1]) + ); + + let (sum, _) = rs1.overflowing_add_signed(i32::from_ne_bytes([ + imm_sign_extend[0] as u8, + (imm_sign_extend[0] >> 8) as u8, + imm_sign_extend[1] as u8, + (imm_sign_extend[1] >> 8) as u8, + ])); + config + .jump_pc_addr + .assign_instance(instance, lk_multiplicity, sum)?; + + config + .i_insn + .assign_instance(instance, lk_multiplicity, step)?; + + Ok(()) + } +} diff --git a/ceno_zkvm/src/instructions/riscv/jump/test.rs b/ceno_zkvm/src/instructions/riscv/jump/test.rs index 5a7036ab7..0b379f250 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/test.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/test.rs @@ -1,37 +1,46 @@ -use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, Word, encode_rv32}; -use ff_ext::GoldilocksExt2; - +use super::{JalInstruction, JalrInstruction}; use crate::{ + Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, - instructions::Instruction, + instructions::{Instruction, riscv::constants::UInt}, scheme::mock_prover::{MOCK_PC_START, MockProver}, structs::ProgramParams, }; - -use super::{JalInstruction, JalrInstruction}; +use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, Word, encode_rv32}; +#[cfg(feature = "u16limb_circuit")] +use ff_ext::BabyBearExt4; +use ff_ext::{ExtensionField, GoldilocksExt2}; +use gkr_iop::circuit_builder::DebugIndex; #[test] fn test_opcode_jal() { - let mut cs = ConstraintSystem::::new(|| "riscv"); + verify_test_opcode_jal::(-8); + verify_test_opcode_jal::(8); + + #[cfg(feature = "u16limb_circuit")] + { + verify_test_opcode_jal::(-8); + verify_test_opcode_jal::(8); + } +} + +fn verify_test_opcode_jal(pc_offset: i32) { + let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); let config = cb .namespace( || "jal", |cb| { - let config = JalInstruction::::construct_circuit( - cb, - &ProgramParams::default(), - ); + let config = JalInstruction::::construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) .unwrap() .unwrap(); - let pc_offset: i32 = -8i32; let new_pc: ByteAddr = ByteAddr(MOCK_PC_START.0.wrapping_add_signed(pc_offset)); let insn_code = encode_rv32(InsnKind::JAL, 0, 0, 4, pc_offset); - let (raw_witin, lkm) = JalInstruction::::assign_instances( + let (raw_witin, lkm) = JalInstruction::::assign_instances( &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, @@ -45,33 +54,68 @@ fn test_opcode_jal() { ) .unwrap(); + // verify rd_written + let expected_rd_written = UInt::from_const_unchecked( + Value::new_unchecked(MOCK_PC_START.0 + PC_STEP_SIZE as u32) + .as_u16_limbs() + .to_vec(), + ); + let rd_written_expr = cb.get_debug_expr(DebugIndex::RdWrite as usize)[0].clone(); + cb.require_equal( + || "assert_rd_written", + rd_written_expr, + expected_rd_written.value(), + ) + .unwrap(); + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); } #[test] fn test_opcode_jalr() { - let mut cs = ConstraintSystem::::new(|| "riscv"); + verify_test_opcode_jalr::(100, 3); + verify_test_opcode_jalr::(100, -3); + + #[cfg(feature = "u16limb_circuit")] + { + verify_test_opcode_jalr::(100, 3); + verify_test_opcode_jalr::(100, -3); + } +} + +fn verify_test_opcode_jalr(rs1_read: Word, imm: i32) { + let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); let config = cb .namespace( || "jalr", |cb| { - let config = JalrInstruction::::construct_circuit( - cb, - &ProgramParams::default(), - ); + let config = JalrInstruction::::construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) .unwrap() .unwrap(); - let imm = -15i32; - let rs1_read: Word = 100u32; + // trim lower bit to 0 let new_pc: ByteAddr = ByteAddr(rs1_read.wrapping_add_signed(imm) & (!1)); let insn_code = encode_rv32(InsnKind::JALR, 2, 0, 4, imm); - let (raw_witin, lkm) = JalrInstruction::::assign_instances( + // verify rd_written + let expected_rd_written = UInt::from_const_unchecked( + Value::new_unchecked(MOCK_PC_START.0 + PC_STEP_SIZE as u32) + .as_u16_limbs() + .to_vec(), + ); + let rd_written_expr = cb.get_debug_expr(DebugIndex::RdWrite as usize)[0].clone(); + cb.require_equal( + || "assert_rd_written", + rd_written_expr, + expected_rd_written.value(), + ) + .unwrap(); + + let (raw_witin, lkm) = JalrInstruction::::assign_instances( &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 305ec9765..092c4b560 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -133,8 +133,10 @@ impl InsnRecord { (insn.imm as u32 >> 12) as i64, F::from_wrapped_u32(insn.imm as u32 >> 12), ), - // TODO JALR need to connecting register (2 limb) with pc (1 limb) - (JALR, _) => (insn.imm as i64, i64_to_base(insn.imm as i64)), + (JALR, _) => ( + insn.imm as i16 as i64, + F::from_canonical_u16(insn.imm as i16 as u16), + ), // for default imm to operate with register value _ => ( insn.imm as i16 as i64, @@ -158,8 +160,6 @@ impl InsnRecord { // in particular imm operated with program counter // encode as field element, which do not need extra sign extension of imm (_, B | J) => (false as i64, F::from_bool(false)), - // TODO JALR need to connecting register (2 limb) with pc (1 limb) - (JALR, _) => (false as i64, F::from_bool(false)), // Signed views _ => ((insn.imm < 0) as i64, F::from_bool(insn.imm < 0)), } diff --git a/gkr_iop/src/circuit_builder.rs b/gkr_iop/src/circuit_builder.rs index f3db364ee..8a336c007 100644 --- a/gkr_iop/src/circuit_builder.rs +++ b/gkr_iop/src/circuit_builder.rs @@ -785,6 +785,26 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { } } + /// to replace `assert_ux` + pub fn assert_ux_v2( + &mut self, + name_fn: N, + expr: Expression, + max_bits: usize, + ) -> Result<(), CircuitBuilderError> + where + NR: Into, + N: FnOnce() -> NR, + { + match max_bits { + 16 => self.assert_u16(name_fn, expr), + 14 => self.assert_u14(name_fn, expr), + 8 => self.assert_byte(name_fn, expr), + 5 => self.assert_u5(name_fn, expr), + c => panic!("Unsupported bit range {c}"), + } + } + /// Generates U16 lookups to prove that `value` fits on `size < 16` bits. /// In general it can be done by two U16 checks: one for `value` and one for /// `value << (16 - size)`. diff --git a/gkr_iop/src/utils/lk_multiplicity.rs b/gkr_iop/src/utils/lk_multiplicity.rs index 2b93662f9..c69558bbd 100644 --- a/gkr_iop/src/utils/lk_multiplicity.rs +++ b/gkr_iop/src/utils/lk_multiplicity.rs @@ -201,6 +201,20 @@ impl LkMultiplicity { } } + #[inline(always)] + pub fn assert_ux_v2(&mut self, v: u64, max_bits: usize) { + self.increment( + match max_bits { + 16 => LookupTable::U16, + 14 => LookupTable::U14, + 8 => LookupTable::U8, + 5 => LookupTable::U5, + _ => panic!("Unsupported bit range"), + }, + v, + ); + } + /// Track a lookup into a logic table (AndTable, etc). pub fn logic_u8(&mut self, a: u64, b: u64) { self.increment(OP::ROM_TYPE, OP::pack(a, b)); From 89974361427525d7e9e578f4b24914aebaabfc84 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Wed, 20 Aug 2025 00:27:07 +0800 Subject: [PATCH 42/46] migrated shift opcode --- ceno_zkvm/Cargo.toml | 18 +- ceno_zkvm/src/instructions/riscv/shift.rs | 284 ++--------- .../instructions/riscv/shift/shift_circuit.rs | 218 +++++++++ .../riscv/shift/shift_circuit_v2.rs | 454 ++++++++++++++++++ ceno_zkvm/src/utils.rs | 10 + 5 files changed, 729 insertions(+), 255 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs create mode 100644 ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 9ada3e96e..392ec1a69 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -61,20 +61,20 @@ ceno-examples = { path = "../examples-builder" } glob = "0.3" [features] -default = ["forbid_overflow"] +default = ["forbid_overflow", "u16limb_circuit"] flamegraph = ["pprof2/flamegraph", "pprof2/criterion"] forbid_overflow = [] jemalloc = ["dep:tikv-jemallocator", "dep:tikv-jemalloc-ctl"] jemalloc-prof = ["jemalloc", "tikv-jemallocator?/profiling"] nightly-features = [ - "p3/nightly-features", - "ff_ext/nightly-features", - "mpcs/nightly-features", - "multilinear_extensions/nightly-features", - "poseidon/nightly-features", - "sumcheck/nightly-features", - "transcript/nightly-features", - "witness/nightly-features", + "p3/nightly-features", + "ff_ext/nightly-features", + "mpcs/nightly-features", + "multilinear_extensions/nightly-features", + "poseidon/nightly-features", + "sumcheck/nightly-features", + "transcript/nightly-features", + "witness/nightly-features", ] sanity-check = ["mpcs/sanity-check"] u16limb_circuit = ["ceno_emul/u16limb_circuit"] diff --git a/ceno_zkvm/src/instructions/riscv/shift.rs b/ceno_zkvm/src/instructions/riscv/shift.rs index d6d25f2b3..510b1b654 100644 --- a/ceno_zkvm/src/instructions/riscv/shift.rs +++ b/ceno_zkvm/src/instructions/riscv/shift.rs @@ -1,42 +1,15 @@ -use std::marker::PhantomData; +#[cfg(not(feature = "u16limb_circuit"))] +mod shift_circuit; +#[cfg(feature = "u16limb_circuit")] +mod shift_circuit_v2; use ceno_emul::InsnKind; -use ff_ext::ExtensionField; -use super::{RIVInstruction, constants::UInt, r_insn::RInstructionConfig}; -use crate::{ - Value, - error::ZKVMError, - gadgets::{AssertLtConfig, SignedExtendConfig}, - instructions::{ - Instruction, - riscv::constants::{LIMB_BITS, UINT_LIMBS}, - }, - structs::ProgramParams, -}; -use ff_ext::FieldInto; -use multilinear_extensions::{Expression, ToExpr, WitIn}; -use witness::set_val; - -pub struct ShiftConfig { - r_insn: RInstructionConfig, - - rs1_read: UInt, - rs2_read: UInt, - rd_written: UInt, - - rs2_high: UInt, - rs2_low5: WitIn, - pow2_rs2_low5: WitIn, - - outflow: WitIn, - assert_lt_config: AssertLtConfig, - - // SRA - signed_extend_config: Option>, -} - -pub struct ShiftLogicalInstruction(PhantomData<(E, I)>); +use super::RIVInstruction; +#[cfg(not(feature = "u16limb_circuit"))] +use crate::instructions::riscv::shift::shift_circuit::ShiftLogicalInstruction; +#[cfg(feature = "u16limb_circuit")] +use crate::instructions::riscv::shift::shift_circuit_v2::ShiftLogicalInstruction; pub struct SllOp; impl RIVInstruction for SllOp { @@ -56,189 +29,12 @@ impl RIVInstruction for SraOp { } pub type SraInstruction = ShiftLogicalInstruction; -impl Instruction for ShiftLogicalInstruction { - type InstructionConfig = ShiftConfig; - - fn name() -> String { - format!("{:?}", I::INST_KIND) - } - - fn construct_circuit( - circuit_builder: &mut crate::circuit_builder::CircuitBuilder, - _params: &ProgramParams, - ) -> Result { - // treat bit shifting as a bit "inflow" and "outflow" process, flowing from left to right or vice versa - // this approach simplifies constraint and witness allocation compared to using multiplication/division gadget, - // as the divisor/multiplier is a power of 2. - // - // example: right shift (bit flow from left to right) - // inflow || rs1_read == rd_written || outflow - // in this case, inflow consists of either all 0s or all 1s for sign extension (if the value is signed). - // - // for left shifts, the inflow is always 0: - // rs1_read || inflow == outflow || rd_written - // - // additional constraint: outflow < (1 << shift), which lead to unique solution - - // soundness: take Goldilocks as example, both sides of the equation are 63 bits numbers (<2**63) - // rd_written * pow2_rs2_low5 + outflow == inflow * 2**32 + rs1_read - // 32 + 31. 31. 31 + 32. 32. (Bit widths) - - let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; - let rd_written = UInt::new(|| "rd_written", circuit_builder)?; - - let rs2_read = UInt::new_unchecked(|| "rs2_read", circuit_builder)?; - let rs2_low5 = circuit_builder.create_witin(|| "rs2_low5"); - // pow2_rs2_low5 is unchecked because it's assignment will be constrained due it's use in lookup_pow2 below - let pow2_rs2_low5 = circuit_builder.create_witin(|| "pow2_rs2_low5"); - // rs2 = rs2_high | rs2_low5 - let rs2_high = UInt::new(|| "rs2_high", circuit_builder)?; - - let outflow = circuit_builder.create_witin(|| "outflow"); - let assert_lt_config = AssertLtConfig::construct_circuit( - circuit_builder, - || "outflow < pow2_rs2_low5", - outflow.expr(), - pow2_rs2_low5.expr(), - UINT_LIMBS * LIMB_BITS, - )?; - - let two_pow_total_bits: Expression<_> = (1u64 << UInt::::TOTAL_BITS).into(); - - let signed_extend_config = match I::INST_KIND { - InsnKind::SLL => { - circuit_builder.require_equal( - || "shift check", - rs1_read.value() * pow2_rs2_low5.expr(), - outflow.expr() * two_pow_total_bits + rd_written.value(), - )?; - None - } - InsnKind::SRL | InsnKind::SRA => { - let (inflow, signed_extend_config) = match I::INST_KIND { - InsnKind::SRA => { - let signed_extend_config = rs1_read.is_negative(circuit_builder)?; - let msb_expr = signed_extend_config.expr(); - let ones = pow2_rs2_low5.expr() - Expression::ONE; - (msb_expr * ones, Some(signed_extend_config)) - } - InsnKind::SRL => (Expression::ZERO, None), - _ => unreachable!(), - }; - - circuit_builder.require_equal( - || "shift check", - rd_written.value() * pow2_rs2_low5.expr() + outflow.expr(), - inflow * two_pow_total_bits + rs1_read.value(), - )?; - signed_extend_config - } - _ => unreachable!(), - }; - - let r_insn = RInstructionConfig::::construct_circuit( - circuit_builder, - I::INST_KIND, - rs1_read.register_expr(), - rs2_read.register_expr(), - rd_written.register_expr(), - )?; - - circuit_builder.lookup_pow2(rs2_low5.expr(), pow2_rs2_low5.expr())?; - circuit_builder.assert_ux::<_, _, 5>(|| "rs2_low5 in u5", rs2_low5.expr())?; - circuit_builder.require_equal( - || "rs2 == rs2_high * 2^5 + rs2_low5", - rs2_read.value(), - (rs2_high.value() << 5) + rs2_low5.expr(), - )?; - - Ok(ShiftConfig { - r_insn, - rs1_read, - rs2_read, - rd_written, - rs2_high, - rs2_low5, - pow2_rs2_low5, - outflow, - assert_lt_config, - signed_extend_config, - }) - } - - fn assign_instance( - config: &Self::InstructionConfig, - instance: &mut [::BaseField], - lk_multiplicity: &mut crate::witness::LkMultiplicity, - step: &ceno_emul::StepRecord, - ) -> Result<(), crate::error::ZKVMError> { - // rs2 & its derived values - let rs2_read = Value::new_unchecked(step.rs2().unwrap().value); - let rs2_low5 = rs2_read.as_u64() & 0b11111; - lk_multiplicity.assert_ux::<5>(rs2_low5); - lk_multiplicity.lookup_pow2(rs2_low5); - - let pow2_rs2_low5 = 1u64 << rs2_low5; - - let rs2_high = Value::new( - ((rs2_read.as_u64() - rs2_low5) >> 5) as u32, - lk_multiplicity, - ); - config.rs2_high.assign_value(instance, rs2_high); - config.rs2_read.assign_value(instance, rs2_read); - - set_val!(instance, config.pow2_rs2_low5, pow2_rs2_low5); - set_val!(instance, config.rs2_low5, rs2_low5); - - // rs1 - let rs1_read = Value::new_unchecked(step.rs1().unwrap().value); - - // rd - let rd_written = Value::new(step.rd().unwrap().value.after, lk_multiplicity); - - // outflow - let outflow = match I::INST_KIND { - InsnKind::SLL => (rs1_read.as_u64() * pow2_rs2_low5) >> UInt::::TOTAL_BITS, - InsnKind::SRL => rs1_read.as_u64() & (pow2_rs2_low5 - 1), - InsnKind::SRA => { - let Some(signed_ext_config) = config.signed_extend_config.as_ref() else { - Err(ZKVMError::CircuitError)? - }; - signed_ext_config.assign_instance( - instance, - lk_multiplicity, - *rs1_read.as_u16_limbs().last().unwrap() as u64, - )?; - rs1_read.as_u64() & (pow2_rs2_low5 - 1) - } - _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), - }; - - set_val!(instance, config.outflow, outflow); - - config.rs1_read.assign_value(instance, rs1_read); - config.rd_written.assign_value(instance, rd_written); - - config.assert_lt_config.assign_instance( - instance, - lk_multiplicity, - outflow, - pow2_rs2_low5, - )?; - - config - .r_insn - .assign_instance(instance, lk_multiplicity, step)?; - - Ok(()) - } -} - #[cfg(test)] mod tests { use ceno_emul::{Change, InsnKind, StepRecord, encode_rv32}; - use ff_ext::GoldilocksExt2; + use ff_ext::{ExtensionField, GoldilocksExt2}; + use super::{ShiftLogicalInstruction, SllOp, SraOp, SrlOp}; use crate::{ Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, @@ -250,56 +46,54 @@ mod tests { structs::ProgramParams, }; - use super::{ShiftLogicalInstruction, SllOp, SraOp, SrlOp}; - #[test] fn test_opcode_sll() { - verify::("basic", 0b_0001, 3, 0b_1000); + verify::("basic", 0b_0001, 3, 0b_1000); // 33 << 33 === 33 << 1 - verify::("rs2 over 5-bits", 0b_0001, 33, 0b_0010); - verify::("bit loss", (1 << 31) | 1, 1, 0b_0010); - verify::("zero shift", 0b_0001, 0, 0b_0001); - verify::("all zeros", 0b_0000, 0, 0b_0000); - verify::("base is zero", 0b_0000, 1, 0b_0000); + verify::("rs2 over 5-bits", 0b_0001, 33, 0b_0010); + verify::("bit loss", (1 << 31) | 1, 1, 0b_0010); + verify::("zero shift", 0b_0001, 0, 0b_0001); + verify::("all zeros", 0b_0000, 0, 0b_0000); + verify::("base is zero", 0b_0000, 1, 0b_0000); } #[test] fn test_opcode_srl() { - verify::("basic", 0b_1000, 3, 0b_0001); + verify::("basic", 0b_1000, 3, 0b_0001); // 33 >> 33 === 33 >> 1 - verify::("rs2 over 5-bits", 0b_1010, 33, 0b_0101); - verify::("bit loss", 0b_1001, 1, 0b_0100); - verify::("zero shift", 0b_1000, 0, 0b_1000); - verify::("all zeros", 0b_0000, 0, 0b_0000); - verify::("base is zero", 0b_0000, 1, 0b_0000); + verify::("rs2 over 5-bits", 0b_1010, 33, 0b_0101); + verify::("bit loss", 0b_1001, 1, 0b_0100); + verify::("zero shift", 0b_1000, 0, 0b_1000); + verify::("all zeros", 0b_0000, 0, 0b_0000); + verify::("base is zero", 0b_0000, 1, 0b_0000); } #[test] fn test_opcode_sra() { // positive rs1 // rs2 = 3 - verify::("32 >> 3", 32, 3, 32 >> 3); - verify::("33 >> 3", 33, 3, 33 >> 3); + verify::("32 >> 3", 32, 3, 32 >> 3); + verify::("33 >> 3", 33, 3, 33 >> 3); // rs2 = 31 - verify::("32 >> 31", 32, 31, 32 >> 31); - verify::("33 >> 31", 33, 31, 33 >> 31); + verify::("32 >> 31", 32, 31, 32 >> 31); + verify::("33 >> 31", 33, 31, 33 >> 31); // negative rs1 // rs2 = 3 - verify::("-32 >> 3", (-32_i32) as u32, 3, (-32_i32 >> 3) as u32); - verify::("-33 >> 3", (-33_i32) as u32, 3, (-33_i32 >> 3) as u32); + verify::("-32 >> 3", (-32_i32) as u32, 3, (-32_i32 >> 3) as u32); + verify::("-33 >> 3", (-33_i32) as u32, 3, (-33_i32 >> 3) as u32); // rs2 = 31 - verify::("-32 >> 31", (-32_i32) as u32, 31, (-32_i32 >> 31) as u32); - verify::("-33 >> 31", (-33_i32) as u32, 31, (-33_i32 >> 31) as u32); + verify::("-32 >> 31", (-32_i32) as u32, 31, (-32_i32 >> 31) as u32); + verify::("-33 >> 31", (-33_i32) as u32, 31, (-33_i32 >> 31) as u32); } - fn verify( + fn verify( name: &'static str, rs1_read: u32, rs2_read: u32, expected_rd_written: u32, ) { - let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); let shift = rs2_read & 0b11111; @@ -326,12 +120,10 @@ mod tests { .namespace( || format!("{prefix}_({name})"), |cb| { - Ok( - ShiftLogicalInstruction::::construct_circuit( - cb, - &ProgramParams::default(), - ), - ) + Ok(ShiftLogicalInstruction::::construct_circuit( + cb, + &ProgramParams::default(), + )) }, ) .unwrap() @@ -350,7 +142,7 @@ mod tests { ) .unwrap(); - let (raw_witin, lkm) = ShiftLogicalInstruction::::assign_instances( + let (raw_witin, lkm) = ShiftLogicalInstruction::::assign_instances( &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs new file mode 100644 index 000000000..87374b20e --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs @@ -0,0 +1,218 @@ +use crate::{ + Value, + error::ZKVMError, + gadgets::SignedExtendConfig, + instructions::{ + Instruction, + riscv::{ + RIVInstruction, + constants::{LIMB_BITS, UINT_LIMBS, UInt}, + r_insn::RInstructionConfig, + }, + }, + structs::ProgramParams, +}; +use ceno_emul::InsnKind; +use ff_ext::{ExtensionField, FieldInto}; +use gkr_iop::gadgets::AssertLtConfig; +use multilinear_extensions::{Expression, ToExpr, WitIn}; +use std::marker::PhantomData; +use witness::set_val; + +pub struct ShiftConfig { + r_insn: RInstructionConfig, + + rs1_read: UInt, + rs2_read: UInt, + pub rd_written: UInt, + + rs2_high: UInt, + rs2_low5: WitIn, + pow2_rs2_low5: WitIn, + + outflow: WitIn, + assert_lt_config: AssertLtConfig, + + // SRA + signed_extend_config: Option>, +} + +pub struct ShiftLogicalInstruction(PhantomData<(E, I)>); + +impl Instruction for ShiftLogicalInstruction { + type InstructionConfig = ShiftConfig; + + fn name() -> String { + format!("{:?}", I::INST_KIND) + } + + fn construct_circuit( + circuit_builder: &mut crate::circuit_builder::CircuitBuilder, + _params: &ProgramParams, + ) -> Result { + // treat bit shifting as a bit "inflow" and "outflow" process, flowing from left to right or vice versa + // this approach simplifies constraint and witness allocation compared to using multiplication/division gadget, + // as the divisor/multiplier is a power of 2. + // + // example: right shift (bit flow from left to right) + // inflow || rs1_read == rd_written || outflow + // in this case, inflow consists of either all 0s or all 1s for sign extension (if the value is signed). + // + // for left shifts, the inflow is always 0: + // rs1_read || inflow == outflow || rd_written + // + // additional constraint: outflow < (1 << shift), which lead to unique solution + + // soundness: take Goldilocks as example, both sides of the equation are 63 bits numbers (<2**63) + // rd_written * pow2_rs2_low5 + outflow == inflow * 2**32 + rs1_read + // 32 + 31. 31. 31 + 32. 32. (Bit widths) + + let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; + let rd_written = UInt::new(|| "rd_written", circuit_builder)?; + + let rs2_read = UInt::new_unchecked(|| "rs2_read", circuit_builder)?; + let rs2_low5 = circuit_builder.create_witin(|| "rs2_low5"); + // pow2_rs2_low5 is unchecked because it's assignment will be constrained due it's use in lookup_pow2 below + let pow2_rs2_low5 = circuit_builder.create_witin(|| "pow2_rs2_low5"); + // rs2 = rs2_high | rs2_low5 + let rs2_high = UInt::new(|| "rs2_high", circuit_builder)?; + + let outflow = circuit_builder.create_witin(|| "outflow"); + let assert_lt_config = AssertLtConfig::construct_circuit( + circuit_builder, + || "outflow < pow2_rs2_low5", + outflow.expr(), + pow2_rs2_low5.expr(), + UINT_LIMBS * LIMB_BITS, + )?; + + let two_pow_total_bits: Expression<_> = (1u64 << UInt::::TOTAL_BITS).into(); + + let signed_extend_config = match I::INST_KIND { + InsnKind::SLL => { + circuit_builder.require_equal( + || "shift check", + rs1_read.value() * pow2_rs2_low5.expr(), + outflow.expr() * two_pow_total_bits + rd_written.value(), + )?; + None + } + InsnKind::SRL | InsnKind::SRA => { + let (inflow, signed_extend_config) = match I::INST_KIND { + InsnKind::SRA => { + let signed_extend_config = rs1_read.is_negative(circuit_builder)?; + let msb_expr = signed_extend_config.expr(); + let ones = pow2_rs2_low5.expr() - Expression::ONE; + (msb_expr * ones, Some(signed_extend_config)) + } + InsnKind::SRL => (Expression::ZERO, None), + _ => unreachable!(), + }; + + circuit_builder.require_equal( + || "shift check", + rd_written.value() * pow2_rs2_low5.expr() + outflow.expr(), + inflow * two_pow_total_bits + rs1_read.value(), + )?; + signed_extend_config + } + _ => unreachable!(), + }; + + let r_insn = RInstructionConfig::::construct_circuit( + circuit_builder, + I::INST_KIND, + rs1_read.register_expr(), + rs2_read.register_expr(), + rd_written.register_expr(), + )?; + + circuit_builder.lookup_pow2(rs2_low5.expr(), pow2_rs2_low5.expr())?; + circuit_builder.assert_ux::<_, _, 5>(|| "rs2_low5 in u5", rs2_low5.expr())?; + circuit_builder.require_equal( + || "rs2 == rs2_high * 2^5 + rs2_low5", + rs2_read.value(), + (rs2_high.value() << 5) + rs2_low5.expr(), + )?; + + Ok(ShiftConfig { + r_insn, + rs1_read, + rs2_read, + rd_written, + rs2_high, + rs2_low5, + pow2_rs2_low5, + outflow, + assert_lt_config, + signed_extend_config, + }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [::BaseField], + lk_multiplicity: &mut crate::witness::LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), crate::error::ZKVMError> { + // rs2 & its derived values + let rs2_read = Value::new_unchecked(step.rs2().unwrap().value); + let rs2_low5 = rs2_read.as_u64() & 0b11111; + lk_multiplicity.assert_ux::<5>(rs2_low5); + lk_multiplicity.lookup_pow2(rs2_low5); + + let pow2_rs2_low5 = 1u64 << rs2_low5; + + let rs2_high = Value::new( + ((rs2_read.as_u64() - rs2_low5) >> 5) as u32, + lk_multiplicity, + ); + config.rs2_high.assign_value(instance, rs2_high); + config.rs2_read.assign_value(instance, rs2_read); + + set_val!(instance, config.pow2_rs2_low5, pow2_rs2_low5); + set_val!(instance, config.rs2_low5, rs2_low5); + + // rs1 + let rs1_read = Value::new_unchecked(step.rs1().unwrap().value); + + // rd + let rd_written = Value::new(step.rd().unwrap().value.after, lk_multiplicity); + + // outflow + let outflow = match I::INST_KIND { + InsnKind::SLL => (rs1_read.as_u64() * pow2_rs2_low5) >> UInt::::TOTAL_BITS, + InsnKind::SRL => rs1_read.as_u64() & (pow2_rs2_low5 - 1), + InsnKind::SRA => { + let Some(signed_ext_config) = config.signed_extend_config.as_ref() else { + Err(ZKVMError::CircuitError)? + }; + signed_ext_config.assign_instance( + instance, + lk_multiplicity, + *rs1_read.as_u16_limbs().last().unwrap() as u64, + )?; + rs1_read.as_u64() & (pow2_rs2_low5 - 1) + } + _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), + }; + + set_val!(instance, config.outflow, outflow); + + config.rs1_read.assign_value(instance, rs1_read); + config.rd_written.assign_value(instance, rd_written); + + config.assert_lt_config.assign_instance( + instance, + lk_multiplicity, + outflow, + pow2_rs2_low5, + )?; + + config + .r_insn + .assign_instance(instance, lk_multiplicity, step)?; + + Ok(()) + } +} diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs new file mode 100644 index 000000000..76f7db2b6 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs @@ -0,0 +1,454 @@ +use crate::{ + Value, + error::ZKVMError, + instructions::{ + Instruction, + riscv::{ + RIVInstruction, + constants::{LIMB_BITS, UINT_LIMBS, UInt}, + r_insn::RInstructionConfig, + }, + }, + structs::ProgramParams, + utils::split_to_limb, +}; +use ceno_emul::InsnKind; +use ff_ext::{ExtensionField, FieldInto}; +use gkr_iop::gadgets::AssertLtConfig; +use itertools::Itertools; +use multilinear_extensions::{Expression, ToExpr, WitIn}; +use p3::field::{Field, FieldAlgebra}; +use std::{array, marker::PhantomData}; +use witness::{InstancePaddingStrategy::Default, set_val}; + +pub struct ShiftBaseConfig { + // bit_multiplier = 2^bit_shift + pub bit_multiplier_left: WitIn, + pub bit_multiplier_right: WitIn, + + // Sign of x for SRA + pub b_sign: WitIn, + + // Boolean columns that are 1 exactly at the index of the bit/limb shift amount + pub bit_shift_marker: [WitIn; LIMB_BITS], + pub limb_shift_marker: [WitIn; NUM_LIMBS], + + // Part of each x[i] that gets bit shifted to the next limb + pub bit_shift_carry: [WitIn; NUM_LIMBS], + pub phantom: PhantomData, +} + +impl + ShiftBaseConfig +{ + pub fn construct_circuit( + circuit_builder: &mut crate::circuit_builder::CircuitBuilder, + kind: InsnKind, + a: [Expression; NUM_LIMBS], + b: [Expression; NUM_LIMBS], + c: [Expression; NUM_LIMBS], + ) -> Result { + let bit_shift_marker = + array::from_fn(|i| circuit_builder.create_witin(|| format!("bit_shift_marker_{}", i))); + let limb_shift_marker = + array::from_fn(|i| circuit_builder.create_witin(|| format!("limb_shift_marker_{}", i))); + let bit_multiplier_left = circuit_builder.create_witin(|| "bit_multiplier_left"); + let bit_multiplier_right = circuit_builder.create_witin(|| "bit_multiplier_right"); + let b_sign = circuit_builder.create_bit(|| "b_sign")?; + let bit_shift_carry = + array::from_fn(|i| circuit_builder.create_witin(|| format!("bit_shift_carry_{}", i))); + + // Constrain that bit_shift, bit_multiplier are correct, i.e. that bit_multiplier = + // 1 << bit_shift. Because the sum of all bit_shift_marker[i] is constrained to be + // 1, bit_shift is guaranteed to be in range. + let mut bit_marker_sum = Expression::ZERO; + let mut bit_shift = Expression::ZERO; + + for i in 0..LIMB_BITS { + circuit_builder.assert_bit( + || format!("bit_shift_marker_{i}_assert_bit"), + bit_shift_marker[i].expr(), + )?; + bit_marker_sum += bit_shift_marker[i].expr(); + bit_shift += E::BaseField::from_canonical_usize(i).expr() * bit_shift_marker[i].expr(); + + match kind { + InsnKind::SLL => { + circuit_builder.condition_require_zero( + || "bit_multiplier_left_condition", + bit_shift_marker[i].expr(), + bit_multiplier_left.expr() + - E::BaseField::from_canonical_usize(1 << i).expr(), + )?; + } + InsnKind::SRL | InsnKind::SRA => { + circuit_builder.condition_require_zero( + || "bit_multiplier_right_condition", + bit_shift_marker[i].expr(), + bit_multiplier_right.expr() + - E::BaseField::from_canonical_usize(1 << i).expr(), + )?; + } + _ => unreachable!(), + } + } + circuit_builder.require_one(|| "bit_marker_sum_as_1", bit_marker_sum.expr())?; + + // Check that a[i] = b[i] <> c[i] both on the bit and limb shift level if c < + // NUM_LIMBS * LIMB_BITS. + let mut limb_marker_sum = Expression::ZERO; + let mut limb_shift = Expression::ZERO; + for i in 0..NUM_LIMBS { + circuit_builder.assert_bit( + || format!("limb_shift_marker_{i}_assert_bit"), + limb_shift_marker[i].expr(), + )?; + limb_marker_sum += limb_shift_marker[i].expr(); + limb_shift += + E::BaseField::from_canonical_usize(i).expr() * limb_shift_marker[i].expr(); + + for j in 0..NUM_LIMBS { + match kind { + InsnKind::SLL => { + if j < i { + circuit_builder.condition_require_zero( + || format!("limb_shift_marker_a_{j}"), + limb_shift_marker[i].expr(), + a[j].expr(), + )?; + } else { + let expected_a_left = if j - i == 0 { + Expression::ZERO + } else { + bit_shift_carry[j - i - 1].expr() + } + b[j - i].expr() * bit_multiplier_left.expr() + - E::BaseField::from_canonical_usize(1 << LIMB_BITS).expr() + * bit_shift_carry[j - i].expr(); + circuit_builder.condition_require_zero( + || format!("limb_shift_marker_a_expected_a_left_{j}",), + limb_shift_marker[i].expr(), + a[j].expr() - expected_a_left, + )?; + } + } + InsnKind::SRL | InsnKind::SRA => { + // SRL and SRA constraints. Combining with above would require an additional column. + if j + i > NUM_LIMBS - 1 { + circuit_builder.condition_require_zero( + || format!("limb_shift_marker_a_{j}"), + limb_shift_marker[i].expr(), + b_sign.expr() + * E::BaseField::from_canonical_usize((1 << LIMB_BITS) - 1) + .expr(), + )?; + } else { + let expected_a_right = + if j + i == NUM_LIMBS - 1 { + b_sign.expr() * (bit_multiplier_right.expr() - Expression::ONE) + } else { + bit_shift_carry[j + i + 1].expr() + } * E::BaseField::from_canonical_usize(1 << LIMB_BITS).expr() + + (b[j + i].expr() - bit_shift_carry[j + i].expr()); + + circuit_builder.condition_require_zero( + || format!("limb_shift_marker_a_expected_a_left_{j}",), + limb_shift_marker[i].expr(), + a[j].expr() * bit_multiplier_right.expr() - expected_a_right, + )?; + } + } + _ => unimplemented!(), + } + } + } + circuit_builder.require_one(|| "limb_marker_sum_as_1", limb_marker_sum.expr())?; + + // Check that bit_shift and limb_shift are correct. + let num_bits = E::BaseField::from_canonical_usize(NUM_LIMBS * LIMB_BITS); + circuit_builder.assert_ux_v2( + || "bit_shift_vs_limb_shift", + (c[0].expr() + - limb_shift * E::BaseField::from_canonical_usize(LIMB_BITS).expr() + - bit_shift.expr()) + * num_bits.inverse().expr(), + LIMB_BITS - ((NUM_LIMBS * LIMB_BITS) as u32).ilog2() as usize, + )?; + if !matches!(kind, InsnKind::SRA) { + circuit_builder.require_zero(|| "b_sign_zero", b_sign.expr())?; + } else { + let mask = E::BaseField::from_canonical_u32(1 << (LIMB_BITS - 1)).expr(); + let b_sign_shifted = b_sign.expr() * mask.expr(); + circuit_builder.lookup_xor_byte( + b[NUM_LIMBS - 1].expr(), + mask.expr(), + b[NUM_LIMBS - 1].expr() + mask.expr() + - (E::BaseField::from_canonical_u32(2).expr()) * b_sign_shifted.expr(), + )?; + } + + for (i, carry) in bit_shift_carry.iter().enumerate() { + // TODO replace `LIMB_BITS` with `bit_shift` so we can support more strict range check + // `bit_shift` could be expression + // TODO refactor range check to support dynamic range + circuit_builder.assert_ux_v2( + || format!("bit_shift_carry_range_check_{i}"), + carry.expr(), + LIMB_BITS, + )?; + } + + Ok(Self { + bit_shift_marker, + bit_multiplier_left, + bit_multiplier_right, + limb_shift_marker, + bit_shift_carry, + b_sign, + phantom: PhantomData, + }) + } + + pub fn assign_instances( + config: &Self, + instance: &mut [::BaseField], + lk_multiplicity: &mut crate::witness::LkMultiplicity, + kind: InsnKind, + b: u32, + c: u32, + ) { + let b = split_to_limb::<_, LIMB_BITS>(b); + let c = split_to_limb::<_, LIMB_BITS>(c); + let (_, limb_shift, bit_shift) = run_shift::( + kind, + &b.clone().try_into().unwrap(), + &c.clone().try_into().unwrap(), + ); + + match kind { + InsnKind::SLL => set_val!( + instance, + config.bit_multiplier_left, + E::BaseField::from_canonical_usize(1 << bit_shift) + ), + _ => set_val!( + instance, + config.bit_multiplier_right, + E::BaseField::from_canonical_usize(1 << bit_shift) + ), + }; + + let bit_shift_carry: [u32; NUM_LIMBS] = array::from_fn(|i| match kind { + InsnKind::SLL => b[i] >> (LIMB_BITS - bit_shift), + _ => b[i] % (1 << bit_shift), + }); + for (val, witin) in bit_shift_carry.iter().zip_eq(&config.bit_shift_carry) { + set_val!(instance, witin, E::BaseField::from_canonical_u32(*val)); + } + for (i, witin) in config.bit_shift_marker.iter().enumerate() { + set_val!(instance, witin, E::BaseField::from_bool(i == bit_shift)); + } + for (i, witin) in config.limb_shift_marker.iter().enumerate() { + set_val!(instance, witin, E::BaseField::from_bool(i == limb_shift)); + } + + let mut b_sign = 0; + if matches!(kind, InsnKind::SRA) { + b_sign = b[NUM_LIMBS - 1] >> (LIMB_BITS - 1); + lk_multiplicity.lookup_xor_byte(b[NUM_LIMBS - 1] as u64, 1 << (LIMB_BITS - 1)); + } + set_val!( + instance, + config.b_sign, + E::BaseField::from_bool(b_sign != 0) + ); + } +} + +pub struct ShiftConfig { + shift_base_config: ShiftBaseConfig, + rs1_read: UInt, + rs2_read: UInt, + pub rd_written: UInt, + r_insn: RInstructionConfig, +} + +pub struct ShiftLogicalInstruction(PhantomData<(E, I)>); + +impl Instruction for ShiftLogicalInstruction { + type InstructionConfig = ShiftConfig; + + fn name() -> String { + format!("{:?}", I::INST_KIND) + } + + fn construct_circuit( + circuit_builder: &mut crate::circuit_builder::CircuitBuilder, + _params: &ProgramParams, + ) -> Result { + let (rd_written, rs1_read, rs2_read) = match I::INST_KIND { + InsnKind::SLL => { + let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; + let rs2_read = UInt::new_unchecked(|| "rs2_read", circuit_builder)?; + let rd_written = UInt::new(|| "rd_written", circuit_builder)?; + (rd_written, rs1_read, rs2_read) + } + _ => unimplemented!(), + }; + + let r_insn = RInstructionConfig::::construct_circuit( + circuit_builder, + I::INST_KIND, + rs1_read.register_expr(), + rs2_read.register_expr(), + rd_written.register_expr(), + )?; + + let shift_base_config = ShiftBaseConfig::construct_circuit( + circuit_builder, + I::INST_KIND, + rd_written.register_expr(), + rs1_read.register_expr(), + rs2_read.register_expr(), + )?; + + Ok(ShiftConfig { + r_insn, + rs1_read, + rs2_read, + rd_written, + shift_base_config, + }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [::BaseField], + lk_multiplicity: &mut crate::witness::LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), crate::error::ZKVMError> { + // rs2 & its derived values + let rs2_read = Value::new_unchecked(step.rs2().unwrap().value); + // let rs2_low5 = rs2_read.as_u64() & 0b11111; + // lk_multiplicity.assert_ux::<5>(rs2_low5); + // lk_multiplicity.lookup_pow2(rs2_low5); + // + // let pow2_rs2_low5 = 1u64 << rs2_low5; + // + // let rs2_high = Value::new( + // ((rs2_read.as_u64() - rs2_low5) >> 5) as u32, + // lk_multiplicity, + // ); + // config.rs2_high.assign_value(instance, rs2_high); + // config.rs2_read.assign_value(instance, rs2_read); + // + // set_val!(instance, config.pow2_rs2_low5, pow2_rs2_low5); + // set_val!(instance, config.rs2_low5, rs2_low5); + // + // // rs1 + // let rs1_read = Value::new_unchecked(step.rs1().unwrap().value); + // + // // rd + // let rd_written = Value::new(step.rd().unwrap().value.after, lk_multiplicity); + // + // // outflow + // let outflow = match I::INST_KIND { + // InsnKind::SLL => (rs1_read.as_u64() * pow2_rs2_low5) >> UInt::::TOTAL_BITS, + // InsnKind::SRL => rs1_read.as_u64() & (pow2_rs2_low5 - 1), + // InsnKind::SRA => { + // let Some(signed_ext_config) = config.signed_extend_config.as_ref() else { + // Err(ZKVMError::CircuitError)? + // }; + // signed_ext_config.assign_instance( + // instance, + // lk_multiplicity, + // *rs1_read.as_u16_limbs().last().unwrap() as u64, + // )?; + // rs1_read.as_u64() & (pow2_rs2_low5 - 1) + // } + // _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), + // }; + // + // set_val!(instance, config.outflow, outflow); + // + // config.rs1_read.assign_value(instance, rs1_read); + // config.rd_written.assign_value(instance, rd_written); + // + // config.assert_lt_config.assign_instance( + // instance, + // lk_multiplicity, + // outflow, + // pow2_rs2_low5, + // )?; + // + // config + // .r_insn + // .assign_instance(instance, lk_multiplicity, step)?; + + Ok(()) + } +} + +fn run_shift( + kind: InsnKind, + x: &[u32; NUM_LIMBS], + y: &[u32; NUM_LIMBS], +) -> ([u32; NUM_LIMBS], usize, usize) { + match kind { + InsnKind::SLL => run_shift_left::(x, y), + InsnKind::SRL => run_shift_right::(x, y, true), + InsnKind::SRA => run_shift_right::(x, y, false), + _ => unreachable!(), + } +} + +fn run_shift_left( + x: &[u32; NUM_LIMBS], + y: &[u32; NUM_LIMBS], +) -> ([u32; NUM_LIMBS], usize, usize) { + let mut result = [0u32; NUM_LIMBS]; + + let (limb_shift, bit_shift) = get_shift::(y); + + for i in limb_shift..NUM_LIMBS { + result[i] = if i > limb_shift { + ((x[i - limb_shift] << bit_shift) + (x[i - limb_shift - 1] >> (LIMB_BITS - bit_shift))) + % (1 << LIMB_BITS) + } else { + (x[i - limb_shift] << bit_shift) % (1 << LIMB_BITS) + }; + } + (result, limb_shift, bit_shift) +} + +fn run_shift_right( + x: &[u32; NUM_LIMBS], + y: &[u32; NUM_LIMBS], + logical: bool, +) -> ([u32; NUM_LIMBS], usize, usize) { + let fill = if logical { + 0 + } else { + ((1 << LIMB_BITS) - 1) * (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1)) + }; + let mut result = [fill; NUM_LIMBS]; + + let (limb_shift, bit_shift) = get_shift::(y); + + for i in 0..(NUM_LIMBS - limb_shift) { + result[i] = if i + limb_shift + 1 < NUM_LIMBS { + ((x[i + limb_shift] >> bit_shift) + (x[i + limb_shift + 1] << (LIMB_BITS - bit_shift))) + % (1 << LIMB_BITS) + } else { + ((x[i + limb_shift] >> bit_shift) + (fill << (LIMB_BITS - bit_shift))) + % (1 << LIMB_BITS) + } + } + (result, limb_shift, bit_shift) +} + +fn get_shift(y: &[u32]) -> (usize, usize) { + // We assume `NUM_LIMBS * LIMB_BITS <= 2^LIMB_BITS` so so the shift is defined + // entirely in y[0]. + let shift = (y[0] as usize) % (NUM_LIMBS * LIMB_BITS); + (shift / LIMB_BITS, shift % LIMB_BITS) +} diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index 1990bae39..a0b09feef 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -29,6 +29,16 @@ pub fn split_to_u8>(value: u32) -> Vec { .collect_vec() } +pub fn split_to_limb, const LIMB_BITS: usize>(value: u32) -> Vec { + (0..(u32::BITS as usize / LIMB_BITS)) + .scan(value, |acc, _| { + let limb = ((*acc & ((1 << LIMB_BITS) - 1)) as u8).into(); + *acc >>= LIMB_BITS; + Some(limb) + }) + .collect_vec() +} + /// Compile time evaluated minimum function /// returns min(a, b) pub(crate) const fn const_min(a: usize, b: usize) -> usize { From ff89e7aea29439dc7ff0b9154d6588b5c88beac0 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Wed, 20 Aug 2025 17:22:14 +0800 Subject: [PATCH 43/46] shift r-type test pass --- ceno_zkvm/src/instructions/riscv/shift.rs | 10 +- .../riscv/shift/shift_circuit_v2.rs | 112 ++++++------------ 2 files changed, 42 insertions(+), 80 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/shift.rs b/ceno_zkvm/src/instructions/riscv/shift.rs index 510b1b654..ed7910671 100644 --- a/ceno_zkvm/src/instructions/riscv/shift.rs +++ b/ceno_zkvm/src/instructions/riscv/shift.rs @@ -36,14 +36,14 @@ mod tests { use super::{ShiftLogicalInstruction, SllOp, SraOp, SrlOp}; use crate::{ - Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, instructions::{ Instruction, - riscv::{RIVInstruction, constants::UInt}, + riscv::{RIVInstruction, constants::UInt8}, }, scheme::mock_prover::{MOCK_PC_START, MockProver}, structs::ProgramParams, + utils::split_to_u8, }; #[test] @@ -134,11 +134,7 @@ mod tests { .require_equal( || format!("{prefix}_({name})_assert_rd_written"), &mut cb, - &UInt::from_const_unchecked( - Value::new_unchecked(expected_rd_written) - .as_u16_limbs() - .to_vec(), - ), + &UInt8::from_const_unchecked(split_to_u8::(expected_rd_written)), ) .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs index 559e592ce..47a6efc6a 100644 --- a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs @@ -1,16 +1,14 @@ use crate::{ - Value, - error::ZKVMError, instructions::{ Instruction, riscv::{ RIVInstruction, - constants::{LIMB_BITS, UINT_LIMBS, UInt}, + constants::{UINT_BYTE_LIMBS, UInt8}, r_insn::RInstructionConfig, }, }, structs::ProgramParams, - utils::split_to_limb, + utils::{split_to_limb, split_to_u8}, }; use ceno_emul::InsnKind; use ff_ext::{ExtensionField, FieldInto}; @@ -63,19 +61,19 @@ impl let mut bit_marker_sum = Expression::ZERO; let mut bit_shift = Expression::ZERO; - for i in 0..LIMB_BITS { + for (i, bit_shift_marker_i) in bit_shift_marker.iter().enumerate().take(LIMB_BITS) { circuit_builder.assert_bit( || format!("bit_shift_marker_{i}_assert_bit"), - bit_shift_marker[i].expr(), + bit_shift_marker_i.expr(), )?; - bit_marker_sum += bit_shift_marker[i].expr(); - bit_shift += E::BaseField::from_canonical_usize(i).expr() * bit_shift_marker[i].expr(); + bit_marker_sum += bit_shift_marker_i.expr(); + bit_shift += E::BaseField::from_canonical_usize(i).expr() * bit_shift_marker_i.expr(); match kind { InsnKind::SLL => { circuit_builder.condition_require_zero( || "bit_multiplier_left_condition", - bit_shift_marker[i].expr(), + bit_shift_marker_i.expr(), bit_multiplier_left.expr() - E::BaseField::from_canonical_usize(1 << i).expr(), )?; @@ -83,7 +81,7 @@ impl InsnKind::SRL | InsnKind::SRA => { circuit_builder.condition_require_zero( || "bit_multiplier_right_condition", - bit_shift_marker[i].expr(), + bit_shift_marker_i.expr(), bit_multiplier_right.expr() - E::BaseField::from_canonical_usize(1 << i).expr(), )?; @@ -111,7 +109,7 @@ impl InsnKind::SLL => { if j < i { circuit_builder.condition_require_zero( - || format!("limb_shift_marker_a_{j}"), + || format!("limb_shift_marker_a_{i}_{j}"), limb_shift_marker[i].expr(), a[j].expr(), )?; @@ -124,7 +122,7 @@ impl - E::BaseField::from_canonical_usize(1 << LIMB_BITS).expr() * bit_shift_carry[j - i].expr(); circuit_builder.condition_require_zero( - || format!("limb_shift_marker_a_expected_a_left_{j}",), + || format!("limb_shift_marker_a_expected_a_left_{i}_{j}",), limb_shift_marker[i].expr(), a[j].expr() - expected_a_left, )?; @@ -134,11 +132,12 @@ impl // SRL and SRA constraints. Combining with above would require an additional column. if j + i > NUM_LIMBS - 1 { circuit_builder.condition_require_zero( - || format!("limb_shift_marker_a_{j}"), + || format!("limb_shift_marker_a_{i}_{j}"), limb_shift_marker[i].expr(), - b_sign.expr() - * E::BaseField::from_canonical_usize((1 << LIMB_BITS) - 1) - .expr(), + a[j].expr() + - b_sign.expr() + * E::BaseField::from_canonical_usize((1 << LIMB_BITS) - 1) + .expr(), )?; } else { let expected_a_right = @@ -150,7 +149,7 @@ impl + (b[j + i].expr() - bit_shift_carry[j + i].expr()); circuit_builder.condition_require_zero( - || format!("limb_shift_marker_a_expected_a_left_{j}",), + || format!("limb_shift_marker_a_expected_a_right_{i}_{j}",), limb_shift_marker[i].expr(), a[j].expr() * bit_multiplier_right.expr() - expected_a_right, )?; @@ -262,8 +261,7 @@ impl let num_bits_log = (NUM_LIMBS * LIMB_BITS).ilog2(); lk_multiplicity.assert_ux_in_u16( LIMB_BITS - num_bits_log as usize, - (((c[0] as usize) - bit_shift - limb_shift * LIMB_BITS) - >> num_bits_log) as u64, + (((c[0] as usize) - bit_shift - limb_shift * LIMB_BITS) >> num_bits_log) as u64, ); let mut b_sign = 0; @@ -276,10 +274,10 @@ impl } pub struct ShiftConfig { - shift_base_config: ShiftBaseConfig, - rs1_read: UInt, - rs2_read: UInt, - pub rd_written: UInt, + shift_base_config: ShiftBaseConfig, + rs1_read: UInt8, + rs2_read: UInt8, + pub rd_written: UInt8, r_insn: RInstructionConfig, } @@ -297,10 +295,10 @@ impl Instruction for ShiftLogicalInstru _params: &ProgramParams, ) -> Result { let (rd_written, rs1_read, rs2_read) = match I::INST_KIND { - InsnKind::SLL => { - let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; - let rs2_read = UInt::new_unchecked(|| "rs2_read", circuit_builder)?; - let rd_written = UInt::new(|| "rd_written", circuit_builder)?; + InsnKind::SLL | InsnKind::SRL | InsnKind::SRA => { + let rs1_read = UInt8::new_unchecked(|| "rs1_read", circuit_builder)?; + let rs2_read = UInt8::new_unchecked(|| "rs2_read", circuit_builder)?; + let rd_written = UInt8::new(|| "rd_written", circuit_builder)?; (rd_written, rs1_read, rs2_read) } _ => unimplemented!(), @@ -317,9 +315,9 @@ impl Instruction for ShiftLogicalInstru let shift_base_config = ShiftBaseConfig::construct_circuit( circuit_builder, I::INST_KIND, - rd_written.register_expr(), - rs1_read.register_expr(), - rs2_read.register_expr(), + rd_written.expr().try_into().unwrap(), + rs1_read.expr().try_into().unwrap(), + rs2_read.expr().try_into().unwrap(), )?; Ok(ShiftConfig { @@ -337,51 +335,19 @@ impl Instruction for ShiftLogicalInstru lk_multiplicity: &mut crate::witness::LkMultiplicity, step: &ceno_emul::StepRecord, ) -> Result<(), crate::error::ZKVMError> { - // rs2 & its derived values - let rs2_read = Value::new_unchecked(step.rs2().unwrap().value); - // let rs2_low5 = rs2_read.as_u64() & 0b11111; - // lk_multiplicity.assert_ux::<5>(rs2_low5); - // lk_multiplicity.lookup_pow2(rs2_low5); - // - // let pow2_rs2_low5 = 1u64 << rs2_low5; - // - // let rs2_high = Value::new( - // ((rs2_read.as_u64() - rs2_low5) >> 5) as u32, - // lk_multiplicity, - // ); - // config.rs2_high.assign_value(instance, rs2_high); - // config.rs2_read.assign_value(instance, rs2_read); - // - // set_val!(instance, config.pow2_rs2_low5, pow2_rs2_low5); - // set_val!(instance, config.rs2_low5, rs2_low5); - // + // rs2 + let rs2_read = split_to_u8::(step.rs2().unwrap().value); // rs1 - let rs1_read = Value::new_unchecked(step.rs1().unwrap().value); + let rs1_read = split_to_u8::(step.rs1().unwrap().value); // rd - let rd_written = Value::new(step.rd().unwrap().value.after, lk_multiplicity); - // // outflow - // let outflow = match I::INST_KIND { - // InsnKind::SLL => (rs1_read.as_u64() * pow2_rs2_low5) >> UInt::::TOTAL_BITS, - // InsnKind::SRL => rs1_read.as_u64() & (pow2_rs2_low5 - 1), - // InsnKind::SRA => { - // let Some(signed_ext_config) = config.signed_extend_config.as_ref() else { - // Err(ZKVMError::CircuitError)? - // }; - // signed_ext_config.assign_instance( - // instance, - // lk_multiplicity, - // *rs1_read.as_u16_limbs().last().unwrap() as u64, - // )?; - // rs1_read.as_u64() & (pow2_rs2_low5 - 1) - // } - // _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), - // }; - // - // set_val!(instance, config.outflow, outflow); - // - config.rs1_read.assign_value(instance, rs1_read); - config.rs2_read.assign_value(instance, rs2_read); - config.rd_written.assign_value(instance, rd_written); + let rd_written = split_to_u8::(step.rd().unwrap().value.after); + for val in &rd_written { + lk_multiplicity.assert_ux::<8>(*val as u64); + } + + config.rs1_read.assign_limbs(instance, &rs1_read); + config.rs2_read.assign_limbs(instance, &rs2_read); + config.rd_written.assign_limbs(instance, &rd_written); config.shift_base_config.assign_instances( instance, From 4932c7e381d49189ab1d71ca63e6cdf48c4e2be7 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Wed, 20 Aug 2025 19:48:19 +0800 Subject: [PATCH 44/46] add babybear test --- ceno_zkvm/Cargo.toml | 18 +- ceno_zkvm/src/instructions/riscv/insn_base.rs | 1 - ceno_zkvm/src/instructions/riscv/shift.rs | 84 +++-- .../riscv/shift/shift_circuit_v2.rs | 132 ++++++-- ceno_zkvm/src/instructions/riscv/shift_imm.rs | 287 +++++------------- .../riscv/shift_imm/shift_imm_circuit.rs | 175 +++++++++++ ceno_zkvm/src/tables/program.rs | 3 - 7 files changed, 418 insertions(+), 282 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 392ec1a69..9ada3e96e 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -61,20 +61,20 @@ ceno-examples = { path = "../examples-builder" } glob = "0.3" [features] -default = ["forbid_overflow", "u16limb_circuit"] +default = ["forbid_overflow"] flamegraph = ["pprof2/flamegraph", "pprof2/criterion"] forbid_overflow = [] jemalloc = ["dep:tikv-jemallocator", "dep:tikv-jemalloc-ctl"] jemalloc-prof = ["jemalloc", "tikv-jemallocator?/profiling"] nightly-features = [ - "p3/nightly-features", - "ff_ext/nightly-features", - "mpcs/nightly-features", - "multilinear_extensions/nightly-features", - "poseidon/nightly-features", - "sumcheck/nightly-features", - "transcript/nightly-features", - "witness/nightly-features", + "p3/nightly-features", + "ff_ext/nightly-features", + "mpcs/nightly-features", + "multilinear_extensions/nightly-features", + "poseidon/nightly-features", + "sumcheck/nightly-features", + "transcript/nightly-features", + "witness/nightly-features", ] sanity-check = ["mpcs/sanity-check"] u16limb_circuit = ["ceno_emul/u16limb_circuit"] diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index 0412dfb09..3d013a2e8 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -506,7 +506,6 @@ impl MemAddr { // Range check the high limb. for i in 1..UINT_LIMBS { let high_u16 = (addr >> (i * 16)) & 0xffff; - println!("assignment max bit {}", (self.max_bits - i * 16).min(16)); lkm.assert_ux_v2(high_u16 as u64, (self.max_bits - i * 16).min(16)); } diff --git a/ceno_zkvm/src/instructions/riscv/shift.rs b/ceno_zkvm/src/instructions/riscv/shift.rs index ed7910671..8e8a43165 100644 --- a/ceno_zkvm/src/instructions/riscv/shift.rs +++ b/ceno_zkvm/src/instructions/riscv/shift.rs @@ -1,7 +1,7 @@ #[cfg(not(feature = "u16limb_circuit"))] -mod shift_circuit; +pub mod shift_circuit; #[cfg(feature = "u16limb_circuit")] -mod shift_circuit_v2; +pub mod shift_circuit_v2; use ceno_emul::InsnKind; @@ -45,46 +45,68 @@ mod tests { structs::ProgramParams, utils::split_to_u8, }; + #[cfg(feature = "u16limb_circuit")] + use ff_ext::BabyBearExt4; #[test] fn test_opcode_sll() { - verify::("basic", 0b_0001, 3, 0b_1000); - // 33 << 33 === 33 << 1 - verify::("rs2 over 5-bits", 0b_0001, 33, 0b_0010); - verify::("bit loss", (1 << 31) | 1, 1, 0b_0010); - verify::("zero shift", 0b_0001, 0, 0b_0001); - verify::("all zeros", 0b_0000, 0, 0b_0000); - verify::("base is zero", 0b_0000, 1, 0b_0000); + let cases = [ + ("basic 1", 32, 3, 32 << 3), + ("basic 2", 0b_0001, 3, 0b_1000), + // 33 << 33 === 33 << 1 + ("rs2 over 5-bits", 0b_0001, 33, 0b_0010), + ("bit loss", (1 << 31) | 1, 1, 0b_0010), + ("zero shift", 0b_0001, 0, 0b_0001), + ("all zeros", 0b_0000, 0, 0b_0000), + ("base is zero", 0b_0000, 1, 0b_0000), + ]; + + for (name, lhs, rhs, expected) in cases { + verify::(name, lhs, rhs, expected); + #[cfg(feature = "u16limb_circuit")] + verify::(name, lhs, rhs, expected); + } } #[test] fn test_opcode_srl() { - verify::("basic", 0b_1000, 3, 0b_0001); - // 33 >> 33 === 33 >> 1 - verify::("rs2 over 5-bits", 0b_1010, 33, 0b_0101); - verify::("bit loss", 0b_1001, 1, 0b_0100); - verify::("zero shift", 0b_1000, 0, 0b_1000); - verify::("all zeros", 0b_0000, 0, 0b_0000); - verify::("base is zero", 0b_0000, 1, 0b_0000); + let cases = [ + ("basic", 0b_1000, 3, 0b_0001), + // 33 >> 33 === 33 >> 1 + ("rs2 over 5-bits", 0b_1010, 33, 0b_0101), + ("bit loss", 0b_1001, 1, 0b_0100), + ("zero shift", 0b_1000, 0, 0b_1000), + ("all zeros", 0b_0000, 0, 0b_0000), + ("base is zero", 0b_0000, 1, 0b_0000), + ]; + + for (name, lhs, rhs, expected) in cases { + verify::(name, lhs, rhs, expected); + #[cfg(feature = "u16limb_circuit")] + verify::(name, lhs, rhs, expected); + } } #[test] fn test_opcode_sra() { - // positive rs1 - // rs2 = 3 - verify::("32 >> 3", 32, 3, 32 >> 3); - verify::("33 >> 3", 33, 3, 33 >> 3); - // rs2 = 31 - verify::("32 >> 31", 32, 31, 32 >> 31); - verify::("33 >> 31", 33, 31, 33 >> 31); - - // negative rs1 - // rs2 = 3 - verify::("-32 >> 3", (-32_i32) as u32, 3, (-32_i32 >> 3) as u32); - verify::("-33 >> 3", (-33_i32) as u32, 3, (-33_i32 >> 3) as u32); - // rs2 = 31 - verify::("-32 >> 31", (-32_i32) as u32, 31, (-32_i32 >> 31) as u32); - verify::("-33 >> 31", (-33_i32) as u32, 31, (-33_i32 >> 31) as u32); + let cases = [ + // positive rs1 + ("32 >> 3", 32, 3, 32 >> 3), + ("33 >> 3", 33, 3, 33 >> 3), + ("32 >> 31", 32, 31, 32 >> 31), + ("33 >> 31", 33, 31, 33 >> 31), + // negative rs1 + ("-32 >> 3", (-32_i32) as u32, 3, (-32_i32 >> 3) as u32), + ("-33 >> 3", (-33_i32) as u32, 3, (-33_i32 >> 3) as u32), + ("-32 >> 31", (-32_i32) as u32, 31, (-32_i32 >> 31) as u32), + ("-33 >> 31", (-33_i32) as u32, 31, (-33_i32 >> 31) as u32), + ]; + + for (name, lhs, rhs, expected) in cases { + verify::(name, lhs, rhs, expected); + #[cfg(feature = "u16limb_circuit")] + verify::(name, lhs, rhs, expected); + } } fn verify( diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs index 47a6efc6a..06c4a28ad 100644 --- a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs @@ -1,9 +1,11 @@ +/// constrain implementation follow from https://github.com/openvm-org/openvm/blob/main/extensions/rv32im/circuit/src/shift/core.rs use crate::{ instructions::{ Instruction, riscv::{ RIVInstruction, constants::{UINT_BYTE_LIMBS, UInt8}, + i_insn::IInstructionConfig, r_insn::RInstructionConfig, }, }, @@ -70,7 +72,7 @@ impl bit_shift += E::BaseField::from_canonical_usize(i).expr() * bit_shift_marker_i.expr(); match kind { - InsnKind::SLL => { + InsnKind::SLL | InsnKind::SLLI => { circuit_builder.condition_require_zero( || "bit_multiplier_left_condition", bit_shift_marker_i.expr(), @@ -78,7 +80,7 @@ impl - E::BaseField::from_canonical_usize(1 << i).expr(), )?; } - InsnKind::SRL | InsnKind::SRA => { + InsnKind::SRL | InsnKind::SRLI | InsnKind::SRA | InsnKind::SRAI => { circuit_builder.condition_require_zero( || "bit_multiplier_right_condition", bit_shift_marker_i.expr(), @@ -106,7 +108,7 @@ impl for j in 0..NUM_LIMBS { match kind { - InsnKind::SLL => { + InsnKind::SLL | InsnKind::SLLI => { if j < i { circuit_builder.condition_require_zero( || format!("limb_shift_marker_a_{i}_{j}"), @@ -128,7 +130,7 @@ impl )?; } } - InsnKind::SRL | InsnKind::SRA => { + InsnKind::SRL | InsnKind::SRLI | InsnKind::SRA | InsnKind::SRAI => { // SRL and SRA constraints. Combining with above would require an additional column. if j + i > NUM_LIMBS - 1 { circuit_builder.condition_require_zero( @@ -180,7 +182,7 @@ impl - bit_shift.expr()) * num_bits.inverse().expr(), )?; - if !matches!(kind, InsnKind::SRA) { + if !matches!(kind, InsnKind::SRA | InsnKind::SRAI) { circuit_builder.require_zero(|| "b_sign_zero", b_sign.expr())?; } else { let mask = E::BaseField::from_canonical_u32(1 << (LIMB_BITS - 1)).expr(); @@ -232,7 +234,7 @@ impl ); match kind { - InsnKind::SLL => set_val!( + InsnKind::SLL | InsnKind::SLLI => set_val!( instance, self.bit_multiplier_left, E::BaseField::from_canonical_usize(1 << bit_shift) @@ -245,7 +247,7 @@ impl }; let bit_shift_carry: [u32; NUM_LIMBS] = array::from_fn(|i| match kind { - InsnKind::SLL => b[i] >> (LIMB_BITS - bit_shift), + InsnKind::SLL | InsnKind::SLLI => b[i] >> (LIMB_BITS - bit_shift), _ => b[i] % (1 << bit_shift), }); for (val, witin) in bit_shift_carry.iter().zip_eq(&self.bit_shift_carry) { @@ -265,7 +267,7 @@ impl ); let mut b_sign = 0; - if matches!(kind, InsnKind::SRA) { + if matches!(kind, InsnKind::SRA | InsnKind::SRAI) { b_sign = b[NUM_LIMBS - 1] >> (LIMB_BITS - 1); lk_multiplicity.lookup_xor_byte(b[NUM_LIMBS - 1] as u64, 1 << (LIMB_BITS - 1)); } @@ -273,7 +275,7 @@ impl } } -pub struct ShiftConfig { +pub struct ShiftRTypeConfig { shift_base_config: ShiftBaseConfig, rs1_read: UInt8, rs2_read: UInt8, @@ -284,7 +286,7 @@ pub struct ShiftConfig { pub struct ShiftLogicalInstruction(PhantomData<(E, I)>); impl Instruction for ShiftLogicalInstruction { - type InstructionConfig = ShiftConfig; + type InstructionConfig = ShiftRTypeConfig; fn name() -> String { format!("{:?}", I::INST_KIND) @@ -320,7 +322,7 @@ impl Instruction for ShiftLogicalInstru rs2_read.expr().try_into().unwrap(), )?; - Ok(ShiftConfig { + Ok(ShiftRTypeConfig { r_insn, rs1_read, rs2_read, @@ -330,7 +332,7 @@ impl Instruction for ShiftLogicalInstru } fn assign_instance( - config: &ShiftConfig, + config: &ShiftRTypeConfig, instance: &mut [::BaseField], lk_multiplicity: &mut crate::witness::LkMultiplicity, step: &ceno_emul::StepRecord, @@ -356,13 +358,6 @@ impl Instruction for ShiftLogicalInstru step.rs1().unwrap().value, step.rs2().unwrap().value, ); - // config.assert_lt_config.assign_instance( - // instance, - // lk_multiplicity, - // outflow, - // pow2_rs2_low5, - // )?; - // config .r_insn .assign_instance(instance, lk_multiplicity, step)?; @@ -371,15 +366,108 @@ impl Instruction for ShiftLogicalInstru } } +pub struct ShiftImmConfig { + shift_base_config: ShiftBaseConfig, + rs1_read: UInt8, + pub rd_written: UInt8, + i_insn: IInstructionConfig, + imm: WitIn, +} + +pub struct ShiftImmInstruction(PhantomData<(E, I)>); + +impl Instruction for ShiftImmInstruction { + type InstructionConfig = ShiftImmConfig; + + fn name() -> String { + format!("{:?}", I::INST_KIND) + } + + fn construct_circuit( + circuit_builder: &mut crate::circuit_builder::CircuitBuilder, + _params: &ProgramParams, + ) -> Result { + let (rd_written, rs1_read, imm) = match I::INST_KIND { + InsnKind::SLLI | InsnKind::SRLI | InsnKind::SRAI => { + let rs1_read = UInt8::new_unchecked(|| "rs1_read", circuit_builder)?; + let imm = circuit_builder.create_witin(|| "imm"); + let rd_written = UInt8::new(|| "rd_written", circuit_builder)?; + (rd_written, rs1_read, imm) + } + _ => unimplemented!(), + }; + let uint8_imm = UInt8::from_exprs_unchecked(vec![imm.expr(), 0.into(), 0.into(), 0.into()]); + + let i_insn = IInstructionConfig::::construct_circuit( + circuit_builder, + I::INST_KIND, + imm.expr(), + 0.into(), + rs1_read.register_expr(), + rd_written.register_expr(), + false, + )?; + + let shift_base_config = ShiftBaseConfig::construct_circuit( + circuit_builder, + I::INST_KIND, + rd_written.expr().try_into().unwrap(), + rs1_read.expr().try_into().unwrap(), + uint8_imm.expr().try_into().unwrap(), + )?; + + Ok(ShiftImmConfig { + i_insn, + imm, + rs1_read, + rd_written, + shift_base_config, + }) + } + + fn assign_instance( + config: &ShiftImmConfig, + instance: &mut [::BaseField], + lk_multiplicity: &mut crate::witness::LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), crate::error::ZKVMError> { + let imm = step.insn().imm as i16 as u16; + set_val!(instance, config.imm, E::BaseField::from_canonical_u16(imm)); + // rs1 + let rs1_read = split_to_u8::(step.rs1().unwrap().value); + // rd + let rd_written = split_to_u8::(step.rd().unwrap().value.after); + for val in &rd_written { + lk_multiplicity.assert_ux::<8>(*val as u64); + } + + config.rs1_read.assign_limbs(instance, &rs1_read); + config.rd_written.assign_limbs(instance, &rd_written); + + config.shift_base_config.assign_instances( + instance, + lk_multiplicity, + I::INST_KIND, + step.rs1().unwrap().value, + imm as u32, + ); + config + .i_insn + .assign_instance(instance, lk_multiplicity, step)?; + + Ok(()) + } +} + fn run_shift( kind: InsnKind, x: &[u32; NUM_LIMBS], y: &[u32; NUM_LIMBS], ) -> ([u32; NUM_LIMBS], usize, usize) { match kind { - InsnKind::SLL => run_shift_left::(x, y), - InsnKind::SRL => run_shift_right::(x, y, true), - InsnKind::SRA => run_shift_right::(x, y, false), + InsnKind::SLL | InsnKind::SLLI => run_shift_left::(x, y), + InsnKind::SRL | InsnKind::SRLI => run_shift_right::(x, y, true), + InsnKind::SRA | InsnKind::SRAI => run_shift_right::(x, y, false), _ => unreachable!(), } } diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm.rs b/ceno_zkvm/src/instructions/riscv/shift_imm.rs index 9eb759ffa..bfac94a33 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm.rs @@ -1,44 +1,17 @@ -use super::RIVInstruction; -use crate::{ - Value, - circuit_builder::CircuitBuilder, - error::ZKVMError, - gadgets::{AssertLtConfig, SignedExtendConfig}, - instructions::{ - Instruction, - riscv::{ - constants::{LIMB_BITS, UINT_LIMBS, UInt}, - i_insn::IInstructionConfig, - }, - }, - structs::ProgramParams, - tables::InsnRecord, - witness::LkMultiplicity, -}; -use ceno_emul::{InsnKind, StepRecord}; -use ff_ext::{ExtensionField, FieldInto}; -use multilinear_extensions::{Expression, ToExpr, WitIn}; -use std::marker::PhantomData; -use witness::set_val; - -pub struct ShiftImmConfig { - i_insn: IInstructionConfig, +#[cfg(not(feature = "u16limb_circuit"))] +mod shift_imm_circuit; - imm: WitIn, - rs1_read: UInt, - rd_written: UInt, - outflow: WitIn, - assert_lt_config: AssertLtConfig, - - // SRAI - is_lt_config: Option>, -} +use super::RIVInstruction; +use ceno_emul::InsnKind; -pub struct ShiftImmInstruction(PhantomData<(E, I)>); +#[cfg(feature = "u16limb_circuit")] +use crate::instructions::riscv::shift::shift_circuit_v2::ShiftImmInstruction; +#[cfg(not(feature = "u16limb_circuit"))] +use crate::instructions::riscv::shift_imm::shift_imm_circuit::ShiftImmInstruction; pub struct SlliOp; impl RIVInstruction for SlliOp { - const INST_KIND: ceno_emul::InsnKind = ceno_emul::InsnKind::SLLI; + const INST_KIND: InsnKind = InsnKind::SLLI; } pub type SlliInstruction = ShiftImmInstruction; @@ -54,210 +27,96 @@ impl RIVInstruction for SrliOp { } pub type SrliInstruction = ShiftImmInstruction; -impl Instruction for ShiftImmInstruction { - type InstructionConfig = ShiftImmConfig; - - fn name() -> String { - format!("{:?}", I::INST_KIND) - } - - fn construct_circuit( - circuit_builder: &mut CircuitBuilder, - _params: &ProgramParams, - ) -> Result { - // treat bit shifting as a bit "inflow" and "outflow" process, flowing from left to right or vice versa - // this approach simplifies constraint and witness allocation compared to using multiplication/division gadget, - // as the divisor/multiplier is a power of 2. - // - // example: right shift (bit flow from left to right) - // inflow || rs1_read == rd_written || outflow - // in this case, inflow consists of either all 0s or all 1s for sign extension (if the value is signed). - // - // for left shifts, the inflow is always 0: - // rs1_read || inflow == outflow || rd_written - // - // additional constraint: outflow < (1 << shift), which lead to unique solution - - // soundness: take Goldilocks as example, both sides of the equation are 63 bits numbers (<2**63) - // rd * imm + outflow == inflow * 2**32 + rs1 - // 32 + 31. 31. 31 + 32. 32. (Bit widths) - - // Note: `imm` wtns is set to 2**imm (upto 32 bit) just for efficient verification. - let imm = circuit_builder.create_witin(|| "imm"); - let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; - let rd_written = UInt::new(|| "rd_written", circuit_builder)?; - - let outflow = circuit_builder.create_witin(|| "outflow"); - let assert_lt_config = AssertLtConfig::construct_circuit( - circuit_builder, - || "outflow < imm", - outflow.expr(), - imm.expr(), - UINT_LIMBS * LIMB_BITS, - )?; - - let two_pow_total_bits: Expression<_> = (1u64 << UInt::::TOTAL_BITS).into(); - - let is_lt_config = match I::INST_KIND { - InsnKind::SLLI => { - circuit_builder.require_equal( - || "shift check", - rs1_read.value() * imm.expr(), // inflow is zero for this case - outflow.expr() * two_pow_total_bits + rd_written.value(), - )?; - None - } - InsnKind::SRAI | InsnKind::SRLI => { - let (inflow, is_lt_config) = match I::INST_KIND { - InsnKind::SRAI => { - let is_rs1_neg = rs1_read.is_negative(circuit_builder)?; - let ones = imm.expr() - 1; - (is_rs1_neg.expr() * ones, Some(is_rs1_neg)) - } - InsnKind::SRLI => (Expression::ZERO, None), - _ => unreachable!(), - }; - circuit_builder.require_equal( - || "shift check", - rd_written.value() * imm.expr() + outflow.expr(), - inflow * two_pow_total_bits + rs1_read.value(), - )?; - is_lt_config - } - _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), - }; - - let i_insn = IInstructionConfig::::construct_circuit( - circuit_builder, - I::INST_KIND, - imm.expr(), - #[cfg(feature = "u16limb_circuit")] - 0.into(), - rs1_read.register_expr(), - rd_written.register_expr(), - false, - )?; - - Ok(ShiftImmConfig { - i_insn, - imm, - rs1_read, - rd_written, - outflow, - assert_lt_config, - is_lt_config, - }) - } - - fn assign_instance( - config: &Self::InstructionConfig, - instance: &mut [::BaseField], - lk_multiplicity: &mut LkMultiplicity, - step: &StepRecord, - ) -> Result<(), ZKVMError> { - // imm_internal is a precomputed 2**shift. - let imm = InsnRecord::::imm_internal(&step.insn()).0 as u64; - let rs1_read = Value::new_unchecked(step.rs1().unwrap().value); - let rd_written = Value::new(step.rd().unwrap().value.after, lk_multiplicity); - - set_val!(instance, config.imm, imm); - config.rs1_read.assign_value(instance, rs1_read.clone()); - config.rd_written.assign_value(instance, rd_written); - - let outflow = match I::INST_KIND { - InsnKind::SLLI => (rs1_read.as_u64() * imm) >> UInt::::TOTAL_BITS, - InsnKind::SRAI | InsnKind::SRLI => { - if I::INST_KIND == InsnKind::SRAI { - config.is_lt_config.as_ref().unwrap().assign_instance( - instance, - lk_multiplicity, - *rs1_read.as_u16_limbs().last().unwrap() as u64, - )?; - } - - rs1_read.as_u64() & (imm - 1) - } - _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), - }; - - set_val!(instance, config.outflow, outflow); - config - .assert_lt_config - .assign_instance(instance, lk_multiplicity, outflow, imm)?; - - config - .i_insn - .assign_instance(instance, lk_multiplicity, step)?; - - Ok(()) - } -} - #[cfg(test)] mod test { use ceno_emul::{Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32u}; - use ff_ext::GoldilocksExt2; + use ff_ext::{ExtensionField, GoldilocksExt2}; use super::{ShiftImmInstruction, SlliOp, SraiOp, SrliOp}; use crate::{ - Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, instructions::{ Instruction, - riscv::{RIVInstruction, constants::UInt}, + riscv::{RIVInstruction, constants::UInt8}, }, scheme::mock_prover::{MOCK_PC_START, MockProver}, structs::ProgramParams, + utils::split_to_u8, }; + #[cfg(feature = "u16limb_circuit")] + use ff_ext::BabyBearExt4; #[test] fn test_opcode_slli() { - // imm = 3 - verify::("32 << 3", 32, 3, 32 << 3); - verify::("33 << 3", 33, 3, 33 << 3); - // imm = 31 - verify::("32 << 31", 32, 31, 32 << 31); - verify::("33 << 31", 33, 31, 33 << 31); + let cases = [ + // imm = 3 + ("32 << 3", 32, 3, 32 << 3), + ("33 << 3", 33, 3, 33 << 3), + + // imm = 31 + ("32 << 31", 32, 31, 32 << 31), + ("33 << 31", 33, 31, 33 << 31), + ]; + + for (name, lhs, imm, expected) in cases { + verify::(name, lhs, imm, expected); + #[cfg(feature = "u16limb_circuit")] + verify::(name, lhs, imm, expected); + } } #[test] fn test_opcode_srai() { - // positive rs1 - // imm = 3 - verify::("32 >> 3", 32, 3, 32 >> 3); - verify::("33 >> 3", 33, 3, 33 >> 3); - // imm = 31 - verify::("32 >> 31", 32, 31, 32 >> 31); - verify::("33 >> 31", 33, 31, 33 >> 31); - - // negative rs1 - // imm = 3 - verify::("-32 >> 3", (-32_i32) as u32, 3, (-32_i32 >> 3) as u32); - verify::("-33 >> 3", (-33_i32) as u32, 3, (-33_i32 >> 3) as u32); - // imm = 31 - verify::("-32 >> 31", (-32_i32) as u32, 31, (-32_i32 >> 31) as u32); - verify::("-33 >> 31", (-33_i32) as u32, 31, (-33_i32 >> 31) as u32); + let cases = [ + // positive rs1 + ("32 >> 3", 32, 3, 32 >> 3), + ("33 >> 3", 33, 3, 33 >> 3), + ("32 >> 31", 32, 31, 32 >> 31), + ("33 >> 31", 33, 31, 33 >> 31), + + // negative rs1 + ("-32 >> 3", (-32_i32) as u32, 3, (-32_i32 >> 3) as u32), + ("-33 >> 3", (-33_i32) as u32, 3, (-33_i32 >> 3) as u32), + ("-32 >> 31", (-32_i32) as u32, 31, (-32_i32 >> 31) as u32), + ("-33 >> 31", (-33_i32) as u32, 31, (-33_i32 >> 31) as u32), + ]; + + for (name, lhs, imm, expected) in cases { + verify::(name, lhs, imm, expected); + #[cfg(feature = "u16limb_circuit")] + verify::(name, lhs, imm, expected); + } } #[test] fn test_opcode_srli() { - // imm = 3 - verify::("32 >> 3", 32, 3, 32 >> 3); - verify::("33 >> 3", 33, 3, 33 >> 3); - // imm = 31 - verify::("32 >> 31", 32, 31, 32 >> 31); - verify::("33 >> 31", 33, 31, 33 >> 31); - // rs1 top bit is 1 - verify::("-32 >> 3", (-32_i32) as u32, 3, (-32_i32) as u32 >> 3); + let cases = [ + // imm = 3 + ("32 >> 3", 32, 3, 32 >> 3), + ("33 >> 3", 33, 3, 33 >> 3), + + // imm = 31 + ("32 >> 31", 32, 31, 32 >> 31), + ("33 >> 31", 33, 31, 33 >> 31), + + // rs1 top bit is 1 + ("-32 >> 3", (-32_i32) as u32, 3, ((-32_i32) as u32) >> 3), + ]; + + for (name, lhs, imm, expected) in cases { + verify::(name, lhs, imm, expected); + #[cfg(feature = "u16limb_circuit")] + verify::(name, lhs, imm, expected); + } } - fn verify( + fn verify( name: &'static str, rs1_read: u32, imm: u32, expected_rd_written: u32, ) { - let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); let (prefix, insn_code, rd_written) = match I::INST_KIND { @@ -283,7 +142,7 @@ mod test { .namespace( || format!("{prefix}_({name})"), |cb| { - let config = ShiftImmInstruction::::construct_circuit( + let config = ShiftImmInstruction::::construct_circuit( cb, &ProgramParams::default(), ); @@ -298,15 +157,11 @@ mod test { .require_equal( || format!("{prefix}_({name})_assert_rd_written"), &mut cb, - &UInt::from_const_unchecked( - Value::new_unchecked(expected_rd_written) - .as_u16_limbs() - .to_vec(), - ), + &UInt8::from_const_unchecked(split_to_u8::(expected_rd_written)), ) .unwrap(); - let (raw_witin, lkm) = ShiftImmInstruction::::assign_instances( + let (raw_witin, lkm) = ShiftImmInstruction::::assign_instances( &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs b/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs new file mode 100644 index 000000000..0bba35411 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs @@ -0,0 +1,175 @@ +use crate::{ + Value, + circuit_builder::CircuitBuilder, + error::ZKVMError, + gadgets::SignedExtendConfig, + instructions::{ + Instruction, + riscv::{ + RIVInstruction, + constants::{LIMB_BITS, UINT_LIMBS, UInt}, + i_insn::IInstructionConfig, + }, + }, + structs::ProgramParams, + tables::InsnRecord, + witness::LkMultiplicity, +}; +use ceno_emul::{InsnKind, StepRecord}; +use ff_ext::{ExtensionField, FieldInto}; +use gkr_iop::gadgets::AssertLtConfig; +use multilinear_extensions::{Expression, ToExpr, WitIn}; +use std::marker::PhantomData; +use witness::set_val; + +pub struct ShiftImmInstruction(PhantomData<(E, I)>); + +pub struct ShiftImmConfig { + i_insn: IInstructionConfig, + + imm: WitIn, + rs1_read: UInt, + pub rd_written: UInt, + outflow: WitIn, + assert_lt_config: AssertLtConfig, + + // SRAI + is_lt_config: Option>, +} + +impl Instruction for ShiftImmInstruction { + type InstructionConfig = ShiftImmConfig; + + fn name() -> String { + format!("{:?}", I::INST_KIND) + } + + fn construct_circuit( + circuit_builder: &mut CircuitBuilder, + _params: &ProgramParams, + ) -> Result { + // treat bit shifting as a bit "inflow" and "outflow" process, flowing from left to right or vice versa + // this approach simplifies constraint and witness allocation compared to using multiplication/division gadget, + // as the divisor/multiplier is a power of 2. + // + // example: right shift (bit flow from left to right) + // inflow || rs1_read == rd_written || outflow + // in this case, inflow consists of either all 0s or all 1s for sign extension (if the value is signed). + // + // for left shifts, the inflow is always 0: + // rs1_read || inflow == outflow || rd_written + // + // additional constraint: outflow < (1 << shift), which lead to unique solution + + // soundness: take Goldilocks as example, both sides of the equation are 63 bits numbers (<2**63) + // rd * imm + outflow == inflow * 2**32 + rs1 + // 32 + 31. 31. 31 + 32. 32. (Bit widths) + + // Note: `imm` wtns is set to 2**imm (upto 32 bit) just for efficient verification. + let imm = circuit_builder.create_witin(|| "imm"); + let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; + let rd_written = UInt::new(|| "rd_written", circuit_builder)?; + + let outflow = circuit_builder.create_witin(|| "outflow"); + let assert_lt_config = AssertLtConfig::construct_circuit( + circuit_builder, + || "outflow < imm", + outflow.expr(), + imm.expr(), + UINT_LIMBS * LIMB_BITS, + )?; + + let two_pow_total_bits: Expression<_> = (1u64 << UInt::::TOTAL_BITS).into(); + + let is_lt_config = match I::INST_KIND { + InsnKind::SLLI => { + circuit_builder.require_equal( + || "shift check", + rs1_read.value() * imm.expr(), // inflow is zero for this case + outflow.expr() * two_pow_total_bits + rd_written.value(), + )?; + None + } + InsnKind::SRAI | InsnKind::SRLI => { + let (inflow, is_lt_config) = match I::INST_KIND { + InsnKind::SRAI => { + let is_rs1_neg = rs1_read.is_negative(circuit_builder)?; + let ones = imm.expr() - 1; + (is_rs1_neg.expr() * ones, Some(is_rs1_neg)) + } + InsnKind::SRLI => (Expression::ZERO, None), + _ => unreachable!(), + }; + circuit_builder.require_equal( + || "shift check", + rd_written.value() * imm.expr() + outflow.expr(), + inflow * two_pow_total_bits + rs1_read.value(), + )?; + is_lt_config + } + _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), + }; + + let i_insn = IInstructionConfig::::construct_circuit( + circuit_builder, + I::INST_KIND, + imm.expr(), + rs1_read.register_expr(), + rd_written.register_expr(), + false, + )?; + + Ok(ShiftImmConfig { + i_insn, + imm, + rs1_read, + rd_written, + outflow, + assert_lt_config, + is_lt_config, + }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [::BaseField], + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + // imm_internal is a precomputed 2**shift. + let imm = InsnRecord::::imm_internal(&step.insn()).0 as u64; + let rs1_read = Value::new_unchecked(step.rs1().unwrap().value); + let rd_written = Value::new(step.rd().unwrap().value.after, lk_multiplicity); + + set_val!(instance, config.imm, imm); + config.rs1_read.assign_value(instance, rs1_read.clone()); + config.rd_written.assign_value(instance, rd_written); + + let outflow = match I::INST_KIND { + InsnKind::SLLI => (rs1_read.as_u64() * imm) >> UInt::::TOTAL_BITS, + InsnKind::SRAI | InsnKind::SRLI => { + if I::INST_KIND == InsnKind::SRAI { + config.is_lt_config.as_ref().unwrap().assign_instance( + instance, + lk_multiplicity, + *rs1_read.as_u16_limbs().last().unwrap() as u64, + )?; + } + + rs1_read.as_u64() & (imm - 1) + } + _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), + }; + + set_val!(instance, config.outflow, outflow); + config + .assert_lt_config + .assign_instance(instance, lk_multiplicity, outflow, imm)?; + + config + .i_insn + .assign_instance(instance, lk_multiplicity, step)?; + + Ok(()) + } +} diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 092c4b560..6ed08d51f 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -107,9 +107,6 @@ impl InsnRecord { #[cfg(feature = "u16limb_circuit")] pub fn imm_internal(insn: &Instruction) -> (i64, F) { match (insn.kind, InsnFormat::from(insn.kind)) { - // Prepare the immediate for ShiftImmInstruction. - // The shift is implemented as a multiplication/division by 1 << immediate. - (SLLI | SRLI | SRAI, _) => (1 << insn.imm, i64_to_base(1 << insn.imm)), // TODO convert to 2 limbs to support smaller field (LB | LH | LW | LBU | LHU | SB | SH | SW, _) => { (insn.imm as i64, i64_to_base(insn.imm as i64)) From b1099cf29066fb23233b64576915b3cb5d3de368 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 21 Aug 2025 11:09:55 +0800 Subject: [PATCH 45/46] clippy fix --- ceno_zkvm/src/instructions/riscv/shift.rs | 15 ++++++++++-- ceno_zkvm/src/instructions/riscv/shift_imm.rs | 24 ++++++++++++------- ceno_zkvm/src/utils.rs | 1 + 3 files changed, 29 insertions(+), 11 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/shift.rs b/ceno_zkvm/src/instructions/riscv/shift.rs index 8e8a43165..b2bfb493a 100644 --- a/ceno_zkvm/src/instructions/riscv/shift.rs +++ b/ceno_zkvm/src/instructions/riscv/shift.rs @@ -35,15 +35,19 @@ mod tests { use ff_ext::{ExtensionField, GoldilocksExt2}; use super::{ShiftLogicalInstruction, SllOp, SraOp, SrlOp}; + #[cfg(feature = "u16limb_circuit")] + use crate::instructions::riscv::constants::UInt8; + #[cfg(feature = "u16limb_circuit")] + use crate::utils::split_to_u8; use crate::{ + Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, instructions::{ Instruction, - riscv::{RIVInstruction, constants::UInt8}, + riscv::{RIVInstruction, constants::UInt}, }, scheme::mock_prover::{MOCK_PC_START, MockProver}, structs::ProgramParams, - utils::split_to_u8, }; #[cfg(feature = "u16limb_circuit")] use ff_ext::BabyBearExt4; @@ -156,6 +160,13 @@ mod tests { .require_equal( || format!("{prefix}_({name})_assert_rd_written"), &mut cb, + #[cfg(not(feature = "u16limb_circuit"))] + &UInt::from_const_unchecked( + Value::new_unchecked(expected_rd_written) + .as_u16_limbs() + .to_vec(), + ), + #[cfg(feature = "u16limb_circuit")] &UInt8::from_const_unchecked(split_to_u8::(expected_rd_written)), ) .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm.rs b/ceno_zkvm/src/instructions/riscv/shift_imm.rs index bfac94a33..d7ea2054c 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm.rs @@ -33,15 +33,18 @@ mod test { use ff_ext::{ExtensionField, GoldilocksExt2}; use super::{ShiftImmInstruction, SlliOp, SraiOp, SrliOp}; + #[cfg(not(feature = "u16limb_circuit"))] + use crate::instructions::riscv::constants::UInt; + #[cfg(feature = "u16limb_circuit")] + use crate::instructions::riscv::constants::UInt8; + #[cfg(feature = "u16limb_circuit")] + use crate::utils::split_to_u8; use crate::{ + Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, - instructions::{ - Instruction, - riscv::{RIVInstruction, constants::UInt8}, - }, + instructions::{Instruction, riscv::RIVInstruction}, scheme::mock_prover::{MOCK_PC_START, MockProver}, structs::ProgramParams, - utils::split_to_u8, }; #[cfg(feature = "u16limb_circuit")] use ff_ext::BabyBearExt4; @@ -52,7 +55,6 @@ mod test { // imm = 3 ("32 << 3", 32, 3, 32 << 3), ("33 << 3", 33, 3, 33 << 3), - // imm = 31 ("32 << 31", 32, 31, 32 << 31), ("33 << 31", 33, 31, 33 << 31), @@ -73,7 +75,6 @@ mod test { ("33 >> 3", 33, 3, 33 >> 3), ("32 >> 31", 32, 31, 32 >> 31), ("33 >> 31", 33, 31, 33 >> 31), - // negative rs1 ("-32 >> 3", (-32_i32) as u32, 3, (-32_i32 >> 3) as u32), ("-33 >> 3", (-33_i32) as u32, 3, (-33_i32 >> 3) as u32), @@ -94,11 +95,9 @@ mod test { // imm = 3 ("32 >> 3", 32, 3, 32 >> 3), ("33 >> 3", 33, 3, 33 >> 3), - // imm = 31 ("32 >> 31", 32, 31, 32 >> 31), ("33 >> 31", 33, 31, 33 >> 31), - // rs1 top bit is 1 ("-32 >> 3", (-32_i32) as u32, 3, ((-32_i32) as u32) >> 3), ]; @@ -157,6 +156,13 @@ mod test { .require_equal( || format!("{prefix}_({name})_assert_rd_written"), &mut cb, + #[cfg(not(feature = "u16limb_circuit"))] + &UInt::from_const_unchecked( + Value::new_unchecked(expected_rd_written) + .as_u16_limbs() + .to_vec(), + ), + #[cfg(feature = "u16limb_circuit")] &UInt8::from_const_unchecked(split_to_u8::(expected_rd_written)), ) .unwrap(); diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index a0b09feef..0041776be 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -29,6 +29,7 @@ pub fn split_to_u8>(value: u32) -> Vec { .collect_vec() } +#[allow(dead_code)] pub fn split_to_limb, const LIMB_BITS: usize>(value: u32) -> Vec { (0..(u32::BITS as usize / LIMB_BITS)) .scan(value, |acc, _| { From 8f630f5607d8dfa0d4e60b0dd21ded5a2dca818d Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 21 Aug 2025 15:18:52 +0800 Subject: [PATCH 46/46] clippy fix --- ceno_zkvm/src/instructions/riscv/shift.rs | 10 +++++----- ceno_zkvm/src/instructions/riscv/shift_imm.rs | 3 ++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/shift.rs b/ceno_zkvm/src/instructions/riscv/shift.rs index b2bfb493a..0c53f1a4c 100644 --- a/ceno_zkvm/src/instructions/riscv/shift.rs +++ b/ceno_zkvm/src/instructions/riscv/shift.rs @@ -35,17 +35,17 @@ mod tests { use ff_ext::{ExtensionField, GoldilocksExt2}; use super::{ShiftLogicalInstruction, SllOp, SraOp, SrlOp}; + #[cfg(not(feature = "u16limb_circuit"))] + use crate::Value; + #[cfg(not(feature = "u16limb_circuit"))] + use crate::instructions::riscv::constants::UInt; #[cfg(feature = "u16limb_circuit")] use crate::instructions::riscv::constants::UInt8; #[cfg(feature = "u16limb_circuit")] use crate::utils::split_to_u8; use crate::{ - Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, - instructions::{ - Instruction, - riscv::{RIVInstruction, constants::UInt}, - }, + instructions::{Instruction, riscv::RIVInstruction}, scheme::mock_prover::{MOCK_PC_START, MockProver}, structs::ProgramParams, }; diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm.rs b/ceno_zkvm/src/instructions/riscv/shift_imm.rs index d7ea2054c..4cf7ac155 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm.rs @@ -34,13 +34,14 @@ mod test { use super::{ShiftImmInstruction, SlliOp, SraiOp, SrliOp}; #[cfg(not(feature = "u16limb_circuit"))] + use crate::Value; + #[cfg(not(feature = "u16limb_circuit"))] use crate::instructions::riscv::constants::UInt; #[cfg(feature = "u16limb_circuit")] use crate::instructions::riscv::constants::UInt8; #[cfg(feature = "u16limb_circuit")] use crate::utils::split_to_u8; use crate::{ - Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, instructions::{Instruction, riscv::RIVInstruction}, scheme::mock_prover::{MOCK_PC_START, MockProver},