Skip to content

Commit

Permalink
Add feature: Collect enough frequent prefixes before projection in Pr…
Browse files Browse the repository at this point in the history
…efixSpan.
  • Loading branch information
zhangjiajin committed Jul 14, 2015
1 parent ca9c4c8 commit 22b0ef4
Showing 1 changed file with 65 additions and 10 deletions.
75 changes: 65 additions & 10 deletions mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class PrefixSpan private (
private var minSupport: Double,
private var maxPatternLength: Int) extends Logging with Serializable {

private val minPatternsBeforeShuffle: Int = 20

/**
* Constructs a default instance with default parameters
* {minSupport: `0.1`, maxPatternLength: `10`}.
Expand Down Expand Up @@ -86,16 +88,69 @@ class PrefixSpan private (
getFreqItemAndCounts(minCount, sequences).collect()
val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase(
lengthOnePatternsAndCounts.map(_._1), sequences)
val groupedProjectedDatabase = prefixAndProjectedDatabase
.map(x => (x._1.toSeq, x._2))
.groupByKey()
.map(x => (x._1.toArray, x._2.toArray))
val nextPatterns = getPatternsInLocal(minCount, groupedProjectedDatabase)
val lengthOnePatternsAndCountsRdd =
sequences.sparkContext.parallelize(
lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2)))
val allPatterns = lengthOnePatternsAndCountsRdd ++ nextPatterns
allPatterns

var patternsCount = lengthOnePatternsAndCounts.length
var allPatternAndCounts = sequences.sparkContext.parallelize(
lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2)))
var currentProjectedDatabase = prefixAndProjectedDatabase
while (patternsCount <= minPatternsBeforeShuffle &&
currentProjectedDatabase.count() != 0) {
val (nextPatternAndCounts, nextProjectedDatabase) =
getPatternCountsAndProjectedDatabase(minCount, currentProjectedDatabase)
patternsCount = nextPatternAndCounts.count().toInt
currentProjectedDatabase = nextProjectedDatabase
allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts
}
if (patternsCount > 0) {
val groupedProjectedDatabase = currentProjectedDatabase
.map(x => (x._1.toSeq, x._2))
.groupByKey()
.map(x => (x._1.toArray, x._2.toArray))
val nextPatternAndCounts = getPatternsInLocal(minCount, groupedProjectedDatabase)
allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts
}
allPatternAndCounts
}

/**
* Get the pattern and counts, and projected database
* @param minCount minimum count
* @param prefixAndProjectedDatabase prefix and projected database,
* @return pattern and counts, and projected database
* (Array[pattern, count], RDD[prefix, projected database ])
*/
private def getPatternCountsAndProjectedDatabase(
minCount: Long,
prefixAndProjectedDatabase: RDD[(Array[Int], Array[Int])]):
(RDD[(Array[Int], Long)], RDD[(Array[Int], Array[Int])]) = {
val prefixAndFreqentItemAndCounts = prefixAndProjectedDatabase.flatMap{ x =>
x._2.distinct.map(y => ((x._1.toSeq, y), 1L))
}.reduceByKey(_+_)
.filter(_._2 >= minCount)
val patternAndCounts = prefixAndFreqentItemAndCounts
.map(x => (x._1._1.toArray ++ Array(x._1._2), x._2))
val prefixlength = prefixAndProjectedDatabase.take(1)(0)._1.length
if (prefixlength + 1 >= maxPatternLength) {
(patternAndCounts, prefixAndProjectedDatabase.filter(x => false))
} else {
val frequentItemsMap = prefixAndFreqentItemAndCounts
.keys.map(x => (x._1, x._2))
.groupByKey()
.mapValues(_.toSet)
.collect
.toMap
val nextPrefixAndProjectedDatabase = prefixAndProjectedDatabase
.filter(x => frequentItemsMap.contains(x._1))
.flatMap { x =>
val frequentItemSet = frequentItemsMap(x._1)
val filteredSequence = x._2.filter(frequentItemSet.contains(_))
val subProjectedDabase = frequentItemSet.map{ y =>
(y, LocalPrefixSpan.getSuffix(y, filteredSequence))
}.filter(_._2.nonEmpty)
subProjectedDabase.map(y => (x._1 ++ Array(y._1), y._2))
}
(patternAndCounts, nextPrefixAndProjectedDatabase)
}
}

/**
Expand Down

0 comments on commit 22b0ef4

Please sign in to comment.