diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 436ab63ed4669..4e76b3b1db19d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -724,15 +724,22 @@ class SparseVector( if (size == 0) { -1 } else { - var maxIdx = 0 + var maxIdx = indices(0) var maxValue = values(0) - var i = 1 - foreachActive{ (i, v) => - if(v > maxValue) { + + foreachActive { (i, v) => + if(values(i) > maxValue){ maxIdx = i maxValue = v } } +// while(i < this.indices.size){ +// if(values(i) > maxValue){ +// maxIdx = indices(i) +// maxValue = values(i) +// } +// i += 1 +// } maxIdx } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 1bc87892e886d..b118856176a72 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -42,7 +42,6 @@ class VectorsSuite extends FunSuite { val vec = Vectors.dense(arr).asInstanceOf[DenseVector] assert(vec.size === arr.length) assert(vec.values.eq(arr)) - vec.argmax } test("sparse vector construction") { @@ -57,8 +56,6 @@ class VectorsSuite extends FunSuite { assert(vec.size === n) assert(vec.indices === indices) assert(vec.values === values) - val vec2 = Vectors.sparse(5,Array(0,3),values).asInstanceOf[SparseVector] - vec2.foreachActive( (i, v) => println(i,v)) } test("dense to array") { @@ -66,11 +63,31 @@ class VectorsSuite extends FunSuite { assert(vec.toArray.eq(arr)) } + test("dense argmax"){ + val vec = Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector] + val noMax = vec.argmax + assert(noMax === -1) + + val vec2 = Vectors.dense(arr).asInstanceOf[DenseVector] + val max = vec2.argmax + assert(max === 3) + } + test("sparse to array") { val vec = Vectors.sparse(n, indices, values).asInstanceOf[SparseVector] assert(vec.toArray === arr) } + test("sparse argmax"){ + val vec = Vectors.sparse(0,Array.empty[Int],Array.empty[Double]).asInstanceOf[SparseVector] + val noMax = vec.argmax + assert(noMax === -1) + + val vec2 = Vectors.sparse(n,indices,values).asInstanceOf[SparseVector] + val max = vec2.argmax + assert(max === 3) + } + test("vector equals") { val dv1 = Vectors.dense(arr.clone()) val dv2 = Vectors.dense(arr.clone())