diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 7c26779536a91..d5990413173b5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -594,7 +594,7 @@ class DenseVector(val values: Array[Double]) extends Vector { new SparseVector(size, ii, vv) } - def argmax: Int = { + override def argmax: Int = { if (size == 0) { -1 } else { @@ -726,17 +726,42 @@ class SparseVector( } else { var maxIdx = 0 - var maxValue = if(indices(0) != 0) 0 else values(0) + var maxValue = if(indices(0) != 0) 0.0 else values(0) foreachActive { (i, v) => - if(v > maxValue){ + if (v > maxValue) { maxIdx = i maxValue = v } } + + // look for inactive values incase all active node values are negative + if(size != values.size && maxValue < 0){ + maxIdx = calcInactiveIdx(indices(0)) + maxValue = 0 + } maxIdx } } + + /** + * Calculates the first instance of an inactive node in a sparse vector and returns the Idx + * of the element. + * @param idx starting index of computation + * @return index of first inactive node or -1 if it cannot find one + */ + private[SparseVector] def calcInactiveIdx(idx: Int): Int ={ + if(idx < size){ + if(!indices.contains(idx)){ + idx + }else{ + calcInactiveIdx(idx+1) + } + }else{ + -1 + } + } + } object SparseVector { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index d1c75e4edb70c..7a86c670b9c71 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -91,10 +91,23 @@ class VectorsSuite extends FunSuite { val max = vec2.argmax assert(max === 3) - // check for case that sparse vector is created with only negative vaues {0.0,0.0,-1.0,0.0,-0.7} - val vec3 = Vectors.sparse(5,Array(2, 4),Array(-1.0,-.7)) + val vec3 = Vectors.sparse(5,Array(2, 4),Array(1.0,-.7)) val max2 = vec3.argmax - assert(max2 === 0) + assert(max2 === 2) + + // check for case that sparse vector is created with only negative vaues {0.0, 0.0,-1.0, -0.7, 0.0} + val vec4 = Vectors.sparse(5,Array(2, 3),Array(-1.0,-.7)) + val max3 = vec4.argmax + assert(max3 === 0) + + // check for case that sparse vector is created with only negative vaues {-1.0, 0.0, -0.7, 0.0, 0.0} + val vec5 = Vectors.sparse(5,Array(0, 3),Array(-1.0,-.7)) + val max4 = vec5.argmax + assert(max4 === 1) + + val vec6 = Vectors.sparse(5,Array(0, 1, 2),Array(-1.0, -.025, -.7)) + val max5 = vec6.argmax + assert(max5 === 3) } test("vector equals") {