diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 9d8c60ef0fc45..82d864b44fa6e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -43,6 +43,8 @@ class PrefixSpan private ( private var minSupport: Double, private var maxPatternLength: Int) extends Logging with Serializable { + private val minPatternsBeforeShuffle: Int = 20 + /** * Constructs a default instance with default parameters * {minSupport: `0.1`, maxPatternLength: `10`}. @@ -86,16 +88,69 @@ class PrefixSpan private ( getFreqItemAndCounts(minCount, sequences).collect() val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase( lengthOnePatternsAndCounts.map(_._1), sequences) - val groupedProjectedDatabase = prefixAndProjectedDatabase - .map(x => (x._1.toSeq, x._2)) - .groupByKey() - .map(x => (x._1.toArray, x._2.toArray)) - val nextPatterns = getPatternsInLocal(minCount, groupedProjectedDatabase) - val lengthOnePatternsAndCountsRdd = - sequences.sparkContext.parallelize( - lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2))) - val allPatterns = lengthOnePatternsAndCountsRdd ++ nextPatterns - allPatterns + + var patternsCount = lengthOnePatternsAndCounts.length + var allPatternAndCounts = sequences.sparkContext.parallelize( + lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2))) + var currentProjectedDatabase = prefixAndProjectedDatabase + while (patternsCount <= minPatternsBeforeShuffle && + currentProjectedDatabase.count() != 0) { + val (nextPatternAndCounts, nextProjectedDatabase) = + getPatternCountsAndProjectedDatabase(minCount, currentProjectedDatabase) + patternsCount = nextPatternAndCounts.count().toInt + currentProjectedDatabase = nextProjectedDatabase + allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts + } + if (patternsCount > 0) { + val groupedProjectedDatabase = currentProjectedDatabase + .map(x => (x._1.toSeq, x._2)) + .groupByKey() + .map(x => (x._1.toArray, x._2.toArray)) + val nextPatternAndCounts = getPatternsInLocal(minCount, groupedProjectedDatabase) + allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts + } + allPatternAndCounts + } + + /** + * Get the pattern and counts, and projected database + * @param minCount minimum count + * @param prefixAndProjectedDatabase prefix and projected database, + * @return pattern and counts, and projected database + * (Array[pattern, count], RDD[prefix, projected database ]) + */ + private def getPatternCountsAndProjectedDatabase( + minCount: Long, + prefixAndProjectedDatabase: RDD[(Array[Int], Array[Int])]): + (RDD[(Array[Int], Long)], RDD[(Array[Int], Array[Int])]) = { + val prefixAndFreqentItemAndCounts = prefixAndProjectedDatabase.flatMap{ x => + x._2.distinct.map(y => ((x._1.toSeq, y), 1L)) + }.reduceByKey(_+_) + .filter(_._2 >= minCount) + val patternAndCounts = prefixAndFreqentItemAndCounts + .map(x => (x._1._1.toArray ++ Array(x._1._2), x._2)) + val prefixlength = prefixAndProjectedDatabase.take(1)(0)._1.length + if (prefixlength + 1 >= maxPatternLength) { + (patternAndCounts, prefixAndProjectedDatabase.filter(x => false)) + } else { + val frequentItemsMap = prefixAndFreqentItemAndCounts + .keys.map(x => (x._1, x._2)) + .groupByKey() + .mapValues(_.toSet) + .collect + .toMap + val nextPrefixAndProjectedDatabase = prefixAndProjectedDatabase + .filter(x => frequentItemsMap.contains(x._1)) + .flatMap { x => + val frequentItemSet = frequentItemsMap(x._1) + val filteredSequence = x._2.filter(frequentItemSet.contains(_)) + val subProjectedDabase = frequentItemSet.map{ y => + (y, LocalPrefixSpan.getSuffix(y, filteredSequence)) + }.filter(_._2.nonEmpty) + subProjectedDabase.map(y => (x._1 ++ Array(y._1), y._2)) + } + (patternAndCounts, nextPrefixAndProjectedDatabase) + } } /**