diff --git a/Cargo.toml b/Cargo.toml index 7c52d9b3f..28afd5380 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ members = [ "block-multiplier", "block-multiplier-sys", "block-multiplier-codegen", + "skyscraper", ] [workspace.package] diff --git a/block-multiplier/benches/bench.rs b/block-multiplier/benches/bench.rs index 2769e9c80..f9766cd97 100644 --- a/block-multiplier/benches/bench.rs +++ b/block-multiplier/benches/bench.rs @@ -1,5 +1,6 @@ use { criterion::{Criterion, black_box, criterion_group, criterion_main}, + fp_rounding::{Zero, with_rounding_mode}, rand::{Rng, SeedableRng, prelude::StdRng}, }; @@ -48,8 +49,6 @@ fn bench_block_multiplier(c: &mut Criterion) { rng.random::(), ]; - let rtz = rtz::RTZ::set().unwrap(); - group.bench_function("scalar_mul", |bencher| { bencher.iter(|| block_multiplier::scalar_mul(black_box(s0_a), black_box(s0_b))) }); @@ -73,23 +72,45 @@ fn bench_block_multiplier(c: &mut Criterion) { }) }); - group.bench_function("block_mul", |bencher| { - bencher.iter(|| { - block_multiplier::block_mul( - &rtz, - black_box(s0_a), - black_box(s0_b), - black_box(v0_a), - black_box(v0_b), - black_box(v1_a), - black_box(v1_b), - ) + group.bench_function("block_mul", |bencher| unsafe { + with_rounding_mode((), |guard, _| { + bencher.iter(|| { + block_multiplier::block_mul( + guard, + black_box(s0_a), + black_box(s0_b), + black_box(v0_a), + black_box(v0_b), + black_box(v1_a), + black_box(v1_b), + ) + }) }) }); - group.bench_function("block_sqr", |bencher| { - bencher.iter(|| { - block_multiplier::block_sqr(&rtz, black_box(s0_a), black_box(v0_a), black_box(v1_a)) + group.bench_function("block_sqr", |bencher| unsafe { + with_rounding_mode((), |guard, _| { + bencher.iter(|| { + block_multiplier::block_sqr( + guard, + black_box(s0_a), + black_box(v0_a), + black_box(v1_a), + ) + }) + }); + }); + + group.finish(); +} + +fn bench_rtz(c: &mut Criterion) { + let mut group = c.benchmark_group("with_rounding_mode"); + group.bench_function("with_rounding_mode", |bencher| { + bencher.iter(|| unsafe { + with_rounding_mode::((), |guard, _| { + black_box(guard); + }) }) }); @@ -103,6 +124,6 @@ criterion_group!( // Warm up is warm because it literally warms up the pi .warm_up_time(std::time::Duration::new(1,0)) .measurement_time(std::time::Duration::new(10,0)); - targets = bench_block_multiplier + targets = bench_block_multiplier, bench_rtz ); criterion_main!(benches); diff --git a/block-multiplier/src/constants.rs b/block-multiplier/src/constants.rs index a8d1fa96b..87ede2743 100644 --- a/block-multiplier/src/constants.rs +++ b/block-multiplier/src/constants.rs @@ -24,10 +24,10 @@ pub const U64_R: [u64; 4] = [ // R^2 mod P pub const U64_R2: [u64; 4] = [ - 0x1BB8E645AE216DA7, - 0x53FE3AB1E35C59E3, - 0x8C49833D53BB8085, - 0x0216D0B17F4E44A5, + 0x1bb8e645ae216da7, + 0x53fe3ab1e35c59e3, + 0x8c49833d53bb8085, + 0x0216d0b17f4e44a5, ]; // R^-1 mod P @@ -38,29 +38,37 @@ pub const U64_R_INV: [u64; 4] = [ 0x15ebf95182c5551c, ]; -pub const U52_NP0: u64 = 0x1F593EFFFFFFF; +pub const U52_NP0: u64 = 0x1f593efffffff; pub const U52_R2: [u64; 5] = [ - 0x0B852D16DA6F5, - 0xC621620CDDCE3, - 0xAF1B95343FFB6, - 0xC3C15E103E7C2, - 0x00281528FA122, + 0x0b852d16da6f5, + 0xc621620cddce3, + 0xaf1b95343ffb6, + 0xc3c15e103e7c2, + 0x00281528fa122, ]; pub const U52_P: [u64; 5] = [ - 0x1F593F0000001, - 0x4879B9709143E, - 0x181585D2833E8, - 0xA029B85045B68, - 0x030644E72E131, + 0x1f593f0000001, + 0x4879b9709143e, + 0x181585d2833e8, + 0xa029b85045b68, + 0x030644e72e131, ]; pub const U52_2P: [u64; 5] = [ - 0x3EB27E0000002, - 0x90F372E12287C, - 0x302B0BA5067D0, - 0x405370A08B6D0, - 0x060C89CE5C263, + 0x3eb27e0000002, + 0x90f372e12287c, + 0x302b0ba5067d0, + 0x405370a08b6d0, + 0x060c89ce5c263, +]; + +pub const F52_P: [f64; 5] = [ + 0x1f593f0000001_u64 as f64, + 0x4879b9709143e_u64 as f64, + 0x181585d2833e8_u64 as f64, + 0xa029b85045b68_u64 as f64, + 0x030644e72e131_u64 as f64, ]; pub const MASK52: u64 = 2_u64.pow(52) - 1; @@ -80,14 +88,15 @@ pub const U64_I2: [u64; 4] = [ ]; pub const U64_I3: [u64; 4] = [ - 0x9BACB016127CBE4E, - 0x0B2051FA31944124, - 0xB064EEA46091C76C, - 0x2B062AAA49F80C7D, + 0x9bacb016127cbe4e, + 0x0b2051fa31944124, + 0xb064eea46091c76c, + 0x2b062aaa49f80c7d, ]; pub const U64_MU0: u64 = 0xc2e1f593efffffff; -// -- [FP SIMD CONSTANTS] -------------------------------------------------------------------------- +// -- [FP SIMD CONSTANTS] +// -------------------------------------------------------------------------- pub const RHO_1: [u64; 5] = [ 0x82e644ee4c3d2, 0xf93893c98b1de, @@ -105,19 +114,19 @@ pub const RHO_2: [u64; 5] = [ ]; pub const RHO_3: [u64; 5] = [ - 0x0E8C656567D77, - 0x430D05713AE61, - 0xEA3BA6B167128, - 0xA7DAE55C5A296, - 0x01B4AFD513572, + 0x0e8c656567d77, + 0x430d05713ae61, + 0xea3ba6b167128, + 0xa7dae55c5a296, + 0x01b4afd513572, ]; pub const RHO_4: [u64; 5] = [ - 0x22E2400E2F27D, - 0x323B46EA19686, - 0xE6C43F0DF672D, - 0x7824014C39E8B, - 0x00C6B48AFE1B8, + 0x22e2400e2f27d, + 0x323b46ea19686, + 0xe6c43f0df672d, + 0x7824014c39e8b, + 0x00c6b48afe1b8, ]; pub const C1: f64 = pow_2(104); // 2.0^104 @@ -128,6 +137,6 @@ pub const C2: f64 = pow_2(104) + pow_2(52); // 2.0^104 + 2.0^52 const fn pow_2(n: u32) -> f64 { // Unfortunately we can't use f64::powi in const fn yet // This is a workaround that creates the bit pattern directly - let exp = ((n as u64 + 1023) & 0x7FF) << 52; + let exp = ((n as u64 + 1023) & 0x7ff) << 52; f64::from_bits(exp) } diff --git a/block-multiplier/src/rtz.rs b/block-multiplier/src/rtz.rs new file mode 100644 index 000000000..1f8050d9a --- /dev/null +++ b/block-multiplier/src/rtz.rs @@ -0,0 +1,208 @@ +use std::marker::PhantomData; + +/// round-toward-zero mode (bits 22-23 to 0b11) +const FPCR_RMODE_BITS: u64 = 0b11 << 22; + +#[derive(Debug, Clone, Copy, PartialEq)] +enum FPCRState { + /// FPCR is idle and available for modification + Idle, + /// FPCR is actively being used for RTZ operations + Active, +} + +/// Proof that Round Toward Zero (RTZ) has been set +/// +/// This struct must to be passed as a (unused) reference to any function that requires round toward zero +/// for correct operation. The struct serves as a proof that RTZ is set and we rely on the lifetime introduced +/// by the reference to enforce the ordering of the FPCR operations relative to the multiplication. This +/// way we can prevent the reset of FPCR to bubble up in front of the multiplication. +/// +/// This type provides RAII-style management of the AArch64 FPCR (Floating-point Control Register), +/// specifically for controlling the rounding mode. When created, it sets the rounding mode to +/// "round toward zero" and restores the previous mode when dropped. +/// +/// # Safety +/// +/// This type is not Send because FPCR is a per-core / per OS-thread register. The PhantomData<*mut ()> ensures this. +/// Only one instance can exist per thread at a time, enforced by the FPCR_OWNED thread-local. +/// +/// # Design Notes +/// +/// This type +#[derive(Debug)] +pub struct RTZ { + prev_fpcr: u64, + _no_send: PhantomData<*mut ()>, +} + +thread_local! { + /// Thread-local flag to ensure only one RTZ instance exists per thread. + /// This prevents multiple concurrent modifications to the FPCR register. + static FPCR_OWNED: std::cell::Cell = std::cell::Cell::new(FPCRState::Idle); +} + +impl RTZ { + /// Attempts to create a new RTZ instance, setting the FPCR to round-toward-zero mode. + /// + /// Returns None if another RTZ instance already exists in this thread. + /// + /// # Safety + /// + /// This function uses inline assembly to modify the FPCR register. The operations are safe + /// when used as intended through this API, as it maintains the following invariants: + /// - Only one instance can exist per thread + /// - The previous FPCR value is always restored on drop + /// - The type cannot be sent between threads + #[cfg(target_arch = "aarch64")] + #[inline] + pub fn set() -> Option { + // Try to acquire ownership of FPCR + let state = FPCR_OWNED.with(|owned| { + let observed_state = owned.get(); + if observed_state == FPCRState::Idle { + owned.set(FPCRState::Active); + } + observed_state + }); + + match state { + FPCRState::Idle => { + let mut prev_fpcr: u64; + unsafe { + // Read current FPCR value and set round-toward-zero mode + core::arch::asm!( + // Read current FPCR value + "mrs {prev_fpcr}, fpcr", + "orr {tmp}, {prev_fpcr}, {rmode}", + "msr fpcr, {tmp}", + prev_fpcr = out(reg) prev_fpcr, + tmp = out(reg) _, + rmode = const FPCR_RMODE_BITS, + ); + } + + Some(Self { + prev_fpcr, + _no_send: PhantomData, + }) + } + FPCRState::Active => None, + } + } + + #[cfg(not(target_arch = "aarch64"))] + #[inline] + fn set() -> Option { + todo!() + } + + /// Reads the current value of the FPCR register. + /// + /// This method is primarily intended for debugging and verification purposes. + #[cfg(target_arch = "aarch64")] + #[inline] + pub fn read(&self) -> u64 { + let mut value: u64; + unsafe { + core::arch::asm!( + "mrs {}, fpcr", + out(reg) value, + options(nomem, nostack, preserves_flags) + ); + } + value + } + + #[cfg(not(target_arch = "aarch64"))] + #[inline] + pub fn read(&self) -> u64 { + todo!() + } + + /// Writes a new value to the FPCR register. + /// + /// This is a low-level operation that directly modifies the FPCR register. + /// It should be used with caution as improper values can affect floating-point behavior. + /// + /// # Safety + /// + /// This operation is safe because: + /// - The RTZ instance proves we have exclusive access to FPCR + /// - The write operation is atomic + #[cfg(target_arch = "aarch64")] + #[inline] + pub fn write(&self, value: u64) { + unsafe { + core::arch::asm!( + "msr fpcr, {}", + in(reg) value, + options(nomem, nostack, preserves_flags) + ); + } + } + + #[cfg(not(target_arch = "aarch64"))] + #[inline] + pub fn write(&self, _value: u64) { + todo!() + } +} + +impl Drop for RTZ { + /// Restores the original FPCR value and releases the thread-local lock. + /// + /// This ensures that the floating-point environment is restored to its + /// previous state when the RTZ instance is dropped, maintaining the + /// RAII pattern. + fn drop(&mut self) { + // Restore the original FPCR value + self.write(self.prev_fpcr); + // Release the thread-local lock + FPCR_OWNED.with(|owned| owned.set(FPCRState::Idle)); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + #[cfg(target_arch = "aarch64")] + fn test_rtz_single_instance() { + // First instance should succeed + let rtz1 = RTZ::set(); + assert!(rtz1.is_some()); + let rtz1 = rtz1.unwrap(); + let beginning_state = rtz1.prev_fpcr; + + // Second instance should fail + let rtz2 = RTZ::set(); + assert!(rtz2.is_none()); + + // Drop first instance + drop(rtz1); + + // Now we should be able to create a new instance + let rtz3 = RTZ::set(); + assert!(rtz3.is_some()); + let rtz3 = rtz3.unwrap(); + + assert_eq!(beginning_state, rtz3.prev_fpcr); + } + + #[test] + #[cfg(target_arch = "aarch64")] + fn test_rtz_read_write() { + let rtz = RTZ::set().unwrap(); + let initial = rtz.read(); + + // Verify that the rounding mode bits are set + assert_eq!(initial & FPCR_RMODE_BITS, FPCR_RMODE_BITS); + + // Test write and read back + let test_value = initial & !FPCR_RMODE_BITS; // Clear rounding mode bits + rtz.write(test_value); + assert_eq!(rtz.read(), test_value); + } +} diff --git a/skyscraper/Cargo.toml b/skyscraper/Cargo.toml new file mode 100644 index 000000000..24eb29dd0 --- /dev/null +++ b/skyscraper/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "skyscraper" +version = "0.1.0" +edition.workspace = true +rust-version.workspace = true +authors.workspace = true +license.workspace = true +homepage.workspace = true +repository.workspace = true + +[dependencies] +block-multiplier = { path = "../block-multiplier" } +fp-rounding = { path = "../fp-rounding" } + +[dev-dependencies] +rand = "0.9.0" +primitive-types = "0.13.1" +criterion = "0.5.1" + +[[bench]] +name = "bench" +harness = false diff --git a/skyscraper/benches/bench.rs b/skyscraper/benches/bench.rs new file mode 100644 index 000000000..03b62f523 --- /dev/null +++ b/skyscraper/benches/bench.rs @@ -0,0 +1,80 @@ +use { + criterion::{black_box, criterion_group, criterion_main, Criterion}, + fp_rounding::with_rounding_mode, + rand::{prelude::StdRng, Rng, SeedableRng}, +}; + +fn bench_skyscraper(c: &mut Criterion) { + let mut group = c.benchmark_group("skyscraper"); + let seed: u64 = rand::random(); + println!("Using random seed for benchmark: {}", seed); + let mut rng = StdRng::seed_from_u64(seed); + + let l_0 = [ + rng.random::(), + rng.random::(), + rng.random::(), + rng.random::(), + ]; + let l_1 = [ + rng.random::(), + rng.random::(), + rng.random::(), + rng.random::(), + ]; + let l_2 = [ + rng.random::(), + rng.random::(), + rng.random::(), + rng.random::(), + ]; + let r_0 = [ + rng.random::(), + rng.random::(), + rng.random::(), + rng.random::(), + ]; + let r_1 = [ + rng.random::(), + rng.random::(), + rng.random::(), + rng.random::(), + ]; + let r_2 = [ + rng.random::(), + rng.random::(), + rng.random::(), + rng.random::(), + ]; + + group.bench_function("compress", |bencher| { + bencher.iter(|| skyscraper::compress(black_box(l_0), black_box(r_0))) + }); + + group.bench_function("block_compress", |bencher| unsafe { + with_rounding_mode((), |guard, _| { + bencher.iter(|| { + skyscraper::block_compress( + black_box(guard), + black_box(l_0), + black_box(l_1), + black_box(l_2), + black_box(r_0), + black_box(r_1), + black_box(r_2), + ) + }) + }); + }); +} + +criterion_group!( + name = benches; + config = Criterion::default() + .sample_size(5000) + // Warm up is warm because it literally warms up the pi + .warm_up_time(std::time::Duration::new(1,0)) + .measurement_time(std::time::Duration::new(10,0)); + targets = bench_skyscraper +); +criterion_main!(benches); diff --git a/skyscraper/src/lib.rs b/skyscraper/src/lib.rs new file mode 100644 index 000000000..262f4475c --- /dev/null +++ b/skyscraper/src/lib.rs @@ -0,0 +1,620 @@ +#![feature(bigint_helper_methods)] +use { + block_multiplier::{block_sqr, scalar_sqr}, + fp_rounding::{RoundingGuard, Zero}, +}; + +pub const U64_P: [u64; 4] = [ + 0x43e1f593f0000001, + 0x2833e84879b97091, + 0xb85045b68181585d, + 0x30644e72e131a029, +]; +pub const U64_2P: [u64; 4] = [ + 0x87c3eb27e0000002, + 0x5067d090f372e122, + 0x70a08b6d0302b0ba, + 0x60c89ce5c2634053, +]; +pub const U64_3P: [u64; 4] = [ + 0xcba5e0bbd0000003, + 0x789bb8d96d2c51b3, + 0x28f0d12384840917, + 0x912ceb58a394e07d, +]; +pub const U64_4P: [u64; 4] = [ + 0x0f87d64fc0000004, + 0xa0cfa121e6e5c245, + 0xe14116da06056174, + 0xc19139cb84c680a6, +]; +pub const U64_5P: [u64; 4] = [ + 0x5369cbe3b0000005, + 0xc903896a609f32d6, + 0x99915c908786b9d1, + 0xf1f5883e65f820d0, +]; + +pub const RC: [[u64; 4]; 8] = [ + [ + 0x903c4324270bd744, + 0x873125f708a7d269, + 0x081dd27906c83855, + 0x276b1823ea6d7667, + ], + [ + 0x7ac8edbb4b378d71, + 0xe29d79f3d99e2cb7, + 0x751417914c1a5a18, + 0x0cf02bd758a484a6, + ], + [ + 0xfa7adc6769e5bc36, + 0x1c3f8e297cca387d, + 0x0eb7730d63481db0, + 0x25b0e03f18ede544, + ], + [ + 0x57847e652f03cfb7, + 0x33440b9668873404, + 0x955a32e849af80bc, + 0x002882fcbe14ae70, + ], + [ + 0x979231396257d4d7, + 0x29989c3e1b37d3c1, + 0x12ef02b47f1277ba, + 0x039ad8571e2b7a9c, + ], + [ + 0xb5b48465abbb7887, + 0xa72a6bc5e6ba2d2b, + 0x4cd48043712f7b29, + 0x1142d5410fc1fc1a, + ], + [ + 0x7ab2c156059075d3, + 0x17cb3594047999b2, + 0x44f2c93598f289f7, + 0x1d78439f69bc0bec, + ], + [ + 0x05d7a965138b8edb, + 0x36ef35a3d55c48b1, + 0x8ddfb8a1ac6f1628, + 0x258588a508f4ff82, + ], +]; + +pub const _1P_MINUS_RC: [[u64; 4]; 8] = [ + [ + 0xb3a5b26fc8f428bd, + 0xa102c25171119e27, + 0xb032733d7ab92007, + 0x08f9364ef6c429c2, + ], + [ + 0xc91907d8a4c87290, + 0x45966e54a01b43d9, + 0x433c2e253566fe44, + 0x2374229b888d1b83, + ], + [ + 0x4967192c861a43cb, + 0x0bf45a1efcef3813, + 0xa998d2a91e393aad, + 0x0ab36e33c843bae5, + ], + [ + 0xec5d772ec0fc304a, + 0xf4efdcb211323c8c, + 0x22f612ce37d1d7a0, + 0x303bcb76231cf1b9, + ], + [ + 0xac4fc45a8da82b2a, + 0xfe9b4c0a5e819ccf, + 0xa5614302026ee0a2, + 0x2cc9761bc306258d, + ], + [ + 0x8e2d712e4444877a, + 0x81097c8292ff4365, + 0x6b7bc5731051dd33, + 0x1f217931d16fa40f, + ], + [ + 0xc92f343dea6f8a2e, + 0x1068b2b4753fd6de, + 0x735d7c80e88ece66, + 0x12ec0ad37775943d, + ], + [ + 0x3e0a4c2edc747126, + 0xf144b2a4a45d27e0, + 0x2a708d14d5124234, + 0x0adec5cdd83ca0a7, + ], +]; + +pub const _2P_MINUS_RC: [[u64; 4]; 8] = [ + [ + 0xf787a803b8f428be, + 0xc936aa99eacb0eb8, + 0x6882b8f3fc3a7864, + 0x395d84c1d7f5c9ec, + ], + [ + 0x0cfafd6c94c87291, + 0x6dca569d19d4b46b, + 0xfb8c73dbb6e856a1, + 0x53d8710e69bebbac, + ], + [ + 0x8d490ec0761a43cc, + 0x3428426776a8a8a4, + 0x61e9185f9fba930a, + 0x3b17bca6a9755b0f, + ], + [ + 0x303f6cc2b0fc304b, + 0x1d23c4fa8aebad1e, + 0xdb465884b9532ffe, + 0x60a019e9044e91e2, + ], + [ + 0xf031b9ee7da82b2b, + 0x26cf3452d83b0d60, + 0x5db188b883f03900, + 0x5d2dc48ea437c5b7, + ], + [ + 0xd20f66c23444877b, + 0xa93d64cb0cb8b3f6, + 0x23cc0b2991d33590, + 0x4f85c7a4b2a14439, + ], + [ + 0x0d1129d1da6f8a2f, + 0x389c9afceef94770, + 0x2badc2376a1026c3, + 0x4350594658a73467, + ], + [ + 0x81ec41c2cc747127, + 0x19789aed1e169871, + 0xe2c0d2cb56939a92, + 0x3b431440b96e40d0, + ], +]; + +pub const _3P_MINUS_RC: [[u64; 4]; 8] = [ + [ + 0x3b699d97a8f428bf, + 0xf16a92e264847f4a, + 0x20d2feaa7dbbd0c1, + 0x69c1d334b9276a16, + ], + [ + 0x50dcf30084c87292, + 0x95fe3ee5938e24fc, + 0xb3dcb9923869aefe, + 0x843cbf814af05bd6, + ], + [ + 0xd12b0454661a43cd, + 0x5c5c2aaff0621935, + 0x1a395e16213beb67, + 0x6b7c0b198aa6fb39, + ], + [ + 0x74216256a0fc304c, + 0x4557ad4304a51daf, + 0x93969e3b3ad4885b, + 0x9104685be580320c, + ], + [ + 0x3413af826da82b2c, + 0x4f031c9b51f47df2, + 0x1601ce6f0571915d, + 0x8d921301856965e1, + ], + [ + 0x15f15c562444877c, + 0xd1714d1386722488, + 0xdc1c50e013548ded, + 0x7fea161793d2e462, + ], + [ + 0x50f31f65ca6f8a30, + 0x60d0834568b2b801, + 0xe3fe07edeb917f20, + 0x73b4a7b939d8d490, + ], + [ + 0xc5ce3756bc747128, + 0x41ac833597d00902, + 0x9b111881d814f2ef, + 0x6ba762b39a9fe0fa, + ], +]; + +pub const _4P_MINUS_RC: [[u64; 4]; 8] = [ + [ + 0x7f4b932b98f428c0, + 0x199e7b2ade3defdb, + 0xd9234460ff3d291f, + 0x9a2621a79a590a3f, + ], + [ + 0x94bee89474c87293, + 0xbe32272e0d47958d, + 0x6c2cff48b9eb075b, + 0xb4a10df42c21fc00, + ], + [ + 0x150cf9e8561a43ce, + 0x849012f86a1b89c7, + 0xd289a3cca2bd43c4, + 0x9be0598c6bd89b62, + ], + [ + 0xb80357ea90fc304d, + 0x6d8b958b7e5e8e40, + 0x4be6e3f1bc55e0b8, + 0xc168b6cec6b1d236, + ], + [ + 0x77f5a5165da82b2d, + 0x773704e3cbadee83, + 0xce52142586f2e9ba, + 0xbdf66174669b060a, + ], + [ + 0x59d351ea1444877d, + 0xf9a5355c002b9519, + 0x946c969694d5e64a, + 0xb04e648a7504848c, + ], + [ + 0x94d514f9ba6f8a31, + 0x89046b8de26c2892, + 0x9c4e4da46d12d77d, + 0xa418f62c1b0a74ba, + ], + [ + 0x09b02ceaac747129, + 0x69e06b7e11897994, + 0x53615e3859964b4c, + 0x9c0bb1267bd18124, + ], +]; + +#[inline] +pub fn compress(l: [u64; 4], r: [u64; 4]) -> [u64; 4] { + let a = l; + let sqr = scalar_sqr(l); + let (l, r) = (wrapping_add(r, sqr), l); + let sqr = scalar_sqr(l); + let (l, r) = (x0p_plus_sqr3p_plus_rc_eq0p(r, sqr, 0), l); + let bar = bar_u8(l); + let (l, r) = (x2p_plus_bar0p_plus_rc_eq0p(r, bar, 1), l); + let bar = bar_u8(l); + let (l, r) = (x0p_plus_bar0p_plus_rc_eq0p(r, bar, 2), l); + let sqr = scalar_sqr(l); + let (l, r) = (x0p_plus_sqr2p_plus_rc_eq0p(r, sqr, 3), l); + let sqr = scalar_sqr(l); + let (l, r) = (x0p_plus_sqr1p_plus_rc_eq0p(r, sqr, 4), l); + let bar = bar_u8(l); + let (l, r) = (x0p_plus_bar0p_plus_rc_eq0p(r, bar, 5), l); + let bar = bar_u8(l); + let (l, r) = (x0p_plus_bar0p_plus_rc_eq0p(r, bar, 6), l); + let sqr = scalar_sqr(l); + let (l, r) = (x0p_plus_sqr1p_plus_rc_eq0p(r, sqr, 7), l); + let sqr = scalar_sqr(l); + x0p_plus_sqr1p_plus_y0p_eq0p(r, sqr, a) +} + +#[inline] +pub fn block_compress( + _rtz: &RoundingGuard, + l_0: [u64; 4], + l_1: [u64; 4], + l_2: [u64; 4], + r_0: [u64; 4], + r_1: [u64; 4], + r_2: [u64; 4], +) -> ([u64; 4], [u64; 4], [u64; 4]) { + let a_0 = l_0; + let a_1 = l_1; + let a_2 = l_2; + let (sqr_0, sqr_1, sqr_2) = block_sqr(_rtz, l_0, l_1, l_2); + let (l_0, r_0) = (wrapping_add(r_0, sqr_0), l_0); + let (l_1, r_1) = (wrapping_add(r_1, sqr_1), l_1); + let (l_2, r_2) = (wrapping_add(r_2, sqr_2), l_2); + let (sqr_0, sqr_1, sqr_2) = block_sqr(_rtz, l_0, l_1, l_2); + let (l_0, r_0) = (x0p_plus_sqr3p_plus_rc_eq0p(r_0, sqr_0, 0), l_0); + let (l_1, r_1) = (x0p_plus_sqr3p_plus_rc_eq0p(r_1, sqr_1, 0), l_1); + let (l_2, r_2) = (x0p_plus_sqr3p_plus_rc_eq0p(r_2, sqr_2, 0), l_2); + let bar_0 = bar_u8(l_0); + let bar_1 = bar_u8(l_1); + let bar_2 = bar_u8(l_2); + let (l_0, r_0) = (x2p_plus_bar0p_plus_rc_eq0p(r_0, bar_0, 1), l_0); + let (l_1, r_1) = (x2p_plus_bar0p_plus_rc_eq0p(r_1, bar_1, 1), l_1); + let (l_2, r_2) = (x2p_plus_bar0p_plus_rc_eq0p(r_2, bar_2, 1), l_2); + let bar_0 = bar_u8(l_0); + let bar_1 = bar_u8(l_1); + let bar_2 = bar_u8(l_2); + let (l_0, r_0) = (x0p_plus_bar0p_plus_rc_eq0p(r_0, bar_0, 2), l_0); + let (l_1, r_1) = (x0p_plus_bar0p_plus_rc_eq0p(r_1, bar_1, 2), l_1); + let (l_2, r_2) = (x0p_plus_bar0p_plus_rc_eq0p(r_2, bar_2, 2), l_2); + let (sqr_0, sqr_1, sqr_2) = block_sqr(_rtz, l_0, l_1, l_2); + let (l_0, r_0) = (x0p_plus_sqr2p_plus_rc_eq0p(r_0, sqr_0, 3), l_0); + let (l_1, r_1) = (x0p_plus_sqr2p_plus_rc_eq0p(r_1, sqr_1, 3), l_1); + let (l_2, r_2) = (x0p_plus_sqr2p_plus_rc_eq0p(r_2, sqr_2, 3), l_2); + let (sqr_0, sqr_1, sqr_2) = block_sqr(_rtz, l_0, l_1, l_2); + let (l_0, r_0) = (x0p_plus_sqr1p_plus_rc_eq0p(r_0, sqr_0, 4), l_0); + let (l_1, r_1) = (x0p_plus_sqr1p_plus_rc_eq0p(r_1, sqr_1, 4), l_1); + let (l_2, r_2) = (x0p_plus_sqr1p_plus_rc_eq0p(r_2, sqr_2, 4), l_2); + let bar_0 = bar_u8(l_0); + let bar_1 = bar_u8(l_1); + let bar_2 = bar_u8(l_2); + let (l_0, r_0) = (x0p_plus_bar0p_plus_rc_eq0p(r_0, bar_0, 5), l_0); + let (l_1, r_1) = (x0p_plus_bar0p_plus_rc_eq0p(r_1, bar_1, 5), l_1); + let (l_2, r_2) = (x0p_plus_bar0p_plus_rc_eq0p(r_2, bar_2, 5), l_2); + let bar_0 = bar_u8(l_0); + let bar_1 = bar_u8(l_1); + let bar_2 = bar_u8(l_2); + let (l_0, r_0) = (x0p_plus_bar0p_plus_rc_eq0p(r_0, bar_0, 6), l_0); + let (l_1, r_1) = (x0p_plus_bar0p_plus_rc_eq0p(r_1, bar_1, 6), l_1); + let (l_2, r_2) = (x0p_plus_bar0p_plus_rc_eq0p(r_2, bar_2, 6), l_2); + let (sqr_0, sqr_1, sqr_2) = block_sqr(_rtz, l_0, l_1, l_2); + let (l_0, r_0) = (x0p_plus_sqr1p_plus_rc_eq0p(r_0, sqr_0, 7), l_0); + let (l_1, r_1) = (x0p_plus_sqr1p_plus_rc_eq0p(r_1, sqr_1, 7), l_1); + let (l_2, r_2) = (x0p_plus_sqr1p_plus_rc_eq0p(r_2, sqr_2, 7), l_2); + let (sqr_0, sqr_1, sqr_2) = block_sqr(_rtz, l_0, l_1, l_2); + let l_0 = x0p_plus_sqr1p_plus_y0p_eq0p(r_0, sqr_0, a_0); + let l_1 = x0p_plus_sqr1p_plus_y0p_eq0p(r_1, sqr_1, a_1); + let l_2 = x0p_plus_sqr1p_plus_y0p_eq0p(r_2, sqr_2, a_2); + (l_0, l_1, l_2) +} + +#[inline(always)] +fn wrapping_add(x: [u64; 4], y: [u64; 4]) -> [u64; 4] { + let x_u128 = unsafe { std::mem::transmute::<[u64; 4], [u128; 2]>(x) }; + let y_u128 = unsafe { std::mem::transmute::<[u64; 4], [u128; 2]>(y) }; + let (lo, c) = x_u128[0].overflowing_add(y_u128[0]); + let (hi, _) = x_u128[1].carrying_add(y_u128[1], c); + unsafe { std::mem::transmute::<[u128; 2], [u64; 4]>([lo, hi]) } +} + +#[inline(always)] +fn wrapping_sub(x: [u64; 4], y: [u64; 4]) -> [u64; 4] { + let x_u128 = unsafe { std::mem::transmute::<[u64; 4], [u128; 2]>(x) }; + let y_u128 = unsafe { std::mem::transmute::<[u64; 4], [u128; 2]>(y) }; + let (lo, b) = x_u128[0].overflowing_sub(y_u128[0]); + let (hi, _) = x_u128[1].borrowing_sub(y_u128[1], b); + unsafe { std::mem::transmute::<[u128; 2], [u64; 4]>([lo, hi]) } +} + +#[inline(always)] +fn overflowing_sub(x: [u64; 4], y: [u64; 4]) -> ([u64; 4], bool) { + let x_u128 = unsafe { std::mem::transmute::<[u64; 4], [u128; 2]>(x) }; + let y_u128 = unsafe { std::mem::transmute::<[u64; 4], [u128; 2]>(y) }; + let (lo, b) = x_u128[0].overflowing_sub(y_u128[0]); + let (hi, b) = x_u128[1].borrowing_sub(y_u128[1], b); + ( + unsafe { std::mem::transmute::<[u128; 2], [u64; 4]>([lo, hi]) }, + b, + ) +} + +#[inline(always)] +fn reduce_1p(x: [u64; 4]) -> [u64; 4] { + let (xr, c) = overflowing_sub(x, U64_P); + if c { + x + } else { + xr + } +} + +#[inline(always)] +fn reduce_2p(x: [u64; 4]) -> [u64; 4] { + let msb0 = (x[3] >> 63) != 0; + let msb1 = ((x[3] << 1) >> 63) != 0; + if msb0 { + wrapping_sub(x, U64_2P) + } else if msb1 { + reduce_1p(wrapping_sub(x, U64_P)) + } else { + reduce_1p(x) + } +} + +#[inline(always)] +fn reduce_3p(x: [u64; 4]) -> [u64; 4] { + let msb0 = (x[3] >> 63) != 0; + let msb1 = ((x[3] << 1) >> 63) != 0; + if msb0 { + reduce_1p(wrapping_sub(x, U64_2P)) + } else if msb1 { + reduce_1p(wrapping_sub(x, U64_P)) + } else { + reduce_1p(x) + } +} + +#[inline(always)] +fn reduce_4p(x: [u64; 4]) -> [u64; 4] { + let msb = (x[3] >> 62) as u8; + if msb == 0 { + reduce_1p(x) + } else { + let r = if msb == 1 { + U64_P + } else if msb == 2 { + U64_2P + } else { + U64_3P + }; + reduce_1p(wrapping_sub(x, r)) + } +} + +#[inline(always)] +fn bar_u8(x: [u64; 4]) -> [u64; 4] { + let mut x_u8 = unsafe { std::mem::transmute::<[u64; 4], [u8; 32]>(x) }; + for i in 0..32 { + let v = x_u8[i]; + x_u8[i] = (v ^ ((!v).rotate_left(1) & v.rotate_left(2) & v.rotate_left(3))).rotate_left(1); + } + let x = unsafe { std::mem::transmute::<[u8; 32], [u64; 4]>(x_u8) }; + [x[2], x[3], x[0], x[1]] +} + +#[inline(always)] +fn x0p_plus_sqr3p_plus_rc_eq0p(x: [u64; 4], sqr: [u64; 4], rc_idx: usize) -> [u64; 4] { + let x_plus_sqr = wrapping_add(x, sqr); + let (tmp, b) = overflowing_sub(x_plus_sqr, _1P_MINUS_RC[rc_idx]); + if b { + wrapping_add(x_plus_sqr, RC[rc_idx]) + } else { + reduce_4p(tmp) + } +} + +#[inline(always)] +fn x2p_plus_bar0p_plus_rc_eq0p(x: [u64; 4], bar: [u64; 4], rc_idx: usize) -> [u64; 4] { + let msb0 = (bar[3] >> 62) as u8; + let msb1 = ((bar[3] << 2) >> 63) != 0; + let bar_plus_rc; + if msb0 == 0 { + bar_plus_rc = wrapping_add(bar, RC[rc_idx]); + } else if msb0 == 1 { + bar_plus_rc = wrapping_sub(bar, _1P_MINUS_RC[rc_idx]); + } else if msb0 == 2 { + bar_plus_rc = wrapping_sub(bar, _2P_MINUS_RC[rc_idx]); + } else if !msb1 { + bar_plus_rc = wrapping_sub(bar, _3P_MINUS_RC[rc_idx]); + } else { + bar_plus_rc = wrapping_sub(bar, _4P_MINUS_RC[rc_idx]); + } + let tmp = wrapping_add(bar_plus_rc, x); + reduce_4p(tmp) +} + +#[inline(always)] +fn x0p_plus_bar0p_plus_rc_eq0p(x: [u64; 4], bar: [u64; 4], rc_idx: usize) -> [u64; 4] { + let msb0 = (bar[3] >> 62) as u8; + let msb1 = ((bar[3] << 2) >> 63) != 0; + let bar_plus_rc; + if msb0 == 0 { + bar_plus_rc = wrapping_add(bar, RC[rc_idx]); + } else if msb0 == 1 { + bar_plus_rc = wrapping_sub(bar, _1P_MINUS_RC[rc_idx]); + } else if msb0 == 2 { + bar_plus_rc = wrapping_sub(bar, _2P_MINUS_RC[rc_idx]); + } else if !msb1 { + bar_plus_rc = wrapping_sub(bar, _3P_MINUS_RC[rc_idx]); + } else { + bar_plus_rc = wrapping_sub(bar, _4P_MINUS_RC[rc_idx]); + } + let tmp = wrapping_add(bar_plus_rc, x); + reduce_3p(tmp) +} + +#[inline(always)] +fn x0p_plus_sqr2p_plus_rc_eq0p(x: [u64; 4], sqr: [u64; 4], rc_idx: usize) -> [u64; 4] { + let x_plus_sqr = wrapping_add(x, sqr); + let (tmp, b) = overflowing_sub(x_plus_sqr, _1P_MINUS_RC[rc_idx]); + if b { + wrapping_add(x_plus_sqr, RC[rc_idx]) + } else { + reduce_3p(tmp) + } +} + +// #[inline(always)] +fn x0p_plus_sqr1p_plus_rc_eq0p(x: [u64; 4], sqr: [u64; 4], rc_idx: usize) -> [u64; 4] { + let x_plus_sqr = wrapping_add(x, sqr); + let (tmp, b) = overflowing_sub(x_plus_sqr, _1P_MINUS_RC[rc_idx]); + if b { + wrapping_add(x_plus_sqr, RC[rc_idx]) + } else { + reduce_2p(tmp) + } +} + +#[inline(always)] +fn x0p_plus_sqr1p_plus_y0p_eq0p(x: [u64; 4], sqr: [u64; 4], y: [u64; 4]) -> [u64; 4] { + let x_plus_sqr = wrapping_add(x, sqr); + let tmp = wrapping_add(x_plus_sqr, y); + reduce_3p(tmp) +} + +#[cfg(test)] +mod tests { + use {super::*, fp_rounding::with_rounding_mode}; + + #[test] + fn test_compress() { + let l = [ + 222647740394868259, + 1954084163509096643, + 7169380306955695398, + 3443405857474191768, + ]; + let r = [ + 650100192727553127, + 2847352847332889852, + 4016598436723263545, + 1563325641941659433, + ]; + let r = compress(l, r); + assert_eq!(r, [ + 18095061023341165257, + 7738479748118643198, + 13857889271559191300, + 570841294491851342 + ]); + } + + #[test] + fn test_block_compress() { + let l = [ + 222647740394868259, + 1954084163509096643, + 7169380306955695398, + 3443405857474191768, + ]; + let r = [ + 650100192727553127, + 2847352847332889852, + 4016598436723263545, + 1563325641941659433, + ]; + + let (r_0, r_1, r_2) = + unsafe { with_rounding_mode((), |guard, _| block_compress(guard, l, l, l, r, r, r)) }; + + assert_eq!(r_0, [ + 18095061023341165257, + 7738479748118643198, + 13857889271559191300, + 570841294491851342 + ]); + assert_eq!(r_1, [ + 18095061023341165257, + 7738479748118643198, + 13857889271559191300, + 570841294491851342 + ]); + assert_eq!(r_2, [ + 18095061023341165257, + 7738479748118643198, + 13857889271559191300, + 570841294491851342 + ]); + } +}