From 6e149fa3bd88a2347e635f03ab9ae5913e03beee Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Tue, 28 Jul 2015 14:36:36 -0700 Subject: [PATCH] Fix splitPrefixSuffixPairs --- .../apache/spark/mllib/fpm/PrefixSpan.scala | 30 ++++++++----------- .../spark/mllib/fpm/PrefixSpanSuite.scala | 21 ++++++------- 2 files changed, 23 insertions(+), 28 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index cbb514d467b4b..b70ff9815adc8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -58,7 +58,7 @@ class PrefixSpan private ( */ def setMinSupport(minSupport: Double): this.type = { require(minSupport >= 0 && minSupport <= 1, - "The minimum support value must be between 0 and 1, including 0 and 1.") + "The minimum support value must be in [0, 1].") this.minSupport = minSupport this } @@ -126,23 +126,17 @@ class PrefixSpan private ( private def splitPrefixSuffixPairs( prefixSuffixPairs: RDD[(ArrayBuffer[Int], Array[Int])]): (RDD[(ArrayBuffer[Int], Array[Int])], RDD[(ArrayBuffer[Int], Array[Int])]) = { - val suffixSizeMap = prefixSuffixPairs - .map(x => (x._1, x._2.length)) - .reduceByKey(_ + _) - .map(x => (x._2 <= maxProjectedDBSizeBeforeLocalProcessing, Set(x._1))) - .reduceByKey(_ ++ _) - .collect - .toMap - val small = if (suffixSizeMap.contains(true)) { - prefixSuffixPairs.filter(x => suffixSizeMap(true).contains(x._1)) - } else { - prefixSuffixPairs.filter(x => false) - } - val large = if (suffixSizeMap.contains(false)) { - prefixSuffixPairs.filter(x => suffixSizeMap(false).contains(x._1)) - } else { - prefixSuffixPairs.filter(x => false) - } + val prefixToSuffixSize = prefixSuffixPairs + .aggregateByKey(0)( + seqOp = { case (count, suffix) => count + suffix.length }, + combOp = { _ + _ }) + val smallPrefixes = prefixToSuffixSize + .filter(_._2 <= maxProjectedDBSizeBeforeLocalProcessing) + .map(_._1) + .collect() + .toSet + val small = prefixSuffixPairs.filter { case (prefix, _) => smallPrefixes.contains(prefix) } + val large = prefixSuffixPairs.filter { case (prefix, _) => !smallPrefixes.contains(prefix) } (small, large) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala index 9f107c89f6d80..6dd2dc926acc5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -44,13 +44,6 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(sequences, 2).cache() - def compareResult( - expectedValue: Array[(Array[Int], Long)], - actualValue: Array[(Array[Int], Long)]): Boolean = { - expectedValue.map(x => (x._1.toSeq, x._2)).toSet == - actualValue.map(x => (x._1.toSeq, x._2)).toSet - } - val prefixspan = new PrefixSpan() .setMinSupport(0.33) .setMaxPatternLength(50) @@ -76,7 +69,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { (Array(4, 5), 2L), (Array(5), 3L) ) - assert(compareResult(expectedValue1, result1.collect())) + assert(compareResults(expectedValue1, result1.collect())) prefixspan.setMinSupport(0.5).setMaxPatternLength(50) val result2 = prefixspan.run(rdd) @@ -87,7 +80,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { (Array(4), 4L), (Array(5), 3L) ) - assert(compareResult(expectedValue2, result2.collect())) + assert(compareResults(expectedValue2, result2.collect())) prefixspan.setMinSupport(0.33).setMaxPatternLength(2) val result3 = prefixspan.run(rdd) @@ -107,6 +100,14 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { (Array(4, 5), 2L), (Array(5), 3L) ) - assert(compareResult(expectedValue3, result3.collect())) + assert(compareResults(expectedValue3, result3.collect())) + } + + private def compareResults( + expectedValue: Array[(Array[Int], Long)], + actualValue: Array[(Array[Int], Long)]): Boolean = { + expectedValue.map(x => (x._1.toSeq, x._2)).toSet == + actualValue.map(x => (x._1.toSeq, x._2)).toSet } + }