diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index c9a7dd0b7ecc7..7c26779536a91 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -724,11 +724,12 @@ class SparseVector( if (size == 0) { -1 } else { - var maxIdx = indices(0) - var maxValue = values(0) + + var maxIdx = 0 + var maxValue = if(indices(0) != 0) 0 else values(0) foreachActive { (i, v) => - if(v != 0.0 && v > maxValue){ + if(v > maxValue){ maxIdx = i maxValue = v } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 24d614b0eb973..d1c75e4edb70c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -91,10 +91,10 @@ class VectorsSuite extends FunSuite { val max = vec2.argmax assert(max === 3) - // check for case that sparse vector is created with a zero value in it by mistake - val vec3 = Vectors.sparse(5,Array(0, 2, 4),Array(-1.0, 0.0, -.7)) + // check for case that sparse vector is created with only negative vaues {0.0,0.0,-1.0,0.0,-0.7} + val vec3 = Vectors.sparse(5,Array(2, 4),Array(-1.0,-.7)) val max2 = vec3.argmax - assert(max2 === 4) + assert(max2 === 0) } test("vector equals") {