diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/insights/RecordInsightsLOCO.scala b/core/src/main/scala/com/salesforce/op/stages/impl/insights/RecordInsightsLOCO.scala index 7bcf02f969..76241757c7 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/insights/RecordInsightsLOCO.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/insights/RecordInsightsLOCO.scala @@ -64,9 +64,21 @@ trait RecordInsightsLOCOParams extends Params { def setTopKStrategy(strategy: TopKStrategy): this.type = set(topKStrategy, strategy.entryName) def getTopKStrategy: TopKStrategy = TopKStrategy.withName($(topKStrategy)) + final val vectorAggregationStrategy = new Param[String](parent = this, name = "vectorAggregationStrategy", + doc = "Aggregate text/date vector by " + + "1. LeaveOutVector strategy - calculate the loco by leaving out the entire vector or " + + "2. Avg strategy - calculate the loco for each column of the vector and then average all the locos." + ) + def setVectorAggregationStrategy(strategy: VectorAggregationStrategy): this.type = + set(vectorAggregationStrategy, strategy.entryName) + def getVectorAggregationStrategy: VectorAggregationStrategy = VectorAggregationStrategy.withName( + $(vectorAggregationStrategy)) + + setDefault( topK -> 20, - topKStrategy -> TopKStrategy.Abs.entryName + topKStrategy -> TopKStrategy.Abs.entryName, + vectorAggregationStrategy -> VectorAggregationStrategy.Avg.entryName ) } @@ -104,30 +116,33 @@ class RecordInsightsLOCO[T <: Model[T]] /** * These are the name of the types we want to perform an aggregation of the LOCO results over derived features */ - private val textTypes = - Set(FeatureType.typeName[Text], FeatureType.typeName[TextArea], FeatureType.typeName[TextList]) - private val textMapTypes = - Set(FeatureType.typeName[TextMap], FeatureType.typeName[TextAreaMap]) - private val dateTypes = - Set(FeatureType.typeName[Date], FeatureType.typeName[DateTime]) - private val dateMapTypes = - Set(FeatureType.typeName[DateMap], FeatureType.typeName[DateTimeMap]) - - // Indices of features derived from hashed Text(Map)Vectorizer - private lazy val textFeatureIndices: Seq[Int] = getIndicesOfFeatureType(textTypes ++ textMapTypes, - h => h.indicatorValue.isEmpty && h.descriptorValue.isEmpty) - - // Indices of features derived from unit Date(Map)Vectorizer - private lazy val dateFeatureIndices = getIndicesOfFeatureType(dateTypes ++ dateMapTypes, _.descriptorValue.isDefined) + private val textTypes = Set(FeatureType.typeName[Text], FeatureType.typeName[TextArea], + FeatureType.typeName[TextList], FeatureType.typeName[TextMap], FeatureType.typeName[TextAreaMap]) + private val dateTypes = Set(FeatureType.typeName[Date], FeatureType.typeName[DateTime], + FeatureType.typeName[DateMap], FeatureType.typeName[DateTimeMap]) + + // Map of RawFeatureName to the size of its derived features that needs to be aggregated + // for the above textTypes and dateTypes. + private lazy val aggFeaturesSize: Map[String, Int] = histories + .filter(h => isTextFeature(h) || isDateFeature(h)) + .groupBy { h => getRawFeatureName(h).get } + .mapValues(_.length) + + /** + * Return whether this feature derived from hashed Text(Map)Vectorizer + * @return Boolean + */ + private def isTextFeature(h: OpVectorColumnHistory): Boolean = { + h.parentFeatureType.exists(textTypes.contains) && h.indicatorValue.isEmpty && h.descriptorValue.isEmpty + } /** - * Return the indices of features derived from given types. - * @return Seq[Int] + * Return whether this feature derived from unit circle Date(Map)Vectorizer + * @return Boolean */ - private def getIndicesOfFeatureType(types: Set[String], predicate: OpVectorColumnHistory => Boolean): Seq[Int] = - histories.collect { - case h if h.parentFeatureType.exists(types.contains) && predicate(h) => h.index - }.distinct.sorted + private def isDateFeature(h: OpVectorColumnHistory): Boolean = { + h.parentFeatureType.exists(dateTypes.contains) && h.descriptorValue.isDefined + } private def computeDiff ( @@ -159,7 +174,7 @@ class RecordInsightsLOCO[T <: Model[T]] // TODO : Filter by parentStage (DateToUnitCircleTransformer & DateToUnitCircleVectorizer) once the bug in the // feature history after multiple transformations has been fixed name.map { n => - val timePeriodName = if ((dateTypes ++ dateMapTypes).exists(history.parentFeatureType.contains)) { + val timePeriodName = if (dateTypes.exists(history.parentFeatureType.contains)) { history.descriptorValue .flatMap(convertToTimePeriod) .map(p => "_" + p.entryName) @@ -168,82 +183,72 @@ class RecordInsightsLOCO[T <: Model[T]] } } + private def aggregateDiffs( + featureSparse: SparseVector, + aggIndices: Array[(Int, Int)], + strategy: VectorAggregationStrategy, + baseScore: Array[Double], + featureSize: Int + ): Array[Double] = { + strategy match { + case VectorAggregationStrategy.Avg => + aggIndices + .map { case (i, oldInd) => computeDiff(featureSparse.copy.updated(i, oldInd, 0.0), baseScore) } + .foldLeft(Array.empty[Double])(sumArrays) + .map( _ / featureSize) + + case VectorAggregationStrategy.LeaveOutVector => + val copyFeatureSparse = featureSparse.copy + aggIndices.foreach {case (i, oldInd) => copyFeatureSparse.updated(i, oldInd, 0.0)} + computeDiff(copyFeatureSparse, baseScore) + } + } + private def returnTopPosNeg ( featureSparse: SparseVector, - zeroCountByFeature: Map[String, Int], - featureSize: Int, baseScore: Array[Double], k: Int, indexToExamine: Int ): Seq[LOCOValue] = { val minMaxHeap = new MinMaxHeap(k) - val aggregationMap = mutable.Map.empty[String, (Array[Int], Array[Double])] - - agggregateDiffs(featureSparse, indexToExamine, minMaxHeap, aggregationMap, - baseScore) - - // Aggregation map contains aggregation of Unit Circle Dates and Hashed Text Features - // Adding LOCO results from aggregation map into heaps - for {(name, (indices, ar)) <- aggregationMap} { - // The index here is arbitrary - val (i, n) = (indices.head, indices.length) - val zeroCounts = zeroCountByFeature.get(name).getOrElse(0) - val diffToExamine = ar.map(_ / (n + zeroCounts)) - minMaxHeap enqueue LOCOValue(i, diffToExamine(indexToExamine), diffToExamine) - } - minMaxHeap.dequeueAll - } + // Map[FeatureName, (Array[SparseVectorIndices], Array[ActualIndices]) + val aggActiveIndices = mutable.Map.empty[String, Array[(Int, Int)]] - private def agggregateDiffs( - featureVec: SparseVector, - indexToExamine: Int, - minMaxHeap: MinMaxHeap, - aggregationMap: mutable.Map[String, (Array[Int], Array[Double])], - baseScore: Array[Double] - ): Unit = { - computeDiffs(featureVec, baseScore).foreach { case (i, oldInd, diffToExamine) => + (0 until featureSparse.size, featureSparse.indices).zipped.foreach { case (i: Int, oldInd: Int) => val history = histories(oldInd) history match { - // If indicator value and descriptor value of a derived text feature are empty, then it is - // a hashing tf output. We aggregate such features for each (rawFeatureName). - case h if (textFeatureIndices ++ dateFeatureIndices).contains(oldInd) => + case h if isTextFeature(h) || isDateFeature(h) => { for {name <- getRawFeatureName(h)} { - val (indices, array) = aggregationMap.getOrElse(name, (Array.empty[Int], Array.empty[Double])) - aggregationMap.update(name, (indices :+ i, sumArrays(array, diffToExamine))) + val indices = aggActiveIndices.getOrElse(name, (Array.empty[(Int, Int)])) + aggActiveIndices.update(name, indices :+ (i, oldInd)) } - case _ => minMaxHeap enqueue LOCOValue(i, diffToExamine(indexToExamine), diffToExamine) + } + case _ => { + val diffToExamine = computeDiff(featureSparse.copy.updated(i, oldInd, 0.0), baseScore) + minMaxHeap enqueue LOCOValue(i, diffToExamine(indexToExamine), diffToExamine) + } } } - } - private def computeDiffs( - featureVec: SparseVector, - baseScore: Array[Double] - ) = { - (0 until featureVec.size, featureVec.indices).zipped.map { case (i, oldInd) => - (i, oldInd, computeDiff(featureVec.copy.updated(i, oldInd, 0.0), baseScore)) + // Aggregate active indices of each text feature and date feature based on vector aggregate strategy. + aggActiveIndices.foreach { + case (name, aggIndices) => + val diffToExamine = aggregateDiffs(featureSparse, aggIndices, + getVectorAggregationStrategy, baseScore, aggFeaturesSize.get(name).get) + minMaxHeap enqueue LOCOValue(aggIndices.head._1, diffToExamine(indexToExamine), diffToExamine) } + + minMaxHeap.dequeueAll } override def transformFn: OPVector => TextMap = features => { val baseResult = modelApply(labelDummy, features) val baseScore = baseResult.score - val featureSize = features.value.size // TODO: sparse implementation only works if changing values to zero - use dense vector to test effect of zeros val featuresSparse = features.value.toSparse - val featureIndexSet = featuresSparse.indices.toSet - - // Besides non 0 values, we want to check the text/date features as well - val zeroValIndices = (textFeatureIndices ++ dateFeatureIndices) - .filterNot(featureIndexSet.contains) - - // Count zeros by feature name - val zeroCountByFeature = zeroValIndices - .groupBy(i => getRawFeatureName(histories(i)).get) - .mapValues(_.length).view.toMap val k = $(topK) // Index where to examine the difference in the prediction vector @@ -254,14 +259,14 @@ class RecordInsightsLOCO[T <: Model[T]] // For MultiClassification, the value is from the predicted class(i.e. the class having the highest probability) case n if n > 2 => baseResult.prediction.toInt } - val topPosNeg = returnTopPosNeg(featuresSparse, zeroCountByFeature, featureSize, baseScore, k, indexToExamine) + val topPosNeg = returnTopPosNeg(featuresSparse, baseScore, k, indexToExamine) val top = getTopKStrategy match { case TopKStrategy.Abs => topPosNeg.sortBy { case LOCOValue(_, v, _) => -math.abs(v) }.take(k) // Take top K positive and top K negative LOCOs, hence 2 * K case TopKStrategy.PositiveNegative => topPosNeg.sortBy { case LOCOValue(_, v, _) => -v }.take(2 * k) } - val allIndices = featuresSparse.indices ++ zeroValIndices + val allIndices = featuresSparse.indices top.map { case LOCOValue(i, _, diffs) => RecordInsightsParser.insightToText(featureInfo(allIndices(i)), diffs) }.toMap.toTextMap @@ -329,3 +334,14 @@ object TopKStrategy extends Enum[TopKStrategy] { case object Abs extends TopKStrategy("abs") case object PositiveNegative extends TopKStrategy("positive and negative") } + + +sealed abstract class VectorAggregationStrategy(val name: String) extends EnumEntry with Serializable + +object VectorAggregationStrategy extends Enum[VectorAggregationStrategy] { + val values = findValues + case object LeaveOutVector extends + VectorAggregationStrategy("calculate the loco by leaving out the entire vector") + case object Avg extends + VectorAggregationStrategy("calculate the loco for each column of the vector and then average all the locos") +} diff --git a/core/src/test/scala/com/salesforce/op/stages/impl/insights/RecordInsightsLOCOTest.scala b/core/src/test/scala/com/salesforce/op/stages/impl/insights/RecordInsightsLOCOTest.scala index b93eb7b7f5..c6178bb584 100644 --- a/core/src/test/scala/com/salesforce/op/stages/impl/insights/RecordInsightsLOCOTest.scala +++ b/core/src/test/scala/com/salesforce/op/stages/impl/insights/RecordInsightsLOCOTest.scala @@ -44,8 +44,7 @@ import com.salesforce.op.testkit.{RandomIntegral, RandomMap, RandomReal, RandomT import com.salesforce.op.utils.spark.RichDataset._ import com.salesforce.op.utils.spark.{OpVectorColumnHistory, OpVectorColumnMetadata, OpVectorMetadata} import com.salesforce.op.{FeatureHistory, OpWorkflow, _} -import org.apache.spark.ml.PredictionModel -import org.apache.spark.ml.classification.LogisticRegressionModel +import org.apache.spark.ml.Model import org.apache.spark.ml.linalg._ import org.apache.spark.ml.regression.LinearRegressionModel import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -54,7 +53,7 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{DataFrame, Encoder, Row} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner -import org.scalatest.{FlatSpec, FunSpec, Suite} +import org.scalatest.{FunSpec, Suite} @RunWith(classOf[JUnitRunner]) @@ -293,101 +292,60 @@ class RecordInsightsLOCOTest extends FunSpec with TestSparkContext with RecordIn } } - it("should aggregate values for text and textMap derived features") { - val testData = generateTestTextData - withClue("TextArea can have two null indicator values") { - testData.actualRecordInsights.map(p => assert(p.size == 7 || p.size == 8)) - } - withClue("SmartTextVectorizer detects country feature as a PickList, hence no " + - "aggregation required for LOCO on this field.") { - testData.actualRecordInsights.foreach { p => - assert(p.keys.exists(r => r.parentFeatureOrigins == Seq(countryFeatureName) && r.indicatorValue.isDefined)) - } - } + for {strategy <- VectorAggregationStrategy.values} { + it (s"aggregate values for text and textMap derived features when strategy=$strategy") { + val (df, featureVector, label) = generateTestTextData + val model = new OpLogisticRegression().setInput(label, featureVector).fit(df) + val actualInsights = generateRecordInsights(model, df, featureVector, strategy) - assertLOCOSum(testData.actualRecordInsights) - assertAggregatedText(textFeatureName) - assertAggregatedTextMap(textMapFeatureName, "k0") - assertAggregatedTextMap(textMapFeatureName, "k1") - assertAggregatedText(textAreaFeatureName) - assertAggregatedTextMap(textAreaMapFeatureName, "k0") - assertAggregatedTextMap(textAreaMapFeatureName, "k1") - - - /** - * Compare the aggregation made by RecordInsightsLOCO on a text field to one made manually - * - * @param textFeatureName Text Field Name - */ - def assertAggregatedText(textFeatureName: String): Unit = { - withClue(s"Aggregate all the derived hashing tf features of rawFeature - $textFeatureName.") { - val predicate = (history: OpVectorColumnHistory) => history.parentFeatureOrigins == Seq(textFeatureName) && - history.indicatorValue.isEmpty && history.descriptorValue.isEmpty - assertAggregatedWithPredicate(predicate, testData) + withClue("TextArea can have two null indicator values") { + actualInsights.map(p => assert(p.size == 7 || p.size == 8)) } - } - - /** - * Compare the aggregation made by RecordInsightsLOCO to one made manually - * - * @param textMapFeatureName Text Map Field Name - */ - def assertAggregatedTextMap(textMapFeatureName: String, keyName: String): Unit = { - withClue(s"Aggregate all the derived hashing tf of rawMapFeature - $textMapFeatureName for key - $keyName") { - val predicate = (history: OpVectorColumnHistory) => history.parentFeatureOrigins == Seq(textMapFeatureName) && - history.grouping == Option(keyName) && history.indicatorValue.isEmpty && history.descriptorValue.isEmpty - assertAggregatedWithPredicate(predicate, testData) + withClue("SmartTextVectorizer detects country feature as a PickList, hence no " + + "aggregation required for LOCO on this field.") { + actualInsights.foreach { p => + assert(p.keys.exists(r => r.parentFeatureOrigins == Seq(countryFeatureName) + && r.indicatorValue.isDefined)) + } } + + assertLOCOSum(actualInsights) + assertAggregatedText(textFeatureName, strategy, model, df, featureVector, label, actualInsights) + assertAggregatedText(textAreaFeatureName, strategy, model, df, featureVector, label, actualInsights) + assertAggregatedTextMap(textMapFeatureName, "k0", strategy, model, df, featureVector, label, + actualInsights) + assertAggregatedTextMap(textMapFeatureName, "k1", strategy, model, df, featureVector, label, + actualInsights) + assertAggregatedTextMap(textAreaMapFeatureName, "k0", strategy, model, df, featureVector, label, + actualInsights) + assertAggregatedTextMap(textAreaMapFeatureName, "k1", strategy, model, df, featureVector, label, + actualInsights) } } - it("should aggregate values for date, datetime, dateMap and dateTimeMap derived features") { - val testData = generateTestDateData - - assertLOCOSum(testData.actualRecordInsights) - assertAggregatedDate(dateFeatureName) - assertAggregatedDate(dateTimeFeatureName) - assertAggregatedDateMap(dateMapFeatureName, "k0") - assertAggregatedDateMap(dateMapFeatureName, "k1") - assertAggregatedDateMap(dateTimeMapFeatureName, "k0") - assertAggregatedDateMap(dateTimeMapFeatureName, "k1") - - /** - * Compare the aggregation made by RecordInsightsLOCO on a Date/DateTime field to one made manually - * - * @param dateFeatureName Date/DateTime Field - */ - def assertAggregatedDate(dateFeatureName: String): Unit = { - for {timePeriod <- TransmogrifierDefaults.CircularDateRepresentations} { - withClue(s"Aggregate x_$timePeriod and y_$timePeriod of rawFeature - $dateFeatureName.") { - val predicate = (history: OpVectorColumnHistory) => history.parentFeatureOrigins == Seq(dateFeatureName) && - history.descriptorValue.isDefined && - history.descriptorValue.get.split("_").last == timePeriod.entryName - assertAggregatedWithPredicate(predicate, testData) - } - } - } - /** - * Compare the aggregation made by RecordInsightsLOCO on a DateMap/DateTimeMap field to one made manually - * - * @param dateMapFeatureName DateMap/DateTimeMap Field - */ - def assertAggregatedDateMap(dateMapFeatureName: String, keyName: String): Unit = { - for {timePeriod <- TransmogrifierDefaults.CircularDateRepresentations} { - withClue(s"Aggregate x_$timePeriod and y_$timePeriod of rawMapFeature - $dateMapFeatureName " + - s"with key as $keyName.") { - val predicate = (history: OpVectorColumnHistory) => - history.parentFeatureOrigins == Seq(dateMapFeatureName) && - history.grouping == Option(keyName) && history.descriptorValue.isDefined && - history.descriptorValue.get.split("_").last == timePeriod.entryName - assertAggregatedWithPredicate(predicate, testData) - } - } + for {strategy <- VectorAggregationStrategy.values} { + it (s"aggregate values for date, datetime, dateMap and dateTimeMap derived features when strategy=$strategy") { + val (df, featureVector, label) = generateTestDateData + val model = new OpLogisticRegression().setInput(label, featureVector).fit(df) + val actualInsights = generateRecordInsights(model, df, featureVector, strategy, topK = 40) + + assertLOCOSum(actualInsights) + assertAggregatedDate(dateFeatureName, strategy, model, df, featureVector, label, actualInsights) + assertAggregatedDate(dateTimeFeatureName, strategy, model, df, featureVector, label, actualInsights) + assertAggregatedDateMap(dateMapFeatureName, "k0", strategy, model, df, featureVector, label, + actualInsights) + assertAggregatedDateMap(dateMapFeatureName, "k1", strategy, model, df, featureVector, label, + actualInsights) + assertAggregatedDateMap(dateTimeMapFeatureName, "k0", strategy, model, df, featureVector, label, + actualInsights) + assertAggregatedDateMap(dateTimeMapFeatureName, "k1", strategy, model, df, featureVector, label, + actualInsights) } } } + private def addMetaData(df: DataFrame, fieldName: String, size: Int): DataFrame = { val columns = (0 until size).map(_.toString).map(i => new OpVectorColumnMetadata(Seq(i), Seq(i), Some(i), Some(i))) val hist = (0 until size).map(_.toString).map(i => i -> FeatureHistory(Seq(s"a_$i"), Seq(s"b_$i"))).toMap @@ -405,6 +363,93 @@ class RecordInsightsLOCOTest extends FunSpec with TestSparkContext with RecordIn } } + /** + * Compare the aggregation made by RecordInsightsLOCO on a text field to one made manually + * + * @param textFeatureName Text Field Name + */ + def assertAggregatedText(textFeatureName: String, + strategy: VectorAggregationStrategy, + model: OpPredictorWrapperModel[_], + df: DataFrame, + featureVector: FeatureLike[OPVector], + label: FeatureLike[RealNN], + actualInsights: Array[Map[OpVectorColumnHistory, Insights]] + ): Unit = { + withClue(s"Aggregate all the derived hashing tf features of rawFeature - $textFeatureName.") { + val predicate = (history: OpVectorColumnHistory) => history.parentFeatureOrigins == Seq(textFeatureName) && + history.indicatorValue.isEmpty && history.descriptorValue.isEmpty + assertAggregatedWithPredicate(predicate, strategy, model, df, featureVector, label, actualInsights) + } + } + + /** + * Compare the aggregation made by RecordInsightsLOCO to one made manually + * + * @param textMapFeatureName Text Map Field Name + */ + def assertAggregatedTextMap(textMapFeatureName: String, keyName: String, + strategy: VectorAggregationStrategy, + model: OpPredictorWrapperModel[_], + df: DataFrame, + featureVector: FeatureLike[OPVector], + label: FeatureLike[RealNN], + actualInsights: Array[Map[OpVectorColumnHistory, Insights]] + ): Unit = { + withClue(s"Aggregate all the derived hashing tf of rawMapFeature - $textMapFeatureName for key - $keyName") { + val predicate = (history: OpVectorColumnHistory) => history.parentFeatureOrigins == Seq(textMapFeatureName) && + history.grouping == Option(keyName) && history.indicatorValue.isEmpty && history.descriptorValue.isEmpty + assertAggregatedWithPredicate(predicate, strategy, model, df, featureVector, label, actualInsights) + } + } + + /** + * Compare the aggregation made by RecordInsightsLOCO on a Date/DateTime field to one made manually + * + * @param dateFeatureName Date/DateTime Field + */ + def assertAggregatedDate(dateFeatureName: String, + strategy: VectorAggregationStrategy, + model: OpPredictorWrapperModel[_], + df: DataFrame, + featureVector: FeatureLike[OPVector], + label: FeatureLike[RealNN], + actualInsights: Array[Map[OpVectorColumnHistory, Insights]] + ): Unit = { + for {timePeriod <- TransmogrifierDefaults.CircularDateRepresentations} { + withClue(s"Aggregate x_$timePeriod and y_$timePeriod of rawFeature - $dateFeatureName.") { + val predicate = (history: OpVectorColumnHistory) => history.parentFeatureOrigins == Seq(dateFeatureName) && + history.descriptorValue.isDefined && + history.descriptorValue.get.split("_").last == timePeriod.entryName + assertAggregatedWithPredicate(predicate, strategy, model, df, featureVector, label, actualInsights) + } + } + } + + /** + * Compare the aggregation made by RecordInsightsLOCO on a DateMap/DateTimeMap field to one made manually + * + * @param dateMapFeatureName DateMap/DateTimeMap Field + */ + def assertAggregatedDateMap(dateMapFeatureName: String, keyName: String, + strategy: VectorAggregationStrategy, + model: OpPredictorWrapperModel[_], + df: DataFrame, + featureVector: FeatureLike[OPVector], + label: FeatureLike[RealNN], + actualInsights: Array[Map[OpVectorColumnHistory, Insights]] + ): Unit = { + for {timePeriod <- TransmogrifierDefaults.CircularDateRepresentations} { + withClue(s"Aggregate x_$timePeriod and y_$timePeriod of rawMapFeature - $dateMapFeatureName " + + s"with key as $keyName.") { + val predicate = (history: OpVectorColumnHistory) => history.parentFeatureOrigins == Seq(dateMapFeatureName) && + history.grouping == Option(keyName) && history.descriptorValue.isDefined && + history.descriptorValue.get.split("_").last == timePeriod.entryName + assertAggregatedWithPredicate(predicate, strategy, model, df, featureVector, label, actualInsights) + } + } + } + /** * Compare the aggregation made by RecordInsightsLOCO to one made manually * @@ -412,34 +457,46 @@ class RecordInsightsLOCOTest extends FunSpec with TestSparkContext with RecordIn */ private def assertAggregatedWithPredicate( predicate: OpVectorColumnHistory => Boolean, - testData: RecordInsightsTestData[LogisticRegressionModel] + strategy: VectorAggregationStrategy, + model: OpPredictorWrapperModel[_], + df: DataFrame, + featureVector: FeatureLike[OPVector], + label: FeatureLike[RealNN], + actualRecordInsights: Array[Map[OpVectorColumnHistory, Insights]] ): Unit = { implicit val enc: Encoder[(Array[Double], Long)] = ExpressionEncoder() implicit val enc2: Encoder[Seq[Double]] = ExpressionEncoder() - val meta = OpVectorMetadata.apply(testData.featureTransformedDF.schema(testData.featureVector.name)) + val meta = OpVectorMetadata.apply(df.schema(featureVector.name)) val indices = meta.getColumnHistory() .filter(predicate) .map(_.index) - val expectedLocos = testData.featureTransformedDF.select(testData.label, testData.featureVector).map { + val expectedLocos = df.select(label, featureVector).map { case Row(l: Double, v: Vector) => - val featureArray = v.toArray - val locos = indices.map { i => - val oldVal = v(i) - val baseScore = testData.sparkModel.transformFn(l.toRealNN, v.toOPVector).score.toSeq - featureArray.update(i, 0.0) - val newScore = testData.sparkModel.transformFn(l.toRealNN, featureArray.toOPVector).score.toSeq - featureArray.update(i, oldVal) - baseScore.zip(newScore).map { case (b, n) => b - n } + val featureArray = v.copy.toArray + val baseScore = model.transformFn(l.toRealNN, v.toOPVector).score.toSeq + strategy match { + case VectorAggregationStrategy.Avg => + val locos = indices.map { i => + val oldVal = v(i) + featureArray.update(i, 0.0) + val newScore = model.transformFn(l.toRealNN, featureArray.toOPVector).score.toSeq + featureArray.update(i, oldVal) + baseScore.zip(newScore).map { case (b, n) => b - n } + } + val sumLOCOs = locos.reduce((a1, a2) => a1.zip(a2).map { case (l, r) => l + r }) + sumLOCOs.map(_ / indices.length) + case VectorAggregationStrategy.LeaveOutVector => + indices.foreach { i => featureArray.update(i, 0.0) } + val newScore = model.transformFn(l.toRealNN, featureArray.toOPVector).score.toSeq + baseScore.zip(newScore).map { case (b, n) => b - n } } - val sumLOCOs = locos.reduce((a1, a2) => a1.zip(a2).map { case (l, r) => l + r }) - sumLOCOs.map(_ / indices.length) } val expected = expectedLocos.collect().toSeq.filter(_.head != 0.0) - val actual = testData.actualRecordInsights + val actual = actualRecordInsights .flatMap(_.find { case (history, _) => predicate(history) }) .map(_._2.map(_._2)).toSeq val zip = actual.zip(expected) @@ -449,8 +506,22 @@ class RecordInsightsLOCOTest extends FunSpec with TestSparkContext with RecordIn } } } + + private def generateRecordInsights[T <: Model[T]]( + model: T, + df: DataFrame, + featureVector: FeatureLike[OPVector], + strategy: VectorAggregationStrategy, + topK: Int = 20 + ): Array[Map[OpVectorColumnHistory, Insights]] = { + val transformer = new RecordInsightsLOCO(model).setInput(featureVector).setTopK(topK) + .setVectorAggregationStrategy(strategy) + val insights = transformer.transform(df) + insights.collect(transformer.getOutput()).map(i => RecordInsightsParser.parseInsights(i)) + } } + trait RecordInsightsTestDataGenerator extends TestSparkContext { self: Suite => @@ -471,7 +542,7 @@ trait RecordInsightsTestDataGenerator extends TestSparkContext { val textAreaFeatureName = "textArea" val textAreaMapFeatureName = "textAreaMap" - def generateTestDateData: RecordInsightsTestData[LogisticRegressionModel] = { + def generateTestDateData: (DataFrame, FeatureLike[OPVector], FeatureLike[RealNN]) = { val refDate = TransmogrifierDefaults.ReferenceDate.minusMillis(1) val minStep = 1000000 @@ -515,19 +586,10 @@ trait RecordInsightsTestDataGenerator extends TestSparkContext { val featureVector = Seq(dateVector, datetimeVector, dateMapVector, datetimeMapVector).combine() val featureTransformedDF = new OpWorkflow().setResultFeatures(featureVector, label).transform(rawData) - // Train a model - val sparkModel = new OpLogisticRegression().setInput(label, featureVector).fit(featureTransformedDF) - - // RecordInsightsLOCO - val locoTransformer = new RecordInsightsLOCO(sparkModel).setInput(featureVector).setTopK(40) - val locoInsights = locoTransformer.transform(featureTransformedDF) - val parsedInsights = locoInsights.collect(locoTransformer.getOutput()).map(i => - RecordInsightsParser.parseInsights(i)) - - RecordInsightsTestData(rawData, featureTransformedDF, featureVector, label, sparkModel, parsedInsights) + (featureTransformedDF, featureVector, label) } - def generateTestTextData: RecordInsightsTestData[LogisticRegressionModel] = { + def generateTestTextData: (DataFrame, FeatureLike[OPVector], FeatureLike[RealNN]) = { // Random Text Data val textData: Seq[Text] = RandomText.strings(5, 10).withProbabilityOfEmpty(0.3).take(numRows).toList @@ -605,29 +667,10 @@ trait RecordInsightsTestDataGenerator extends TestSparkContext { // Sanity Checker val checker = new SanityChecker().setInput(label, featureVector) - val checked = checker.fit(vectorized).transform(vectorized) + val checkedDf = checker.fit(vectorized).transform(vectorized) val checkedFeatureVector = checker.getOutput() - // RecordInsightsLOCO - val sparkModel = new OpLogisticRegression().setInput(label, checkedFeatureVector).fit(checked) - - val transformer = new RecordInsightsLOCO(sparkModel).setInput(checkedFeatureVector) - - val insights = transformer.transform(checked) - - val parsed = insights.collect(transformer.getOutput()).map(i => RecordInsightsParser.parseInsights(i)) - - RecordInsightsTestData(testData, checked, checkedFeatureVector, label, sparkModel, parsed) + (checkedDf, checkedFeatureVector, label) } } - -case class RecordInsightsTestData[M <: PredictionModel[Vector, M]] -( - rawDF: DataFrame, - featureTransformedDF: DataFrame, - featureVector: FeatureLike[OPVector], - label: FeatureLike[RealNN], - sparkModel: OpPredictorWrapperModel[M], - actualRecordInsights: Array[Map[OpVectorColumnHistory, Insights]] -)