diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/feature/SmartTextMapVectorizer.scala b/core/src/main/scala/com/salesforce/op/stages/impl/feature/SmartTextMapVectorizer.scala index dbb4b43173..3990f01eb7 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/feature/SmartTextMapVectorizer.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/feature/SmartTextMapVectorizer.scala @@ -35,11 +35,10 @@ import com.salesforce.op.features.types._ import com.salesforce.op.stages.base.sequence.{SequenceEstimator, SequenceModel} import com.salesforce.op.stages.impl.feature.VectorizerUtils._ import com.salesforce.op.utils.json.JsonLike -import com.salesforce.op.utils.spark.RichDataset._ import com.salesforce.op.utils.spark.{OpVectorColumnMetadata, OpVectorMetadata} import com.twitter.algebird.Monoid._ import com.twitter.algebird.Operators._ -import com.twitter.algebird.{Monoid, Semigroup} +import com.twitter.algebird.Monoid import com.twitter.algebird.macros.caseclass import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.{Dataset, Encoder} @@ -138,9 +137,28 @@ class SmartTextMapVectorizer[T <: OPMap[String]] val shouldCleanValues = $(cleanText) val shouldTrackNulls = $(trackNulls) + val topKValue = $(topK) val allFeatureInfo = aggregatedStats.toSeq.map { textMapStats => textMapStats.keyValueCounts.toSeq.map { case (k, textStats) => + // Estimate the coverage of the top K values + val totalCount = textStats.valueCounts.values.sum // total count + // Filter by minimum support + val filteredStats = textStats.valueCounts.filter { case (_, count) => count >= $(minSupport) } + val sorted = filteredStats.toSeq.sortBy(- _._2) + val sortedValues = sorted.map(_._2) + // Cumulative Count + val cumCount = sortedValues.headOption.map(_ => sortedValues.tail.scanLeft(sortedValues.head)(_ + _)) + .getOrElse(Seq.empty) + val coverage = cumCount.lift(math.min(topKValue, cumCount.length) - 1).getOrElse(0L) * 1.0 / totalCount val vecMethod: TextVectorizationMethod = textStats match { + // If cardinality not respected, but coverage is, then pivot the feature + // Extra checks need to be passed : + // + // - Cardinality must be greater than maxCard (already mentioned above). + // - Cardinality must also be greater than topK. + // - Finally, the computed coverage of the topK with minimum support must be > 0. + case _ if textStats.valueCounts.size > maxCard && textStats.valueCounts.size > topKValue && + coverage >= $(coveragePct) => TextVectorizationMethod.Pivot case _ if textStats.valueCounts.size <= maxCard => TextVectorizationMethod.Pivot case _ if textStats.lengthStdDev < minLenStdDev => TextVectorizationMethod.Ignore case _ => TextVectorizationMethod.Hash diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/feature/SmartTextVectorizer.scala b/core/src/main/scala/com/salesforce/op/stages/impl/feature/SmartTextVectorizer.scala index 37347b24c1..2874aef29e 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/feature/SmartTextVectorizer.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/feature/SmartTextVectorizer.scala @@ -32,11 +32,10 @@ package com.salesforce.op.stages.impl.feature import com.salesforce.op.UID import com.salesforce.op.features.TransientFeature -import com.salesforce.op.features.types.{OPVector, Text, TextList, VectorConversions, SeqDoubleConversions} +import com.salesforce.op.features.types._ import com.salesforce.op.stages.base.sequence.{SequenceEstimator, SequenceModel} import com.salesforce.op.stages.impl.feature.VectorizerUtils._ import com.salesforce.op.utils.json.JsonLike -import com.salesforce.op.utils.spark.RichDataset._ import com.salesforce.op.utils.spark.{OpVectorColumnMetadata, OpVectorMetadata} import com.twitter.algebird.Monoid import com.twitter.algebird.Monoid._ @@ -91,16 +90,34 @@ class SmartTextVectorizer[T <: Text](uid: String = UID[SmartTextVectorizer[T]])( ) val aggregatedStats: Array[TextStats] = valueStats.reduce(_ + _) + val topKValue = $(topK) val (vectorizationMethods, topValues) = aggregatedStats.map { stats => + // Estimate the coverage of the top K values + val totalCount = stats.valueCounts.values.sum // total count + // Filter by minimum support + val filteredStats = stats.valueCounts.filter { case (_, count) => count >= $(minSupport) } + val sorted = filteredStats.toSeq.sortBy(- _._2) + val sortedValues = sorted.map(_._2) + // Cumulative Count + val cumCount = sortedValues.headOption.map(_ => sortedValues.tail.scanLeft(sortedValues.head)(_ + _)) + .getOrElse(Seq.empty) + val coverage = cumCount.lift(math.min(topKValue, cumCount.length) - 1).getOrElse(0L) * 1.0 / totalCount val vecMethod: TextVectorizationMethod = stats match { + // If cardinality not respected, but coverage is, then pivot the feature + // Extra checks need to be passed : + // + // - Cardinality must be greater than maxCard (already mentioned above). + // - Cardinality must also be greater than topK. + // - Finally, the computed coverage of the topK with minimum support must be > 0. + case _ if stats.valueCounts.size > maxCard && stats.valueCounts.size > topKValue && coverage >= $(coveragePct) + => TextVectorizationMethod.Pivot case _ if stats.valueCounts.size <= maxCard => TextVectorizationMethod.Pivot case _ if stats.lengthStdDev < minLenStdDev => TextVectorizationMethod.Ignore case _ => TextVectorizationMethod.Hash } - val topValues = stats.valueCounts - .filter { case (_, count) => count >= $(minSupport) } + val topValues = filteredStats .toSeq.sortBy(v => -v._2 -> v._1) - .take($(topK)).map(_._1) + .take(topKValue).map(_._1) (vecMethod, topValues) }.unzip @@ -169,9 +186,10 @@ class SmartTextVectorizer[T <: Text](uid: String = UID[SmartTextVectorizer[T]])( } object SmartTextVectorizer { - val MaxCardinality: Int = 100 + val MaxCardinality: Int = 1000 val MinTextLengthStdDev: Double = 0 val LengthType: TextLengthType = TextLengthType.FullEntry + val CoveragePct: Double = 0.90 private[op] def partition[T: ClassTag](input: Array[T], condition: Array[Boolean]): (Array[T], Array[T]) = { val all = input.zip(condition) @@ -365,6 +383,16 @@ trait MaxCardinalityParams extends Params { final def setMaxCardinality(v: Int): this.type = set(maxCardinality, v) final def getMaxCardinality: Int = $(maxCardinality) setDefault(maxCardinality -> SmartTextVectorizer.MaxCardinality) + + final val coveragePct = new DoubleParam( + parent = this, name = "coveragePct", + doc = "Threshold of percentage of the entries. If the topK entries make up for more than this pecentage," + + " the feature is treated as a categorical", + isValid = ParamValidators.inRange(lowerBound = 0, upperBound = 1, lowerInclusive = false, upperInclusive = true) + ) + final def setCoveragePct(v: Double): this.type = set(coveragePct, v) + final def getCoveragePct: Double = $(coveragePct) + setDefault(coveragePct -> SmartTextVectorizer.CoveragePct) } trait MinLengthStdDevParams extends Params { diff --git a/core/src/test/scala/com/salesforce/op/stages/impl/feature/SmartTextMapVectorizerTest.scala b/core/src/test/scala/com/salesforce/op/stages/impl/feature/SmartTextMapVectorizerTest.scala index ca1f62f98c..190826675a 100644 --- a/core/src/test/scala/com/salesforce/op/stages/impl/feature/SmartTextMapVectorizerTest.scala +++ b/core/src/test/scala/com/salesforce/op/stages/impl/feature/SmartTextMapVectorizerTest.scala @@ -31,6 +31,7 @@ package com.salesforce.op.stages.impl.feature import com.salesforce.op._ +import com.salesforce.op.features.FeatureLike import com.salesforce.op.stages.base.sequence.SequenceModel import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder} import com.salesforce.op.utils.spark.{OpVectorColumnMetadata, OpVectorMetadata} @@ -39,8 +40,11 @@ import org.apache.spark.ml.linalg.Vectors import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import com.salesforce.op.features.types._ +import com.salesforce.op.stages.impl.feature.TextVectorizationMethod.{Hash, Pivot} import com.salesforce.op.testkit.RandomText +import org.apache.spark.sql.{DataFrame, Encoder, Encoders} import org.scalatest.Assertion +import scala.util.Random @RunWith(classOf[JUnitRunner]) class SmartTextMapVectorizerTest @@ -120,6 +124,11 @@ class SmartTextMapVectorizerTest val tokensMap = textMapData.mapValues(s => TextTokenizer.tokenizeString(s).tokens) val tol = 1e-12 // Tolerance for comparing real numbers + val oneCountryData: Seq[Text] = Seq.fill(1000)(Text("United States")) + val categoricalCountryData = Random.shuffle(oneCountryData ++ countryData) + val countryMapData = categoricalCountryData.map { case country => mapifyText(Seq(country)) } + val (countryMapDF, rawCatCountryMap) = TestFeatureBuilder("rawCatCountryMap", countryMapData) + /** * Estimator instance to be tested */ @@ -137,6 +146,7 @@ class SmartTextMapVectorizerTest Vectors.dense(Array(0.0, 1.0, 0.0, 1.0)) ).map(_.toOPVector) + import spark.sqlContext.implicits._ Spec[TextMapStats] should "provide a proper semigroup" in { val data = Seq( @@ -589,6 +599,100 @@ class SmartTextMapVectorizerTest meta.columns.slice(18, 21).forall(_.indicatorValue.contains(OpVectorColumnMetadata.NullString)) } + it should "treat the edge case of coverage being near 0" in { + val maxCard = 100 + val vectorizer = new SmartTextMapVectorizer().setCoveragePct(1e-10).setMaxCardinality(maxCard).setMinSupport(1) + .setTrackTextLen(true).setInput(rawCatCountryMap) + val output = vectorizer.getOutput() + val transformed = new OpWorkflow().setResultFeatures(output).transform(countryMapDF) + assertVectorLength(transformed, output, TransmogrifierDefaults.TopK + 2, Pivot) + } + it should "treat the edge case of coverage being near 1" in { + val maxCard = 100 + val vectorizer = new SmartTextMapVectorizer().setCoveragePct(1.0 - 1e-10).setMaxCardinality(maxCard) + .setMinSupport(1) + .setTrackTextLen(true).setInput(rawCatCountryMap) + val output = vectorizer.getOutput() + val transformed = new OpWorkflow().setResultFeatures(output).transform(countryMapDF) + assertVectorLength(transformed, output, TransmogrifierDefaults.DefaultNumOfFeatures + 2, Hash) + } + + it should "detect one categorical with high cardinality using the coverage" in { + val maxCard = 100 + val topK = 10 + val cardinality = countryMapDF.select(rawCatCountryMap).as[TextMap#Value].map(_("f0")).distinct().count().toInt + cardinality should be > maxCard + cardinality should be > topK + val vectorizer = new SmartTextMapVectorizer() + .setMaxCardinality(maxCard).setTopK(topK).setMinSupport(1).setCoveragePct(0.5).setCleanText(false) + .setInput(rawCatCountryMap) + val output = vectorizer.getOutput() + val transformed = new OpWorkflow().setResultFeatures(output).transform(countryMapDF) + assertVectorLength(transformed, output, topK + 2, Pivot) + } + + it should "not pivot using the coverage because of a high minimum support" in { + val maxCard = 100 + val topK = 10 + val minSupport = 99999 + val numHashes = 5 + val cardinality = countryMapDF.select(rawCatCountryMap).as[TextMap#Value].map(_("f0")).distinct().count().toInt + cardinality should be > maxCard + cardinality should be > topK + val vectorizer = new SmartTextMapVectorizer() + .setMaxCardinality(maxCard).setTopK(topK).setMinSupport(minSupport).setNumFeatures(numHashes).setCoveragePct(0.5) + .setTrackTextLen(true).setCleanText(false).setInput(rawCatCountryMap) + val output = vectorizer.getOutput() + val transformed = new OpWorkflow().setResultFeatures(output).transform(countryMapDF) + assertVectorLength(transformed, output, numHashes + 2, Hash) + } + + it should "still pivot using the coverage despite a high minimum support" in { + val maxCard = 100 + val topK = 10 + val minSupport = 100 + val numHashes = 5 + val cardinality = countryMapDF.select(rawCatCountryMap).as[TextMap#Value].map(_("f0")).distinct().count().toInt + cardinality should be > maxCard + cardinality should be > topK + val vectorizer = new SmartTextMapVectorizer() + .setMaxCardinality(maxCard).setTopK(topK).setMinSupport(minSupport).setNumFeatures(numHashes).setCoveragePct(0.5) + .setTrackTextLen(true).setCleanText(false).setInput(rawCatCountryMap) + val output = vectorizer.getOutput() + val transformed = new OpWorkflow().setResultFeatures(output).transform(countryMapDF) + assertVectorLength(transformed, output, 3, Pivot) + } + + it should "not pivot using the coverage because top K is too high" in { + val maxCard = 100 + val topK = 1000000 + val numHashes = 5 + val cardinality = countryMapDF.select(rawCatCountryMap).as[TextMap#Value].map(_("f0")).distinct().count().toInt + cardinality should be > maxCard + cardinality should be <= topK + val vectorizer = new SmartTextMapVectorizer() + .setMaxCardinality(maxCard).setTopK(topK).setNumFeatures(numHashes).setCoveragePct(0.5) + .setTrackTextLen(true).setCleanText(false).setInput(rawCatCountryMap) + val output = vectorizer.getOutput() + val transformed = new OpWorkflow().setResultFeatures(output).transform(countryMapDF) + assertVectorLength(transformed, output, numHashes + 2, Hash) + } + + it should "still transform country into text, despite the coverage" in { + val maxCard = 100 + val topK = 10 + val numHashes = 5 + val cardinality = rawDFSeparateMaps.select(rawTextMap1).as[TextMap#Value].map(_.get("f0")).distinct().count().toInt + cardinality should be > maxCard + cardinality should be > topK + val coverageHashed = new SmartTextMapVectorizer() + .setMaxCardinality(maxCard).setTopK(topK).setMinSupport(1).setCoveragePct(0.5).setCleanText(false) + .setTrackTextLen(true).setNumFeatures(numHashes).setInput(rawTextMap1).getOutput() + val transformed = new OpWorkflow().setResultFeatures(coverageHashed).transform(rawDFSeparateMaps) + val expectedLength = numHashes + 2 + categoricalTextData.toSet.filter(!_.isEmpty).toSeq.length + 2 + assertVectorLength(transformed, coverageHashed, expectedLength , Hash) + } + it should "create a TextStats object from text that makes sense" in { val res = TextMapStats.computeTextMapStats[TextMap]( textMapData, @@ -668,6 +772,24 @@ class SmartTextMapVectorizerTest checkDerivedQuantities(res, "f2", Seq(4, 5, 5, 5, 3).map(_.toLong)) } + private[op] def assertVectorLength(df: DataFrame, output: FeatureLike[OPVector], + expectedLength: Int, textVectorizationMethod: TextVectorizationMethod): Unit = { + val result = df.collect(output) + val firstRes = result.head + val metaColumns = OpVectorMetadata(df.schema(output.name)).columns + + firstRes.v.size shouldBe expectedLength + metaColumns.length shouldBe expectedLength + + textVectorizationMethod match { + case Pivot => assert(metaColumns(expectedLength - 2).indicatorValue.contains(OpVectorColumnMetadata.OtherString)) + case Hash => assert(metaColumns(expectedLength - 2).descriptorValue + .contains(OpVectorColumnMetadata.TextLenString)) + case other => throw new Error(s"Only Pivoting or Hashing possible, got ${other} instead") + } + assert(metaColumns(expectedLength - 1).indicatorValue.contains(OpVectorColumnMetadata.NullString)) + } + /** * Set of tests to check that the derived quantities calculated on the length distribution in TextMapStats (for * a single key) match the actual length distributions of the tokens. diff --git a/core/src/test/scala/com/salesforce/op/stages/impl/feature/SmartTextVectorizerTest.scala b/core/src/test/scala/com/salesforce/op/stages/impl/feature/SmartTextVectorizerTest.scala index 0cad973908..3f12702c1b 100644 --- a/core/src/test/scala/com/salesforce/op/stages/impl/feature/SmartTextVectorizerTest.scala +++ b/core/src/test/scala/com/salesforce/op/stages/impl/feature/SmartTextVectorizerTest.scala @@ -31,17 +31,22 @@ package com.salesforce.op.stages.impl.feature import com.salesforce.op._ +import com.salesforce.op.features.FeatureLike import com.salesforce.op.features.types._ import com.salesforce.op.stages.base.sequence.SequenceModel +import com.salesforce.op.stages.impl.feature.TextVectorizationMethod.{Hash, Pivot} import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder} import com.salesforce.op.testkit.RandomText import com.salesforce.op.utils.spark.RichDataset._ import com.salesforce.op.utils.spark.{OpVectorColumnMetadata, OpVectorMetadata} import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.sql.DataFrame import org.junit.runner.RunWith import org.scalatest.Assertion import org.scalatest.junit.JUnitRunner +import scala.util.Random + @RunWith(classOf[JUnitRunner]) class SmartTextVectorizerTest @@ -98,6 +103,10 @@ class SmartTextVectorizerTest "Here they are ALL standing in a row." val tol = 1e-12 // Tolerance for comparing real numbers + val oneCountryData: Seq[Text] = Seq.fill(1000)(Text("United States")) + val categoricalCountryData = Random.shuffle(oneCountryData ++ countryData) + val (countryDF, rawCatCountry) = TestFeatureBuilder("rawCatCountry", categoricalCountryData) + it should "detect one categorical and one non-categorical text feature" in { val smartVectorized = new SmartTextVectorizer() .setMaxCardinality(2).setNumFeatures(4).setMinSupport(1).setTopK(2).setPrependFeatureName(false) @@ -235,6 +244,7 @@ class SmartTextVectorizerTest meta.columns.slice(18, 21).forall(_.indicatorValue.contains(OpVectorColumnMetadata.NullString)) } + it should "detect and ignore fields that looks like machine-generated IDs by having a low value length variance" in { val topKCategorial = 3 val hashSize = 5 @@ -270,6 +280,98 @@ class SmartTextVectorizerTest meta.columns.slice(18, 21).forall(_.indicatorValue.contains(OpVectorColumnMetadata.NullString)) } + it should "treat the edge case of coverage being near 0" in { + val maxCard = 100 + val vectorizer = new SmartTextVectorizer().setCoveragePct(1e-10).setMaxCardinality(maxCard).setMinSupport(1) + .setTrackTextLen(true).setInput(rawCatCountry) + val output = vectorizer.getOutput() + val transformed = new OpWorkflow().setResultFeatures(output).transform(countryDF) + assertVectorLength(transformed, output, TransmogrifierDefaults.TopK + 2, Pivot) + } + it should "treat the edge case of coverage being near 1" in { + val maxCard = 100 + val vectorizer = new SmartTextVectorizer().setCoveragePct(1.0 - 1e-10).setMaxCardinality(maxCard).setMinSupport(1) + .setTrackTextLen(true).setInput(rawCatCountry) + val output = vectorizer.getOutput() + val transformed = new OpWorkflow().setResultFeatures(output).transform(countryDF) + assertVectorLength(transformed, output, TransmogrifierDefaults.DefaultNumOfFeatures + 2, Hash) + } + + it should "detect one categorical with high cardinality using the coverage" in { + val maxCard = 100 + val topK = 10 + val cardinality = countryDF.select(rawCatCountry).distinct().count().toInt + cardinality should be > maxCard + cardinality should be > topK + val vectorizer = new SmartTextVectorizer() + .setMaxCardinality(maxCard).setTopK(topK).setMinSupport(1).setCoveragePct(0.5).setCleanText(false) + .setInput(rawCatCountry) + val output = vectorizer.getOutput() + val transformed = new OpWorkflow().setResultFeatures(output).transform(countryDF) + assertVectorLength(transformed, output, topK + 2, Pivot) + } + + it should "not pivot using the coverage because of a high minimum support" in { + val maxCard = 100 + val topK = 10 + val minSupport = 99999 + val numHashes = 5 + val cardinality = countryDF.select(rawCatCountry).distinct().count().toInt + cardinality should be > maxCard + cardinality should be > topK + val vectorizer = new SmartTextVectorizer() + .setMaxCardinality(maxCard).setTopK(topK).setMinSupport(minSupport).setNumFeatures(numHashes).setCoveragePct(0.5) + .setTrackTextLen(true).setCleanText(false).setInput(rawCatCountry) + val output = vectorizer.getOutput() + val transformed = new OpWorkflow().setResultFeatures(output).transform(countryDF) + assertVectorLength(transformed, output, numHashes + 2, Hash) + } + + it should "still pivot using the coverage despite a high minimum support" in { + val maxCard = 100 + val topK = 10 + val minSupport = 100 + val numHashes = 5 + val cardinality = countryDF.select(rawCatCountry).distinct().count().toInt + cardinality should be > maxCard + cardinality should be > topK + val vectorizer = new SmartTextVectorizer() + .setMaxCardinality(maxCard).setTopK(topK).setMinSupport(minSupport).setNumFeatures(numHashes).setCoveragePct(0.5) + .setTrackTextLen(true).setCleanText(false).setInput(rawCatCountry) + val output = vectorizer.getOutput() + val transformed = new OpWorkflow().setResultFeatures(output).transform(countryDF) + assertVectorLength(transformed, output, 3, Pivot) + } + + it should "not pivot using the coverage because top K is too high" in { + val maxCard = 100 + val topK = 1000000 + val numHashes = 5 + val cardinality = countryDF.select(rawCatCountry).distinct().count().toInt + cardinality should be > maxCard + cardinality should be <= topK + val vectorizer = new SmartTextVectorizer() + .setMaxCardinality(maxCard).setTopK(topK).setNumFeatures(numHashes).setCoveragePct(0.5) + .setTrackTextLen(true).setCleanText(false).setInput(rawCatCountry) + val output = vectorizer.getOutput() + val transformed = new OpWorkflow().setResultFeatures(output).transform(countryDF) + assertVectorLength(transformed, output, numHashes + 2, Hash) + } + + it should "still transform country into text, despite the coverage" in { + val maxCard = 100 + val topK = 10 + val numHashes = 5 + val cardinality = rawDF.select(rawCountry).distinct().count().toInt + cardinality should be > maxCard + cardinality should be > topK + val coverageHashed = new SmartTextVectorizer() + .setMaxCardinality(maxCard).setTopK(topK).setMinSupport(1).setCoveragePct(0.5).setCleanText(false) + .setTrackTextLen(true).setNumFeatures(numHashes).setInput(rawCountry).getOutput() + val transformed = new OpWorkflow().setResultFeatures(coverageHashed).transform(rawDF) + assertVectorLength(transformed, coverageHashed, numHashes + 2, Hash) + } + it should "fail with an error" in { val emptyDF = inputData.filter(inputData("text1") === "").toDF() @@ -544,6 +646,23 @@ class SmartTextVectorizerTest checkDerivedQuantities(res, Seq(58).map(_.toLong)) } + private[op] def assertVectorLength(df: DataFrame, output: FeatureLike[OPVector], + expectedLength: Int, textVectorizationMethod: TextVectorizationMethod): Unit = { + val result = df.collect(output) + val firstRes = result.head + firstRes.v.size shouldBe expectedLength + val metaColumns = OpVectorMetadata(df.schema(output.name)).columns + metaColumns.length shouldBe expectedLength + textVectorizationMethod match { + case Pivot => assert(metaColumns(expectedLength - 2).indicatorValue.contains(OpVectorColumnMetadata.OtherString)) + case Hash => assert(metaColumns(expectedLength - 2).descriptorValue + .contains(OpVectorColumnMetadata.TextLenString)) + case other => throw new Error(s"Only Pivoting or Hashing possible, got ${other} instead") + } + assert(metaColumns(expectedLength - 1).indicatorValue.contains(OpVectorColumnMetadata.NullString)) + } + + /** * Set of tests to check that the derived quantities calculated on the length distribution in TextStats * match the actual length distributions of the tokens. diff --git a/core/src/test/scala/com/salesforce/op/stages/impl/preparators/SanityCheckerTest.scala b/core/src/test/scala/com/salesforce/op/stages/impl/preparators/SanityCheckerTest.scala index 70d7d420a2..cfc0c9168c 100644 --- a/core/src/test/scala/com/salesforce/op/stages/impl/preparators/SanityCheckerTest.scala +++ b/core/src/test/scala/com/salesforce/op/stages/impl/preparators/SanityCheckerTest.scala @@ -357,6 +357,7 @@ class SanityCheckerTest extends OpEstimatorSpec[OPVector, BinaryModel[RealNN, OP it should "remove individual text hash features independently" in { val smartMapVectorized = new SmartTextMapVectorizer[TextMap]() .setMaxCardinality(2).setNumFeatures(8).setMinSupport(1).setTopK(2).setPrependFeatureName(true) + .setCoveragePct(1.0) .setHashSpaceStrategy(HashSpaceStrategy.Shared) .setInput(textMap).getOutput() @@ -391,6 +392,7 @@ class SanityCheckerTest extends OpEstimatorSpec[OPVector, BinaryModel[RealNN, OP val smartMapVectorized = new SmartTextMapVectorizer[TextMap]() .setMaxCardinality(2).setNumFeatures(4).setMinSupport(1).setTopK(2).setPrependFeatureName(true) .setHashSpaceStrategy(HashSpaceStrategy.Separate) + .setCoveragePct(1.0) .setInput(textMap).getOutput() val checkedFeatures = new SanityChecker() @@ -432,6 +434,7 @@ class SanityCheckerTest extends OpEstimatorSpec[OPVector, BinaryModel[RealNN, OP val smartMapVectorized = new SmartTextMapVectorizer[TextMap]() .setMaxCardinality(2).setNumFeatures(8).setMinSupport(1).setTopK(2).setPrependFeatureName(true) .setHashSpaceStrategy(HashSpaceStrategy.Shared) + .setCoveragePct(1.0) .setInput(textMap).getOutput() val checkedFeatures = new SanityChecker() @@ -502,6 +505,7 @@ class SanityCheckerTest extends OpEstimatorSpec[OPVector, BinaryModel[RealNN, OP val smartMapVectorized = new SmartTextMapVectorizer[TextMap]() .setMaxCardinality(2).setNumFeatures(8).setMinSupport(1).setTopK(2).setPrependFeatureName(true) .setHashSpaceStrategy(HashSpaceStrategy.Shared) + .setCoveragePct(1.0) .setInput(textMap).getOutput() val checkedFeatures = new SanityChecker()