From 95f21ca8a095767202e1c4d620a865c1647d7e6c Mon Sep 17 00:00:00 2001 From: Rex Kerr Date: Thu, 30 Jan 2014 12:28:14 -0800 Subject: [PATCH 1/2] SI-6736 Range.contains is wrong Removed once-used private method that was calculating ranges in error and corrected the contains method (plus improved performance). --- .../scala/collection/immutable/Range.scala | 18 +++++++++++------- .../scala/collection/NumericRangeTest.scala | 7 ++++++- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/src/library/scala/collection/immutable/Range.scala b/src/library/scala/collection/immutable/Range.scala index 786b18cd21ce..ba695dfbdcb8 100644 --- a/src/library/scala/collection/immutable/Range.scala +++ b/src/library/scala/collection/immutable/Range.scala @@ -203,12 +203,6 @@ extends scala.collection.AbstractSeq[Int] } counted } - // Tests whether a number is within the endpoints, without testing - // whether it is a member of the sequence (i.e. when step > 1.) - private def isWithinBoundaries(elem: Int) = !isEmpty && ( - (step > 0 && start <= elem && elem <= last ) || - (step < 0 && last <= elem && elem <= start) - ) // Methods like apply throw exceptions on invalid n, but methods like take/drop // are forgiving: therefore the checks are with the methods. private def locationAfterN(n: Int) = start + (step * n) @@ -256,7 +250,17 @@ extends scala.collection.AbstractSeq[Int] if (isInclusive) this else new Range.Inclusive(start, end, step) - final def contains(x: Int) = isWithinBoundaries(x) && ((x - start) % step == 0) + final def contains(x: Int) = { + if (x==end && !isInclusive) false + else if (step > 0) { + if (x < start || x > end) false + else (step == 1) || (((x - start) % step) == 0) + } + else { + if (x < end || x > start) false + else (step == -1) || (((x - start) % step) == 0) + } + } final override def sum[B >: Int](implicit num: Numeric[B]): Int = { if (num eq scala.math.Numeric.IntIsIntegral) { diff --git a/test/junit/scala/collection/NumericRangeTest.scala b/test/junit/scala/collection/NumericRangeTest.scala index 0260723b9d2d..f03bf1c49877 100644 --- a/test/junit/scala/collection/NumericRangeTest.scala +++ b/test/junit/scala/collection/NumericRangeTest.scala @@ -6,7 +6,7 @@ import org.junit.Test import scala.math._ import scala.util._ -/* Tests various maps by making sure they all agree on the same answers. */ +/* Tests various ranges by making sure they all agree on the same answers. */ @RunWith(classOf[JUnit4]) class RangeConsistencyTest { def r2nr[T: Integral]( @@ -120,4 +120,9 @@ class RangeConsistencyTest { case _ => false } }} + + @Test + def testSI6736() { assert{ + (0 to Int.MaxValue).contains(4) && !((Int.MinValue to 0).contains(4)) + } } } From e152297090c26051b7e9a6a1740c4670d23d9d5d Mon Sep 17 00:00:00 2001 From: Rex Kerr Date: Fri, 31 Jan 2014 17:18:29 -0800 Subject: [PATCH 2/2] Reasonable Range operations consistently work when overfull. Operations are reasonable when they don't require indexing or conversion into a collection. These include head, tail, init, last, drop, take, dropWhile, takeWhile, dropRight, takeRight, span. Tests added also to verify the new behavior. --- .../scala/collection/immutable/Range.scala | 129 ++++++++++++++---- test/files/scalacheck/range.scala | 3 +- .../scala/collection/NumericRangeTest.scala | 18 ++- 3 files changed, 122 insertions(+), 28 deletions(-) diff --git a/src/library/scala/collection/immutable/Range.scala b/src/library/scala/collection/immutable/Range.scala index ba695dfbdcb8..26ccd0980345 100644 --- a/src/library/scala/collection/immutable/Range.scala +++ b/src/library/scala/collection/immutable/Range.scala @@ -23,6 +23,15 @@ import scala.collection.parallel.immutable.ParRange * println(r2.length) // = 5 * }}} * + * Ranges that contain more than `Int.MaxValue` elements can be created, but + * these overfull ranges have only limited capabilities. Any method that + * could require a collection of over `Int.MaxValue` length to be created, or + * could be asked to index beyond `Int.MaxValue` elements will throw an + * exception. Overfull ranges can safely be reduced in size by changing + * the step size (e.g. `by 3`) or taking/dropping elements. `contains`, + * `equals`, and access to the ends of the range (`head`, `last`, `tail`, + * `init`) are also permitted on overfull ranges. + * * @param start the start of this range. * @param end the exclusive end of the range. * @param step the step for the range. @@ -77,10 +86,24 @@ extends scala.collection.AbstractSeq[Int] } } @deprecated("This method will be made private, use `last` instead.", "2.11") - final val lastElement = start + (numRangeElements - 1) * step + final val lastElement = + if (isEmpty) start - step + else step match { + case 1 => if (isInclusive) end else end-1 + case -1 => if (isInclusive) end else end+1 + case _ => + val remainder = (gap % step).toInt + if (remainder != 0) end - remainder + else if (isInclusive) end + else end - step + } + @deprecated("This method will be made private.", "2.11") - final val terminalElement = start + numRangeElements * step + final val terminalElement = lastElement + step + /** The last element of this range. This method will return the correct value + * even if there are too many elements to iterate over. + */ override def last = if (isEmpty) Nil.last else lastElement override def head = if (isEmpty) Nil.head else start @@ -149,8 +172,12 @@ extends scala.collection.AbstractSeq[Int] */ final override def take(n: Int): Range = ( if (n <= 0 || isEmpty) newEmptyRange(start) - else if (n >= numRangeElements) this - else new Range.Inclusive(start, locationAfterN(n - 1), step) + else if (n >= numRangeElements && numRangeElements >= 0) this + else { + // May have more than Int.MaxValue elements in range (numRangeElements < 0) + // but the logic is the same either way: take the first n + new Range.Inclusive(start, locationAfterN(n - 1), step) + } ) /** Creates a new range containing all the elements of this range except the first `n` elements. @@ -162,8 +189,12 @@ extends scala.collection.AbstractSeq[Int] */ final override def drop(n: Int): Range = ( if (n <= 0 || isEmpty) this - else if (n >= numRangeElements) newEmptyRange(end) - else copy(locationAfterN(n), end, step) + else if (n >= numRangeElements && numRangeElements >= 0) newEmptyRange(end) + else { + // May have more than Int.MaxValue elements (numRangeElements < 0) + // but the logic is the same either way: go forwards n steps, keep the rest + copy(locationAfterN(n), end, step) + } ) /** Creates a new range containing all the elements of this range except the last one. @@ -192,16 +223,16 @@ extends scala.collection.AbstractSeq[Int] drop(1) } - // Counts how many elements from the start meet the given test. - private def skipCount(p: Int => Boolean): Int = { - var current = start - var counted = 0 - - while (counted < numRangeElements && p(current)) { - counted += 1 - current += step + // Advance from the start while we meet the given test + private def argTakeWhile(p: Int => Boolean): Long = { + if (isEmpty) start + else { + var current = start + val stop = last + while (current != stop && p(current)) current += step + if (current != stop || !p(current)) current + else current.toLong + step } - counted } // Methods like apply throw exceptions on invalid n, but methods like take/drop // are forgiving: therefore the checks are with the methods. @@ -213,9 +244,33 @@ extends scala.collection.AbstractSeq[Int] // based on the given value. private def newEmptyRange(value: Int) = new Range(value, value, step) - final override def takeWhile(p: Int => Boolean): Range = take(skipCount(p)) - final override def dropWhile(p: Int => Boolean): Range = drop(skipCount(p)) - final override def span(p: Int => Boolean): (Range, Range) = splitAt(skipCount(p)) + final override def takeWhile(p: Int => Boolean): Range = { + val stop = argTakeWhile(p) + if (stop==start) newEmptyRange(start) + else { + val x = (stop - step).toInt + if (x == last) this + else new Range.Inclusive(start, x, step) + } + } + final override def dropWhile(p: Int => Boolean): Range = { + val stop = argTakeWhile(p) + if (stop == start) this + else { + val x = (stop - step).toInt + if (x == last) newEmptyRange(last) + else new Range.Inclusive(x + step, last, step) + } + } + final override def span(p: Int => Boolean): (Range, Range) = { + val border = argTakeWhile(p) + if (border == start) (newEmptyRange(start), this) + else { + val x = (border - step).toInt + if (x == last) (this, newEmptyRange(last)) + else (new Range.Inclusive(start, x, step), new Range.Inclusive(x+step, last, step)) + } + } /** Creates a pair of new ranges, first consisting of elements before `n`, and the second * of elements after `n`. @@ -228,13 +283,32 @@ extends scala.collection.AbstractSeq[Int] * * $doesNotUseBuilders */ - final override def takeRight(n: Int): Range = drop(numRangeElements - n) + final override def takeRight(n: Int): Range = { + if (n <= 0) newEmptyRange(start) + else if (numRangeElements >= 0) drop(numRangeElements - n) + else { + // Need to handle over-full range separately + val y = last + val x = y - step.toLong*(n-1) + if ((step > 0 && x < start) || (step < 0 && x > start)) this + else new Range.Inclusive(x.toInt, y, step) + } + } /** Creates a new range consisting of the initial `length - n` elements of the range. * * $doesNotUseBuilders */ - final override def dropRight(n: Int): Range = take(numRangeElements - n) + final override def dropRight(n: Int): Range = { + if (n <= 0) this + else if (numRangeElements >= 0) take(numRangeElements - n) + else { + // Need to handle over-full range separately + val y = last - step.toInt*n + if ((step > 0 && y < start) || (step < 0 && y > start)) newEmptyRange(start) + else new Range.Inclusive(start, y.toInt, step) + } + } /** Returns the reverse of this range. * @@ -289,9 +363,15 @@ extends scala.collection.AbstractSeq[Int] override def equals(other: Any) = other match { case x: Range => - (x canEqual this) && (length == x.length) && ( - isEmpty || // all empty sequences are equal - (start == x.start && last == x.last) // same length and same endpoints implies equality + // Note: this must succeed for overfull ranges (length > Int.MaxValue) + (x canEqual this) && ( + isEmpty || // all empty sequences are equal + (start == x.start && { // Otherwise, must have same start + val l0 = last + (l0 == x.last && ( // And same end + start == l0 || step == x.step // And either the same step, or not take any steps + )) + }) ) case _ => super.equals(other) @@ -301,7 +381,8 @@ extends scala.collection.AbstractSeq[Int] */ override def toString() = { - val endStr = if (numRangeElements > Range.MAX_PRINT) ", ... )" else ")" + val endStr = + if (numRangeElements > Range.MAX_PRINT || (!isEmpty && numRangeElements < 0)) ", ... )" else ")" take(Range.MAX_PRINT).mkString("Range(", ", ", endStr) } } diff --git a/test/files/scalacheck/range.scala b/test/files/scalacheck/range.scala index 1eb186f3039d..493083a51fe1 100644 --- a/test/files/scalacheck/range.scala +++ b/test/files/scalacheck/range.scala @@ -265,7 +265,8 @@ object TooLargeRange extends Properties("Too Large Range") { property("Too large range throws exception") = forAll(genTooLargeStart) { start => try { val r = Range.inclusive(start, Int.MaxValue, 1) - println("how here? r = " + r.toString) + val l = r.length + println("how here? length = " + l + ", r = " + r.toString) false } catch { case _: IllegalArgumentException => true } diff --git a/test/junit/scala/collection/NumericRangeTest.scala b/test/junit/scala/collection/NumericRangeTest.scala index f03bf1c49877..3980c31577b5 100644 --- a/test/junit/scala/collection/NumericRangeTest.scala +++ b/test/junit/scala/collection/NumericRangeTest.scala @@ -122,7 +122,19 @@ class RangeConsistencyTest { }} @Test - def testSI6736() { assert{ - (0 to Int.MaxValue).contains(4) && !((Int.MinValue to 0).contains(4)) - } } + def testSI6736() { + // These operations on overfull ranges should all succeed. + assert( (0 to Int.MaxValue).contains(4) ) + assert( !((Int.MinValue to 0).contains(4)) ) + assert( (Int.MinValue to 0).last == 0 ) + assert( (Int.MinValue until 5).last == 4 ) + assert( (-7 to -99 by -4).last == -99 && (-7 until -99 by -4).last == -95 ) + assert( (Int.MinValue to 5) == (Int.MinValue until 6) ) + assert( (-3 to Int.MaxValue).drop(4).length == Int.MaxValue ) + assert( (-3 to Int.MaxValue).take(1234) == (-3 to 1230) ) + assert( (-3 to Int.MaxValue).dropRight(4).length == Int.MaxValue ) + assert( (-3 to Int.MaxValue).takeRight(1234).length == 1234 ) + assert( (-3 to Int.MaxValue).dropWhile(_ <= 0).length == Int.MaxValue ) + assert( (-3 to Int.MaxValue).span(_ <= 0) match { case (a,b) => a.length == 4 && b.length == Int.MaxValue } ) + } }