Skip to content

Commit

Permalink
Modified the code according to the review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangjiajin committed Jul 8, 2015
1 parent 91fd7e6 commit 575995f
Show file tree
Hide file tree
Showing 2 changed files with 278 additions and 0 deletions.
209 changes: 209 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.fpm

import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD

/**
*
* :: Experimental ::
*
* A parallel PrefixSpan algorithm to mine sequential pattern.
* The PrefixSpan algorithm is described in
* [[http://doi.org/10.1109/ICDE.2001.914830]].
*
* @param minSupport the minimal support level of the sequential pattern, any pattern appears
* more than (minSupport * size-of-the-dataset) times will be output
* @param maxPatternLength the maximal length of the sequential pattern, any pattern appears
* less than maxPatternLength will be output
*
* @see [[https://en.wikipedia.org/wiki/Sequential_Pattern_Mining Sequential Pattern Mining
* (Wikipedia)]]
*/
@Experimental
class PrefixSpan(
private var minSupport: Double,
private var maxPatternLength: Int) extends java.io.Serializable {

private var absMinSupport: Int = 0

/**
* Constructs a default instance with default parameters
* {minSupport: `0.1`, maxPatternLength: 10}.
*/
def this() = this(0.1, 10)

/**
* Sets the minimal support level (default: `0.1`).
*/
def setMinSupport(minSupport: Double): this.type = {
this.minSupport = minSupport
this
}

/**
* Sets maximal pattern length.
*/
def setMaxPatternLength(maxPatternLength: Int): this.type = {
this.maxPatternLength = maxPatternLength
this
}

/**
* Calculate sequential patterns:
* a) find and collect length-one patterns
* b) for each length-one patterns and each sequence,
* emit (pattern (prefix), suffix sequence) as key-value pairs
* c) group by key and then map value iterator to array
* d) local PrefixSpan on each prefix
* @return sequential patterns
*/
def run(sequences: RDD[Array[Int]]): RDD[(Seq[Int], Int)] = {
absMinSupport = getAbsoluteMinSupport(sequences)
val (lengthOnePatternsAndCounts, prefixAndCandidates) =
findLengthOnePatterns(sequences)
val repartitionedRdd = makePrefixProjectedDatabases(prefixAndCandidates)
val nextPatterns = getPatternsInLocal(repartitionedRdd)
val allPatterns = lengthOnePatternsAndCounts.map(x => (Seq(x._1), x._2)) ++ nextPatterns
allPatterns
}

private def getAbsoluteMinSupport(sequences: RDD[Array[Int]]): Int = {
val result = if (minSupport <= 0) {
0
}else {
val count = sequences.count()
val support = if (minSupport <= 1) minSupport else 1
(support * count).toInt
}
result
}

/**
* Find the patterns that it's length is one
* @param sequences original sequences data
* @return length-one patterns and projection table
*/
private def findLengthOnePatterns(
sequences: RDD[Array[Int]]): (RDD[(Int, Int)], RDD[(Seq[Int], Array[Int])]) = {
val LengthOnePatternAndCounts = sequences
.flatMap(_.distinct.map((_, 1)))
.reduceByKey(_ + _)
val infrequentLengthOnePatterns: Array[Int] = LengthOnePatternAndCounts
.filter(_._2 < absMinSupport)
.map(_._1)
.collect()
val frequentLengthOnePatterns = LengthOnePatternAndCounts
.filter(_._2 >= absMinSupport)
val frequentLengthOnePatternsArray = frequentLengthOnePatterns
.map(_._1)
.collect()
val filteredSequences =
if (infrequentLengthOnePatterns.isEmpty) {
sequences
} else {
sequences.map { p =>
p.filter { x => !infrequentLengthOnePatterns.contains(x) }
}
}
val prefixAndCandidates = filteredSequences.flatMap { x =>
frequentLengthOnePatternsArray.map { y =>
val sub = getSuffix(y, x)
(Seq(y), sub)
}
}.filter(x => x._2.nonEmpty)
(frequentLengthOnePatterns, prefixAndCandidates)
}

/**
* Re-partition the RDD data, to get better balance and performance.
* @param data patterns and projected sequences data before re-partition
* @return patterns and projected sequences data after re-partition
*/
private def makePrefixProjectedDatabases(
data: RDD[(Seq[Int], Array[Int])]): RDD[(Seq[Int], Array[Array[Int]])] = {
val dataMerged = data
.groupByKey()
.mapValues(_.toArray)
dataMerged
}

/**
* calculate the patterns in local.
* @param data patterns and projected sequences data data
* @return patterns
*/
private def getPatternsInLocal(
data: RDD[(Seq[Int], Array[Array[Int]])]): RDD[(Seq[Int], Int)] = {
val result = data.flatMap { x =>
getPatternsWithPrefix(x._1, x._2)
}
result
}

/**
* calculate the patterns with one prefix in local.
* @param prefix prefix
* @param projectedDatabase patterns and projected sequences data
* @return patterns
*/
private def getPatternsWithPrefix(
prefix: Seq[Int],
projectedDatabase: Array[Array[Int]]): Array[(Seq[Int], Int)] = {
val prefixAndCounts = projectedDatabase
.flatMap(_.distinct)
.groupBy(x => x)
.mapValues(_.length)
val frequentPrefixExtensions = prefixAndCounts.filter(x => x._2 >= absMinSupport)
val frequentPrefixesAndCounts = frequentPrefixExtensions
.map(x => (prefix ++ Seq(x._1), x._2))
.toArray
val cleanedSearchSpace = projectedDatabase
.map(x => x.filter(y => frequentPrefixExtensions.contains(y)))
val prefixProjectedDatabases = frequentPrefixExtensions.map { x =>
val sub = cleanedSearchSpace.map(y => getSuffix(x._1, y)).filter(_.nonEmpty)
(prefix ++ Seq(x._1), sub)
}.filter(x => x._2.nonEmpty)
.toArray
val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength
if (continueProcess) {
val nextPatterns = prefixProjectedDatabases
.map(x => getPatternsWithPrefix(x._1, x._2))
.reduce(_ ++ _)
frequentPrefixesAndCounts ++ nextPatterns
} else {
frequentPrefixesAndCounts
}
}

/**
* calculate suffix sequence following a prefix in a sequence
* @param prefix prefix
* @param sequence original sequence
* @return suffix sequence
*/
private def getSuffix(prefix: Int, sequence: Array[Int]): Array[Int] = {
val index = sequence.indexOf(prefix)
if (index == -1) {
Array()
} else {
sequence.drop(index + 1)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.fpm

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD

class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext {

test("Prefixspan sequences mining using Integer type") {
val sequences = Array(
Array(3, 1, 3, 4, 5),
Array(2, 3, 1),
Array(3, 4, 4, 3),
Array(1, 3, 4, 5),
Array(2, 4, 1),
Array(6, 5, 3))

val rdd = sc.parallelize(sequences, 2).cache()

def formatResultString(data: RDD[(Seq[Int], Int)]): String = {
data.map(x => x._1.mkString(",") + ": " + x._2)
.collect()
.sortWith(_<_)
.mkString("; ")
}

val prefixspan = new PrefixSpan()
.setMinSupport(0.34)
.setMaxPatternLength(50)
val result1 = prefixspan.run(rdd)
val len1 = result1.count().toInt
val actualValue1 = formatResultString(result1)
val expectedValue1 =
"1,3,4,5: 2; 1,3,4: 2; 1,3,5: 2; 1,3: 2; 1,4,5: 2;" +
" 1,4: 2; 1,5: 2; 1: 4; 2,1: 2; 2: 2; 3,1: 2; 3,3: 2;" +
" 3,4,5: 2; 3,4: 3; 3,5: 2; 3: 5; 4,5: 2; 4: 4; 5: 3"
assert(expectedValue1 == actualValue1)

prefixspan.setMinSupport(0.5).setMaxPatternLength(50)
val result2 = prefixspan.run(rdd)
val expectedValue2 = "1: 4; 3,4: 3; 3: 5; 4: 4; 5: 3"
val actualValue2 = formatResultString(result2)
assert(expectedValue2 == actualValue2)

prefixspan.setMinSupport(0.34).setMaxPatternLength(2)
val result3 = prefixspan.run(rdd)
val actualValue3 = formatResultString(result3)
val expectedValue3 =
"1,3: 2; 1,4: 2; 1,5: 2; 1: 4; 2,1: 2; 2: 2; 3,1: 2;" +
" 3,3: 2; 3,4: 3; 3,5: 2; 3: 5; 4,5: 2; 4: 4; 5: 3"
assert(expectedValue3 == actualValue3)
}
}

0 comments on commit 575995f

Please sign in to comment.