Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improved test SmartTextMapVectorizerTest #296

Merged
merged 7 commits into from
Apr 30, 2019
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,21 @@
package com.salesforce.op.stages.impl.feature

import com.salesforce.op._
import com.salesforce.op.features.types._
import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext}
import com.salesforce.op.stages.base.sequence.SequenceModel
import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder, TestSparkContext}
import com.salesforce.op.utils.spark.{OpVectorColumnMetadata, OpVectorMetadata}
import com.salesforce.op.utils.spark.RichDataset._
import com.salesforce.op.utils.spark.RichMetadata._
import org.apache.spark.ml.linalg.Vectors
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

import com.salesforce.op.features.types._

@RunWith(classOf[JUnitRunner])
class SmartTextMapVectorizerTest extends FlatSpec with TestSparkContext with AttributeAsserts {
lazy val (data, m1, m2, f1, f2) = TestFeatureBuilder("textMap1", "textMap2", "text1", "text2",
class SmartTextMapVectorizerTest
extends OpEstimatorSpec[OPVector, SequenceModel[TextMap, OPVector], SmartTextMapVectorizer[TextMap]]
with AttributeAsserts {

lazy val (inputData, m1, m2, f1, f2) = TestFeatureBuilder("textMap1", "textMap2", "text1", "text2",
Seq[(TextMap, TextMap, Text, Text)](
(TextMap(Map("text1" -> "hello world", "text2" -> "Hello world!")), TextMap.empty,
"hello world".toText, "Hello world!".toText),
Expand Down Expand Up @@ -71,6 +73,26 @@ class SmartTextMapVectorizerTest extends FlatSpec with TestSparkContext with Att
)
)

/**
* Estimator instance to be tested
*/
override val estimator: SmartTextMapVectorizer[TextMap] = new SmartTextMapVectorizer[TextMap]()
.setMaxCardinality(2).setNumFeatures(4).setMinSupport(1).setTopK(2).setPrependFeatureName(true)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's perhaps test the stage with all the default settings, i.e. just val estimator: SmartTextMapVectorizer[TextMap] = new SmartTextMapVectorizer[TextMap]()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wsuchy can you please address my comment so we can merge this?

.setCleanKeys(false)
.setInput(m1, m2)

/**
* Expected result of the transformer applied on the Input Dataset
*/
override val expectedResult: Seq[OPVector] = Seq(
Vectors.sparse(9, Array(0, 5, 7), Array(1.0, 1.0, 1.0)),
Vectors.sparse(9, Array(0, 8), Array(1.0, 1.0)),
Vectors.sparse(9, Array(1, 4), Array(1.0, 1.0)),
Vectors.sparse(9, Array(0, 4), Array(1.0, 2.0)),
Vectors.sparse(9, Array(3, 8), Array(1.0, 1.0))
).map(_.toOPVector)


Spec[TextMapStats] should "provide a proper semigroup" in {
val data = Seq(
TextMapStats(Map(
Expand All @@ -93,17 +115,14 @@ class SmartTextMapVectorizerTest extends FlatSpec with TestSparkContext with Att
)))
}

Spec[SmartTextMapVectorizer[_]] should "detect one categorical and one non-categorical text feature" in {
val smartMapVectorized = new SmartTextMapVectorizer[TextMap]()
.setMaxCardinality(2).setNumFeatures(4).setMinSupport(1).setTopK(2).setPrependFeatureName(true)
.setCleanKeys(false)
.setInput(m1, m2).getOutput()
it should "detect one categorical and one non-categorical text feature" in {
val smartMapVectorized = estimator.getOutput()

val smartVectorized = new SmartTextVectorizer()
.setMaxCardinality(2).setNumFeatures(4).setMinSupport(1).setTopK(2).setPrependFeatureName(true)
.setInput(f1, f2).getOutput()

val transformed = new OpWorkflow().setResultFeatures(smartMapVectorized, smartVectorized).transform(data)
val transformed = new OpWorkflow().setResultFeatures(smartMapVectorized, smartVectorized).transform(inputData)
val result = transformed.collect(smartMapVectorized, smartVectorized)
val field = transformed.schema(smartVectorized.name)
assertNominal(field, Array.fill(4)(true) ++ Array.fill(4)(false) :+ true, transformed.collect(smartVectorized))
Expand Down Expand Up @@ -136,7 +155,7 @@ class SmartTextMapVectorizerTest extends FlatSpec with TestSparkContext with Att
.setMaxCardinality(10).setNumFeatures(4).setMinSupport(1).setTopK(2).setPrependFeatureName(true)
.setInput(f1, f2).getOutput()

val transformed = new OpWorkflow().setResultFeatures(smartMapVectorized, smartVectorized).transform(data)
val transformed = new OpWorkflow().setResultFeatures(smartMapVectorized, smartVectorized).transform(inputData)
val result = transformed.collect(smartMapVectorized, smartVectorized)
val field = transformed.schema(smartVectorized.name)
val rSmart = transformed.collect(smartVectorized)
Expand Down Expand Up @@ -171,7 +190,7 @@ class SmartTextMapVectorizerTest extends FlatSpec with TestSparkContext with Att
.setHashSpaceStrategy(HashSpaceStrategy.Separate)
.setInput(f1, f2).getOutput()

val transformed = new OpWorkflow().setResultFeatures(smartMapVectorized, smartVectorized).transform(data)
val transformed = new OpWorkflow().setResultFeatures(smartMapVectorized, smartVectorized).transform(inputData)
val result = transformed.collect(smartMapVectorized, smartVectorized)
val field = transformed.schema(smartVectorized.name)
assertNominal(field, Array.fill(8)(false) ++ Array(true, true), transformed.collect(smartVectorized))
Expand Down Expand Up @@ -205,7 +224,7 @@ class SmartTextMapVectorizerTest extends FlatSpec with TestSparkContext with Att
.setNumFeatures(4).setHashSpaceStrategy(HashSpaceStrategy.Shared)
.setInput(f1, f2).getOutput()

val transformed = new OpWorkflow().setResultFeatures(smartMapVectorized, smartVectorized).transform(data)
val transformed = new OpWorkflow().setResultFeatures(smartMapVectorized, smartVectorized).transform(inputData)
val result = transformed.collect(smartMapVectorized, smartVectorized)
val field = transformed.schema(smartVectorized.name)
assertNominal(field, Array.fill(4)(false) ++ Array(true, true), transformed.collect(smartVectorized))
Expand Down Expand Up @@ -242,7 +261,7 @@ class SmartTextMapVectorizerTest extends FlatSpec with TestSparkContext with Att
.setNumFeatures(TransmogrifierDefaults.MaxNumOfFeatures).setHashSpaceStrategy(HashSpaceStrategy.Auto)
.setInput(f1, f2).getOutput()

val transformed = new OpWorkflow().setResultFeatures(smartMapVectorized, smartVectorized).transform(data)
val transformed = new OpWorkflow().setResultFeatures(smartMapVectorized, smartVectorized).transform(inputData)
val result = transformed.collect(smartMapVectorized, smartVectorized)
val field = transformed.schema(smartVectorized.name)
val rSmart = transformed.collect(smartVectorized)
Expand Down Expand Up @@ -282,7 +301,7 @@ class SmartTextMapVectorizerTest extends FlatSpec with TestSparkContext with Att
others = Array(m2)
)

val transformed = new OpWorkflow().setResultFeatures(smartMapVectorized, shortcutMapVectorized).transform(data)
val transformed = new OpWorkflow().setResultFeatures(smartMapVectorized, shortcutMapVectorized).transform(inputData)
val result = transformed.collect(smartMapVectorized, shortcutMapVectorized)
val field = transformed.schema(shortcutMapVectorized.name)
assertNominal(field, Array.fill(4)(true) ++ Array.fill(4)(false) :+ true,
Expand Down Expand Up @@ -316,7 +335,7 @@ class SmartTextMapVectorizerTest extends FlatSpec with TestSparkContext with Att
.setCleanKeys(false)
.setInput(m3, m4).getOutput()

val transformed = new OpWorkflow().setResultFeatures(textMapVectorized, textAreaMapVectorized).transform(data)
val transformed = new OpWorkflow().setResultFeatures(textMapVectorized, textAreaMapVectorized).transform(inputData)
val result = transformed.collect(textMapVectorized, textAreaMapVectorized)
val field = transformed.schema(textMapVectorized.name)
assertNominal(field, Array.fill(4)(true) ++ Array.fill(4)(false) :+ true,
Expand Down Expand Up @@ -352,7 +371,7 @@ class SmartTextMapVectorizerTest extends FlatSpec with TestSparkContext with Att
.setTrackTextLen(true)
.setInput(f1, f2).getOutput()

val transformed = new OpWorkflow().setResultFeatures(smartMapVectorized, smartVectorized).transform(data)
val transformed = new OpWorkflow().setResultFeatures(smartMapVectorized, smartVectorized).transform(inputData)
val result = transformed.collect(smartMapVectorized, smartVectorized)

val field = transformed.schema(smartVectorized.name)
Expand All @@ -376,4 +395,5 @@ class SmartTextMapVectorizerTest extends FlatSpec with TestSparkContext with Att

result.foreach { case (vec1, vec2) => vec1 shouldBe vec2 }
}

}