Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow TextStats length distribution to be token-based and refactor for testability #464

Merged
merged 28 commits into from Mar 26, 2020
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
1c2cbc2
Refactored and added more incremental tests
Jauntbox Mar 3, 2020
83881c0
Updated test
Jauntbox Mar 3, 2020
ca2e122
Added tests and fixed a small bug
Jauntbox Mar 4, 2020
c190d97
More refactoring and updating tests
Jauntbox Mar 4, 2020
75b0770
More test refactoring
Jauntbox Mar 5, 2020
305c57e
More refactoring
Jauntbox Mar 5, 2020
ff53dc4
Small cleanups
Jauntbox Mar 5, 2020
13b6076
Merge branch 'master' of github.com:salesforce/TransmogrifAI into km/…
Jauntbox Mar 6, 2020
fc6cd07
Addressing comments
Jauntbox Mar 6, 2020
1434128
Added text length distribution to the TextStats calculated in RFF
Jauntbox Mar 9, 2020
e117fbd
Made offending methods private
Jauntbox Mar 16, 2020
87878e0
Merge branch 'master' of github.com:salesforce/TransmogrifAI into km/…
Jauntbox Mar 16, 2020
fe36ec5
Comments
Jauntbox Mar 17, 2020
cd31e32
Spelling
Jauntbox Mar 17, 2020
a322363
Added toggle for tokenization in length distribution
Jauntbox Mar 18, 2020
32ad893
Added toggle to turn tokenization on/off for length distribution coun…
Jauntbox Mar 19, 2020
ab9d2a7
Reverted changes to RFF for now and added logging to help with visibi…
Jauntbox Mar 19, 2020
4cb5c88
Updated tests to check both tokenized and non-tokenized text feature …
Jauntbox Mar 19, 2020
e463685
Better logging
Jauntbox Mar 19, 2020
e44ca68
Revert unintentional RFF changes
Jauntbox Mar 19, 2020
ce5663e
Removed unused method
Jauntbox Mar 19, 2020
b866bb3
Removed SVC models from the default models to try in BinaryClassifica…
Jauntbox Mar 19, 2020
e72af36
Added new params to vectorizer shortcuts
Jauntbox Mar 20, 2020
307a014
scalastyle issue
Jauntbox Mar 20, 2020
49127d9
Replaced boolean param with enum
Jauntbox Mar 26, 2020
95bc3e7
Added enum to json4s serialization list
Jauntbox Mar 26, 2020
aad13ba
Actually add the enum file
Jauntbox Mar 26, 2020
d78868d
Merge branch 'master' of github.com:salesforce/TransmogrifAI into km/…
Jauntbox Mar 26, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -33,7 +33,8 @@ package com.salesforce.op.filters
import java.util.Objects

import com.salesforce.op.features.{FeatureDistributionLike, FeatureDistributionType}
import com.salesforce.op.stages.impl.feature.{HashAlgorithm, Inclusion, NumericBucketizer, TextStats}
import com.salesforce.op.stages.impl.feature.{HashAlgorithm, Inclusion, NumericBucketizer, TextStats,
SmartTextVectorizer}
import com.salesforce.op.utils.json.EnumEntrySerializer
import com.twitter.algebird.Monoid._
import com.twitter.algebird._
Expand Down Expand Up @@ -282,7 +283,15 @@ object FeatureDistribution {
* @return TextStats object containing a Map from a value to its frequency (histogram)
*/
private def cardinalityValues(values: ProcessedSeq): TextStats = {
TextStats(countStringValues(values.left.getOrElse(values.right.get)), Map.empty)
implicit val testStatsMonoid: Monoid[TextStats] = TextStats.monoid(RawFeatureFilter.MaxCardinality)

val stringVals = values match {
case Left(stringSeq) => stringSeq
case Right(doubleSeq) => doubleSeq.map(_.toString)
}
stringVals.foldLeft(TextStats.empty)((acc, el) => acc + SmartTextVectorizer.computeTextStats(
Option(el), shouldCleanText = false, maxCardinality = RawFeatureFilter.MaxCardinality)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should shouldCleanText = true instead ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can change it, but I don't think it matters much here. These values aren't used in SmartTextVectorizer. They're the ones that show up in the ModelInsights.

)
}

private def countStringValues[T](seq: Seq[T]): Map[String, Long] = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not relevant to this PR but i think countStringValues is no longer used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I can remove it then

Expand Down
Expand Up @@ -67,17 +67,6 @@ class SmartTextMapVectorizer[T <: OPMap[String]]

private implicit val textMapStatsSeqEnc: Encoder[Array[TextMapStats]] = ExpressionEncoder[Array[TextMapStats]]()

private def computeTextMapStats
(
textMap: T#Value, shouldCleanKeys: Boolean, shouldCleanValues: Boolean
): TextMapStats = {
val keyValueCounts = textMap.map{ case (k, v) =>
cleanTextFn(k, shouldCleanKeys) ->
TextStats(Map(cleanTextFn(v, shouldCleanValues) -> 1L), Map(cleanTextFn(v, shouldCleanValues).length -> 1L))
}
TextMapStats(keyValueCounts)
}

private def makeHashingParams() = HashingFunctionParams(
hashWithIndex = $(hashWithIndex),
prependFeatureName = $(prependFeatureName),
Expand Down Expand Up @@ -184,7 +173,7 @@ class SmartTextMapVectorizer[T <: OPMap[String]]

implicit val testStatsMonoid: Monoid[TextMapStats] = TextMapStats.monoid(maxCard)
val valueStats: Dataset[Array[TextMapStats]] = dataset.map(
_.map(computeTextMapStats(_, shouldCleanKeys, shouldCleanValues)).toArray
_.map(SmartTextMapVectorizer.computeTextMapStats(_, shouldCleanKeys, shouldCleanValues, maxCard)).toArray
)
val aggregatedStats: Array[TextMapStats] = valueStats.reduce(_ + _)

Expand All @@ -203,6 +192,36 @@ class SmartTextMapVectorizer[T <: OPMap[String]]
}
}

object SmartTextMapVectorizer extends CleanTextFun {

/**
* Computes a TextMapStats instance from a text map entry
*
* @param textMap Text value (eg. entry in a dataframe)
* @param shouldCleanKeys Whether or not the keys (feature names) should be cleaned
* @param shouldCleanValues Whether or not the values (the actual text) should be cleaned
* @param maxCardinality Max cardinality to keep track of in maps (relevant for the text length distribution here)
* @tparam T Feature type that the text map value is coming from
* @return TextMapStats instance with value and length counts filled out appropriately for each key
*/
private[op] def computeTextMapStats[T <: OPMap[String]](
tovbinm marked this conversation as resolved.
Show resolved Hide resolved
textMap: T#Value,
shouldCleanKeys: Boolean,
shouldCleanValues: Boolean,
maxCardinality: Int
)(implicit tti: TypeTag[T], ttiv: TypeTag[T#Value]): TextMapStats = {
val keyValueCounts = textMap.map{ case (k, v) =>
val cleanedText = cleanTextFn(v, shouldCleanValues)
val lengthsMap: Map[Int, Long] = TextTokenizer.tokenizeString(cleanedText).tokens.value
.foldLeft(Map.empty[Int, Long])(
(acc, el) => TextStats.additionHelper(acc, Map(el.length -> 1L), maxCardinality)
)
cleanTextFn(k, shouldCleanKeys) -> TextStats(Map(cleanedText -> 1L), lengthsMap)
}
TextMapStats(keyValueCounts)
}
}

/**
* Summary statistics of a text feature
*
Expand Down
Expand Up @@ -85,7 +85,9 @@ class SmartTextVectorizer[T <: Text](uid: String = UID[SmartTextVectorizer[T]])(
val shouldCleanText = $(cleanText)

implicit val testStatsMonoid: Semigroup[TextStats] = TextStats.monoid(maxCard)
val valueStats: Dataset[Array[TextStats]] = dataset.map(_.map(computeTextStats(_, shouldCleanText)).toArray)
val valueStats: Dataset[Array[TextStats]] = dataset.map(
_.map(SmartTextVectorizer.computeTextStats(_, shouldCleanText, maxCard)).toArray
)
val aggregatedStats: Array[TextStats] = valueStats.reduce(_ + _)

val (vectorizationMethods, topValues) = aggregatedStats.map { stats =>
Expand Down Expand Up @@ -121,14 +123,6 @@ class SmartTextVectorizer[T <: Text](uid: String = UID[SmartTextVectorizer[T]])(
.setTrackTextLen($(trackTextLen))
}

private def computeTextStats(text: T#Value, shouldCleanText: Boolean): TextStats = {
val (valueCounts, lengthCounts) = text match {
case Some(v) => (Map(cleanTextFn(v, shouldCleanText) -> 1L), Map(cleanTextFn(v, shouldCleanText).length -> 1L))
case None => (Map.empty[String, Long], Map.empty[Int, Long])
}
TextStats(valueCounts, lengthCounts)
}

private def makeVectorMetadata(smartTextParams: SmartTextVectorizerModelArgs): OpVectorMetadata = {
require(inN.length == smartTextParams.vectorizationMethods.length)

Expand Down Expand Up @@ -164,13 +158,40 @@ class SmartTextVectorizer[T <: Text](uid: String = UID[SmartTextVectorizer[T]])(
}
}

object SmartTextVectorizer {
object SmartTextVectorizer extends CleanTextFun {
val MaxCardinality: Int = 100
val MinTextLengthStdDev: Double = 0

private[op] def partition[T: ClassTag](input: Array[T], condition: Array[Boolean]): (Array[T], Array[T]) = {
val all = input.zip(condition)
(all.collect { case (item, true) => item }, all.collect { case (item, false) => item })
}

/**
* Computes a TextStats instance from a Text entry
*
* @param text Text value (eg. entry in a dataframe)
* @param shouldCleanText Whether or not the text should be cleaned
* @param maxCardinality Max cardinality to keep track of in maps (relevant for the text length distribution here)
* @tparam T Feature type that the text value is coming from
* @return TextStats instance with value and length counts filled out appropriately
*/
private[op] def computeTextStats[T <: Text : TypeTag](
tovbinm marked this conversation as resolved.
Show resolved Hide resolved
text: T#Value,
shouldCleanText: Boolean,
maxCardinality: Int
): TextStats = {
// Go through each token in text and start appending it to a TextStats instance
val lengthsMap: Map[Int, Long] = TextTokenizer.tokenizeStringOpt(text).tokens.value
.foldLeft(Map.empty[Int, Long])(
(acc, el) => TextStats.additionHelper(acc, Map(el.length -> 1L), maxCardinality)
)
val (valueCounts, lengthCounts) = text match {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when we reach RawFeatureFilter.MaxCardinality for valueCounts, will lengthCounts also stop accumulating ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm, this is taken care of by val newLengthCounts = additionHelper(l.lengthCounts, r.lengthCounts, maxCardinality) , pls disregard this comment :D

case Some(v) => (Map(cleanTextFn(v, shouldCleanText) -> 1L), lengthsMap)
case None => (Map.empty[String, Long], Map.empty[Int, Long])
}
TextStats(valueCounts, lengthCounts)
}
}

/**
Expand All @@ -189,8 +210,8 @@ private[op] case class TextStats
val lengthMean: Double = lengthCounts.foldLeft(0.0)((acc, el) => acc + el._1 * el._2) / lengthSize
val lengthVariance: Double = lengthCounts.foldLeft(0.0)(
(acc, el) => acc + el._2 * (el._1 - lengthMean) * (el._1 - lengthMean)
)
val lengthStdDev: Double = math.sqrt(lengthVariance / lengthSize)
) / lengthSize
tovbinm marked this conversation as resolved.
Show resolved Hide resolved
val lengthStdDev: Double = math.sqrt(lengthVariance)
}

private[op] object TextStats {
Expand Down
Expand Up @@ -140,7 +140,7 @@ object TextTokenizer {
/**
* Language wise sentence tokenization
*
* @param text text to tokenize
* @param textString text to tokenize (in String form)
* @param languageDetector language detector instance
* @param analyzer text analyzer instance
* @param sentenceSplitter sentence splitter instance
Expand All @@ -151,8 +151,8 @@ object TextTokenizer {
* @param minTokenLength minimum token length
* @return detected language and sentence tokens
*/
def tokenize(
text: Text,
private[op] def tokenizeString(
textString: String,
Copy link
Collaborator

@tovbinm tovbinm Mar 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now this function can explode with NullPointerException is textString is null, while before it could not have happened.

Copy link
Contributor Author

@Jauntbox Jauntbox Mar 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I'm not sure if it's possible for the textString argument to be null in practice though. When this function used for tokenizing the map entries, a value that was originally null there will just not show up as an entry in the map. When it's used for tokenizing a normal Text entry, then we should have already safely converted any nulls or missing elements into an Option[String], right?

The actual tokenize call during vectorization is still tokenize(v.toText) where v is the value in a text map. I'd actually argue that that should be changed to tokenizeString(v) to save time converting it to Text and back again.

I agree it's technically less safe, but I don't think it's necessary to have null checking at this point in the flow. We should make sure the data gets created in a safe way, which I think we already do. Are there some specific edge cases I'm missing?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the simplest is to add a null check tokenizeString and return

Copy link
Contributor Author

@Jauntbox Jauntbox Mar 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that I think about it more, I'm pretty sure the old tokenize function would also give a NPE if you fed it a sneaky null value. The SomeValue.unapply function explicitly calls v.isEmpty which would also fail if v was null.

I put back the old tokenize function as oldTokenize and tried

val sneakyStringOpt: Option[String] = null
val myText = Text(sneakyStringOpt)
val res = TextTokenizer.oldTokenize(myText)

which did indeed throw a NPE.

We have tests all over the place (eg. our vectorizer tests and FeatureTypeSparkConverterTest) that make sure we can handle null values in dataframes and safely convert them into our types. I'm not aware of any explicit null checks in our functions elsewhere, so it just feels weird to put one here.

@leahmcguire any opinions on this?

Copy link
Collaborator

@tovbinm tovbinm Mar 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SomeValue.unapply operates on value which is Option[String]. Null check is done during the construction of Text when the values are extracted from Dataframe / RDD. NullPointerException is indeed unlikely to be thrown.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your example you provided is not currently possible and also not a fair one :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Jauntbox is this only called from the Option[String] version below? if so make it private and it is fine.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact make them both private please

languageDetector: LanguageDetector = LanguageDetector,
analyzer: TextAnalyzer = Analyzer,
sentenceSplitter: Option[SentenceSplitter] = None,
Expand All @@ -162,30 +162,84 @@ object TextTokenizer {
toLowercase: Boolean = ToLowercase,
minTokenLength: Int = MinTokenLength
): TextTokenizerResult = {
text match {
case SomeValue(Some(txt)) =>
val language =
if (!autoDetectLanguage) defaultLanguage
else {
languageDetector
.detectLanguages(txt)
.collectFirst { case (lang, confidence) if confidence > autoDetectThreshold => lang }
.getOrElse(defaultLanguage)
}
val lowerTxt = if (toLowercase) txt.toLowerCase else txt
val language =
if (!autoDetectLanguage) defaultLanguage
else {
languageDetector
.detectLanguages(textString)
.collectFirst { case (lang, confidence) if confidence > autoDetectThreshold => lang }
.getOrElse(defaultLanguage)
}
val lowerTxt = if (toLowercase) textString.toLowerCase else textString

val sentences = sentenceSplitter.map(_.getSentences(lowerTxt, language))
.getOrElse(Seq(lowerTxt))
.map { sentence =>
val tokens = analyzer.analyze(sentence, language)
tokens.filter(_.length >= minTokenLength).toTextList
}
TextTokenizerResult(language, sentences)
case _ =>
TextTokenizerResult(defaultLanguage, Seq(TextList.empty))
val sentences = sentenceSplitter.map(_.getSentences(lowerTxt, language))
.getOrElse(Seq(lowerTxt))
.map { sentence =>
val tokens = analyzer.analyze(sentence, language)
tokens.filter(_.length >= minTokenLength).toTextList
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we only keeping tokens with length > minTokenLength ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was existing behavior. It's a configurable parameter (defaulting to 1), so is not required.

}
TextTokenizerResult(language, sentences)
}

/**
* Language wise sentence tokenization
*
* @param textStringOpt text to tokenize (in Option[String] form)
* @param languageDetector language detector instance
* @param analyzer text analyzer instance
* @param sentenceSplitter sentence splitter instance
* @param autoDetectLanguage whether to attempt language detection
* @param defaultLanguage default language
* @param autoDetectThreshold language detection threshold
* @param toLowercase whether to convert all characters to lowercase before tokenizing
* @param minTokenLength minimum token length
* @return detected language and sentence tokens
*/
private[op] def tokenizeStringOpt(
textStringOpt: Option[String],
languageDetector: LanguageDetector = LanguageDetector,
analyzer: TextAnalyzer = Analyzer,
sentenceSplitter: Option[SentenceSplitter] = None,
autoDetectLanguage: Boolean = AutoDetectLanguage,
defaultLanguage: Language = DefaultLanguage,
autoDetectThreshold: Double = AutoDetectThreshold,
toLowercase: Boolean = ToLowercase,
minTokenLength: Int = MinTokenLength
): TextTokenizerResult = {
textStringOpt match {
case Some(txt) => tokenizeString(txt, languageDetector, analyzer, sentenceSplitter, autoDetectLanguage,
defaultLanguage, autoDetectThreshold, toLowercase, minTokenLength)
case None => TextTokenizerResult(defaultLanguage, Seq(TextList.empty))
}
}

/**
* Language wise sentence tokenization
*
* @param text text to tokenize
* @param languageDetector language detector instance
* @param analyzer text analyzer instance
* @param sentenceSplitter sentence splitter instance
* @param autoDetectLanguage whether to attempt language detection
* @param defaultLanguage default language
* @param autoDetectThreshold language detection threshold
* @param toLowercase whether to convert all characters to lowercase before tokenizing
* @param minTokenLength minimum token length
* @return detected language and sentence tokens
*/
def tokenize(
tovbinm marked this conversation as resolved.
Show resolved Hide resolved
text: Text,
languageDetector: LanguageDetector = LanguageDetector,
analyzer: TextAnalyzer = Analyzer,
sentenceSplitter: Option[SentenceSplitter] = None,
autoDetectLanguage: Boolean = AutoDetectLanguage,
defaultLanguage: Language = DefaultLanguage,
autoDetectThreshold: Double = AutoDetectThreshold,
toLowercase: Boolean = ToLowercase,
minTokenLength: Int = MinTokenLength
): TextTokenizerResult = tokenizeStringOpt(text.value, languageDetector, analyzer, sentenceSplitter,
autoDetectLanguage, defaultLanguage, autoDetectThreshold, toLowercase, minTokenLength)

/**
* Text tokenization result
*
Expand Down
Expand Up @@ -78,9 +78,9 @@ class FeatureDistributionTest extends FlatSpec with PassengerSparkFixtureTest wi
distribs(3).distribution.sum shouldBe 0
distribs(4).distribution.sum shouldBe 3
distribs(4).summaryInfo.length shouldBe bins
distribs(2).cardEstimate.get shouldBe TextStats(Map("male" -> 1, "female" -> 1), Map.empty)
distribs(2).cardEstimate.get shouldBe TextStats(Map("male" -> 1, "female" -> 1), Map(4 -> 1L, 6 -> 1L))
distribs(2).moments.get shouldBe Moments(2, 5.0, 2.0, 0.0, 2.0)
distribs(4).cardEstimate.get shouldBe TextStats(Map("5.0" -> 1, "1.0" -> 1, "3.0" -> 1), Map.empty)
distribs(4).cardEstimate.get shouldBe TextStats(Map("5.0" -> 1, "1.0" -> 1, "3.0" -> 1), Map(3 -> 3L))
distribs(4).moments.get shouldBe Moments(3, 3.0, 8.0, 0.0, 32.0)
}

Expand Down