diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index e6fd05ac87b20..cbb514d467b4b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -45,7 +45,7 @@ class PrefixSpan private ( private var minSupport: Double, private var maxPatternLength: Int) extends Logging with Serializable { - private val maxSuffixesBeforeLocalProcessing: Long = 10000 + private val maxProjectedDBSizeBeforeLocalProcessing: Long = 10000 /** * Constructs a default instance with default parameters @@ -89,24 +89,19 @@ class PrefixSpan private ( val lengthOnePatternsAndCounts = getFreqItemAndCounts(minCount, sequences) val prefixSuffixPairs = getPrefixSuffixPairs( lengthOnePatternsAndCounts.map(_._1).collect(), sequences) - var patternsCount: Long = lengthOnePatternsAndCounts.count() + prefixSuffixPairs.persist(StorageLevel.MEMORY_AND_DISK) var allPatternAndCounts = lengthOnePatternsAndCounts.map(x => (ArrayBuffer(x._1), x._2)) var (smallPrefixSuffixPairs, largePrefixSuffixPairs) = splitPrefixSuffixPairs(prefixSuffixPairs) - largePrefixSuffixPairs.persist(StorageLevel.MEMORY_AND_DISK) - var patternLength: Int = 1 - while (patternLength < maxPatternLength && - largePrefixSuffixPairs.count() != 0) { + while (largePrefixSuffixPairs.count() != 0) { val (nextPatternAndCounts, nextPrefixSuffixPairs) = getPatternCountsAndPrefixSuffixPairs(minCount, largePrefixSuffixPairs) - patternsCount = nextPatternAndCounts.count() largePrefixSuffixPairs.unpersist() - val splitedPrefixSuffixPairs = splitPrefixSuffixPairs(nextPrefixSuffixPairs) - largePrefixSuffixPairs = splitedPrefixSuffixPairs._2 + val (smallerPairsPart, largerPairsPart) = splitPrefixSuffixPairs(nextPrefixSuffixPairs) + largePrefixSuffixPairs = largerPairsPart largePrefixSuffixPairs.persist(StorageLevel.MEMORY_AND_DISK) - smallPrefixSuffixPairs = smallPrefixSuffixPairs ++ splitedPrefixSuffixPairs._1 - allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts - patternLength = patternLength + 1 + smallPrefixSuffixPairs ++= smallerPairsPart + allPatternAndCounts ++= nextPatternAndCounts } if (smallPrefixSuffixPairs.count() > 0) { val projectedDatabase = smallPrefixSuffixPairs @@ -114,7 +109,7 @@ class PrefixSpan private ( .groupByKey() .map(x => (x._1.toArray, x._2.toArray)) val nextPatternAndCounts = getPatternsInLocal(minCount, projectedDatabase) - allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts + allPatternAndCounts ++= nextPatternAndCounts } allPatternAndCounts.map { case (pattern, count) => (pattern.toArray, count) } } @@ -122,8 +117,8 @@ class PrefixSpan private ( /** * Split prefix suffix pairs to two parts: - * suffixes' size less than maxSuffixesBeforeLocalProcessing and - * suffixes' size more than maxSuffixesBeforeLocalProcessing + * Prefixes with projected databases smaller than maxSuffixesBeforeLocalProcessing and + * Prefixes with projected databases larger than maxSuffixesBeforeLocalProcessing * @param prefixSuffixPairs prefix (length n) and suffix pairs, * @return small size prefix suffix pairs and big size prefix suffix pairs * (RDD[prefix, suffix], RDD[prefix, suffix ]) @@ -134,7 +129,7 @@ class PrefixSpan private ( val suffixSizeMap = prefixSuffixPairs .map(x => (x._1, x._2.length)) .reduceByKey(_ + _) - .map(x => (x._2 <= maxSuffixesBeforeLocalProcessing, Set(x._1))) + .map(x => (x._2 <= maxProjectedDBSizeBeforeLocalProcessing, Set(x._1))) .reduceByKey(_ ++ _) .collect .toMap