diff --git a/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpIndexToStringNoFilterTest.scala b/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpIndexToStringNoFilterTest.scala index d7b9f1da5f..c04937af45 100644 --- a/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpIndexToStringNoFilterTest.scala +++ b/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpIndexToStringNoFilterTest.scala @@ -30,53 +30,26 @@ package com.salesforce.op.stages.impl.feature -import com.salesforce.op._ import com.salesforce.op.features.types._ -import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext} +import com.salesforce.op.test.{OpTransformerSpec, TestFeatureBuilder} import com.salesforce.op.utils.spark.RichDataset._ import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner -import org.scalatest.{Assertions, FlatSpec, Matchers} @RunWith(classOf[JUnitRunner]) -class OpIndexToStringNoFilterTest extends FlatSpec with TestSparkContext { +class OpIndexToStringNoFilterTest extends OpTransformerSpec[Text, OpIndexToStringNoFilter] { + val (inputData, indF) = TestFeatureBuilder(Seq(0.0, 2.0, 1.0, 0.0, 0.0, 1.0).map(_.toRealNN)) + val labels = Array("a", "c") - val (ds, indF) = TestFeatureBuilder(Seq(0.0, 2.0, 1.0, 0.0, 0.0, 1.0).map(_.toRealNN)) - val labels = Array("a", "c", "b") - val expected = Array("a", "b", "c", "a", "a", "c").map(_.toText) + override val transformer: OpIndexToStringNoFilter = new OpIndexToStringNoFilter().setInput(indF).setLabels(labels) - val labelsNew = Array("a", "c") - val expectedNew = Array("a", OpIndexToStringNoFilter.unseenDefault, "c", "a", "a", "c").map(_.toText) + override val expectedResult: Seq[Text] = + Array("a", OpIndexToStringNoFilter.unseenDefault, "c", "a", "a", "c").map(_.toText) - Spec[OpIndexToStringNoFilter] should "correctly deindex a numeric column" in { - val indexToStr = new OpIndexToStringNoFilter().setInput(indF).setLabels(labels) - val strs = indexToStr.transform(ds).collect(indexToStr.getOutput()) - - strs shouldBe expected - } - - it should "correctly deindex a numeric column (shortcut)" in { - val str = indF.deindexed(labels) - val strs = str.originStage.asInstanceOf[OpIndexToStringNoFilter].transform(ds).collect(str) - strs shouldBe expected - - val str2 = indF.deindexed(labels, handleInvalid = IndexToStringHandleInvalid.Error) - val strs2 = str2.originStage.asInstanceOf[OpIndexToString].transform(ds).collect(str2) - strs2 shouldBe expected - } - - it should "correctly deindex even if the lables list does not match the number of indicies" in { - val indexToStr = new OpIndexToStringNoFilter().setInput(indF).setLabels(labelsNew) - val strs = indexToStr.transform(ds).collect(indexToStr.getOutput()) - - strs shouldBe expectedNew - } - - Spec[OpIndexToString] should "correctly deindex a numeric column" in { - val indexToStr = new OpIndexToString().setInput(indF).setLabels(labels) - val strs = indexToStr.transform(ds).collect(indexToStr.getOutput()) - - strs shouldBe expected + it should "correctly deindex a numeric column using shortcut" in { + val str2 = indF.deindexed(labels, handleInvalid = IndexToStringHandleInvalid.NoFilter) + val strs2 = str2.originStage.asInstanceOf[OpIndexToStringNoFilter].transform(inputData).collect(str2) + strs2 shouldBe expectedResult } } diff --git a/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpIndexToStringTest.scala b/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpIndexToStringTest.scala new file mode 100644 index 0000000000..e1637ebde6 --- /dev/null +++ b/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpIndexToStringTest.scala @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2017, Salesforce.com, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * * Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package com.salesforce.op.stages.impl.feature + +import com.salesforce.op.features.types._ +import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext} +import com.salesforce.op.utils.spark.RichDataset._ +import org.junit.runner.RunWith +import org.scalatest.FlatSpec +import org.scalatest.junit.JUnitRunner + +@RunWith(classOf[JUnitRunner]) +class OpIndexToStringTest extends FlatSpec with TestSparkContext { + + val (inputData, indF) = TestFeatureBuilder(Seq(0.0, 2.0, 1.0, 0.0, 0.0, 1.0).map(_.toRealNN)) + val labels = Array("a", "c", "b") + + val expectedResult: Seq[Text] = Array("a", "b", "c", "a", "a", "c").map(_.toText) + + val transformer: OpIndexToString = new OpIndexToString().setInput(indF).setLabels(labels) + + Spec[OpIndexToString] should "correctly deindex a numeric column" in { + val strs = transformer.transform(inputData).collect(transformer.getOutput()) + strs shouldBe expectedResult + } + + it should "correctly deindex a numeric column (shortcut)" in { + val str = indF.deindexed(labels, handleInvalid = IndexToStringHandleInvalid.Error) + val strs = str.originStage.asInstanceOf[OpIndexToString].transform(inputData).collect(str) + strs shouldBe expectedResult + } + + it should "getLabels" in { + transformer.getLabels shouldBe labels + } +} diff --git a/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpStringIndexerNoFilterTest.scala b/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpStringIndexerNoFilterTest.scala index b7590e02ce..f2e01d12d8 100644 --- a/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpStringIndexerNoFilterTest.scala +++ b/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpStringIndexerNoFilterTest.scala @@ -32,52 +32,48 @@ package com.salesforce.op.stages.impl.feature import com.salesforce.op._ import com.salesforce.op.features.types._ +import com.salesforce.op.stages.base.unary.UnaryModel import com.salesforce.op.stages.impl.feature.StringIndexerHandleInvalid.Skip import com.salesforce.op.stages.sparkwrappers.generic.SwUnaryModel -import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext} +import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder} import com.salesforce.op.utils.spark.RichDataset._ import org.apache.spark.ml.feature.StringIndexerModel import org.junit.runner.RunWith -import org.scalatest.FlatSpec import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) -class OpStringIndexerNoFilterTest extends FlatSpec with TestSparkContext { +class OpStringIndexerNoFilterTest extends + OpEstimatorSpec[RealNN, UnaryModel[Text, RealNN], OpStringIndexerNoFilter[Text]] { val txtData = Seq("a", "b", "c", "a", "a", "c").map(_.toText) - val (ds, txtF) = TestFeatureBuilder(txtData) - val expected = Array(0.0, 2.0, 1.0, 0.0, 0.0, 1.0).map(_.toRealNN) + val (inputData, txtF) = TestFeatureBuilder(txtData) + override val expectedResult: Seq[RealNN] = Array(0.0, 2.0, 1.0, 0.0, 0.0, 1.0).map(_.toRealNN) + + override val estimator: OpStringIndexerNoFilter[Text] = new OpStringIndexerNoFilter[Text]().setInput(txtF) val txtDataNew = Seq("a", "b", "c", "a", "a", "c", "d", "e").map(_.toText) - val (dsNew, txtFNew ) = TestFeatureBuilder(txtDataNew) + val (dsNew, txtFNew) = TestFeatureBuilder(txtDataNew) val expectedNew = Array(0.0, 2.0, 1.0, 0.0, 0.0, 1.0, 3.0, 3.0).map(_.toRealNN) - - Spec[OpStringIndexerNoFilter[_]] should "correctly index a text column" in { - val stringIndexer = new OpStringIndexerNoFilter[Text]().setInput(txtF) - val indices = stringIndexer.fit(ds).transform(ds).collect(stringIndexer.getOutput()) - - indices shouldBe expected - } - it should "correctly index a text column (shortcut)" in { val indexed = txtF.indexed() - val indices = indexed.originStage.asInstanceOf[OpStringIndexerNoFilter[_]].fit(ds).transform(ds).collect(indexed) - indices shouldBe expected + val indices = indexed.originStage.asInstanceOf[OpStringIndexerNoFilter[_]] + .fit(inputData).transform(inputData).collect(indexed) + indices shouldBe expectedResult val indexed2 = txtF.indexed(handleInvalid = Skip) - val indicesfit = indexed2.originStage.asInstanceOf[OpStringIndexer[_]].fit(ds) - val indices2 = indicesfit.transform(ds).collect(indexed2) + val indicesfit = indexed2.originStage.asInstanceOf[OpStringIndexer[_]].fit(inputData) + val indices2 = indicesfit.transform(inputData).collect(indexed2) val indices3 = indicesfit.asInstanceOf[SwUnaryModel[Text, RealNN, StringIndexerModel]] .setInput(txtFNew).transform(dsNew).collect(indexed2) - indices2 shouldBe expected - indices3 shouldBe expected + indices2 shouldBe expectedResult + indices3 shouldBe expectedResult } it should "correctly deinxed a numeric column" in { val indexed = txtF.indexed() - val indices = indexed.originStage.asInstanceOf[OpStringIndexerNoFilter[_]].fit(ds).transform(ds) + val indices = indexed.originStage.asInstanceOf[OpStringIndexerNoFilter[_]].fit(inputData).transform(inputData) val deindexed = indexed.deindexed() val deindexedData = deindexed.originStage.asInstanceOf[OpIndexToStringNoFilter] .transform(indices).collect(deindexed) @@ -85,9 +81,7 @@ class OpStringIndexerNoFilterTest extends FlatSpec with TestSparkContext { } it should "assign new strings to the unseen string category" in { - val stringIndexer = new OpStringIndexerNoFilter[Text]().setInput(txtF) - val indices = stringIndexer.fit(ds).setInput(txtFNew).transform(dsNew).collect(stringIndexer.getOutput()) - + val indices = estimator.fit(inputData).setInput(txtFNew).transform(dsNew).collect(estimator.getOutput()) indices shouldBe expectedNew } }