Skip to content

Commit

Permalink
Merge pull request #1 from feynmanliang/SPARK-8998-collectBeforeLocal
Browse files Browse the repository at this point in the history
[Spark-8998]Collect Enough Prefixes Improvements
  • Loading branch information
zhangjiajin committed Jul 29, 2015
2 parents 64271b3 + 87fa021 commit ad23aa9
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 138 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
minCount: Long,
maxPatternLength: Int,
prefixes: List[Int],
database: Array[Array[Int]]): Iterator[(List[Int], Long)] = {
database: Iterable[Array[Int]]): Iterator[(List[Int], Long)] = {
if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty
val frequentItemAndCounts = getFreqItemAndCounts(minCount, database)
val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains))
Expand All @@ -67,7 +67,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
}
}

def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = {
def project(database: Iterable[Array[Int]], prefix: Int): Iterable[Array[Int]] = {
database
.map(getSuffix(prefix, _))
.filter(_.nonEmpty)
Expand All @@ -81,7 +81,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
*/
private def getFreqItemAndCounts(
minCount: Long,
database: Array[Array[Int]]): mutable.Map[Int, Long] = {
database: Iterable[Array[Int]]): mutable.Map[Int, Long] = {
// TODO: use PrimitiveKeyOpenHashMap
val counts = mutable.Map[Int, Long]().withDefaultValue(0L)
database.foreach { sequence =>
Expand Down
247 changes: 122 additions & 125 deletions mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,30 +45,44 @@ class PrefixSpan private (
private var minSupport: Double,
private var maxPatternLength: Int) extends Logging with Serializable {

private val maxProjectedDBSizeBeforeLocalProcessing: Long = 10000
/**
* The maximum number of items allowed in a projected database before local processing. If a
* projected database exceeds this size, another iteration of distributed PrefixSpan is run.
*/
private val maxLocalProjDBSize: Long = 10000

/**
* Constructs a default instance with default parameters
* {minSupport: `0.1`, maxPatternLength: `10`}.
*/
def this() = this(0.1, 10)

/**
* Get the minimal support (i.e. the frequency of occurrence before a pattern is considered
* frequent).
*/
def getMinSupport(): Double = this.minSupport

/**
* Sets the minimal support level (default: `0.1`).
*/
def setMinSupport(minSupport: Double): this.type = {
require(minSupport >= 0 && minSupport <= 1,
"The minimum support value must be between 0 and 1, including 0 and 1.")
require(minSupport >= 0 && minSupport <= 1, "The minimum support value must be in [0, 1].")
this.minSupport = minSupport
this
}

/**
* Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider.
*/
def getMaxPatternLength(): Double = this.maxPatternLength

/**
* Sets maximal pattern length (default: `10`).
*/
def setMaxPatternLength(maxPatternLength: Int): this.type = {
require(maxPatternLength >= 1,
"The maximum pattern length value must be greater than 0.")
// TODO: support unbounded pattern length when maxPatternLength = 0
require(maxPatternLength >= 1, "The maximum pattern length value must be greater than 0.")
this.maxPatternLength = maxPatternLength
this
}
Expand All @@ -85,162 +99,145 @@ class PrefixSpan private (
if (sequences.getStorageLevel == StorageLevel.NONE) {
logWarning("Input data is not cached.")
}
val minCount = getMinCount(sequences)
val lengthOnePatternsAndCounts = getFreqItemAndCounts(minCount, sequences)
val prefixSuffixPairs = getPrefixSuffixPairs(
lengthOnePatternsAndCounts.map(_._1).collect(), sequences)
prefixSuffixPairs.persist(StorageLevel.MEMORY_AND_DISK)
var allPatternAndCounts = lengthOnePatternsAndCounts.map(x => (ArrayBuffer(x._1), x._2))
var (smallPrefixSuffixPairs, largePrefixSuffixPairs) =
splitPrefixSuffixPairs(prefixSuffixPairs)
while (largePrefixSuffixPairs.count() != 0) {
val (nextPatternAndCounts, nextPrefixSuffixPairs) =
getPatternCountsAndPrefixSuffixPairs(minCount, largePrefixSuffixPairs)
largePrefixSuffixPairs.unpersist()
val (smallerPairsPart, largerPairsPart) = splitPrefixSuffixPairs(nextPrefixSuffixPairs)
largePrefixSuffixPairs = largerPairsPart
largePrefixSuffixPairs.persist(StorageLevel.MEMORY_AND_DISK)
smallPrefixSuffixPairs ++= smallerPairsPart
allPatternAndCounts ++= nextPatternAndCounts

// Convert min support to a min number of transactions for this dataset
val minCount = if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong

// (Frequent items -> number of occurrences, all items here satisfy the `minSupport` threshold
val freqItemCounts = sequences
.flatMap(seq => seq.distinct.map(item => (item, 1L)))
.reduceByKey(_ + _)
.filter(_._2 >= minCount)

// Pairs of (length 1 prefix, suffix consisting of frequent items)
val itemSuffixPairs = {
val freqItems = freqItemCounts.keys.collect().toSet
sequences.flatMap { seq =>
val filteredSeq = seq.filter(freqItems.contains(_))
freqItems.flatMap { item =>
val candidateSuffix = LocalPrefixSpan.getSuffix(item, filteredSeq)
candidateSuffix match {
case suffix if !suffix.isEmpty => Some((List(item), suffix))
case _ => None
}
}
}
}
if (smallPrefixSuffixPairs.count() > 0) {
val projectedDatabase = smallPrefixSuffixPairs
.map(x => (x._1.toSeq, x._2))
.groupByKey()
.map(x => (x._1.toArray, x._2.toArray))
val nextPatternAndCounts = getPatternsInLocal(minCount, projectedDatabase)
allPatternAndCounts ++= nextPatternAndCounts

// Accumulator for the computed results to be returned, initialized to the frequent items (i.e.
// frequent length-one prefixes)
var resultsAccumulator = freqItemCounts.map(x => (List(x._1), x._2))

// Remaining work to be locally and distributively processed respectfully
var (pairsForLocal, pairsForDistributed) = partitionByProjDBSize(itemSuffixPairs)

// Continue processing until no pairs for distributed processing remain (i.e. all prefixes have
// projected database sizes <= `maxLocalProjDBSize`)
while (pairsForDistributed.count() != 0) {
val (nextPatternAndCounts, nextPrefixSuffixPairs) =
extendPrefixes(minCount, pairsForDistributed)
pairsForDistributed.unpersist()
val (smallerPairsPart, largerPairsPart) = partitionByProjDBSize(nextPrefixSuffixPairs)
pairsForDistributed = largerPairsPart
pairsForDistributed.persist(StorageLevel.MEMORY_AND_DISK)
pairsForLocal ++= smallerPairsPart
resultsAccumulator ++= nextPatternAndCounts
}
allPatternAndCounts.map { case (pattern, count) => (pattern.toArray, count) }

// Process the small projected databases locally
resultsAccumulator ++= getPatternsInLocal(minCount, pairsForLocal.groupByKey())

resultsAccumulator.map { case (pattern, count) => (pattern.toArray, count) }
}


/**
* Split prefix suffix pairs to two parts:
* Prefixes with projected databases smaller than maxSuffixesBeforeLocalProcessing and
* Prefixes with projected databases larger than maxSuffixesBeforeLocalProcessing
* Partitions the prefix-suffix pairs by projected database size.
* @param prefixSuffixPairs prefix (length n) and suffix pairs,
* @return small size prefix suffix pairs and big size prefix suffix pairs
* (RDD[prefix, suffix], RDD[prefix, suffix ])
* @return prefix-suffix pairs partitioned by whether their projected database size is <= or
* greater than [[maxLocalProjDBSize]]
*/
private def splitPrefixSuffixPairs(
prefixSuffixPairs: RDD[(ArrayBuffer[Int], Array[Int])]):
(RDD[(ArrayBuffer[Int], Array[Int])], RDD[(ArrayBuffer[Int], Array[Int])]) = {
val suffixSizeMap = prefixSuffixPairs
.map(x => (x._1, x._2.length))
.reduceByKey(_ + _)
.map(x => (x._2 <= maxProjectedDBSizeBeforeLocalProcessing, Set(x._1)))
.reduceByKey(_ ++ _)
.collect
.toMap
val small = if (suffixSizeMap.contains(true)) {
prefixSuffixPairs.filter(x => suffixSizeMap(true).contains(x._1))
} else {
prefixSuffixPairs.filter(x => false)
}
val large = if (suffixSizeMap.contains(false)) {
prefixSuffixPairs.filter(x => suffixSizeMap(false).contains(x._1))
} else {
prefixSuffixPairs.filter(x => false)
}
private def partitionByProjDBSize(prefixSuffixPairs: RDD[(List[Int], Array[Int])])
: (RDD[(List[Int], Array[Int])], RDD[(List[Int], Array[Int])]) = {
val prefixToSuffixSize = prefixSuffixPairs
.aggregateByKey(0)(
seqOp = { case (count, suffix) => count + suffix.length },
combOp = { _ + _ })
val smallPrefixes = prefixToSuffixSize
.filter(_._2 <= maxLocalProjDBSize)
.keys
.collect()
.toSet
val small = prefixSuffixPairs.filter { case (prefix, _) => smallPrefixes.contains(prefix) }
val large = prefixSuffixPairs.filter { case (prefix, _) => !smallPrefixes.contains(prefix) }
(small, large)
}

/**
* Get the pattern and counts, and prefix suffix pairs
* Extends all prefixes by one item from their suffix and computes the resulting frequent prefixes
* and remaining work.
* @param minCount minimum count
* @param prefixSuffixPairs prefix (length n) and suffix pairs,
* @return pattern (length n+1) and counts, and prefix (length n+1) and suffix pairs
* (RDD[pattern, count], RDD[prefix, suffix ])
* @param prefixSuffixPairs prefix (length N) and suffix pairs,
* @return (frequent length N+1 extended prefix, count) pairs and (frequent length N+1 extended
* prefix, corresponding suffix) pairs.
*/
private def getPatternCountsAndPrefixSuffixPairs(
private def extendPrefixes(
minCount: Long,
prefixSuffixPairs: RDD[(ArrayBuffer[Int], Array[Int])]):
(RDD[(ArrayBuffer[Int], Long)], RDD[(ArrayBuffer[Int], Array[Int])]) = {
val prefixAndFrequentItemAndCounts = prefixSuffixPairs
prefixSuffixPairs: RDD[(List[Int], Array[Int])])
: (RDD[(List[Int], Long)], RDD[(List[Int], Array[Int])]) = {

// (length N prefix, item from suffix) pairs and their corresponding number of occurrences
// Every (prefix :+ suffix) is guaranteed to have support exceeding `minSupport`
val prefixItemPairAndCounts = prefixSuffixPairs
.flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L)) }
.reduceByKey(_ + _)
.filter(_._2 >= minCount)
val patternAndCounts = prefixAndFrequentItemAndCounts
.map { case ((prefix, item), count) => (prefix :+ item, count) }
val prefixToFrequentNextItemsMap = prefixAndFrequentItemAndCounts

// Map from prefix to set of possible next items from suffix
val prefixToNextItems = prefixItemPairAndCounts
.keys
.groupByKey()
.mapValues(_.toSet)
.collect()
.toMap
val nextPrefixSuffixPairs = prefixSuffixPairs
.filter(x => prefixToFrequentNextItemsMap.contains(x._1))
.flatMap { case (prefix, suffix) =>
val frequentNextItems = prefixToFrequentNextItemsMap(prefix)
val filteredSuffix = suffix.filter(frequentNextItems.contains(_))
frequentNextItems.flatMap { item =>
val suffix = LocalPrefixSpan.getSuffix(item, filteredSuffix)
if (suffix.isEmpty) None
else Some(prefix :+ item, suffix)
}
}
(patternAndCounts, nextPrefixSuffixPairs)
}

/**
* Get the minimum count (sequences count * minSupport).
* @param sequences input data set, contains a set of sequences,
* @return minimum count,
*/
private def getMinCount(sequences: RDD[Array[Int]]): Long = {
if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong
}

/**
* Generates frequent items by filtering the input data using minimal count level.
* @param minCount the absolute minimum count
* @param sequences original sequences data
* @return array of item and count pair
*/
private def getFreqItemAndCounts(
minCount: Long,
sequences: RDD[Array[Int]]): RDD[(Int, Long)] = {
sequences.flatMap(_.distinct.map((_, 1L)))
.reduceByKey(_ + _)
.filter(_._2 >= minCount)
}
// Frequent patterns with length N+1 and their corresponding counts
val extendedPrefixAndCounts = prefixItemPairAndCounts
.map { case ((prefix, item), count) => (item :: prefix, count) }

/**
* Get the frequent prefixes and suffix pairs.
* @param frequentPrefixes frequent prefixes
* @param sequences sequences data
* @return prefixes and suffix pairs.
*/
private def getPrefixSuffixPairs(
frequentPrefixes: Array[Int],
sequences: RDD[Array[Int]]): RDD[(ArrayBuffer[Int], Array[Int])] = {
val filteredSequences = sequences.map { p =>
p.filter (frequentPrefixes.contains(_) )
}
filteredSequences.flatMap { x =>
frequentPrefixes.map { y =>
val sub = LocalPrefixSpan.getSuffix(y, x)
(ArrayBuffer(y), sub)
}.filter(_._2.nonEmpty)
}
// Remaining work, all prefixes will have length N+1
val extendedPrefixAndSuffix = prefixSuffixPairs
.filter(x => prefixToNextItems.contains(x._1))
.flatMap { case (prefix, suffix) =>
val frequentNextItems = prefixToNextItems(prefix)
val filteredSuffix = suffix.filter(frequentNextItems.contains(_))
frequentNextItems.flatMap { item =>
LocalPrefixSpan.getSuffix(item, filteredSuffix) match {
case suffix if !suffix.isEmpty => Some(item :: prefix, suffix)
case _ => None
}
}
}

(extendedPrefixAndCounts, extendedPrefixAndSuffix)
}

/**
* calculate the patterns in local.
* Calculate the patterns in local.
* @param minCount the absolute minimum count
* @param data prefixes and projected sequences data data
* @return patterns
*/
private def getPatternsInLocal(
minCount: Long,
data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(ArrayBuffer[Int], Long)] = {
data: RDD[(List[Int], Iterable[Array[Int]])]): RDD[(List[Int], Long)] = {
data.flatMap {
case (prefix, projDB) =>
LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList.reverse, projDB)
.map { case (pattern: List[Int], count: Long) =>
(pattern.toArray.reverse.to[ArrayBuffer], count)
}
case (prefix, projDB) =>
LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList.reverse, projDB)
.map { case (pattern: List[Int], count: Long) =>
(pattern.reverse, count)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,6 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {

val rdd = sc.parallelize(sequences, 2).cache()

def compareResult(
expectedValue: Array[(Array[Int], Long)],
actualValue: Array[(Array[Int], Long)]): Boolean = {
expectedValue.map(x => (x._1.toSeq, x._2)).toSet ==
actualValue.map(x => (x._1.toSeq, x._2)).toSet
}

val prefixspan = new PrefixSpan()
.setMinSupport(0.33)
.setMaxPatternLength(50)
Expand All @@ -76,7 +69,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
(Array(4, 5), 2L),
(Array(5), 3L)
)
assert(compareResult(expectedValue1, result1.collect()))
assert(compareResults(expectedValue1, result1.collect()))

prefixspan.setMinSupport(0.5).setMaxPatternLength(50)
val result2 = prefixspan.run(rdd)
Expand All @@ -87,7 +80,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
(Array(4), 4L),
(Array(5), 3L)
)
assert(compareResult(expectedValue2, result2.collect()))
assert(compareResults(expectedValue2, result2.collect()))

prefixspan.setMinSupport(0.33).setMaxPatternLength(2)
val result3 = prefixspan.run(rdd)
Expand All @@ -107,6 +100,14 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
(Array(4, 5), 2L),
(Array(5), 3L)
)
assert(compareResult(expectedValue3, result3.collect()))
assert(compareResults(expectedValue3, result3.collect()))
}

private def compareResults(
expectedValue: Array[(Array[Int], Long)],
actualValue: Array[(Array[Int], Long)]): Boolean = {
expectedValue.map(x => (x._1.toSeq, x._2)).toSet ==
actualValue.map(x => (x._1.toSeq, x._2)).toSet
}

}

0 comments on commit ad23aa9

Please sign in to comment.