Skip to content

Commit

Permalink
Added concrete wrappers for HashingTF, NGram and StopWordsRemover (#314)
Browse files Browse the repository at this point in the history
  • Loading branch information
tovbinm committed May 13, 2019
1 parent 976cd25 commit 3992cfe
Show file tree
Hide file tree
Showing 10 changed files with 303 additions and 27 deletions.
11 changes: 4 additions & 7 deletions core/src/main/scala/com/salesforce/op/dsl/RichListFeature.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ trait RichListFeature {
numTerms: Int = TransmogrifierDefaults.DefaultNumOfFeatures,
binary: Boolean = TransmogrifierDefaults.BinaryFreq
): FeatureLike[OPVector] = {
val htf = new HashingTF().setNumFeatures(numTerms).setBinary(binary)
val tr = new OpTransformerWrapper[TextList, OPVector, HashingTF](htf, UID[HashingTF])
val tr = new OpHashingTF().setNumFeatures(numTerms).setBinary(binary)
f.transformWith(tr)
}

Expand Down Expand Up @@ -151,8 +150,7 @@ trait RichListFeature {
* @return
*/
def ngram(n: Int = 2): FeatureLike[TextList] = {
val ngrm = new NGram().setN(n)
val tr = new OpTransformerWrapper[TextList, TextList, NGram](ngrm, UID[NGram])
val tr = new OpNGram().setN(n)
f.transformWith(tr)
}

Expand All @@ -169,9 +167,8 @@ trait RichListFeature {
stopWords: Array[String] = StopWordsRemover.loadDefaultStopWords("english"),
caseSensitive: Boolean = false
): FeatureLike[TextList] = {
val remover = new StopWordsRemover().setStopWords(stopWords).setCaseSensitive(caseSensitive)
val tr = new OpTransformerWrapper[TextList, TextList, StopWordsRemover](remover, UID[StopWordsRemover])
f.transformWith(tr)
val remover = new OpStopWordsRemover().setStopWords(stopWords).setCaseSensitive(caseSensitive)
f.transformWith(remover)
}


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright (c) 2017, Salesforce.com, Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* * Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* * Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* * Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

package com.salesforce.op.stages.impl.feature

import com.salesforce.op.UID
import com.salesforce.op.features.types._
import com.salesforce.op.stages.sparkwrappers.specific.OpTransformerWrapper
import org.apache.spark.ml.feature.HashingTF

/**
* Wrapper for [[org.apache.spark.ml.feature.HashingTF]]
*
* Maps a sequence of terms to their term frequencies using the hashing trick.
* Currently we use Austin Appleby's MurmurHash 3 algorithm (MurmurHash3_x86_32)
* to calculate the hash code value for the term object.
* Since a simple modulo is used to transform the hash function to a column index,
* it is advisable to use a power of two as the numFeatures parameter;
* otherwise the features will not be mapped evenly to the columns.
*
* @see [[HashingTF]] for more info
*/
class OpHashingTF(uid: String = UID[HashingTF])
extends OpTransformerWrapper[TextList, OPVector, HashingTF](transformer = new HashingTF(), uid = uid) {

/**
* Number of features. Should be greater than 0.
* (default = 2^18^)
*/
def setNumFeatures(value: Int): this.type = {
getSparkMlStage().get.setNumFeatures(value)
this
}

/**
* Binary toggle to control term frequency counts.
* If true, all non-zero counts are set to 1. This is useful for discrete probabilistic
* models that model binary events rather than integer counts.
* (default = false)
*/
def setBinary(value: Boolean): this.type = {
getSparkMlStage().get.setBinary(value)
this
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import enumeratum._
import org.apache.spark.ml.feature.IndexToString

/**
* OP wrapper for [[org.apache.spark.ml.feature.IndexToString]]
* Wrapper for [[org.apache.spark.ml.feature.IndexToString]]
*
* NOTE THAT THIS CLASS EITHER FILTERS OUT OR THROWS AN ERROR IF PREVIOUSLY UNSEEN VALUES APPEAR
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright (c) 2017, Salesforce.com, Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* * Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* * Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* * Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

package com.salesforce.op.stages.impl.feature

import com.salesforce.op.UID
import com.salesforce.op.features.types._
import com.salesforce.op.stages.sparkwrappers.specific.OpTransformerWrapper
import org.apache.spark.ml.feature.NGram

/**
* Wrapper for [[org.apache.spark.ml.feature.NGram]]
*
* A feature transformer that converts the input array of strings into an array of n-grams. Null
* values in the input array are ignored.
* It returns an array of n-grams where each n-gram is represented by a space-separated string of
* words.
*
* When the input is empty, an empty array is returned.
* When the input array length is less than n (number of elements per n-gram), no n-grams are
* returned.
*
* @see [[NGram]] for more info
*/
class OpNGram(uid: String = UID[NGram])
extends OpTransformerWrapper[TextList, TextList, NGram](transformer = new NGram(), uid = uid) {

/**
* Minimum n-gram length, greater than or equal to 1.
* Default: 2, bigram features
*/
def setN(value: Int): this.type = {
getSparkMlStage().get.setN(value)
this
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright (c) 2017, Salesforce.com, Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* * Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* * Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* * Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

package com.salesforce.op.stages.impl.feature

import com.salesforce.op.UID
import com.salesforce.op.features.types._
import com.salesforce.op.stages.sparkwrappers.specific.OpTransformerWrapper
import org.apache.spark.ml.feature.StopWordsRemover

/**
* Wrapper for [[org.apache.spark.ml.feature.StopWordsRemover]]
*
* A feature transformer that filters out stop words from input.
*
* @note null values from input array are preserved unless adding null to stopWords explicitly.
*
* @see <a href="http://en.wikipedia.org/wiki/Stop_words">Stop words (Wikipedia)</a>
* @see [[StopWordsRemover]] for more info
*/
class OpStopWordsRemover(uid: String = UID[StopWordsRemover])
extends OpTransformerWrapper[TextList, TextList, StopWordsRemover](transformer = new StopWordsRemover(), uid = uid) {

/**
* The words to be filtered out.
* Default: English stop words
*
* @see `StopWordsRemover.loadDefaultStopWords()`
*/
def setStopWords(value: Array[String]): this.type = {
getSparkMlStage().get.setStopWords(value)
this
}

/**
* Whether to do a case sensitive comparison over the stop words.
* Default: false
*/
def setCaseSensitive(value: Boolean): this.type = {
getSparkMlStage().get.setCaseSensitive(value)
this
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.ml.feature.{StringIndexer, StringIndexerModel}
import scala.reflect.runtime.universe.TypeTag

/**
* OP wrapper for [[org.apache.spark.ml.feature.StringIndexer]]
* Wrapper for [[org.apache.spark.ml.feature.StringIndexer]]
*
* NOTE THAT THIS CLASS EITHER FILTERS OUT OR THROWS AN ERROR IF PREVIOUSLY UNSEEN VALUES APPEAR
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@ import org.apache.spark.ml.param.ParamMap

import scala.reflect.runtime.universe.TypeTag

// TODO: all the transformers that inherit traits HasInputCol and HasOutputCol should really extend
// org.apache.spark.ml.UnaryTransformer, so can add a PR to spark so we can then move this class to our namespace.

/**
* Wraps a spark ML transformer with setable input and output columns. Those transformers that fall in this case,
* include those that inherit from org.apache.spark.ml.UnaryEstimator, as well as others such as OneHotEncoder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,20 @@
package com.salesforce.op.stages.impl.feature

import com.salesforce.op._
import com.salesforce.op.features.types._
import com.salesforce.op.features.Feature
import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext}
import com.salesforce.op.features.types._
import com.salesforce.op.test.{SwTransformerSpec, TestFeatureBuilder}
import com.salesforce.op.utils.spark.RichDataset._
import org.apache.spark.ml.Transformer
import org.apache.spark.mllib.feature.HashingTF
import org.apache.spark.ml.feature.HashingTF
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.DataFrame
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertions, FlatSpec, Matchers}


@RunWith(classOf[JUnitRunner])
class HashingTFTest extends FlatSpec with TestSparkContext {
class OpHashingTFTest extends SwTransformerSpec[OPVector, HashingTF, OpHashingTF] {

// scalastyle:off
val testData = Seq(
Expand All @@ -55,19 +55,27 @@ class HashingTFTest extends FlatSpec with TestSparkContext {
).map(_.toLowerCase.split(" ").toSeq.toTextList)
// scalastyle:on

lazy val (ds, f1): (DataFrame, Feature[TextList]) = TestFeatureBuilder(testData)
val (inputData, f1): (DataFrame, Feature[TextList]) = TestFeatureBuilder(testData)

val hashed = f1.tf(numTerms = 5)
val transformer = hashed.originStage.asInstanceOf[OpHashingTF]

val expectedResult: Seq[OPVector] = Seq(
Vectors.sparse(5, Array(0, 1, 2, 3, 4), Array(2.0, 4.0, 2.0, 3.0, 1.0)),
Vectors.sparse(5, Array(0, 1, 2, 3, 4), Array(4.0, 1.0, 3.0, 1.0, 1.0)),
Vectors.sparse(5, Array(0, 2, 3, 4), Array(2.0, 2.0, 2.0, 2.0)),
Vectors.sparse(5, Array(0, 1, 2, 4), Array(3.0, 5.0, 1.0, 2.0))
).map(_.toOPVector)

def hash(
s: String,
numOfFeatures: Int = TransmogrifierDefaults.DefaultNumOfFeatures,
binary: Boolean = false
): Int = {
new HashingTF(numOfFeatures).setBinary(binary).indexOf(s)
}
): Int = new org.apache.spark.mllib.feature.HashingTF(numOfFeatures).setBinary(binary).indexOf(s)

Spec[HashingTF] should "hash categorical data" in {
it should "hash categorical data" in {
val hashed = f1.tf()
val transformedData = hashed.originStage.asInstanceOf[Transformer].transform(ds)
val transformedData = hashed.originStage.asInstanceOf[Transformer].transform(inputData)
val results = transformedData.select(hashed.name).collect(hashed)

hashed.name shouldBe hashed.originStage.getOutputFeatureName
Expand All @@ -86,7 +94,7 @@ class HashingTFTest extends FlatSpec with TestSparkContext {
val numFeatures = 100

val hashed = f1.tf(numTerms = numFeatures)
val transformedData = hashed.originStage.asInstanceOf[Transformer].transform(ds)
val transformedData = hashed.originStage.asInstanceOf[Transformer].transform(inputData)
val results = transformedData.select(hashed.name).collect(hashed)

// scalastyle:off
Expand All @@ -101,7 +109,7 @@ class HashingTFTest extends FlatSpec with TestSparkContext {
val binary = true

val hashed = f1.tf(binary = binary)
val transformedData = hashed.originStage.asInstanceOf[Transformer].transform(ds)
val transformedData = hashed.originStage.asInstanceOf[Transformer].transform(inputData)
val results = transformedData.select(hashed.name).collect(hashed)

// scalastyle:off
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class NGramTest extends SwTransformerSpec[TextList, NGram, OpTransformerWrapper[TextList, TextList, NGram]] {
class OpNGramTest extends SwTransformerSpec[TextList, NGram, OpNGram] {
val data = Seq("a b c d e f g").map(_.split(" ").toSeq.toTextList)
val (inputData, textListFeature) = TestFeatureBuilder(data)

val expectedResult = Seq(Seq("a b", "b c", "c d", "d e", "e f", "f g").toTextList)

val bigrams = textListFeature.ngram()
val transformer = bigrams.originStage.asInstanceOf[OpTransformerWrapper[TextList, TextList, NGram]]
val transformer = bigrams.originStage.asInstanceOf[OpNGram]

it should "generate unigrams" in {
val unigrams = textListFeature.ngram(n = 1)
Expand Down
Loading

0 comments on commit 3992cfe

Please sign in to comment.