From d2250b7871035c8096d377805ca9f9a9cf90fdd3 Mon Sep 17 00:00:00 2001 From: zhangjiajin Date: Sat, 18 Jul 2015 18:03:37 +0800 Subject: [PATCH] remove minPatternsBeforeLocalProcessing, add maxSuffixesBeforeLocalProcessing. --- .../apache/spark/mllib/fpm/PrefixSpan.scala | 55 ++++++++++++++++--- 1 file changed, 46 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 139b2f6952fb8..e6fd05ac87b20 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -45,7 +45,7 @@ class PrefixSpan private ( private var minSupport: Double, private var maxPatternLength: Int) extends Logging with Serializable { - private val minPatternsBeforeLocalProcessing: Int = 20 + private val maxSuffixesBeforeLocalProcessing: Long = 10000 /** * Constructs a default instance with default parameters @@ -91,20 +91,25 @@ class PrefixSpan private ( lengthOnePatternsAndCounts.map(_._1).collect(), sequences) var patternsCount: Long = lengthOnePatternsAndCounts.count() var allPatternAndCounts = lengthOnePatternsAndCounts.map(x => (ArrayBuffer(x._1), x._2)) - var currentPrefixSuffixPairs = prefixSuffixPairs + var (smallPrefixSuffixPairs, largePrefixSuffixPairs) = + splitPrefixSuffixPairs(prefixSuffixPairs) + largePrefixSuffixPairs.persist(StorageLevel.MEMORY_AND_DISK) var patternLength: Int = 1 while (patternLength < maxPatternLength && - patternsCount <= minPatternsBeforeLocalProcessing && - currentPrefixSuffixPairs.count() != 0) { + largePrefixSuffixPairs.count() != 0) { val (nextPatternAndCounts, nextPrefixSuffixPairs) = - getPatternCountsAndPrefixSuffixPairs(minCount, currentPrefixSuffixPairs) + getPatternCountsAndPrefixSuffixPairs(minCount, largePrefixSuffixPairs) patternsCount = nextPatternAndCounts.count() - currentPrefixSuffixPairs = nextPrefixSuffixPairs + largePrefixSuffixPairs.unpersist() + val splitedPrefixSuffixPairs = splitPrefixSuffixPairs(nextPrefixSuffixPairs) + largePrefixSuffixPairs = splitedPrefixSuffixPairs._2 + largePrefixSuffixPairs.persist(StorageLevel.MEMORY_AND_DISK) + smallPrefixSuffixPairs = smallPrefixSuffixPairs ++ splitedPrefixSuffixPairs._1 allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts patternLength = patternLength + 1 } - if (patternLength < maxPatternLength && patternsCount > 0) { - val projectedDatabase = currentPrefixSuffixPairs + if (smallPrefixSuffixPairs.count() > 0) { + val projectedDatabase = smallPrefixSuffixPairs .map(x => (x._1.toSeq, x._2)) .groupByKey() .map(x => (x._1.toArray, x._2.toArray)) @@ -114,6 +119,38 @@ class PrefixSpan private ( allPatternAndCounts.map { case (pattern, count) => (pattern.toArray, count) } } + + /** + * Split prefix suffix pairs to two parts: + * suffixes' size less than maxSuffixesBeforeLocalProcessing and + * suffixes' size more than maxSuffixesBeforeLocalProcessing + * @param prefixSuffixPairs prefix (length n) and suffix pairs, + * @return small size prefix suffix pairs and big size prefix suffix pairs + * (RDD[prefix, suffix], RDD[prefix, suffix ]) + */ + private def splitPrefixSuffixPairs( + prefixSuffixPairs: RDD[(ArrayBuffer[Int], Array[Int])]): + (RDD[(ArrayBuffer[Int], Array[Int])], RDD[(ArrayBuffer[Int], Array[Int])]) = { + val suffixSizeMap = prefixSuffixPairs + .map(x => (x._1, x._2.length)) + .reduceByKey(_ + _) + .map(x => (x._2 <= maxSuffixesBeforeLocalProcessing, Set(x._1))) + .reduceByKey(_ ++ _) + .collect + .toMap + val small = if (suffixSizeMap.contains(true)) { + prefixSuffixPairs.filter(x => suffixSizeMap(true).contains(x._1)) + } else { + prefixSuffixPairs.filter(x => false) + } + val large = if (suffixSizeMap.contains(false)) { + prefixSuffixPairs.filter(x => suffixSizeMap(false).contains(x._1)) + } else { + prefixSuffixPairs.filter(x => false) + } + (small, large) + } + /** * Get the pattern and counts, and prefix suffix pairs * @param minCount minimum count @@ -205,7 +242,7 @@ class PrefixSpan private ( data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(ArrayBuffer[Int], Long)] = { data.flatMap { case (prefix, projDB) => - LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList, projDB) + LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList.reverse, projDB) .map { case (pattern: List[Int], count: Long) => (pattern.toArray.reverse.to[ArrayBuffer], count) }