Skip to content

Commit e8534ef

Browse files
committed
Use unsigned arithmetics in Range, instead of Longs.
Previously, `Range` used a number of intermediate operations on `Long`s to avoid overflow. We can streamline a lot of code by using unsigned `Int` arithmetics. In particular, there is only 1 division in the initialization path, instead of 3. Although the fields have not changed, the content of `numRangeElements` is more strict for overfull ranges. This means that deserializing an overfull range from a previous version would not be safe. This is why we bump the SerialVersionUID. This commit upstreams scala-js/scala-js@d972218 from Scala.js.
1 parent c583aa3 commit e8534ef

File tree

1 file changed

+136
-80
lines changed

1 file changed

+136
-80
lines changed

library/src/scala/collection/immutable/Range.scala

Lines changed: 136 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ import scala.util.hashing.MurmurHash3
5757
* '''Note:''' this method does not use builders to construct a new range,
5858
* and its complexity is O(1).
5959
*/
60-
@SerialVersionUID(3L)
60+
@SerialVersionUID(4L)
6161
sealed abstract class Range(
6262
val start: Int,
6363
val end: Int,
@@ -83,11 +83,6 @@ sealed abstract class Range(
8383
r.asInstanceOf[S with EfficientSplit]
8484
}
8585

86-
private[this] def gap = end.toLong - start.toLong
87-
private[this] def isExact = gap % step == 0
88-
private[this] def hasStub = isInclusive || !isExact
89-
private[this] def longLength = gap / step + ( if (hasStub) 1 else 0 )
90-
9186
def isInclusive: Boolean
9287

9388
final override val isEmpty: Boolean = (
@@ -97,27 +92,90 @@ sealed abstract class Range(
9792
(if (step >= 0) start >= end else start <= end)
9893
)
9994

95+
if (step == 0) throw new IllegalArgumentException("step cannot be 0.")
96+
97+
/** Number of elements in this range, if it is non-empty.
98+
*
99+
* If the range is empty, `numRangeElements` does not have a meaningful value.
100+
*
101+
* Otherwise, `numRangeElements` is interpreted in the range [1, 2^32],
102+
* respecting modular arithmetics wrt. the unsigned interpretation.
103+
* In other words, it is 0 if the mathematical value should be 2^32, and the
104+
* standard unsigned int encoding of the mathematical value otherwise.
105+
*
106+
* This interpretation allows to represent all values with the correct
107+
* modular arithmetics, which streamlines the usage sites.
108+
*/
100109
private[this] val numRangeElements: Int = {
101-
if (step == 0) throw new IllegalArgumentException("step cannot be 0.")
102-
else if (isEmpty) 0
103-
else {
104-
val len = longLength
105-
if (len > scala.Int.MaxValue) -1
106-
else len.toInt
107-
}
110+
val stepSign = step >> 31 // if (step >= 0) 0 else -1
111+
val gap = ((end - start) ^ stepSign) - stepSign // if (step >= 0) (end - start) else -(end - start)
112+
val absStep = (step ^ stepSign) - stepSign // if (step >= 0) step else -step
113+
114+
/* If `absStep` is a constant 1, `div` collapses to being an alias of
115+
* `gap`. Then `absStep * div` also collapses to `gap` and therefore
116+
* `absStep * div != gap` constant-folds to `false`.
117+
*
118+
* Since most ranges are exclusive, that makes `numRangeElements` an alias
119+
* of `gap`. Moreover, for exclusive ranges with step 1 and start 0 (which
120+
* are the common case), it makes it an alias of `end` and the entire
121+
* computation goes away.
122+
*/
123+
val div = Integer.divideUnsigned(gap, absStep)
124+
if (isInclusive || (absStep * div != gap)) div + 1 else div
108125
}
109126

110-
final def length = if (numRangeElements < 0) fail() else numRangeElements
127+
final def length: Int =
128+
if (isEmpty) 0
129+
else if (numRangeElements > 0) numRangeElements
130+
else fail()
131+
132+
/** Computes the element of this range after `n` steps from `start`.
133+
*
134+
* `n` is interpreted as an unsigned integer.
135+
*
136+
* If the mathematical result is not within this Range, the result won't
137+
* make sense, but won't error out.
138+
*/
139+
@inline
140+
private[this] def locationAfterN(n: Int): Int = {
141+
/* If `step >= 0`, we interpret `step * n` as an unsigned multiplication,
142+
* and the addition as a mixed `(signed, unsigned) -> signed` operation.
143+
* With those interpretation, they do not overflow, assuming the
144+
* mathematical result is within this Range.
145+
*
146+
* If `step < 0`, we should compute `start - (-step * n)`, with the
147+
* multiplication also interpreted as unsigned, and the subtraction as
148+
* mixed. Again, using those interpretations, they do not overflow.
149+
* But then modular arithmetics allow us to cancel out the two `-` signs,
150+
* so we end up with the same formula.
151+
*/
152+
start + (step * n)
153+
}
111154

112-
// This field has a sensible value only for non-empty ranges
113-
private[this] val lastElement = step match {
114-
case 1 => if (isInclusive) end else end-1
115-
case -1 => if (isInclusive) end else end+1
116-
case _ =>
117-
val remainder = (gap % step).toInt
118-
if (remainder != 0) end - remainder
119-
else if (isInclusive) end
120-
else end - step
155+
/** Last element of this non-empty range.
156+
*
157+
* For empty ranges, this value is nonsensical.
158+
*/
159+
private[this] val lastElement: Int = {
160+
/* Since we can assume the range is non-empty, `(numRangeElements - 1)`
161+
* is a valid unsigned value in the full int range. The general formula is
162+
* therefore `locationAfterN(numRangeElements - 1)`.
163+
*
164+
* We special-case 1 and -1 so that, in the happy path where `step` is a
165+
* constant 1 or -1, and we only use `foreach`, `numRangeElements` is dead
166+
* code.
167+
*
168+
* When `step` is not constant, it is probably 1 or -1 anyway, so the
169+
* single branch should be predictably true.
170+
*
171+
* `step == 1 || step == -1`
172+
* equiv `(step + 1 == 2) || (step + 1 == 0)`
173+
* equiv `((step + 1) & ~2) == 0`
174+
*/
175+
if (((step + 1) & ~2) == 0)
176+
(if (isInclusive) end else end - step)
177+
else
178+
locationAfterN(numRangeElements - 1)
121179
}
122180

123181
/** The last element of this range. This method will return the correct value
@@ -171,18 +229,22 @@ sealed abstract class Range(
171229
// which means it will not fail fast for those cases where failing was
172230
// correct.
173231
private[this] def validateMaxLength(): Unit = {
174-
if (numRangeElements < 0)
232+
if (numRangeElements <= 0 && !isEmpty)
175233
fail()
176234
}
177235
private[this] def description = "%d %s %d by %s".format(start, if (isInclusive) "to" else "until", end, step)
178236
private[this] def fail() = throw new IllegalArgumentException(description + ": seqs cannot contain more than Int.MaxValue elements.")
179237

180238
@throws[IndexOutOfBoundsException]
181239
final def apply(idx: Int): Int = {
182-
validateMaxLength()
183-
if (idx < 0 || idx >= numRangeElements)
184-
throw CommonErrors.indexOutOfBounds(index = idx, max = numRangeElements - 1)
185-
else start + (step * idx)
240+
/* If length is not valid, numRangeElements <= 0, so the condition is always true.
241+
* We push validateMaxLength() inside the then branch, out of the happy path.
242+
*/
243+
if (idx < 0 || idx >= numRangeElements || isEmpty) {
244+
validateMaxLength()
245+
val max = if (isEmpty) -1 else numRangeElements - 1
246+
throw CommonErrors.indexOutOfBounds(index = idx, max = max)
247+
} else locationAfterN(idx)
186248
}
187249

188250
/*@`inline`*/ final override def foreach[@specialized(Unit) U](f: Int => U): Unit = {
@@ -230,19 +292,23 @@ sealed abstract class Range(
230292
case _ => super.sameElements(that)
231293
}
232294

295+
/** Is the non-negative value `n` greater or equal to the number of elements
296+
* in this non-empty range?
297+
*
298+
* This method returns nonsensical results if `n < 0` or if `this.isEmpty`.
299+
*/
300+
@inline private[this] def greaterEqualNumRangeElements(n: Int): Boolean =
301+
(n ^ Int.MinValue) > ((numRangeElements - 1) ^ Int.MinValue) // unsigned comparison
302+
233303
/** Creates a new range containing the first `n` elements of this range.
234304
*
235305
* @param n the number of elements to take.
236306
* @return a new range consisting of `n` first elements.
237307
*/
238308
final override def take(n: Int): Range =
239309
if (n <= 0 || isEmpty) newEmptyRange(start)
240-
else if (n >= numRangeElements && numRangeElements >= 0) this
241-
else {
242-
// May have more than Int.MaxValue elements in range (numRangeElements < 0)
243-
// but the logic is the same either way: take the first n
244-
new Range.Inclusive(start, locationAfterN(n - 1), step)
245-
}
310+
else if (greaterEqualNumRangeElements(n)) this
311+
else new Range.Inclusive(start, locationAfterN(n - 1), step)
246312

247313
/** Creates a new range containing all the elements of this range except the first `n` elements.
248314
*
@@ -251,42 +317,27 @@ sealed abstract class Range(
251317
*/
252318
final override def drop(n: Int): Range =
253319
if (n <= 0 || isEmpty) this
254-
else if (n >= numRangeElements && numRangeElements >= 0) newEmptyRange(end)
255-
else {
256-
// May have more than Int.MaxValue elements (numRangeElements < 0)
257-
// but the logic is the same either way: go forwards n steps, keep the rest
258-
copy(locationAfterN(n), end, step)
259-
}
320+
else if (greaterEqualNumRangeElements(n)) newEmptyRange(end)
321+
else copy(locationAfterN(n), end, step)
260322

261323
/** Creates a new range consisting of the last `n` elements of the range.
262324
*
263325
* $doesNotUseBuilders
264326
*/
265327
final override def takeRight(n: Int): Range = {
266-
if (n <= 0) newEmptyRange(start)
267-
else if (numRangeElements >= 0) drop(numRangeElements - n)
268-
else {
269-
// Need to handle over-full range separately
270-
val y = last
271-
val x = y - step.toLong*(n-1)
272-
if ((step > 0 && x < start) || (step < 0 && x > start)) this
273-
else Range.inclusive(x.toInt, y, step)
274-
}
328+
if (n <= 0 || isEmpty) newEmptyRange(start)
329+
else if (greaterEqualNumRangeElements(n)) this
330+
else copy(locationAfterN(numRangeElements - n), end, step)
275331
}
276332

277333
/** Creates a new range consisting of the initial `length - n` elements of the range.
278334
*
279335
* $doesNotUseBuilders
280336
*/
281337
final override def dropRight(n: Int): Range = {
282-
if (n <= 0) this
283-
else if (numRangeElements >= 0) take(numRangeElements - n)
284-
else {
285-
// Need to handle over-full range separately
286-
val y = last - step.toInt*n
287-
if ((step > 0 && y < start) || (step < 0 && y > start)) newEmptyRange(start)
288-
else Range.inclusive(start, y.toInt, step)
289-
}
338+
if (n <= 0 || isEmpty) this
339+
else if (greaterEqualNumRangeElements(n)) newEmptyRange(end)
340+
else Range.inclusive(start, locationAfterN(numRangeElements - 1 - n), step)
290341
}
291342

292343
// Advance from the start while we meet the given test
@@ -340,8 +391,9 @@ sealed abstract class Range(
340391
* @return a new range consisting of a contiguous interval of values in the old range
341392
*/
342393
final override def slice(from: Int, until: Int): Range =
343-
if (from <= 0) take(until)
344-
else if (until >= numRangeElements && numRangeElements >= 0) drop(from)
394+
if (isEmpty) this
395+
else if (from <= 0) take(until)
396+
else if (greaterEqualNumRangeElements(until) && until >= 0) drop(from)
345397
else {
346398
val fromValue = locationAfterN(from)
347399
if (from >= until) newEmptyRange(fromValue)
@@ -351,10 +403,6 @@ sealed abstract class Range(
351403
// Overridden only to refine the return type
352404
final override def splitAt(n: Int): (Range, Range) = (take(n), drop(n))
353405

354-
// Methods like apply throw exceptions on invalid n, but methods like take/drop
355-
// are forgiving: therefore the checks are with the methods.
356-
private[this] def locationAfterN(n: Int) = start + (step * n)
357-
358406
// When one drops everything. Can't ever have unchecked operations
359407
// like "end + 1" or "end - 1" because ranges involving Int.{ MinValue, MaxValue }
360408
// will overflow. This creates an exclusive range where start == end
@@ -374,13 +422,13 @@ sealed abstract class Range(
374422
else new Range.Inclusive(start, end, step)
375423

376424
final def contains(x: Int): Boolean = {
377-
if (x == end && !isInclusive) false
425+
if (isEmpty) false
378426
else if (step > 0) {
379-
if (x < start || x > end) false
427+
if (x < start || x > lastElement) false
380428
else (step == 1) || (Integer.remainderUnsigned(x - start, step) == 0)
381429
}
382430
else {
383-
if (x < end || x > start) false
431+
if (x > start || x < lastElement) false
384432
else (step == -1) || (Integer.remainderUnsigned(start - x, -step) == 0)
385433
}
386434
}
@@ -483,7 +531,12 @@ sealed abstract class Range(
483531
final override def toString: String = {
484532
val preposition = if (isInclusive) "to" else "until"
485533
val stepped = if (step == 1) "" else s" by $step"
486-
val prefix = if (isEmpty) "empty " else if (!isExact) "inexact " else ""
534+
535+
def isInexact =
536+
if (isInclusive) lastElement != end
537+
else (lastElement + step) != end
538+
539+
val prefix = if (isEmpty) "empty " else if (isInexact) "inexact " else ""
487540
s"${prefix}Range $start $preposition $end$stepped"
488541
}
489542

@@ -543,16 +596,19 @@ object Range {
543596

544597
if (isEmpty) 0
545598
else {
546-
// Counts with Longs so we can recognize too-large ranges.
547-
val gap: Long = end.toLong - start.toLong
548-
val jumps: Long = gap / step
549-
// Whether the size of this range is one larger than the
550-
// number of full-sized jumps.
551-
val hasStub = isInclusive || (gap % step != 0)
552-
val result: Long = jumps + ( if (hasStub) 1 else 0 )
553-
554-
if (result > scala.Int.MaxValue) -1
555-
else result.toInt
599+
val stepSign = step >> 31 // if (step >= 0) 0 else -1
600+
val gap = ((end - start) ^ stepSign) - stepSign // if (step >= 0) (end - start) else -(end - start)
601+
val absStep = (step ^ stepSign) - stepSign // if (step >= 0) step else -step
602+
603+
val div = Integer.divideUnsigned(gap, absStep)
604+
if (isInclusive) {
605+
if (div == -1) // max unsigned int
606+
-1 // corner case: there are 2^32 elements, which would overflow to 0
607+
else
608+
div + 1
609+
} else {
610+
if (absStep * div != gap) div + 1 else div
611+
}
556612
}
557613
}
558614
def count(start: Int, end: Int, step: Int): Int =
@@ -576,12 +632,12 @@ object Range {
576632
*/
577633
def inclusive(start: Int, end: Int): Range.Inclusive = new Range.Inclusive(start, end, 1)
578634

579-
@SerialVersionUID(3L)
635+
@SerialVersionUID(4L)
580636
final class Inclusive(start: Int, end: Int, step: Int) extends Range(start, end, step) {
581637
def isInclusive: Boolean = true
582638
}
583639

584-
@SerialVersionUID(3L)
640+
@SerialVersionUID(4L)
585641
final class Exclusive(start: Int, end: Int, step: Int) extends Range(start, end, step) {
586642
def isInclusive: Boolean = false
587643
}
@@ -635,7 +691,7 @@ object Range {
635691
* @param lastElement The last element included in the Range
636692
* @param initiallyEmpty Whether the Range was initially empty or not
637693
*/
638-
@SerialVersionUID(3L)
694+
@SerialVersionUID(4L)
639695
private class RangeIterator(
640696
start: Int,
641697
step: Int,

0 commit comments

Comments
 (0)