Skip to content

Commit

Permalink
Support non-temporal sequence in PrefixSpan
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangjiajin committed Jul 24, 2015
1 parent b572f54 commit 216ab0c
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 79 deletions.
130 changes: 100 additions & 30 deletions mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,55 +40,125 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
minCount: Long,
maxPatternLength: Int,
prefixes: List[Int],
database: Array[Array[Int]]): Iterator[(List[Int], Long)] = {
if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty
val frequentItemAndCounts = getFreqItemAndCounts(minCount, database)
val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains))
frequentItemAndCounts.iterator.flatMap { case (item, count) =>
val newPrefixes = item :: prefixes
val newProjected = project(filteredDatabase, item)
Iterator.single((newPrefixes, count)) ++
run(minCount, maxPatternLength, newPrefixes, newProjected)
database: Array[(Array[Int], Int)]): Iterator[(List[Int], Long)] = {
if ((prefixes.nonEmpty && prefixes.filter(_ != -1).length == maxPatternLength) ||
database.length < minCount) { return Iterator.empty }
val frequentItemAndCounts = getFreqPrefixAndCounts(minCount, prefixes, database)
frequentItemAndCounts.iterator.flatMap { case (prefix, count) =>
val newProjected = project(database, prefix)
Iterator.single((prefix, count)) ++
run(minCount, maxPatternLength, prefix, newProjected)
}
}

/**
* Calculate suffix sequence immediately after the first occurrence of an item.
* @param item item to get suffix after
* @param sequence sequence to extract suffix from
* @param element the last element of prefix
* @param sequenceAndFlag sequence to extract suffix from
* @return suffix sequence
*/
def getSuffix(item: Int, sequence: Array[Int]): Array[Int] = {
val index = sequence.indexOf(item)
if (index == -1) {
Array()
def getSuffix(
element: Array[Int],
sequenceAndFlag: (Array[Int], Int)): (Array[Int], Int) = {
val (originalSequence, flag) = sequenceAndFlag
val sequence =
if (element.length > 1 && flag == 1) {
element.take(element.length - 1) ++ originalSequence
} else if (element.length == 1 && flag == 1) {
val firstPosition = originalSequence.indexOf(-1)
if (firstPosition != -1) {
originalSequence.drop(firstPosition + 1)
} else {
return (Array(), 0)
}
} else {
originalSequence
}
var found = false
var currentIndex = -1
var nextIndex = 0
while (nextIndex != -1 && !found) {
nextIndex = sequence.indexOf(-1, currentIndex + 1)
found = element.toSet.subsetOf(
sequence.slice(currentIndex + 1, nextIndex).toSet)
if (!found) currentIndex = nextIndex
}
if (found) {
val itemPosition = sequence.indexOf(element.last, currentIndex)
if (sequence.apply(itemPosition + 1) == -1) {
(sequence.drop(itemPosition + 2), 0)
} else {
(sequence.drop(itemPosition + 1), 1)
}
} else {
sequence.drop(index + 1)
(Array(), 0)
}
}

def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = {
private def project(
database: Array[(Array[Int], Int)],
prefix: List[Int]): Array[(Array[Int], Int)] = {
val lastElement = prefix.toArray.drop(prefix.lastIndexOf(-1) + 1)
database
.map(getSuffix(prefix, _))
.filter(_.nonEmpty)
.map(getSuffix(lastElement, _))
.filter(_._1.nonEmpty)
}

/**
* Generates frequent items by filtering the input data using minimal count level.
* Generates frequent prefixes by filtering the input data using minimal count level.
* @param minCount the minimum count for an item to be frequent
* @param database database of sequences
* @return freq item to count map
* @param prefix prefix
* @param suffixes suffixes
* @return freq prefix to count map
*/
private def getFreqItemAndCounts(
private def getFreqPrefixAndCounts(
minCount: Long,
database: Array[Array[Int]]): mutable.Map[Int, Long] = {
// TODO: use PrimitiveKeyOpenHashMap
val counts = mutable.Map[Int, Long]().withDefaultValue(0L)
database.foreach { sequence =>
sequence.distinct.foreach { item =>
counts(item) += 1L
prefix: List[Int],
suffixes: Array[(Array[Int], Int)]): mutable.Map[List[Int], Long] = {
val counts = mutable.Map[List[Int], Long]().withDefaultValue(0L)
val singleItemSet = suffixes.map { case (suffix, flag) =>
if (flag == 0) suffix else suffix.drop(suffix.indexOf(-1) + 1)
}.flatMap(_.filter(_ != -1).distinct)
.groupBy(item => item).map(x => (x._1, x._2.length.toLong))
singleItemSet.filter(_._2 >= minCount).foreach { case (item, count) =>
if (prefix.nonEmpty) counts(prefix :+ -1 :+ item) = count else counts(List(item)) = count
}
if (prefix.nonEmpty) {
val lastElement = prefix.drop(prefix.lastIndexOf(-1) + 1).toArray
val multiItemSet = mutable.Map[Int, Long]().withDefaultValue(0L)
suffixes.map { case (suffix, flag) =>
if (flag == 0) suffix else lastElement ++ suffix
}.foreach { suffix =>
singleItemSet.keys.foreach { item =>
if (!lastElement.contains(item)) {
val element = lastElement :+ item
if (isSubElement(suffix, element)) {
multiItemSet(item) += 1L
}
}
}
}
multiItemSet.filter(_._2 >= minCount).foreach { case (item, count) =>
if (prefix.nonEmpty) {
counts(prefix :+ item) = count
} else {
counts(List(item)) = count
}
}
}
counts
}

private def isSubElement(sequence: Array[Int], element: Array[Int]): Boolean = {
var found = false
var currentIndex = -1
var nextIndex = 0
while (nextIndex != -1 && !found) {
nextIndex = sequence.indexOf(-1, currentIndex + 1)
found = element.toSet.subsetOf(
sequence.slice(currentIndex + 1, nextIndex).toSet)
if (!found) currentIndex = nextIndex
}
counts.filter(_._2 >= minCount)
found
}
}
48 changes: 33 additions & 15 deletions mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.mllib.fpm

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -81,11 +83,12 @@ class PrefixSpan private (
if (sequences.getStorageLevel == StorageLevel.NONE) {
logWarning("Input data is not cached.")
}
val minCount = getMinCount(sequences)
val sortedSequences = sortSequences(sequences)
val minCount = getMinCount(sortedSequences)
val lengthOnePatternsAndCounts =
getFreqItemAndCounts(minCount, sequences).collect()
getFreqItemAndCounts(minCount, sortedSequences).collect()
val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase(
lengthOnePatternsAndCounts.map(_._1), sequences)
lengthOnePatternsAndCounts.map(_._1), sortedSequences)
val groupedProjectedDatabase = prefixAndProjectedDatabase
.map(x => (x._1.toSeq, x._2))
.groupByKey()
Expand All @@ -98,6 +101,24 @@ class PrefixSpan private (
allPatterns
}

private def sortSequences(sequences: RDD[Array[Int]]): RDD[Array[Int]] = {
sequences.map { sequence =>
val sortedArray: ArrayBuffer[Int] = ArrayBuffer()
var currentIndex = -1
var nextIndex = 0
while (nextIndex != -1) {
nextIndex = sequence.indexOf(-1, currentIndex + 1)
if (nextIndex != -1) {
sortedArray ++= sequence.slice(currentIndex, nextIndex).sorted
} else {
sortedArray ++= sequence.drop(currentIndex).sorted
}
currentIndex = nextIndex
}
sortedArray.toArray
}
}

/**
* Get the minimum count (sequences count * minSupport).
* @param sequences input data set, contains a set of sequences,
Expand All @@ -116,7 +137,7 @@ class PrefixSpan private (
private def getFreqItemAndCounts(
minCount: Long,
sequences: RDD[Array[Int]]): RDD[(Int, Long)] = {
sequences.flatMap(_.distinct.map((_, 1L)))
sequences.flatMap(_.filter(_ != -1).distinct.map((_, 1L)))
.reduceByKey(_ + _)
.filter(_._2 >= minCount)
}
Expand All @@ -129,15 +150,12 @@ class PrefixSpan private (
*/
private def getPrefixAndProjectedDatabase(
frequentPrefixes: Array[Int],
sequences: RDD[Array[Int]]): RDD[(Array[Int], Array[Int])] = {
val filteredSequences = sequences.map { p =>
p.filter (frequentPrefixes.contains(_) )
}
filteredSequences.flatMap { x =>
frequentPrefixes.map { y =>
val sub = LocalPrefixSpan.getSuffix(y, x)
(Array(y), sub)
}.filter(_._2.nonEmpty)
sequences: RDD[Array[Int]]): RDD[(Array[Int], (Array[Int], Int))] = {
sequences.flatMap { sequence =>
frequentPrefixes.map { item =>
val sub = LocalPrefixSpan.getSuffix(Array(item), (sequence, 0))
(Array(item), sub)
}.filter(_._2._1.nonEmpty)
}
}

Expand All @@ -149,10 +167,10 @@ class PrefixSpan private (
*/
private def getPatternsInLocal(
minCount: Long,
data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = {
data: RDD[(Array[Int], Array[(Array[Int], Int)])]): RDD[(Array[Int], Long)] = {
data.flatMap { case (prefix, projDB) =>
LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList, projDB)
.map { case (pattern: List[Int], count: Long) => (pattern.toArray.reverse, count) }
.map { case (pattern: List[Int], count: Long) => (pattern.toArray, count) }
}
}
}
Loading

0 comments on commit 216ab0c

Please sign in to comment.