diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index bd9b92ce8bc66..8bbfe43a00a54 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -20,26 +20,68 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.Experimental import org.apache.spark.ml.Transformer import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.param.{BooleanParam, Param} +import org.apache.spark.ml.param.{ParamMap, BooleanParam, Param} import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.types.{StructField, ArrayType, StructType} +import org.apache.spark.sql.types.{StringType, StructField, ArrayType, StructType} import org.apache.spark.sql.functions.{col, udf} /** - * :: Experimental :: * stop words list */ -@Experimental -object StopWords{ - val EnglishSet = ("a an and are as at be by for from has he in is it its of on that the to " + - "was were will with").split("\\s").toSet +private object StopWords{ + + /** + * Use the same default stopwords list as scikit-learn. + * The original list can be found from "Glasgow Information Retrieval Group" + * [[http://ir.dcs.gla.ac.uk/resources/linguistic_utils/stop_words]] + */ + val EnglishStopWords = Array( "a", "about", "above", "across", "after", "afterwards", "again", + "against", "all", "almost", "alone", "along", "already", "also", "although", "always", + "am", "among", "amongst", "amoungst", "amount", "an", "and", "another", + "any", "anyhow", "anyone", "anything", "anyway", "anywhere", "are", + "around", "as", "at", "back", "be", "became", "because", "become", + "becomes", "becoming", "been", "before", "beforehand", "behind", "being", + "below", "beside", "besides", "between", "beyond", "bill", "both", + "bottom", "but", "by", "call", "can", "cannot", "cant", "co", "con", + "could", "couldnt", "cry", "de", "describe", "detail", "do", "done", + "down", "due", "during", "each", "eg", "eight", "either", "eleven", "else", + "elsewhere", "empty", "enough", "etc", "even", "ever", "every", "everyone", + "everything", "everywhere", "except", "few", "fifteen", "fify", "fill", + "find", "fire", "first", "five", "for", "former", "formerly", "forty", + "found", "four", "from", "front", "full", "further", "get", "give", "go", + "had", "has", "hasnt", "have", "he", "hence", "her", "here", "hereafter", + "hereby", "herein", "hereupon", "hers", "herself", "him", "himself", "his", + "how", "however", "hundred", "i", "ie", "if", "in", "inc", "indeed", + "interest", "into", "is", "it", "its", "itself", "keep", "last", "latter", + "latterly", "least", "less", "ltd", "made", "many", "may", "me", + "meanwhile", "might", "mill", "mine", "more", "moreover", "most", "mostly", + "move", "much", "must", "my", "myself", "name", "namely", "neither", + "never", "nevertheless", "next", "nine", "no", "nobody", "none", "noone", + "nor", "not", "nothing", "now", "nowhere", "of", "off", "often", "on", + "once", "one", "only", "onto", "or", "other", "others", "otherwise", "our", + "ours", "ourselves", "out", "over", "own", "part", "per", "perhaps", + "please", "put", "rather", "re", "same", "see", "seem", "seemed", + "seeming", "seems", "serious", "several", "she", "should", "show", "side", + "since", "sincere", "six", "sixty", "so", "some", "somehow", "someone", + "something", "sometime", "sometimes", "somewhere", "still", "such", + "system", "take", "ten", "than", "that", "the", "their", "them", + "themselves", "then", "thence", "there", "thereafter", "thereby", + "therefore", "therein", "thereupon", "these", "they", "thick", "thin", + "third", "this", "those", "though", "three", "through", "throughout", + "thru", "thus", "to", "together", "too", "top", "toward", "towards", + "twelve", "twenty", "two", "un", "under", "until", "up", "upon", "us", + "very", "via", "was", "we", "well", "were", "what", "whatever", "when", + "whence", "whenever", "where", "whereafter", "whereas", "whereby", + "wherein", "whereupon", "wherever", "whether", "which", "while", "whither", + "who", "whoever", "whole", "whom", "whose", "why", "will", "with", + "within", "without", "would", "yet", "you", "your", "yours", "yourself", "yourselves") } /** * :: Experimental :: * A feature transformer that filters out stop words from input - * [[http://en.wikipedia.org/wiki/Stop_words]] + * @see [[http://en.wikipedia.org/wiki/Stop_words]] */ @Experimental class StopWordsRemover(override val uid: String) @@ -57,38 +99,38 @@ class StopWordsRemover(override val uid: String) * the stop words set to be filtered out * @group param */ - val stopWords: Param[Set[String]] = new Param(this, "stopWords", "stop words") + val stopWords: Param[Array[String]] = new Param(this, "stopWords", "stop words") /** @group setParam */ - def setStopWords(value: Set[String]): this.type = set(stopWords, value) + def setStopWords(value: Array[String]): this.type = set(stopWords, value) /** @group getParam */ - def getStopWords: Set[String] = getOrDefault(stopWords) + def getStopWords: Array[String] = $(stopWords) /** * whether to do a case sensitive comparison over the stop words * @group param */ val caseSensitive: BooleanParam = new BooleanParam(this, "caseSensitive", - "whether to do case-sensitive filter") + "whether to do case-sensitive comparison during filtering") /** @group setParam */ def setCaseSensitive(value: Boolean): this.type = set(caseSensitive, value) /** @group getParam */ - def getCaseSensitive: Boolean = getOrDefault(caseSensitive) + def getCaseSensitive: Boolean = $(caseSensitive) - setDefault(stopWords -> StopWords.EnglishSet, caseSensitive -> false) + setDefault(stopWords -> StopWords.EnglishStopWords, caseSensitive -> false) override def transform(dataset: DataFrame): DataFrame = { val outputSchema = transformSchema(dataset.schema) + val stopwordsSet = $(stopWords).toSet + val lowerStopWords = stopwordsSet.map(_.toLowerCase) val t = udf { terms: Seq[String] => if ($(caseSensitive)) { - terms.filterNot(s => s != null && $(stopWords).contains(s)) - } - else { - val lowerStopWords = $(stopWords).map(_.toLowerCase) - terms.filterNot(s => s != null && lowerStopWords.contains(s.toLowerCase)) + terms.filter(s => s == null || !stopwordsSet.contains(s)) + } else { + terms.filter(s => s == null || !lowerStopWords.contains(s.toLowerCase)) } } val metadata = outputSchema($(outputCol)).metadata @@ -97,10 +139,12 @@ class StopWordsRemover(override val uid: String) override def transformSchema(schema: StructType): StructType = { val inputType = schema($(inputCol)).dataType - require(inputType.isInstanceOf[ArrayType], - s"The input column must be ArrayType, but got $inputType.") + require(inputType.sameType(ArrayType(StringType)), + s"Input type must be ArrayType(StringType) but got $inputType.") val outputFields = schema.fields :+ StructField($(outputCol), inputType, schema($(inputCol)).nullable) StructType(outputFields) } + + override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index 33f036f3d59d3..b8e666cb916e6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -21,18 +21,13 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} -import scala.beans.BeanInfo - -@BeanInfo -case class StopWordsTestData(raw: Array[String], wanted: Array[String]) - object StopWordsRemoverSuite extends SparkFunSuite { def testStopWordsRemover(t: StopWordsRemover, dataset: DataFrame): Unit = { t.transform(dataset) - .select("filtered", "wanted") + .select("filtered", "expected") .collect() .foreach { case Row(tokens, wantedTokens) => - assert(tokens === wantedTokens) + assert(tokens === wantedTokens) } } } @@ -44,15 +39,16 @@ class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext { val remover = new StopWordsRemover() .setInputCol("raw") .setOutputCol("filtered") - val dataset = sqlContext.createDataFrame(Seq( - StopWordsTestData(Array("test", "test"), Array("test", "test")), - StopWordsTestData(Array("a", "b", "c", "d"), Array("b", "c", "d")), - StopWordsTestData(Array("a", "the", "an"), Array()), - StopWordsTestData(Array("A", "The", "AN"), Array()), - StopWordsTestData(Array(null), Array(null)), - StopWordsTestData(Array(), Array()) - )) - testStopWordsRemover(remover, dataset) + val dataSet = sqlContext.createDataFrame(Seq( + (Seq("test", "test"), Seq("test", "test")), + (Seq("a", "b", "c", "d"), Seq("b", "c", "d")), + (Seq("a", "the", "an"), Seq()), + (Seq("A", "The", "AN"), Seq()), + (Seq(null), Seq(null)), + (Seq(), Seq()) + )).toDF("raw", "expected") + + testStopWordsRemover(remover, dataSet) } test("StopWordsRemover case sensitive") { @@ -60,25 +56,25 @@ class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext { .setInputCol("raw") .setOutputCol("filtered") .setCaseSensitive(true) + val dataSet = sqlContext.createDataFrame(Seq( + (Seq("A"), Seq("A")), + (Seq("The", "the"), Seq("The")) + )).toDF("raw", "expected") - val dataset = sqlContext.createDataFrame(Seq( - StopWordsTestData(Array("A"), Array("A")), - StopWordsTestData(Array("The", "the"), Array("The")) - )) - testStopWordsRemover(remover, dataset) + testStopWordsRemover(remover, dataSet) } test("StopWordsRemover with additional words") { - val stopWords = StopWords.EnglishSet + "python" + "scala" + val stopWords = StopWords.EnglishStopWords ++ Array("python", "scala") val remover = new StopWordsRemover() .setInputCol("raw") .setOutputCol("filtered") .setStopWords(stopWords) + val dataSet = sqlContext.createDataFrame(Seq( + (Seq("python", "scala", "a"), Seq()), + (Seq("Python", "Scala", "swift"), Seq("swift")) + )).toDF("raw", "expected") - val dataset = sqlContext.createDataFrame(Seq( - StopWordsTestData(Array("python", "scala", "a"), Array()), - StopWordsTestData(Array("Python", "Scala", "swift"), Array("swift")) - )) - testStopWordsRemover(remover, dataset) + testStopWordsRemover(remover, dataSet) } }