Skip to content

Commit 994de8f

Browse files
committed
SI-4370 Range bug: Wrong result for Long.MinValue to Long.MaxValue by Int.MaxValue
Fixed by rewriting the entire logic for the count method. This is necessary because the old code was making all kinds of assumptions about what numbers were, but the interface is completely generic. Those assumptions still made have been explicitly specified. Note that you have to make some or you end up doing a binary search, which is not exactly fast. The existing routine is 10-20% slower than the old (broken) one in the worst cases. This seems close enough to me to not bother special-casing Long and BigInt, though I note that this could be done for improved performance. Note that ranges that end up in Int ranges defer to Range for count. We can't assume that one is the smallest increment, so both endpoints and the step need to be Int. A new JUnit test has been added to verify that the test works. It secretly contains an alternate BigInt implementation, but that is a lot slower (>5x) than Long.
1 parent 681308a commit 994de8f

File tree

2 files changed

+194
-20
lines changed

2 files changed

+194
-20
lines changed

src/library/scala/collection/immutable/NumericRange.scala

Lines changed: 71 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -241,28 +241,79 @@ object NumericRange {
241241
else if (start == end) if (isInclusive) 1 else 0
242242
else if (upward != posStep) 0
243243
else {
244-
val diff = num.minus(end, start)
245-
val jumps = num.toLong(num.quot(diff, step))
246-
val remainder = num.rem(diff, step)
247-
val longCount = jumps + (
248-
if (!isInclusive && zero == remainder) 0 else 1
249-
)
250-
251-
/* The edge cases keep coming. Since e.g.
252-
* Long.MaxValue + 1 == Long.MinValue
253-
* we do some more improbable seeming checks lest
254-
* overflow turn up as an empty range.
244+
/* We have to be frightfully paranoid about running out of range.
245+
* We also can't assume that the numbers will fit in a Long.
246+
* We will assume that if a > 0, -a can be represented, and if
247+
* a < 0, -a+1 can be represented. We also assume that if we
248+
* can't fit in Int, we can represent 2*Int.MaxValue+3 (at least).
249+
* And we assume that numbers wrap rather than cap when they overflow.
255250
*/
256-
// The second condition contradicts an empty result.
257-
val isOverflow = longCount == 0 && num.lt(num.plus(start, step), end) == upward
258-
259-
if (longCount > scala.Int.MaxValue || longCount < 0L || isOverflow) {
260-
val word = if (isInclusive) "to" else "until"
261-
val descr = List(start, word, end, "by", step) mkString " "
262-
263-
throw new IllegalArgumentException(descr + ": seqs cannot contain more than Int.MaxValue elements.")
251+
// Check whether we can short-circuit by deferring to Int range.
252+
val startint = num.toInt(start)
253+
if (start == num.fromInt(startint)) {
254+
val endint = num.toInt(end)
255+
if (end == num.fromInt(endint)) {
256+
val stepint = num.toInt(step)
257+
if (step == num.fromInt(stepint)) {
258+
return {
259+
if (isInclusive) Range.inclusive(startint, endint, stepint).length
260+
else Range (startint, endint, stepint).length
261+
}
262+
}
263+
}
264+
}
265+
// If we reach this point, deferring to Int failed.
266+
// Numbers may be big.
267+
val one = num.one
268+
val limit = num.fromInt(Int.MaxValue)
269+
def check(t: T): T =
270+
if (num.gt(t, limit)) throw new IllegalArgumentException("More than Int.MaxValue elements.")
271+
else t
272+
// If the range crosses zero, it might overflow when subtracted
273+
val startside = num.signum(start)
274+
val endside = num.signum(end)
275+
num.toInt{
276+
if (startside*endside >= 0) {
277+
// We're sure we can subtract these numbers.
278+
// Note that we do not use .rem because of different conventions for Long and BigInt
279+
val diff = num.minus(end, start)
280+
val quotient = check(num.quot(diff, step))
281+
val remainder = num.minus(diff, num.times(quotient, step))
282+
if (!isInclusive && zero == remainder) quotient else check(num.plus(quotient, one))
283+
}
284+
else {
285+
// We might not even be able to subtract these numbers.
286+
// Jump in three pieces:
287+
// * start to -1 or 1, whichever is closer (waypointA)
288+
// * one step, which will take us at least to 0 (ends at waypointB)
289+
// * there to the end
290+
val negone = num.fromInt(-1)
291+
val startlim = if (posStep) negone else one
292+
val startdiff = num.minus(startlim, start)
293+
val startq = check(num.quot(startdiff, step))
294+
val waypointA = if (startq == zero) start else num.plus(start, num.times(startq, step))
295+
val waypointB = num.plus(waypointA, step)
296+
check {
297+
if (num.lt(waypointB, end) != upward) {
298+
// No last piece
299+
if (isInclusive && waypointB == end) num.plus(startq, num.fromInt(2))
300+
else num.plus(startq, one)
301+
}
302+
else {
303+
// There is a last piece
304+
val enddiff = num.minus(end,waypointB)
305+
val endq = check(num.quot(enddiff, step))
306+
val last = if (endq == zero) waypointB else num.plus(waypointB, num.times(endq, step))
307+
// Now we have to tally up all the pieces
308+
// 1 for the initial value
309+
// startq steps to waypointA
310+
// 1 step to waypointB
311+
// endq steps to the end (one less if !isInclusive and last==end)
312+
num.plus(startq, num.plus(endq, if (!isInclusive && last==end) one else num.fromInt(2)))
313+
}
314+
}
315+
}
264316
}
265-
longCount.toInt
266317
}
267318
}
268319

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
package scala.collection.immutable
2+
3+
import org.junit.runner.RunWith
4+
import org.junit.runners.JUnit4
5+
import org.junit.Test
6+
import scala.math._
7+
import scala.util._
8+
9+
/* Tests various maps by making sure they all agree on the same answers. */
10+
@RunWith(classOf[JUnit4])
11+
class RangeConsistencyTest {
12+
def r2nr[T: Integral](
13+
r: Range, puff: T, stride: T, check: (T,T) => Boolean, bi: T => BigInt
14+
): List[(BigInt,Try[Int])] = {
15+
val num = implicitly[Integral[T]]
16+
import num._
17+
val one = num.one
18+
19+
if (!check(puff, fromInt(r.start))) return Nil
20+
val start = puff * fromInt(r.start)
21+
val sp1 = start + one
22+
val sn1 = start - one
23+
24+
if (!check(puff, fromInt(r.end))) return Nil
25+
val end = puff * fromInt(r.end)
26+
val ep1 = end + one
27+
val en1 = end - one
28+
29+
if (!check(stride, fromInt(r.step))) return Nil
30+
val step = stride * fromInt(r.step)
31+
32+
def NR(s: T, e: T, i: T) = {
33+
val delta = (bi(e) - bi(s)).abs - (if (r.isInclusive) 0 else 1)
34+
val n = if (r.length == 0) BigInt(0) else delta / bi(i).abs + 1
35+
if (r.isInclusive) {
36+
(n, Try(NumericRange.inclusive(s,e,i).length))
37+
}
38+
else {
39+
(n, Try(NumericRange(s,e,i).length))
40+
}
41+
}
42+
43+
List(NR(start, end, step)) :::
44+
(if (sn1 < start) List(NR(sn1, end, step)) else Nil) :::
45+
(if (start < sp1) List(NR(sp1, end, step)) else Nil) :::
46+
(if (en1 < end) List(NR(start, en1, step)) else Nil) :::
47+
(if (end < ep1) List(NR(start, ep1, step)) else Nil)
48+
}
49+
50+
// Motivated by SI-4370: Wrong result for Long.MinValue to Long.MaxValue by Int.MaxValue
51+
@Test
52+
def rangeChurnTest() {
53+
val rn = new Random(4370)
54+
for (i <- 0 to 10000) { control.Breaks.breakable {
55+
val start = rn.nextInt
56+
val end = rn.nextInt
57+
val step = rn.nextInt(4) match {
58+
case 0 => 1
59+
case 1 => -1
60+
case 2 => (rn.nextInt(11)+2)*(2*rn.nextInt(2)+1)
61+
case 3 => var x = rn.nextInt; while (x==0) x = rn.nextInt; x
62+
}
63+
val r = if (rn.nextBoolean) Range.inclusive(start, end, step) else Range(start, end, step)
64+
65+
try { r.length }
66+
catch { case iae: IllegalArgumentException => control.Breaks.break }
67+
68+
val lpuff = rn.nextInt(4) match {
69+
case 0 => 1L
70+
case 1 => rn.nextInt(11)+2L
71+
case 2 => 1L << rn.nextInt(60)
72+
case 3 => math.max(1L, math.abs(rn.nextLong))
73+
}
74+
val lstride = rn.nextInt(4) match {
75+
case 0 => lpuff
76+
case 1 => 1L
77+
case 2 => 1L << rn.nextInt(60)
78+
case 3 => math.max(1L, math.abs(rn.nextLong))
79+
}
80+
val lr = r2nr[Long](
81+
r, lpuff, lstride,
82+
(a,b) => { val x = BigInt(a)*BigInt(b); x.isValidLong },
83+
x => BigInt(x)
84+
)
85+
86+
lr.foreach{ case (n,t) => assert(
87+
t match {
88+
case Failure(_) => n > Int.MaxValue
89+
case Success(m) => n == m
90+
},
91+
(r.start, r.end, r.step, r.isInclusive, lpuff, lstride, n, t)
92+
)}
93+
94+
val bipuff = rn.nextInt(3) match {
95+
case 0 => BigInt(1)
96+
case 1 => BigInt(rn.nextLong) + Long.MaxValue + 2
97+
case 2 => BigInt("1" + "0"*(rn.nextInt(100)+1))
98+
}
99+
val bistride = rn.nextInt(3) match {
100+
case 0 => bipuff
101+
case 1 => BigInt(1)
102+
case 2 => BigInt("1" + "0"*(rn.nextInt(100)+1))
103+
}
104+
val bir = r2nr[BigInt](r, bipuff, bistride, (a,b) => true, identity)
105+
106+
bir.foreach{ case (n,t) => assert(
107+
t match {
108+
case Failure(_) => n > Int.MaxValue
109+
case Success(m) => n == m
110+
},
111+
(r.start, r.end, r.step, r.isInclusive, bipuff, bistride, n, t)
112+
)}
113+
}}
114+
}
115+
116+
@Test
117+
def testSI4370() { assert{
118+
Try((Long.MinValue to Long.MaxValue by Int.MaxValue).length) match {
119+
case Failure(iae: IllegalArgumentException) => true
120+
case _ => false
121+
}
122+
}}
123+
}

0 commit comments

Comments
 (0)