diff --git a/src/library/scala/math/BigInt.scala b/src/library/scala/math/BigInt.scala
index a88b1371ccc6..d0018fc8e970 100644
--- a/src/library/scala/math/BigInt.scala
+++ b/src/library/scala/math/BigInt.scala
@@ -124,6 +124,46 @@ object BigInt {
/** Implicit conversion from `java.math.BigInteger` to `scala.BigInt`.
*/
implicit def javaBigInteger2bigInt(x: BigInteger): BigInt = apply(x)
+
+ /**
+ * Returns the greatest common divisor of a and b. Returns 0 if a == 0 && b == 0.
+ */
+ private def longGcd(a: Long, b: Long): Long = {
+ // code adapted from Google Guava LongMath.java / gcd
+ if (a == 0) { // 0 % b == 0, so b divides a, but the converse doesn't hold.
+ // BigInteger.gcd is consistent with this decision.
+ return b
+ }
+ else if (b == 0) return a // similar logic
+ /*
+ * Uses the binary GCD algorithm; see http://en.wikipedia.org/wiki/Binary_GCD_algorithm. This is
+ * >60% faster than the Euclidean algorithm in benchmarks.
+ */
+ val aTwos = java.lang.Long.numberOfTrailingZeros(a)
+ var a1 = a >> aTwos // divide out all 2s
+
+ val bTwos = java.lang.Long.numberOfTrailingZeros(b)
+ var b1 = b >> bTwos
+ while (a1 != b1) { // both a, b are odd
+ // The key to the binary GCD algorithm is as follows:
+ // Both a1 and b1 are odd. Assume a1 > b1; then gcd(a1 - b1, b1) = gcd(a1, b1).
+ // But in gcd(a1 - b1, b1), a1 - b1 is even and b1 is odd, so we can divide out powers of two.
+ // We bend over backwards to avoid branching, adapting a technique from
+ // http://graphics.stanford.edu/~seander/bithacks.html#IntegerMinOrMax
+ val delta = a1 - b1 // can't overflow, since a1 and b1 are nonnegative
+ val minDeltaOrZero = delta & (delta >> (java.lang.Long.SIZE - 1))
+ // equivalent to Math.min(delta, 0)
+ a1 = delta - minDeltaOrZero - minDeltaOrZero // sets a to Math.abs(a - b)
+
+ // a is now nonnegative and even
+ b1 += minDeltaOrZero // sets b to min(old a, b)
+
+ a1 >>= java.lang.Long.numberOfTrailingZeros(a1) // divide out all 2s, since 2 doesn't divide b
+
+ }
+ a1 << scala.math.min(aTwos, bTwos)
+ }
+
}
/** A type with efficient encoding of arbitrary integers.
@@ -162,6 +202,10 @@ final class BigInt private (private var _bigInteger: BigInteger, private val _lo
else Long.MinValue
)
+ /** Returns whether the integer is encoded in the Long. Returns true for all values fitting in a Long except
+ * Long.MinValue. */
+ private def longEncoding: Boolean = _long != Long.MinValue
+
def bigInteger: BigInteger = {
val read = _bigInteger
if (read ne null) read else {
@@ -185,11 +229,13 @@ final class BigInt private (private var _bigInteger: BigInteger, private val _lo
case that: Float => isValidFloat && toFloat == that
case x => isValidLong && unifiedPrimitiveEquals(x)
}
- override def isValidByte: Boolean = this >= Byte.MinValue && this <= Byte.MaxValue
- override def isValidShort: Boolean = this >= Short.MinValue && this <= Short.MaxValue
- override def isValidChar: Boolean = this >= Char.MinValue && this <= Char.MaxValue
- override def isValidInt: Boolean = this >= Int.MinValue && this <= Int.MaxValue
- def isValidLong: Boolean = this >= Long.MinValue && this <= Long.MaxValue
+
+ override def isValidByte: Boolean = _long >= Byte.MinValue && _long <= Byte.MaxValue /* && longEncoding */
+ override def isValidShort: Boolean = _long >= Short.MinValue && _long <= Short.MaxValue /* && longEncoding */
+ override def isValidChar: Boolean = _long >= Char.MinValue && _long <= Char.MaxValue /* && longEncoding */
+ override def isValidInt: Boolean = _long >= Int.MinValue && _long <= Int.MaxValue /* && longEncoding */
+ def isValidLong: Boolean = longEncoding || _bigInteger == BigInt.longMinValueBigInteger // rhs of || tests == Long.MinValue
+
/** Returns `true` iff this can be represented exactly by [[scala.Float]]; otherwise returns `false`.
*/
def isValidFloat: Boolean = {
@@ -231,151 +277,266 @@ final class BigInt private (private var _bigInteger: BigInteger, private val _lo
/** Compares this BigInt with the specified BigInt for equality.
*/
- def equals (that: BigInt): Boolean = compare(that) == 0
+ def equals(that: BigInt): Boolean =
+ if (this.longEncoding)
+ that.longEncoding && (this._long == that._long)
+ else
+ !that.longEncoding && (this._bigInteger == that._bigInteger)
/** Compares this BigInt with the specified BigInt
*/
- def compare (that: BigInt): Int = this.bigInteger.compareTo(that.bigInteger)
+ def compare(that: BigInt): Int =
+ if (this.longEncoding) {
+ if (that.longEncoding) java.lang.Long.compare(this._long, that._long) else -that._bigInteger.signum()
+ } else {
+ if (that.longEncoding) _bigInteger.signum() else this._bigInteger.compareTo(that._bigInteger)
+ }
/** Addition of BigInts
*/
- def + (that: BigInt): BigInt = BigInt(this.bigInteger.add(that.bigInteger))
+ def +(that: BigInt): BigInt = {
+ if (this.longEncoding && that.longEncoding) { // fast path
+ val x = this._long
+ val y = that._long
+ val z = x + y
+ if ((~(x ^ y) & (x ^ z)) >= 0L) return BigInt(z)
+ }
+ BigInt(this.bigInteger.add(that.bigInteger))
+ }
/** Subtraction of BigInts
*/
- def - (that: BigInt): BigInt = BigInt(this.bigInteger.subtract(that.bigInteger))
+ def -(that: BigInt): BigInt = {
+ if (this.longEncoding && that.longEncoding) { // fast path
+ val x = this._long
+ val y = that._long
+ val z = x - y
+ if (((x ^ y) & (x ^ z)) >= 0L) return BigInt(z)
+ }
+ BigInt(this.bigInteger.subtract(that.bigInteger))
+ }
/** Multiplication of BigInts
*/
- def * (that: BigInt): BigInt = BigInt(this.bigInteger.multiply(that.bigInteger))
+ def *(that: BigInt): BigInt = {
+ if (this.longEncoding && that.longEncoding) { // fast path
+ val x = this._long
+ val y = that._long
+ val z = x * y
+ // original code checks the y != Long.MinValue, but when longEncoding is true, that is never the case
+ // if (x == 0 || (y == z / x && !(x == -1 && y == Long.MinValue))) return BigInt(z)
+ if (x == 0 || y == z / x) return BigInt(z)
+ }
+ BigInt(this.bigInteger.multiply(that.bigInteger))
+ }
/** Division of BigInts
*/
- def / (that: BigInt): BigInt = BigInt(this.bigInteger.divide(that.bigInteger))
+ def /(that: BigInt): BigInt =
+ // in the fast path, note that the original code avoided storing -Long.MinValue in a long:
+ // if (this._long != Long.MinValue || that._long != -1) return BigInt(this._long / that._long)
+ // but we know this._long cannot be Long.MinValue, because Long.MinValue is the tag for bigger integers
+ if (this.longEncoding && that.longEncoding) BigInt(this._long / that._long)
+ else BigInt(this.bigInteger.divide(that.bigInteger))
/** Remainder of BigInts
*/
- def % (that: BigInt): BigInt = BigInt(this.bigInteger.remainder(that.bigInteger))
+ def %(that: BigInt): BigInt =
+ // see / for the original logic regarding Long.MinValue
+ if (this.longEncoding && that.longEncoding) BigInt(this._long % that._long)
+ else BigInt(this.bigInteger.remainder(that.bigInteger))
/** Returns a pair of two BigInts containing (this / that) and (this % that).
*/
- def /% (that: BigInt): (BigInt, BigInt) = {
- val dr = this.bigInteger.divideAndRemainder(that.bigInteger)
- (BigInt(dr(0)), BigInt(dr(1)))
- }
+ def /%(that: BigInt): (BigInt, BigInt) =
+ if (this.longEncoding && that.longEncoding) {
+ val x = this._long
+ val y = that._long
+ // original line: if (x != Long.MinValue || y != -1) return (BigInt(x / y), BigInt(x % y))
+ (BigInt(x / y), BigInt(x % y))
+ } else {
+ val dr = this.bigInteger.divideAndRemainder(that.bigInteger)
+ (BigInt(dr(0)), BigInt(dr(1)))
+ }
/** Leftshift of BigInt
*/
- def << (n: Int): BigInt = BigInt(this.bigInteger.shiftLeft(n))
+ def <<(n: Int): BigInt =
+ if (longEncoding && n <= 0) (this >> (-n)) else BigInt(this.bigInteger.shiftLeft(n))
/** (Signed) rightshift of BigInt
*/
- def >> (n: Int): BigInt = BigInt(this.bigInteger.shiftRight(n))
-
+ def >>(n: Int): BigInt =
+ if (longEncoding && n >= 0) {
+ if (n < 64) BigInt(_long >> n)
+ else if (_long < 0) BigInt(-1)
+ else BigInt(0) // for _long >= 0
+ } else BigInt(this.bigInteger.shiftRight(n))
+
/** Bitwise and of BigInts
*/
- def & (that: BigInt): BigInt = BigInt(this.bigInteger.and(that.bigInteger))
+ def &(that: BigInt): BigInt =
+ if (this.longEncoding && that.longEncoding)
+ BigInt(this._long & that._long)
+ else BigInt(this.bigInteger.and(that.bigInteger))
/** Bitwise or of BigInts
*/
- def | (that: BigInt): BigInt = BigInt(this.bigInteger.or (that.bigInteger))
+ def |(that: BigInt): BigInt =
+ if (this.longEncoding && that.longEncoding)
+ BigInt(this._long | that._long)
+ else BigInt(this.bigInteger.or(that.bigInteger))
/** Bitwise exclusive-or of BigInts
*/
- def ^ (that: BigInt): BigInt = BigInt(this.bigInteger.xor(that.bigInteger))
+ def ^(that: BigInt): BigInt =
+ if (this.longEncoding && that.longEncoding)
+ BigInt(this._long ^ that._long)
+ else BigInt(this.bigInteger.xor(that.bigInteger))
/** Bitwise and-not of BigInts. Returns a BigInt whose value is (this & ~that).
*/
- def &~ (that: BigInt): BigInt = BigInt(this.bigInteger.andNot(that.bigInteger))
+ def &~(that: BigInt): BigInt =
+ if (this.longEncoding && that.longEncoding)
+ BigInt(this._long & ~that._long)
+ else BigInt(this.bigInteger.andNot(that.bigInteger))
/** Returns the greatest common divisor of abs(this) and abs(that)
*/
- def gcd (that: BigInt): BigInt = BigInt(this.bigInteger.gcd(that.bigInteger))
+ def gcd(that: BigInt): BigInt =
+ if (this.longEncoding) {
+ if (this._long == 0) return that.abs
+ // if (this._long == Long.MinValue) return (-this) gcd that
+ // this != 0 && this != Long.MinValue
+ if (that.longEncoding) {
+ if (that._long == 0) return this.abs
+ // if (that._long == Long.MinValue) return this gcd (-that)
+ BigInt(BigInt.longGcd(this._long.abs, that._long.abs))
+ } else that gcd this // force the BigInteger on the left
+ } else {
+ // this is not a valid long
+ if (that.longEncoding) {
+ if (that._long == 0) return this.abs
+ // if (that._long == Long.MinValue) return this gcd (-that)
+ val red = (this._bigInteger mod BigInteger.valueOf(that._long.abs)).longValue()
+ if (red == 0) return that.abs
+ BigInt(BigInt.longGcd(that._long.abs, red))
+ } else BigInt(this.bigInteger.gcd(that.bigInteger))
+ }
+
/** Returns a BigInt whose value is (this mod that).
* This method differs from `%` in that it always returns a non-negative BigInt.
* @param that A positive number
*/
- def mod (that: BigInt): BigInt = BigInt(this.bigInteger.mod(that.bigInteger))
+ def mod(that: BigInt): BigInt =
+ if (this.longEncoding && that.longEncoding) {
+ val res = this._long % that._long
+ if (res >= 0) BigInt(res) else BigInt(res + that._long)
+ } else BigInt(this.bigInteger.mod(that.bigInteger))
/** Returns the minimum of this and that
*/
- def min (that: BigInt): BigInt = BigInt(this.bigInteger.min(that.bigInteger))
+ def min(that: BigInt): BigInt =
+ if (this <= that) this else that
/** Returns the maximum of this and that
*/
- def max (that: BigInt): BigInt = BigInt(this.bigInteger.max(that.bigInteger))
+ def max(that: BigInt): BigInt =
+ if (this >= that) this else that
/** Returns a BigInt whose value is (this raised to the power of exp).
*/
- def pow (exp: Int): BigInt = BigInt(this.bigInteger.pow(exp))
+ def pow(exp: Int): BigInt = BigInt(this.bigInteger.pow(exp))
/** Returns a BigInt whose value is
* (this raised to the power of exp modulo m).
*/
- def modPow (exp: BigInt, m: BigInt): BigInt =
- BigInt(this.bigInteger.modPow(exp.bigInteger, m.bigInteger))
+ def modPow(exp: BigInt, m: BigInt): BigInt = BigInt(this.bigInteger.modPow(exp.bigInteger, m.bigInteger))
/** Returns a BigInt whose value is (the inverse of this modulo m).
*/
- def modInverse (m: BigInt): BigInt = BigInt(this.bigInteger.modInverse(m.bigInteger))
+ def modInverse(m: BigInt): BigInt = BigInt(this.bigInteger.modInverse(m.bigInteger))
/** Returns a BigInt whose value is the negation of this BigInt
*/
- def unary_- : BigInt = BigInt(this.bigInteger.negate())
+ def unary_- : BigInt = if (longEncoding) BigInt(-_long) else BigInt(this.bigInteger.negate())
/** Returns the absolute value of this BigInt
*/
- def abs: BigInt = BigInt(this.bigInteger.abs())
+ def abs: BigInt = if (signum < 0) -this else this
/** Returns the sign of this BigInt;
* -1 if it is less than 0,
* +1 if it is greater than 0,
* 0 if it is equal to 0.
*/
- def signum: Int = this.bigInteger.signum()
+ def signum: Int = if (longEncoding) java.lang.Long.signum(_long) else _bigInteger.signum()
/** Returns the sign of this BigInt;
* -1 if it is less than 0,
* +1 if it is greater than 0,
* 0 if it is equal to 0.
*/
- def sign: BigInt = signum
+ def sign: BigInt = BigInt(signum)
/** Returns the bitwise complement of this BigInt
*/
- def unary_~ : BigInt = BigInt(this.bigInteger.not())
+ def unary_~ : BigInt =
+ // it is equal to -(this + 1)
+ if (longEncoding && _long != Long.MaxValue) BigInt(-(_long + 1)) else BigInt(this.bigInteger.not())
/** Returns true if and only if the designated bit is set.
*/
- def testBit (n: Int): Boolean = this.bigInteger.testBit(n)
+ def testBit(n: Int): Boolean =
+ if (longEncoding) {
+ if (n <= 63)
+ (_long & (1L << n)) != 0
+ else
+ _long < 0 // give the sign bit
+ } else _bigInteger.testBit(n)
/** Returns a BigInt whose value is equivalent to this BigInt with the designated bit set.
*/
- def setBit (n: Int): BigInt = BigInt(this.bigInteger.setBit(n))
+ def setBit(n: Int): BigInt = // note that we do not operate on the Long sign bit #63
+ if (longEncoding && n <= 62) BigInt(_long | (1L << n)) else BigInt(this.bigInteger.setBit(n))
/** Returns a BigInt whose value is equivalent to this BigInt with the designated bit cleared.
*/
- def clearBit(n: Int): BigInt = BigInt(this.bigInteger.clearBit(n))
+ def clearBit(n: Int): BigInt = // note that we do not operate on the Long sign bit #63
+ if (longEncoding && n <= 62) BigInt(_long & ~(1L << n)) else BigInt(this.bigInteger.clearBit(n))
/** Returns a BigInt whose value is equivalent to this BigInt with the designated bit flipped.
*/
- def flipBit (n: Int): BigInt = BigInt(this.bigInteger.flipBit(n))
+ def flipBit(n: Int): BigInt = // note that we do not operate on the Long sign bit #63
+ if (longEncoding && n <= 62) BigInt(_long ^ (1L << n)) else BigInt(this.bigInteger.flipBit(n))
/** Returns the index of the rightmost (lowest-order) one bit in this BigInt
* (the number of zero bits to the right of the rightmost one bit).
*/
- def lowestSetBit: Int = this.bigInteger.getLowestSetBit()
+ def lowestSetBit: Int =
+ if (longEncoding) {
+ if (_long == 0) -1 else java.lang.Long.numberOfTrailingZeros(_long)
+ } else this.bigInteger.getLowestSetBit()
/** Returns the number of bits in the minimal two's-complement representation of this BigInt,
* excluding a sign bit.
*/
- def bitLength: Int = this.bigInteger.bitLength()
+ def bitLength: Int =
+ // bitLength is defined as ceil(log2(this < 0 ? -this : this + 1)))
+ // where ceil(log2(x)) = 64 - numberOfLeadingZeros(x - 1)
+ if (longEncoding) {
+ if (_long < 0) 64 - java.lang.Long.numberOfLeadingZeros(-(_long + 1)) // takes care of Long.MinValue
+ else 64 - java.lang.Long.numberOfLeadingZeros(_long)
+ } else _bigInteger.bitLength()
/** Returns the number of bits in the two's complement representation of this BigInt
* that differ from its sign bit.
*/
- def bitCount: Int = this.bigInteger.bitCount()
+ def bitCount: Int =
+ if (longEncoding) {
+ if (_long < 0) java.lang.Long.bitCount(-(_long + 1)) else java.lang.Long.bitCount(_long)
+ } else this.bigInteger.bitCount()
/** Returns true if this BigInt is probably prime, false if it's definitely composite.
* @param certainty a measure of the uncertainty that the caller is willing to tolerate:
@@ -413,7 +574,7 @@ final class BigInt private (private var _bigInteger: BigInteger, private val _lo
* overall magnitude of the BigInt value as well as return a result with
* the opposite sign.
*/
- def intValue: Int = this.bigInteger.intValue
+ def intValue: Int = if (longEncoding) _long.toInt else this.bigInteger.intValue
/** Converts this BigInt to a long.
* If the BigInt is too big to fit in a long, only the low-order 64 bits
@@ -421,7 +582,7 @@ final class BigInt private (private var _bigInteger: BigInteger, private val _lo
* overall magnitude of the BigInt value as well as return a result with
* the opposite sign.
*/
- def longValue: Long = this.bigInteger.longValue
+ def longValue: Long = if (longEncoding) _long else _bigInteger.longValue
/** Converts this `BigInt` to a `float`.
* If this `BigInt` has too great a magnitude to represent as a float,
@@ -435,7 +596,9 @@ final class BigInt private (private var _bigInteger: BigInteger, private val _lo
* it will be converted to `Double.NEGATIVE_INFINITY` or
* `Double.POSITIVE_INFINITY` as appropriate.
*/
- def doubleValue: Double = this.bigInteger.doubleValue
+ def doubleValue: Double =
+ if (isValidLong && (-(1L << 53) <= _long && _long <= (1L << 53))) _long.toDouble
+ else this.bigInteger.doubleValue
/** Create a `NumericRange[BigInt]` in range `[start;end)`
* with the specified step, where start is the target BigInt.
@@ -452,7 +615,7 @@ final class BigInt private (private var _bigInteger: BigInteger, private val _lo
/** Returns the decimal String representation of this BigInt.
*/
- override def toString(): String = this.bigInteger.toString()
+ override def toString(): String = if (longEncoding) _long.toString() else _bigInteger.toString()
/** Returns the String representation in the specified radix of this BigInt.
*/