diff --git a/project/MimaFilters.scala b/project/MimaFilters.scala index 0b35213fffec..0cde580c4f63 100644 --- a/project/MimaFilters.scala +++ b/project/MimaFilters.scala @@ -25,6 +25,11 @@ object MimaFilters extends AutoPlugin { // don't publish the artifact built with JDK 11 anyways ProblemFilters.exclude[DirectMissingMethodProblem]("scala.collection.convert.JavaCollectionWrappers#IteratorWrapper.asIterator"), + // for the method this(Long)Unit in class scala.math.BigInt does not have a correspondent in other versions + // this new constructor is nevertheless private, and can only be called from the BigInt class and its companion + // object + ProblemFilters.exclude[DirectMissingMethodProblem]("scala.math.BigInt.this"), + // PR: https://github.com/scala/scala/pull/9336; remove after re-STARR ProblemFilters.exclude[MissingTypesProblem]("scala.deprecatedOverriding"), ProblemFilters.exclude[MissingTypesProblem]("scala.deprecatedInheritance"), diff --git a/src/library/scala/math/BigInt.scala b/src/library/scala/math/BigInt.scala index 20cec9742ed2..6ea371328d9e 100644 --- a/src/library/scala/math/BigInt.scala +++ b/src/library/scala/math/BigInt.scala @@ -21,9 +21,23 @@ import scala.collection.immutable.NumericRange object BigInt { + private val longMinValueBigInteger = BigInteger.valueOf(Long.MinValue) + private val longMinValue = new BigInt(longMinValueBigInteger, Long.MinValue) + private[this] val minCached = -1024 private[this] val maxCached = 1024 private[this] val cache = new Array[BigInt](maxCached - minCached + 1) + + private[this] def getCached(i: Int): BigInt = { + val offset = i - minCached + var n = cache(offset) + if (n eq null) { + n = new BigInt(null, i.toLong) + cache(offset) = n + } + n + } + private val minusOne = BigInteger.valueOf(-1) /** Constructs a `BigInt` whose value is equal to that of the @@ -33,12 +47,7 @@ object BigInt { * @return the constructed `BigInt` */ def apply(i: Int): BigInt = - if (minCached <= i && i <= maxCached) { - val offset = i - minCached - var n = cache(offset) - if (n eq null) { n = new BigInt(BigInteger.valueOf(i.toLong)); cache(offset) = n } - n - } else new BigInt(BigInteger.valueOf(i.toLong)) + if (minCached <= i && i <= maxCached) getCached(i) else apply(i: Long) /** Constructs a `BigInt` whose value is equal to that of the * specified long value. @@ -47,14 +56,15 @@ object BigInt { * @return the constructed `BigInt` */ def apply(l: Long): BigInt = - if (minCached <= l && l <= maxCached) apply(l.toInt) - else new BigInt(BigInteger.valueOf(l)) + if (minCached <= l && l <= maxCached) getCached(l.toInt) else { + if (l == Long.MinValue) longMinValue else new BigInt(null, l) + } /** Translates a byte array containing the two's-complement binary * representation of a BigInt into a BigInt. */ def apply(x: Array[Byte]): BigInt = - new BigInt(new BigInteger(x)) + apply(new BigInteger(x)) /** Translates the sign-magnitude representation of a BigInt into a BigInt. * @@ -64,40 +74,44 @@ object BigInt { * the number. */ def apply(signum: Int, magnitude: Array[Byte]): BigInt = - new BigInt(new BigInteger(signum, magnitude)) + apply(new BigInteger(signum, magnitude)) /** Constructs a randomly generated positive BigInt that is probably prime, * with the specified bitLength. */ def apply(bitlength: Int, certainty: Int, rnd: scala.util.Random): BigInt = - new BigInt(new BigInteger(bitlength, certainty, rnd.self)) + apply(new BigInteger(bitlength, certainty, rnd.self)) /** Constructs a randomly generated BigInt, uniformly distributed over the * range `0` to `(2 ^ numBits - 1)`, inclusive. */ def apply(numbits: Int, rnd: scala.util.Random): BigInt = - new BigInt(new BigInteger(numbits, rnd.self)) + apply(new BigInteger(numbits, rnd.self)) /** Translates the decimal String representation of a BigInt into a BigInt. */ def apply(x: String): BigInt = - new BigInt(new BigInteger(x)) + apply(new BigInteger(x)) /** Translates the string representation of a `BigInt` in the * specified `radix` into a BigInt. */ def apply(x: String, radix: Int): BigInt = - new BigInt(new BigInteger(x, radix)) + apply(new BigInteger(x, radix)) /** Translates a `java.math.BigInteger` into a BigInt. */ - def apply(x: BigInteger): BigInt = - new BigInt(x) + def apply(x: BigInteger): BigInt = { + if (x.bitLength <= 63) { + val l = x.longValue + if (minCached <= l && l <= maxCached) getCached(l.toInt) else new BigInt(x, l) + } else new BigInt(x, Long.MinValue) + } /** Returns a positive BigInt that is probably prime, with the specified bitLength. */ def probablePrime(bitLength: Int, rnd: scala.util.Random): BigInt = - new BigInt(BigInteger.probablePrime(bitLength, rnd.self)) + apply(BigInteger.probablePrime(bitLength, rnd.self)) /** Implicit conversion from `Int` to `BigInt`. */ @@ -110,14 +124,103 @@ object BigInt { /** Implicit conversion from `java.math.BigInteger` to `scala.BigInt`. */ implicit def javaBigInteger2bigInt(x: BigInteger): BigInt = apply(x) + + // this method is adapted from Google Guava's version at + // https://github.com/google/guava/blob/master/guava/src/com/google/common/math/LongMath.java + // that code carries the following notice: + // * Copyright (C) 2011 The Guava Authors + // * + // * Licensed under the Apache License, Version 2.0 (the "License") + /** + * Returns the greatest common divisor of a and b. Returns 0 if a == 0 && b == 0. + */ + private def longGcd(a: Long, b: Long): Long = { + // both a and b must be >= 0 + if (a == 0) { // 0 % b == 0, so b divides a, but the converse doesn't hold. + // BigInteger.gcd is consistent with this decision. + return b + } + else if (b == 0) return a // similar logic + /* + * Uses the binary GCD algorithm; see http://en.wikipedia.org/wiki/Binary_GCD_algorithm. This is + * >60% faster than the Euclidean algorithm in benchmarks. + */ + val aTwos = java.lang.Long.numberOfTrailingZeros(a) + var a1 = a >> aTwos // divide out all 2s + + val bTwos = java.lang.Long.numberOfTrailingZeros(b) + var b1 = b >> bTwos + while (a1 != b1) { // both a, b are odd + // The key to the binary GCD algorithm is as follows: + // Both a1 and b1 are odd. Assume a1 > b1; then gcd(a1 - b1, b1) = gcd(a1, b1). + // But in gcd(a1 - b1, b1), a1 - b1 is even and b1 is odd, so we can divide out powers of two. + // We bend over backwards to avoid branching, adapting a technique from + // http://graphics.stanford.edu/~seander/bithacks.html#IntegerMinOrMax + val delta = a1 - b1 // can't overflow, since a1 and b1 are nonnegative + val minDeltaOrZero = delta & (delta >> (java.lang.Long.SIZE - 1)) + // equivalent to Math.min(delta, 0) + a1 = delta - minDeltaOrZero - minDeltaOrZero // sets a to Math.abs(a - b) + + // a is now nonnegative and even + b1 += minDeltaOrZero // sets b to min(old a, b) + + a1 >>= java.lang.Long.numberOfTrailingZeros(a1) // divide out all 2s, since 2 doesn't divide b + + } + a1 << scala.math.min(aTwos, bTwos) + } + } -final class BigInt(val bigInteger: BigInteger) +/** A type with efficient encoding of arbitrary integers. + * + * It wraps `java.math.BigInteger`, with optimization for small values that can be encoded in a `Long`. + */ +final class BigInt private (private var _bigInteger: BigInteger, private val _long: Long) extends ScalaNumber with ScalaNumericConversions with Serializable with Ordered[BigInt] { + // The class has a special encoding for integer that fit in a Long *and* are not equal to Long.MinValue. + // + // The Long value Long.MinValue is a tag specifying that the integer is encoded in the BigInteger field. + // + // There are three possible states for the class fields (_bigInteger, _long) + // 1. (null, l) where l != Long.MinValue, encodes the integer "l" + // 2. (b, l) where l != Long.MinValue; then b is a BigInteger with value l, encodes "l" == "b" + // 3a. (b, Long.MinValue) where b == Long.MinValue, encodes Long.MinValue + // 3b. (b, Long.MinValue) where b does not fit in a Long, encodes "b" + // + // There is only one possible transition 1. -> 2., when the method .bigInteger is called, then the field + // _bigInteger caches the result. + // + // The case 3a. is the only one where the BigInteger could actually fit in a Long, but as its value is used as a + // tag, we'll take the slow path instead. + // + // Additionally, we know that if this.isValidLong is true, then _long is the encoded value. + + /** Public constructor present for compatibility. Use the BigInt.apply companion object method instead. */ + def this(bigInteger: BigInteger) = this( + bigInteger, // even if it is a short BigInteger, we cache the instance + if (bigInteger.bitLength <= 63) + bigInteger.longValue // if _bigInteger is actually equal to Long.MinValue, no big deal, its value acts as a tag + else Long.MinValue + ) + + /** Returns whether the integer is encoded in the Long. Returns true for all values fitting in a Long except + * Long.MinValue. */ + private def longEncoding: Boolean = _long != Long.MinValue + + def bigInteger: BigInteger = { + val read = _bigInteger + if (read ne null) read else { + val write = BigInteger.valueOf(_long) + _bigInteger = write // reference assignment is atomic; this is multi-thread safe (if possibly wasteful) + write + } + } + /** Returns the hash code for this BigInt. */ override def hashCode(): Int = if (isValidLong) unifiedPrimitiveHashcode @@ -132,11 +235,13 @@ final class BigInt(val bigInteger: BigInteger) case that: Float => isValidFloat && toFloat == that case x => isValidLong && unifiedPrimitiveEquals(x) } - override def isValidByte: Boolean = this >= Byte.MinValue && this <= Byte.MaxValue - override def isValidShort: Boolean = this >= Short.MinValue && this <= Short.MaxValue - override def isValidChar: Boolean = this >= Char.MinValue && this <= Char.MaxValue - override def isValidInt: Boolean = this >= Int.MinValue && this <= Int.MaxValue - def isValidLong: Boolean = this >= Long.MinValue && this <= Long.MaxValue + + override def isValidByte: Boolean = _long >= Byte.MinValue && _long <= Byte.MaxValue /* && longEncoding */ + override def isValidShort: Boolean = _long >= Short.MinValue && _long <= Short.MaxValue /* && longEncoding */ + override def isValidChar: Boolean = _long >= Char.MinValue && _long <= Char.MaxValue /* && longEncoding */ + override def isValidInt: Boolean = _long >= Int.MinValue && _long <= Int.MaxValue /* && longEncoding */ + def isValidLong: Boolean = longEncoding || _bigInteger == BigInt.longMinValueBigInteger // rhs of || tests == Long.MinValue + /** Returns `true` iff this can be represented exactly by [[scala.Float]]; otherwise returns `false`. */ def isValidFloat: Boolean = { @@ -178,151 +283,266 @@ final class BigInt(val bigInteger: BigInteger) /** Compares this BigInt with the specified BigInt for equality. */ - def equals (that: BigInt): Boolean = compare(that) == 0 + def equals(that: BigInt): Boolean = + if (this.longEncoding) + that.longEncoding && (this._long == that._long) + else + !that.longEncoding && (this._bigInteger == that._bigInteger) /** Compares this BigInt with the specified BigInt */ - def compare (that: BigInt): Int = this.bigInteger.compareTo(that.bigInteger) + def compare(that: BigInt): Int = + if (this.longEncoding) { + if (that.longEncoding) java.lang.Long.compare(this._long, that._long) else -that._bigInteger.signum() + } else { + if (that.longEncoding) _bigInteger.signum() else this._bigInteger.compareTo(that._bigInteger) + } /** Addition of BigInts */ - def + (that: BigInt): BigInt = new BigInt(this.bigInteger.add(that.bigInteger)) + def +(that: BigInt): BigInt = { + if (this.longEncoding && that.longEncoding) { // fast path + val x = this._long + val y = that._long + val z = x + y + if ((~(x ^ y) & (x ^ z)) >= 0L) return BigInt(z) + } + BigInt(this.bigInteger.add(that.bigInteger)) + } /** Subtraction of BigInts */ - def - (that: BigInt): BigInt = new BigInt(this.bigInteger.subtract(that.bigInteger)) + def -(that: BigInt): BigInt = { + if (this.longEncoding && that.longEncoding) { // fast path + val x = this._long + val y = that._long + val z = x - y + if (((x ^ y) & (x ^ z)) >= 0L) return BigInt(z) + } + BigInt(this.bigInteger.subtract(that.bigInteger)) + } /** Multiplication of BigInts */ - def * (that: BigInt): BigInt = new BigInt(this.bigInteger.multiply(that.bigInteger)) + def *(that: BigInt): BigInt = { + if (this.longEncoding && that.longEncoding) { // fast path + val x = this._long + val y = that._long + val z = x * y + // original code checks the y != Long.MinValue, but when longEncoding is true, that is never the case + // if (x == 0 || (y == z / x && !(x == -1 && y == Long.MinValue))) return BigInt(z) + if (x == 0 || y == z / x) return BigInt(z) + } + BigInt(this.bigInteger.multiply(that.bigInteger)) + } /** Division of BigInts */ - def / (that: BigInt): BigInt = new BigInt(this.bigInteger.divide(that.bigInteger)) + def /(that: BigInt): BigInt = + // in the fast path, note that the original code avoided storing -Long.MinValue in a long: + // if (this._long != Long.MinValue || that._long != -1) return BigInt(this._long / that._long) + // but we know this._long cannot be Long.MinValue, because Long.MinValue is the tag for bigger integers + if (this.longEncoding && that.longEncoding) BigInt(this._long / that._long) + else BigInt(this.bigInteger.divide(that.bigInteger)) /** Remainder of BigInts */ - def % (that: BigInt): BigInt = new BigInt(this.bigInteger.remainder(that.bigInteger)) + def %(that: BigInt): BigInt = + // see / for the original logic regarding Long.MinValue + if (this.longEncoding && that.longEncoding) BigInt(this._long % that._long) + else BigInt(this.bigInteger.remainder(that.bigInteger)) /** Returns a pair of two BigInts containing (this / that) and (this % that). */ - def /% (that: BigInt): (BigInt, BigInt) = { - val dr = this.bigInteger.divideAndRemainder(that.bigInteger) - (new BigInt(dr(0)), new BigInt(dr(1))) - } + def /%(that: BigInt): (BigInt, BigInt) = + if (this.longEncoding && that.longEncoding) { + val x = this._long + val y = that._long + // original line: if (x != Long.MinValue || y != -1) return (BigInt(x / y), BigInt(x % y)) + (BigInt(x / y), BigInt(x % y)) + } else { + val dr = this.bigInteger.divideAndRemainder(that.bigInteger) + (BigInt(dr(0)), BigInt(dr(1))) + } /** Leftshift of BigInt */ - def << (n: Int): BigInt = new BigInt(this.bigInteger.shiftLeft(n)) + def <<(n: Int): BigInt = + if (longEncoding && n <= 0) (this >> (-n)) else BigInt(this.bigInteger.shiftLeft(n)) /** (Signed) rightshift of BigInt */ - def >> (n: Int): BigInt = new BigInt(this.bigInteger.shiftRight(n)) - + def >>(n: Int): BigInt = + if (longEncoding && n >= 0) { + if (n < 64) BigInt(_long >> n) + else if (_long < 0) BigInt(-1) + else BigInt(0) // for _long >= 0 + } else BigInt(this.bigInteger.shiftRight(n)) + /** Bitwise and of BigInts */ - def & (that: BigInt): BigInt = new BigInt(this.bigInteger.and(that.bigInteger)) + def &(that: BigInt): BigInt = + if (this.longEncoding && that.longEncoding) + BigInt(this._long & that._long) + else BigInt(this.bigInteger.and(that.bigInteger)) /** Bitwise or of BigInts */ - def | (that: BigInt): BigInt = new BigInt(this.bigInteger.or (that.bigInteger)) + def |(that: BigInt): BigInt = + if (this.longEncoding && that.longEncoding) + BigInt(this._long | that._long) + else BigInt(this.bigInteger.or(that.bigInteger)) /** Bitwise exclusive-or of BigInts */ - def ^ (that: BigInt): BigInt = new BigInt(this.bigInteger.xor(that.bigInteger)) + def ^(that: BigInt): BigInt = + if (this.longEncoding && that.longEncoding) + BigInt(this._long ^ that._long) + else BigInt(this.bigInteger.xor(that.bigInteger)) /** Bitwise and-not of BigInts. Returns a BigInt whose value is (this & ~that). */ - def &~ (that: BigInt): BigInt = new BigInt(this.bigInteger.andNot(that.bigInteger)) + def &~(that: BigInt): BigInt = + if (this.longEncoding && that.longEncoding) + BigInt(this._long & ~that._long) + else BigInt(this.bigInteger.andNot(that.bigInteger)) /** Returns the greatest common divisor of abs(this) and abs(that) */ - def gcd (that: BigInt): BigInt = new BigInt(this.bigInteger.gcd(that.bigInteger)) + def gcd(that: BigInt): BigInt = + if (this.longEncoding) { + if (this._long == 0) return that.abs + // if (this._long == Long.MinValue) return (-this) gcd that + // this != 0 && this != Long.MinValue + if (that.longEncoding) { + if (that._long == 0) return this.abs + // if (that._long == Long.MinValue) return this gcd (-that) + BigInt(BigInt.longGcd(this._long.abs, that._long.abs)) + } else that gcd this // force the BigInteger on the left + } else { + // this is not a valid long + if (that.longEncoding) { + if (that._long == 0) return this.abs + // if (that._long == Long.MinValue) return this gcd (-that) + val red = (this._bigInteger mod BigInteger.valueOf(that._long.abs)).longValue() + if (red == 0) return that.abs + BigInt(BigInt.longGcd(that._long.abs, red)) + } else BigInt(this.bigInteger.gcd(that.bigInteger)) + } + /** Returns a BigInt whose value is (this mod that). * This method differs from `%` in that it always returns a non-negative BigInt. * @param that A positive number */ - def mod (that: BigInt): BigInt = new BigInt(this.bigInteger.mod(that.bigInteger)) + def mod(that: BigInt): BigInt = + if (this.longEncoding && that.longEncoding) { + val res = this._long % that._long + if (res >= 0) BigInt(res) else BigInt(res + that._long) + } else BigInt(this.bigInteger.mod(that.bigInteger)) /** Returns the minimum of this and that */ - def min (that: BigInt): BigInt = new BigInt(this.bigInteger.min(that.bigInteger)) + def min(that: BigInt): BigInt = + if (this <= that) this else that /** Returns the maximum of this and that */ - def max (that: BigInt): BigInt = new BigInt(this.bigInteger.max(that.bigInteger)) + def max(that: BigInt): BigInt = + if (this >= that) this else that /** Returns a BigInt whose value is (this raised to the power of exp). */ - def pow (exp: Int): BigInt = new BigInt(this.bigInteger.pow(exp)) + def pow(exp: Int): BigInt = BigInt(this.bigInteger.pow(exp)) /** Returns a BigInt whose value is * (this raised to the power of exp modulo m). */ - def modPow (exp: BigInt, m: BigInt): BigInt = - new BigInt(this.bigInteger.modPow(exp.bigInteger, m.bigInteger)) + def modPow(exp: BigInt, m: BigInt): BigInt = BigInt(this.bigInteger.modPow(exp.bigInteger, m.bigInteger)) /** Returns a BigInt whose value is (the inverse of this modulo m). */ - def modInverse (m: BigInt): BigInt = new BigInt(this.bigInteger.modInverse(m.bigInteger)) + def modInverse(m: BigInt): BigInt = BigInt(this.bigInteger.modInverse(m.bigInteger)) /** Returns a BigInt whose value is the negation of this BigInt */ - def unary_- : BigInt = new BigInt(this.bigInteger.negate()) + def unary_- : BigInt = if (longEncoding) BigInt(-_long) else BigInt(this.bigInteger.negate()) /** Returns the absolute value of this BigInt */ - def abs: BigInt = new BigInt(this.bigInteger.abs()) + def abs: BigInt = if (signum < 0) -this else this /** Returns the sign of this BigInt; * -1 if it is less than 0, * +1 if it is greater than 0, * 0 if it is equal to 0. */ - def signum: Int = this.bigInteger.signum() + def signum: Int = if (longEncoding) java.lang.Long.signum(_long) else _bigInteger.signum() /** Returns the sign of this BigInt; * -1 if it is less than 0, * +1 if it is greater than 0, * 0 if it is equal to 0. */ - def sign: BigInt = signum + def sign: BigInt = BigInt(signum) /** Returns the bitwise complement of this BigInt */ - def unary_~ : BigInt = new BigInt(this.bigInteger.not()) + def unary_~ : BigInt = + // it is equal to -(this + 1) + if (longEncoding && _long != Long.MaxValue) BigInt(-(_long + 1)) else BigInt(this.bigInteger.not()) /** Returns true if and only if the designated bit is set. */ - def testBit (n: Int): Boolean = this.bigInteger.testBit(n) + def testBit(n: Int): Boolean = + if (longEncoding) { + if (n <= 63) + (_long & (1L << n)) != 0 + else + _long < 0 // give the sign bit + } else _bigInteger.testBit(n) /** Returns a BigInt whose value is equivalent to this BigInt with the designated bit set. */ - def setBit (n: Int): BigInt = new BigInt(this.bigInteger.setBit(n)) + def setBit(n: Int): BigInt = // note that we do not operate on the Long sign bit #63 + if (longEncoding && n <= 62) BigInt(_long | (1L << n)) else BigInt(this.bigInteger.setBit(n)) /** Returns a BigInt whose value is equivalent to this BigInt with the designated bit cleared. */ - def clearBit(n: Int): BigInt = new BigInt(this.bigInteger.clearBit(n)) + def clearBit(n: Int): BigInt = // note that we do not operate on the Long sign bit #63 + if (longEncoding && n <= 62) BigInt(_long & ~(1L << n)) else BigInt(this.bigInteger.clearBit(n)) /** Returns a BigInt whose value is equivalent to this BigInt with the designated bit flipped. */ - def flipBit (n: Int): BigInt = new BigInt(this.bigInteger.flipBit(n)) + def flipBit(n: Int): BigInt = // note that we do not operate on the Long sign bit #63 + if (longEncoding && n <= 62) BigInt(_long ^ (1L << n)) else BigInt(this.bigInteger.flipBit(n)) /** Returns the index of the rightmost (lowest-order) one bit in this BigInt * (the number of zero bits to the right of the rightmost one bit). */ - def lowestSetBit: Int = this.bigInteger.getLowestSetBit() + def lowestSetBit: Int = + if (longEncoding) { + if (_long == 0) -1 else java.lang.Long.numberOfTrailingZeros(_long) + } else this.bigInteger.getLowestSetBit() /** Returns the number of bits in the minimal two's-complement representation of this BigInt, * excluding a sign bit. */ - def bitLength: Int = this.bigInteger.bitLength() + def bitLength: Int = + // bitLength is defined as ceil(log2(this < 0 ? -this : this + 1))) + // where ceil(log2(x)) = 64 - numberOfLeadingZeros(x - 1) + if (longEncoding) { + if (_long < 0) 64 - java.lang.Long.numberOfLeadingZeros(-(_long + 1)) // takes care of Long.MinValue + else 64 - java.lang.Long.numberOfLeadingZeros(_long) + } else _bigInteger.bitLength() /** Returns the number of bits in the two's complement representation of this BigInt * that differ from its sign bit. */ - def bitCount: Int = this.bigInteger.bitCount() + def bitCount: Int = + if (longEncoding) { + if (_long < 0) java.lang.Long.bitCount(-(_long + 1)) else java.lang.Long.bitCount(_long) + } else this.bigInteger.bitCount() /** Returns true if this BigInt is probably prime, false if it's definitely composite. * @param certainty a measure of the uncertainty that the caller is willing to tolerate: @@ -360,7 +580,7 @@ final class BigInt(val bigInteger: BigInteger) * overall magnitude of the BigInt value as well as return a result with * the opposite sign. */ - def intValue: Int = this.bigInteger.intValue + def intValue: Int = if (longEncoding) _long.toInt else this.bigInteger.intValue /** Converts this BigInt to a long. * If the BigInt is too big to fit in a long, only the low-order 64 bits @@ -368,7 +588,7 @@ final class BigInt(val bigInteger: BigInteger) * overall magnitude of the BigInt value as well as return a result with * the opposite sign. */ - def longValue: Long = this.bigInteger.longValue + def longValue: Long = if (longEncoding) _long else _bigInteger.longValue /** Converts this `BigInt` to a `float`. * If this `BigInt` has too great a magnitude to represent as a float, @@ -382,7 +602,9 @@ final class BigInt(val bigInteger: BigInteger) * it will be converted to `Double.NEGATIVE_INFINITY` or * `Double.POSITIVE_INFINITY` as appropriate. */ - def doubleValue: Double = this.bigInteger.doubleValue + def doubleValue: Double = + if (isValidLong && (-(1L << 53) <= _long && _long <= (1L << 53))) _long.toDouble + else this.bigInteger.doubleValue /** Create a `NumericRange[BigInt]` in range `[start;end)` * with the specified step, where start is the target BigInt. @@ -399,7 +621,7 @@ final class BigInt(val bigInteger: BigInteger) /** Returns the decimal String representation of this BigInt. */ - override def toString(): String = this.bigInteger.toString() + override def toString(): String = if (longEncoding) _long.toString() else _bigInteger.toString() /** Returns the String representation in the specified radix of this BigInt. */ diff --git a/test/benchmarks/src/main/scala/scala/math/BigIntEulerProblem15Benchmark.scala b/test/benchmarks/src/main/scala/scala/math/BigIntEulerProblem15Benchmark.scala new file mode 100644 index 000000000000..690c078ec2f7 --- /dev/null +++ b/test/benchmarks/src/main/scala/scala/math/BigIntEulerProblem15Benchmark.scala @@ -0,0 +1,29 @@ +package scala.math + +import java.util.concurrent.TimeUnit + +import org.openjdk.jmh.annotations._ +import org.openjdk.jmh.infra.Blackhole + +@BenchmarkMode(Array(Mode.AverageTime)) +@Fork(2) +@Threads(1) +@Warmup(iterations = 10) +@Measurement(iterations = 10) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@State(Scope.Benchmark) +class BigIntEulerProblem15Benchmark { + + @Param(Array("5", "10", "15", "20", "25", "30", "35", "40", "45", "50", "55", + "60", "65", "70", "75", "80", "85", "90", "95", "100")) + var size: Int = _ + + @Benchmark + def eulerProblem15(bh: Blackhole): Unit = { + def f(row: Array[BigInt], c: Int): BigInt = + if (c == 0) row.last else f(row.scan(BigInt(0))(_ + _), c - 1) + def computeAnswer(n: Int): BigInt = f(Array.fill(n + 1)(BigInt(1)), n) + bh.consume(computeAnswer(size)) + } + +} diff --git a/test/benchmarks/src/main/scala/scala/math/BigIntFactorialBenchmark.scala b/test/benchmarks/src/main/scala/scala/math/BigIntFactorialBenchmark.scala new file mode 100644 index 000000000000..0aaa18c029e1 --- /dev/null +++ b/test/benchmarks/src/main/scala/scala/math/BigIntFactorialBenchmark.scala @@ -0,0 +1,30 @@ +package scala.math + +import java.util.concurrent.TimeUnit + +import org.openjdk.jmh.annotations._ +import org.openjdk.jmh.infra.Blackhole + +import scala.annotation.tailrec + +@BenchmarkMode(Array(Mode.AverageTime)) +@Fork(2) +@Threads(1) +@Warmup(iterations = 10) +@Measurement(iterations = 10) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@State(Scope.Benchmark) +class BigIntFactorialBenchmark { + + @Param(Array("5", "10", "15", "20", "25", "30", "35", "40", "45", "50", "55", + "60", "65", "70", "75", "80", "85", "90", "95", "100")) + var size: Int = _ + + @Benchmark + def factorial(bh: Blackhole): Unit = { + @tailrec def fact(i: Int, n: Int, prev: BigInt): BigInt = + if (i > n) prev else fact(i + 1, n, prev * i) + bh.consume(fact(1, size, BigInt(1))) + } + +} diff --git a/test/benchmarks/src/main/scala/scala/math/BigIntRSABenchmark.scala b/test/benchmarks/src/main/scala/scala/math/BigIntRSABenchmark.scala new file mode 100644 index 000000000000..4c93f324e0bd --- /dev/null +++ b/test/benchmarks/src/main/scala/scala/math/BigIntRSABenchmark.scala @@ -0,0 +1,32 @@ +package scala.math + +import java.util.concurrent.TimeUnit + +import org.openjdk.jmh.annotations._ +import org.openjdk.jmh.infra._ + +@BenchmarkMode(Array(Mode.AverageTime)) +@Fork(2) +@Threads(1) +@Warmup(iterations = 10) +@Measurement(iterations = 10) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@State(Scope.Benchmark) +class BigIntRSABenchmark { + + @Benchmark + def encodeDecode(bh: Blackhole): Unit = { + // private key + val d = BigInt("5617843187844953170308463622230283376298685") + // public key + val n = BigInt("9516311845790656153499716760847001433441357") + val e = 65537 + + // concatenation of "Scala is great" + val plaintext = BigInt("83099097108097032105115032103114101097116") + val ciphertext = plaintext.modPow(e, n) + val recoveredtext = ciphertext.modPow(d, n) + bh.consume(plaintext == recoveredtext) + } + +} diff --git a/test/scalacheck/scala/math/BigIntProperties.scala b/test/scalacheck/scala/math/BigIntProperties.scala index c4c0295dc50a..d036719b368f 100644 --- a/test/scalacheck/scala/math/BigIntProperties.scala +++ b/test/scalacheck/scala/math/BigIntProperties.scala @@ -61,6 +61,7 @@ object BigIntProperties extends Properties("BigInt") { property("longValue") = forAll { (l: Long) => BigInt(l).longValue ?= l } property("toLong") = forAll { (l: Long) => BigInt(l).toLong ?= l } + property("new BigInt(bigInteger = BigInteger.ZERO)") = (new BigInt(bigInteger = BigInteger.ZERO)) == 0 property("BigInt.apply(i: Int)") = forAll { (i: Int) => BigInt(i) ?= BigInt(BigInteger.valueOf(i)) } property("BigInt.apply(l: Long)") = forAll { (l: Long) => BigInt(l) ?= BigInt(BigInteger.valueOf(l)) } property("BigInt.apply(x: Array[Byte])") = forAll(bigInteger) { bi => BigInt(bi) ?= BigInt(bi.toByteArray) }