diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala index 7ead6327486cc..bf066b6c3e8b7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala @@ -40,55 +40,125 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { minCount: Long, maxPatternLength: Int, prefixes: List[Int], - database: Array[Array[Int]]): Iterator[(List[Int], Long)] = { - if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty - val frequentItemAndCounts = getFreqItemAndCounts(minCount, database) - val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains)) - frequentItemAndCounts.iterator.flatMap { case (item, count) => - val newPrefixes = item :: prefixes - val newProjected = project(filteredDatabase, item) - Iterator.single((newPrefixes, count)) ++ - run(minCount, maxPatternLength, newPrefixes, newProjected) + database: Array[(Array[Int], Int)]): Iterator[(List[Int], Long)] = { + if ((prefixes.nonEmpty && prefixes.filter(_ != -1).length == maxPatternLength) || + database.length < minCount) { return Iterator.empty } + val frequentItemAndCounts = getFreqPrefixAndCounts(minCount, prefixes, database) + frequentItemAndCounts.iterator.flatMap { case (prefix, count) => + val newProjected = project(database, prefix) + Iterator.single((prefix, count)) ++ + run(minCount, maxPatternLength, prefix, newProjected) } } /** * Calculate suffix sequence immediately after the first occurrence of an item. - * @param item item to get suffix after - * @param sequence sequence to extract suffix from + * @param element the last element of prefix + * @param sequenceAndFlag sequence to extract suffix from * @return suffix sequence */ - def getSuffix(item: Int, sequence: Array[Int]): Array[Int] = { - val index = sequence.indexOf(item) - if (index == -1) { - Array() + def getSuffix( + element: Array[Int], + sequenceAndFlag: (Array[Int], Int)): (Array[Int], Int) = { + val (originalSequence, flag) = sequenceAndFlag + val sequence = + if (element.length > 1 && flag == 1) { + element.take(element.length - 1) ++ originalSequence + } else if (element.length == 1 && flag == 1) { + val firstPosition = originalSequence.indexOf(-1) + if (firstPosition != -1) { + originalSequence.drop(firstPosition + 1) + } else { + return (Array(), 0) + } + } else { + originalSequence + } + var found = false + var currentIndex = -1 + var nextIndex = 0 + while (nextIndex != -1 && !found) { + nextIndex = sequence.indexOf(-1, currentIndex + 1) + found = element.toSet.subsetOf( + sequence.slice(currentIndex + 1, nextIndex).toSet) + if (!found) currentIndex = nextIndex + } + if (found) { + val itemPosition = sequence.indexOf(element.last, currentIndex) + if (sequence.apply(itemPosition + 1) == -1) { + (sequence.drop(itemPosition + 2), 0) + } else { + (sequence.drop(itemPosition + 1), 1) + } } else { - sequence.drop(index + 1) + (Array(), 0) } } - def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = { + private def project( + database: Array[(Array[Int], Int)], + prefix: List[Int]): Array[(Array[Int], Int)] = { + val lastElement = prefix.toArray.drop(prefix.lastIndexOf(-1) + 1) database - .map(getSuffix(prefix, _)) - .filter(_.nonEmpty) + .map(getSuffix(lastElement, _)) + .filter(_._1.nonEmpty) } /** - * Generates frequent items by filtering the input data using minimal count level. + * Generates frequent prefixes by filtering the input data using minimal count level. * @param minCount the minimum count for an item to be frequent - * @param database database of sequences - * @return freq item to count map + * @param prefix prefix + * @param suffixes suffixes + * @return freq prefix to count map */ - private def getFreqItemAndCounts( + private def getFreqPrefixAndCounts( minCount: Long, - database: Array[Array[Int]]): mutable.Map[Int, Long] = { - // TODO: use PrimitiveKeyOpenHashMap - val counts = mutable.Map[Int, Long]().withDefaultValue(0L) - database.foreach { sequence => - sequence.distinct.foreach { item => - counts(item) += 1L + prefix: List[Int], + suffixes: Array[(Array[Int], Int)]): mutable.Map[List[Int], Long] = { + val counts = mutable.Map[List[Int], Long]().withDefaultValue(0L) + val singleItemSet = suffixes.map { case (suffix, flag) => + if (flag == 0) suffix else suffix.drop(suffix.indexOf(-1) + 1) + }.flatMap(_.filter(_ != -1).distinct) + .groupBy(item => item).map(x => (x._1, x._2.length.toLong)) + singleItemSet.filter(_._2 >= minCount).foreach { case (item, count) => + if (prefix.nonEmpty) counts(prefix :+ -1 :+ item) = count else counts(List(item)) = count + } + if (prefix.nonEmpty) { + val lastElement = prefix.drop(prefix.lastIndexOf(-1) + 1).toArray + val multiItemSet = mutable.Map[Int, Long]().withDefaultValue(0L) + suffixes.map { case (suffix, flag) => + if (flag == 0) suffix else lastElement ++ suffix + }.foreach { suffix => + singleItemSet.keys.foreach { item => + if (!lastElement.contains(item)) { + val element = lastElement :+ item + if (isSubElement(suffix, element)) { + multiItemSet(item) += 1L + } + } + } } + multiItemSet.filter(_._2 >= minCount).foreach { case (item, count) => + if (prefix.nonEmpty) { + counts(prefix :+ item) = count + } else { + counts(List(item)) = count + } + } + } + counts + } + + private def isSubElement(sequence: Array[Int], element: Array[Int]): Boolean = { + var found = false + var currentIndex = -1 + var nextIndex = 0 + while (nextIndex != -1 && !found) { + nextIndex = sequence.indexOf(-1, currentIndex + 1) + found = element.toSet.subsetOf( + sequence.slice(currentIndex + 1, nextIndex).toSet) + if (!found) currentIndex = nextIndex } - counts.filter(_._2 >= minCount) + found } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 6f52db7b073ae..a4a0530add210 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.fpm +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD @@ -81,11 +83,12 @@ class PrefixSpan private ( if (sequences.getStorageLevel == StorageLevel.NONE) { logWarning("Input data is not cached.") } - val minCount = getMinCount(sequences) + val sortedSequences = sortSequences(sequences) + val minCount = getMinCount(sortedSequences) val lengthOnePatternsAndCounts = - getFreqItemAndCounts(minCount, sequences).collect() + getFreqItemAndCounts(minCount, sortedSequences).collect() val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase( - lengthOnePatternsAndCounts.map(_._1), sequences) + lengthOnePatternsAndCounts.map(_._1), sortedSequences) val groupedProjectedDatabase = prefixAndProjectedDatabase .map(x => (x._1.toSeq, x._2)) .groupByKey() @@ -98,6 +101,24 @@ class PrefixSpan private ( allPatterns } + private def sortSequences(sequences: RDD[Array[Int]]): RDD[Array[Int]] = { + sequences.map { sequence => + val sortedArray: ArrayBuffer[Int] = ArrayBuffer() + var currentIndex = -1 + var nextIndex = 0 + while (nextIndex != -1) { + nextIndex = sequence.indexOf(-1, currentIndex + 1) + if (nextIndex != -1) { + sortedArray ++= sequence.slice(currentIndex, nextIndex).sorted + } else { + sortedArray ++= sequence.drop(currentIndex).sorted + } + currentIndex = nextIndex + } + sortedArray.toArray + } + } + /** * Get the minimum count (sequences count * minSupport). * @param sequences input data set, contains a set of sequences, @@ -116,7 +137,7 @@ class PrefixSpan private ( private def getFreqItemAndCounts( minCount: Long, sequences: RDD[Array[Int]]): RDD[(Int, Long)] = { - sequences.flatMap(_.distinct.map((_, 1L))) + sequences.flatMap(_.filter(_ != -1).distinct.map((_, 1L))) .reduceByKey(_ + _) .filter(_._2 >= minCount) } @@ -129,15 +150,12 @@ class PrefixSpan private ( */ private def getPrefixAndProjectedDatabase( frequentPrefixes: Array[Int], - sequences: RDD[Array[Int]]): RDD[(Array[Int], Array[Int])] = { - val filteredSequences = sequences.map { p => - p.filter (frequentPrefixes.contains(_) ) - } - filteredSequences.flatMap { x => - frequentPrefixes.map { y => - val sub = LocalPrefixSpan.getSuffix(y, x) - (Array(y), sub) - }.filter(_._2.nonEmpty) + sequences: RDD[Array[Int]]): RDD[(Array[Int], (Array[Int], Int))] = { + sequences.flatMap { sequence => + frequentPrefixes.map { item => + val sub = LocalPrefixSpan.getSuffix(Array(item), (sequence, 0)) + (Array(item), sub) + }.filter(_._2._1.nonEmpty) } } @@ -149,10 +167,10 @@ class PrefixSpan private ( */ private def getPatternsInLocal( minCount: Long, - data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = { + data: RDD[(Array[Int], Array[(Array[Int], Int)])]): RDD[(Array[Int], Long)] = { data.flatMap { case (prefix, projDB) => LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList, projDB) - .map { case (pattern: List[Int], count: Long) => (pattern.toArray.reverse, count) } + .map { case (pattern: List[Int], count: Long) => (pattern.toArray, count) } } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala index 9f107c89f6d80..3c1ae79ebe8b0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -35,20 +35,20 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { */ val sequences = Array( - Array(1, 3, 4, 5), - Array(2, 3, 1), - Array(2, 4, 1), - Array(3, 1, 3, 4, 5), - Array(3, 4, 4, 3), - Array(6, 5, 3)) + Array(1, -1, 3, -1, 4, -1, 5), + Array(2, -1, 3, -1, 1), + Array(2, -1, 4, -1, 1), + Array(3, -1, 1, -1, 3, -1, 4, -1, 5), + Array(3, -1, 4, -1, 4, -1, 3), + Array(6, -1, 5, -1, 3)) val rdd = sc.parallelize(sequences, 2).cache() def compareResult( - expectedValue: Array[(Array[Int], Long)], - actualValue: Array[(Array[Int], Long)]): Boolean = { - expectedValue.map(x => (x._1.toSeq, x._2)).toSet == - actualValue.map(x => (x._1.toSeq, x._2)).toSet + expectedValue: Array[(Array[Int], Long)], + actualValue: Array[(Array[Int], Long)]): Boolean = { + expectedValue.map(x => (x._1.toSeq, x._2)).toSet == + actualValue.map(x => (x._1.toSeq, x._2)).toSet } val prefixspan = new PrefixSpan() @@ -57,23 +57,23 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { val result1 = prefixspan.run(rdd) val expectedValue1 = Array( (Array(1), 4L), - (Array(1, 3), 2L), - (Array(1, 3, 4), 2L), - (Array(1, 3, 4, 5), 2L), - (Array(1, 3, 5), 2L), - (Array(1, 4), 2L), - (Array(1, 4, 5), 2L), - (Array(1, 5), 2L), + (Array(1, -1, 3), 2L), + (Array(1, -1, 3, -1, 4), 2L), + (Array(1, -1, 3, -1, 4, -1, 5), 2L), + (Array(1, -1, 3, -1, 5), 2L), + (Array(1, -1, 4), 2L), + (Array(1, -1, 4, -1, 5), 2L), + (Array(1, -1, 5), 2L), (Array(2), 2L), - (Array(2, 1), 2L), + (Array(2, -1, 1), 2L), (Array(3), 5L), - (Array(3, 1), 2L), - (Array(3, 3), 2L), - (Array(3, 4), 3L), - (Array(3, 4, 5), 2L), - (Array(3, 5), 2L), + (Array(3, -1, 1), 2L), + (Array(3, -1, 3), 2L), + (Array(3, -1, 4), 3L), + (Array(3, -1, 4, -1, 5), 2L), + (Array(3, -1, 5), 2L), (Array(4), 4L), - (Array(4, 5), 2L), + (Array(4, -1, 5), 2L), (Array(5), 3L) ) assert(compareResult(expectedValue1, result1.collect())) @@ -83,7 +83,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { val expectedValue2 = Array( (Array(1), 4L), (Array(3), 5L), - (Array(3, 4), 3L), + (Array(3, -1, 4), 3L), (Array(4), 4L), (Array(5), 3L) ) @@ -93,20 +93,89 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { val result3 = prefixspan.run(rdd) val expectedValue3 = Array( (Array(1), 4L), - (Array(1, 3), 2L), - (Array(1, 4), 2L), - (Array(1, 5), 2L), - (Array(2, 1), 2L), + (Array(1, -1, 3), 2L), + (Array(1, -1, 4), 2L), + (Array(1, -1, 5), 2L), + (Array(2, -1, 1), 2L), (Array(2), 2L), (Array(3), 5L), - (Array(3, 1), 2L), - (Array(3, 3), 2L), - (Array(3, 4), 3L), - (Array(3, 5), 2L), + (Array(3, -1, 1), 2L), + (Array(3, -1, 3), 2L), + (Array(3, -1, 4), 3L), + (Array(3, -1, 5), 2L), (Array(4), 4L), - (Array(4, 5), 2L), + (Array(4, -1, 5), 2L), (Array(5), 3L) ) assert(compareResult(expectedValue3, result3.collect())) + + val sequences4 = Array( + "a,abc,ac,d,cf", + "ad,c,bc,ae", + "ef,ab,df,c,b", + "e,g,af,c,b,c") + val coder = Array('a', 'b', 'c', 'd', 'e', 'f', 'g').zip(Array(1, 2, 3, 4, 5, 6, 7)).toMap + val intSequences = sequences4.map(_.split(",").flatMap(-1 +: _.toArray.map(coder)).drop(1)) + val rdd4 = sc.parallelize(intSequences, 2).cache() + prefixspan.setMinSupport(0.5).setMaxPatternLength(5) + val result4 = prefixspan.run(rdd4) + val expectedValue4 = Array( + "a:4", + "b:4", + "c:4", + "d:3", + "e:3", + "f:3", + "a,a:2", + "a,b:4", + "a,bc:2", + "a,bc,a:2", + "a,b,a:2", + "a,b,c:2", + "ab:2", + "ab,c:2", + "ab,d:2", + "ab,d,c:2", + "ab,f:2", + "a,c:4", + "a,c,a:2", + "a,c,b:3", + "a,c,c:3", + "a,d:2", + "a,d,c:2", + "a,f:2", + "b,a:2", + "b,c:3", + "bc:2", + "bc,a:2", + "b,d:2", + "b,d,c:2", + "b,f:2", + "c,a:2", + "c,b:3", + "c,c:3", + "d,b:2", + "d,c:3", + "d,c,b:2", + "e,a:2", + "e,a,b:2", + "e,a,c:2", + "e,a,c,b:2", + "e,b:2", + "e,b,c:2", + "e,c:2", + "e,c,b:2", + "e,f:2", + "e,f,b:2", + "e,f,c:2", + "e,f,c,b:2", + "f,b:2", + "f,b,c:2", + "f,c:2", + "f,c,b:2") + val intExpectedValue = expectedValue4 + .map(_.split(":")) + .map(x => (x.apply(0).split(",").flatMap(-1 +: _.toArray.map(coder)).drop(1), x.apply(1).toLong)) + assert(compareResult(intExpectedValue, result4.collect())) } }