diff --git a/ci/big_quickcheck/src/lib.rs b/ci/big_quickcheck/src/lib.rs index 47ba2e14..57dd8b73 100644 --- a/ci/big_quickcheck/src/lib.rs +++ b/ci/big_quickcheck/src/lib.rs @@ -358,6 +358,63 @@ fn quickcheck_modpow() { qc.quickcheck(test_modpow as fn(i128, u128, i128) -> TestResult); } +#[test] +fn quickcheck_modinv() { + let gen = Gen::new(usize::max_value()); + let mut qc = QuickCheck::new().gen(gen); + + fn test_modinv(value: i128, modulus: i128) -> TestResult { + if modulus.is_zero() { + TestResult::discard() + } else { + let value = BigInt::from(value); + let modulus = BigInt::from(modulus); + match (value.modinv(&modulus), value.gcd(&modulus).is_one()) { + (None, false) => TestResult::passed(), + (None, true) => { + eprintln!("{}.modinv({}) -> None, expected Some(_)", value, modulus); + TestResult::failed() + } + (Some(inverse), false) => { + eprintln!( + "{}.modinv({}) -> Some({}), expected None", + value, modulus, inverse + ); + TestResult::failed() + } + (Some(inverse), true) => { + // The inverse should either be in [0,m) or (m,0] + let zero = BigInt::zero(); + if (modulus.is_positive() && !(zero <= inverse && inverse < modulus)) + || (modulus.is_negative() && !(modulus < inverse && inverse <= zero)) + { + eprintln!( + "{}.modinv({}) -> Some({}) is out of range", + value, modulus, inverse + ); + return TestResult::failed(); + } + + // We don't know the expected inverse, but we can verify the product ≡ 1 + let product = (&value * &inverse).mod_floor(&modulus); + let mod_one = BigInt::one().mod_floor(&modulus); + if product != mod_one { + eprintln!("{}.modinv({}) -> Some({})", value, modulus, inverse); + eprintln!( + "{} * {} ≡ {}, expected {}", + value, inverse, product, mod_one + ); + return TestResult::failed(); + } + TestResult::passed() + } + } + } + } + + qc.quickcheck(test_modinv as fn(i128, i128) -> TestResult); +} + #[test] fn quickcheck_to_float_equals_i128_cast() { let gen = Gen::new(usize::max_value()); diff --git a/src/bigint.rs b/src/bigint.rs index 97faa834..1716b585 100644 --- a/src/bigint.rs +++ b/src/bigint.rs @@ -1025,6 +1025,64 @@ impl BigInt { power::modpow(self, exponent, modulus) } + /// Returns the modular multiplicative inverse if it exists, otherwise `None`. + /// + /// This solves for `x` such that `self * x ≡ 1 (mod modulus)`. + /// Note that this rounds like `mod_floor`, not like the `%` operator, + /// which makes a difference when given a negative `self` or `modulus`. + /// The solution will be in the interval `[0, modulus)` for `modulus > 0`, + /// or in the interval `(modulus, 0]` for `modulus < 0`, + /// and it exists if and only if `gcd(self, modulus) == 1`. + /// + /// ``` + /// use num_bigint::BigInt; + /// use num_integer::Integer; + /// use num_traits::{One, Zero}; + /// + /// let m = BigInt::from(383); + /// + /// // Trivial cases + /// assert_eq!(BigInt::zero().modinv(&m), None); + /// assert_eq!(BigInt::one().modinv(&m), Some(BigInt::one())); + /// let neg1 = &m - 1u32; + /// assert_eq!(neg1.modinv(&m), Some(neg1)); + /// + /// // Positive self and modulus + /// let a = BigInt::from(271); + /// let x = a.modinv(&m).unwrap(); + /// assert_eq!(x, BigInt::from(106)); + /// assert_eq!(x.modinv(&m).unwrap(), a); + /// assert_eq!((&a * x).mod_floor(&m), BigInt::one()); + /// + /// // Negative self and positive modulus + /// let b = -&a; + /// let x = b.modinv(&m).unwrap(); + /// assert_eq!(x, BigInt::from(277)); + /// assert_eq!((&b * x).mod_floor(&m), BigInt::one()); + /// + /// // Positive self and negative modulus + /// let n = -&m; + /// let x = a.modinv(&n).unwrap(); + /// assert_eq!(x, BigInt::from(-277)); + /// assert_eq!((&a * x).mod_floor(&n), &n + 1); + /// + /// // Negative self and modulus + /// let x = b.modinv(&n).unwrap(); + /// assert_eq!(x, BigInt::from(-106)); + /// assert_eq!((&b * x).mod_floor(&n), &n + 1); + /// ``` + pub fn modinv(&self, modulus: &Self) -> Option { + let result = self.data.modinv(&modulus.data)?; + // The sign of the result follows the modulus, like `mod_floor`. + let (sign, mag) = match (self.is_negative(), modulus.is_negative()) { + (false, false) => (Plus, result), + (true, false) => (Plus, &modulus.data - result), + (false, true) => (Minus, &modulus.data - result), + (true, true) => (Minus, result), + }; + Some(BigInt::from_biguint(sign, mag)) + } + /// Returns the truncated principal square root of `self` -- /// see [`num_integer::Roots::sqrt()`]. pub fn sqrt(&self) -> Self { diff --git a/src/biguint.rs b/src/biguint.rs index 1554eb0f..3963a5cb 100644 --- a/src/biguint.rs +++ b/src/biguint.rs @@ -877,6 +877,86 @@ impl BigUint { power::modpow(self, exponent, modulus) } + /// Returns the modular multiplicative inverse if it exists, otherwise `None`. + /// + /// This solves for `x` in the interval `[0, modulus)` such that `self * x ≡ 1 (mod modulus)`. + /// The solution exists if and only if `gcd(self, modulus) == 1`. + /// + /// ``` + /// use num_bigint::BigUint; + /// use num_traits::{One, Zero}; + /// + /// let m = BigUint::from(383_u32); + /// + /// // Trivial cases + /// assert_eq!(BigUint::zero().modinv(&m), None); + /// assert_eq!(BigUint::one().modinv(&m), Some(BigUint::one())); + /// let neg1 = &m - 1u32; + /// assert_eq!(neg1.modinv(&m), Some(neg1)); + /// + /// let a = BigUint::from(271_u32); + /// let x = a.modinv(&m).unwrap(); + /// assert_eq!(x, BigUint::from(106_u32)); + /// assert_eq!(x.modinv(&m).unwrap(), a); + /// assert!((a * x % m).is_one()); + /// ``` + pub fn modinv(&self, modulus: &Self) -> Option { + // Based on the inverse pseudocode listed here: + // https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm#Modular_integers + // TODO: convert to extended *binary* GCD so we're shifting instead of dividing. + + assert!( + !modulus.is_zero(), + "attempt to calculate with zero modulus!" + ); + if modulus.is_one() { + return Some(Self::zero()); + } + + let mut r0; // = modulus.clone(); + let mut r1 = self % modulus; + let mut t0; // = Self::zero(); + let mut t1; // = Self::one(); + + // Lift and simplify the first iteration to avoid some initial allocations. + if r1.is_zero() { + return None; + } else if r1.is_one() { + return Some(r1); + } else { + let (q, r2) = modulus.div_rem(&r1); + if r2.is_zero() { + return None; + } + r0 = r1; + r1 = r2; + t0 = Self::one(); + t1 = modulus - q; + } + + while !r1.is_zero() { + let (q, r2) = r0.div_rem(&r1); + r0 = r1; + r1 = r2; + + // let t2 = (t0 - q * t1) % modulus; + let qt1 = q * &t1 % modulus; + let t2 = if t0 < qt1 { + t0 + (modulus - qt1) + } else { + t0 - qt1 + }; + t0 = t1; + t1 = t2; + } + + if r0.is_one() { + Some(t0) + } else { + None + } + } + /// Returns the truncated principal square root of `self` -- /// see [Roots::sqrt](https://docs.rs/num-integer/0.1/num_integer/trait.Roots.html#method.sqrt) pub fn sqrt(&self) -> Self {