From a094c65bae8b4030e88092f4537b2f5337a3c1b7 Mon Sep 17 00:00:00 2001 From: Tony Wu Date: Tue, 1 Apr 2025 15:14:42 +1100 Subject: [PATCH 1/8] prepare PR for block-multiplier --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 08f6230c9..2636de644 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] resolver = "2" -members = ["noir-r1cs", "delegated-spartan", "merkle-hash-bench", "prover"] +members = ["block-multiplier", "noir-r1cs", "delegated-spartan", "merkle-hash-bench", "prover"] [workspace.package] edition = "2021" From 94b8b4090263d69cb0f0824daca18ff5d8970bf1 Mon Sep 17 00:00:00 2001 From: Tony Wu Date: Tue, 1 Apr 2025 15:15:52 +1100 Subject: [PATCH 2/8] add block-multiplier crate --- block-multiplier/Cargo.toml | 16 + block-multiplier/benches/bench.rs | 91 ++++++ block-multiplier/src/constants.rs | 126 ++++++++ block-multiplier/src/lib.rs | 471 ++++++++++++++++++++++++++++++ block-multiplier/src/rtz.rs | 208 +++++++++++++ 5 files changed, 912 insertions(+) create mode 100644 block-multiplier/Cargo.toml create mode 100644 block-multiplier/benches/bench.rs create mode 100644 block-multiplier/src/constants.rs create mode 100644 block-multiplier/src/lib.rs create mode 100644 block-multiplier/src/rtz.rs diff --git a/block-multiplier/Cargo.toml b/block-multiplier/Cargo.toml new file mode 100644 index 000000000..09f82a71c --- /dev/null +++ b/block-multiplier/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "block-multiplier" +version = "0.1.0" +edition = "2024" + +[dependencies] +seq-macro = "0.3.5" + +[dev-dependencies] +rand = "0.9.0" +primitive-types = "0.13.1" +criterion = "0.5.1" + +[[bench]] +name = "bench" +harness = false \ No newline at end of file diff --git a/block-multiplier/benches/bench.rs b/block-multiplier/benches/bench.rs new file mode 100644 index 000000000..6888a6c0f --- /dev/null +++ b/block-multiplier/benches/bench.rs @@ -0,0 +1,91 @@ +use criterion::{Criterion, black_box, criterion_group, criterion_main}; +use rand::prelude::StdRng; +use rand::{Rng, SeedableRng}; + +fn bench_block_multiplier(c: &mut Criterion) { + let mut group = c.benchmark_group("block_multiplier"); + + let seed: u64 = rand::random(); + println!("Using random seed for benchmark: {}", seed); + let mut rng = StdRng::seed_from_u64(seed); + + let s0_a = [ + rng.random::(), + rng.random::(), + rng.random::(), + rng.random::(), + ]; + let s0_b = [ + rng.random::(), + rng.random::(), + rng.random::(), + rng.random::(), + ]; + + let v0_a = [ + rng.random::(), + rng.random::(), + rng.random::(), + rng.random::(), + ]; + let v0_b = [ + rng.random::(), + rng.random::(), + rng.random::(), + rng.random::(), + ]; + let v1_a = [ + rng.random::(), + rng.random::(), + rng.random::(), + rng.random::(), + ]; + let v1_b = [ + rng.random::(), + rng.random::(), + rng.random::(), + rng.random::(), + ]; + + let rtz = block_multiplier::rtz::RTZ::set().unwrap(); + + group.bench_function("block_multiplier", |bencher| { + bencher.iter(|| { + block_multiplier::block_multiplier( + &rtz, + black_box(s0_a), + black_box(s0_b), + black_box(v0_a), + black_box(v0_b), + black_box(v1_a), + black_box(v1_b), + ) + }) + }); + + group.finish(); +} + +fn bench_rtz(c: &mut Criterion) { + let mut group = c.benchmark_group("rtz"); + group.bench_function("rtz", |bencher| { + bencher.iter(|| { + let rtz = block_multiplier::rtz::RTZ::set(); + black_box(rtz.is_some()); + drop(rtz); + }) + }); + + group.finish(); +} + +criterion_group!( + name = benches; + config = Criterion::default() + .sample_size(5000) + // Warm up is warm because it literally warms up the pi + .warm_up_time(std::time::Duration::new(1,0)) + .measurement_time(std::time::Duration::new(10,0)); + targets = bench_block_multiplier, bench_rtz +); +criterion_main!(benches); diff --git a/block-multiplier/src/constants.rs b/block-multiplier/src/constants.rs new file mode 100644 index 000000000..7422f6886 --- /dev/null +++ b/block-multiplier/src/constants.rs @@ -0,0 +1,126 @@ +pub const NP0: u64 = 0xc2e1f593efffffff; + +pub const P: [u64; 4] = [ + 0x43e1f593f0000001, + 0x2833e84879b97091, + 0xb85045b68181585d, + 0x30644e72e131a029, +]; + +// R mod P +pub const R: [u64; 4] = [ + 0xac96341c4ffffffb, + 0x36fc76959f60cd29, + 0x666ea36f7879462e, + 0x0e0a77c19a07df2f, +]; + +// R^2 mod P +pub const R2: [u64; 4] = [ + 0x1BB8E645AE216DA7, + 0x53FE3AB1E35C59E3, + 0x8C49833D53BB8085, + 0x0216D0B17F4E44A5, +]; + +// R^-1 mod P +pub const R_INV: [u64; 4] = [ + 0xdc5ba0056db1194e, + 0x090ef5a9e111ec87, + 0xc8260de4aeb85d5d, + 0x15ebf95182c5551c, +]; + +pub const U52_NP0: u64 = 0x1F593EFFFFFFF; +pub const U52_R2: [u64; 5] = [ + 0x0B852D16DA6F5, + 0xC621620CDDCE3, + 0xAF1B95343FFB6, + 0xC3C15E103E7C2, + 0x00281528FA122, +]; + +pub const U52_P: [u64; 5] = [ + 0x1F593F0000001, + 0x4879B9709143E, + 0x181585D2833E8, + 0xA029B85045B68, + 0x030644E72E131, +]; + +pub const F52_P: [f64; 5] = [ + 0x1F593F0000001_u64 as f64, + 0x4879B9709143E_u64 as f64, + 0x181585D2833E8_u64 as f64, + 0xA029B85045B68_u64 as f64, + 0x030644E72E131_u64 as f64, +]; + +pub const MASK52: u64 = 2_u64.pow(52) - 1; +pub const MASK48: u64 = 2_u64.pow(48) - 1; + +pub const U64_I1: [u64; 4] = [ + 0x2d3e8053e396ee4d, + 0xca478dbeab3c92cd, + 0xb2d8f06f77f52a93, + 0x24d6ba07f7aa8f04, +]; +pub const U64_I2: [u64; 4] = [ + 0x18ee753c76f9dc6f, + 0x54ad7e14a329e70f, + 0x2b16366f4f7684df, + 0x133100d71fdf3579, +]; + +pub const U64_I3: [u64; 4] = [ + 0x9BACB016127CBE4E, + 0x0B2051FA31944124, + 0xB064EEA46091C76C, + 0x2B062AAA49F80C7D, +]; +pub const U64_MU0: u64 = 0xc2e1f593efffffff; + +// -- [FP SIMD CONSTANTS] -------------------------------------------------------------------------- +pub const RHO_1: [u64; 5] = [ + 0x82e644ee4c3d2, + 0xf93893c98b1de, + 0xd46fe04d0a4c7, + 0x8f0aad55e2a1f, + 0x005ed0447de83, +]; + +pub const RHO_2: [u64; 5] = [ + 0x74eccce9a797a, + 0x16ddcc30bd8a4, + 0x49ecd3539499e, + 0xb23a6fcc592b8, + 0x00e3bd49f6ee5, +]; + +pub const RHO_3: [u64; 5] = [ + 0x0E8C656567D77, + 0x430D05713AE61, + 0xEA3BA6B167128, + 0xA7DAE55C5A296, + 0x01B4AFD513572, +]; + +pub const RHO_4: [u64; 5] = [ + 0x22E2400E2F27D, + 0x323B46EA19686, + 0xE6C43F0DF672D, + 0x7824014C39E8B, + 0x00C6B48AFE1B8, +]; + +pub const C1: f64 = pow_2(104); // 2.0^104 +pub const C2: f64 = pow_2(104) + pow_2(52); // 2.0^104 + 2.0^52 +// const C3: f64 = pow_2(52); // 2.0^52 +// ------------------------------------------------------------------------------------------------- + +const fn pow_2(n: u32) -> f64 { + // Unfortunately we can't use f64::powi in const fn yet + // This is a workaround that creates the bit pattern directly + let exp = ((n as u64 + 1023) & 0x7FF) << 52; + f64::from_bits(exp) +} diff --git a/block-multiplier/src/lib.rs b/block-multiplier/src/lib.rs new file mode 100644 index 000000000..81fd43974 --- /dev/null +++ b/block-multiplier/src/lib.rs @@ -0,0 +1,471 @@ +#![feature(portable_simd)] + +pub mod constants; +pub mod rtz; + +use crate::constants::*; +use rtz::RTZ; +use seq_macro::seq; +use std::arch::aarch64::vcvtq_f64_u64; +use std::ops::BitAnd; +use std::simd::{Simd, StdFloat, num::SimdFloat}; + +/// Macro to extract a subarray from an array. +/// +/// # Arguments +/// +/// * `$t` - The source array +/// * `$b` - The starting index (base) in the source array +/// * `$l` - The length of the subarray to extract +/// +/// This should be used over t[N..].try_into().unwrap() in getting a subarray. Using try_into+unwrap +/// introduces the eh_personality (exception handling) +/// +/// # Example +/// +/// ``` +/// use block_multiplier::subarray; +/// let array = [1, 2, 3, 4, 5]; +/// let sub = subarray!(array, 1, 3); // Creates [2, 3, 4] +/// ``` +#[macro_export] +macro_rules! subarray { + + ($t:expr, $b: literal, $l: literal) => { + { + use seq_macro::seq; + let t = $t; + let mut s = [0;$l]; + + // The compiler does not detect out-of-bounds when using `for` therefore `seq!` is used here + seq!(i in 0..$l { + s[i] = t[$b+i]; + }); + s + } + }; +} + +#[inline] +pub fn block_multiplier( + _rtz: &RTZ, // Proof that the mode has been set to RTZ + s0_a: [u64; 4], + s0_b: [u64; 4], + v0_a: [u64; 4], + v0_b: [u64; 4], + v1_a: [u64; 4], + v1_b: [u64; 4], +) -> ([u64; 4], [u64; 4], [u64; 4]) { + // -- [VECTOR] --------------------------------------------------------------------------------- + let v0_a = u256_to_u260_shl2_simd(transpose_u256_to_simd([v0_a, v1_a])); + let v0_b = u256_to_u260_shl2_simd(transpose_u256_to_simd([v0_b, v1_b])); + + let mut t: [Simd; 10] = [Simd::splat(0); 10]; + t[0] = Simd::splat(make_initial(1, 0)); + t[9] = Simd::splat(make_initial(0, 6)); + t[1] = Simd::splat(make_initial(2, 1)); + t[8] = Simd::splat(make_initial(6, 7)); + t[2] = Simd::splat(make_initial(3, 2)); + t[7] = Simd::splat(make_initial(7, 8)); + t[3] = Simd::splat(make_initial(4, 3)); + t[6] = Simd::splat(make_initial(8, 9)); + t[4] = Simd::splat(make_initial(10, 4)); + t[5] = Simd::splat(make_initial(9, 10)); + + let avi: Simd = unsafe { vcvtq_f64_u64(v0_a[0].into()).into() }; + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[0].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[0 + 0 + 1] += p_hi.to_bits(); + t[0 + 0] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[1].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[0 + 1 + 1] += p_hi.to_bits(); + t[0 + 1] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[2].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[0 + 2 + 1] += p_hi.to_bits(); + t[0 + 2] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[3].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[0 + 3 + 1] += p_hi.to_bits(); + t[0 + 3] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[4].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[0 + 4 + 1] += p_hi.to_bits(); + t[0 + 4] += p_lo.to_bits(); + let avi: Simd = unsafe { vcvtq_f64_u64(v0_a[1].into()).into() }; + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[0].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1 + 0 + 1] += p_hi.to_bits(); + t[1 + 0] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[1].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1 + 1 + 1] += p_hi.to_bits(); + t[1 + 1] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[2].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1 + 2 + 1] += p_hi.to_bits(); + t[1 + 2] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[3].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1 + 3 + 1] += p_hi.to_bits(); + t[1 + 3] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[4].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1 + 4 + 1] += p_hi.to_bits(); + t[1 + 4] += p_lo.to_bits(); + let avi: Simd = unsafe { vcvtq_f64_u64(v0_a[2].into()).into() }; + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[0].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[2 + 0 + 1] += p_hi.to_bits(); + t[2 + 0] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[1].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[2 + 1 + 1] += p_hi.to_bits(); + t[2 + 1] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[2].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[2 + 2 + 1] += p_hi.to_bits(); + t[2 + 2] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[3].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[2 + 3 + 1] += p_hi.to_bits(); + t[2 + 3] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[4].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[2 + 4 + 1] += p_hi.to_bits(); + t[2 + 4] += p_lo.to_bits(); + let avi: Simd = unsafe { vcvtq_f64_u64(v0_a[3].into()).into() }; + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[0].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[3 + 0 + 1] += p_hi.to_bits(); + t[3 + 0] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[1].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[3 + 1 + 1] += p_hi.to_bits(); + t[3 + 1] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[2].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[3 + 2 + 1] += p_hi.to_bits(); + t[3 + 2] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[3].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[3 + 3 + 1] += p_hi.to_bits(); + t[3 + 3] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[4].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[3 + 4 + 1] += p_hi.to_bits(); + t[3 + 4] += p_lo.to_bits(); + let avi: Simd = unsafe { vcvtq_f64_u64(v0_a[4].into()).into() }; + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[0].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[4 + 0 + 1] += p_hi.to_bits(); + t[4 + 0] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[1].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[4 + 1 + 1] += p_hi.to_bits(); + t[4 + 1] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[2].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[4 + 2 + 1] += p_hi.to_bits(); + t[4 + 2] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[3].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[4 + 3 + 1] += p_hi.to_bits(); + t[4 + 3] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[4].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[4 + 4 + 1] += p_hi.to_bits(); + t[4 + 4] += p_lo.to_bits(); + + t[1] += t[0] >> 52; + t[2] += t[1] >> 52; + t[3] += t[2] >> 52; + t[4] += t[3] >> 52; + + let r0 = smult_noinit_simd(t[0].bitand(Simd::splat(MASK52)), RHO_4); + let r1 = smult_noinit_simd(t[1].bitand(Simd::splat(MASK52)), RHO_3); + let r2 = smult_noinit_simd(t[2].bitand(Simd::splat(MASK52)), RHO_2); + let r3 = smult_noinit_simd(t[3].bitand(Simd::splat(MASK52)), RHO_1); + + let s = [t[4], t[5], t[6], t[7], t[8], t[9]]; + + let s = addv_simd(r3, addv_simd(addv_simd(s, r0), addv_simd(r1, r2))); + + let m = (s[0] * Simd::splat(U52_NP0)).bitand(Simd::splat(MASK52)); + let mp = smult_noinit_simd(m, U52_P); + + let resolve = resolve_simd_add_truncate(s, mp); + let u256_result = u260_to_u256_simd(resolve); + let v = transpose_simd_to_u256(u256_result); + + // --------------------------------------------------------------------------------------------- + // -- [SCALAR] --------------------------------------------------------------------------------- + let mut s0_t = [0_u64; 8]; + let mut carry = 0; + (s0_t[0], carry) = carrying_mul_add(s0_a[0], s0_b[0], s0_t[0], carry); + (s0_t[1], carry) = carrying_mul_add(s0_a[0], s0_b[1], s0_t[1], carry); + (s0_t[2], carry) = carrying_mul_add(s0_a[0], s0_b[2], s0_t[2], carry); + (s0_t[3], carry) = carrying_mul_add(s0_a[0], s0_b[3], s0_t[3], carry); + s0_t[4] = carry; + carry = 0; + (s0_t[1], carry) = carrying_mul_add(s0_a[1], s0_b[0], s0_t[1], carry); + (s0_t[2], carry) = carrying_mul_add(s0_a[1], s0_b[1], s0_t[2], carry); + (s0_t[3], carry) = carrying_mul_add(s0_a[1], s0_b[2], s0_t[3], carry); + (s0_t[4], carry) = carrying_mul_add(s0_a[1], s0_b[3], s0_t[4], carry); + s0_t[5] = carry; + carry = 0; + (s0_t[2], carry) = carrying_mul_add(s0_a[2], s0_b[0], s0_t[2], carry); + (s0_t[3], carry) = carrying_mul_add(s0_a[2], s0_b[1], s0_t[3], carry); + (s0_t[4], carry) = carrying_mul_add(s0_a[2], s0_b[2], s0_t[4], carry); + (s0_t[5], carry) = carrying_mul_add(s0_a[2], s0_b[3], s0_t[5], carry); + s0_t[6] = carry; + carry = 0; + (s0_t[3], carry) = carrying_mul_add(s0_a[3], s0_b[0], s0_t[3], carry); + (s0_t[4], carry) = carrying_mul_add(s0_a[3], s0_b[1], s0_t[4], carry); + (s0_t[5], carry) = carrying_mul_add(s0_a[3], s0_b[2], s0_t[5], carry); + (s0_t[6], carry) = carrying_mul_add(s0_a[3], s0_b[3], s0_t[6], carry); + s0_t[7] = carry; + + let mut s0_r1 = [0_u64; 5]; + (s0_r1[0], s0_r1[1]) = carrying_mul_add(s0_t[0], U64_I3[0], s0_r1[0], 0); + (s0_r1[1], s0_r1[2]) = carrying_mul_add(s0_t[0], U64_I3[1], s0_r1[1], 0); + (s0_r1[2], s0_r1[3]) = carrying_mul_add(s0_t[0], U64_I3[2], s0_r1[2], 0); + (s0_r1[3], s0_r1[4]) = carrying_mul_add(s0_t[0], U64_I3[3], s0_r1[3], 0); + + let mut s0_r2 = [0_u64; 5]; + (s0_r2[0], s0_r2[1]) = carrying_mul_add(s0_t[1], U64_I2[0], s0_r2[0], 0); + (s0_r2[1], s0_r2[2]) = carrying_mul_add(s0_t[1], U64_I2[1], s0_r2[1], 0); + (s0_r2[2], s0_r2[3]) = carrying_mul_add(s0_t[1], U64_I2[2], s0_r2[2], 0); + (s0_r2[3], s0_r2[4]) = carrying_mul_add(s0_t[1], U64_I2[3], s0_r2[3], 0); + + let mut s0_r3 = [0_u64; 5]; + (s0_r3[0], s0_r3[1]) = carrying_mul_add(s0_t[2], U64_I1[0], s0_r3[0], 0); + (s0_r3[1], s0_r3[2]) = carrying_mul_add(s0_t[2], U64_I1[1], s0_r3[1], 0); + (s0_r3[2], s0_r3[3]) = carrying_mul_add(s0_t[2], U64_I1[2], s0_r3[2], 0); + (s0_r3[3], s0_r3[4]) = carrying_mul_add(s0_t[2], U64_I1[3], s0_r3[3], 0); + + let s0_s = addv(addv(subarray!(s0_t, 3, 5), s0_r1), addv(s0_r2, s0_r3)); + + let s0_m = U64_MU0.wrapping_mul(s0_s[0]); + let mut s0_mp = [0_u64; 5]; + (s0_mp[0], s0_mp[1]) = carrying_mul_add(s0_m, P[0], s0_mp[0], 0); + (s0_mp[1], s0_mp[2]) = carrying_mul_add(s0_m, P[1], s0_mp[1], 0); + (s0_mp[2], s0_mp[3]) = carrying_mul_add(s0_m, P[2], s0_mp[2], 0); + (s0_mp[3], s0_mp[4]) = carrying_mul_add(s0_m, P[3], s0_mp[3], 0); + + let s0 = subarray!(addv(s0_s, s0_mp), 1, 4); + // --------------------------------------------------------------------------------------------- + (s0, v[0], v[1]) +} +// ------------------------------------------------------------------------------------------------- + +#[inline(always)] +fn addv(mut a: [u64; N], b: [u64; N]) -> [u64; N] { + let mut carry = 0u64; + for i in 0..N { + let (sum1, overflow1) = a[i].overflowing_add(b[i]); + let (sum2, overflow2) = sum1.overflowing_add(carry); + a[i] = sum2; + carry = (overflow1 as u64) + (overflow2 as u64); + } + a +} + +// -- [SIMD UTILS] --------------------------------------------------------------------------------- + +#[inline(always)] +const fn make_initial(low_count: usize, high_count: usize) -> u64 { + let val = high_count * 0x467 + low_count * 0x433; + -((val as i64 & 0xFFF) << 52) as u64 +} + +#[inline(always)] +fn transpose_u256_to_simd(limbs: [[u64; 4]; 2]) -> [Simd; 4] { + // This does not issue multiple ldp and zip which might be marginally faster. + [ + Simd::from_array([limbs[0][0], limbs[1][0]]), + Simd::from_array([limbs[0][1], limbs[1][1]]), + Simd::from_array([limbs[0][2], limbs[1][2]]), + Simd::from_array([limbs[0][3], limbs[1][3]]), + ] +} + +#[inline(always)] +fn transpose_simd_to_u256(limbs: [Simd; 4]) -> [[u64; 4]; 2] { + let mut result = [[0; 4]; 2]; + for i in 0..limbs.len() { + let tmp = limbs[i].to_array(); + result[0][i] = tmp[0]; + result[1][i] = tmp[1]; + } + result +} + +#[inline(always)] +fn u256_to_u260_shl2_simd(limbs: [Simd; 4]) -> [Simd; 5] { + let [l0, l1, l2, l3] = limbs; + [ + (l0 << 2) & Simd::splat(MASK52), + ((l0 >> 50) | (l1 << 14)) & Simd::splat(MASK52), + ((l1 >> 38) | (l2 << 26)) & Simd::splat(MASK52), + ((l2 >> 26) | (l3 << 38)) & Simd::splat(MASK52), + l3 >> 14, + ] +} + +#[inline(always)] +fn u260_to_u256_simd(limbs: [Simd; 5]) -> [Simd; 4] { + let [l0, l1, l2, l3, l4] = limbs; + [ + l0 | (l1 << 52), + (l1 >> 12) | (l2 << 40), + (l2 >> 24) | (l3 << 28), + (l3 >> 36) | (l4 << 16), + ] +} + +#[inline(always)] +fn smult_noinit_simd(s: Simd, v: [u64; 5]) -> [Simd; 6] { + let mut t = [Simd::splat(0); 6]; + let s: Simd = unsafe { vcvtq_f64_u64(s.into()).into() }; + + for i in 0..v.len() { + let p_hi = s.mul_add(Simd::splat(v[i] as f64), Simd::splat(C1)); + let p_lo = s.mul_add(Simd::splat(v[i] as f64), Simd::splat(C2) - p_hi); + t[i + 1] += p_hi.to_bits(); + t[i] += p_lo.to_bits(); + } + t +} + +#[inline(always)] +fn addv_simd( + mut va: [Simd; N], + vb: [Simd; N], +) -> [Simd; N] { + for i in 0..va.len() { + va[i] += vb[i]; + } + va +} + +#[inline(always)] +fn resolve_simd_add_truncate(s: [Simd; 6], mp: [Simd; 6]) -> [Simd; 5] { + let mut out = [Simd::splat(0); 5]; + let mut carry = (s[0] + mp[0]) >> 52; + for i in 0..5 { + let tmp = s[i + 1] + mp[i + 1] + carry; + out[i] = tmp.bitand(Simd::splat(MASK52)); + carry = tmp >> 52; + } + out +} + +// ------------------------------------------------------------------------------------------------- + +#[inline(always)] +fn carrying_mul_add(a: u64, b: u64, add: u64, carry: u64) -> (u64, u64) { + let c: u128 = a as u128 * b as u128 + carry as u128 + add as u128; + (c as u64, (c >> 64) as u64) +} + +#[cfg(test)] +mod tests { + use crate::{block_multiplier, constants, rtz::RTZ}; + use primitive_types::U256; + use rand::{Rng, SeedableRng, rngs}; + + const OUTPUT_MAX: [u64; 4] = [ + 0x783c14d81ffffffe, + 0xaf982f6f0c8d1edd, + 0x8f5f7492fcfd4f45, + 0x9f37631a3d9cbfac, + ]; + + fn mod_mul(a: U256, b: U256) -> U256 { + let p = U256(constants::P); + let mut c = [0u64; 4]; + c.copy_from_slice(&(a.full_mul(b) % p).0[0..4]); + U256(c) + } + + #[test] + fn test_block_multiplier() { + let mut rng = rngs::StdRng::seed_from_u64(0); + let p = U256(constants::P); + let r = U256(constants::R); + let r_inv = U256(constants::R_INV); + + let mut s0_a_bytes = [0u8; 32]; + let mut s0_b_bytes = [0u8; 32]; + let mut v0_a_bytes = [0u8; 32]; + let mut v0_b_bytes = [0u8; 32]; + let mut v1_a_bytes = [0u8; 32]; + let mut v1_b_bytes = [0u8; 32]; + + let rtz = RTZ::set().unwrap(); + + for _ in 0..100000 { + rng.fill(&mut s0_a_bytes); + rng.fill(&mut s0_b_bytes); + rng.fill(&mut v0_a_bytes); + rng.fill(&mut v0_b_bytes); + rng.fill(&mut v1_a_bytes); + rng.fill(&mut v1_b_bytes); + let s0_a = U256::from_little_endian(&s0_a_bytes) % p; + let s0_b = U256::from_little_endian(&s0_b_bytes) % p; + let v0_a = U256::from_little_endian(&v0_a_bytes) % p; + let v0_b = U256::from_little_endian(&v0_b_bytes) % p; + let v1_a = U256::from_little_endian(&v1_a_bytes) % p; + let v1_b = U256::from_little_endian(&v1_b_bytes) % p; + let s0_a_mont = mod_mul(s0_a, r); + let s0_b_mont = mod_mul(s0_b, r); + let v0_a_mont = mod_mul(v0_a, r); + let v0_b_mont = mod_mul(v0_b, r); + let v1_a_mont = mod_mul(v1_a, r); + let v1_b_mont = mod_mul(v1_b, r); + + let (s0, v0, v1) = block_multiplier( + &rtz, + s0_a_mont.0, + s0_b_mont.0, + v0_a_mont.0, + v0_b_mont.0, + v1_a_mont.0, + v1_b_mont.0, + ); + assert!(U256(s0) < U256(OUTPUT_MAX)); + assert!(U256(v0) < U256(OUTPUT_MAX)); + assert!(U256(v1) < U256(OUTPUT_MAX)); + assert_eq!(mod_mul(U256(s0), r_inv), mod_mul(s0_a, s0_b)); + assert_eq!(mod_mul(U256(v0), r_inv), mod_mul(v0_a, v0_b)); + assert_eq!(mod_mul(U256(v1), r_inv), mod_mul(v1_a, v1_b)); + } + } +} \ No newline at end of file diff --git a/block-multiplier/src/rtz.rs b/block-multiplier/src/rtz.rs new file mode 100644 index 000000000..1f8050d9a --- /dev/null +++ b/block-multiplier/src/rtz.rs @@ -0,0 +1,208 @@ +use std::marker::PhantomData; + +/// round-toward-zero mode (bits 22-23 to 0b11) +const FPCR_RMODE_BITS: u64 = 0b11 << 22; + +#[derive(Debug, Clone, Copy, PartialEq)] +enum FPCRState { + /// FPCR is idle and available for modification + Idle, + /// FPCR is actively being used for RTZ operations + Active, +} + +/// Proof that Round Toward Zero (RTZ) has been set +/// +/// This struct must to be passed as a (unused) reference to any function that requires round toward zero +/// for correct operation. The struct serves as a proof that RTZ is set and we rely on the lifetime introduced +/// by the reference to enforce the ordering of the FPCR operations relative to the multiplication. This +/// way we can prevent the reset of FPCR to bubble up in front of the multiplication. +/// +/// This type provides RAII-style management of the AArch64 FPCR (Floating-point Control Register), +/// specifically for controlling the rounding mode. When created, it sets the rounding mode to +/// "round toward zero" and restores the previous mode when dropped. +/// +/// # Safety +/// +/// This type is not Send because FPCR is a per-core / per OS-thread register. The PhantomData<*mut ()> ensures this. +/// Only one instance can exist per thread at a time, enforced by the FPCR_OWNED thread-local. +/// +/// # Design Notes +/// +/// This type +#[derive(Debug)] +pub struct RTZ { + prev_fpcr: u64, + _no_send: PhantomData<*mut ()>, +} + +thread_local! { + /// Thread-local flag to ensure only one RTZ instance exists per thread. + /// This prevents multiple concurrent modifications to the FPCR register. + static FPCR_OWNED: std::cell::Cell = std::cell::Cell::new(FPCRState::Idle); +} + +impl RTZ { + /// Attempts to create a new RTZ instance, setting the FPCR to round-toward-zero mode. + /// + /// Returns None if another RTZ instance already exists in this thread. + /// + /// # Safety + /// + /// This function uses inline assembly to modify the FPCR register. The operations are safe + /// when used as intended through this API, as it maintains the following invariants: + /// - Only one instance can exist per thread + /// - The previous FPCR value is always restored on drop + /// - The type cannot be sent between threads + #[cfg(target_arch = "aarch64")] + #[inline] + pub fn set() -> Option { + // Try to acquire ownership of FPCR + let state = FPCR_OWNED.with(|owned| { + let observed_state = owned.get(); + if observed_state == FPCRState::Idle { + owned.set(FPCRState::Active); + } + observed_state + }); + + match state { + FPCRState::Idle => { + let mut prev_fpcr: u64; + unsafe { + // Read current FPCR value and set round-toward-zero mode + core::arch::asm!( + // Read current FPCR value + "mrs {prev_fpcr}, fpcr", + "orr {tmp}, {prev_fpcr}, {rmode}", + "msr fpcr, {tmp}", + prev_fpcr = out(reg) prev_fpcr, + tmp = out(reg) _, + rmode = const FPCR_RMODE_BITS, + ); + } + + Some(Self { + prev_fpcr, + _no_send: PhantomData, + }) + } + FPCRState::Active => None, + } + } + + #[cfg(not(target_arch = "aarch64"))] + #[inline] + fn set() -> Option { + todo!() + } + + /// Reads the current value of the FPCR register. + /// + /// This method is primarily intended for debugging and verification purposes. + #[cfg(target_arch = "aarch64")] + #[inline] + pub fn read(&self) -> u64 { + let mut value: u64; + unsafe { + core::arch::asm!( + "mrs {}, fpcr", + out(reg) value, + options(nomem, nostack, preserves_flags) + ); + } + value + } + + #[cfg(not(target_arch = "aarch64"))] + #[inline] + pub fn read(&self) -> u64 { + todo!() + } + + /// Writes a new value to the FPCR register. + /// + /// This is a low-level operation that directly modifies the FPCR register. + /// It should be used with caution as improper values can affect floating-point behavior. + /// + /// # Safety + /// + /// This operation is safe because: + /// - The RTZ instance proves we have exclusive access to FPCR + /// - The write operation is atomic + #[cfg(target_arch = "aarch64")] + #[inline] + pub fn write(&self, value: u64) { + unsafe { + core::arch::asm!( + "msr fpcr, {}", + in(reg) value, + options(nomem, nostack, preserves_flags) + ); + } + } + + #[cfg(not(target_arch = "aarch64"))] + #[inline] + pub fn write(&self, _value: u64) { + todo!() + } +} + +impl Drop for RTZ { + /// Restores the original FPCR value and releases the thread-local lock. + /// + /// This ensures that the floating-point environment is restored to its + /// previous state when the RTZ instance is dropped, maintaining the + /// RAII pattern. + fn drop(&mut self) { + // Restore the original FPCR value + self.write(self.prev_fpcr); + // Release the thread-local lock + FPCR_OWNED.with(|owned| owned.set(FPCRState::Idle)); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + #[cfg(target_arch = "aarch64")] + fn test_rtz_single_instance() { + // First instance should succeed + let rtz1 = RTZ::set(); + assert!(rtz1.is_some()); + let rtz1 = rtz1.unwrap(); + let beginning_state = rtz1.prev_fpcr; + + // Second instance should fail + let rtz2 = RTZ::set(); + assert!(rtz2.is_none()); + + // Drop first instance + drop(rtz1); + + // Now we should be able to create a new instance + let rtz3 = RTZ::set(); + assert!(rtz3.is_some()); + let rtz3 = rtz3.unwrap(); + + assert_eq!(beginning_state, rtz3.prev_fpcr); + } + + #[test] + #[cfg(target_arch = "aarch64")] + fn test_rtz_read_write() { + let rtz = RTZ::set().unwrap(); + let initial = rtz.read(); + + // Verify that the rounding mode bits are set + assert_eq!(initial & FPCR_RMODE_BITS, FPCR_RMODE_BITS); + + // Test write and read back + let test_value = initial & !FPCR_RMODE_BITS; // Clear rounding mode bits + rtz.write(test_value); + assert_eq!(rtz.read(), test_value); + } +} From aca72a72e4ec2090a79fd0f80382e9d7ed4934f0 Mon Sep 17 00:00:00 2001 From: Tony Wu Date: Mon, 7 Apr 2025 10:40:49 +0100 Subject: [PATCH 3/8] implement subtract 2P --- block-multiplier/benches/bench.rs | 24 +- block-multiplier/src/constants.rs | 25 +- block-multiplier/src/lib.rs | 547 ++++++++++++++++++++++++------ 3 files changed, 485 insertions(+), 111 deletions(-) diff --git a/block-multiplier/benches/bench.rs b/block-multiplier/benches/bench.rs index 6888a6c0f..32ed1bdeb 100644 --- a/block-multiplier/benches/bench.rs +++ b/block-multiplier/benches/bench.rs @@ -49,9 +49,29 @@ fn bench_block_multiplier(c: &mut Criterion) { let rtz = block_multiplier::rtz::RTZ::set().unwrap(); - group.bench_function("block_multiplier", |bencher| { + group.bench_function("scalar_mul", |bencher| { bencher.iter(|| { - block_multiplier::block_multiplier( + block_multiplier::scalar_mul( + black_box(s0_a), + black_box(s0_b), + ) + }) + }); + + group.bench_function("simd_mul", |bencher| { + bencher.iter(|| { + block_multiplier::simd_mul( + black_box(v0_a), + black_box(v0_b), + black_box(v1_a), + black_box(v1_b), + ) + }) + }); + + group.bench_function("block_mul", |bencher| { + bencher.iter(|| { + block_multiplier::block_mul( &rtz, black_box(s0_a), black_box(s0_b), diff --git a/block-multiplier/src/constants.rs b/block-multiplier/src/constants.rs index 7422f6886..3e54cb437 100644 --- a/block-multiplier/src/constants.rs +++ b/block-multiplier/src/constants.rs @@ -1,14 +1,21 @@ -pub const NP0: u64 = 0xc2e1f593efffffff; +pub const U64_NP0: u64 = 0xc2e1f593efffffff; -pub const P: [u64; 4] = [ +pub const U64_P: [u64; 4] = [ 0x43e1f593f0000001, 0x2833e84879b97091, 0xb85045b68181585d, 0x30644e72e131a029, ]; +pub const U64_2P: [u64; 4] = [ + 0x87c3eb27e0000002, + 0x5067d090f372e122, + 0x70a08b6d0302b0ba, + 0x60c89ce5c2634053, +]; + // R mod P -pub const R: [u64; 4] = [ +pub const U64_R: [u64; 4] = [ 0xac96341c4ffffffb, 0x36fc76959f60cd29, 0x666ea36f7879462e, @@ -16,7 +23,7 @@ pub const R: [u64; 4] = [ ]; // R^2 mod P -pub const R2: [u64; 4] = [ +pub const U64_R2: [u64; 4] = [ 0x1BB8E645AE216DA7, 0x53FE3AB1E35C59E3, 0x8C49833D53BB8085, @@ -24,7 +31,7 @@ pub const R2: [u64; 4] = [ ]; // R^-1 mod P -pub const R_INV: [u64; 4] = [ +pub const U64_R_INV: [u64; 4] = [ 0xdc5ba0056db1194e, 0x090ef5a9e111ec87, 0xc8260de4aeb85d5d, @@ -48,6 +55,14 @@ pub const U52_P: [u64; 5] = [ 0x030644E72E131, ]; +pub const U52_2P: [u64; 5] = [ + 0x3EB27E0000002, + 0x90F372E12287C, + 0x302B0BA5067D0, + 0x405370A08B6D0, + 0x060C89CE5C263, +]; + pub const F52_P: [f64; 5] = [ 0x1F593F0000001_u64 as f64, 0x4879B9709143E_u64 as f64, diff --git a/block-multiplier/src/lib.rs b/block-multiplier/src/lib.rs index 81fd43974..832254251 100644 --- a/block-multiplier/src/lib.rs +++ b/block-multiplier/src/lib.rs @@ -1,4 +1,5 @@ #![feature(portable_simd)] +#![feature(bigint_helper_methods)] pub mod constants; pub mod rtz; @@ -9,6 +10,10 @@ use seq_macro::seq; use std::arch::aarch64::vcvtq_f64_u64; use std::ops::BitAnd; use std::simd::{Simd, StdFloat, num::SimdFloat}; +use std::simd::cmp::SimdPartialEq; +use std::simd::num::SimdUint; +use std::simd::num::SimdInt; +use std::array; /// Macro to extract a subarray from an array. /// @@ -46,17 +51,78 @@ macro_rules! subarray { }; } -#[inline] -pub fn block_multiplier( - _rtz: &RTZ, // Proof that the mode has been set to RTZ - s0_a: [u64; 4], - s0_b: [u64; 4], +#[inline(always)] +pub fn scalar_mul( + a: [u64; 4], + b: [u64; 4] +) -> [u64; 4] { + // -- [SCALAR] --------------------------------------------------------------------------------- + let mut t = [0_u64; 8]; + + let mut carry = 0; + (t[0], carry) = carrying_mul_add(a[0], b[0], t[0], carry); + (t[1], carry) = carrying_mul_add(a[0], b[1], t[1], carry); + (t[2], carry) = carrying_mul_add(a[0], b[2], t[2], carry); + (t[3], carry) = carrying_mul_add(a[0], b[3], t[3], carry); + t[4] = carry; + carry = 0; + (t[1], carry) = carrying_mul_add(a[1], b[0], t[1], carry); + (t[2], carry) = carrying_mul_add(a[1], b[1], t[2], carry); + (t[3], carry) = carrying_mul_add(a[1], b[2], t[3], carry); + (t[4], carry) = carrying_mul_add(a[1], b[3], t[4], carry); + t[5] = carry; + carry = 0; + (t[2], carry) = carrying_mul_add(a[2], b[0], t[2], carry); + (t[3], carry) = carrying_mul_add(a[2], b[1], t[3], carry); + (t[4], carry) = carrying_mul_add(a[2], b[2], t[4], carry); + (t[5], carry) = carrying_mul_add(a[2], b[3], t[5], carry); + t[6] = carry; + carry = 0; + (t[3], carry) = carrying_mul_add(a[3], b[0], t[3], carry); + (t[4], carry) = carrying_mul_add(a[3], b[1], t[4], carry); + (t[5], carry) = carrying_mul_add(a[3], b[2], t[5], carry); + (t[6], carry) = carrying_mul_add(a[3], b[3], t[6], carry); + t[7] = carry; + + let mut s_r1 = [0_u64; 5]; + (s_r1[0], s_r1[1]) = carrying_mul_add(t[0], U64_I3[0], 0, 0); + (s_r1[1], s_r1[2]) = carrying_mul_add(t[0], U64_I3[1], s_r1[1], 0); + (s_r1[2], s_r1[3]) = carrying_mul_add(t[0], U64_I3[2], s_r1[2], 0); + (s_r1[3], s_r1[4]) = carrying_mul_add(t[0], U64_I3[3], s_r1[3], 0); + + let mut s_r2 = [0_u64; 5]; + (s_r2[0], s_r2[1]) = carrying_mul_add(t[1], U64_I2[0], 0, 0); + (s_r2[1], s_r2[2]) = carrying_mul_add(t[1], U64_I2[1], s_r2[1], 0); + (s_r2[2], s_r2[3]) = carrying_mul_add(t[1], U64_I2[2], s_r2[2], 0); + (s_r2[3], s_r2[4]) = carrying_mul_add(t[1], U64_I2[3], s_r2[3], 0); + + let mut s_r3 = [0_u64; 5]; + (s_r3[0], s_r3[1]) = carrying_mul_add(t[2], U64_I1[0], 0, 0); + (s_r3[1], s_r3[2]) = carrying_mul_add(t[2], U64_I1[1], s_r3[1], 0); + (s_r3[2], s_r3[3]) = carrying_mul_add(t[2], U64_I1[2], s_r3[2], 0); + (s_r3[3], s_r3[4]) = carrying_mul_add(t[2], U64_I1[3], s_r3[3], 0); + + let s = addv(addv(subarray!(t, 3, 5), s_r1), addv(s_r2, s_r3)); + + let m = U64_MU0.wrapping_mul(s[0]); + let mut mp = [0_u64; 5]; + (mp[0], mp[1]) = carrying_mul_add(m, U64_P[0], mp[0], 0); + (mp[1], mp[2]) = carrying_mul_add(m, U64_P[1], mp[1], 0); + (mp[2], mp[3]) = carrying_mul_add(m, U64_P[2], mp[2], 0); + (mp[3], mp[4]) = carrying_mul_add(m, U64_P[3], mp[3], 0); + + let r = reduce_ct(subarray!(addv(s, mp), 1, 4)); + // --------------------------------------------------------------------------------------------- + r +} + +#[inline(always)] +pub fn simd_mul( v0_a: [u64; 4], v0_b: [u64; 4], v1_a: [u64; 4], - v1_b: [u64; 4], -) -> ([u64; 4], [u64; 4], [u64; 4]) { - // -- [VECTOR] --------------------------------------------------------------------------------- + v1_b: [u64; 4] +) -> ([u64; 4], [u64; 4]) { let v0_a = u256_to_u260_shl2_simd(transpose_u256_to_simd([v0_a, v1_a])); let v0_b = u256_to_u260_shl2_simd(transpose_u256_to_simd([v0_b, v1_b])); @@ -98,6 +164,7 @@ pub fn block_multiplier( let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); t[0 + 4 + 1] += p_hi.to_bits(); t[0 + 4] += p_lo.to_bits(); + let avi: Simd = unsafe { vcvtq_f64_u64(v0_a[1].into()).into() }; let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[0].into()).into() }; let p_hi = avi.mul_add(bvj, Simd::splat(C1)); @@ -124,6 +191,7 @@ pub fn block_multiplier( let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); t[1 + 4 + 1] += p_hi.to_bits(); t[1 + 4] += p_lo.to_bits(); + let avi: Simd = unsafe { vcvtq_f64_u64(v0_a[2].into()).into() }; let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[0].into()).into() }; let p_hi = avi.mul_add(bvj, Simd::splat(C1)); @@ -150,6 +218,7 @@ pub fn block_multiplier( let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); t[2 + 4 + 1] += p_hi.to_bits(); t[2 + 4] += p_lo.to_bits(); + let avi: Simd = unsafe { vcvtq_f64_u64(v0_a[3].into()).into() }; let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[0].into()).into() }; let p_hi = avi.mul_add(bvj, Simd::splat(C1)); @@ -176,6 +245,7 @@ pub fn block_multiplier( let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); t[3 + 4 + 1] += p_hi.to_bits(); t[3 + 4] += p_lo.to_bits(); + let avi: Simd = unsafe { vcvtq_f64_u64(v0_a[4].into()).into() }; let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[0].into()).into() }; let p_hi = avi.mul_add(bvj, Simd::splat(C1)); @@ -213,77 +283,267 @@ pub fn block_multiplier( let r2 = smult_noinit_simd(t[2].bitand(Simd::splat(MASK52)), RHO_2); let r3 = smult_noinit_simd(t[3].bitand(Simd::splat(MASK52)), RHO_1); - let s = [t[4], t[5], t[6], t[7], t[8], t[9]]; - - let s = addv_simd(r3, addv_simd(addv_simd(s, r0), addv_simd(r1, r2))); + let s = [ + r0[0] + r1[0] + r2[0] + r3[0] + t[4], + r0[1] + r1[1] + r2[1] + r3[1] + t[5], + r0[2] + r1[2] + r2[2] + r3[2] + t[6], + r0[3] + r1[3] + r2[3] + r3[3] + t[7], + r0[4] + r1[4] + r2[4] + r3[4] + t[8], + r0[5] + r1[5] + r2[5] + r3[5] + t[9], + ]; let m = (s[0] * Simd::splat(U52_NP0)).bitand(Simd::splat(MASK52)); let mp = smult_noinit_simd(m, U52_P); - let resolve = resolve_simd_add_truncate(s, mp); - let u256_result = u260_to_u256_simd(resolve); + let reduced = reduce_ct_simd(addv_simd(s, mp)); + let u256_result = u260_to_u256_simd(reduced); let v = transpose_simd_to_u256(u256_result); + (v[0], v[1]) +} - // --------------------------------------------------------------------------------------------- - // -- [SCALAR] --------------------------------------------------------------------------------- - let mut s0_t = [0_u64; 8]; +#[inline(always)] +pub fn block_mul( + _rtz: &RTZ, // Proof that the mode has been set to RTZ + s0_a: [u64; 4], + s0_b: [u64; 4], + v0_a: [u64; 4], + v0_b: [u64; 4], + v1_a: [u64; 4], + v1_b: [u64; 4], +) -> ([u64; 4], [u64; 4], [u64; 4]) { + // -- [SCALAR AB MULT] -------------------------------------------------------------------------------------------- + let mut s_t = [0_u64; 8]; let mut carry = 0; - (s0_t[0], carry) = carrying_mul_add(s0_a[0], s0_b[0], s0_t[0], carry); - (s0_t[1], carry) = carrying_mul_add(s0_a[0], s0_b[1], s0_t[1], carry); - (s0_t[2], carry) = carrying_mul_add(s0_a[0], s0_b[2], s0_t[2], carry); - (s0_t[3], carry) = carrying_mul_add(s0_a[0], s0_b[3], s0_t[3], carry); - s0_t[4] = carry; + (s_t[0], carry) = carrying_mul_add(s0_a[0], s0_b[0], s_t[0], carry); + (s_t[1], carry) = carrying_mul_add(s0_a[0], s0_b[1], s_t[1], carry); + (s_t[2], carry) = carrying_mul_add(s0_a[0], s0_b[2], s_t[2], carry); + (s_t[3], carry) = carrying_mul_add(s0_a[0], s0_b[3], s_t[3], carry); + s_t[4] = carry; carry = 0; - (s0_t[1], carry) = carrying_mul_add(s0_a[1], s0_b[0], s0_t[1], carry); - (s0_t[2], carry) = carrying_mul_add(s0_a[1], s0_b[1], s0_t[2], carry); - (s0_t[3], carry) = carrying_mul_add(s0_a[1], s0_b[2], s0_t[3], carry); - (s0_t[4], carry) = carrying_mul_add(s0_a[1], s0_b[3], s0_t[4], carry); - s0_t[5] = carry; + (s_t[1], carry) = carrying_mul_add(s0_a[1], s0_b[0], s_t[1], carry); + (s_t[2], carry) = carrying_mul_add(s0_a[1], s0_b[1], s_t[2], carry); + (s_t[3], carry) = carrying_mul_add(s0_a[1], s0_b[2], s_t[3], carry); + (s_t[4], carry) = carrying_mul_add(s0_a[1], s0_b[3], s_t[4], carry); + s_t[5] = carry; carry = 0; - (s0_t[2], carry) = carrying_mul_add(s0_a[2], s0_b[0], s0_t[2], carry); - (s0_t[3], carry) = carrying_mul_add(s0_a[2], s0_b[1], s0_t[3], carry); - (s0_t[4], carry) = carrying_mul_add(s0_a[2], s0_b[2], s0_t[4], carry); - (s0_t[5], carry) = carrying_mul_add(s0_a[2], s0_b[3], s0_t[5], carry); - s0_t[6] = carry; + (s_t[2], carry) = carrying_mul_add(s0_a[2], s0_b[0], s_t[2], carry); + (s_t[3], carry) = carrying_mul_add(s0_a[2], s0_b[1], s_t[3], carry); + (s_t[4], carry) = carrying_mul_add(s0_a[2], s0_b[2], s_t[4], carry); + (s_t[5], carry) = carrying_mul_add(s0_a[2], s0_b[3], s_t[5], carry); + s_t[6] = carry; carry = 0; - (s0_t[3], carry) = carrying_mul_add(s0_a[3], s0_b[0], s0_t[3], carry); - (s0_t[4], carry) = carrying_mul_add(s0_a[3], s0_b[1], s0_t[4], carry); - (s0_t[5], carry) = carrying_mul_add(s0_a[3], s0_b[2], s0_t[5], carry); - (s0_t[6], carry) = carrying_mul_add(s0_a[3], s0_b[3], s0_t[6], carry); - s0_t[7] = carry; - - let mut s0_r1 = [0_u64; 5]; - (s0_r1[0], s0_r1[1]) = carrying_mul_add(s0_t[0], U64_I3[0], s0_r1[0], 0); - (s0_r1[1], s0_r1[2]) = carrying_mul_add(s0_t[0], U64_I3[1], s0_r1[1], 0); - (s0_r1[2], s0_r1[3]) = carrying_mul_add(s0_t[0], U64_I3[2], s0_r1[2], 0); - (s0_r1[3], s0_r1[4]) = carrying_mul_add(s0_t[0], U64_I3[3], s0_r1[3], 0); - - let mut s0_r2 = [0_u64; 5]; - (s0_r2[0], s0_r2[1]) = carrying_mul_add(s0_t[1], U64_I2[0], s0_r2[0], 0); - (s0_r2[1], s0_r2[2]) = carrying_mul_add(s0_t[1], U64_I2[1], s0_r2[1], 0); - (s0_r2[2], s0_r2[3]) = carrying_mul_add(s0_t[1], U64_I2[2], s0_r2[2], 0); - (s0_r2[3], s0_r2[4]) = carrying_mul_add(s0_t[1], U64_I2[3], s0_r2[3], 0); - - let mut s0_r3 = [0_u64; 5]; - (s0_r3[0], s0_r3[1]) = carrying_mul_add(s0_t[2], U64_I1[0], s0_r3[0], 0); - (s0_r3[1], s0_r3[2]) = carrying_mul_add(s0_t[2], U64_I1[1], s0_r3[1], 0); - (s0_r3[2], s0_r3[3]) = carrying_mul_add(s0_t[2], U64_I1[2], s0_r3[2], 0); - (s0_r3[3], s0_r3[4]) = carrying_mul_add(s0_t[2], U64_I1[3], s0_r3[3], 0); - - let s0_s = addv(addv(subarray!(s0_t, 3, 5), s0_r1), addv(s0_r2, s0_r3)); - - let s0_m = U64_MU0.wrapping_mul(s0_s[0]); - let mut s0_mp = [0_u64; 5]; - (s0_mp[0], s0_mp[1]) = carrying_mul_add(s0_m, P[0], s0_mp[0], 0); - (s0_mp[1], s0_mp[2]) = carrying_mul_add(s0_m, P[1], s0_mp[1], 0); - (s0_mp[2], s0_mp[3]) = carrying_mul_add(s0_m, P[2], s0_mp[2], 0); - (s0_mp[3], s0_mp[4]) = carrying_mul_add(s0_m, P[3], s0_mp[3], 0); - - let s0 = subarray!(addv(s0_s, s0_mp), 1, 4); - // --------------------------------------------------------------------------------------------- + (s_t[3], carry) = carrying_mul_add(s0_a[3], s0_b[0], s_t[3], carry); + (s_t[4], carry) = carrying_mul_add(s0_a[3], s0_b[1], s_t[4], carry); + (s_t[5], carry) = carrying_mul_add(s0_a[3], s0_b[2], s_t[5], carry); + (s_t[6], carry) = carrying_mul_add(s0_a[3], s0_b[3], s_t[6], carry); + s_t[7] = carry; + // ---------------------------------------------------------------------------------------------------------------- + // -- [VECTOR AB MULT] -------------------------------------------------------------------------------------------- + let v0_a = u256_to_u260_shl2_simd(transpose_u256_to_simd([v0_a, v1_a])); + let v0_b = u256_to_u260_shl2_simd(transpose_u256_to_simd([v0_b, v1_b])); + + let mut t: [Simd; 10] = [Simd::splat(0); 10]; + t[0] = Simd::splat(make_initial(1, 0)); + t[9] = Simd::splat(make_initial(0, 6)); + t[1] = Simd::splat(make_initial(2, 1)); + t[8] = Simd::splat(make_initial(6, 7)); + t[2] = Simd::splat(make_initial(3, 2)); + t[7] = Simd::splat(make_initial(7, 8)); + t[3] = Simd::splat(make_initial(4, 3)); + t[6] = Simd::splat(make_initial(8, 9)); + t[4] = Simd::splat(make_initial(10, 4)); + t[5] = Simd::splat(make_initial(9, 10)); + + let avi: Simd = unsafe { vcvtq_f64_u64(v0_a[0].into()).into() }; + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[0].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[0 + 0 + 1] += p_hi.to_bits(); + t[0 + 0] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[1].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[0 + 1 + 1] += p_hi.to_bits(); + t[0 + 1] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[2].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[0 + 2 + 1] += p_hi.to_bits(); + t[0 + 2] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[3].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[0 + 3 + 1] += p_hi.to_bits(); + t[0 + 3] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[4].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[0 + 4 + 1] += p_hi.to_bits(); + t[0 + 4] += p_lo.to_bits(); + let avi: Simd = unsafe { vcvtq_f64_u64(v0_a[1].into()).into() }; + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[0].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1 + 0 + 1] += p_hi.to_bits(); + t[1 + 0] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[1].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1 + 1 + 1] += p_hi.to_bits(); + t[1 + 1] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[2].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1 + 2 + 1] += p_hi.to_bits(); + t[1 + 2] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[3].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1 + 3 + 1] += p_hi.to_bits(); + t[1 + 3] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[4].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1 + 4 + 1] += p_hi.to_bits(); + t[1 + 4] += p_lo.to_bits(); + let avi: Simd = unsafe { vcvtq_f64_u64(v0_a[2].into()).into() }; + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[0].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[2 + 0 + 1] += p_hi.to_bits(); + t[2 + 0] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[1].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[2 + 1 + 1] += p_hi.to_bits(); + t[2 + 1] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[2].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[2 + 2 + 1] += p_hi.to_bits(); + t[2 + 2] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[3].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[2 + 3 + 1] += p_hi.to_bits(); + t[2 + 3] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[4].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[2 + 4 + 1] += p_hi.to_bits(); + t[2 + 4] += p_lo.to_bits(); + let avi: Simd = unsafe { vcvtq_f64_u64(v0_a[3].into()).into() }; + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[0].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[3 + 0 + 1] += p_hi.to_bits(); + t[3 + 0] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[1].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[3 + 1 + 1] += p_hi.to_bits(); + t[3 + 1] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[2].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[3 + 2 + 1] += p_hi.to_bits(); + t[3 + 2] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[3].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[3 + 3 + 1] += p_hi.to_bits(); + t[3 + 3] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[4].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[3 + 4 + 1] += p_hi.to_bits(); + t[3 + 4] += p_lo.to_bits(); + let avi: Simd = unsafe { vcvtq_f64_u64(v0_a[4].into()).into() }; + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[0].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[4 + 0 + 1] += p_hi.to_bits(); + t[4 + 0] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[1].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[4 + 1 + 1] += p_hi.to_bits(); + t[4 + 1] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[2].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[4 + 2 + 1] += p_hi.to_bits(); + t[4 + 2] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[3].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[4 + 3 + 1] += p_hi.to_bits(); + t[4 + 3] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_b[4].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[4 + 4 + 1] += p_hi.to_bits(); + t[4 + 4] += p_lo.to_bits(); + // ---------------------------------------------------------------------------------------------------------------- + // -- [VECTOR REDUCE] --------------------------------------------------------------------------------------------- + t[1] += t[0] >> 52; + t[2] += t[1] >> 52; + t[3] += t[2] >> 52; + t[4] += t[3] >> 52; + + let r0 = smult_noinit_simd(t[0].bitand(Simd::splat(MASK52)), RHO_4); + let r1 = smult_noinit_simd(t[1].bitand(Simd::splat(MASK52)), RHO_3); + let r2 = smult_noinit_simd(t[2].bitand(Simd::splat(MASK52)), RHO_2); + let r3 = smult_noinit_simd(t[3].bitand(Simd::splat(MASK52)), RHO_1); + + let s = [ + r0[0] + r1[0] + r2[0] + r3[0] + t[4], + r0[1] + r1[1] + r2[1] + r3[1] + t[5], + r0[2] + r1[2] + r2[2] + r3[2] + t[6], + r0[3] + r1[3] + r2[3] + r3[3] + t[7], + r0[4] + r1[4] + r2[4] + r3[4] + t[8], + r0[5] + r1[5] + r2[5] + r3[5] + t[9], + ]; + // ---------------------------------------------------------------------------------------------------------------- + // -- [SCALAR REDUCE] --------------------------------------------------------------------------------------------- + let mut s_r1 = [0_u64; 5]; + (s_r1[0], s_r1[1]) = carrying_mul_add(s_t[0], U64_I3[0], 0, 0); + (s_r1[1], s_r1[2]) = carrying_mul_add(s_t[0], U64_I3[1], s_r1[1], 0); + (s_r1[2], s_r1[3]) = carrying_mul_add(s_t[0], U64_I3[2], s_r1[2], 0); + (s_r1[3], s_r1[4]) = carrying_mul_add(s_t[0], U64_I3[3], s_r1[3], 0); + + let mut s_r2 = [0_u64; 5]; + (s_r2[0], s_r2[1]) = carrying_mul_add(s_t[1], U64_I2[0], 0, 0); + (s_r2[1], s_r2[2]) = carrying_mul_add(s_t[1], U64_I2[1], s_r2[1], 0); + (s_r2[2], s_r2[3]) = carrying_mul_add(s_t[1], U64_I2[2], s_r2[2], 0); + (s_r2[3], s_r2[4]) = carrying_mul_add(s_t[1], U64_I2[3], s_r2[3], 0); + + let mut s_r3 = [0_u64; 5]; + (s_r3[0], s_r3[1]) = carrying_mul_add(s_t[2], U64_I1[0], 0, 0); + (s_r3[1], s_r3[2]) = carrying_mul_add(s_t[2], U64_I1[1], s_r3[1], 0); + (s_r3[2], s_r3[3]) = carrying_mul_add(s_t[2], U64_I1[2], s_r3[2], 0); + (s_r3[3], s_r3[4]) = carrying_mul_add(s_t[2], U64_I1[3], s_r3[3], 0); + + let s_s = addv(addv(subarray!(s_t, 3, 5), s_r1), addv(s_r2, s_r3)); + // ---------------------------------------------------------------------------------------------------------------- + // -- [FINAL] ----------------------------------------------------------------------------------------------------- + let s_m = U64_MU0.wrapping_mul(s_s[0]); + let mut s_mp = [0_u64; 5]; + (s_mp[0], s_mp[1]) = carrying_mul_add(s_m, U64_P[0], 0, 0); + (s_mp[1], s_mp[2]) = carrying_mul_add(s_m, U64_P[1], s_mp[1], 0); + (s_mp[2], s_mp[3]) = carrying_mul_add(s_m, U64_P[2], s_mp[2], 0); + (s_mp[3], s_mp[4]) = carrying_mul_add(s_m, U64_P[3], s_mp[3], 0); + let s0 = reduce_ct(subarray!(addv(s_s, s_mp), 1, 4)); + + let m = (s[0] * Simd::splat(U52_NP0)).bitand(Simd::splat(MASK52)); + let mp = smult_noinit_simd(m, U52_P); + let resolve = reduce_ct_simd(addv_simd(s, mp)); + let u256_result = u260_to_u256_simd(resolve); + let v = transpose_simd_to_u256(u256_result); + // ---------------------------------------------------------------------------------------------------------------- (s0, v[0], v[1]) } -// ------------------------------------------------------------------------------------------------- +// -------------------------------------------------------------------------------------------------------------------- #[inline(always)] fn addv(mut a: [u64; N], b: [u64; N]) -> [u64; N] { @@ -306,7 +566,7 @@ const fn make_initial(low_count: usize, high_count: usize) -> u64 { } #[inline(always)] -fn transpose_u256_to_simd(limbs: [[u64; 4]; 2]) -> [Simd; 4] { +pub fn transpose_u256_to_simd(limbs: [[u64; 4]; 2]) -> [Simd; 4] { // This does not issue multiple ldp and zip which might be marginally faster. [ Simd::from_array([limbs[0][0], limbs[1][0]]), @@ -318,17 +578,18 @@ fn transpose_u256_to_simd(limbs: [[u64; 4]; 2]) -> [Simd; 4] { #[inline(always)] fn transpose_simd_to_u256(limbs: [Simd; 4]) -> [[u64; 4]; 2] { - let mut result = [[0; 4]; 2]; - for i in 0..limbs.len() { - let tmp = limbs[i].to_array(); - result[0][i] = tmp[0]; - result[1][i] = tmp[1]; - } - result + let tmp0 = limbs[0].to_array(); + let tmp1 = limbs[1].to_array(); + let tmp2 = limbs[2].to_array(); + let tmp3 = limbs[3].to_array(); + [ + [tmp0[0], tmp1[0], tmp2[0], tmp3[0]], + [tmp0[1], tmp1[1], tmp2[1], tmp3[1]], + ] } #[inline(always)] -fn u256_to_u260_shl2_simd(limbs: [Simd; 4]) -> [Simd; 5] { +pub fn u256_to_u260_shl2_simd(limbs: [Simd; 4]) -> [Simd; 5] { let [l0, l1, l2, l3] = limbs; [ (l0 << 2) & Simd::splat(MASK52), @@ -355,15 +616,78 @@ fn smult_noinit_simd(s: Simd, v: [u64; 5]) -> [Simd; 6] { let mut t = [Simd::splat(0); 6]; let s: Simd = unsafe { vcvtq_f64_u64(s.into()).into() }; - for i in 0..v.len() { - let p_hi = s.mul_add(Simd::splat(v[i] as f64), Simd::splat(C1)); - let p_lo = s.mul_add(Simd::splat(v[i] as f64), Simd::splat(C2) - p_hi); - t[i + 1] += p_hi.to_bits(); - t[i] += p_lo.to_bits(); - } + let p_hi_0 = s.mul_add(Simd::splat(v[0] as f64), Simd::splat(C1)); + let p_lo_0 = s.mul_add(Simd::splat(v[0] as f64), Simd::splat(C2) - p_hi_0); + t[1] += p_hi_0.to_bits(); + t[0] += p_lo_0.to_bits(); + + let p_hi_1 = s.mul_add(Simd::splat(v[1] as f64), Simd::splat(C1)); + let p_lo_1 = s.mul_add(Simd::splat(v[1] as f64), Simd::splat(C2) - p_hi_1); + t[2] += p_hi_1.to_bits(); + t[1] += p_lo_1.to_bits(); + + let p_hi_2 = s.mul_add(Simd::splat(v[2] as f64), Simd::splat(C1)); + let p_lo_2 = s.mul_add(Simd::splat(v[2] as f64), Simd::splat(C2) - p_hi_2); + t[3] += p_hi_2.to_bits(); + t[2] += p_lo_2.to_bits(); + + let p_hi_3 = s.mul_add(Simd::splat(v[3] as f64), Simd::splat(C1)); + let p_lo_3 = s.mul_add(Simd::splat(v[3] as f64), Simd::splat(C2) - p_hi_3); + t[4] += p_hi_3.to_bits(); + t[3] += p_lo_3.to_bits(); + + let p_hi_4 = s.mul_add(Simd::splat(v[4] as f64), Simd::splat(C1)); + let p_lo_4 = s.mul_add(Simd::splat(v[4] as f64), Simd::splat(C2) - p_hi_4); + t[5] += p_hi_4.to_bits(); + t[4] += p_lo_4.to_bits(); + t } +#[inline(always)] +/// Resolve the carry bits in the upper parts 12b and reduce the result to within < 3p +pub fn reduce_ct_simd(red: [Simd; 6]) -> [Simd; 5] { + // The lowest limb contains carries that still need to be applied. + let mut borrow: Simd = (red[0] >> 52).cast(); + let a = [red[1], red[2], red[3], red[4], red[5]]; + + // To reduce Check whether the most significant bit is set + let mask = (a[4] >> 47).bitand(Simd::splat(1)).simd_eq(Simd::splat(0)); + + // Select values based on the mask: if mask lane is true, use zeros, else use U52_2P + let zeros = [Simd::splat(0); 5]; + let twop = U52_2P.map(|pi| Simd::splat(pi)); + let b: [_; 5] = array::from_fn(|i| mask.select(zeros[i], twop[i])); + + let mut c = [Simd::splat(0); 5]; + for i in 0..c.len() { + let tmp: Simd = a[i].cast::() - b[i].cast() + borrow; + c[i] = tmp.cast().bitand(Simd::splat(MASK52)); + borrow = tmp >> 52 + } + + c +} + +#[inline(always)] +pub fn reduce_ct(a: [u64; 4]) -> [u64; 4] { + let b = [[0_u64; 4], U64_2P]; + let msb = (a[3] >> 63) & 1; + sub(a, b[msb as usize]) +} + +#[inline(always)] +pub fn sub(a: [u64; N], b: [u64; N]) -> [u64; N] { + let mut borrow: i128 = 0; + let mut c = [0; N]; + for i in 0..N { + let tmp = a[i] as i128 - b[i] as i128 + borrow as i128; + c[i] = tmp as u64; + borrow = tmp >> 64 + } + c +} + #[inline(always)] fn addv_simd( mut va: [Simd; N], @@ -374,21 +698,7 @@ fn addv_simd( } va } - -#[inline(always)] -fn resolve_simd_add_truncate(s: [Simd; 6], mp: [Simd; 6]) -> [Simd; 5] { - let mut out = [Simd::splat(0); 5]; - let mut carry = (s[0] + mp[0]) >> 52; - for i in 0..5 { - let tmp = s[i + 1] + mp[i + 1] + carry; - out[i] = tmp.bitand(Simd::splat(MASK52)); - carry = tmp >> 52; - } - out -} - // ------------------------------------------------------------------------------------------------- - #[inline(always)] fn carrying_mul_add(a: u64, b: u64, add: u64, carry: u64) -> (u64, u64) { let c: u128 = a as u128 * b as u128 + carry as u128 + add as u128; @@ -397,7 +707,7 @@ fn carrying_mul_add(a: u64, b: u64, add: u64, carry: u64) -> (u64, u64) { #[cfg(test)] mod tests { - use crate::{block_multiplier, constants, rtz::RTZ}; + use crate::{block_mul, scalar_mul, constants, rtz::RTZ}; use primitive_types::U256; use rand::{Rng, SeedableRng, rngs}; @@ -409,18 +719,18 @@ mod tests { ]; fn mod_mul(a: U256, b: U256) -> U256 { - let p = U256(constants::P); + let p = U256(constants::U64_P); let mut c = [0u64; 4]; c.copy_from_slice(&(a.full_mul(b) % p).0[0..4]); U256(c) } #[test] - fn test_block_multiplier() { + fn test_block_mul() { let mut rng = rngs::StdRng::seed_from_u64(0); - let p = U256(constants::P); - let r = U256(constants::R); - let r_inv = U256(constants::R_INV); + let p = U256(constants::U64_P); + let r = U256(constants::U64_R); + let r_inv = U256(constants::U64_R_INV); let mut s0_a_bytes = [0u8; 32]; let mut s0_b_bytes = [0u8; 32]; @@ -451,7 +761,7 @@ mod tests { let v1_a_mont = mod_mul(v1_a, r); let v1_b_mont = mod_mul(v1_b, r); - let (s0, v0, v1) = block_multiplier( + let (s0, v0, v1) = block_mul( &rtz, s0_a_mont.0, s0_b_mont.0, @@ -468,4 +778,33 @@ mod tests { assert_eq!(mod_mul(U256(v1), r_inv), mod_mul(v1_a, v1_b)); } } -} \ No newline at end of file + + #[test] + fn test_scalar_mul() { + let mut rng = rngs::StdRng::seed_from_u64(0); + let p = U256(constants::U64_P); + let r = U256(constants::U64_R); + let r_inv = U256(constants::U64_R_INV); + + let mut s0_a_bytes = [0u8; 32]; + let mut s0_b_bytes = [0u8; 32]; + + let _rtz = RTZ::set().unwrap(); + + for _ in 0..100000 { + rng.fill(&mut s0_a_bytes); + rng.fill(&mut s0_b_bytes); + let s0_a = U256::from_little_endian(&s0_a_bytes) % p; + let s0_b = U256::from_little_endian(&s0_b_bytes) % p; + let s0_a_mont = mod_mul(s0_a, r); + let s0_b_mont = mod_mul(s0_b, r); + + let s0 = scalar_mul( + s0_a_mont.0, + s0_b_mont.0, + ); + assert!(U256(s0) < U256(OUTPUT_MAX)); + assert_eq!(mod_mul(U256(s0), r_inv), mod_mul(s0_a, s0_b)); + } + } +} From 8a52aa7da086691bbe37addee832c1f4913ed68e Mon Sep 17 00:00:00 2001 From: Tony Wu Date: Thu, 10 Apr 2025 05:16:57 +0100 Subject: [PATCH 4/8] super naive squaring, but good performance! --- block-multiplier/benches/bench.rs | 28 ++ block-multiplier/src/lib.rs | 552 +++++++++++++++++++++++++++++- 2 files changed, 579 insertions(+), 1 deletion(-) diff --git a/block-multiplier/benches/bench.rs b/block-multiplier/benches/bench.rs index 32ed1bdeb..524c9130e 100644 --- a/block-multiplier/benches/bench.rs +++ b/block-multiplier/benches/bench.rs @@ -58,6 +58,23 @@ fn bench_block_multiplier(c: &mut Criterion) { }) }); + group.bench_function("scalar_sqr", |bencher| { + bencher.iter(|| { + block_multiplier::scalar_sqr( + black_box(s0_a), + ) + }) + }); + + group.bench_function("simd_sqr", |bencher| { + bencher.iter(|| { + block_multiplier::simd_sqr( + black_box(v0_a), + black_box(v1_a), + ) + }) + }); + group.bench_function("simd_mul", |bencher| { bencher.iter(|| { block_multiplier::simd_mul( @@ -83,6 +100,17 @@ fn bench_block_multiplier(c: &mut Criterion) { }) }); + group.bench_function("block_sqr", |bencher| { + bencher.iter(|| { + block_multiplier::block_sqr( + &rtz, + black_box(s0_a), + black_box(v0_a), + black_box(v1_a), + ) + }) + }); + group.finish(); } diff --git a/block-multiplier/src/lib.rs b/block-multiplier/src/lib.rs index 832254251..aec75a99a 100644 --- a/block-multiplier/src/lib.rs +++ b/block-multiplier/src/lib.rs @@ -51,6 +51,70 @@ macro_rules! subarray { }; } +#[inline(always)] +pub fn scalar_sqr( + a: [u64; 4], +) -> [u64; 4] { + // -- [SCALAR] --------------------------------------------------------------------------------- + let mut t = [0_u64; 8]; + + let mut carry = 0; + (t[0], carry) = carrying_mul_add(a[0], a[0], t[0], carry); + (t[1], carry) = carrying_mul_add(a[0], a[1], t[1], carry); + (t[2], carry) = carrying_mul_add(a[0], a[2], t[2], carry); + (t[3], carry) = carrying_mul_add(a[0], a[3], t[3], carry); + t[4] = carry; + carry = 0; + (t[1], carry) = carrying_mul_add(a[1], a[0], t[1], carry); + (t[2], carry) = carrying_mul_add(a[1], a[1], t[2], carry); + (t[3], carry) = carrying_mul_add(a[1], a[2], t[3], carry); + (t[4], carry) = carrying_mul_add(a[1], a[3], t[4], carry); + t[5] = carry; + carry = 0; + (t[2], carry) = carrying_mul_add(a[2], a[0], t[2], carry); + (t[3], carry) = carrying_mul_add(a[2], a[1], t[3], carry); + (t[4], carry) = carrying_mul_add(a[2], a[2], t[4], carry); + (t[5], carry) = carrying_mul_add(a[2], a[3], t[5], carry); + t[6] = carry; + carry = 0; + (t[3], carry) = carrying_mul_add(a[3], a[0], t[3], carry); + (t[4], carry) = carrying_mul_add(a[3], a[1], t[4], carry); + (t[5], carry) = carrying_mul_add(a[3], a[2], t[5], carry); + (t[6], carry) = carrying_mul_add(a[3], a[3], t[6], carry); + t[7] = carry; + + let mut s_r1 = [0_u64; 5]; + (s_r1[0], s_r1[1]) = carrying_mul_add(t[0], U64_I3[0], 0, 0); + (s_r1[1], s_r1[2]) = carrying_mul_add(t[0], U64_I3[1], s_r1[1], 0); + (s_r1[2], s_r1[3]) = carrying_mul_add(t[0], U64_I3[2], s_r1[2], 0); + (s_r1[3], s_r1[4]) = carrying_mul_add(t[0], U64_I3[3], s_r1[3], 0); + + let mut s_r2 = [0_u64; 5]; + (s_r2[0], s_r2[1]) = carrying_mul_add(t[1], U64_I2[0], 0, 0); + (s_r2[1], s_r2[2]) = carrying_mul_add(t[1], U64_I2[1], s_r2[1], 0); + (s_r2[2], s_r2[3]) = carrying_mul_add(t[1], U64_I2[2], s_r2[2], 0); + (s_r2[3], s_r2[4]) = carrying_mul_add(t[1], U64_I2[3], s_r2[3], 0); + + let mut s_r3 = [0_u64; 5]; + (s_r3[0], s_r3[1]) = carrying_mul_add(t[2], U64_I1[0], 0, 0); + (s_r3[1], s_r3[2]) = carrying_mul_add(t[2], U64_I1[1], s_r3[1], 0); + (s_r3[2], s_r3[3]) = carrying_mul_add(t[2], U64_I1[2], s_r3[2], 0); + (s_r3[3], s_r3[4]) = carrying_mul_add(t[2], U64_I1[3], s_r3[3], 0); + + let s = addv(addv(subarray!(t, 3, 5), s_r1), addv(s_r2, s_r3)); + + let m = U64_MU0.wrapping_mul(s[0]); + let mut mp = [0_u64; 5]; + (mp[0], mp[1]) = carrying_mul_add(m, U64_P[0], mp[0], 0); + (mp[1], mp[2]) = carrying_mul_add(m, U64_P[1], mp[1], 0); + (mp[2], mp[3]) = carrying_mul_add(m, U64_P[2], mp[2], 0); + (mp[3], mp[4]) = carrying_mul_add(m, U64_P[3], mp[3], 0); + + let r = reduce_ct(subarray!(addv(s, mp), 1, 4)); + // --------------------------------------------------------------------------------------------- + r +} + #[inline(always)] pub fn scalar_mul( a: [u64; 4], @@ -116,6 +180,189 @@ pub fn scalar_mul( r } + +#[inline(always)] +pub fn simd_sqr( + v0_a: [u64; 4], + v1_a: [u64; 4], +) -> ([u64; 4], [u64; 4]) { + let v0_a = u256_to_u260_shl2_simd(transpose_u256_to_simd([v0_a, v1_a])); + + let mut t: [Simd; 10] = [Simd::splat(0); 10]; + t[0] = Simd::splat(make_initial(1, 0)); + t[9] = Simd::splat(make_initial(0, 6)); + t[1] = Simd::splat(make_initial(2, 1)); + t[8] = Simd::splat(make_initial(6, 7)); + t[2] = Simd::splat(make_initial(3, 2)); + t[7] = Simd::splat(make_initial(7, 8)); + t[3] = Simd::splat(make_initial(4, 3)); + t[6] = Simd::splat(make_initial(8, 9)); + t[4] = Simd::splat(make_initial(10, 4)); + t[5] = Simd::splat(make_initial(9, 10)); + + let avi: Simd = unsafe { vcvtq_f64_u64(v0_a[0].into()).into() }; + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[0].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[0 + 0 + 1] += p_hi.to_bits(); + t[0 + 0] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[1].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[0 + 1 + 1] += p_hi.to_bits(); + t[0 + 1] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[2].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[0 + 2 + 1] += p_hi.to_bits(); + t[0 + 2] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[3].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[0 + 3 + 1] += p_hi.to_bits(); + t[0 + 3] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[4].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[0 + 4 + 1] += p_hi.to_bits(); + t[0 + 4] += p_lo.to_bits(); + + let avi: Simd = unsafe { vcvtq_f64_u64(v0_a[1].into()).into() }; + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[0].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1 + 0 + 1] += p_hi.to_bits(); + t[1 + 0] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[1].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1 + 1 + 1] += p_hi.to_bits(); + t[1 + 1] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[2].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1 + 2 + 1] += p_hi.to_bits(); + t[1 + 2] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[3].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1 + 3 + 1] += p_hi.to_bits(); + t[1 + 3] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[4].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1 + 4 + 1] += p_hi.to_bits(); + t[1 + 4] += p_lo.to_bits(); + + let avi: Simd = unsafe { vcvtq_f64_u64(v0_a[2].into()).into() }; + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[0].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[2 + 0 + 1] += p_hi.to_bits(); + t[2 + 0] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[1].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[2 + 1 + 1] += p_hi.to_bits(); + t[2 + 1] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[2].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[2 + 2 + 1] += p_hi.to_bits(); + t[2 + 2] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[3].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[2 + 3 + 1] += p_hi.to_bits(); + t[2 + 3] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[4].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[2 + 4 + 1] += p_hi.to_bits(); + t[2 + 4] += p_lo.to_bits(); + + let avi: Simd = unsafe { vcvtq_f64_u64(v0_a[3].into()).into() }; + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[0].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[3 + 0 + 1] += p_hi.to_bits(); + t[3 + 0] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[1].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[3 + 1 + 1] += p_hi.to_bits(); + t[3 + 1] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[2].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[3 + 2 + 1] += p_hi.to_bits(); + t[3 + 2] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[3].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[3 + 3 + 1] += p_hi.to_bits(); + t[3 + 3] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[4].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[3 + 4 + 1] += p_hi.to_bits(); + t[3 + 4] += p_lo.to_bits(); + + let avi: Simd = unsafe { vcvtq_f64_u64(v0_a[4].into()).into() }; + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[0].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[4 + 0 + 1] += p_hi.to_bits(); + t[4 + 0] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[1].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[4 + 1 + 1] += p_hi.to_bits(); + t[4 + 1] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[2].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[4 + 2 + 1] += p_hi.to_bits(); + t[4 + 2] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[3].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[4 + 3 + 1] += p_hi.to_bits(); + t[4 + 3] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[4].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[4 + 4 + 1] += p_hi.to_bits(); + t[4 + 4] += p_lo.to_bits(); + + t[1] += t[0] >> 52; + t[2] += t[1] >> 52; + t[3] += t[2] >> 52; + t[4] += t[3] >> 52; + + let r0 = smult_noinit_simd(t[0].bitand(Simd::splat(MASK52)), RHO_4); + let r1 = smult_noinit_simd(t[1].bitand(Simd::splat(MASK52)), RHO_3); + let r2 = smult_noinit_simd(t[2].bitand(Simd::splat(MASK52)), RHO_2); + let r3 = smult_noinit_simd(t[3].bitand(Simd::splat(MASK52)), RHO_1); + + let s = [ + r0[0] + r1[0] + r2[0] + r3[0] + t[4], + r0[1] + r1[1] + r2[1] + r3[1] + t[5], + r0[2] + r1[2] + r2[2] + r3[2] + t[6], + r0[3] + r1[3] + r2[3] + r3[3] + t[7], + r0[4] + r1[4] + r2[4] + r3[4] + t[8], + r0[5] + r1[5] + r2[5] + r3[5] + t[9], + ]; + + let m = (s[0] * Simd::splat(U52_NP0)).bitand(Simd::splat(MASK52)); + let mp = smult_noinit_simd(m, U52_P); + + let reduced = reduce_ct_simd(addv_simd(s, mp)); + let u256_result = u260_to_u256_simd(reduced); + let v = transpose_simd_to_u256(u256_result); + (v[0], v[1]) +} + #[inline(always)] pub fn simd_mul( v0_a: [u64; 4], @@ -301,6 +548,246 @@ pub fn simd_mul( (v[0], v[1]) } + +#[inline(always)] +pub fn block_sqr( + _rtz: &RTZ, // Proof that the mode has been set to RTZ + s0_a: [u64; 4], + v0_a: [u64; 4], + v1_a: [u64; 4], +) -> ([u64; 4], [u64; 4], [u64; 4]) { + // -- [SCALAR AB MULT] -------------------------------------------------------------------------------------------- + let mut s_t = [0_u64; 8]; + let mut carry = 0; + (s_t[0], carry) = carrying_mul_add(s0_a[0], s0_a[0], s_t[0], carry); + (s_t[1], carry) = carrying_mul_add(s0_a[0], s0_a[1], s_t[1], carry); + (s_t[2], carry) = carrying_mul_add(s0_a[0], s0_a[2], s_t[2], carry); + (s_t[3], carry) = carrying_mul_add(s0_a[0], s0_a[3], s_t[3], carry); + s_t[4] = carry; + carry = 0; + (s_t[1], carry) = carrying_mul_add(s0_a[1], s0_a[0], s_t[1], carry); + (s_t[2], carry) = carrying_mul_add(s0_a[1], s0_a[1], s_t[2], carry); + (s_t[3], carry) = carrying_mul_add(s0_a[1], s0_a[2], s_t[3], carry); + (s_t[4], carry) = carrying_mul_add(s0_a[1], s0_a[3], s_t[4], carry); + s_t[5] = carry; + carry = 0; + (s_t[2], carry) = carrying_mul_add(s0_a[2], s0_a[0], s_t[2], carry); + (s_t[3], carry) = carrying_mul_add(s0_a[2], s0_a[1], s_t[3], carry); + (s_t[4], carry) = carrying_mul_add(s0_a[2], s0_a[2], s_t[4], carry); + (s_t[5], carry) = carrying_mul_add(s0_a[2], s0_a[3], s_t[5], carry); + s_t[6] = carry; + carry = 0; + (s_t[3], carry) = carrying_mul_add(s0_a[3], s0_a[0], s_t[3], carry); + (s_t[4], carry) = carrying_mul_add(s0_a[3], s0_a[1], s_t[4], carry); + (s_t[5], carry) = carrying_mul_add(s0_a[3], s0_a[2], s_t[5], carry); + (s_t[6], carry) = carrying_mul_add(s0_a[3], s0_a[3], s_t[6], carry); + s_t[7] = carry; + // ---------------------------------------------------------------------------------------------------------------- + // -- [VECTOR AB MULT] -------------------------------------------------------------------------------------------- + let v0_a = u256_to_u260_shl2_simd(transpose_u256_to_simd([v0_a, v1_a])); + + let mut t: [Simd; 10] = [Simd::splat(0); 10]; + t[0] = Simd::splat(make_initial(1, 0)); + t[9] = Simd::splat(make_initial(0, 6)); + t[1] = Simd::splat(make_initial(2, 1)); + t[8] = Simd::splat(make_initial(6, 7)); + t[2] = Simd::splat(make_initial(3, 2)); + t[7] = Simd::splat(make_initial(7, 8)); + t[3] = Simd::splat(make_initial(4, 3)); + t[6] = Simd::splat(make_initial(8, 9)); + t[4] = Simd::splat(make_initial(10, 4)); + t[5] = Simd::splat(make_initial(9, 10)); + + let avi: Simd = unsafe { vcvtq_f64_u64(v0_a[0].into()).into() }; + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[0].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[0 + 0 + 1] += p_hi.to_bits(); + t[0 + 0] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[1].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[0 + 1 + 1] += p_hi.to_bits(); + t[0 + 1] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[2].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[0 + 2 + 1] += p_hi.to_bits(); + t[0 + 2] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[3].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[0 + 3 + 1] += p_hi.to_bits(); + t[0 + 3] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[4].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[0 + 4 + 1] += p_hi.to_bits(); + t[0 + 4] += p_lo.to_bits(); + let avi: Simd = unsafe { vcvtq_f64_u64(v0_a[1].into()).into() }; + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[0].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1 + 0 + 1] += p_hi.to_bits(); + t[1 + 0] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[1].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1 + 1 + 1] += p_hi.to_bits(); + t[1 + 1] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[2].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1 + 2 + 1] += p_hi.to_bits(); + t[1 + 2] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[3].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1 + 3 + 1] += p_hi.to_bits(); + t[1 + 3] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[4].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1 + 4 + 1] += p_hi.to_bits(); + t[1 + 4] += p_lo.to_bits(); + let avi: Simd = unsafe { vcvtq_f64_u64(v0_a[2].into()).into() }; + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[0].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[2 + 0 + 1] += p_hi.to_bits(); + t[2 + 0] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[1].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[2 + 1 + 1] += p_hi.to_bits(); + t[2 + 1] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[2].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[2 + 2 + 1] += p_hi.to_bits(); + t[2 + 2] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[3].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[2 + 3 + 1] += p_hi.to_bits(); + t[2 + 3] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[4].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[2 + 4 + 1] += p_hi.to_bits(); + t[2 + 4] += p_lo.to_bits(); + let avi: Simd = unsafe { vcvtq_f64_u64(v0_a[3].into()).into() }; + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[0].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[3 + 0 + 1] += p_hi.to_bits(); + t[3 + 0] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[1].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[3 + 1 + 1] += p_hi.to_bits(); + t[3 + 1] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[2].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[3 + 2 + 1] += p_hi.to_bits(); + t[3 + 2] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[3].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[3 + 3 + 1] += p_hi.to_bits(); + t[3 + 3] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[4].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[3 + 4 + 1] += p_hi.to_bits(); + t[3 + 4] += p_lo.to_bits(); + let avi: Simd = unsafe { vcvtq_f64_u64(v0_a[4].into()).into() }; + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[0].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[4 + 0 + 1] += p_hi.to_bits(); + t[4 + 0] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[1].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[4 + 1 + 1] += p_hi.to_bits(); + t[4 + 1] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[2].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[4 + 2 + 1] += p_hi.to_bits(); + t[4 + 2] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[3].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[4 + 3 + 1] += p_hi.to_bits(); + t[4 + 3] += p_lo.to_bits(); + let bvj: Simd = unsafe { vcvtq_f64_u64(v0_a[4].into()).into() }; + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[4 + 4 + 1] += p_hi.to_bits(); + t[4 + 4] += p_lo.to_bits(); + // ---------------------------------------------------------------------------------------------------------------- + // -- [VECTOR REDUCE] --------------------------------------------------------------------------------------------- + t[1] += t[0] >> 52; + t[2] += t[1] >> 52; + t[3] += t[2] >> 52; + t[4] += t[3] >> 52; + + let r0 = smult_noinit_simd(t[0].bitand(Simd::splat(MASK52)), RHO_4); + let r1 = smult_noinit_simd(t[1].bitand(Simd::splat(MASK52)), RHO_3); + let r2 = smult_noinit_simd(t[2].bitand(Simd::splat(MASK52)), RHO_2); + let r3 = smult_noinit_simd(t[3].bitand(Simd::splat(MASK52)), RHO_1); + + let s = [ + r0[0] + r1[0] + r2[0] + r3[0] + t[4], + r0[1] + r1[1] + r2[1] + r3[1] + t[5], + r0[2] + r1[2] + r2[2] + r3[2] + t[6], + r0[3] + r1[3] + r2[3] + r3[3] + t[7], + r0[4] + r1[4] + r2[4] + r3[4] + t[8], + r0[5] + r1[5] + r2[5] + r3[5] + t[9], + ]; + // ---------------------------------------------------------------------------------------------------------------- + // -- [SCALAR REDUCE] --------------------------------------------------------------------------------------------- + let mut s_r1 = [0_u64; 5]; + (s_r1[0], s_r1[1]) = carrying_mul_add(s_t[0], U64_I3[0], 0, 0); + (s_r1[1], s_r1[2]) = carrying_mul_add(s_t[0], U64_I3[1], s_r1[1], 0); + (s_r1[2], s_r1[3]) = carrying_mul_add(s_t[0], U64_I3[2], s_r1[2], 0); + (s_r1[3], s_r1[4]) = carrying_mul_add(s_t[0], U64_I3[3], s_r1[3], 0); + + let mut s_r2 = [0_u64; 5]; + (s_r2[0], s_r2[1]) = carrying_mul_add(s_t[1], U64_I2[0], 0, 0); + (s_r2[1], s_r2[2]) = carrying_mul_add(s_t[1], U64_I2[1], s_r2[1], 0); + (s_r2[2], s_r2[3]) = carrying_mul_add(s_t[1], U64_I2[2], s_r2[2], 0); + (s_r2[3], s_r2[4]) = carrying_mul_add(s_t[1], U64_I2[3], s_r2[3], 0); + + let mut s_r3 = [0_u64; 5]; + (s_r3[0], s_r3[1]) = carrying_mul_add(s_t[2], U64_I1[0], 0, 0); + (s_r3[1], s_r3[2]) = carrying_mul_add(s_t[2], U64_I1[1], s_r3[1], 0); + (s_r3[2], s_r3[3]) = carrying_mul_add(s_t[2], U64_I1[2], s_r3[2], 0); + (s_r3[3], s_r3[4]) = carrying_mul_add(s_t[2], U64_I1[3], s_r3[3], 0); + + let s_s = addv(addv(subarray!(s_t, 3, 5), s_r1), addv(s_r2, s_r3)); + // ---------------------------------------------------------------------------------------------------------------- + // -- [FINAL] ----------------------------------------------------------------------------------------------------- + let s_m = U64_MU0.wrapping_mul(s_s[0]); + let mut s_mp = [0_u64; 5]; + (s_mp[0], s_mp[1]) = carrying_mul_add(s_m, U64_P[0], 0, 0); + (s_mp[1], s_mp[2]) = carrying_mul_add(s_m, U64_P[1], s_mp[1], 0); + (s_mp[2], s_mp[3]) = carrying_mul_add(s_m, U64_P[2], s_mp[2], 0); + (s_mp[3], s_mp[4]) = carrying_mul_add(s_m, U64_P[3], s_mp[3], 0); + let s0 = reduce_ct(subarray!(addv(s_s, s_mp), 1, 4)); + + let m = (s[0] * Simd::splat(U52_NP0)).bitand(Simd::splat(MASK52)); + let mp = smult_noinit_simd(m, U52_P); + let resolve = reduce_ct_simd(addv_simd(s, mp)); + let u256_result = u260_to_u256_simd(resolve); + let v = transpose_simd_to_u256(u256_result); + // ---------------------------------------------------------------------------------------------------------------- + (s0, v[0], v[1]) +} + #[inline(always)] pub fn block_mul( _rtz: &RTZ, // Proof that the mode has been set to RTZ @@ -707,7 +1194,7 @@ fn carrying_mul_add(a: u64, b: u64, add: u64, carry: u64) -> (u64, u64) { #[cfg(test)] mod tests { - use crate::{block_mul, scalar_mul, constants, rtz::RTZ}; + use crate::{block_mul, block_sqr, scalar_mul, scalar_sqr, constants, rtz::RTZ}; use primitive_types::U256; use rand::{Rng, SeedableRng, rngs}; @@ -779,6 +1266,45 @@ mod tests { } } + #[test] + fn test_block_sqr() { + let mut rng = rngs::StdRng::seed_from_u64(0); + let p = U256(constants::U64_P); + let r = U256(constants::U64_R); + let r_inv = U256(constants::U64_R_INV); + + let mut s0_a_bytes = [0u8; 32]; + let mut v0_a_bytes = [0u8; 32]; + let mut v1_a_bytes = [0u8; 32]; + + let rtz = RTZ::set().unwrap(); + + for _ in 0..100000 { + rng.fill(&mut s0_a_bytes); + rng.fill(&mut v0_a_bytes); + rng.fill(&mut v1_a_bytes); + let s0_a = U256::from_little_endian(&s0_a_bytes) % p; + let v0_a = U256::from_little_endian(&v0_a_bytes) % p; + let v1_a = U256::from_little_endian(&v1_a_bytes) % p; + let s0_a_mont = mod_mul(s0_a, r); + let v0_a_mont = mod_mul(v0_a, r); + let v1_a_mont = mod_mul(v1_a, r); + + let (s0, v0, v1) = block_sqr( + &rtz, + s0_a_mont.0, + v0_a_mont.0, + v1_a_mont.0, + ); + assert!(U256(s0) < U256(OUTPUT_MAX)); + assert!(U256(v0) < U256(OUTPUT_MAX)); + assert!(U256(v1) < U256(OUTPUT_MAX)); + assert_eq!(mod_mul(U256(s0), r_inv), mod_mul(s0_a, s0_a)); + assert_eq!(mod_mul(U256(v0), r_inv), mod_mul(v0_a, v0_a)); + assert_eq!(mod_mul(U256(v1), r_inv), mod_mul(v1_a, v1_a)); + } + } + #[test] fn test_scalar_mul() { let mut rng = rngs::StdRng::seed_from_u64(0); @@ -807,4 +1333,28 @@ mod tests { assert_eq!(mod_mul(U256(s0), r_inv), mod_mul(s0_a, s0_b)); } } + + #[test] + fn test_scalar_sqr() { + let mut rng = rngs::StdRng::seed_from_u64(0); + let p = U256(constants::U64_P); + let r = U256(constants::U64_R); + let r_inv = U256(constants::U64_R_INV); + + let mut s0_a_bytes = [0u8; 32]; + + let _rtz = RTZ::set().unwrap(); + + for _ in 0..100000 { + rng.fill(&mut s0_a_bytes); + let s0_a = U256::from_little_endian(&s0_a_bytes) % p; + let s0_a_mont = mod_mul(s0_a, r); + + let s0 = scalar_sqr( + s0_a_mont.0, + ); + assert!(U256(s0) < U256(OUTPUT_MAX)); + assert_eq!(mod_mul(U256(s0), r_inv), mod_mul(s0_a, s0_a)); + } + } } From 094ff743756fb20bb931bb41488636a983bf4b74 Mon Sep 17 00:00:00 2001 From: Tony Wu Date: Wed, 16 Apr 2025 06:26:29 +0100 Subject: [PATCH 5/8] add inlining for top level block-multiplier functions --- block-multiplier/src/lib.rs | 58 ++++++++++++------------------------- 1 file changed, 19 insertions(+), 39 deletions(-) diff --git a/block-multiplier/src/lib.rs b/block-multiplier/src/lib.rs index aec75a99a..f82747e28 100644 --- a/block-multiplier/src/lib.rs +++ b/block-multiplier/src/lib.rs @@ -8,12 +8,12 @@ use crate::constants::*; use rtz::RTZ; use seq_macro::seq; use std::arch::aarch64::vcvtq_f64_u64; +use std::array; use std::ops::BitAnd; -use std::simd::{Simd, StdFloat, num::SimdFloat}; use std::simd::cmp::SimdPartialEq; -use std::simd::num::SimdUint; use std::simd::num::SimdInt; -use std::array; +use std::simd::num::SimdUint; +use std::simd::{Simd, StdFloat, num::SimdFloat}; /// Macro to extract a subarray from an array. /// @@ -51,10 +51,8 @@ macro_rules! subarray { }; } -#[inline(always)] -pub fn scalar_sqr( - a: [u64; 4], -) -> [u64; 4] { +#[inline] +pub fn scalar_sqr(a: [u64; 4]) -> [u64; 4] { // -- [SCALAR] --------------------------------------------------------------------------------- let mut t = [0_u64; 8]; @@ -102,7 +100,7 @@ pub fn scalar_sqr( (s_r3[3], s_r3[4]) = carrying_mul_add(t[2], U64_I1[3], s_r3[3], 0); let s = addv(addv(subarray!(t, 3, 5), s_r1), addv(s_r2, s_r3)); - + let m = U64_MU0.wrapping_mul(s[0]); let mut mp = [0_u64; 5]; (mp[0], mp[1]) = carrying_mul_add(m, U64_P[0], mp[0], 0); @@ -115,11 +113,8 @@ pub fn scalar_sqr( r } -#[inline(always)] -pub fn scalar_mul( - a: [u64; 4], - b: [u64; 4] -) -> [u64; 4] { +#[inline] +pub fn scalar_mul(a: [u64; 4], b: [u64; 4]) -> [u64; 4] { // -- [SCALAR] --------------------------------------------------------------------------------- let mut t = [0_u64; 8]; @@ -167,7 +162,7 @@ pub fn scalar_mul( (s_r3[3], s_r3[4]) = carrying_mul_add(t[2], U64_I1[3], s_r3[3], 0); let s = addv(addv(subarray!(t, 3, 5), s_r1), addv(s_r2, s_r3)); - + let m = U64_MU0.wrapping_mul(s[0]); let mut mp = [0_u64; 5]; (mp[0], mp[1]) = carrying_mul_add(m, U64_P[0], mp[0], 0); @@ -180,12 +175,8 @@ pub fn scalar_mul( r } - -#[inline(always)] -pub fn simd_sqr( - v0_a: [u64; 4], - v1_a: [u64; 4], -) -> ([u64; 4], [u64; 4]) { +#[inline] +pub fn simd_sqr(v0_a: [u64; 4], v1_a: [u64; 4]) -> ([u64; 4], [u64; 4]) { let v0_a = u256_to_u260_shl2_simd(transpose_u256_to_simd([v0_a, v1_a])); let mut t: [Simd; 10] = [Simd::splat(0); 10]; @@ -363,12 +354,12 @@ pub fn simd_sqr( (v[0], v[1]) } -#[inline(always)] +#[inline] pub fn simd_mul( v0_a: [u64; 4], v0_b: [u64; 4], v1_a: [u64; 4], - v1_b: [u64; 4] + v1_b: [u64; 4], ) -> ([u64; 4], [u64; 4]) { let v0_a = u256_to_u260_shl2_simd(transpose_u256_to_simd([v0_a, v1_a])); let v0_b = u256_to_u260_shl2_simd(transpose_u256_to_simd([v0_b, v1_b])); @@ -548,8 +539,7 @@ pub fn simd_mul( (v[0], v[1]) } - -#[inline(always)] +#[inline] pub fn block_sqr( _rtz: &RTZ, // Proof that the mode has been set to RTZ s0_a: [u64; 4], @@ -788,7 +778,7 @@ pub fn block_sqr( (s0, v[0], v[1]) } -#[inline(always)] +#[inline] pub fn block_mul( _rtz: &RTZ, // Proof that the mode has been set to RTZ s0_a: [u64; 4], @@ -1194,7 +1184,7 @@ fn carrying_mul_add(a: u64, b: u64, add: u64, carry: u64) -> (u64, u64) { #[cfg(test)] mod tests { - use crate::{block_mul, block_sqr, scalar_mul, scalar_sqr, constants, rtz::RTZ}; + use crate::{block_mul, block_sqr, constants, rtz::RTZ, scalar_mul, scalar_sqr}; use primitive_types::U256; use rand::{Rng, SeedableRng, rngs}; @@ -1290,12 +1280,7 @@ mod tests { let v0_a_mont = mod_mul(v0_a, r); let v1_a_mont = mod_mul(v1_a, r); - let (s0, v0, v1) = block_sqr( - &rtz, - s0_a_mont.0, - v0_a_mont.0, - v1_a_mont.0, - ); + let (s0, v0, v1) = block_sqr(&rtz, s0_a_mont.0, v0_a_mont.0, v1_a_mont.0); assert!(U256(s0) < U256(OUTPUT_MAX)); assert!(U256(v0) < U256(OUTPUT_MAX)); assert!(U256(v1) < U256(OUTPUT_MAX)); @@ -1325,10 +1310,7 @@ mod tests { let s0_a_mont = mod_mul(s0_a, r); let s0_b_mont = mod_mul(s0_b, r); - let s0 = scalar_mul( - s0_a_mont.0, - s0_b_mont.0, - ); + let s0 = scalar_mul(s0_a_mont.0, s0_b_mont.0); assert!(U256(s0) < U256(OUTPUT_MAX)); assert_eq!(mod_mul(U256(s0), r_inv), mod_mul(s0_a, s0_b)); } @@ -1350,9 +1332,7 @@ mod tests { let s0_a = U256::from_little_endian(&s0_a_bytes) % p; let s0_a_mont = mod_mul(s0_a, r); - let s0 = scalar_sqr( - s0_a_mont.0, - ); + let s0 = scalar_sqr(s0_a_mont.0); assert!(U256(s0) < U256(OUTPUT_MAX)); assert_eq!(mod_mul(U256(s0), r_inv), mod_mul(s0_a, s0_a)); } From b563f887ad471abfc0c357a874452ae14604a70e Mon Sep 17 00:00:00 2001 From: Tony Wu Date: Wed, 16 Apr 2025 06:27:45 +0100 Subject: [PATCH 6/8] Initial block-skyscrapper --- Cargo.toml | 2 +- skyscraper/Cargo.toml | 21 ++ skyscraper/benches/bench.rs | 76 +++++ skyscraper/src/lib.rs | 630 ++++++++++++++++++++++++++++++++++++ 4 files changed, 728 insertions(+), 1 deletion(-) create mode 100644 skyscraper/Cargo.toml create mode 100644 skyscraper/benches/bench.rs create mode 100644 skyscraper/src/lib.rs diff --git a/Cargo.toml b/Cargo.toml index 2636de644..9dc10935c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] resolver = "2" -members = ["block-multiplier", "noir-r1cs", "delegated-spartan", "merkle-hash-bench", "prover"] +members = ["block-multiplier", "noir-r1cs", "delegated-spartan", "merkle-hash-bench", "prover", "skyscraper"] [workspace.package] edition = "2021" diff --git a/skyscraper/Cargo.toml b/skyscraper/Cargo.toml new file mode 100644 index 000000000..3dfbeb913 --- /dev/null +++ b/skyscraper/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "skyscraper" +version = "0.1.0" +edition.workspace = true +rust-version.workspace = true +authors.workspace = true +license.workspace = true +homepage.workspace = true +repository.workspace = true + +[dependencies] +block-multiplier = { path = "../block-multiplier" } + +[dev-dependencies] +rand = "0.9.0" +primitive-types = "0.13.1" +criterion = "0.5.1" + +[[bench]] +name = "bench" +harness = false \ No newline at end of file diff --git a/skyscraper/benches/bench.rs b/skyscraper/benches/bench.rs new file mode 100644 index 000000000..c4de5fae2 --- /dev/null +++ b/skyscraper/benches/bench.rs @@ -0,0 +1,76 @@ +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use rand::prelude::StdRng; +use rand::{Rng, SeedableRng}; + +fn bench_skyscraper(c: &mut Criterion) { + let mut group = c.benchmark_group("skyscraper"); + let seed: u64 = rand::random(); + println!("Using random seed for benchmark: {}", seed); + let mut rng = StdRng::seed_from_u64(seed); + + let l_0 = [ + rng.random::(), + rng.random::(), + rng.random::(), + rng.random::(), + ]; + let l_1 = [ + rng.random::(), + rng.random::(), + rng.random::(), + rng.random::(), + ]; + let l_2 = [ + rng.random::(), + rng.random::(), + rng.random::(), + rng.random::(), + ]; + let r_0 = [ + rng.random::(), + rng.random::(), + rng.random::(), + rng.random::(), + ]; + let r_1 = [ + rng.random::(), + rng.random::(), + rng.random::(), + rng.random::(), + ]; + let r_2 = [ + rng.random::(), + rng.random::(), + rng.random::(), + rng.random::(), + ]; + + let _rtz = block_multiplier::rtz::RTZ::set().unwrap(); + + group.bench_function("compress", |bencher| { + bencher.iter(|| skyscraper::compress(black_box(l_0), black_box(r_0))) + }); + + group.bench_function("block_compress", |bencher| { + bencher.iter(|| skyscraper::block_compress( + black_box(&_rtz), + black_box(l_0), + black_box(l_1), + black_box(l_2), + black_box(r_0), + black_box(r_1), + black_box(r_2) + )) + }); +} + +criterion_group!( + name = benches; + config = Criterion::default() + .sample_size(5000) + // Warm up is warm because it literally warms up the pi + .warm_up_time(std::time::Duration::new(1,0)) + .measurement_time(std::time::Duration::new(10,0)); + targets = bench_skyscraper +); +criterion_main!(benches); diff --git a/skyscraper/src/lib.rs b/skyscraper/src/lib.rs new file mode 100644 index 000000000..067f24d80 --- /dev/null +++ b/skyscraper/src/lib.rs @@ -0,0 +1,630 @@ +#![feature(bigint_helper_methods)] +use block_multiplier::*; +use block_multiplier::rtz::RTZ; + +pub const U64_P: [u64; 4] = [ + 0x43e1f593f0000001, + 0x2833e84879b97091, + 0xb85045b68181585d, + 0x30644e72e131a029, +]; +pub const U64_2P: [u64; 4] = [ + 0x87c3eb27e0000002, + 0x5067d090f372e122, + 0x70a08b6d0302b0ba, + 0x60c89ce5c2634053, +]; +pub const U64_3P: [u64; 4] = [ + 0xcba5e0bbd0000003, + 0x789bb8d96d2c51b3, + 0x28f0d12384840917, + 0x912ceb58a394e07d, +]; +pub const U64_4P: [u64; 4] = [ + 0x0f87d64fc0000004, + 0xa0cfa121e6e5c245, + 0xe14116da06056174, + 0xc19139cb84c680a6, +]; +pub const U64_5P: [u64; 4] = [ + 0x5369cbe3b0000005, + 0xc903896a609f32d6, + 0x99915c908786b9d1, + 0xf1f5883e65f820d0, +]; + +pub const RC: [[u64; 4]; 8] = [ + [ + 0x903c4324270bd744, + 0x873125f708a7d269, + 0x081dd27906c83855, + 0x276b1823ea6d7667, + ], + [ + 0x7ac8edbb4b378d71, + 0xe29d79f3d99e2cb7, + 0x751417914c1a5a18, + 0x0cf02bd758a484a6, + ], + [ + 0xfa7adc6769e5bc36, + 0x1c3f8e297cca387d, + 0x0eb7730d63481db0, + 0x25b0e03f18ede544, + ], + [ + 0x57847e652f03cfb7, + 0x33440b9668873404, + 0x955a32e849af80bc, + 0x002882fcbe14ae70, + ], + [ + 0x979231396257d4d7, + 0x29989c3e1b37d3c1, + 0x12ef02b47f1277ba, + 0x039ad8571e2b7a9c, + ], + [ + 0xb5b48465abbb7887, + 0xa72a6bc5e6ba2d2b, + 0x4cd48043712f7b29, + 0x1142d5410fc1fc1a, + ], + [ + 0x7ab2c156059075d3, + 0x17cb3594047999b2, + 0x44f2c93598f289f7, + 0x1d78439f69bc0bec, + ], + [ + 0x05d7a965138b8edb, + 0x36ef35a3d55c48b1, + 0x8ddfb8a1ac6f1628, + 0x258588a508f4ff82, + ], +]; + +pub const _1P_MINUS_RC: [[u64; 4]; 8] = [ + [ + 0xb3a5b26fc8f428bd, + 0xa102c25171119e27, + 0xb032733d7ab92007, + 0x08f9364ef6c429c2, + ], + [ + 0xc91907d8a4c87290, + 0x45966e54a01b43d9, + 0x433c2e253566fe44, + 0x2374229b888d1b83, + ], + [ + 0x4967192c861a43cb, + 0x0bf45a1efcef3813, + 0xa998d2a91e393aad, + 0x0ab36e33c843bae5, + ], + [ + 0xec5d772ec0fc304a, + 0xf4efdcb211323c8c, + 0x22f612ce37d1d7a0, + 0x303bcb76231cf1b9, + ], + [ + 0xac4fc45a8da82b2a, + 0xfe9b4c0a5e819ccf, + 0xa5614302026ee0a2, + 0x2cc9761bc306258d, + ], + [ + 0x8e2d712e4444877a, + 0x81097c8292ff4365, + 0x6b7bc5731051dd33, + 0x1f217931d16fa40f, + ], + [ + 0xc92f343dea6f8a2e, + 0x1068b2b4753fd6de, + 0x735d7c80e88ece66, + 0x12ec0ad37775943d, + ], + [ + 0x3e0a4c2edc747126, + 0xf144b2a4a45d27e0, + 0x2a708d14d5124234, + 0x0adec5cdd83ca0a7, + ], +]; + +pub const _2P_MINUS_RC: [[u64; 4]; 8] = [ + [ + 0xf787a803b8f428be, + 0xc936aa99eacb0eb8, + 0x6882b8f3fc3a7864, + 0x395d84c1d7f5c9ec, + ], + [ + 0x0cfafd6c94c87291, + 0x6dca569d19d4b46b, + 0xfb8c73dbb6e856a1, + 0x53d8710e69bebbac, + ], + [ + 0x8d490ec0761a43cc, + 0x3428426776a8a8a4, + 0x61e9185f9fba930a, + 0x3b17bca6a9755b0f, + ], + [ + 0x303f6cc2b0fc304b, + 0x1d23c4fa8aebad1e, + 0xdb465884b9532ffe, + 0x60a019e9044e91e2, + ], + [ + 0xf031b9ee7da82b2b, + 0x26cf3452d83b0d60, + 0x5db188b883f03900, + 0x5d2dc48ea437c5b7, + ], + [ + 0xd20f66c23444877b, + 0xa93d64cb0cb8b3f6, + 0x23cc0b2991d33590, + 0x4f85c7a4b2a14439, + ], + [ + 0x0d1129d1da6f8a2f, + 0x389c9afceef94770, + 0x2badc2376a1026c3, + 0x4350594658a73467, + ], + [ + 0x81ec41c2cc747127, + 0x19789aed1e169871, + 0xe2c0d2cb56939a92, + 0x3b431440b96e40d0, + ], +]; + +pub const _3P_MINUS_RC: [[u64; 4]; 8] = [ + [ + 0x3b699d97a8f428bf, + 0xf16a92e264847f4a, + 0x20d2feaa7dbbd0c1, + 0x69c1d334b9276a16, + ], + [ + 0x50dcf30084c87292, + 0x95fe3ee5938e24fc, + 0xb3dcb9923869aefe, + 0x843cbf814af05bd6, + ], + [ + 0xd12b0454661a43cd, + 0x5c5c2aaff0621935, + 0x1a395e16213beb67, + 0x6b7c0b198aa6fb39, + ], + [ + 0x74216256a0fc304c, + 0x4557ad4304a51daf, + 0x93969e3b3ad4885b, + 0x9104685be580320c, + ], + [ + 0x3413af826da82b2c, + 0x4f031c9b51f47df2, + 0x1601ce6f0571915d, + 0x8d921301856965e1, + ], + [ + 0x15f15c562444877c, + 0xd1714d1386722488, + 0xdc1c50e013548ded, + 0x7fea161793d2e462, + ], + [ + 0x50f31f65ca6f8a30, + 0x60d0834568b2b801, + 0xe3fe07edeb917f20, + 0x73b4a7b939d8d490, + ], + [ + 0xc5ce3756bc747128, + 0x41ac833597d00902, + 0x9b111881d814f2ef, + 0x6ba762b39a9fe0fa, + ], +]; + +pub const _4P_MINUS_RC: [[u64; 4]; 8] = [ + [ + 0x7f4b932b98f428c0, + 0x199e7b2ade3defdb, + 0xd9234460ff3d291f, + 0x9a2621a79a590a3f, + ], + [ + 0x94bee89474c87293, + 0xbe32272e0d47958d, + 0x6c2cff48b9eb075b, + 0xb4a10df42c21fc00, + ], + [ + 0x150cf9e8561a43ce, + 0x849012f86a1b89c7, + 0xd289a3cca2bd43c4, + 0x9be0598c6bd89b62, + ], + [ + 0xb80357ea90fc304d, + 0x6d8b958b7e5e8e40, + 0x4be6e3f1bc55e0b8, + 0xc168b6cec6b1d236, + ], + [ + 0x77f5a5165da82b2d, + 0x773704e3cbadee83, + 0xce52142586f2e9ba, + 0xbdf66174669b060a, + ], + [ + 0x59d351ea1444877d, + 0xf9a5355c002b9519, + 0x946c969694d5e64a, + 0xb04e648a7504848c, + ], + [ + 0x94d514f9ba6f8a31, + 0x89046b8de26c2892, + 0x9c4e4da46d12d77d, + 0xa418f62c1b0a74ba, + ], + [ + 0x09b02ceaac747129, + 0x69e06b7e11897994, + 0x53615e3859964b4c, + 0x9c0bb1267bd18124, + ], +]; + +#[inline] +pub fn compress(l: [u64; 4], r: [u64; 4]) -> [u64; 4] { + let a = l; + let sqr = scalar_sqr(l); + let (l, r) = (wrapping_add(r, sqr), l); + let sqr = scalar_sqr(l); + let (l, r) = (x0p_plus_sqr3p_plus_rc_eq0p(r, sqr, 0), l); + let bar = bar_u8(l); + let (l, r) = (x2p_plus_bar0p_plus_rc_eq0p(r, bar, 1), l); + let bar = bar_u8(l); + let (l, r) = (x0p_plus_bar0p_plus_rc_eq0p(r, bar, 2), l); + let sqr = scalar_sqr(l); + let (l, r) = (x0p_plus_sqr2p_plus_rc_eq0p(r, sqr, 3), l); + let sqr = scalar_sqr(l); + let (l, r) = (x0p_plus_sqr1p_plus_rc_eq0p(r, sqr, 4), l); + let bar = bar_u8(l); + let (l, r) = (x0p_plus_bar0p_plus_rc_eq0p(r, bar, 5), l); + let bar = bar_u8(l); + let (l, r) = (x0p_plus_bar0p_plus_rc_eq0p(r, bar, 6), l); + let sqr = scalar_sqr(l); + let (l, r) = (x0p_plus_sqr1p_plus_rc_eq0p(r, sqr, 7), l); + let sqr = scalar_sqr(l); + x0p_plus_sqr1p_plus_y0p_eq0p(r, sqr, a) +} + +#[inline] +pub fn block_compress( + _rtz: &RTZ, + l_0: [u64; 4], + l_1: [u64; 4], + l_2: [u64; 4], + r_0: [u64; 4], + r_1: [u64; 4], + r_2: [u64; 4] +) -> ([u64; 4], [u64; 4], [u64; 4]) { + let a_0 = l_0; + let a_1 = l_1; + let a_2 = l_2; + let (sqr_0, sqr_1, sqr_2) = block_sqr(_rtz, l_0, l_1, l_2); + let (l_0, r_0) = (wrapping_add(r_0, sqr_0), l_0); + let (l_1, r_1) = (wrapping_add(r_1, sqr_1), l_1); + let (l_2, r_2) = (wrapping_add(r_2, sqr_2), l_2); + let (sqr_0, sqr_1, sqr_2) = block_sqr(_rtz, l_0, l_1, l_2); + let (l_0, r_0) = (x0p_plus_sqr3p_plus_rc_eq0p(r_0, sqr_0, 0), l_0); + let (l_1, r_1) = (x0p_plus_sqr3p_plus_rc_eq0p(r_1, sqr_1, 0), l_1); + let (l_2, r_2) = (x0p_plus_sqr3p_plus_rc_eq0p(r_2, sqr_2, 0), l_2); + let bar_0 = bar_u8(l_0); + let bar_1 = bar_u8(l_1); + let bar_2 = bar_u8(l_2); + let (l_0, r_0) = (x2p_plus_bar0p_plus_rc_eq0p(r_0, bar_0, 1), l_0); + let (l_1, r_1) = (x2p_plus_bar0p_plus_rc_eq0p(r_1, bar_1, 1), l_1); + let (l_2, r_2) = (x2p_plus_bar0p_plus_rc_eq0p(r_2, bar_2, 1), l_2); + let bar_0 = bar_u8(l_0); + let bar_1 = bar_u8(l_1); + let bar_2 = bar_u8(l_2); + let (l_0, r_0) = (x0p_plus_bar0p_plus_rc_eq0p(r_0, bar_0, 2), l_0); + let (l_1, r_1) = (x0p_plus_bar0p_plus_rc_eq0p(r_1, bar_1, 2), l_1); + let (l_2, r_2) = (x0p_plus_bar0p_plus_rc_eq0p(r_2, bar_2, 2), l_2); + let (sqr_0, sqr_1, sqr_2) = block_sqr(_rtz, l_0, l_1, l_2); + let (l_0, r_0) = (x0p_plus_sqr2p_plus_rc_eq0p(r_0, sqr_0, 3), l_0); + let (l_1, r_1) = (x0p_plus_sqr2p_plus_rc_eq0p(r_1, sqr_1, 3), l_1); + let (l_2, r_2) = (x0p_plus_sqr2p_plus_rc_eq0p(r_2, sqr_2, 3), l_2); + let (sqr_0, sqr_1, sqr_2) = block_sqr(_rtz, l_0, l_1, l_2); + let (l_0, r_0) = (x0p_plus_sqr1p_plus_rc_eq0p(r_0, sqr_0, 4), l_0); + let (l_1, r_1) = (x0p_plus_sqr1p_plus_rc_eq0p(r_1, sqr_1, 4), l_1); + let (l_2, r_2) = (x0p_plus_sqr1p_plus_rc_eq0p(r_2, sqr_2, 4), l_2); + let bar_0 = bar_u8(l_0); + let bar_1 = bar_u8(l_1); + let bar_2 = bar_u8(l_2); + let (l_0, r_0) = (x0p_plus_bar0p_plus_rc_eq0p(r_0, bar_0, 5), l_0); + let (l_1, r_1) = (x0p_plus_bar0p_plus_rc_eq0p(r_1, bar_1, 5), l_1); + let (l_2, r_2) = (x0p_plus_bar0p_plus_rc_eq0p(r_2, bar_2, 5), l_2); + let bar_0 = bar_u8(l_0); + let bar_1 = bar_u8(l_1); + let bar_2 = bar_u8(l_2); + let (l_0, r_0) = (x0p_plus_bar0p_plus_rc_eq0p(r_0, bar_0, 6), l_0); + let (l_1, r_1) = (x0p_plus_bar0p_plus_rc_eq0p(r_1, bar_1, 6), l_1); + let (l_2, r_2) = (x0p_plus_bar0p_plus_rc_eq0p(r_2, bar_2, 6), l_2); + let (sqr_0, sqr_1, sqr_2) = block_sqr(_rtz, l_0, l_1, l_2); + let (l_0, r_0) = (x0p_plus_sqr1p_plus_rc_eq0p(r_0, sqr_0, 7), l_0); + let (l_1, r_1) = (x0p_plus_sqr1p_plus_rc_eq0p(r_1, sqr_1, 7), l_1); + let (l_2, r_2) = (x0p_plus_sqr1p_plus_rc_eq0p(r_2, sqr_2, 7), l_2); + let (sqr_0, sqr_1, sqr_2) = block_sqr(_rtz, l_0, l_1, l_2); + let l_0 = x0p_plus_sqr1p_plus_y0p_eq0p(r_0, sqr_0, a_0); + let l_1 = x0p_plus_sqr1p_plus_y0p_eq0p(r_1, sqr_1, a_1); + let l_2 = x0p_plus_sqr1p_plus_y0p_eq0p(r_2, sqr_2, a_2); + (l_0, l_1, l_2) +} + +#[inline(always)] +fn wrapping_add(x: [u64; 4], y: [u64; 4]) -> [u64; 4] { + let x_u128 = unsafe { std::mem::transmute::<[u64; 4], [u128; 2]>(x) }; + let y_u128 = unsafe { std::mem::transmute::<[u64; 4], [u128; 2]>(y) }; + let (lo, c) = x_u128[0].overflowing_add(y_u128[0]); + let (hi, _) = x_u128[1].carrying_add(y_u128[1], c); + unsafe { std::mem::transmute::<[u128; 2], [u64; 4]>([lo, hi]) } +} + +#[inline(always)] +fn wrapping_sub(x: [u64; 4], y: [u64; 4]) -> [u64; 4] { + let x_u128 = unsafe { std::mem::transmute::<[u64; 4], [u128; 2]>(x) }; + let y_u128 = unsafe { std::mem::transmute::<[u64; 4], [u128; 2]>(y) }; + let (lo, b) = x_u128[0].overflowing_sub(y_u128[0]); + let (hi, _) = x_u128[1].borrowing_sub(y_u128[1], b); + unsafe { std::mem::transmute::<[u128; 2], [u64; 4]>([lo, hi]) } +} + +#[inline(always)] +fn overflowing_sub(x: [u64; 4], y: [u64; 4]) -> ([u64; 4], bool) { + let x_u128 = unsafe { std::mem::transmute::<[u64; 4], [u128; 2]>(x) }; + let y_u128 = unsafe { std::mem::transmute::<[u64; 4], [u128; 2]>(y) }; + let (lo, b) = x_u128[0].overflowing_sub(y_u128[0]); + let (hi, b) = x_u128[1].borrowing_sub(y_u128[1], b); + ( + unsafe { std::mem::transmute::<[u128; 2], [u64; 4]>([lo, hi]) }, + b, + ) +} + +#[inline(always)] +fn reduce_1p(x: [u64; 4]) -> [u64; 4] { + let (xr, c) = overflowing_sub(x, U64_P); + if c { + x + } else { + xr + } +} + +#[inline(always)] +fn reduce_2p(x: [u64; 4]) -> [u64; 4] { + let msb0 = (x[3] >> 63) != 0; + let msb1 = ((x[3] << 1) >> 63) != 0; + if msb0 { + wrapping_sub(x, U64_2P) + } else if msb1 { + reduce_1p(wrapping_sub(x, U64_P)) + } else { + reduce_1p(x) + } +} + +#[inline(always)] +fn reduce_3p(x: [u64; 4]) -> [u64; 4] { + let msb0 = (x[3] >> 63) != 0; + let msb1 = ((x[3] << 1) >> 63) != 0; + if msb0 { + reduce_1p(wrapping_sub(x, U64_2P)) + } else if msb1 { + reduce_1p(wrapping_sub(x, U64_P)) + } else { + reduce_1p(x) + } +} + +#[inline(always)] +fn reduce_4p(x: [u64; 4]) -> [u64; 4] { + let msb = (x[3] >> 62) as u8; + if msb == 0 { + reduce_1p(x) + } else { + let r = if msb == 1 { + U64_P + } else if msb == 2 { + U64_2P + } else { + U64_3P + }; + reduce_1p(wrapping_sub(x, r)) + } +} + +#[inline(always)] +fn bar_u8(x: [u64; 4]) -> [u64; 4] { + let mut x_u8 = unsafe { std::mem::transmute::<[u64; 4], [u8; 32]>(x) }; + for i in 0..32 { + let v = x_u8[i]; + x_u8[i] = (v ^ ((!v).rotate_left(1) & v.rotate_left(2) & v.rotate_left(3))).rotate_left(1); + } + let x = unsafe { std::mem::transmute::<[u8; 32], [u64; 4]>(x_u8) }; + [x[2], x[3], x[0], x[1]] +} + +#[inline(always)] +fn x0p_plus_sqr3p_plus_rc_eq0p(x: [u64; 4], sqr: [u64; 4], rc_idx: usize) -> [u64; 4] { + let x_plus_sqr = wrapping_add(x, sqr); + let (tmp, b) = overflowing_sub(x_plus_sqr, _1P_MINUS_RC[rc_idx]); + if b { + wrapping_add(x_plus_sqr, RC[rc_idx]) + } else { + reduce_4p(tmp) + } +} + +#[inline(always)] +fn x2p_plus_bar0p_plus_rc_eq0p(x: [u64; 4], bar: [u64; 4], rc_idx: usize) -> [u64; 4] { + let msb0 = (bar[3] >> 62) as u8; + let msb1 = ((bar[3] << 2) >> 63) != 0; + let bar_plus_rc; + if msb0 == 0 { + bar_plus_rc = wrapping_add(bar, RC[rc_idx]); + } else if msb0 == 1 { + bar_plus_rc = wrapping_sub(bar, _1P_MINUS_RC[rc_idx]); + } else if msb0 == 2 { + bar_plus_rc = wrapping_sub(bar, _2P_MINUS_RC[rc_idx]); + } else if !msb1 { + bar_plus_rc = wrapping_sub(bar, _3P_MINUS_RC[rc_idx]); + } else { + bar_plus_rc = wrapping_sub(bar, _4P_MINUS_RC[rc_idx]); + } + let tmp = wrapping_add(bar_plus_rc, x); + reduce_4p(tmp) +} + +#[inline(always)] +fn x0p_plus_bar0p_plus_rc_eq0p(x: [u64; 4], bar: [u64; 4], rc_idx: usize) -> [u64; 4] { + let msb0 = (bar[3] >> 62) as u8; + let msb1 = ((bar[3] << 2) >> 63) != 0; + let bar_plus_rc; + if msb0 == 0 { + bar_plus_rc = wrapping_add(bar, RC[rc_idx]); + } else if msb0 == 1 { + bar_plus_rc = wrapping_sub(bar, _1P_MINUS_RC[rc_idx]); + } else if msb0 == 2 { + bar_plus_rc = wrapping_sub(bar, _2P_MINUS_RC[rc_idx]); + } else if !msb1 { + bar_plus_rc = wrapping_sub(bar, _3P_MINUS_RC[rc_idx]); + } else { + bar_plus_rc = wrapping_sub(bar, _4P_MINUS_RC[rc_idx]); + } + let tmp = wrapping_add(bar_plus_rc, x); + reduce_3p(tmp) +} + +#[inline(always)] +fn x0p_plus_sqr2p_plus_rc_eq0p(x: [u64; 4], sqr: [u64; 4], rc_idx: usize) -> [u64; 4] { + let x_plus_sqr = wrapping_add(x, sqr); + let (tmp, b) = overflowing_sub(x_plus_sqr, _1P_MINUS_RC[rc_idx]); + if b { + wrapping_add(x_plus_sqr, RC[rc_idx]) + } else { + reduce_3p(tmp) + } +} + +// #[inline(always)] +fn x0p_plus_sqr1p_plus_rc_eq0p(x: [u64; 4], sqr: [u64; 4], rc_idx: usize) -> [u64; 4] { + let x_plus_sqr = wrapping_add(x, sqr); + let (tmp, b) = overflowing_sub(x_plus_sqr, _1P_MINUS_RC[rc_idx]); + if b { + wrapping_add(x_plus_sqr, RC[rc_idx]) + } else { + reduce_2p(tmp) + } +} + +#[inline(always)] +fn x0p_plus_sqr1p_plus_y0p_eq0p(x: [u64; 4], sqr: [u64; 4], y: [u64; 4]) -> [u64; 4] { + let x_plus_sqr = wrapping_add(x, sqr); + let tmp = wrapping_add(x_plus_sqr, y); + reduce_3p(tmp) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_compress() { + let l = [ + 222647740394868259, + 1954084163509096643, + 7169380306955695398, + 3443405857474191768, + ]; + let r = [ + 650100192727553127, + 2847352847332889852, + 4016598436723263545, + 1563325641941659433, + ]; + let r = compress(l, r); + assert_eq!( + r, + [ + 18095061023341165257, + 7738479748118643198, + 13857889271559191300, + 570841294491851342 + ] + ); + } + + #[test] + fn test_block_compress() { + let l = [ + 222647740394868259, + 1954084163509096643, + 7169380306955695398, + 3443405857474191768, + ]; + let r = [ + 650100192727553127, + 2847352847332889852, + 4016598436723263545, + 1563325641941659433, + ]; + + let _rtz = block_multiplier::rtz::RTZ::set().unwrap(); + + let (r_0, r_1, r_2) = block_compress(&_rtz, l, l, l, r, r, r); + assert_eq!( + r_0, + [ + 18095061023341165257, + 7738479748118643198, + 13857889271559191300, + 570841294491851342 + ] + ); + assert_eq!( + r_1, + [ + 18095061023341165257, + 7738479748118643198, + 13857889271559191300, + 570841294491851342 + ] + ); + assert_eq!( + r_2, + [ + 18095061023341165257, + 7738479748118643198, + 13857889271559191300, + 570841294491851342 + ] + ); + } +} From ac7b1d6a3baf9004815e32c2ecbec6e1d3a51aed Mon Sep 17 00:00:00 2001 From: Remco Bloemen Date: Wed, 7 May 2025 13:47:46 +0200 Subject: [PATCH 7/8] Update fp-rounding --- skyscraper/Cargo.toml | 3 +- skyscraper/benches/bench.rs | 34 +++++++++-------- skyscraper/src/lib.rs | 76 ++++++++++++++++--------------------- 3 files changed, 54 insertions(+), 59 deletions(-) diff --git a/skyscraper/Cargo.toml b/skyscraper/Cargo.toml index 3dfbeb913..24eb29dd0 100644 --- a/skyscraper/Cargo.toml +++ b/skyscraper/Cargo.toml @@ -10,6 +10,7 @@ repository.workspace = true [dependencies] block-multiplier = { path = "../block-multiplier" } +fp-rounding = { path = "../fp-rounding" } [dev-dependencies] rand = "0.9.0" @@ -18,4 +19,4 @@ criterion = "0.5.1" [[bench]] name = "bench" -harness = false \ No newline at end of file +harness = false diff --git a/skyscraper/benches/bench.rs b/skyscraper/benches/bench.rs index c4de5fae2..03b62f523 100644 --- a/skyscraper/benches/bench.rs +++ b/skyscraper/benches/bench.rs @@ -1,6 +1,8 @@ -use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use rand::prelude::StdRng; -use rand::{Rng, SeedableRng}; +use { + criterion::{black_box, criterion_group, criterion_main, Criterion}, + fp_rounding::with_rounding_mode, + rand::{prelude::StdRng, Rng, SeedableRng}, +}; fn bench_skyscraper(c: &mut Criterion) { let mut group = c.benchmark_group("skyscraper"); @@ -45,22 +47,24 @@ fn bench_skyscraper(c: &mut Criterion) { rng.random::(), ]; - let _rtz = block_multiplier::rtz::RTZ::set().unwrap(); - group.bench_function("compress", |bencher| { bencher.iter(|| skyscraper::compress(black_box(l_0), black_box(r_0))) }); - group.bench_function("block_compress", |bencher| { - bencher.iter(|| skyscraper::block_compress( - black_box(&_rtz), - black_box(l_0), - black_box(l_1), - black_box(l_2), - black_box(r_0), - black_box(r_1), - black_box(r_2) - )) + group.bench_function("block_compress", |bencher| unsafe { + with_rounding_mode((), |guard, _| { + bencher.iter(|| { + skyscraper::block_compress( + black_box(guard), + black_box(l_0), + black_box(l_1), + black_box(l_2), + black_box(r_0), + black_box(r_1), + black_box(r_2), + ) + }) + }); }); } diff --git a/skyscraper/src/lib.rs b/skyscraper/src/lib.rs index 067f24d80..262f4475c 100644 --- a/skyscraper/src/lib.rs +++ b/skyscraper/src/lib.rs @@ -1,6 +1,8 @@ #![feature(bigint_helper_methods)] -use block_multiplier::*; -use block_multiplier::rtz::RTZ; +use { + block_multiplier::{block_sqr, scalar_sqr}, + fp_rounding::{RoundingGuard, Zero}, +}; pub const U64_P: [u64; 4] = [ 0x43e1f593f0000001, @@ -315,13 +317,13 @@ pub fn compress(l: [u64; 4], r: [u64; 4]) -> [u64; 4] { #[inline] pub fn block_compress( - _rtz: &RTZ, + _rtz: &RoundingGuard, l_0: [u64; 4], l_1: [u64; 4], l_2: [u64; 4], r_0: [u64; 4], r_1: [u64; 4], - r_2: [u64; 4] + r_2: [u64; 4], ) -> ([u64; 4], [u64; 4], [u64; 4]) { let a_0 = l_0; let a_1 = l_1; @@ -553,7 +555,7 @@ fn x0p_plus_sqr1p_plus_y0p_eq0p(x: [u64; 4], sqr: [u64; 4], y: [u64; 4]) -> [u64 #[cfg(test)] mod tests { - use super::*; + use {super::*, fp_rounding::with_rounding_mode}; #[test] fn test_compress() { @@ -570,15 +572,12 @@ mod tests { 1563325641941659433, ]; let r = compress(l, r); - assert_eq!( - r, - [ - 18095061023341165257, - 7738479748118643198, - 13857889271559191300, - 570841294491851342 - ] - ); + assert_eq!(r, [ + 18095061023341165257, + 7738479748118643198, + 13857889271559191300, + 570841294491851342 + ]); } #[test] @@ -596,35 +595,26 @@ mod tests { 1563325641941659433, ]; - let _rtz = block_multiplier::rtz::RTZ::set().unwrap(); + let (r_0, r_1, r_2) = + unsafe { with_rounding_mode((), |guard, _| block_compress(guard, l, l, l, r, r, r)) }; - let (r_0, r_1, r_2) = block_compress(&_rtz, l, l, l, r, r, r); - assert_eq!( - r_0, - [ - 18095061023341165257, - 7738479748118643198, - 13857889271559191300, - 570841294491851342 - ] - ); - assert_eq!( - r_1, - [ - 18095061023341165257, - 7738479748118643198, - 13857889271559191300, - 570841294491851342 - ] - ); - assert_eq!( - r_2, - [ - 18095061023341165257, - 7738479748118643198, - 13857889271559191300, - 570841294491851342 - ] - ); + assert_eq!(r_0, [ + 18095061023341165257, + 7738479748118643198, + 13857889271559191300, + 570841294491851342 + ]); + assert_eq!(r_1, [ + 18095061023341165257, + 7738479748118643198, + 13857889271559191300, + 570841294491851342 + ]); + assert_eq!(r_2, [ + 18095061023341165257, + 7738479748118643198, + 13857889271559191300, + 570841294491851342 + ]); } } From 4773dbc2ceac432e21d64d2e1f5c0ac66c9ad515 Mon Sep 17 00:00:00 2001 From: Remco Bloemen Date: Wed, 7 May 2025 14:19:18 +0200 Subject: [PATCH 8/8] Benchmark with_rounding_mode --- block-multiplier/benches/bench.rs | 54 ++++++++++++++++++------------- 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/block-multiplier/benches/bench.rs b/block-multiplier/benches/bench.rs index 8e6222508..f9766cd97 100644 --- a/block-multiplier/benches/bench.rs +++ b/block-multiplier/benches/bench.rs @@ -1,5 +1,6 @@ use { criterion::{Criterion, black_box, criterion_group, criterion_main}, + fp_rounding::{Zero, with_rounding_mode}, rand::{Rng, SeedableRng, prelude::StdRng}, }; @@ -48,8 +49,6 @@ fn bench_block_multiplier(c: &mut Criterion) { rng.random::(), ]; - let rtz = rtz::RTZ::set().unwrap(); - group.bench_function("scalar_mul", |bencher| { bencher.iter(|| block_multiplier::scalar_mul(black_box(s0_a), black_box(s0_b))) }); @@ -73,36 +72,45 @@ fn bench_block_multiplier(c: &mut Criterion) { }) }); - group.bench_function("block_mul", |bencher| { - bencher.iter(|| { - block_multiplier::block_mul( - &rtz, - black_box(s0_a), - black_box(s0_b), - black_box(v0_a), - black_box(v0_b), - black_box(v1_a), - black_box(v1_b), - ) + group.bench_function("block_mul", |bencher| unsafe { + with_rounding_mode((), |guard, _| { + bencher.iter(|| { + block_multiplier::block_mul( + guard, + black_box(s0_a), + black_box(s0_b), + black_box(v0_a), + black_box(v0_b), + black_box(v1_a), + black_box(v1_b), + ) + }) }) }); - group.bench_function("block_sqr", |bencher| { - bencher.iter(|| { - block_multiplier::block_sqr(&rtz, black_box(s0_a), black_box(v0_a), black_box(v1_a)) - }) + group.bench_function("block_sqr", |bencher| unsafe { + with_rounding_mode((), |guard, _| { + bencher.iter(|| { + block_multiplier::block_sqr( + guard, + black_box(s0_a), + black_box(v0_a), + black_box(v1_a), + ) + }) + }); }); group.finish(); } fn bench_rtz(c: &mut Criterion) { - let mut group = c.benchmark_group("rtz"); - group.bench_function("rtz", |bencher| { - bencher.iter(|| { - let rtz = block_multiplier::rtz::RTZ::set(); - black_box(rtz.is_some()); - drop(rtz); + let mut group = c.benchmark_group("with_rounding_mode"); + group.bench_function("with_rounding_mode", |bencher| { + bencher.iter(|| unsafe { + with_rounding_mode::((), |guard, _| { + black_box(guard); + }) }) });