Skip to content

Commit

Permalink
add stopWordsRemover
Browse files Browse the repository at this point in the history
  • Loading branch information
hhbyyh committed Jun 10, 2015
1 parent 6e4fb0c commit b3aa957
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.feature

import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.param.{BooleanParam, Param}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.{StructField, ArrayType, StructType}
import org.apache.spark.sql.functions.{col, udf}

/**
* :: Experimental ::
* stop words list
*/
@Experimental
object StopWords{
val EnglishSet = ("a an and are as at be by for from has he in is it its of on that the to " +
"was were will with").split("\\s").toSet
}

/**
* :: Experimental ::
* A feature transformer that filters out stop words from input
* [[http://en.wikipedia.org/wiki/Stop_words]]
*/
@Experimental
class StopWordsRemover(override val uid: String)
extends Transformer with HasInputCol with HasOutputCol {

def this() = this(Identifiable.randomUID("stopWords"))

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)

/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)

/**
* the stop words set to be filtered out
* @group param
*/
val stopWords: Param[Set[String]] = new Param(this, "stopWords", "stop words")

/** @group setParam */
def setStopWords(value: Set[String]): this.type = set(stopWords, value)

/** @group getParam */
def getStopWords: Set[String] = getOrDefault(stopWords)

/**
* whether to do a case sensitive comparison over the stop words
* @group param
*/
val caseSensitive: BooleanParam = new BooleanParam(this, "caseSensitive",
"whether to do case-sensitive filter")

/** @group setParam */
def setCaseSensitive(value: Boolean): this.type = set(caseSensitive, value)

/** @group getParam */
def getCaseSensitive: Boolean = getOrDefault(caseSensitive)

setDefault(stopWords -> StopWords.EnglishSet, caseSensitive -> false)

override def transform(dataset: DataFrame): DataFrame = {
val outputSchema = transformSchema(dataset.schema)
val t = udf { terms: Seq[String] =>
if ($(caseSensitive)) {
terms.filterNot(s => s != null && $(stopWords).contains(s))
}
else {
val lowerStopWords = $(stopWords).map(_.toLowerCase)
terms.filterNot(s => s != null && lowerStopWords.contains(s.toLowerCase))
}
}
val metadata = outputSchema($(outputCol)).metadata
dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata))
}

override def transformSchema(schema: StructType): StructType = {
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[ArrayType],
s"The input column must be ArrayType, but got $inputType.")
val outputFields = schema.fields :+
StructField($(outputCol), inputType, schema($(inputCol)).nullable)
StructType(outputFields)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}

import scala.beans.BeanInfo

@BeanInfo
case class StopWordsTestData(raw: Array[String], wanted: Array[String])

object StopWordsRemoverSuite extends SparkFunSuite {
def testStopWordsRemover(t: StopWordsRemover, dataset: DataFrame): Unit = {
t.transform(dataset)
.select("filtered", "wanted")
.collect()
.foreach { case Row(tokens, wantedTokens) =>
assert(tokens === wantedTokens)
}
}
}

class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext {
import org.apache.spark.ml.feature.StopWordsRemoverSuite._

test("StopWordsRemover default") {
val remover = new StopWordsRemover()
.setInputCol("raw")
.setOutputCol("filtered")
val dataset = sqlContext.createDataFrame(Seq(
StopWordsTestData(Array("test", "test"), Array("test", "test")),
StopWordsTestData(Array("a", "b", "c", "d"), Array("b", "c", "d")),
StopWordsTestData(Array("a", "the", "an"), Array()),
StopWordsTestData(Array("A", "The", "AN"), Array()),
StopWordsTestData(Array(null), Array(null)),
StopWordsTestData(Array(), Array())
))
testStopWordsRemover(remover, dataset)
}

test("StopWordsRemover case sensitive") {
val remover = new StopWordsRemover()
.setInputCol("raw")
.setOutputCol("filtered")
.setCaseSensitive(true)

val dataset = sqlContext.createDataFrame(Seq(
StopWordsTestData(Array("A"), Array("A")),
StopWordsTestData(Array("The", "the"), Array("The"))
))
testStopWordsRemover(remover, dataset)
}

test("StopWordsRemover with additional words") {
val stopWords = StopWords.EnglishSet + "python" + "scala"
val remover = new StopWordsRemover()
.setInputCol("raw")
.setOutputCol("filtered")
.setStopWords(stopWords)

val dataset = sqlContext.createDataFrame(Seq(
StopWordsTestData(Array("python", "scala", "a"), Array()),
StopWordsTestData(Array("Python", "Scala", "swift"), Array("swift"))
))
testStopWordsRemover(remover, dataset)
}
}

0 comments on commit b3aa957

Please sign in to comment.