Skip to content

Commit

Permalink
Fixing corner case issue with zeros in the active values of the spars…
Browse files Browse the repository at this point in the history
…e vector. Updated unit tests
  • Loading branch information
GeorgeDittmar committed Jun 1, 2015
1 parent b1f059f commit 3ee8711
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
14 changes: 9 additions & 5 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,6 @@ class SparseVector(
-1
} else {

//grab first active index and value by default
var maxIdx = indices(0)
var maxValue = values(0)

Expand All @@ -736,9 +735,14 @@ class SparseVector(
}
}

// look for inactive values incase all active node values are negative
// look for inactive values in case all active node values are negative
if(size != values.size && maxValue <= 0){
maxIdx = calcInactiveIdx(0)
val firstInactiveIdx = calcFirstInactiveIdx(0)
if(maxValue == 0){
if(firstInactiveIdx >= maxIdx) maxIdx else maxIdx = firstInactiveIdx
}else{
maxIdx = firstInactiveIdx
}
maxValue = 0
}
maxIdx
Expand All @@ -751,12 +755,12 @@ class SparseVector(
* @param idx starting index of computation
* @return index of first inactive node
*/
private[SparseVector] def calcInactiveIdx(idx: Int): Int = {
private[SparseVector] def calcFirstInactiveIdx(idx: Int): Int = {
if (idx < size) {
if (!indices.contains(idx)) {
idx
} else {
calcInactiveIdx(idx + 1)
calcFirstInactiveIdx(idx + 1)
}
} else {
-1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,19 +95,24 @@ class VectorsSuite extends FunSuite {
val max2 = vec3.argmax
assert(max2 === 2)

// check for case that sparse vector is created with only negative vaues {0.0, 0.0,-1.0, -0.7, 0.0}
val vec4 = Vectors.sparse(5,Array(2, 3),Array(-1.0,-.7))
// check for case that sparse vector is created with only negative values {0.0, 0.0,-1.0, -0.7, 0.0}
val vec4 = Vectors.sparse(5,Array(0, 1, 2, 3),Array(0.0, 0.0, -1.0,-.7))
val max3 = vec4.argmax
assert(max3 === 0)

// check for case that sparse vector is created with only negative vaues {-1.0, 0.0, -0.7, 0.0, 0.0}
val vec5 = Vectors.sparse(5,Array(0, 3),Array(-1.0,-.7))
val vec5 = Vectors.sparse(11,Array(0, 3, 10),Array(-1.0,-.7,0.0))
val max4 = vec5.argmax
assert(max4 === 1)

val vec6 = Vectors.sparse(2,Array(0, 1),Array(-1.0, 0.0))
val vec6 = Vectors.sparse(5,Array(0, 1, 3),Array(-1.0, 0.0, -.7))
val max5 = vec6.argmax
assert(max5 === 1)

// test that converting the sparse vector to another sparse vector then calling argmax still works right
var vec8 = Vectors.sparse(5,Array(0, 1),Array(0.0, -1.0))
vec8 = vec8.toSparse
val max7 = vec8.argmax
assert(max7 === 0)
}

test("vector equals") {
Expand Down

0 comments on commit 3ee8711

Please sign in to comment.