Skip to content

Commit

Permalink
Add categorical detection to be coverage based in addition to unique …
Browse files Browse the repository at this point in the history
…count based (#473)

Related issues
Currently SmartTextVectorizer and SmartTextMapVectorizer will count the number of unique entries in a text field (up to a threshold, currently 50) and treat the feature as categorical if it has < 50 unique entries.
You can still run into features that are effectively categorical, but may have a long tail of low-frequency entries. We would get better signal extraction if we treated these as categorical instead of hashing them.

Describe the proposed solution
Adding an extra check for Text(Map) features in order to become categoricals. This only applies to features that have a cardinality higher than the threshold and therefore would be hashed.

A better approach to detecting text features that are really categorical would be to use a coverage criteria. For example, the topK entries with minimum support cover at least 90% of the entries, then this would be a good feature to pivot by entry instead of hash by token. The value of 90% can be tuned by the user thanks to a param.

Extra checks need to be passed :

Cardinality must be greater than maxCard (already mentioned above).
Cardinality must also be greater than topK.
Finally, the computed coverage of the topK with minimum support must be > 0.
If there is m < TopK elements with the required minimum support, then we are looking at the coverage of these m elements.

Describe alternatives you've considered
I've considered using Algebird Count Min Sketch in order to compute the current TextStats.
However I ran into multiple issue :

Lack of transparency: TopNCMS only returns the "HeavyHitters" however you need much more than that(e.g. cardinality) in order to use the coverage method.
Serialization issues when writing to JSON
A branch still exists : mw/coverage, but it is in shambles.

Additional context
Some criticism regarding TextStats. It seems not to be a semi group as it is not associative. Was it intended?


* First Logic

* First tests

* Second Tests

* Removing prin ts

* fix test

* Fix 2

* Adding comments

* Line change

* Adding more comments

* Removing useless condition

* Fixing tests

Co-authored-by: Michael Weil <mweil@salesforce.com>
  • Loading branch information
michaelweilsalesforce and mweilsalesforce committed May 14, 2020
1 parent 7d0c33e commit 24cdbc4
Show file tree
Hide file tree
Showing 5 changed files with 299 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,10 @@ import com.salesforce.op.features.types._
import com.salesforce.op.stages.base.sequence.{SequenceEstimator, SequenceModel}
import com.salesforce.op.stages.impl.feature.VectorizerUtils._
import com.salesforce.op.utils.json.JsonLike
import com.salesforce.op.utils.spark.RichDataset._
import com.salesforce.op.utils.spark.{OpVectorColumnMetadata, OpVectorMetadata}
import com.twitter.algebird.Monoid._
import com.twitter.algebird.Operators._
import com.twitter.algebird.{Monoid, Semigroup}
import com.twitter.algebird.Monoid
import com.twitter.algebird.macros.caseclass
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.{Dataset, Encoder}
Expand Down Expand Up @@ -138,9 +137,28 @@ class SmartTextMapVectorizer[T <: OPMap[String]]
val shouldCleanValues = $(cleanText)
val shouldTrackNulls = $(trackNulls)

val topKValue = $(topK)
val allFeatureInfo = aggregatedStats.toSeq.map { textMapStats =>
textMapStats.keyValueCounts.toSeq.map { case (k, textStats) =>
// Estimate the coverage of the top K values
val totalCount = textStats.valueCounts.values.sum // total count
// Filter by minimum support
val filteredStats = textStats.valueCounts.filter { case (_, count) => count >= $(minSupport) }
val sorted = filteredStats.toSeq.sortBy(- _._2)
val sortedValues = sorted.map(_._2)
// Cumulative Count
val cumCount = sortedValues.headOption.map(_ => sortedValues.tail.scanLeft(sortedValues.head)(_ + _))
.getOrElse(Seq.empty)
val coverage = cumCount.lift(math.min(topKValue, cumCount.length) - 1).getOrElse(0L) * 1.0 / totalCount
val vecMethod: TextVectorizationMethod = textStats match {
// If cardinality not respected, but coverage is, then pivot the feature
// Extra checks need to be passed :
//
// - Cardinality must be greater than maxCard (already mentioned above).
// - Cardinality must also be greater than topK.
// - Finally, the computed coverage of the topK with minimum support must be > 0.
case _ if textStats.valueCounts.size > maxCard && textStats.valueCounts.size > topKValue &&
coverage >= $(coveragePct) => TextVectorizationMethod.Pivot
case _ if textStats.valueCounts.size <= maxCard => TextVectorizationMethod.Pivot
case _ if textStats.lengthStdDev < minLenStdDev => TextVectorizationMethod.Ignore
case _ => TextVectorizationMethod.Hash
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,10 @@ package com.salesforce.op.stages.impl.feature

import com.salesforce.op.UID
import com.salesforce.op.features.TransientFeature
import com.salesforce.op.features.types.{OPVector, Text, TextList, VectorConversions, SeqDoubleConversions}
import com.salesforce.op.features.types._
import com.salesforce.op.stages.base.sequence.{SequenceEstimator, SequenceModel}
import com.salesforce.op.stages.impl.feature.VectorizerUtils._
import com.salesforce.op.utils.json.JsonLike
import com.salesforce.op.utils.spark.RichDataset._
import com.salesforce.op.utils.spark.{OpVectorColumnMetadata, OpVectorMetadata}
import com.twitter.algebird.Monoid
import com.twitter.algebird.Monoid._
Expand Down Expand Up @@ -91,16 +90,34 @@ class SmartTextVectorizer[T <: Text](uid: String = UID[SmartTextVectorizer[T]])(
)
val aggregatedStats: Array[TextStats] = valueStats.reduce(_ + _)

val topKValue = $(topK)
val (vectorizationMethods, topValues) = aggregatedStats.map { stats =>
// Estimate the coverage of the top K values
val totalCount = stats.valueCounts.values.sum // total count
// Filter by minimum support
val filteredStats = stats.valueCounts.filter { case (_, count) => count >= $(minSupport) }
val sorted = filteredStats.toSeq.sortBy(- _._2)
val sortedValues = sorted.map(_._2)
// Cumulative Count
val cumCount = sortedValues.headOption.map(_ => sortedValues.tail.scanLeft(sortedValues.head)(_ + _))
.getOrElse(Seq.empty)
val coverage = cumCount.lift(math.min(topKValue, cumCount.length) - 1).getOrElse(0L) * 1.0 / totalCount
val vecMethod: TextVectorizationMethod = stats match {
// If cardinality not respected, but coverage is, then pivot the feature
// Extra checks need to be passed :
//
// - Cardinality must be greater than maxCard (already mentioned above).
// - Cardinality must also be greater than topK.
// - Finally, the computed coverage of the topK with minimum support must be > 0.
case _ if stats.valueCounts.size > maxCard && stats.valueCounts.size > topKValue && coverage >= $(coveragePct)
=> TextVectorizationMethod.Pivot
case _ if stats.valueCounts.size <= maxCard => TextVectorizationMethod.Pivot
case _ if stats.lengthStdDev < minLenStdDev => TextVectorizationMethod.Ignore
case _ => TextVectorizationMethod.Hash
}
val topValues = stats.valueCounts
.filter { case (_, count) => count >= $(minSupport) }
val topValues = filteredStats
.toSeq.sortBy(v => -v._2 -> v._1)
.take($(topK)).map(_._1)
.take(topKValue).map(_._1)
(vecMethod, topValues)
}.unzip

Expand Down Expand Up @@ -169,9 +186,10 @@ class SmartTextVectorizer[T <: Text](uid: String = UID[SmartTextVectorizer[T]])(
}

object SmartTextVectorizer {
val MaxCardinality: Int = 100
val MaxCardinality: Int = 1000
val MinTextLengthStdDev: Double = 0
val LengthType: TextLengthType = TextLengthType.FullEntry
val CoveragePct: Double = 0.90

private[op] def partition[T: ClassTag](input: Array[T], condition: Array[Boolean]): (Array[T], Array[T]) = {
val all = input.zip(condition)
Expand Down Expand Up @@ -365,6 +383,16 @@ trait MaxCardinalityParams extends Params {
final def setMaxCardinality(v: Int): this.type = set(maxCardinality, v)
final def getMaxCardinality: Int = $(maxCardinality)
setDefault(maxCardinality -> SmartTextVectorizer.MaxCardinality)

final val coveragePct = new DoubleParam(
parent = this, name = "coveragePct",
doc = "Threshold of percentage of the entries. If the topK entries make up for more than this pecentage," +
" the feature is treated as a categorical",
isValid = ParamValidators.inRange(lowerBound = 0, upperBound = 1, lowerInclusive = false, upperInclusive = true)
)
final def setCoveragePct(v: Double): this.type = set(coveragePct, v)
final def getCoveragePct: Double = $(coveragePct)
setDefault(coveragePct -> SmartTextVectorizer.CoveragePct)
}

trait MinLengthStdDevParams extends Params {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
package com.salesforce.op.stages.impl.feature

import com.salesforce.op._
import com.salesforce.op.features.FeatureLike
import com.salesforce.op.stages.base.sequence.SequenceModel
import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder}
import com.salesforce.op.utils.spark.{OpVectorColumnMetadata, OpVectorMetadata}
Expand All @@ -39,8 +40,11 @@ import org.apache.spark.ml.linalg.Vectors
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import com.salesforce.op.features.types._
import com.salesforce.op.stages.impl.feature.TextVectorizationMethod.{Hash, Pivot}
import com.salesforce.op.testkit.RandomText
import org.apache.spark.sql.{DataFrame, Encoder, Encoders}
import org.scalatest.Assertion
import scala.util.Random

@RunWith(classOf[JUnitRunner])
class SmartTextMapVectorizerTest
Expand Down Expand Up @@ -120,6 +124,11 @@ class SmartTextMapVectorizerTest
val tokensMap = textMapData.mapValues(s => TextTokenizer.tokenizeString(s).tokens)
val tol = 1e-12 // Tolerance for comparing real numbers

val oneCountryData: Seq[Text] = Seq.fill(1000)(Text("United States"))
val categoricalCountryData = Random.shuffle(oneCountryData ++ countryData)
val countryMapData = categoricalCountryData.map { case country => mapifyText(Seq(country)) }
val (countryMapDF, rawCatCountryMap) = TestFeatureBuilder("rawCatCountryMap", countryMapData)

/**
* Estimator instance to be tested
*/
Expand All @@ -137,6 +146,7 @@ class SmartTextMapVectorizerTest
Vectors.dense(Array(0.0, 1.0, 0.0, 1.0))
).map(_.toOPVector)

import spark.sqlContext.implicits._

Spec[TextMapStats] should "provide a proper semigroup" in {
val data = Seq(
Expand Down Expand Up @@ -589,6 +599,100 @@ class SmartTextMapVectorizerTest
meta.columns.slice(18, 21).forall(_.indicatorValue.contains(OpVectorColumnMetadata.NullString))
}

it should "treat the edge case of coverage being near 0" in {
val maxCard = 100
val vectorizer = new SmartTextMapVectorizer().setCoveragePct(1e-10).setMaxCardinality(maxCard).setMinSupport(1)
.setTrackTextLen(true).setInput(rawCatCountryMap)
val output = vectorizer.getOutput()
val transformed = new OpWorkflow().setResultFeatures(output).transform(countryMapDF)
assertVectorLength(transformed, output, TransmogrifierDefaults.TopK + 2, Pivot)
}
it should "treat the edge case of coverage being near 1" in {
val maxCard = 100
val vectorizer = new SmartTextMapVectorizer().setCoveragePct(1.0 - 1e-10).setMaxCardinality(maxCard)
.setMinSupport(1)
.setTrackTextLen(true).setInput(rawCatCountryMap)
val output = vectorizer.getOutput()
val transformed = new OpWorkflow().setResultFeatures(output).transform(countryMapDF)
assertVectorLength(transformed, output, TransmogrifierDefaults.DefaultNumOfFeatures + 2, Hash)
}

it should "detect one categorical with high cardinality using the coverage" in {
val maxCard = 100
val topK = 10
val cardinality = countryMapDF.select(rawCatCountryMap).as[TextMap#Value].map(_("f0")).distinct().count().toInt
cardinality should be > maxCard
cardinality should be > topK
val vectorizer = new SmartTextMapVectorizer()
.setMaxCardinality(maxCard).setTopK(topK).setMinSupport(1).setCoveragePct(0.5).setCleanText(false)
.setInput(rawCatCountryMap)
val output = vectorizer.getOutput()
val transformed = new OpWorkflow().setResultFeatures(output).transform(countryMapDF)
assertVectorLength(transformed, output, topK + 2, Pivot)
}

it should "not pivot using the coverage because of a high minimum support" in {
val maxCard = 100
val topK = 10
val minSupport = 99999
val numHashes = 5
val cardinality = countryMapDF.select(rawCatCountryMap).as[TextMap#Value].map(_("f0")).distinct().count().toInt
cardinality should be > maxCard
cardinality should be > topK
val vectorizer = new SmartTextMapVectorizer()
.setMaxCardinality(maxCard).setTopK(topK).setMinSupport(minSupport).setNumFeatures(numHashes).setCoveragePct(0.5)
.setTrackTextLen(true).setCleanText(false).setInput(rawCatCountryMap)
val output = vectorizer.getOutput()
val transformed = new OpWorkflow().setResultFeatures(output).transform(countryMapDF)
assertVectorLength(transformed, output, numHashes + 2, Hash)
}

it should "still pivot using the coverage despite a high minimum support" in {
val maxCard = 100
val topK = 10
val minSupport = 100
val numHashes = 5
val cardinality = countryMapDF.select(rawCatCountryMap).as[TextMap#Value].map(_("f0")).distinct().count().toInt
cardinality should be > maxCard
cardinality should be > topK
val vectorizer = new SmartTextMapVectorizer()
.setMaxCardinality(maxCard).setTopK(topK).setMinSupport(minSupport).setNumFeatures(numHashes).setCoveragePct(0.5)
.setTrackTextLen(true).setCleanText(false).setInput(rawCatCountryMap)
val output = vectorizer.getOutput()
val transformed = new OpWorkflow().setResultFeatures(output).transform(countryMapDF)
assertVectorLength(transformed, output, 3, Pivot)
}

it should "not pivot using the coverage because top K is too high" in {
val maxCard = 100
val topK = 1000000
val numHashes = 5
val cardinality = countryMapDF.select(rawCatCountryMap).as[TextMap#Value].map(_("f0")).distinct().count().toInt
cardinality should be > maxCard
cardinality should be <= topK
val vectorizer = new SmartTextMapVectorizer()
.setMaxCardinality(maxCard).setTopK(topK).setNumFeatures(numHashes).setCoveragePct(0.5)
.setTrackTextLen(true).setCleanText(false).setInput(rawCatCountryMap)
val output = vectorizer.getOutput()
val transformed = new OpWorkflow().setResultFeatures(output).transform(countryMapDF)
assertVectorLength(transformed, output, numHashes + 2, Hash)
}

it should "still transform country into text, despite the coverage" in {
val maxCard = 100
val topK = 10
val numHashes = 5
val cardinality = rawDFSeparateMaps.select(rawTextMap1).as[TextMap#Value].map(_.get("f0")).distinct().count().toInt
cardinality should be > maxCard
cardinality should be > topK
val coverageHashed = new SmartTextMapVectorizer()
.setMaxCardinality(maxCard).setTopK(topK).setMinSupport(1).setCoveragePct(0.5).setCleanText(false)
.setTrackTextLen(true).setNumFeatures(numHashes).setInput(rawTextMap1).getOutput()
val transformed = new OpWorkflow().setResultFeatures(coverageHashed).transform(rawDFSeparateMaps)
val expectedLength = numHashes + 2 + categoricalTextData.toSet.filter(!_.isEmpty).toSeq.length + 2
assertVectorLength(transformed, coverageHashed, expectedLength , Hash)
}

it should "create a TextStats object from text that makes sense" in {
val res = TextMapStats.computeTextMapStats[TextMap](
textMapData,
Expand Down Expand Up @@ -668,6 +772,24 @@ class SmartTextMapVectorizerTest
checkDerivedQuantities(res, "f2", Seq(4, 5, 5, 5, 3).map(_.toLong))
}

private[op] def assertVectorLength(df: DataFrame, output: FeatureLike[OPVector],
expectedLength: Int, textVectorizationMethod: TextVectorizationMethod): Unit = {
val result = df.collect(output)
val firstRes = result.head
val metaColumns = OpVectorMetadata(df.schema(output.name)).columns

firstRes.v.size shouldBe expectedLength
metaColumns.length shouldBe expectedLength

textVectorizationMethod match {
case Pivot => assert(metaColumns(expectedLength - 2).indicatorValue.contains(OpVectorColumnMetadata.OtherString))
case Hash => assert(metaColumns(expectedLength - 2).descriptorValue
.contains(OpVectorColumnMetadata.TextLenString))
case other => throw new Error(s"Only Pivoting or Hashing possible, got ${other} instead")
}
assert(metaColumns(expectedLength - 1).indicatorValue.contains(OpVectorColumnMetadata.NullString))
}

/**
* Set of tests to check that the derived quantities calculated on the length distribution in TextMapStats (for
* a single key) match the actual length distributions of the tokens.
Expand Down
Loading

0 comments on commit 24cdbc4

Please sign in to comment.