diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 5e6322f2a05a1..5c563262e184d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -97,14 +97,31 @@ class PrefixSpan private ( if (sequences.getStorageLevel == StorageLevel.NONE) { logWarning("Input data is not cached.") } - val minCount = getMinCount(sequences) - val lengthOnePatternsAndCounts = getFreqItemAndCounts(minCount, sequences) - val prefixSuffixPairs = getPrefixSuffixPairs( - lengthOnePatternsAndCounts.map(_._1).collect(), sequences) + + // Convert min support to a min number of transactions for this dataset + val minCount = if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong + + val itemCounts = sequences + .flatMap(_.distinct.map((_, 1L))) + .reduceByKey(_ + _) + .filter(_._2 >= minCount) + + val prefixSuffixPairs = { + val frequentItems = itemCounts.map(_._1).collect() + val candidates = sequences.map { p => + p.filter (frequentItems.contains(_) ) + } + candidates.flatMap { x => + frequentItems.map { y => + val sub = LocalPrefixSpan.getSuffix(y, x) + (ArrayBuffer(y), sub) + }.filter(_._2.nonEmpty) + } + } prefixSuffixPairs.persist(StorageLevel.MEMORY_AND_DISK) - var allPatternAndCounts = lengthOnePatternsAndCounts.map(x => (ArrayBuffer(x._1), x._2)) - var (smallPrefixSuffixPairs, largePrefixSuffixPairs) = - splitPrefixSuffixPairs(prefixSuffixPairs) + + var allPatternAndCounts = itemCounts.map(x => (ArrayBuffer(x._1), x._2)) + var (smallPrefixSuffixPairs, largePrefixSuffixPairs) = splitPrefixSuffixPairs(prefixSuffixPairs) while (largePrefixSuffixPairs.count() != 0) { val (nextPatternAndCounts, nextPrefixSuffixPairs) = getPatternCountsAndPrefixSuffixPairs(minCount, largePrefixSuffixPairs) @@ -115,6 +132,7 @@ class PrefixSpan private ( smallPrefixSuffixPairs ++= smallerPairsPart allPatternAndCounts ++= nextPatternAndCounts } + if (smallPrefixSuffixPairs.count() > 0) { val projectedDatabase = smallPrefixSuffixPairs .map(x => (x._1.toSeq, x._2)) @@ -189,29 +207,6 @@ class PrefixSpan private ( (patternAndCounts, nextPrefixSuffixPairs) } - /** - * Get the minimum count (sequences count * minSupport). - * @param sequences input data set, contains a set of sequences, - * @return minimum count, - */ - private def getMinCount(sequences: RDD[Array[Int]]): Long = { - if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong - } - - /** - * Generates frequent items by filtering the input data using minimal count level. - * @param minCount the absolute minimum count - * @param sequences original sequences data - * @return array of item and count pair - */ - private def getFreqItemAndCounts( - minCount: Long, - sequences: RDD[Array[Int]]): RDD[(Int, Long)] = { - sequences.flatMap(_.distinct.map((_, 1L))) - .reduceByKey(_ + _) - .filter(_._2 >= minCount) - } - /** * Get the frequent prefixes and suffix pairs. * @param frequentPrefixes frequent prefixes @@ -221,15 +216,6 @@ class PrefixSpan private ( private def getPrefixSuffixPairs( frequentPrefixes: Array[Int], sequences: RDD[Array[Int]]): RDD[(ArrayBuffer[Int], Array[Int])] = { - val filteredSequences = sequences.map { p => - p.filter (frequentPrefixes.contains(_) ) - } - filteredSequences.flatMap { x => - frequentPrefixes.map { y => - val sub = LocalPrefixSpan.getSuffix(y, x) - (ArrayBuffer(y), sub) - }.filter(_._2.nonEmpty) - } } /**