diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index 0412dfb09..3d013a2e8 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -506,7 +506,6 @@ impl MemAddr { // Range check the high limb. for i in 1..UINT_LIMBS { let high_u16 = (addr >> (i * 16)) & 0xffff; - println!("assignment max bit {}", (self.max_bits - i * 16).min(16)); lkm.assert_ux_v2(high_u16 as u64, (self.max_bits - i * 16).min(16)); } diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index 96e892fa7..4cd9875db 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -6,7 +6,6 @@ use super::{ use crate::instructions::riscv::auipc::AuipcInstruction; #[cfg(feature = "u16limb_circuit")] use crate::instructions::riscv::lui::LuiInstruction; - use crate::{ error::ZKVMError, instructions::{ diff --git a/ceno_zkvm/src/instructions/riscv/shift.rs b/ceno_zkvm/src/instructions/riscv/shift.rs index d6d25f2b3..0c53f1a4c 100644 --- a/ceno_zkvm/src/instructions/riscv/shift.rs +++ b/ceno_zkvm/src/instructions/riscv/shift.rs @@ -1,42 +1,15 @@ -use std::marker::PhantomData; +#[cfg(not(feature = "u16limb_circuit"))] +pub mod shift_circuit; +#[cfg(feature = "u16limb_circuit")] +pub mod shift_circuit_v2; use ceno_emul::InsnKind; -use ff_ext::ExtensionField; -use super::{RIVInstruction, constants::UInt, r_insn::RInstructionConfig}; -use crate::{ - Value, - error::ZKVMError, - gadgets::{AssertLtConfig, SignedExtendConfig}, - instructions::{ - Instruction, - riscv::constants::{LIMB_BITS, UINT_LIMBS}, - }, - structs::ProgramParams, -}; -use ff_ext::FieldInto; -use multilinear_extensions::{Expression, ToExpr, WitIn}; -use witness::set_val; - -pub struct ShiftConfig { - r_insn: RInstructionConfig, - - rs1_read: UInt, - rs2_read: UInt, - rd_written: UInt, - - rs2_high: UInt, - rs2_low5: WitIn, - pow2_rs2_low5: WitIn, - - outflow: WitIn, - assert_lt_config: AssertLtConfig, - - // SRA - signed_extend_config: Option>, -} - -pub struct ShiftLogicalInstruction(PhantomData<(E, I)>); +use super::RIVInstruction; +#[cfg(not(feature = "u16limb_circuit"))] +use crate::instructions::riscv::shift::shift_circuit::ShiftLogicalInstruction; +#[cfg(feature = "u16limb_circuit")] +use crate::instructions::riscv::shift::shift_circuit_v2::ShiftLogicalInstruction; pub struct SllOp; impl RIVInstruction for SllOp { @@ -56,250 +29,97 @@ impl RIVInstruction for SraOp { } pub type SraInstruction = ShiftLogicalInstruction; -impl Instruction for ShiftLogicalInstruction { - type InstructionConfig = ShiftConfig; - - fn name() -> String { - format!("{:?}", I::INST_KIND) - } - - fn construct_circuit( - circuit_builder: &mut crate::circuit_builder::CircuitBuilder, - _params: &ProgramParams, - ) -> Result { - // treat bit shifting as a bit "inflow" and "outflow" process, flowing from left to right or vice versa - // this approach simplifies constraint and witness allocation compared to using multiplication/division gadget, - // as the divisor/multiplier is a power of 2. - // - // example: right shift (bit flow from left to right) - // inflow || rs1_read == rd_written || outflow - // in this case, inflow consists of either all 0s or all 1s for sign extension (if the value is signed). - // - // for left shifts, the inflow is always 0: - // rs1_read || inflow == outflow || rd_written - // - // additional constraint: outflow < (1 << shift), which lead to unique solution - - // soundness: take Goldilocks as example, both sides of the equation are 63 bits numbers (<2**63) - // rd_written * pow2_rs2_low5 + outflow == inflow * 2**32 + rs1_read - // 32 + 31. 31. 31 + 32. 32. (Bit widths) - - let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; - let rd_written = UInt::new(|| "rd_written", circuit_builder)?; - - let rs2_read = UInt::new_unchecked(|| "rs2_read", circuit_builder)?; - let rs2_low5 = circuit_builder.create_witin(|| "rs2_low5"); - // pow2_rs2_low5 is unchecked because it's assignment will be constrained due it's use in lookup_pow2 below - let pow2_rs2_low5 = circuit_builder.create_witin(|| "pow2_rs2_low5"); - // rs2 = rs2_high | rs2_low5 - let rs2_high = UInt::new(|| "rs2_high", circuit_builder)?; - - let outflow = circuit_builder.create_witin(|| "outflow"); - let assert_lt_config = AssertLtConfig::construct_circuit( - circuit_builder, - || "outflow < pow2_rs2_low5", - outflow.expr(), - pow2_rs2_low5.expr(), - UINT_LIMBS * LIMB_BITS, - )?; - - let two_pow_total_bits: Expression<_> = (1u64 << UInt::::TOTAL_BITS).into(); - - let signed_extend_config = match I::INST_KIND { - InsnKind::SLL => { - circuit_builder.require_equal( - || "shift check", - rs1_read.value() * pow2_rs2_low5.expr(), - outflow.expr() * two_pow_total_bits + rd_written.value(), - )?; - None - } - InsnKind::SRL | InsnKind::SRA => { - let (inflow, signed_extend_config) = match I::INST_KIND { - InsnKind::SRA => { - let signed_extend_config = rs1_read.is_negative(circuit_builder)?; - let msb_expr = signed_extend_config.expr(); - let ones = pow2_rs2_low5.expr() - Expression::ONE; - (msb_expr * ones, Some(signed_extend_config)) - } - InsnKind::SRL => (Expression::ZERO, None), - _ => unreachable!(), - }; - - circuit_builder.require_equal( - || "shift check", - rd_written.value() * pow2_rs2_low5.expr() + outflow.expr(), - inflow * two_pow_total_bits + rs1_read.value(), - )?; - signed_extend_config - } - _ => unreachable!(), - }; - - let r_insn = RInstructionConfig::::construct_circuit( - circuit_builder, - I::INST_KIND, - rs1_read.register_expr(), - rs2_read.register_expr(), - rd_written.register_expr(), - )?; - - circuit_builder.lookup_pow2(rs2_low5.expr(), pow2_rs2_low5.expr())?; - circuit_builder.assert_ux::<_, _, 5>(|| "rs2_low5 in u5", rs2_low5.expr())?; - circuit_builder.require_equal( - || "rs2 == rs2_high * 2^5 + rs2_low5", - rs2_read.value(), - (rs2_high.value() << 5) + rs2_low5.expr(), - )?; - - Ok(ShiftConfig { - r_insn, - rs1_read, - rs2_read, - rd_written, - rs2_high, - rs2_low5, - pow2_rs2_low5, - outflow, - assert_lt_config, - signed_extend_config, - }) - } - - fn assign_instance( - config: &Self::InstructionConfig, - instance: &mut [::BaseField], - lk_multiplicity: &mut crate::witness::LkMultiplicity, - step: &ceno_emul::StepRecord, - ) -> Result<(), crate::error::ZKVMError> { - // rs2 & its derived values - let rs2_read = Value::new_unchecked(step.rs2().unwrap().value); - let rs2_low5 = rs2_read.as_u64() & 0b11111; - lk_multiplicity.assert_ux::<5>(rs2_low5); - lk_multiplicity.lookup_pow2(rs2_low5); - - let pow2_rs2_low5 = 1u64 << rs2_low5; - - let rs2_high = Value::new( - ((rs2_read.as_u64() - rs2_low5) >> 5) as u32, - lk_multiplicity, - ); - config.rs2_high.assign_value(instance, rs2_high); - config.rs2_read.assign_value(instance, rs2_read); - - set_val!(instance, config.pow2_rs2_low5, pow2_rs2_low5); - set_val!(instance, config.rs2_low5, rs2_low5); - - // rs1 - let rs1_read = Value::new_unchecked(step.rs1().unwrap().value); - - // rd - let rd_written = Value::new(step.rd().unwrap().value.after, lk_multiplicity); - - // outflow - let outflow = match I::INST_KIND { - InsnKind::SLL => (rs1_read.as_u64() * pow2_rs2_low5) >> UInt::::TOTAL_BITS, - InsnKind::SRL => rs1_read.as_u64() & (pow2_rs2_low5 - 1), - InsnKind::SRA => { - let Some(signed_ext_config) = config.signed_extend_config.as_ref() else { - Err(ZKVMError::CircuitError)? - }; - signed_ext_config.assign_instance( - instance, - lk_multiplicity, - *rs1_read.as_u16_limbs().last().unwrap() as u64, - )?; - rs1_read.as_u64() & (pow2_rs2_low5 - 1) - } - _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), - }; - - set_val!(instance, config.outflow, outflow); - - config.rs1_read.assign_value(instance, rs1_read); - config.rd_written.assign_value(instance, rd_written); - - config.assert_lt_config.assign_instance( - instance, - lk_multiplicity, - outflow, - pow2_rs2_low5, - )?; - - config - .r_insn - .assign_instance(instance, lk_multiplicity, step)?; - - Ok(()) - } -} - #[cfg(test)] mod tests { use ceno_emul::{Change, InsnKind, StepRecord, encode_rv32}; - use ff_ext::GoldilocksExt2; + use ff_ext::{ExtensionField, GoldilocksExt2}; + use super::{ShiftLogicalInstruction, SllOp, SraOp, SrlOp}; + #[cfg(not(feature = "u16limb_circuit"))] + use crate::Value; + #[cfg(not(feature = "u16limb_circuit"))] + use crate::instructions::riscv::constants::UInt; + #[cfg(feature = "u16limb_circuit")] + use crate::instructions::riscv::constants::UInt8; + #[cfg(feature = "u16limb_circuit")] + use crate::utils::split_to_u8; use crate::{ - Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, - instructions::{ - Instruction, - riscv::{RIVInstruction, constants::UInt}, - }, + instructions::{Instruction, riscv::RIVInstruction}, scheme::mock_prover::{MOCK_PC_START, MockProver}, structs::ProgramParams, }; - - use super::{ShiftLogicalInstruction, SllOp, SraOp, SrlOp}; + #[cfg(feature = "u16limb_circuit")] + use ff_ext::BabyBearExt4; #[test] fn test_opcode_sll() { - verify::("basic", 0b_0001, 3, 0b_1000); - // 33 << 33 === 33 << 1 - verify::("rs2 over 5-bits", 0b_0001, 33, 0b_0010); - verify::("bit loss", (1 << 31) | 1, 1, 0b_0010); - verify::("zero shift", 0b_0001, 0, 0b_0001); - verify::("all zeros", 0b_0000, 0, 0b_0000); - verify::("base is zero", 0b_0000, 1, 0b_0000); + let cases = [ + ("basic 1", 32, 3, 32 << 3), + ("basic 2", 0b_0001, 3, 0b_1000), + // 33 << 33 === 33 << 1 + ("rs2 over 5-bits", 0b_0001, 33, 0b_0010), + ("bit loss", (1 << 31) | 1, 1, 0b_0010), + ("zero shift", 0b_0001, 0, 0b_0001), + ("all zeros", 0b_0000, 0, 0b_0000), + ("base is zero", 0b_0000, 1, 0b_0000), + ]; + + for (name, lhs, rhs, expected) in cases { + verify::(name, lhs, rhs, expected); + #[cfg(feature = "u16limb_circuit")] + verify::(name, lhs, rhs, expected); + } } #[test] fn test_opcode_srl() { - verify::("basic", 0b_1000, 3, 0b_0001); - // 33 >> 33 === 33 >> 1 - verify::("rs2 over 5-bits", 0b_1010, 33, 0b_0101); - verify::("bit loss", 0b_1001, 1, 0b_0100); - verify::("zero shift", 0b_1000, 0, 0b_1000); - verify::("all zeros", 0b_0000, 0, 0b_0000); - verify::("base is zero", 0b_0000, 1, 0b_0000); + let cases = [ + ("basic", 0b_1000, 3, 0b_0001), + // 33 >> 33 === 33 >> 1 + ("rs2 over 5-bits", 0b_1010, 33, 0b_0101), + ("bit loss", 0b_1001, 1, 0b_0100), + ("zero shift", 0b_1000, 0, 0b_1000), + ("all zeros", 0b_0000, 0, 0b_0000), + ("base is zero", 0b_0000, 1, 0b_0000), + ]; + + for (name, lhs, rhs, expected) in cases { + verify::(name, lhs, rhs, expected); + #[cfg(feature = "u16limb_circuit")] + verify::(name, lhs, rhs, expected); + } } #[test] fn test_opcode_sra() { - // positive rs1 - // rs2 = 3 - verify::("32 >> 3", 32, 3, 32 >> 3); - verify::("33 >> 3", 33, 3, 33 >> 3); - // rs2 = 31 - verify::("32 >> 31", 32, 31, 32 >> 31); - verify::("33 >> 31", 33, 31, 33 >> 31); - - // negative rs1 - // rs2 = 3 - verify::("-32 >> 3", (-32_i32) as u32, 3, (-32_i32 >> 3) as u32); - verify::("-33 >> 3", (-33_i32) as u32, 3, (-33_i32 >> 3) as u32); - // rs2 = 31 - verify::("-32 >> 31", (-32_i32) as u32, 31, (-32_i32 >> 31) as u32); - verify::("-33 >> 31", (-33_i32) as u32, 31, (-33_i32 >> 31) as u32); + let cases = [ + // positive rs1 + ("32 >> 3", 32, 3, 32 >> 3), + ("33 >> 3", 33, 3, 33 >> 3), + ("32 >> 31", 32, 31, 32 >> 31), + ("33 >> 31", 33, 31, 33 >> 31), + // negative rs1 + ("-32 >> 3", (-32_i32) as u32, 3, (-32_i32 >> 3) as u32), + ("-33 >> 3", (-33_i32) as u32, 3, (-33_i32 >> 3) as u32), + ("-32 >> 31", (-32_i32) as u32, 31, (-32_i32 >> 31) as u32), + ("-33 >> 31", (-33_i32) as u32, 31, (-33_i32 >> 31) as u32), + ]; + + for (name, lhs, rhs, expected) in cases { + verify::(name, lhs, rhs, expected); + #[cfg(feature = "u16limb_circuit")] + verify::(name, lhs, rhs, expected); + } } - fn verify( + fn verify( name: &'static str, rs1_read: u32, rs2_read: u32, expected_rd_written: u32, ) { - let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); let shift = rs2_read & 0b11111; @@ -326,12 +146,10 @@ mod tests { .namespace( || format!("{prefix}_({name})"), |cb| { - Ok( - ShiftLogicalInstruction::::construct_circuit( - cb, - &ProgramParams::default(), - ), - ) + Ok(ShiftLogicalInstruction::::construct_circuit( + cb, + &ProgramParams::default(), + )) }, ) .unwrap() @@ -342,15 +160,18 @@ mod tests { .require_equal( || format!("{prefix}_({name})_assert_rd_written"), &mut cb, + #[cfg(not(feature = "u16limb_circuit"))] &UInt::from_const_unchecked( Value::new_unchecked(expected_rd_written) .as_u16_limbs() .to_vec(), ), + #[cfg(feature = "u16limb_circuit")] + &UInt8::from_const_unchecked(split_to_u8::(expected_rd_written)), ) .unwrap(); - let (raw_witin, lkm) = ShiftLogicalInstruction::::assign_instances( + let (raw_witin, lkm) = ShiftLogicalInstruction::::assign_instances( &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs new file mode 100644 index 000000000..87374b20e --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs @@ -0,0 +1,218 @@ +use crate::{ + Value, + error::ZKVMError, + gadgets::SignedExtendConfig, + instructions::{ + Instruction, + riscv::{ + RIVInstruction, + constants::{LIMB_BITS, UINT_LIMBS, UInt}, + r_insn::RInstructionConfig, + }, + }, + structs::ProgramParams, +}; +use ceno_emul::InsnKind; +use ff_ext::{ExtensionField, FieldInto}; +use gkr_iop::gadgets::AssertLtConfig; +use multilinear_extensions::{Expression, ToExpr, WitIn}; +use std::marker::PhantomData; +use witness::set_val; + +pub struct ShiftConfig { + r_insn: RInstructionConfig, + + rs1_read: UInt, + rs2_read: UInt, + pub rd_written: UInt, + + rs2_high: UInt, + rs2_low5: WitIn, + pow2_rs2_low5: WitIn, + + outflow: WitIn, + assert_lt_config: AssertLtConfig, + + // SRA + signed_extend_config: Option>, +} + +pub struct ShiftLogicalInstruction(PhantomData<(E, I)>); + +impl Instruction for ShiftLogicalInstruction { + type InstructionConfig = ShiftConfig; + + fn name() -> String { + format!("{:?}", I::INST_KIND) + } + + fn construct_circuit( + circuit_builder: &mut crate::circuit_builder::CircuitBuilder, + _params: &ProgramParams, + ) -> Result { + // treat bit shifting as a bit "inflow" and "outflow" process, flowing from left to right or vice versa + // this approach simplifies constraint and witness allocation compared to using multiplication/division gadget, + // as the divisor/multiplier is a power of 2. + // + // example: right shift (bit flow from left to right) + // inflow || rs1_read == rd_written || outflow + // in this case, inflow consists of either all 0s or all 1s for sign extension (if the value is signed). + // + // for left shifts, the inflow is always 0: + // rs1_read || inflow == outflow || rd_written + // + // additional constraint: outflow < (1 << shift), which lead to unique solution + + // soundness: take Goldilocks as example, both sides of the equation are 63 bits numbers (<2**63) + // rd_written * pow2_rs2_low5 + outflow == inflow * 2**32 + rs1_read + // 32 + 31. 31. 31 + 32. 32. (Bit widths) + + let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; + let rd_written = UInt::new(|| "rd_written", circuit_builder)?; + + let rs2_read = UInt::new_unchecked(|| "rs2_read", circuit_builder)?; + let rs2_low5 = circuit_builder.create_witin(|| "rs2_low5"); + // pow2_rs2_low5 is unchecked because it's assignment will be constrained due it's use in lookup_pow2 below + let pow2_rs2_low5 = circuit_builder.create_witin(|| "pow2_rs2_low5"); + // rs2 = rs2_high | rs2_low5 + let rs2_high = UInt::new(|| "rs2_high", circuit_builder)?; + + let outflow = circuit_builder.create_witin(|| "outflow"); + let assert_lt_config = AssertLtConfig::construct_circuit( + circuit_builder, + || "outflow < pow2_rs2_low5", + outflow.expr(), + pow2_rs2_low5.expr(), + UINT_LIMBS * LIMB_BITS, + )?; + + let two_pow_total_bits: Expression<_> = (1u64 << UInt::::TOTAL_BITS).into(); + + let signed_extend_config = match I::INST_KIND { + InsnKind::SLL => { + circuit_builder.require_equal( + || "shift check", + rs1_read.value() * pow2_rs2_low5.expr(), + outflow.expr() * two_pow_total_bits + rd_written.value(), + )?; + None + } + InsnKind::SRL | InsnKind::SRA => { + let (inflow, signed_extend_config) = match I::INST_KIND { + InsnKind::SRA => { + let signed_extend_config = rs1_read.is_negative(circuit_builder)?; + let msb_expr = signed_extend_config.expr(); + let ones = pow2_rs2_low5.expr() - Expression::ONE; + (msb_expr * ones, Some(signed_extend_config)) + } + InsnKind::SRL => (Expression::ZERO, None), + _ => unreachable!(), + }; + + circuit_builder.require_equal( + || "shift check", + rd_written.value() * pow2_rs2_low5.expr() + outflow.expr(), + inflow * two_pow_total_bits + rs1_read.value(), + )?; + signed_extend_config + } + _ => unreachable!(), + }; + + let r_insn = RInstructionConfig::::construct_circuit( + circuit_builder, + I::INST_KIND, + rs1_read.register_expr(), + rs2_read.register_expr(), + rd_written.register_expr(), + )?; + + circuit_builder.lookup_pow2(rs2_low5.expr(), pow2_rs2_low5.expr())?; + circuit_builder.assert_ux::<_, _, 5>(|| "rs2_low5 in u5", rs2_low5.expr())?; + circuit_builder.require_equal( + || "rs2 == rs2_high * 2^5 + rs2_low5", + rs2_read.value(), + (rs2_high.value() << 5) + rs2_low5.expr(), + )?; + + Ok(ShiftConfig { + r_insn, + rs1_read, + rs2_read, + rd_written, + rs2_high, + rs2_low5, + pow2_rs2_low5, + outflow, + assert_lt_config, + signed_extend_config, + }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [::BaseField], + lk_multiplicity: &mut crate::witness::LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), crate::error::ZKVMError> { + // rs2 & its derived values + let rs2_read = Value::new_unchecked(step.rs2().unwrap().value); + let rs2_low5 = rs2_read.as_u64() & 0b11111; + lk_multiplicity.assert_ux::<5>(rs2_low5); + lk_multiplicity.lookup_pow2(rs2_low5); + + let pow2_rs2_low5 = 1u64 << rs2_low5; + + let rs2_high = Value::new( + ((rs2_read.as_u64() - rs2_low5) >> 5) as u32, + lk_multiplicity, + ); + config.rs2_high.assign_value(instance, rs2_high); + config.rs2_read.assign_value(instance, rs2_read); + + set_val!(instance, config.pow2_rs2_low5, pow2_rs2_low5); + set_val!(instance, config.rs2_low5, rs2_low5); + + // rs1 + let rs1_read = Value::new_unchecked(step.rs1().unwrap().value); + + // rd + let rd_written = Value::new(step.rd().unwrap().value.after, lk_multiplicity); + + // outflow + let outflow = match I::INST_KIND { + InsnKind::SLL => (rs1_read.as_u64() * pow2_rs2_low5) >> UInt::::TOTAL_BITS, + InsnKind::SRL => rs1_read.as_u64() & (pow2_rs2_low5 - 1), + InsnKind::SRA => { + let Some(signed_ext_config) = config.signed_extend_config.as_ref() else { + Err(ZKVMError::CircuitError)? + }; + signed_ext_config.assign_instance( + instance, + lk_multiplicity, + *rs1_read.as_u16_limbs().last().unwrap() as u64, + )?; + rs1_read.as_u64() & (pow2_rs2_low5 - 1) + } + _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), + }; + + set_val!(instance, config.outflow, outflow); + + config.rs1_read.assign_value(instance, rs1_read); + config.rd_written.assign_value(instance, rd_written); + + config.assert_lt_config.assign_instance( + instance, + lk_multiplicity, + outflow, + pow2_rs2_low5, + )?; + + config + .r_insn + .assign_instance(instance, lk_multiplicity, step)?; + + Ok(()) + } +} diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs new file mode 100644 index 000000000..06c4a28ad --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs @@ -0,0 +1,525 @@ +/// constrain implementation follow from https://github.com/openvm-org/openvm/blob/main/extensions/rv32im/circuit/src/shift/core.rs +use crate::{ + instructions::{ + Instruction, + riscv::{ + RIVInstruction, + constants::{UINT_BYTE_LIMBS, UInt8}, + i_insn::IInstructionConfig, + r_insn::RInstructionConfig, + }, + }, + structs::ProgramParams, + utils::{split_to_limb, split_to_u8}, +}; +use ceno_emul::InsnKind; +use ff_ext::{ExtensionField, FieldInto}; +use itertools::Itertools; +use multilinear_extensions::{Expression, ToExpr, WitIn}; +use p3::field::{Field, FieldAlgebra}; +use std::{array, marker::PhantomData}; +use witness::set_val; + +pub struct ShiftBaseConfig { + // bit_multiplier = 2^bit_shift + pub bit_multiplier_left: WitIn, + pub bit_multiplier_right: WitIn, + + // Sign of x for SRA + pub b_sign: WitIn, + + // Boolean columns that are 1 exactly at the index of the bit/limb shift amount + pub bit_shift_marker: [WitIn; LIMB_BITS], + pub limb_shift_marker: [WitIn; NUM_LIMBS], + + // Part of each x[i] that gets bit shifted to the next limb + pub bit_shift_carry: [WitIn; NUM_LIMBS], + pub phantom: PhantomData, +} + +impl + ShiftBaseConfig +{ + pub fn construct_circuit( + circuit_builder: &mut crate::circuit_builder::CircuitBuilder, + kind: InsnKind, + a: [Expression; NUM_LIMBS], + b: [Expression; NUM_LIMBS], + c: [Expression; NUM_LIMBS], + ) -> Result { + let bit_shift_marker = + array::from_fn(|i| circuit_builder.create_witin(|| format!("bit_shift_marker_{}", i))); + let limb_shift_marker = + array::from_fn(|i| circuit_builder.create_witin(|| format!("limb_shift_marker_{}", i))); + let bit_multiplier_left = circuit_builder.create_witin(|| "bit_multiplier_left"); + let bit_multiplier_right = circuit_builder.create_witin(|| "bit_multiplier_right"); + let b_sign = circuit_builder.create_bit(|| "b_sign")?; + let bit_shift_carry = + array::from_fn(|i| circuit_builder.create_witin(|| format!("bit_shift_carry_{}", i))); + + // Constrain that bit_shift, bit_multiplier are correct, i.e. that bit_multiplier = + // 1 << bit_shift. Because the sum of all bit_shift_marker[i] is constrained to be + // 1, bit_shift is guaranteed to be in range. + let mut bit_marker_sum = Expression::ZERO; + let mut bit_shift = Expression::ZERO; + + for (i, bit_shift_marker_i) in bit_shift_marker.iter().enumerate().take(LIMB_BITS) { + circuit_builder.assert_bit( + || format!("bit_shift_marker_{i}_assert_bit"), + bit_shift_marker_i.expr(), + )?; + bit_marker_sum += bit_shift_marker_i.expr(); + bit_shift += E::BaseField::from_canonical_usize(i).expr() * bit_shift_marker_i.expr(); + + match kind { + InsnKind::SLL | InsnKind::SLLI => { + circuit_builder.condition_require_zero( + || "bit_multiplier_left_condition", + bit_shift_marker_i.expr(), + bit_multiplier_left.expr() + - E::BaseField::from_canonical_usize(1 << i).expr(), + )?; + } + InsnKind::SRL | InsnKind::SRLI | InsnKind::SRA | InsnKind::SRAI => { + circuit_builder.condition_require_zero( + || "bit_multiplier_right_condition", + bit_shift_marker_i.expr(), + bit_multiplier_right.expr() + - E::BaseField::from_canonical_usize(1 << i).expr(), + )?; + } + _ => unreachable!(), + } + } + circuit_builder.require_one(|| "bit_marker_sum_one_hot", bit_marker_sum.expr())?; + + // Check that a[i] = b[i] <> c[i] both on the bit and limb shift level if c < + // NUM_LIMBS * LIMB_BITS. + let mut limb_marker_sum = Expression::ZERO; + let mut limb_shift = Expression::ZERO; + for i in 0..NUM_LIMBS { + circuit_builder.assert_bit( + || format!("limb_shift_marker_{i}_assert_bit"), + limb_shift_marker[i].expr(), + )?; + limb_marker_sum += limb_shift_marker[i].expr(); + limb_shift += + E::BaseField::from_canonical_usize(i).expr() * limb_shift_marker[i].expr(); + + for j in 0..NUM_LIMBS { + match kind { + InsnKind::SLL | InsnKind::SLLI => { + if j < i { + circuit_builder.condition_require_zero( + || format!("limb_shift_marker_a_{i}_{j}"), + limb_shift_marker[i].expr(), + a[j].expr(), + )?; + } else { + let expected_a_left = if j - i == 0 { + Expression::ZERO + } else { + bit_shift_carry[j - i - 1].expr() + } + b[j - i].expr() * bit_multiplier_left.expr() + - E::BaseField::from_canonical_usize(1 << LIMB_BITS).expr() + * bit_shift_carry[j - i].expr(); + circuit_builder.condition_require_zero( + || format!("limb_shift_marker_a_expected_a_left_{i}_{j}",), + limb_shift_marker[i].expr(), + a[j].expr() - expected_a_left, + )?; + } + } + InsnKind::SRL | InsnKind::SRLI | InsnKind::SRA | InsnKind::SRAI => { + // SRL and SRA constraints. Combining with above would require an additional column. + if j + i > NUM_LIMBS - 1 { + circuit_builder.condition_require_zero( + || format!("limb_shift_marker_a_{i}_{j}"), + limb_shift_marker[i].expr(), + a[j].expr() + - b_sign.expr() + * E::BaseField::from_canonical_usize((1 << LIMB_BITS) - 1) + .expr(), + )?; + } else { + let expected_a_right = + if j + i == NUM_LIMBS - 1 { + b_sign.expr() * (bit_multiplier_right.expr() - Expression::ONE) + } else { + bit_shift_carry[j + i + 1].expr() + } * E::BaseField::from_canonical_usize(1 << LIMB_BITS).expr() + + (b[j + i].expr() - bit_shift_carry[j + i].expr()); + + circuit_builder.condition_require_zero( + || format!("limb_shift_marker_a_expected_a_right_{i}_{j}",), + limb_shift_marker[i].expr(), + a[j].expr() * bit_multiplier_right.expr() - expected_a_right, + )?; + } + } + _ => unimplemented!(), + } + } + } + circuit_builder.require_one(|| "limb_marker_sum_one_hot", limb_marker_sum.expr())?; + + // Check that bit_shift and limb_shift are correct. + let num_bits = E::BaseField::from_canonical_usize(NUM_LIMBS * LIMB_BITS); + // TODO switch to assert_ux_v2 once support dynamic table range check + // circuit_builder.assert_ux_v2( + // || "bit_shift_vs_limb_shift", + // (c[0].expr() + // - limb_shift * E::BaseField::from_canonical_usize(LIMB_BITS).expr() + // - bit_shift.expr()) + // * num_bits.inverse().expr(), + // LIMB_BITS - ((NUM_LIMBS * LIMB_BITS) as u32).ilog2() as usize, + // )?; + circuit_builder.assert_ux_in_u16( + || "bit_shift_vs_limb_shift", + LIMB_BITS - ((NUM_LIMBS * LIMB_BITS) as u32).ilog2() as usize, + (c[0].expr() + - limb_shift * E::BaseField::from_canonical_usize(LIMB_BITS).expr() + - bit_shift.expr()) + * num_bits.inverse().expr(), + )?; + if !matches!(kind, InsnKind::SRA | InsnKind::SRAI) { + circuit_builder.require_zero(|| "b_sign_zero", b_sign.expr())?; + } else { + let mask = E::BaseField::from_canonical_u32(1 << (LIMB_BITS - 1)).expr(); + let b_sign_shifted = b_sign.expr() * mask.expr(); + circuit_builder.lookup_xor_byte( + b[NUM_LIMBS - 1].expr(), + mask.expr(), + b[NUM_LIMBS - 1].expr() + mask.expr() + - (E::BaseField::from_canonical_u32(2).expr()) * b_sign_shifted.expr(), + )?; + } + + for (i, carry) in bit_shift_carry.iter().enumerate() { + // TODO replace `LIMB_BITS` with `bit_shift` so we can support more strict range check + // `bit_shift` could be expression + // TODO refactor range check to support dynamic range + circuit_builder.assert_ux_v2( + || format!("bit_shift_carry_range_check_{i}"), + carry.expr(), + LIMB_BITS, + )?; + } + + Ok(Self { + bit_shift_marker, + bit_multiplier_left, + bit_multiplier_right, + limb_shift_marker, + bit_shift_carry, + b_sign, + phantom: PhantomData, + }) + } + + pub fn assign_instances( + &self, + instance: &mut [::BaseField], + lk_multiplicity: &mut crate::witness::LkMultiplicity, + kind: InsnKind, + b: u32, + c: u32, + ) { + let b = split_to_limb::<_, LIMB_BITS>(b); + let c = split_to_limb::<_, LIMB_BITS>(c); + let (_, limb_shift, bit_shift) = run_shift::( + kind, + &b.clone().try_into().unwrap(), + &c.clone().try_into().unwrap(), + ); + + match kind { + InsnKind::SLL | InsnKind::SLLI => set_val!( + instance, + self.bit_multiplier_left, + E::BaseField::from_canonical_usize(1 << bit_shift) + ), + _ => set_val!( + instance, + self.bit_multiplier_right, + E::BaseField::from_canonical_usize(1 << bit_shift) + ), + }; + + let bit_shift_carry: [u32; NUM_LIMBS] = array::from_fn(|i| match kind { + InsnKind::SLL | InsnKind::SLLI => b[i] >> (LIMB_BITS - bit_shift), + _ => b[i] % (1 << bit_shift), + }); + for (val, witin) in bit_shift_carry.iter().zip_eq(&self.bit_shift_carry) { + set_val!(instance, witin, E::BaseField::from_canonical_u32(*val)); + lk_multiplicity.assert_ux_v2(*val as u64, LIMB_BITS); + } + for (i, witin) in self.bit_shift_marker.iter().enumerate() { + set_val!(instance, witin, E::BaseField::from_bool(i == bit_shift)); + } + for (i, witin) in self.limb_shift_marker.iter().enumerate() { + set_val!(instance, witin, E::BaseField::from_bool(i == limb_shift)); + } + let num_bits_log = (NUM_LIMBS * LIMB_BITS).ilog2(); + lk_multiplicity.assert_ux_in_u16( + LIMB_BITS - num_bits_log as usize, + (((c[0] as usize) - bit_shift - limb_shift * LIMB_BITS) >> num_bits_log) as u64, + ); + + let mut b_sign = 0; + if matches!(kind, InsnKind::SRA | InsnKind::SRAI) { + b_sign = b[NUM_LIMBS - 1] >> (LIMB_BITS - 1); + lk_multiplicity.lookup_xor_byte(b[NUM_LIMBS - 1] as u64, 1 << (LIMB_BITS - 1)); + } + set_val!(instance, self.b_sign, E::BaseField::from_bool(b_sign != 0)); + } +} + +pub struct ShiftRTypeConfig { + shift_base_config: ShiftBaseConfig, + rs1_read: UInt8, + rs2_read: UInt8, + pub rd_written: UInt8, + r_insn: RInstructionConfig, +} + +pub struct ShiftLogicalInstruction(PhantomData<(E, I)>); + +impl Instruction for ShiftLogicalInstruction { + type InstructionConfig = ShiftRTypeConfig; + + fn name() -> String { + format!("{:?}", I::INST_KIND) + } + + fn construct_circuit( + circuit_builder: &mut crate::circuit_builder::CircuitBuilder, + _params: &ProgramParams, + ) -> Result { + let (rd_written, rs1_read, rs2_read) = match I::INST_KIND { + InsnKind::SLL | InsnKind::SRL | InsnKind::SRA => { + let rs1_read = UInt8::new_unchecked(|| "rs1_read", circuit_builder)?; + let rs2_read = UInt8::new_unchecked(|| "rs2_read", circuit_builder)?; + let rd_written = UInt8::new(|| "rd_written", circuit_builder)?; + (rd_written, rs1_read, rs2_read) + } + _ => unimplemented!(), + }; + + let r_insn = RInstructionConfig::::construct_circuit( + circuit_builder, + I::INST_KIND, + rs1_read.register_expr(), + rs2_read.register_expr(), + rd_written.register_expr(), + )?; + + let shift_base_config = ShiftBaseConfig::construct_circuit( + circuit_builder, + I::INST_KIND, + rd_written.expr().try_into().unwrap(), + rs1_read.expr().try_into().unwrap(), + rs2_read.expr().try_into().unwrap(), + )?; + + Ok(ShiftRTypeConfig { + r_insn, + rs1_read, + rs2_read, + rd_written, + shift_base_config, + }) + } + + fn assign_instance( + config: &ShiftRTypeConfig, + instance: &mut [::BaseField], + lk_multiplicity: &mut crate::witness::LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), crate::error::ZKVMError> { + // rs2 + let rs2_read = split_to_u8::(step.rs2().unwrap().value); + // rs1 + let rs1_read = split_to_u8::(step.rs1().unwrap().value); + // rd + let rd_written = split_to_u8::(step.rd().unwrap().value.after); + for val in &rd_written { + lk_multiplicity.assert_ux::<8>(*val as u64); + } + + config.rs1_read.assign_limbs(instance, &rs1_read); + config.rs2_read.assign_limbs(instance, &rs2_read); + config.rd_written.assign_limbs(instance, &rd_written); + + config.shift_base_config.assign_instances( + instance, + lk_multiplicity, + I::INST_KIND, + step.rs1().unwrap().value, + step.rs2().unwrap().value, + ); + config + .r_insn + .assign_instance(instance, lk_multiplicity, step)?; + + Ok(()) + } +} + +pub struct ShiftImmConfig { + shift_base_config: ShiftBaseConfig, + rs1_read: UInt8, + pub rd_written: UInt8, + i_insn: IInstructionConfig, + imm: WitIn, +} + +pub struct ShiftImmInstruction(PhantomData<(E, I)>); + +impl Instruction for ShiftImmInstruction { + type InstructionConfig = ShiftImmConfig; + + fn name() -> String { + format!("{:?}", I::INST_KIND) + } + + fn construct_circuit( + circuit_builder: &mut crate::circuit_builder::CircuitBuilder, + _params: &ProgramParams, + ) -> Result { + let (rd_written, rs1_read, imm) = match I::INST_KIND { + InsnKind::SLLI | InsnKind::SRLI | InsnKind::SRAI => { + let rs1_read = UInt8::new_unchecked(|| "rs1_read", circuit_builder)?; + let imm = circuit_builder.create_witin(|| "imm"); + let rd_written = UInt8::new(|| "rd_written", circuit_builder)?; + (rd_written, rs1_read, imm) + } + _ => unimplemented!(), + }; + let uint8_imm = UInt8::from_exprs_unchecked(vec![imm.expr(), 0.into(), 0.into(), 0.into()]); + + let i_insn = IInstructionConfig::::construct_circuit( + circuit_builder, + I::INST_KIND, + imm.expr(), + 0.into(), + rs1_read.register_expr(), + rd_written.register_expr(), + false, + )?; + + let shift_base_config = ShiftBaseConfig::construct_circuit( + circuit_builder, + I::INST_KIND, + rd_written.expr().try_into().unwrap(), + rs1_read.expr().try_into().unwrap(), + uint8_imm.expr().try_into().unwrap(), + )?; + + Ok(ShiftImmConfig { + i_insn, + imm, + rs1_read, + rd_written, + shift_base_config, + }) + } + + fn assign_instance( + config: &ShiftImmConfig, + instance: &mut [::BaseField], + lk_multiplicity: &mut crate::witness::LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), crate::error::ZKVMError> { + let imm = step.insn().imm as i16 as u16; + set_val!(instance, config.imm, E::BaseField::from_canonical_u16(imm)); + // rs1 + let rs1_read = split_to_u8::(step.rs1().unwrap().value); + // rd + let rd_written = split_to_u8::(step.rd().unwrap().value.after); + for val in &rd_written { + lk_multiplicity.assert_ux::<8>(*val as u64); + } + + config.rs1_read.assign_limbs(instance, &rs1_read); + config.rd_written.assign_limbs(instance, &rd_written); + + config.shift_base_config.assign_instances( + instance, + lk_multiplicity, + I::INST_KIND, + step.rs1().unwrap().value, + imm as u32, + ); + config + .i_insn + .assign_instance(instance, lk_multiplicity, step)?; + + Ok(()) + } +} + +fn run_shift( + kind: InsnKind, + x: &[u32; NUM_LIMBS], + y: &[u32; NUM_LIMBS], +) -> ([u32; NUM_LIMBS], usize, usize) { + match kind { + InsnKind::SLL | InsnKind::SLLI => run_shift_left::(x, y), + InsnKind::SRL | InsnKind::SRLI => run_shift_right::(x, y, true), + InsnKind::SRA | InsnKind::SRAI => run_shift_right::(x, y, false), + _ => unreachable!(), + } +} + +fn run_shift_left( + x: &[u32; NUM_LIMBS], + y: &[u32; NUM_LIMBS], +) -> ([u32; NUM_LIMBS], usize, usize) { + let mut result = [0u32; NUM_LIMBS]; + + let (limb_shift, bit_shift) = get_shift::(y); + + for i in limb_shift..NUM_LIMBS { + result[i] = if i > limb_shift { + ((x[i - limb_shift] << bit_shift) + (x[i - limb_shift - 1] >> (LIMB_BITS - bit_shift))) + % (1 << LIMB_BITS) + } else { + (x[i - limb_shift] << bit_shift) % (1 << LIMB_BITS) + }; + } + (result, limb_shift, bit_shift) +} + +fn run_shift_right( + x: &[u32; NUM_LIMBS], + y: &[u32; NUM_LIMBS], + logical: bool, +) -> ([u32; NUM_LIMBS], usize, usize) { + let fill = if logical { + 0 + } else { + ((1 << LIMB_BITS) - 1) * (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1)) + }; + let mut result = [fill; NUM_LIMBS]; + + let (limb_shift, bit_shift) = get_shift::(y); + + for i in 0..(NUM_LIMBS - limb_shift) { + result[i] = if i + limb_shift + 1 < NUM_LIMBS { + ((x[i + limb_shift] >> bit_shift) + (x[i + limb_shift + 1] << (LIMB_BITS - bit_shift))) + % (1 << LIMB_BITS) + } else { + ((x[i + limb_shift] >> bit_shift) + (fill << (LIMB_BITS - bit_shift))) + % (1 << LIMB_BITS) + } + } + (result, limb_shift, bit_shift) +} + +fn get_shift(y: &[u32]) -> (usize, usize) { + // We assume `NUM_LIMBS * LIMB_BITS <= 2^LIMB_BITS` so so the shift is defined + // entirely in y[0]. + let shift = (y[0] as usize) % (NUM_LIMBS * LIMB_BITS); + (shift / LIMB_BITS, shift % LIMB_BITS) +} diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm.rs b/ceno_zkvm/src/instructions/riscv/shift_imm.rs index 9eb759ffa..4cf7ac155 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm.rs @@ -1,44 +1,17 @@ -use super::RIVInstruction; -use crate::{ - Value, - circuit_builder::CircuitBuilder, - error::ZKVMError, - gadgets::{AssertLtConfig, SignedExtendConfig}, - instructions::{ - Instruction, - riscv::{ - constants::{LIMB_BITS, UINT_LIMBS, UInt}, - i_insn::IInstructionConfig, - }, - }, - structs::ProgramParams, - tables::InsnRecord, - witness::LkMultiplicity, -}; -use ceno_emul::{InsnKind, StepRecord}; -use ff_ext::{ExtensionField, FieldInto}; -use multilinear_extensions::{Expression, ToExpr, WitIn}; -use std::marker::PhantomData; -use witness::set_val; - -pub struct ShiftImmConfig { - i_insn: IInstructionConfig, - - imm: WitIn, - rs1_read: UInt, - rd_written: UInt, - outflow: WitIn, - assert_lt_config: AssertLtConfig, +#[cfg(not(feature = "u16limb_circuit"))] +mod shift_imm_circuit; - // SRAI - is_lt_config: Option>, -} +use super::RIVInstruction; +use ceno_emul::InsnKind; -pub struct ShiftImmInstruction(PhantomData<(E, I)>); +#[cfg(feature = "u16limb_circuit")] +use crate::instructions::riscv::shift::shift_circuit_v2::ShiftImmInstruction; +#[cfg(not(feature = "u16limb_circuit"))] +use crate::instructions::riscv::shift_imm::shift_imm_circuit::ShiftImmInstruction; pub struct SlliOp; impl RIVInstruction for SlliOp { - const INST_KIND: ceno_emul::InsnKind = ceno_emul::InsnKind::SLLI; + const INST_KIND: InsnKind = InsnKind::SLLI; } pub type SlliInstruction = ShiftImmInstruction; @@ -54,210 +27,96 @@ impl RIVInstruction for SrliOp { } pub type SrliInstruction = ShiftImmInstruction; -impl Instruction for ShiftImmInstruction { - type InstructionConfig = ShiftImmConfig; - - fn name() -> String { - format!("{:?}", I::INST_KIND) - } - - fn construct_circuit( - circuit_builder: &mut CircuitBuilder, - _params: &ProgramParams, - ) -> Result { - // treat bit shifting as a bit "inflow" and "outflow" process, flowing from left to right or vice versa - // this approach simplifies constraint and witness allocation compared to using multiplication/division gadget, - // as the divisor/multiplier is a power of 2. - // - // example: right shift (bit flow from left to right) - // inflow || rs1_read == rd_written || outflow - // in this case, inflow consists of either all 0s or all 1s for sign extension (if the value is signed). - // - // for left shifts, the inflow is always 0: - // rs1_read || inflow == outflow || rd_written - // - // additional constraint: outflow < (1 << shift), which lead to unique solution - - // soundness: take Goldilocks as example, both sides of the equation are 63 bits numbers (<2**63) - // rd * imm + outflow == inflow * 2**32 + rs1 - // 32 + 31. 31. 31 + 32. 32. (Bit widths) - - // Note: `imm` wtns is set to 2**imm (upto 32 bit) just for efficient verification. - let imm = circuit_builder.create_witin(|| "imm"); - let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; - let rd_written = UInt::new(|| "rd_written", circuit_builder)?; - - let outflow = circuit_builder.create_witin(|| "outflow"); - let assert_lt_config = AssertLtConfig::construct_circuit( - circuit_builder, - || "outflow < imm", - outflow.expr(), - imm.expr(), - UINT_LIMBS * LIMB_BITS, - )?; - - let two_pow_total_bits: Expression<_> = (1u64 << UInt::::TOTAL_BITS).into(); - - let is_lt_config = match I::INST_KIND { - InsnKind::SLLI => { - circuit_builder.require_equal( - || "shift check", - rs1_read.value() * imm.expr(), // inflow is zero for this case - outflow.expr() * two_pow_total_bits + rd_written.value(), - )?; - None - } - InsnKind::SRAI | InsnKind::SRLI => { - let (inflow, is_lt_config) = match I::INST_KIND { - InsnKind::SRAI => { - let is_rs1_neg = rs1_read.is_negative(circuit_builder)?; - let ones = imm.expr() - 1; - (is_rs1_neg.expr() * ones, Some(is_rs1_neg)) - } - InsnKind::SRLI => (Expression::ZERO, None), - _ => unreachable!(), - }; - circuit_builder.require_equal( - || "shift check", - rd_written.value() * imm.expr() + outflow.expr(), - inflow * two_pow_total_bits + rs1_read.value(), - )?; - is_lt_config - } - _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), - }; - - let i_insn = IInstructionConfig::::construct_circuit( - circuit_builder, - I::INST_KIND, - imm.expr(), - #[cfg(feature = "u16limb_circuit")] - 0.into(), - rs1_read.register_expr(), - rd_written.register_expr(), - false, - )?; - - Ok(ShiftImmConfig { - i_insn, - imm, - rs1_read, - rd_written, - outflow, - assert_lt_config, - is_lt_config, - }) - } - - fn assign_instance( - config: &Self::InstructionConfig, - instance: &mut [::BaseField], - lk_multiplicity: &mut LkMultiplicity, - step: &StepRecord, - ) -> Result<(), ZKVMError> { - // imm_internal is a precomputed 2**shift. - let imm = InsnRecord::::imm_internal(&step.insn()).0 as u64; - let rs1_read = Value::new_unchecked(step.rs1().unwrap().value); - let rd_written = Value::new(step.rd().unwrap().value.after, lk_multiplicity); - - set_val!(instance, config.imm, imm); - config.rs1_read.assign_value(instance, rs1_read.clone()); - config.rd_written.assign_value(instance, rd_written); - - let outflow = match I::INST_KIND { - InsnKind::SLLI => (rs1_read.as_u64() * imm) >> UInt::::TOTAL_BITS, - InsnKind::SRAI | InsnKind::SRLI => { - if I::INST_KIND == InsnKind::SRAI { - config.is_lt_config.as_ref().unwrap().assign_instance( - instance, - lk_multiplicity, - *rs1_read.as_u16_limbs().last().unwrap() as u64, - )?; - } - - rs1_read.as_u64() & (imm - 1) - } - _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), - }; - - set_val!(instance, config.outflow, outflow); - config - .assert_lt_config - .assign_instance(instance, lk_multiplicity, outflow, imm)?; - - config - .i_insn - .assign_instance(instance, lk_multiplicity, step)?; - - Ok(()) - } -} - #[cfg(test)] mod test { use ceno_emul::{Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32u}; - use ff_ext::GoldilocksExt2; + use ff_ext::{ExtensionField, GoldilocksExt2}; use super::{ShiftImmInstruction, SlliOp, SraiOp, SrliOp}; + #[cfg(not(feature = "u16limb_circuit"))] + use crate::Value; + #[cfg(not(feature = "u16limb_circuit"))] + use crate::instructions::riscv::constants::UInt; + #[cfg(feature = "u16limb_circuit")] + use crate::instructions::riscv::constants::UInt8; + #[cfg(feature = "u16limb_circuit")] + use crate::utils::split_to_u8; use crate::{ - Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, - instructions::{ - Instruction, - riscv::{RIVInstruction, constants::UInt}, - }, + instructions::{Instruction, riscv::RIVInstruction}, scheme::mock_prover::{MOCK_PC_START, MockProver}, structs::ProgramParams, }; + #[cfg(feature = "u16limb_circuit")] + use ff_ext::BabyBearExt4; #[test] fn test_opcode_slli() { - // imm = 3 - verify::("32 << 3", 32, 3, 32 << 3); - verify::("33 << 3", 33, 3, 33 << 3); - // imm = 31 - verify::("32 << 31", 32, 31, 32 << 31); - verify::("33 << 31", 33, 31, 33 << 31); + let cases = [ + // imm = 3 + ("32 << 3", 32, 3, 32 << 3), + ("33 << 3", 33, 3, 33 << 3), + // imm = 31 + ("32 << 31", 32, 31, 32 << 31), + ("33 << 31", 33, 31, 33 << 31), + ]; + + for (name, lhs, imm, expected) in cases { + verify::(name, lhs, imm, expected); + #[cfg(feature = "u16limb_circuit")] + verify::(name, lhs, imm, expected); + } } #[test] fn test_opcode_srai() { - // positive rs1 - // imm = 3 - verify::("32 >> 3", 32, 3, 32 >> 3); - verify::("33 >> 3", 33, 3, 33 >> 3); - // imm = 31 - verify::("32 >> 31", 32, 31, 32 >> 31); - verify::("33 >> 31", 33, 31, 33 >> 31); - - // negative rs1 - // imm = 3 - verify::("-32 >> 3", (-32_i32) as u32, 3, (-32_i32 >> 3) as u32); - verify::("-33 >> 3", (-33_i32) as u32, 3, (-33_i32 >> 3) as u32); - // imm = 31 - verify::("-32 >> 31", (-32_i32) as u32, 31, (-32_i32 >> 31) as u32); - verify::("-33 >> 31", (-33_i32) as u32, 31, (-33_i32 >> 31) as u32); + let cases = [ + // positive rs1 + ("32 >> 3", 32, 3, 32 >> 3), + ("33 >> 3", 33, 3, 33 >> 3), + ("32 >> 31", 32, 31, 32 >> 31), + ("33 >> 31", 33, 31, 33 >> 31), + // negative rs1 + ("-32 >> 3", (-32_i32) as u32, 3, (-32_i32 >> 3) as u32), + ("-33 >> 3", (-33_i32) as u32, 3, (-33_i32 >> 3) as u32), + ("-32 >> 31", (-32_i32) as u32, 31, (-32_i32 >> 31) as u32), + ("-33 >> 31", (-33_i32) as u32, 31, (-33_i32 >> 31) as u32), + ]; + + for (name, lhs, imm, expected) in cases { + verify::(name, lhs, imm, expected); + #[cfg(feature = "u16limb_circuit")] + verify::(name, lhs, imm, expected); + } } #[test] fn test_opcode_srli() { - // imm = 3 - verify::("32 >> 3", 32, 3, 32 >> 3); - verify::("33 >> 3", 33, 3, 33 >> 3); - // imm = 31 - verify::("32 >> 31", 32, 31, 32 >> 31); - verify::("33 >> 31", 33, 31, 33 >> 31); - // rs1 top bit is 1 - verify::("-32 >> 3", (-32_i32) as u32, 3, (-32_i32) as u32 >> 3); + let cases = [ + // imm = 3 + ("32 >> 3", 32, 3, 32 >> 3), + ("33 >> 3", 33, 3, 33 >> 3), + // imm = 31 + ("32 >> 31", 32, 31, 32 >> 31), + ("33 >> 31", 33, 31, 33 >> 31), + // rs1 top bit is 1 + ("-32 >> 3", (-32_i32) as u32, 3, ((-32_i32) as u32) >> 3), + ]; + + for (name, lhs, imm, expected) in cases { + verify::(name, lhs, imm, expected); + #[cfg(feature = "u16limb_circuit")] + verify::(name, lhs, imm, expected); + } } - fn verify( + fn verify( name: &'static str, rs1_read: u32, imm: u32, expected_rd_written: u32, ) { - let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); let (prefix, insn_code, rd_written) = match I::INST_KIND { @@ -283,7 +142,7 @@ mod test { .namespace( || format!("{prefix}_({name})"), |cb| { - let config = ShiftImmInstruction::::construct_circuit( + let config = ShiftImmInstruction::::construct_circuit( cb, &ProgramParams::default(), ); @@ -298,15 +157,18 @@ mod test { .require_equal( || format!("{prefix}_({name})_assert_rd_written"), &mut cb, + #[cfg(not(feature = "u16limb_circuit"))] &UInt::from_const_unchecked( Value::new_unchecked(expected_rd_written) .as_u16_limbs() .to_vec(), ), + #[cfg(feature = "u16limb_circuit")] + &UInt8::from_const_unchecked(split_to_u8::(expected_rd_written)), ) .unwrap(); - let (raw_witin, lkm) = ShiftImmInstruction::::assign_instances( + let (raw_witin, lkm) = ShiftImmInstruction::::assign_instances( &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs b/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs new file mode 100644 index 000000000..0bba35411 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs @@ -0,0 +1,175 @@ +use crate::{ + Value, + circuit_builder::CircuitBuilder, + error::ZKVMError, + gadgets::SignedExtendConfig, + instructions::{ + Instruction, + riscv::{ + RIVInstruction, + constants::{LIMB_BITS, UINT_LIMBS, UInt}, + i_insn::IInstructionConfig, + }, + }, + structs::ProgramParams, + tables::InsnRecord, + witness::LkMultiplicity, +}; +use ceno_emul::{InsnKind, StepRecord}; +use ff_ext::{ExtensionField, FieldInto}; +use gkr_iop::gadgets::AssertLtConfig; +use multilinear_extensions::{Expression, ToExpr, WitIn}; +use std::marker::PhantomData; +use witness::set_val; + +pub struct ShiftImmInstruction(PhantomData<(E, I)>); + +pub struct ShiftImmConfig { + i_insn: IInstructionConfig, + + imm: WitIn, + rs1_read: UInt, + pub rd_written: UInt, + outflow: WitIn, + assert_lt_config: AssertLtConfig, + + // SRAI + is_lt_config: Option>, +} + +impl Instruction for ShiftImmInstruction { + type InstructionConfig = ShiftImmConfig; + + fn name() -> String { + format!("{:?}", I::INST_KIND) + } + + fn construct_circuit( + circuit_builder: &mut CircuitBuilder, + _params: &ProgramParams, + ) -> Result { + // treat bit shifting as a bit "inflow" and "outflow" process, flowing from left to right or vice versa + // this approach simplifies constraint and witness allocation compared to using multiplication/division gadget, + // as the divisor/multiplier is a power of 2. + // + // example: right shift (bit flow from left to right) + // inflow || rs1_read == rd_written || outflow + // in this case, inflow consists of either all 0s or all 1s for sign extension (if the value is signed). + // + // for left shifts, the inflow is always 0: + // rs1_read || inflow == outflow || rd_written + // + // additional constraint: outflow < (1 << shift), which lead to unique solution + + // soundness: take Goldilocks as example, both sides of the equation are 63 bits numbers (<2**63) + // rd * imm + outflow == inflow * 2**32 + rs1 + // 32 + 31. 31. 31 + 32. 32. (Bit widths) + + // Note: `imm` wtns is set to 2**imm (upto 32 bit) just for efficient verification. + let imm = circuit_builder.create_witin(|| "imm"); + let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; + let rd_written = UInt::new(|| "rd_written", circuit_builder)?; + + let outflow = circuit_builder.create_witin(|| "outflow"); + let assert_lt_config = AssertLtConfig::construct_circuit( + circuit_builder, + || "outflow < imm", + outflow.expr(), + imm.expr(), + UINT_LIMBS * LIMB_BITS, + )?; + + let two_pow_total_bits: Expression<_> = (1u64 << UInt::::TOTAL_BITS).into(); + + let is_lt_config = match I::INST_KIND { + InsnKind::SLLI => { + circuit_builder.require_equal( + || "shift check", + rs1_read.value() * imm.expr(), // inflow is zero for this case + outflow.expr() * two_pow_total_bits + rd_written.value(), + )?; + None + } + InsnKind::SRAI | InsnKind::SRLI => { + let (inflow, is_lt_config) = match I::INST_KIND { + InsnKind::SRAI => { + let is_rs1_neg = rs1_read.is_negative(circuit_builder)?; + let ones = imm.expr() - 1; + (is_rs1_neg.expr() * ones, Some(is_rs1_neg)) + } + InsnKind::SRLI => (Expression::ZERO, None), + _ => unreachable!(), + }; + circuit_builder.require_equal( + || "shift check", + rd_written.value() * imm.expr() + outflow.expr(), + inflow * two_pow_total_bits + rs1_read.value(), + )?; + is_lt_config + } + _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), + }; + + let i_insn = IInstructionConfig::::construct_circuit( + circuit_builder, + I::INST_KIND, + imm.expr(), + rs1_read.register_expr(), + rd_written.register_expr(), + false, + )?; + + Ok(ShiftImmConfig { + i_insn, + imm, + rs1_read, + rd_written, + outflow, + assert_lt_config, + is_lt_config, + }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [::BaseField], + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + // imm_internal is a precomputed 2**shift. + let imm = InsnRecord::::imm_internal(&step.insn()).0 as u64; + let rs1_read = Value::new_unchecked(step.rs1().unwrap().value); + let rd_written = Value::new(step.rd().unwrap().value.after, lk_multiplicity); + + set_val!(instance, config.imm, imm); + config.rs1_read.assign_value(instance, rs1_read.clone()); + config.rd_written.assign_value(instance, rd_written); + + let outflow = match I::INST_KIND { + InsnKind::SLLI => (rs1_read.as_u64() * imm) >> UInt::::TOTAL_BITS, + InsnKind::SRAI | InsnKind::SRLI => { + if I::INST_KIND == InsnKind::SRAI { + config.is_lt_config.as_ref().unwrap().assign_instance( + instance, + lk_multiplicity, + *rs1_read.as_u16_limbs().last().unwrap() as u64, + )?; + } + + rs1_read.as_u64() & (imm - 1) + } + _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), + }; + + set_val!(instance, config.outflow, outflow); + config + .assert_lt_config + .assign_instance(instance, lk_multiplicity, outflow, imm)?; + + config + .i_insn + .assign_instance(instance, lk_multiplicity, step)?; + + Ok(()) + } +} diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 092c4b560..6ed08d51f 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -107,9 +107,6 @@ impl InsnRecord { #[cfg(feature = "u16limb_circuit")] pub fn imm_internal(insn: &Instruction) -> (i64, F) { match (insn.kind, InsnFormat::from(insn.kind)) { - // Prepare the immediate for ShiftImmInstruction. - // The shift is implemented as a multiplication/division by 1 << immediate. - (SLLI | SRLI | SRAI, _) => (1 << insn.imm, i64_to_base(1 << insn.imm)), // TODO convert to 2 limbs to support smaller field (LB | LH | LW | LBU | LHU | SB | SH | SW, _) => { (insn.imm as i64, i64_to_base(insn.imm as i64)) diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index 1990bae39..0041776be 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -29,6 +29,17 @@ pub fn split_to_u8>(value: u32) -> Vec { .collect_vec() } +#[allow(dead_code)] +pub fn split_to_limb, const LIMB_BITS: usize>(value: u32) -> Vec { + (0..(u32::BITS as usize / LIMB_BITS)) + .scan(value, |acc, _| { + let limb = ((*acc & ((1 << LIMB_BITS) - 1)) as u8).into(); + *acc >>= LIMB_BITS; + Some(limb) + }) + .collect_vec() +} + /// Compile time evaluated minimum function /// returns min(a, b) pub(crate) const fn const_min(a: usize, b: usize) -> usize {