diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index c46c6f886e7c9..f5cf732506181 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -725,7 +725,6 @@ class SparseVector( -1 } else { - //grab first active index and value by default var maxIdx = indices(0) var maxValue = values(0) @@ -736,9 +735,14 @@ class SparseVector( } } - // look for inactive values incase all active node values are negative + // look for inactive values in case all active node values are negative if(size != values.size && maxValue <= 0){ - maxIdx = calcInactiveIdx(0) + val firstInactiveIdx = calcFirstInactiveIdx(0) + if(maxValue == 0){ + if(firstInactiveIdx >= maxIdx) maxIdx else maxIdx = firstInactiveIdx + }else{ + maxIdx = firstInactiveIdx + } maxValue = 0 } maxIdx @@ -751,12 +755,12 @@ class SparseVector( * @param idx starting index of computation * @return index of first inactive node */ - private[SparseVector] def calcInactiveIdx(idx: Int): Int = { + private[SparseVector] def calcFirstInactiveIdx(idx: Int): Int = { if (idx < size) { if (!indices.contains(idx)) { idx } else { - calcInactiveIdx(idx + 1) + calcFirstInactiveIdx(idx + 1) } } else { -1 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 118855baecae0..d4dd7f2e0d3d6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -95,19 +95,24 @@ class VectorsSuite extends FunSuite { val max2 = vec3.argmax assert(max2 === 2) - // check for case that sparse vector is created with only negative vaues {0.0, 0.0,-1.0, -0.7, 0.0} - val vec4 = Vectors.sparse(5,Array(2, 3),Array(-1.0,-.7)) + // check for case that sparse vector is created with only negative values {0.0, 0.0,-1.0, -0.7, 0.0} + val vec4 = Vectors.sparse(5,Array(0, 1, 2, 3),Array(0.0, 0.0, -1.0,-.7)) val max3 = vec4.argmax assert(max3 === 0) - // check for case that sparse vector is created with only negative vaues {-1.0, 0.0, -0.7, 0.0, 0.0} - val vec5 = Vectors.sparse(5,Array(0, 3),Array(-1.0,-.7)) + val vec5 = Vectors.sparse(11,Array(0, 3, 10),Array(-1.0,-.7,0.0)) val max4 = vec5.argmax assert(max4 === 1) - val vec6 = Vectors.sparse(2,Array(0, 1),Array(-1.0, 0.0)) + val vec6 = Vectors.sparse(5,Array(0, 1, 3),Array(-1.0, 0.0, -.7)) val max5 = vec6.argmax assert(max5 === 1) + + // test that converting the sparse vector to another sparse vector then calling argmax still works right + var vec8 = Vectors.sparse(5,Array(0, 1),Array(0.0, -1.0)) + vec8 = vec8.toSparse + val max7 = vec8.argmax + assert(max7 === 0) } test("vector equals") {