Skip to content

Commit

Permalink
Fix splitPrefixSuffixPairs
Browse files Browse the repository at this point in the history
  • Loading branch information
Feynman Liang committed Jul 28, 2015
1 parent 64271b3 commit 6e149fa
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 28 deletions.
30 changes: 12 additions & 18 deletions mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class PrefixSpan private (
*/
def setMinSupport(minSupport: Double): this.type = {
require(minSupport >= 0 && minSupport <= 1,
"The minimum support value must be between 0 and 1, including 0 and 1.")
"The minimum support value must be in [0, 1].")
this.minSupport = minSupport
this
}
Expand Down Expand Up @@ -126,23 +126,17 @@ class PrefixSpan private (
private def splitPrefixSuffixPairs(
prefixSuffixPairs: RDD[(ArrayBuffer[Int], Array[Int])]):
(RDD[(ArrayBuffer[Int], Array[Int])], RDD[(ArrayBuffer[Int], Array[Int])]) = {
val suffixSizeMap = prefixSuffixPairs
.map(x => (x._1, x._2.length))
.reduceByKey(_ + _)
.map(x => (x._2 <= maxProjectedDBSizeBeforeLocalProcessing, Set(x._1)))
.reduceByKey(_ ++ _)
.collect
.toMap
val small = if (suffixSizeMap.contains(true)) {
prefixSuffixPairs.filter(x => suffixSizeMap(true).contains(x._1))
} else {
prefixSuffixPairs.filter(x => false)
}
val large = if (suffixSizeMap.contains(false)) {
prefixSuffixPairs.filter(x => suffixSizeMap(false).contains(x._1))
} else {
prefixSuffixPairs.filter(x => false)
}
val prefixToSuffixSize = prefixSuffixPairs
.aggregateByKey(0)(
seqOp = { case (count, suffix) => count + suffix.length },
combOp = { _ + _ })
val smallPrefixes = prefixToSuffixSize
.filter(_._2 <= maxProjectedDBSizeBeforeLocalProcessing)
.map(_._1)
.collect()
.toSet
val small = prefixSuffixPairs.filter { case (prefix, _) => smallPrefixes.contains(prefix) }
val large = prefixSuffixPairs.filter { case (prefix, _) => !smallPrefixes.contains(prefix) }
(small, large)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,6 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {

val rdd = sc.parallelize(sequences, 2).cache()

def compareResult(
expectedValue: Array[(Array[Int], Long)],
actualValue: Array[(Array[Int], Long)]): Boolean = {
expectedValue.map(x => (x._1.toSeq, x._2)).toSet ==
actualValue.map(x => (x._1.toSeq, x._2)).toSet
}

val prefixspan = new PrefixSpan()
.setMinSupport(0.33)
.setMaxPatternLength(50)
Expand All @@ -76,7 +69,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
(Array(4, 5), 2L),
(Array(5), 3L)
)
assert(compareResult(expectedValue1, result1.collect()))
assert(compareResults(expectedValue1, result1.collect()))

prefixspan.setMinSupport(0.5).setMaxPatternLength(50)
val result2 = prefixspan.run(rdd)
Expand All @@ -87,7 +80,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
(Array(4), 4L),
(Array(5), 3L)
)
assert(compareResult(expectedValue2, result2.collect()))
assert(compareResults(expectedValue2, result2.collect()))

prefixspan.setMinSupport(0.33).setMaxPatternLength(2)
val result3 = prefixspan.run(rdd)
Expand All @@ -107,6 +100,14 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
(Array(4, 5), 2L),
(Array(5), 3L)
)
assert(compareResult(expectedValue3, result3.collect()))
assert(compareResults(expectedValue3, result3.collect()))
}

private def compareResults(
expectedValue: Array[(Array[Int], Long)],
actualValue: Array[(Array[Int], Long)]): Boolean = {
expectedValue.map(x => (x._1.toSeq, x._2)).toSet ==
actualValue.map(x => (x._1.toSeq, x._2)).toSet
}

}

0 comments on commit 6e149fa

Please sign in to comment.