Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SI-6736 Range.contains is wrong #3437

Merged
merged 2 commits into from Feb 12, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
147 changes: 116 additions & 31 deletions src/library/scala/collection/immutable/Range.scala
Expand Up @@ -23,6 +23,15 @@ import scala.collection.parallel.immutable.ParRange
* println(r2.length) // = 5
* }}}
*
* Ranges that contain more than `Int.MaxValue` elements can be created, but
* these overfull ranges have only limited capabilities. Any method that
* could require a collection of over `Int.MaxValue` length to be created, or
* could be asked to index beyond `Int.MaxValue` elements will throw an
* exception. Overfull ranges can safely be reduced in size by changing
* the step size (e.g. `by 3`) or taking/dropping elements. `contains`,
* `equals`, and access to the ends of the range (`head`, `last`, `tail`,
* `init`) are also permitted on overfull ranges.
*
* @param start the start of this range.
* @param end the exclusive end of the range.
* @param step the step for the range.
Expand Down Expand Up @@ -77,10 +86,24 @@ extends scala.collection.AbstractSeq[Int]
}
}
@deprecated("This method will be made private, use `last` instead.", "2.11")
final val lastElement = start + (numRangeElements - 1) * step
final val lastElement =
if (isEmpty) start - step
else step match {
case 1 => if (isInclusive) end else end-1
case -1 => if (isInclusive) end else end+1
case _ =>
val remainder = (gap % step).toInt
if (remainder != 0) end - remainder
else if (isInclusive) end
else end - step
}

@deprecated("This method will be made private.", "2.11")
final val terminalElement = start + numRangeElements * step
final val terminalElement = lastElement + step

/** The last element of this range. This method will return the correct value
* even if there are too many elements to iterate over.
*/
override def last = if (isEmpty) Nil.last else lastElement
override def head = if (isEmpty) Nil.head else start

Expand Down Expand Up @@ -149,8 +172,12 @@ extends scala.collection.AbstractSeq[Int]
*/
final override def take(n: Int): Range = (
if (n <= 0 || isEmpty) newEmptyRange(start)
else if (n >= numRangeElements) this
else new Range.Inclusive(start, locationAfterN(n - 1), step)
else if (n >= numRangeElements && numRangeElements >= 0) this
else {
// May have more than Int.MaxValue elements in range (numRangeElements < 0)
// but the logic is the same either way: take the first n
new Range.Inclusive(start, locationAfterN(n - 1), step)
}
)

/** Creates a new range containing all the elements of this range except the first `n` elements.
Expand All @@ -162,8 +189,12 @@ extends scala.collection.AbstractSeq[Int]
*/
final override def drop(n: Int): Range = (
if (n <= 0 || isEmpty) this
else if (n >= numRangeElements) newEmptyRange(end)
else copy(locationAfterN(n), end, step)
else if (n >= numRangeElements && numRangeElements >= 0) newEmptyRange(end)
else {
// May have more than Int.MaxValue elements (numRangeElements < 0)
// but the logic is the same either way: go forwards n steps, keep the rest
copy(locationAfterN(n), end, step)
}
)

/** Creates a new range containing all the elements of this range except the last one.
Expand Down Expand Up @@ -192,23 +223,17 @@ extends scala.collection.AbstractSeq[Int]
drop(1)
}

// Counts how many elements from the start meet the given test.
private def skipCount(p: Int => Boolean): Int = {
var current = start
var counted = 0

while (counted < numRangeElements && p(current)) {
counted += 1
current += step
// Advance from the start while we meet the given test
private def argTakeWhile(p: Int => Boolean): Long = {
if (isEmpty) start
else {
var current = start
val stop = last
while (current != stop && p(current)) current += step
if (current != stop || !p(current)) current
else current.toLong + step
}
counted
}
// Tests whether a number is within the endpoints, without testing
// whether it is a member of the sequence (i.e. when step > 1.)
private def isWithinBoundaries(elem: Int) = !isEmpty && (
(step > 0 && start <= elem && elem <= last ) ||
(step < 0 && last <= elem && elem <= start)
)
// Methods like apply throw exceptions on invalid n, but methods like take/drop
// are forgiving: therefore the checks are with the methods.
private def locationAfterN(n: Int) = start + (step * n)
Expand All @@ -219,9 +244,33 @@ extends scala.collection.AbstractSeq[Int]
// based on the given value.
private def newEmptyRange(value: Int) = new Range(value, value, step)

final override def takeWhile(p: Int => Boolean): Range = take(skipCount(p))
final override def dropWhile(p: Int => Boolean): Range = drop(skipCount(p))
final override def span(p: Int => Boolean): (Range, Range) = splitAt(skipCount(p))
final override def takeWhile(p: Int => Boolean): Range = {
val stop = argTakeWhile(p)
if (stop==start) newEmptyRange(start)
else {
val x = (stop - step).toInt
if (x == last) this
else new Range.Inclusive(start, x, step)
}
}
final override def dropWhile(p: Int => Boolean): Range = {
val stop = argTakeWhile(p)
if (stop == start) this
else {
val x = (stop - step).toInt
if (x == last) newEmptyRange(last)
else new Range.Inclusive(x + step, last, step)
}
}
final override def span(p: Int => Boolean): (Range, Range) = {
val border = argTakeWhile(p)
if (border == start) (newEmptyRange(start), this)
else {
val x = (border - step).toInt
if (x == last) (this, newEmptyRange(last))
else (new Range.Inclusive(start, x, step), new Range.Inclusive(x+step, last, step))
}
}

/** Creates a pair of new ranges, first consisting of elements before `n`, and the second
* of elements after `n`.
Expand All @@ -234,13 +283,32 @@ extends scala.collection.AbstractSeq[Int]
*
* $doesNotUseBuilders
*/
final override def takeRight(n: Int): Range = drop(numRangeElements - n)
final override def takeRight(n: Int): Range = {
if (n <= 0) newEmptyRange(start)
else if (numRangeElements >= 0) drop(numRangeElements - n)
else {
// Need to handle over-full range separately
val y = last
val x = y - step.toLong*(n-1)
if ((step > 0 && x < start) || (step < 0 && x > start)) this
else new Range.Inclusive(x.toInt, y, step)
}
}

/** Creates a new range consisting of the initial `length - n` elements of the range.
*
* $doesNotUseBuilders
*/
final override def dropRight(n: Int): Range = take(numRangeElements - n)
final override def dropRight(n: Int): Range = {
if (n <= 0) this
else if (numRangeElements >= 0) take(numRangeElements - n)
else {
// Need to handle over-full range separately
val y = last - step.toInt*n
if ((step > 0 && y < start) || (step < 0 && y > start)) newEmptyRange(start)
else new Range.Inclusive(start, y.toInt, step)
}
}

/** Returns the reverse of this range.
*
Expand All @@ -256,7 +324,17 @@ extends scala.collection.AbstractSeq[Int]
if (isInclusive) this
else new Range.Inclusive(start, end, step)

final def contains(x: Int) = isWithinBoundaries(x) && ((x - start) % step == 0)
final def contains(x: Int) = {
if (x==end && !isInclusive) false
else if (step > 0) {
if (x < start || x > end) false
else (step == 1) || (((x - start) % step) == 0)
}
else {
if (x < end || x > start) false
else (step == -1) || (((x - start) % step) == 0)
}
}

final override def sum[B >: Int](implicit num: Numeric[B]): Int = {
if (num eq scala.math.Numeric.IntIsIntegral) {
Expand Down Expand Up @@ -285,9 +363,15 @@ extends scala.collection.AbstractSeq[Int]

override def equals(other: Any) = other match {
case x: Range =>
(x canEqual this) && (length == x.length) && (
isEmpty || // all empty sequences are equal
(start == x.start && last == x.last) // same length and same endpoints implies equality
// Note: this must succeed for overfull ranges (length > Int.MaxValue)
(x canEqual this) && (
isEmpty || // all empty sequences are equal
(start == x.start && { // Otherwise, must have same start
val l0 = last
(l0 == x.last && ( // And same end
start == l0 || step == x.step // And either the same step, or not take any steps
))
})
)
case _ =>
super.equals(other)
Expand All @@ -297,7 +381,8 @@ extends scala.collection.AbstractSeq[Int]
*/

override def toString() = {
val endStr = if (numRangeElements > Range.MAX_PRINT) ", ... )" else ")"
val endStr =
if (numRangeElements > Range.MAX_PRINT || (!isEmpty && numRangeElements < 0)) ", ... )" else ")"
take(Range.MAX_PRINT).mkString("Range(", ", ", endStr)
}
}
Expand Down
3 changes: 2 additions & 1 deletion test/files/scalacheck/range.scala
Expand Up @@ -265,7 +265,8 @@ object TooLargeRange extends Properties("Too Large Range") {
property("Too large range throws exception") = forAll(genTooLargeStart) { start =>
try {
val r = Range.inclusive(start, Int.MaxValue, 1)
println("how here? r = " + r.toString)
val l = r.length
println("how here? length = " + l + ", r = " + r.toString)
false
}
catch { case _: IllegalArgumentException => true }
Expand Down
19 changes: 18 additions & 1 deletion test/junit/scala/collection/NumericRangeTest.scala
Expand Up @@ -6,7 +6,7 @@ import org.junit.Test
import scala.math._
import scala.util._

/* Tests various maps by making sure they all agree on the same answers. */
/* Tests various ranges by making sure they all agree on the same answers. */
@RunWith(classOf[JUnit4])
class RangeConsistencyTest {
def r2nr[T: Integral](
Expand Down Expand Up @@ -120,4 +120,21 @@ class RangeConsistencyTest {
case _ => false
}
}}

@Test
def testSI6736() {
// These operations on overfull ranges should all succeed.
assert( (0 to Int.MaxValue).contains(4) )
assert( !((Int.MinValue to 0).contains(4)) )
assert( (Int.MinValue to 0).last == 0 )
assert( (Int.MinValue until 5).last == 4 )
assert( (-7 to -99 by -4).last == -99 && (-7 until -99 by -4).last == -95 )
assert( (Int.MinValue to 5) == (Int.MinValue until 6) )
assert( (-3 to Int.MaxValue).drop(4).length == Int.MaxValue )
assert( (-3 to Int.MaxValue).take(1234) == (-3 to 1230) )
assert( (-3 to Int.MaxValue).dropRight(4).length == Int.MaxValue )
assert( (-3 to Int.MaxValue).takeRight(1234).length == 1234 )
assert( (-3 to Int.MaxValue).dropWhile(_ <= 0).length == Int.MaxValue )
assert( (-3 to Int.MaxValue).span(_ <= 0) match { case (a,b) => a.length == 4 && b.length == Int.MaxValue } )
}
}