Skip to content

Commit

Permalink
Modified the code according to the review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangjiajin committed Jul 16, 2015
1 parent baa2885 commit 095aa3a
Showing 1 changed file with 44 additions and 43 deletions.
87 changes: 44 additions & 43 deletions mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.mllib.fpm

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
Expand All @@ -43,7 +45,7 @@ class PrefixSpan private (
private var minSupport: Double,
private var maxPatternLength: Int) extends Logging with Serializable {

private val minPatternsBeforeShuffle: Int = 20
private val minPatternsBeforeLocalProcessing: Int = 20

/**
* Constructs a default instance with default parameters
Expand Down Expand Up @@ -88,66 +90,65 @@ class PrefixSpan private (
val prefixSuffixPairs = getPrefixSuffixPairs(
lengthOnePatternsAndCounts.map(_._1).collect(), sequences)
var patternsCount: Long = lengthOnePatternsAndCounts.count()
var allPatternAndCounts = lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2))
var allPatternAndCounts = lengthOnePatternsAndCounts.map(x => (ArrayBuffer(x._1), x._2))
var currentPrefixSuffixPairs = prefixSuffixPairs
while (patternsCount <= minPatternsBeforeShuffle && currentPrefixSuffixPairs.count() != 0) {
var patternLength: Int = 1
while (patternLength < maxPatternLength &&
patternsCount <= minPatternsBeforeLocalProcessing &&
currentPrefixSuffixPairs.count() != 0) {
val (nextPatternAndCounts, nextPrefixSuffixPairs) =
getPatternCountsAndPrefixSuffixPairs(minCount, currentPrefixSuffixPairs)
patternsCount = nextPatternAndCounts.count().toInt
patternsCount = nextPatternAndCounts.count()
currentPrefixSuffixPairs = nextPrefixSuffixPairs
allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts
patternLength = patternLength + 1
}
if (patternsCount > 0) {
if (patternLength < maxPatternLength && patternsCount > 0) {
val projectedDatabase = currentPrefixSuffixPairs
.map(x => (x._1.toSeq, x._2))
.groupByKey()
.map(x => (x._1.toArray, x._2.toArray))
val nextPatternAndCounts = getPatternsInLocal(minCount, projectedDatabase)
allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts
}
allPatternAndCounts
allPatternAndCounts.map { case (pattern, count) => (pattern.toArray, count) }
}

/**
* Get the pattern and counts, and prefix suffix pairs
* @param minCount minimum count
* @param prefixSuffixPairs prefix and suffix pairs,
* @return pattern and counts, and prefix suffix pairs
* (Array[pattern, count], RDD[prefix, suffix ])
* @param prefixSuffixPairs prefix (length n) and suffix pairs,
* @return pattern (length n+1) and counts, and prefix (length n+1) and suffix pairs
* (RDD[pattern, count], RDD[prefix, suffix ])
*/
private def getPatternCountsAndPrefixSuffixPairs(
minCount: Long,
prefixSuffixPairs: RDD[(Array[Int], Array[Int])]):
(RDD[(Array[Int], Long)], RDD[(Array[Int], Array[Int])]) = {
val prefixAndFreqentItemAndCounts = prefixSuffixPairs
.flatMap { case (prefix, suffix) =>
suffix.distinct.map(y => ((prefix.toSeq, y), 1L))
}.reduceByKey(_ + _)
prefixSuffixPairs: RDD[(ArrayBuffer[Int], Array[Int])]):
(RDD[(ArrayBuffer[Int], Long)], RDD[(ArrayBuffer[Int], Array[Int])]) = {
val prefixAndFrequentItemAndCounts = prefixSuffixPairs
.flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L)) }
.reduceByKey(_ + _)
.filter(_._2 >= minCount)
val patternAndCounts = prefixAndFreqentItemAndCounts
.map{ case ((prefix, item), count) => (prefix.toArray :+ item, count) }
val prefixlength = prefixSuffixPairs.first()._1.length
if (prefixlength + 1 >= maxPatternLength) {
(patternAndCounts, prefixSuffixPairs.filter(x => false))
} else {
val frequentItemsMap = prefixAndFreqentItemAndCounts
.keys
.groupByKey()
.mapValues(_.toSet)
.collect
.toMap
val nextPrefixSuffixPairs = prefixSuffixPairs
.filter(x => frequentItemsMap.contains(x._1))
.flatMap { case (prefix, suffix) =>
val frequentItemSet = frequentItemsMap(prefix)
val filteredSuffix = suffix.filter(frequentItemSet.contains(_))
val nextSuffixes = frequentItemSet.map{ item =>
(item, LocalPrefixSpan.getSuffix(item, filteredSuffix))
}.filter(_._2.nonEmpty)
nextSuffixes.map { case (item, suffix) => (prefix :+ item, suffix) }
val patternAndCounts = prefixAndFrequentItemAndCounts
.map { case ((prefix, item), count) => (prefix :+ item, count) }
val prefixToFrequentNextItemsMap = prefixAndFrequentItemAndCounts
.keys
.groupByKey()
.mapValues(_.toSet)
.collect()
.toMap
val nextPrefixSuffixPairs = prefixSuffixPairs
.filter(x => prefixToFrequentNextItemsMap.contains(x._1))
.flatMap { case (prefix, suffix) =>
val frequentNextItems = prefixToFrequentNextItemsMap(prefix)
val filteredSuffix = suffix.filter(frequentNextItems.contains(_))
frequentNextItems.flatMap { item =>
val suffix = LocalPrefixSpan.getSuffix(item, filteredSuffix)
if (suffix.isEmpty) None
else Some(prefix :+ item, suffix)
}
(patternAndCounts, nextPrefixSuffixPairs)
}
(patternAndCounts, nextPrefixSuffixPairs)
}

/**
Expand Down Expand Up @@ -181,14 +182,14 @@ class PrefixSpan private (
*/
private def getPrefixSuffixPairs(
frequentPrefixes: Array[Int],
sequences: RDD[Array[Int]]): RDD[(Array[Int], Array[Int])] = {
sequences: RDD[Array[Int]]): RDD[(ArrayBuffer[Int], Array[Int])] = {
val filteredSequences = sequences.map { p =>
p.filter (frequentPrefixes.contains(_) )
}
filteredSequences.flatMap { x =>
frequentPrefixes.map { y =>
val sub = LocalPrefixSpan.getSuffix(y, x)
(Array(y), sub)
(ArrayBuffer(y), sub)
}.filter(_._2.nonEmpty)
}
}
Expand All @@ -201,9 +202,9 @@ class PrefixSpan private (
*/
private def getPatternsInLocal(
minCount: Long,
data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = {
data.flatMap { x =>
LocalPrefixSpan.run(minCount, maxPatternLength, x._1, x._2)
}
data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(ArrayBuffer[Int], Long)] = {
data
.flatMap { x => LocalPrefixSpan.run(minCount, maxPatternLength, x._1, x._2) }
.map { case (pattern, count) => (pattern.to[ArrayBuffer], count) }
}
}

0 comments on commit 095aa3a

Please sign in to comment.