Skip to content

Commit

Permalink
Modified the code according to the review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangjiajin committed Jul 9, 2015
1 parent 89bc368 commit 1dd33ad
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 108 deletions.
211 changes: 127 additions & 84 deletions mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

package org.apache.spark.mllib.fpm

import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel

/**
*
Expand All @@ -37,165 +39,206 @@ import org.apache.spark.rdd.RDD
* (Wikipedia)]]
*/
@Experimental
class PrefixSpan(
class PrefixSpan private (
private var minSupport: Double,
private var maxPatternLength: Int) extends java.io.Serializable {

private var absMinSupport: Int = 0
private var maxPatternLength: Int) extends Logging with Serializable {

/**
* Constructs a default instance with default parameters
* {minSupport: `0.1`, maxPatternLength: 10}.
* {minSupport: `0.1`, maxPatternLength: `10`}.
*/
def this() = this(0.1, 10)

/**
* Sets the minimal support level (default: `0.1`).
*/
def setMinSupport(minSupport: Double): this.type = {
require(minSupport >= 0 && minSupport <= 1)
this.minSupport = minSupport
this
}

/**
* Sets maximal pattern length.
* Sets maximal pattern length (default: `10`).
*/
def setMaxPatternLength(maxPatternLength: Int): this.type = {
require(maxPatternLength >= 1)
this.maxPatternLength = maxPatternLength
this
}

/**
* Calculate sequential patterns:
* a) find and collect length-one patterns
* b) for each length-one patterns and each sequence,
* emit (pattern (prefix), suffix sequence) as key-value pairs
* c) group by key and then map value iterator to array
* d) local PrefixSpan on each prefix
* @return sequential patterns
* Find the complete set of sequential patterns in the input sequences.
* @param sequences input data set, contains a set of sequences,
* a sequence is an ordered list of elements.
* @return a set of sequential pattern pairs,
* the key of pair is pattern (a list of elements),
* the value of pair is the pattern's support value.
*/
def run(sequences: RDD[Array[Int]]): RDD[(Seq[Int], Int)] = {
absMinSupport = getAbsoluteMinSupport(sequences)
def run(sequences: RDD[Array[Int]]): RDD[(Array[Int], Long)] = {
if (sequences.getStorageLevel == StorageLevel.NONE) {
logWarning("Input data is not cached.")
}
val minCount = getAbsoluteMinSupport(sequences)
val (lengthOnePatternsAndCounts, prefixAndCandidates) =
findLengthOnePatterns(sequences)
findLengthOnePatterns(minCount, sequences)
val repartitionedRdd = makePrefixProjectedDatabases(prefixAndCandidates)
val nextPatterns = getPatternsInLocal(repartitionedRdd)
val allPatterns = lengthOnePatternsAndCounts.map(x => (Seq(x._1), x._2)) ++ nextPatterns
val nextPatterns = getPatternsInLocal(minCount, repartitionedRdd)
val allPatterns = lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2)) ++ nextPatterns
allPatterns
}

private def getAbsoluteMinSupport(sequences: RDD[Array[Int]]): Int = {
val result = if (minSupport <= 0) {
0
} else {
val count = sequences.count()
val support = if (minSupport <= 1) minSupport else 1
(support * count).toInt
}
result
/**
* Get the absolute minimum support value (sequences count * minSupport).
* @param sequences input data set, contains a set of sequences,
* @return absolute minimum support value,
*/
private def getAbsoluteMinSupport(sequences: RDD[Array[Int]]): Long = {
if (minSupport == 0) 0L else (sequences.count() * minSupport).toLong
}

/**
* Find the patterns that it's length is one
* Generates frequent items by filtering the input data using minimal support level.
* @param minCount the absolute minimum support
* @param sequences original sequences data
* @return length-one patterns and projection table
* @return array of frequent pattern ordered by their frequencies
*/
private def findLengthOnePatterns(
sequences: RDD[Array[Int]]): (RDD[(Int, Int)], RDD[(Seq[Int], Array[Int])]) = {
val LengthOnePatternAndCounts = sequences
.flatMap(_.distinct.map((_, 1)))
private def getFreqItemAndCounts(
minCount: Long,
sequences: RDD[Array[Int]]): RDD[(Int, Long)] = {
sequences.flatMap(_.distinct.map((_, 1L)))
.reduceByKey(_ + _)
val infrequentLengthOnePatterns: Array[Int] = LengthOnePatternAndCounts
.filter(_._2 < absMinSupport)
.map(_._1)
.collect()
val frequentLengthOnePatterns = LengthOnePatternAndCounts
.filter(_._2 >= absMinSupport)
val frequentLengthOnePatternsArray = frequentLengthOnePatterns
.map(_._1)
.collect()
val filteredSequences =
if (infrequentLengthOnePatterns.isEmpty) {
sequences
} else {
sequences.map { p =>
p.filter { x => !infrequentLengthOnePatterns.contains(x) }
}
}
val prefixAndCandidates = filteredSequences.flatMap { x =>
frequentLengthOnePatternsArray.map { y =>
.filter(_._2 >= minCount)
}

/**
* Generates frequent items by filtering the input data using minimal support level.
* @param minCount the absolute minimum support
* @param sequences sequences data
* @return array of frequent pattern ordered by their frequencies
*/
private def getFreqItemAndCounts(
minCount: Long,
sequences: Array[Array[Int]]): Array[(Int, Long)] = {
sequences.flatMap(_.distinct)
.groupBy(x => x)
.mapValues(_.length.toLong)
.filter(_._2 >= minCount)
.toArray
}

/**
* Get the frequent prefixes' projected database.
* @param frequentPrefixes frequent prefixes
* @param sequences sequences data
* @return prefixes and projected database
*/
private def getPatternAndProjectedDatabase(
frequentPrefixes: Array[Int],
sequences: RDD[Array[Int]]): RDD[(Array[Int], Array[Int])] = {
val filteredSequences = sequences.map { p =>
p.filter (frequentPrefixes.contains(_) )
}
filteredSequences.flatMap { x =>
frequentPrefixes.map { y =>
val sub = getSuffix(y, x)
(Seq(y), sub)
(Array(y), sub)
}
}.filter(x => x._2.nonEmpty)
(frequentLengthOnePatterns, prefixAndCandidates)
}

/**
* Re-partition the RDD data, to get better balance and performance.
* Get the frequent prefixes' projected database.
* @param prePrefix the frequent prefixes' prefix
* @param frequentPrefixes frequent prefixes
* @param sequences sequences data
* @return prefixes and projected database
*/
private def getPatternAndProjectedDatabase(
prePrefix: Array[Int],
frequentPrefixes: Array[Int],
sequences: Array[Array[Int]]): Array[(Array[Int], Array[Array[Int]])] = {
val filteredProjectedDatabase = sequences
.map(x => x.filter(frequentPrefixes.contains(_)))
frequentPrefixes.map { x =>
val sub = filteredProjectedDatabase.map(y => getSuffix(x, y)).filter(_.nonEmpty)
(prePrefix ++ Array(x), sub)
}.filter(x => x._2.nonEmpty)
}

/**
* Find the patterns that it's length is one
* @param minCount the absolute minimum support
* @param sequences original sequences data
* @return length-one patterns and projection table
*/
private def findLengthOnePatterns(
minCount: Long,
sequences: RDD[Array[Int]]): (RDD[(Int, Long)], RDD[(Array[Int], Array[Int])]) = {
val frequentLengthOnePatternAndCounts = getFreqItemAndCounts(minCount, sequences)
val prefixAndProjectedDatabase = getPatternAndProjectedDatabase(
frequentLengthOnePatternAndCounts.keys.collect(), sequences)
(frequentLengthOnePatternAndCounts, prefixAndProjectedDatabase)
}

/**
* Constructs prefix-projected databases from (prefix, suffix) pairs.
* @param data patterns and projected sequences data before re-partition
* @return patterns and projected sequences data after re-partition
*/
private def makePrefixProjectedDatabases(
data: RDD[(Seq[Int], Array[Int])]): RDD[(Seq[Int], Array[Array[Int]])] = {
val dataMerged = data
data: RDD[(Array[Int], Array[Int])]): RDD[(Array[Int], Array[Array[Int]])] = {
data.map(x => (x._1.toSeq, x._2))
.groupByKey()
.mapValues(_.toArray)
dataMerged
.map(x => (x._1.toArray, x._2.toArray))
}

/**
* calculate the patterns in local.
* @param minCount the absolute minimum support
* @param data patterns and projected sequences data data
* @return patterns
*/
private def getPatternsInLocal(
data: RDD[(Seq[Int], Array[Array[Int]])]): RDD[(Seq[Int], Int)] = {
val result = data.flatMap { x =>
getPatternsWithPrefix(x._1, x._2)
minCount: Long,
data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = {
data.flatMap { x =>
getPatternsWithPrefix(minCount, x._1, x._2)
}
result
}

/**
* calculate the patterns with one prefix in local.
* @param minCount the absolute minimum support
* @param prefix prefix
* @param projectedDatabase patterns and projected sequences data
* @return patterns
*/
private def getPatternsWithPrefix(
prefix: Seq[Int],
projectedDatabase: Array[Array[Int]]): Array[(Seq[Int], Int)] = {
val prefixAndCounts = projectedDatabase
.flatMap(_.distinct)
.groupBy(x => x)
.mapValues(_.length)
val frequentPrefixExtensions = prefixAndCounts.filter(x => x._2 >= absMinSupport)
val frequentPrefixesAndCounts = frequentPrefixExtensions
.map(x => (prefix ++ Seq(x._1), x._2))
.toArray
val cleanedSearchSpace = projectedDatabase
.map(x => x.filter(y => frequentPrefixExtensions.contains(y)))
val prefixProjectedDatabases = frequentPrefixExtensions.map { x =>
val sub = cleanedSearchSpace.map(y => getSuffix(x._1, y)).filter(_.nonEmpty)
(prefix ++ Seq(x._1), sub)
}.filter(x => x._2.nonEmpty)
.toArray
minCount: Long,
prefix: Array[Int],
projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = {
val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase)
val frequentPatternAndCounts = frequentPrefixAndCounts
.map(x => (prefix ++ Array(x._1), x._2))
val prefixProjectedDatabases = getPatternAndProjectedDatabase(
prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase)

val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength
if (continueProcess) {
val nextPatterns = prefixProjectedDatabases
.map(x => getPatternsWithPrefix(x._1, x._2))
.map(x => getPatternsWithPrefix(minCount, x._1, x._2))
.reduce(_ ++ _)
frequentPrefixesAndCounts ++ nextPatterns
frequentPatternAndCounts ++ nextPatterns
} else {
frequentPrefixesAndCounts
frequentPatternAndCounts
}
}

/**
* calculate suffix sequence following a prefix in a sequence
* @param prefix prefix
* @param sequence original sequence
* @param sequence sequence
* @return suffix sequence
*/
private def getSuffix(prefix: Int, sequence: Array[Int]): Array[Int] = {
Expand Down
Loading

0 comments on commit 1dd33ad

Please sign in to comment.