From b3aa957a2abf92fdc5b0389d79bfec9389dcbaf8 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 10 Jun 2015 18:50:05 +0800 Subject: [PATCH] add stopWordsRemover --- .../spark/ml/feature/StopWordsRemover.scala | 106 ++++++++++++++++++ .../ml/feature/StopWordsRemoverSuite.scala | 84 ++++++++++++++ 2 files changed, 190 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala new file mode 100644 index 0000000000000..bd9b92ce8bc66 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.param.{BooleanParam, Param} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.{StructField, ArrayType, StructType} +import org.apache.spark.sql.functions.{col, udf} + +/** + * :: Experimental :: + * stop words list + */ +@Experimental +object StopWords{ + val EnglishSet = ("a an and are as at be by for from has he in is it its of on that the to " + + "was were will with").split("\\s").toSet +} + +/** + * :: Experimental :: + * A feature transformer that filters out stop words from input + * [[http://en.wikipedia.org/wiki/Stop_words]] + */ +@Experimental +class StopWordsRemover(override val uid: String) + extends Transformer with HasInputCol with HasOutputCol { + + def this() = this(Identifiable.randomUID("stopWords")) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** + * the stop words set to be filtered out + * @group param + */ + val stopWords: Param[Set[String]] = new Param(this, "stopWords", "stop words") + + /** @group setParam */ + def setStopWords(value: Set[String]): this.type = set(stopWords, value) + + /** @group getParam */ + def getStopWords: Set[String] = getOrDefault(stopWords) + + /** + * whether to do a case sensitive comparison over the stop words + * @group param + */ + val caseSensitive: BooleanParam = new BooleanParam(this, "caseSensitive", + "whether to do case-sensitive filter") + + /** @group setParam */ + def setCaseSensitive(value: Boolean): this.type = set(caseSensitive, value) + + /** @group getParam */ + def getCaseSensitive: Boolean = getOrDefault(caseSensitive) + + setDefault(stopWords -> StopWords.EnglishSet, caseSensitive -> false) + + override def transform(dataset: DataFrame): DataFrame = { + val outputSchema = transformSchema(dataset.schema) + val t = udf { terms: Seq[String] => + if ($(caseSensitive)) { + terms.filterNot(s => s != null && $(stopWords).contains(s)) + } + else { + val lowerStopWords = $(stopWords).map(_.toLowerCase) + terms.filterNot(s => s != null && lowerStopWords.contains(s.toLowerCase)) + } + } + val metadata = outputSchema($(outputCol)).metadata + dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata)) + } + + override def transformSchema(schema: StructType): StructType = { + val inputType = schema($(inputCol)).dataType + require(inputType.isInstanceOf[ArrayType], + s"The input column must be ArrayType, but got $inputType.") + val outputFields = schema.fields :+ + StructField($(outputCol), inputType, schema($(inputCol)).nullable) + StructType(outputFields) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala new file mode 100644 index 0000000000000..33f036f3d59d3 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Row} + +import scala.beans.BeanInfo + +@BeanInfo +case class StopWordsTestData(raw: Array[String], wanted: Array[String]) + +object StopWordsRemoverSuite extends SparkFunSuite { + def testStopWordsRemover(t: StopWordsRemover, dataset: DataFrame): Unit = { + t.transform(dataset) + .select("filtered", "wanted") + .collect() + .foreach { case Row(tokens, wantedTokens) => + assert(tokens === wantedTokens) + } + } +} + +class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext { + import org.apache.spark.ml.feature.StopWordsRemoverSuite._ + + test("StopWordsRemover default") { + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") + val dataset = sqlContext.createDataFrame(Seq( + StopWordsTestData(Array("test", "test"), Array("test", "test")), + StopWordsTestData(Array("a", "b", "c", "d"), Array("b", "c", "d")), + StopWordsTestData(Array("a", "the", "an"), Array()), + StopWordsTestData(Array("A", "The", "AN"), Array()), + StopWordsTestData(Array(null), Array(null)), + StopWordsTestData(Array(), Array()) + )) + testStopWordsRemover(remover, dataset) + } + + test("StopWordsRemover case sensitive") { + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") + .setCaseSensitive(true) + + val dataset = sqlContext.createDataFrame(Seq( + StopWordsTestData(Array("A"), Array("A")), + StopWordsTestData(Array("The", "the"), Array("The")) + )) + testStopWordsRemover(remover, dataset) + } + + test("StopWordsRemover with additional words") { + val stopWords = StopWords.EnglishSet + "python" + "scala" + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") + .setStopWords(stopWords) + + val dataset = sqlContext.createDataFrame(Seq( + StopWordsTestData(Array("python", "scala", "a"), Array()), + StopWordsTestData(Array("Python", "Scala", "swift"), Array("swift")) + )) + testStopWordsRemover(remover, dataset) + } +}