diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala index 7ead6327486cc..0ea792081086d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala @@ -40,7 +40,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { minCount: Long, maxPatternLength: Int, prefixes: List[Int], - database: Array[Array[Int]]): Iterator[(List[Int], Long)] = { + database: Iterable[Array[Int]]): Iterator[(List[Int], Long)] = { if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty val frequentItemAndCounts = getFreqItemAndCounts(minCount, database) val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains)) @@ -67,7 +67,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { } } - def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = { + def project(database: Iterable[Array[Int]], prefix: Int): Iterable[Array[Int]] = { database .map(getSuffix(prefix, _)) .filter(_.nonEmpty) @@ -81,7 +81,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { */ private def getFreqItemAndCounts( minCount: Long, - database: Array[Array[Int]]): mutable.Map[Int, Long] = { + database: Iterable[Array[Int]]): mutable.Map[Int, Long] = { // TODO: use PrimitiveKeyOpenHashMap val counts = mutable.Map[Int, Long]().withDefaultValue(0L) database.foreach { sequence => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index cbb514d467b4b..5b8da9665366b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -45,7 +45,11 @@ class PrefixSpan private ( private var minSupport: Double, private var maxPatternLength: Int) extends Logging with Serializable { - private val maxProjectedDBSizeBeforeLocalProcessing: Long = 10000 + /** + * The maximum number of items allowed in a projected database before local processing. If a + * projected database exceeds this size, another iteration of distributed PrefixSpan is run. + */ + private val maxLocalProjDBSize: Long = 10000 /** * Constructs a default instance with default parameters @@ -53,22 +57,32 @@ class PrefixSpan private ( */ def this() = this(0.1, 10) + /** + * Get the minimal support (i.e. the frequency of occurrence before a pattern is considered + * frequent). + */ + def getMinSupport(): Double = this.minSupport + /** * Sets the minimal support level (default: `0.1`). */ def setMinSupport(minSupport: Double): this.type = { - require(minSupport >= 0 && minSupport <= 1, - "The minimum support value must be between 0 and 1, including 0 and 1.") + require(minSupport >= 0 && minSupport <= 1, "The minimum support value must be in [0, 1].") this.minSupport = minSupport this } + /** + * Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider. + */ + def getMaxPatternLength(): Double = this.maxPatternLength + /** * Sets maximal pattern length (default: `10`). */ def setMaxPatternLength(maxPatternLength: Int): this.type = { - require(maxPatternLength >= 1, - "The maximum pattern length value must be greater than 0.") + // TODO: support unbounded pattern length when maxPatternLength = 0 + require(maxPatternLength >= 1, "The maximum pattern length value must be greater than 0.") this.maxPatternLength = maxPatternLength this } @@ -85,162 +99,145 @@ class PrefixSpan private ( if (sequences.getStorageLevel == StorageLevel.NONE) { logWarning("Input data is not cached.") } - val minCount = getMinCount(sequences) - val lengthOnePatternsAndCounts = getFreqItemAndCounts(minCount, sequences) - val prefixSuffixPairs = getPrefixSuffixPairs( - lengthOnePatternsAndCounts.map(_._1).collect(), sequences) - prefixSuffixPairs.persist(StorageLevel.MEMORY_AND_DISK) - var allPatternAndCounts = lengthOnePatternsAndCounts.map(x => (ArrayBuffer(x._1), x._2)) - var (smallPrefixSuffixPairs, largePrefixSuffixPairs) = - splitPrefixSuffixPairs(prefixSuffixPairs) - while (largePrefixSuffixPairs.count() != 0) { - val (nextPatternAndCounts, nextPrefixSuffixPairs) = - getPatternCountsAndPrefixSuffixPairs(minCount, largePrefixSuffixPairs) - largePrefixSuffixPairs.unpersist() - val (smallerPairsPart, largerPairsPart) = splitPrefixSuffixPairs(nextPrefixSuffixPairs) - largePrefixSuffixPairs = largerPairsPart - largePrefixSuffixPairs.persist(StorageLevel.MEMORY_AND_DISK) - smallPrefixSuffixPairs ++= smallerPairsPart - allPatternAndCounts ++= nextPatternAndCounts + + // Convert min support to a min number of transactions for this dataset + val minCount = if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong + + // (Frequent items -> number of occurrences, all items here satisfy the `minSupport` threshold + val freqItemCounts = sequences + .flatMap(seq => seq.distinct.map(item => (item, 1L))) + .reduceByKey(_ + _) + .filter(_._2 >= minCount) + + // Pairs of (length 1 prefix, suffix consisting of frequent items) + val itemSuffixPairs = { + val freqItems = freqItemCounts.keys.collect().toSet + sequences.flatMap { seq => + val filteredSeq = seq.filter(freqItems.contains(_)) + freqItems.flatMap { item => + val candidateSuffix = LocalPrefixSpan.getSuffix(item, filteredSeq) + candidateSuffix match { + case suffix if !suffix.isEmpty => Some((List(item), suffix)) + case _ => None + } + } + } } - if (smallPrefixSuffixPairs.count() > 0) { - val projectedDatabase = smallPrefixSuffixPairs - .map(x => (x._1.toSeq, x._2)) - .groupByKey() - .map(x => (x._1.toArray, x._2.toArray)) - val nextPatternAndCounts = getPatternsInLocal(minCount, projectedDatabase) - allPatternAndCounts ++= nextPatternAndCounts + + // Accumulator for the computed results to be returned, initialized to the frequent items (i.e. + // frequent length-one prefixes) + var resultsAccumulator = freqItemCounts.map(x => (List(x._1), x._2)) + + // Remaining work to be locally and distributively processed respectfully + var (pairsForLocal, pairsForDistributed) = partitionByProjDBSize(itemSuffixPairs) + + // Continue processing until no pairs for distributed processing remain (i.e. all prefixes have + // projected database sizes <= `maxLocalProjDBSize`) + while (pairsForDistributed.count() != 0) { + val (nextPatternAndCounts, nextPrefixSuffixPairs) = + extendPrefixes(minCount, pairsForDistributed) + pairsForDistributed.unpersist() + val (smallerPairsPart, largerPairsPart) = partitionByProjDBSize(nextPrefixSuffixPairs) + pairsForDistributed = largerPairsPart + pairsForDistributed.persist(StorageLevel.MEMORY_AND_DISK) + pairsForLocal ++= smallerPairsPart + resultsAccumulator ++= nextPatternAndCounts } - allPatternAndCounts.map { case (pattern, count) => (pattern.toArray, count) } + + // Process the small projected databases locally + resultsAccumulator ++= getPatternsInLocal(minCount, pairsForLocal.groupByKey()) + + resultsAccumulator.map { case (pattern, count) => (pattern.toArray, count) } } /** - * Split prefix suffix pairs to two parts: - * Prefixes with projected databases smaller than maxSuffixesBeforeLocalProcessing and - * Prefixes with projected databases larger than maxSuffixesBeforeLocalProcessing + * Partitions the prefix-suffix pairs by projected database size. * @param prefixSuffixPairs prefix (length n) and suffix pairs, - * @return small size prefix suffix pairs and big size prefix suffix pairs - * (RDD[prefix, suffix], RDD[prefix, suffix ]) + * @return prefix-suffix pairs partitioned by whether their projected database size is <= or + * greater than [[maxLocalProjDBSize]] */ - private def splitPrefixSuffixPairs( - prefixSuffixPairs: RDD[(ArrayBuffer[Int], Array[Int])]): - (RDD[(ArrayBuffer[Int], Array[Int])], RDD[(ArrayBuffer[Int], Array[Int])]) = { - val suffixSizeMap = prefixSuffixPairs - .map(x => (x._1, x._2.length)) - .reduceByKey(_ + _) - .map(x => (x._2 <= maxProjectedDBSizeBeforeLocalProcessing, Set(x._1))) - .reduceByKey(_ ++ _) - .collect - .toMap - val small = if (suffixSizeMap.contains(true)) { - prefixSuffixPairs.filter(x => suffixSizeMap(true).contains(x._1)) - } else { - prefixSuffixPairs.filter(x => false) - } - val large = if (suffixSizeMap.contains(false)) { - prefixSuffixPairs.filter(x => suffixSizeMap(false).contains(x._1)) - } else { - prefixSuffixPairs.filter(x => false) - } + private def partitionByProjDBSize(prefixSuffixPairs: RDD[(List[Int], Array[Int])]) + : (RDD[(List[Int], Array[Int])], RDD[(List[Int], Array[Int])]) = { + val prefixToSuffixSize = prefixSuffixPairs + .aggregateByKey(0)( + seqOp = { case (count, suffix) => count + suffix.length }, + combOp = { _ + _ }) + val smallPrefixes = prefixToSuffixSize + .filter(_._2 <= maxLocalProjDBSize) + .keys + .collect() + .toSet + val small = prefixSuffixPairs.filter { case (prefix, _) => smallPrefixes.contains(prefix) } + val large = prefixSuffixPairs.filter { case (prefix, _) => !smallPrefixes.contains(prefix) } (small, large) } /** - * Get the pattern and counts, and prefix suffix pairs + * Extends all prefixes by one item from their suffix and computes the resulting frequent prefixes + * and remaining work. * @param minCount minimum count - * @param prefixSuffixPairs prefix (length n) and suffix pairs, - * @return pattern (length n+1) and counts, and prefix (length n+1) and suffix pairs - * (RDD[pattern, count], RDD[prefix, suffix ]) + * @param prefixSuffixPairs prefix (length N) and suffix pairs, + * @return (frequent length N+1 extended prefix, count) pairs and (frequent length N+1 extended + * prefix, corresponding suffix) pairs. */ - private def getPatternCountsAndPrefixSuffixPairs( + private def extendPrefixes( minCount: Long, - prefixSuffixPairs: RDD[(ArrayBuffer[Int], Array[Int])]): - (RDD[(ArrayBuffer[Int], Long)], RDD[(ArrayBuffer[Int], Array[Int])]) = { - val prefixAndFrequentItemAndCounts = prefixSuffixPairs + prefixSuffixPairs: RDD[(List[Int], Array[Int])]) + : (RDD[(List[Int], Long)], RDD[(List[Int], Array[Int])]) = { + + // (length N prefix, item from suffix) pairs and their corresponding number of occurrences + // Every (prefix :+ suffix) is guaranteed to have support exceeding `minSupport` + val prefixItemPairAndCounts = prefixSuffixPairs .flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L)) } .reduceByKey(_ + _) .filter(_._2 >= minCount) - val patternAndCounts = prefixAndFrequentItemAndCounts - .map { case ((prefix, item), count) => (prefix :+ item, count) } - val prefixToFrequentNextItemsMap = prefixAndFrequentItemAndCounts + + // Map from prefix to set of possible next items from suffix + val prefixToNextItems = prefixItemPairAndCounts .keys .groupByKey() .mapValues(_.toSet) .collect() .toMap - val nextPrefixSuffixPairs = prefixSuffixPairs - .filter(x => prefixToFrequentNextItemsMap.contains(x._1)) - .flatMap { case (prefix, suffix) => - val frequentNextItems = prefixToFrequentNextItemsMap(prefix) - val filteredSuffix = suffix.filter(frequentNextItems.contains(_)) - frequentNextItems.flatMap { item => - val suffix = LocalPrefixSpan.getSuffix(item, filteredSuffix) - if (suffix.isEmpty) None - else Some(prefix :+ item, suffix) - } - } - (patternAndCounts, nextPrefixSuffixPairs) - } - /** - * Get the minimum count (sequences count * minSupport). - * @param sequences input data set, contains a set of sequences, - * @return minimum count, - */ - private def getMinCount(sequences: RDD[Array[Int]]): Long = { - if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong - } - /** - * Generates frequent items by filtering the input data using minimal count level. - * @param minCount the absolute minimum count - * @param sequences original sequences data - * @return array of item and count pair - */ - private def getFreqItemAndCounts( - minCount: Long, - sequences: RDD[Array[Int]]): RDD[(Int, Long)] = { - sequences.flatMap(_.distinct.map((_, 1L))) - .reduceByKey(_ + _) - .filter(_._2 >= minCount) - } + // Frequent patterns with length N+1 and their corresponding counts + val extendedPrefixAndCounts = prefixItemPairAndCounts + .map { case ((prefix, item), count) => (item :: prefix, count) } - /** - * Get the frequent prefixes and suffix pairs. - * @param frequentPrefixes frequent prefixes - * @param sequences sequences data - * @return prefixes and suffix pairs. - */ - private def getPrefixSuffixPairs( - frequentPrefixes: Array[Int], - sequences: RDD[Array[Int]]): RDD[(ArrayBuffer[Int], Array[Int])] = { - val filteredSequences = sequences.map { p => - p.filter (frequentPrefixes.contains(_) ) - } - filteredSequences.flatMap { x => - frequentPrefixes.map { y => - val sub = LocalPrefixSpan.getSuffix(y, x) - (ArrayBuffer(y), sub) - }.filter(_._2.nonEmpty) - } + // Remaining work, all prefixes will have length N+1 + val extendedPrefixAndSuffix = prefixSuffixPairs + .filter(x => prefixToNextItems.contains(x._1)) + .flatMap { case (prefix, suffix) => + val frequentNextItems = prefixToNextItems(prefix) + val filteredSuffix = suffix.filter(frequentNextItems.contains(_)) + frequentNextItems.flatMap { item => + LocalPrefixSpan.getSuffix(item, filteredSuffix) match { + case suffix if !suffix.isEmpty => Some(item :: prefix, suffix) + case _ => None + } + } + } + + (extendedPrefixAndCounts, extendedPrefixAndSuffix) } /** - * calculate the patterns in local. + * Calculate the patterns in local. * @param minCount the absolute minimum count * @param data prefixes and projected sequences data data * @return patterns */ private def getPatternsInLocal( minCount: Long, - data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(ArrayBuffer[Int], Long)] = { + data: RDD[(List[Int], Iterable[Array[Int]])]): RDD[(List[Int], Long)] = { data.flatMap { - case (prefix, projDB) => - LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList.reverse, projDB) - .map { case (pattern: List[Int], count: Long) => - (pattern.toArray.reverse.to[ArrayBuffer], count) - } + case (prefix, projDB) => + LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList.reverse, projDB) + .map { case (pattern: List[Int], count: Long) => + (pattern.reverse, count) + } } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala index 9f107c89f6d80..6dd2dc926acc5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -44,13 +44,6 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(sequences, 2).cache() - def compareResult( - expectedValue: Array[(Array[Int], Long)], - actualValue: Array[(Array[Int], Long)]): Boolean = { - expectedValue.map(x => (x._1.toSeq, x._2)).toSet == - actualValue.map(x => (x._1.toSeq, x._2)).toSet - } - val prefixspan = new PrefixSpan() .setMinSupport(0.33) .setMaxPatternLength(50) @@ -76,7 +69,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { (Array(4, 5), 2L), (Array(5), 3L) ) - assert(compareResult(expectedValue1, result1.collect())) + assert(compareResults(expectedValue1, result1.collect())) prefixspan.setMinSupport(0.5).setMaxPatternLength(50) val result2 = prefixspan.run(rdd) @@ -87,7 +80,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { (Array(4), 4L), (Array(5), 3L) ) - assert(compareResult(expectedValue2, result2.collect())) + assert(compareResults(expectedValue2, result2.collect())) prefixspan.setMinSupport(0.33).setMaxPatternLength(2) val result3 = prefixspan.run(rdd) @@ -107,6 +100,14 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { (Array(4, 5), 2L), (Array(5), 3L) ) - assert(compareResult(expectedValue3, result3.collect())) + assert(compareResults(expectedValue3, result3.collect())) + } + + private def compareResults( + expectedValue: Array[(Array[Int], Long)], + actualValue: Array[(Array[Int], Long)]): Boolean = { + expectedValue.map(x => (x._1.toSeq, x._2)).toSet == + actualValue.map(x => (x._1.toSeq, x._2)).toSet } + }