Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip serializing cardinality estimates in FeatureDistributions #447

Merged
merged 9 commits into from
Dec 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import com.twitter.algebird._
import com.twitter.algebird.Operators._
import org.apache.spark.mllib.feature.HashingTF
import org.json4s.jackson.Serialization
import org.json4s.{DefaultFormats, Formats}
import org.json4s.{DefaultFormats, FieldSerializer, Formats}

import scala.util.Try

Expand All @@ -64,6 +64,7 @@ case class FeatureDistribution
distribution: Array[Double],
summaryInfo: Array[Double],
moments: Option[Moments] = None,
cardEstimate: Option[TextStats] = None,
`type`: FeatureDistributionType = FeatureDistributionType.Training
) extends FeatureDistributionLike {

Expand Down Expand Up @@ -109,9 +110,10 @@ case class FeatureDistribution
val combinedSummaryInfo = if (summaryInfo.length > fd.summaryInfo.length) summaryInfo else fd.summaryInfo

val combinedMoments = moments + fd.moments
val combinedCard = cardEstimate + fd.cardEstimate

FeatureDistribution(name, key, count + fd.count, nulls + fd.nulls, combinedDist,
combinedSummaryInfo, combinedMoments, `type`)
combinedSummaryInfo, combinedMoments, combinedCard, `type`)
}

/**
Expand Down Expand Up @@ -172,14 +174,14 @@ case class FeatureDistribution
}

override def equals(that: Any): Boolean = that match {
case FeatureDistribution(`name`, `key`, `count`, `nulls`, d, s, m, `type`) =>
case FeatureDistribution(`name`, `key`, `count`, `nulls`, d, s, m, c, `type`) =>
distribution.deep == d.deep && summaryInfo.deep == s.deep &&
moments == m
moments == m && cardEstimate == c
case _ => false
}

override def hashCode(): Int = Objects.hashCode(name, key, count, nulls, distribution,
summaryInfo, moments, `type`)
summaryInfo, moments, cardEstimate, `type`)
}

object FeatureDistribution {
Expand All @@ -190,8 +192,13 @@ object FeatureDistribution {
override def plus(l: FeatureDistribution, r: FeatureDistribution): FeatureDistribution = l.reduce(r)
}

val FeatureDistributionSerializer = FieldSerializer[FeatureDistribution](
FieldSerializer.ignore("cardEstimate")
)

implicit val formats: Formats = DefaultFormats +
EnumEntrySerializer.json4s[FeatureDistributionType](FeatureDistributionType)
EnumEntrySerializer.json4s[FeatureDistributionType](FeatureDistributionType) +
FeatureDistributionSerializer

/**
* Feature distributions to json
Expand Down Expand Up @@ -238,6 +245,7 @@ object FeatureDistribution {
.getOrElse(1L -> (Array(summary.min, summary.max, summary.sum, summary.count) -> new Array[Double](bins)))

val moments = value.map(momentsValues)
val cardEstimate = value.map(cardinalityValues)

FeatureDistribution(
name = name,
Expand All @@ -247,6 +255,7 @@ object FeatureDistribution {
summaryInfo = summaryInfo,
distribution = distribution,
moments = moments,
cardEstimate = cardEstimate,
`type` = `type`
)
}
Expand All @@ -265,6 +274,21 @@ object FeatureDistribution {
MomentsGroup.sum(population.map(x => Moments(x)))
}

/**
* Function to track frequency of the first $(MaxCardinality) unique values
* (number for numeric features, token for text features)
*
* @param values values to track distribution / frequency
* @return TextStats object containing a Map from a value to its frequency (histogram)
*/
private def cardinalityValues(values: ProcessedSeq): TextStats = {
TextStats(countStringValues(values.left.getOrElse(values.right.get)))
}

private def countStringValues[T](seq: Seq[T]): Map[String, Int] = {
seq.groupBy(identity).map { case (k, valSeq) => k.toString -> valSeq.size }
}

/**
* Function to put data into histogram of counts
*
Expand Down
28 changes: 16 additions & 12 deletions core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,30 +31,28 @@
package com.salesforce.op

import com.salesforce.op.evaluators._
import com.salesforce.op.features.types._
import com.salesforce.op.features.types.{Real, _}
import com.salesforce.op.features.{Feature, FeatureDistributionType, FeatureLike}
import com.salesforce.op.filters._
import com.salesforce.op.stages.impl.classification._
import com.salesforce.op.stages.impl.feature.{CombinationStrategy, TextStats}
import com.salesforce.op.stages.impl.preparators._
import com.salesforce.op.stages.impl.regression.{OpLinearRegression, OpXGBoostRegressor, RegressionModelSelector}
import com.salesforce.op.stages.impl.selector.ModelSelectorNames.EstimatorType
import com.salesforce.op.stages.impl.selector.{SelectedModelCombiner, SelectedCombinerModel, SelectedModel}
import com.salesforce.op.stages.impl.selector.ValidationType._
import com.salesforce.op.stages.impl.selector.{SelectedCombinerModel, SelectedModel, SelectedModelCombiner}
import com.salesforce.op.stages.impl.tuning.{DataCutter, DataSplitter}
import com.salesforce.op.test.{PassengerSparkFixtureTest, TestFeatureBuilder}
import com.salesforce.op.testkit.RandomReal
import com.salesforce.op.utils.spark.{OpVectorColumnMetadata, OpVectorMetadata}
import com.twitter.algebird.Moments
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tuning.ParamGridBuilder
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.junit.runner.RunWith
import com.salesforce.op.features.types.Real
import com.salesforce.op.stages.impl.feature.{CombinationStrategy, TextStats}
import com.twitter.algebird.Moments
import org.apache.spark.sql.{DataFrame, Dataset}
import org.scalactic.Equality
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner
import org.apache.spark.sql.functions._

import scala.util.{Failure, Success}

Expand Down Expand Up @@ -166,15 +164,16 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest with Dou
return Array(descaledsmallCoeff, originalsmallCoeff, descaledbigCoeff, orginalbigCoeff)
}

def getFeatureMoments(inputModel: FeatureLike[Prediction],
DF: DataFrame): Map[String, Moments] = {
def getFeatureMomentsAndCard(inputModel: FeatureLike[Prediction],
DF: DataFrame): (Map[String, Moments], Map[String, TextStats]) = {
lazy val workFlow = new OpWorkflow().setResultFeatures(inputModel).setInputDataset(DF)
lazy val dummyReader = workFlow.getReader()
lazy val workFlowRFF = workFlow.withRawFeatureFilter(Some(dummyReader), None)
lazy val model = workFlowRFF.train()
val insights = model.modelInsights(inputModel)
val featureMoments = insights.features.map(f => f.featureName -> f.distributions.head.moments.get).toMap
return featureMoments
val featureCardinality = insights.features.map(f => f.featureName -> f.distributions.head.cardEstimate.get).toMap
featureMoments -> featureCardinality
}

val params = new OpParams()
Expand Down Expand Up @@ -782,7 +781,7 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest with Dou
val df = linRegDF._3
val meanTol = 0.01
val varTol = 0.01
val moments = getFeatureMoments(standardizedLinpred, linRegDF._3)
val (moments, cardinality) = getFeatureMomentsAndCard(standardizedLinpred, linRegDF._3)

// Go through each feature and check that the mean, variance, and unique counts match the data
moments.foreach { case (featureName, value) => {
Expand All @@ -793,6 +792,11 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest with Dou
math.abs((value.variance - expectedVariance) / expectedVariance) < varTol shouldBe true
}
}

cardinality.foreach { case (featureName, value) =>
val actualUniques = df.select(featureName).as[Double].distinct.collect.toSet
actualUniques should contain allElementsOf value.valueCounts.keySet.map(_.toDouble)
}
}

it should "return correct insights when a model combiner equal is used as the final feature" in {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ import com.salesforce.op.features.{FeatureDistributionType, TransientFeature}
import com.salesforce.op.stages.impl.feature.TextStats
import com.salesforce.op.test.PassengerSparkFixtureTest
import com.salesforce.op.testkit.RandomText
import com.salesforce.op.utils.json.EnumEntrySerializer
import com.twitter.algebird.Moments
import org.json4s.DefaultFormats
import org.json4s.jackson.Serialization
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner
Expand Down Expand Up @@ -75,7 +78,9 @@ class FeatureDistributionTest extends FlatSpec with PassengerSparkFixtureTest wi
distribs(3).distribution.sum shouldBe 0
distribs(4).distribution.sum shouldBe 3
distribs(4).summaryInfo.length shouldBe bins
distribs(2).cardEstimate.get shouldBe TextStats(Map("male" -> 1, "female" -> 1))
distribs(2).moments.get shouldBe Moments(2, 5.0, 2.0, 0.0, 2.0)
distribs(4).cardEstimate.get shouldBe TextStats(Map("5.0" -> 1, "1.0" -> 1, "3.0" -> 1))
distribs(4).moments.get shouldBe Moments(3, 3.0, 8.0, 0.0, 32.0)
}

Expand Down Expand Up @@ -200,7 +205,8 @@ class FeatureDistributionTest extends FlatSpec with PassengerSparkFixtureTest wi
it should "marshall to/from json" in {
val fd1 = FeatureDistribution("A", None, 10, 1, Array(1, 4, 0, 0, 6), Array.empty)
val fd2 = FeatureDistribution("A", None, 10, 1, Array(1, 4, 0, 0, 6),
Array.empty, Some(Moments(1.0)), FeatureDistributionType.Scoring)
Array.empty, Some(Moments(1.0)), Option.empty,
FeatureDistributionType.Scoring)
val json = FeatureDistribution.toJson(Array(fd1, fd2))
FeatureDistribution.fromJson(json) match {
case Success(r) => r shouldBe Seq(fd1, fd2)
Expand All @@ -210,7 +216,7 @@ class FeatureDistributionTest extends FlatSpec with PassengerSparkFixtureTest wi

it should "marshall to/from json with default vector args" in {
val fd1 = FeatureDistribution("A", None, 10, 1, Array(1, 4, 0, 0, 6),
Array.empty, None, FeatureDistributionType.Scoring)
Array.empty, None, None, FeatureDistributionType.Scoring)
val fd2 = FeatureDistribution("A", Some("X"), 20, 20, Array(2, 8, 0, 0, 12), Array.empty)
val json =
"""[{"name":"A","count":10,"nulls":1,"distribution":[1.0,4.0,0.0,0.0,6.0],"type":"Scoring"},
Expand Down Expand Up @@ -238,4 +244,22 @@ class FeatureDistributionTest extends FlatSpec with PassengerSparkFixtureTest wi
intercept[IllegalArgumentException](fd1.jsDivergence(fd1.copy(name = "boo"))) should have message
"requirement failed: Name must match to compare or combine feature distributions: A != boo"
}

it should "not serialize cardEstimate field" in {
val cardEstimate = "cardEstimate"
val fd1 = FeatureDistribution("A", None, 10, 1, Array(1, 4, 0, 0, 6),
Array.empty, Some(Moments(1.0)), Some(TextStats(Map("foo" -> 1, "bar" ->2))),
FeatureDistributionType.Scoring)
val featureDistributions = Seq(fd1, fd1.copy(cardEstimate = None))

FeatureDistribution.toJson(featureDistributions) shouldNot include (cardEstimate)

// deserialization from json with and without cardEstimate works
val jsonWithCardEstimate = Serialization.write(featureDistributions)(DefaultFormats +
EnumEntrySerializer.json4s[FeatureDistributionType](FeatureDistributionType))
jsonWithCardEstimate should fullyMatch regex Seq(cardEstimate).mkString(".*", ".*", ".*")
jsonWithCardEstimate shouldNot fullyMatch regex Seq.fill(2)(cardEstimate).mkString(".*", ".*", ".*")

FeatureDistribution.fromJson(jsonWithCardEstimate) shouldBe Success(featureDistributions)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,16 @@ trait FiltersTestData {

protected val scoreSummaries = Seq(
FeatureDistribution("A", None, 10, 8, Array(1, 4, 0, 0, 6), Array.empty,
None, FeatureDistributionType.Scoring),
None, None, FeatureDistributionType.Scoring),
FeatureDistribution("B", None, 20, 20, Array(2, 8, 0, 0, 12), Array.empty,
None, FeatureDistributionType.Scoring),
None, None, FeatureDistributionType.Scoring),
FeatureDistribution("C", Some("1"), 10, 1, Array(0, 0, 10, 10, 0),
Array.empty, None, FeatureDistributionType.Scoring),
Array.empty, None, None, FeatureDistributionType.Scoring),
FeatureDistribution("C", Some("2"), 20, 19, Array(2, 8, 0, 0, 12),
Array.empty, None, FeatureDistributionType.Scoring),
Array.empty, None, None, FeatureDistributionType.Scoring),
FeatureDistribution("D", Some("1"), 0, 0, Array(0, 0, 0, 0, 0), Array.empty,
None, FeatureDistributionType.Scoring),
None, None, FeatureDistributionType.Scoring),
FeatureDistribution("D", Some("2"), 0, 0, Array(0, 0, 0, 0, 0), Array.empty,
None, FeatureDistributionType.Scoring)
None, None, FeatureDistributionType.Scoring)
)
}