Skip to content

Commit

Permalink
Parallelize freqItemCounts
Browse files Browse the repository at this point in the history
  • Loading branch information
Feynman Liang committed Jul 30, 2015
1 parent ad23aa9 commit 4ddf479
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class PrefixSpan private (
* The maximum number of items allowed in a projected database before local processing. If a
* projected database exceeds this size, another iteration of distributed PrefixSpan is run.
*/
// TODO: make configurable with a better default value, 10000 may be too small
private val maxLocalProjDBSize: Long = 10000

/**
Expand All @@ -61,7 +62,7 @@ class PrefixSpan private (
* Get the minimal support (i.e. the frequency of occurrence before a pattern is considered
* frequent).
*/
def getMinSupport(): Double = this.minSupport
def getMinSupport: Double = this.minSupport

/**
* Sets the minimal support level (default: `0.1`).
Expand All @@ -75,7 +76,7 @@ class PrefixSpan private (
/**
* Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider.
*/
def getMaxPatternLength(): Double = this.maxPatternLength
def getMaxPatternLength: Double = this.maxPatternLength

/**
* Sets maximal pattern length (default: `10`).
Expand All @@ -96,6 +97,8 @@ class PrefixSpan private (
* the value of pair is the pattern's count.
*/
def run(sequences: RDD[Array[Int]]): RDD[(Array[Int], Long)] = {
val sc = sequences.sparkContext

if (sequences.getStorageLevel == StorageLevel.NONE) {
logWarning("Input data is not cached.")
}
Expand All @@ -108,10 +111,11 @@ class PrefixSpan private (
.flatMap(seq => seq.distinct.map(item => (item, 1L)))
.reduceByKey(_ + _)
.filter(_._2 >= minCount)
.collect()

// Pairs of (length 1 prefix, suffix consisting of frequent items)
val itemSuffixPairs = {
val freqItems = freqItemCounts.keys.collect().toSet
val freqItems = freqItemCounts.map(_._1).toSet
sequences.flatMap { seq =>
val filteredSeq = seq.filter(freqItems.contains(_))
freqItems.flatMap { item =>
Expand Down Expand Up @@ -141,13 +145,14 @@ class PrefixSpan private (
pairsForDistributed = largerPairsPart
pairsForDistributed.persist(StorageLevel.MEMORY_AND_DISK)
pairsForLocal ++= smallerPairsPart
resultsAccumulator ++= nextPatternAndCounts
resultsAccumulator ++= nextPatternAndCounts.collect()
}

// Process the small projected databases locally
resultsAccumulator ++= getPatternsInLocal(minCount, pairsForLocal.groupByKey())
val remainingResults = getPatternsInLocal(minCount, pairsForLocal.groupByKey())

resultsAccumulator.map { case (pattern, count) => (pattern.toArray, count) }
(sc.parallelize(resultsAccumulator, 1) ++ remainingResults)
.map { case (pattern, count) => (pattern.toArray, count) }
}


Expand Down

0 comments on commit 4ddf479

Please sign in to comment.