diff --git a/core/src/main/scala/com/salesforce/op/dsl/RichListFeature.scala b/core/src/main/scala/com/salesforce/op/dsl/RichListFeature.scala index 02962865f6..3bcfa17853 100644 --- a/core/src/main/scala/com/salesforce/op/dsl/RichListFeature.scala +++ b/core/src/main/scala/com/salesforce/op/dsl/RichListFeature.scala @@ -60,8 +60,7 @@ trait RichListFeature { numTerms: Int = TransmogrifierDefaults.DefaultNumOfFeatures, binary: Boolean = TransmogrifierDefaults.BinaryFreq ): FeatureLike[OPVector] = { - val htf = new HashingTF().setNumFeatures(numTerms).setBinary(binary) - val tr = new OpTransformerWrapper[TextList, OPVector, HashingTF](htf, UID[HashingTF]) + val tr = new OpHashingTF().setNumFeatures(numTerms).setBinary(binary) f.transformWith(tr) } @@ -151,8 +150,7 @@ trait RichListFeature { * @return */ def ngram(n: Int = 2): FeatureLike[TextList] = { - val ngrm = new NGram().setN(n) - val tr = new OpTransformerWrapper[TextList, TextList, NGram](ngrm, UID[NGram]) + val tr = new OpNGram().setN(n) f.transformWith(tr) } @@ -169,9 +167,8 @@ trait RichListFeature { stopWords: Array[String] = StopWordsRemover.loadDefaultStopWords("english"), caseSensitive: Boolean = false ): FeatureLike[TextList] = { - val remover = new StopWordsRemover().setStopWords(stopWords).setCaseSensitive(caseSensitive) - val tr = new OpTransformerWrapper[TextList, TextList, StopWordsRemover](remover, UID[StopWordsRemover]) - f.transformWith(tr) + val remover = new OpStopWordsRemover().setStopWords(stopWords).setCaseSensitive(caseSensitive) + f.transformWith(remover) } diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpHashingTF.scala b/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpHashingTF.scala new file mode 100644 index 0000000000..1647a475f9 --- /dev/null +++ b/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpHashingTF.scala @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2017, Salesforce.com, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * * Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package com.salesforce.op.stages.impl.feature + +import com.salesforce.op.UID +import com.salesforce.op.features.types._ +import com.salesforce.op.stages.sparkwrappers.specific.OpTransformerWrapper +import org.apache.spark.ml.feature.HashingTF + +/** + * Wrapper for [[org.apache.spark.ml.feature.HashingTF]] + * + * Maps a sequence of terms to their term frequencies using the hashing trick. + * Currently we use Austin Appleby's MurmurHash 3 algorithm (MurmurHash3_x86_32) + * to calculate the hash code value for the term object. + * Since a simple modulo is used to transform the hash function to a column index, + * it is advisable to use a power of two as the numFeatures parameter; + * otherwise the features will not be mapped evenly to the columns. + * + * @see [[HashingTF]] for more info + */ +class OpHashingTF(uid: String = UID[HashingTF]) + extends OpTransformerWrapper[TextList, OPVector, HashingTF](transformer = new HashingTF(), uid = uid) { + + /** + * Number of features. Should be greater than 0. + * (default = 2^18^) + */ + def setNumFeatures(value: Int): this.type = { + getSparkMlStage().get.setNumFeatures(value) + this + } + + /** + * Binary toggle to control term frequency counts. + * If true, all non-zero counts are set to 1. This is useful for discrete probabilistic + * models that model binary events rather than integer counts. + * (default = false) + */ + def setBinary(value: Boolean): this.type = { + getSparkMlStage().get.setBinary(value) + this + } +} diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpIndexToString.scala b/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpIndexToString.scala index a0f50ad926..dddaad97a0 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpIndexToString.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpIndexToString.scala @@ -37,7 +37,7 @@ import enumeratum._ import org.apache.spark.ml.feature.IndexToString /** - * OP wrapper for [[org.apache.spark.ml.feature.IndexToString]] + * Wrapper for [[org.apache.spark.ml.feature.IndexToString]] * * NOTE THAT THIS CLASS EITHER FILTERS OUT OR THROWS AN ERROR IF PREVIOUSLY UNSEEN VALUES APPEAR * diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpNGram.scala b/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpNGram.scala new file mode 100644 index 0000000000..b07ac93540 --- /dev/null +++ b/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpNGram.scala @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2017, Salesforce.com, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * * Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package com.salesforce.op.stages.impl.feature + +import com.salesforce.op.UID +import com.salesforce.op.features.types._ +import com.salesforce.op.stages.sparkwrappers.specific.OpTransformerWrapper +import org.apache.spark.ml.feature.NGram + +/** + * Wrapper for [[org.apache.spark.ml.feature.NGram]] + * + * A feature transformer that converts the input array of strings into an array of n-grams. Null + * values in the input array are ignored. + * It returns an array of n-grams where each n-gram is represented by a space-separated string of + * words. + * + * When the input is empty, an empty array is returned. + * When the input array length is less than n (number of elements per n-gram), no n-grams are + * returned. + * + * @see [[NGram]] for more info + */ +class OpNGram(uid: String = UID[NGram]) + extends OpTransformerWrapper[TextList, TextList, NGram](transformer = new NGram(), uid = uid) { + + /** + * Minimum n-gram length, greater than or equal to 1. + * Default: 2, bigram features + */ + def setN(value: Int): this.type = { + getSparkMlStage().get.setN(value) + this + } + +} diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpStopWordsRemover.scala b/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpStopWordsRemover.scala new file mode 100644 index 0000000000..e58ea6f66f --- /dev/null +++ b/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpStopWordsRemover.scala @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2017, Salesforce.com, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * * Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package com.salesforce.op.stages.impl.feature + +import com.salesforce.op.UID +import com.salesforce.op.features.types._ +import com.salesforce.op.stages.sparkwrappers.specific.OpTransformerWrapper +import org.apache.spark.ml.feature.StopWordsRemover + +/** + * Wrapper for [[org.apache.spark.ml.feature.StopWordsRemover]] + * + * A feature transformer that filters out stop words from input. + * + * @note null values from input array are preserved unless adding null to stopWords explicitly. + * + * @see Stop words (Wikipedia) + * @see [[StopWordsRemover]] for more info + */ +class OpStopWordsRemover(uid: String = UID[StopWordsRemover]) + extends OpTransformerWrapper[TextList, TextList, StopWordsRemover](transformer = new StopWordsRemover(), uid = uid) { + + /** + * The words to be filtered out. + * Default: English stop words + * + * @see `StopWordsRemover.loadDefaultStopWords()` + */ + def setStopWords(value: Array[String]): this.type = { + getSparkMlStage().get.setStopWords(value) + this + } + + /** + * Whether to do a case sensitive comparison over the stop words. + * Default: false + */ + def setCaseSensitive(value: Boolean): this.type = { + getSparkMlStage().get.setCaseSensitive(value) + this + } +} diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpStringIndexer.scala b/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpStringIndexer.scala index a0359200a3..f7318187e9 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpStringIndexer.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpStringIndexer.scala @@ -40,7 +40,7 @@ import org.apache.spark.ml.feature.{StringIndexer, StringIndexerModel} import scala.reflect.runtime.universe.TypeTag /** - * OP wrapper for [[org.apache.spark.ml.feature.StringIndexer]] + * Wrapper for [[org.apache.spark.ml.feature.StringIndexer]] * * NOTE THAT THIS CLASS EITHER FILTERS OUT OR THROWS AN ERROR IF PREVIOUSLY UNSEEN VALUES APPEAR * diff --git a/core/src/main/scala/com/salesforce/op/stages/sparkwrappers/specific/OpTransformerWrapper.scala b/core/src/main/scala/com/salesforce/op/stages/sparkwrappers/specific/OpTransformerWrapper.scala index ed3b379ebf..92a34ed375 100644 --- a/core/src/main/scala/com/salesforce/op/stages/sparkwrappers/specific/OpTransformerWrapper.scala +++ b/core/src/main/scala/com/salesforce/op/stages/sparkwrappers/specific/OpTransformerWrapper.scala @@ -39,9 +39,6 @@ import org.apache.spark.ml.param.ParamMap import scala.reflect.runtime.universe.TypeTag -// TODO: all the transformers that inherit traits HasInputCol and HasOutputCol should really extend -// org.apache.spark.ml.UnaryTransformer, so can add a PR to spark so we can then move this class to our namespace. - /** * Wraps a spark ML transformer with setable input and output columns. Those transformers that fall in this case, * include those that inherit from org.apache.spark.ml.UnaryEstimator, as well as others such as OneHotEncoder, diff --git a/core/src/test/scala/com/salesforce/op/stages/impl/feature/HashingTFTest.scala b/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpHashingTFTest.scala similarity index 81% rename from core/src/test/scala/com/salesforce/op/stages/impl/feature/HashingTFTest.scala rename to core/src/test/scala/com/salesforce/op/stages/impl/feature/OpHashingTFTest.scala index bf26881b41..cdfd46ccbc 100644 --- a/core/src/test/scala/com/salesforce/op/stages/impl/feature/HashingTFTest.scala +++ b/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpHashingTFTest.scala @@ -31,20 +31,20 @@ package com.salesforce.op.stages.impl.feature import com.salesforce.op._ -import com.salesforce.op.features.types._ import com.salesforce.op.features.Feature -import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext} +import com.salesforce.op.features.types._ +import com.salesforce.op.test.{SwTransformerSpec, TestFeatureBuilder} import com.salesforce.op.utils.spark.RichDataset._ import org.apache.spark.ml.Transformer -import org.apache.spark.mllib.feature.HashingTF +import org.apache.spark.ml.feature.HashingTF +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.sql.DataFrame import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner -import org.scalatest.{Assertions, FlatSpec, Matchers} @RunWith(classOf[JUnitRunner]) -class HashingTFTest extends FlatSpec with TestSparkContext { +class OpHashingTFTest extends SwTransformerSpec[OPVector, HashingTF, OpHashingTF] { // scalastyle:off val testData = Seq( @@ -55,19 +55,27 @@ class HashingTFTest extends FlatSpec with TestSparkContext { ).map(_.toLowerCase.split(" ").toSeq.toTextList) // scalastyle:on - lazy val (ds, f1): (DataFrame, Feature[TextList]) = TestFeatureBuilder(testData) + val (inputData, f1): (DataFrame, Feature[TextList]) = TestFeatureBuilder(testData) + + val hashed = f1.tf(numTerms = 5) + val transformer = hashed.originStage.asInstanceOf[OpHashingTF] + + val expectedResult: Seq[OPVector] = Seq( + Vectors.sparse(5, Array(0, 1, 2, 3, 4), Array(2.0, 4.0, 2.0, 3.0, 1.0)), + Vectors.sparse(5, Array(0, 1, 2, 3, 4), Array(4.0, 1.0, 3.0, 1.0, 1.0)), + Vectors.sparse(5, Array(0, 2, 3, 4), Array(2.0, 2.0, 2.0, 2.0)), + Vectors.sparse(5, Array(0, 1, 2, 4), Array(3.0, 5.0, 1.0, 2.0)) + ).map(_.toOPVector) def hash( s: String, numOfFeatures: Int = TransmogrifierDefaults.DefaultNumOfFeatures, binary: Boolean = false - ): Int = { - new HashingTF(numOfFeatures).setBinary(binary).indexOf(s) - } + ): Int = new org.apache.spark.mllib.feature.HashingTF(numOfFeatures).setBinary(binary).indexOf(s) - Spec[HashingTF] should "hash categorical data" in { + it should "hash categorical data" in { val hashed = f1.tf() - val transformedData = hashed.originStage.asInstanceOf[Transformer].transform(ds) + val transformedData = hashed.originStage.asInstanceOf[Transformer].transform(inputData) val results = transformedData.select(hashed.name).collect(hashed) hashed.name shouldBe hashed.originStage.getOutputFeatureName @@ -86,7 +94,7 @@ class HashingTFTest extends FlatSpec with TestSparkContext { val numFeatures = 100 val hashed = f1.tf(numTerms = numFeatures) - val transformedData = hashed.originStage.asInstanceOf[Transformer].transform(ds) + val transformedData = hashed.originStage.asInstanceOf[Transformer].transform(inputData) val results = transformedData.select(hashed.name).collect(hashed) // scalastyle:off @@ -101,7 +109,7 @@ class HashingTFTest extends FlatSpec with TestSparkContext { val binary = true val hashed = f1.tf(binary = binary) - val transformedData = hashed.originStage.asInstanceOf[Transformer].transform(ds) + val transformedData = hashed.originStage.asInstanceOf[Transformer].transform(inputData) val results = transformedData.select(hashed.name).collect(hashed) // scalastyle:off diff --git a/core/src/test/scala/com/salesforce/op/stages/impl/feature/NGramTest.scala b/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpNGramTest.scala similarity index 93% rename from core/src/test/scala/com/salesforce/op/stages/impl/feature/NGramTest.scala rename to core/src/test/scala/com/salesforce/op/stages/impl/feature/OpNGramTest.scala index 5e51ac75ad..4937113eb4 100644 --- a/core/src/test/scala/com/salesforce/op/stages/impl/feature/NGramTest.scala +++ b/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpNGramTest.scala @@ -42,14 +42,14 @@ import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) -class NGramTest extends SwTransformerSpec[TextList, NGram, OpTransformerWrapper[TextList, TextList, NGram]] { +class OpNGramTest extends SwTransformerSpec[TextList, NGram, OpNGram] { val data = Seq("a b c d e f g").map(_.split(" ").toSeq.toTextList) val (inputData, textListFeature) = TestFeatureBuilder(data) val expectedResult = Seq(Seq("a b", "b c", "c d", "d e", "e f", "f g").toTextList) val bigrams = textListFeature.ngram() - val transformer = bigrams.originStage.asInstanceOf[OpTransformerWrapper[TextList, TextList, NGram]] + val transformer = bigrams.originStage.asInstanceOf[OpNGram] it should "generate unigrams" in { val unigrams = textListFeature.ngram(n = 1) diff --git a/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpStopWordsRemoverTest.scala b/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpStopWordsRemoverTest.scala new file mode 100644 index 0000000000..5169741039 --- /dev/null +++ b/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpStopWordsRemoverTest.scala @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2017, Salesforce.com, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * * Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package com.salesforce.op.stages.impl.feature + +import com.salesforce.op._ +import com.salesforce.op.features.types._ +import com.salesforce.op.utils.spark.RichDataset._ +import com.salesforce.op.test.{SwTransformerSpec, TestFeatureBuilder} +import org.apache.spark.ml.feature.StopWordsRemover +import org.junit.runner.RunWith +import org.scalatest.junit.JUnitRunner + + +@RunWith(classOf[JUnitRunner]) +class OpStopWordsRemoverTest extends SwTransformerSpec[TextList, StopWordsRemover, OpStopWordsRemover] { + val data = Seq( + "I AM groot", "Groot call me human", "or I will crush you" + ).map(_.split(" ").toSeq.toTextList) + + val (inputData, textListFeature) = TestFeatureBuilder(data) + + val bigrams = textListFeature.removeStopWords() + val transformer = bigrams.originStage.asInstanceOf[OpStopWordsRemover] + + val expectedResult = Seq(Seq("groot"), Seq("Groot", "call", "human"), Seq("crush")).map(_.toTextList) + + it should "allow case sensitivity" in { + val noStopWords = textListFeature.removeStopWords(caseSensitive = true) + val res = noStopWords.originStage.asInstanceOf[OpStopWordsRemover].transform(inputData) + res.collect(noStopWords) shouldBe Seq( + Seq("I", "AM", "groot"), Seq("Groot", "call", "human"), Seq("I", "crush")).map(_.toTextList) + } + + it should "set custom stop words" in { + val noStopWords = textListFeature.removeStopWords(stopWords = Array("Groot", "I")) + val res = noStopWords.originStage.asInstanceOf[OpStopWordsRemover].transform(inputData) + res.collect(noStopWords) shouldBe Seq( + Seq("AM"), Seq("call", "me", "human"), Seq("or", "will", "crush", "you")).map(_.toTextList) + } +}