Skip to content

Commit

Permalink
Modified the code according to the review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangjiajin committed Jul 11, 2015
1 parent 574e56c commit ca9c4c8
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,35 @@ import org.apache.spark.annotation.Experimental
private[fpm] object LocalPrefixSpan extends Logging with Serializable {

/**
* Calculate all patterns of a projected database in local.
* Calculate all patterns of a projected database.
* @param minCount minimum count
* @param maxPatternLength maximum pattern length
* @param prefix prefix
* @param projectedDatabase the projected dabase
* @return a set of sequential pattern pairs,
* the key of pair is pattern (a list of elements),
* the key of pair is sequential pattern (a list of items),
* the value of pair is the pattern's count.
*/
def run(
minCount: Long,
maxPatternLength: Int,
prefix: Array[Int],
projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = {
getPatternsWithPrefix(minCount, maxPatternLength, prefix, projectedDatabase)
val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase)
val frequentPatternAndCounts = frequentPrefixAndCounts
.map(x => (prefix ++ Array(x._1), x._2))
val prefixProjectedDatabases = getPatternAndProjectedDatabase(
prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase)

val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength
if (continueProcess) {
val nextPatterns = prefixProjectedDatabases
.map(x => run(minCount, maxPatternLength, x._1, x._2))
.reduce(_ ++ _)
frequentPatternAndCounts ++ nextPatterns
} else {
frequentPatternAndCounts
}
}

/**
Expand Down Expand Up @@ -96,34 +110,4 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
(prePrefix ++ Array(x), sub)
}.filter(x => x._2.nonEmpty)
}

/**
* Calculate all patterns of a projected database in local.
* @param minCount the minimum count
* @param maxPatternLength maximum pattern length
* @param prefix prefix
* @param projectedDatabase projected database
* @return patterns
*/
private def getPatternsWithPrefix(
minCount: Long,
maxPatternLength: Int,
prefix: Array[Int],
projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = {
val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase)
val frequentPatternAndCounts = frequentPrefixAndCounts
.map(x => (prefix ++ Array(x._1), x._2))
val prefixProjectedDatabases = getPatternAndProjectedDatabase(
prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase)

val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength
if (continueProcess) {
val nextPatterns = prefixProjectedDatabases
.map(x => getPatternsWithPrefix(minCount, maxPatternLength, x._1, x._2))
.reduce(_ ++ _)
frequentPatternAndCounts ++ nextPatterns
} else {
frequentPatternAndCounts
}
}
}
42 changes: 10 additions & 32 deletions mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,15 @@ class PrefixSpan private (
logWarning("Input data is not cached.")
}
val minCount = getMinCount(sequences)
val (lengthOnePatternsAndCounts, prefixAndCandidates) =
findLengthOnePatterns(minCount, sequences)
val projectedDatabase = makePrefixProjectedDatabases(prefixAndCandidates)
val nextPatterns = getPatternsInLocal(minCount, projectedDatabase)
val lengthOnePatternsAndCounts =
getFreqItemAndCounts(minCount, sequences).collect()
val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase(
lengthOnePatternsAndCounts.map(_._1), sequences)
val groupedProjectedDatabase = prefixAndProjectedDatabase
.map(x => (x._1.toSeq, x._2))
.groupByKey()
.map(x => (x._1.toArray, x._2.toArray))
val nextPatterns = getPatternsInLocal(minCount, groupedProjectedDatabase)
val lengthOnePatternsAndCountsRdd =
sequences.sparkContext.parallelize(
lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2)))
Expand Down Expand Up @@ -122,7 +127,7 @@ class PrefixSpan private (
* @param sequences sequences data
* @return prefixes and projected database
*/
private def getPatternAndProjectedDatabase(
private def getPrefixAndProjectedDatabase(
frequentPrefixes: Array[Int],
sequences: RDD[Array[Int]]): RDD[(Array[Int], Array[Int])] = {
val filteredSequences = sequences.map { p =>
Expand All @@ -136,33 +141,6 @@ class PrefixSpan private (
}
}

/**
* Find the patterns that it's length is one
* @param minCount the minimum count
* @param sequences original sequences data
* @return length-one patterns and projection table
*/
private def findLengthOnePatterns(
minCount: Long,
sequences: RDD[Array[Int]]): (Array[(Int, Long)], RDD[(Array[Int], Array[Int])]) = {
val frequentLengthOnePatternAndCounts = getFreqItemAndCounts(minCount, sequences)
val prefixAndProjectedDatabase = getPatternAndProjectedDatabase(
frequentLengthOnePatternAndCounts.keys.collect(), sequences)
(frequentLengthOnePatternAndCounts.collect(), prefixAndProjectedDatabase)
}

/**
* Constructs prefix-projected databases from (prefix, suffix) pairs.
* @param data patterns and projected sequences data before re-partition
* @return patterns and projected sequences data after re-partition
*/
private def makePrefixProjectedDatabases(
data: RDD[(Array[Int], Array[Int])]): RDD[(Array[Int], Array[Array[Int]])] = {
data.map(x => (x._1.toSeq, x._2))
.groupByKey()
.map(x => (x._1.toArray, x._2.toArray))
}

/**
* calculate the patterns in local.
* @param minCount the absolute minimum count
Expand Down

0 comments on commit ca9c4c8

Please sign in to comment.