diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index e056f2146c3f1..aed7e30033b8a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.fpm +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD @@ -43,7 +45,7 @@ class PrefixSpan private ( private var minSupport: Double, private var maxPatternLength: Int) extends Logging with Serializable { - private val minPatternsBeforeShuffle: Int = 20 + private val minPatternsBeforeLocalProcessing: Int = 20 /** * Constructs a default instance with default parameters @@ -88,16 +90,20 @@ class PrefixSpan private ( val prefixSuffixPairs = getPrefixSuffixPairs( lengthOnePatternsAndCounts.map(_._1).collect(), sequences) var patternsCount: Long = lengthOnePatternsAndCounts.count() - var allPatternAndCounts = lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2)) + var allPatternAndCounts = lengthOnePatternsAndCounts.map(x => (ArrayBuffer(x._1), x._2)) var currentPrefixSuffixPairs = prefixSuffixPairs - while (patternsCount <= minPatternsBeforeShuffle && currentPrefixSuffixPairs.count() != 0) { + var patternLength: Int = 1 + while (patternLength < maxPatternLength && + patternsCount <= minPatternsBeforeLocalProcessing && + currentPrefixSuffixPairs.count() != 0) { val (nextPatternAndCounts, nextPrefixSuffixPairs) = getPatternCountsAndPrefixSuffixPairs(minCount, currentPrefixSuffixPairs) - patternsCount = nextPatternAndCounts.count().toInt + patternsCount = nextPatternAndCounts.count() currentPrefixSuffixPairs = nextPrefixSuffixPairs allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts + patternLength = patternLength + 1 } - if (patternsCount > 0) { + if (patternLength < maxPatternLength && patternsCount > 0) { val projectedDatabase = currentPrefixSuffixPairs .map(x => (x._1.toSeq, x._2)) .groupByKey() @@ -105,49 +111,44 @@ class PrefixSpan private ( val nextPatternAndCounts = getPatternsInLocal(minCount, projectedDatabase) allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts } - allPatternAndCounts + allPatternAndCounts.map { case (pattern, count) => (pattern.toArray, count) } } /** * Get the pattern and counts, and prefix suffix pairs * @param minCount minimum count - * @param prefixSuffixPairs prefix and suffix pairs, - * @return pattern and counts, and prefix suffix pairs - * (Array[pattern, count], RDD[prefix, suffix ]) + * @param prefixSuffixPairs prefix (length n) and suffix pairs, + * @return pattern (length n+1) and counts, and prefix (length n+1) and suffix pairs + * (RDD[pattern, count], RDD[prefix, suffix ]) */ private def getPatternCountsAndPrefixSuffixPairs( minCount: Long, - prefixSuffixPairs: RDD[(Array[Int], Array[Int])]): - (RDD[(Array[Int], Long)], RDD[(Array[Int], Array[Int])]) = { - val prefixAndFreqentItemAndCounts = prefixSuffixPairs - .flatMap { case (prefix, suffix) => - suffix.distinct.map(y => ((prefix.toSeq, y), 1L)) - }.reduceByKey(_ + _) + prefixSuffixPairs: RDD[(ArrayBuffer[Int], Array[Int])]): + (RDD[(ArrayBuffer[Int], Long)], RDD[(ArrayBuffer[Int], Array[Int])]) = { + val prefixAndFrequentItemAndCounts = prefixSuffixPairs + .flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L)) } + .reduceByKey(_ + _) .filter(_._2 >= minCount) - val patternAndCounts = prefixAndFreqentItemAndCounts - .map{ case ((prefix, item), count) => (prefix.toArray :+ item, count) } - val prefixlength = prefixSuffixPairs.first()._1.length - if (prefixlength + 1 >= maxPatternLength) { - (patternAndCounts, prefixSuffixPairs.filter(x => false)) - } else { - val frequentItemsMap = prefixAndFreqentItemAndCounts - .keys - .groupByKey() - .mapValues(_.toSet) - .collect - .toMap - val nextPrefixSuffixPairs = prefixSuffixPairs - .filter(x => frequentItemsMap.contains(x._1)) - .flatMap { case (prefix, suffix) => - val frequentItemSet = frequentItemsMap(prefix) - val filteredSuffix = suffix.filter(frequentItemSet.contains(_)) - val nextSuffixes = frequentItemSet.map{ item => - (item, LocalPrefixSpan.getSuffix(item, filteredSuffix)) - }.filter(_._2.nonEmpty) - nextSuffixes.map { case (item, suffix) => (prefix :+ item, suffix) } + val patternAndCounts = prefixAndFrequentItemAndCounts + .map { case ((prefix, item), count) => (prefix :+ item, count) } + val prefixToFrequentNextItemsMap = prefixAndFrequentItemAndCounts + .keys + .groupByKey() + .mapValues(_.toSet) + .collect() + .toMap + val nextPrefixSuffixPairs = prefixSuffixPairs + .filter(x => prefixToFrequentNextItemsMap.contains(x._1)) + .flatMap { case (prefix, suffix) => + val frequentNextItems = prefixToFrequentNextItemsMap(prefix) + val filteredSuffix = suffix.filter(frequentNextItems.contains(_)) + frequentNextItems.flatMap { item => + val suffix = LocalPrefixSpan.getSuffix(item, filteredSuffix) + if (suffix.isEmpty) None + else Some(prefix :+ item, suffix) } - (patternAndCounts, nextPrefixSuffixPairs) } + (patternAndCounts, nextPrefixSuffixPairs) } /** @@ -181,14 +182,14 @@ class PrefixSpan private ( */ private def getPrefixSuffixPairs( frequentPrefixes: Array[Int], - sequences: RDD[Array[Int]]): RDD[(Array[Int], Array[Int])] = { + sequences: RDD[Array[Int]]): RDD[(ArrayBuffer[Int], Array[Int])] = { val filteredSequences = sequences.map { p => p.filter (frequentPrefixes.contains(_) ) } filteredSequences.flatMap { x => frequentPrefixes.map { y => val sub = LocalPrefixSpan.getSuffix(y, x) - (Array(y), sub) + (ArrayBuffer(y), sub) }.filter(_._2.nonEmpty) } } @@ -201,9 +202,9 @@ class PrefixSpan private ( */ private def getPatternsInLocal( minCount: Long, - data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = { - data.flatMap { x => - LocalPrefixSpan.run(minCount, maxPatternLength, x._1, x._2) - } + data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(ArrayBuffer[Int], Long)] = { + data + .flatMap { x => LocalPrefixSpan.run(minCount, maxPatternLength, x._1, x._2) } + .map { case (pattern, count) => (pattern.to[ArrayBuffer], count) } } }