Skip to content

Commit

Permalink
Aggregate LOCOs of DateToUnitCircleTransformer. (#349)
Browse files Browse the repository at this point in the history
  • Loading branch information
sanmitra authored and leahmcguire committed Jul 11, 2019
1 parent 28eac0c commit 87aca8d
Show file tree
Hide file tree
Showing 2 changed files with 317 additions and 126 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,48 @@

package com.salesforce.op.stages.impl.insights

import com.salesforce.op.UID
import com.salesforce.op.{FeatureInsights, UID}
import com.salesforce.op.features.types._
import com.salesforce.op.stages.base.unary.UnaryTransformer
import com.salesforce.op.stages.impl.feature.{DateToUnitCircle, TimePeriod}
import com.salesforce.op.stages.impl.selector.SelectedModel
import com.salesforce.op.stages.sparkwrappers.specific.OpPredictorWrapperModel
import com.salesforce.op.stages.sparkwrappers.specific.SparkModelConverter._
import com.salesforce.op.utils.spark.OpVectorMetadata
import com.salesforce.op.utils.spark.{OpVectorColumnHistory, OpVectorMetadata}
import enumeratum.{Enum, EnumEntry}
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.Model
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param.{IntParam, Param}
import org.apache.spark.ml.param.{IntParam, Param, Params}

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import scala.reflect.runtime.universe._


trait RecordInsightsLOCOParams extends Params {

final val topK = new IntParam(
parent = this, name = "topK",
doc = "Number of insights to keep for each record"
)
def setTopK(value: Int): this.type = set(topK, value)
def getTopK: Int = $(topK)

final val topKStrategy = new Param[String](parent = this, name = "topKStrategy",
doc = "Whether returning topK based on absolute value or topK positives and negatives. For MultiClassification," +
" the value is from the predicted class (i.e. the class having the highest probability)"
)
def setTopKStrategy(strategy: TopKStrategy): this.type = set(topKStrategy, strategy.entryName)
def getTopKStrategy: TopKStrategy = TopKStrategy.withName($(topKStrategy))

setDefault(
topK -> 20,
topKStrategy -> TopKStrategy.Abs.entryName
)
}

/**
* Creates record level insights for model predictions. Takes the model to explain as a constructor argument.
* The input feature is the feature vector fed into the model.
Expand All @@ -65,24 +91,8 @@ class RecordInsightsLOCO[T <: Model[T]]
(
val model: T,
uid: String = UID[RecordInsightsLOCO[_]]
) extends UnaryTransformer[OPVector, TextMap](operationName = "recordInsightsLOCO", uid = uid) {

final val topK = new IntParam(
parent = this, name = "topK",
doc = "Number of insights to keep for each record"
)

def setTopK(value: Int): this.type = set(topK, value)
def getTopK: Int = $(topK)
setDefault(topK -> 20)

final val topKStrategy = new Param[String](parent = this, name = "topKStrategy",
doc = "Whether returning topK based on absolute value or topK positives and negatives. For MultiClassification," +
" the value is from the predicted class (i.e. the class having the highest probability)"
)
def setTopKStrategy(strategy: TopKStrategy): this.type = set(topKStrategy, strategy.entryName)
def getTopKStrategy: TopKStrategy = TopKStrategy.withName($(topKStrategy))
setDefault(topKStrategy, TopKStrategy.Abs.entryName)
) extends UnaryTransformer[OPVector, TextMap](operationName = "recordInsightsLOCO", uid = uid)
with RecordInsightsLOCOParams {

private val modelApply = model match {
case m: SelectedModel => m.transformFn
Expand All @@ -100,10 +110,23 @@ class RecordInsightsLOCO[T <: Model[T]]
Set(FeatureType.typeName[Text], FeatureType.typeName[TextArea], FeatureType.typeName[TextList])
private val textMapTypes =
Set(FeatureType.typeName[TextMap], FeatureType.typeName[TextAreaMap])
private val dateTypes =
Set(FeatureType.typeName[Date], FeatureType.typeName[DateTime])
private val dateMapTypes =
Set(FeatureType.typeName[DateMap], FeatureType.typeName[DateTimeMap])

// Indices of features derived from Text(Map)Vectorizer
private lazy val textFeatureIndices = histories
.filter(_.parentFeatureType.exists((textTypes ++ textMapTypes).contains))
private lazy val textFeatureIndices = getIndicesOfFeatureType(textTypes ++ textMapTypes)

// Indices of features derived from Date(Map)Vectorizer
private lazy val dateFeatureIndices = getIndicesOfFeatureType(dateTypes ++ dateMapTypes)

/**
* Return the indices of features derived from given types.
* @return Seq[Int]
*/
private def getIndicesOfFeatureType (types: Set[String]): Seq[Int] = histories
.filter(_.parentFeatureType.exists(types.contains))
.map(_.index)
.distinct.sorted

Expand All @@ -127,6 +150,19 @@ class RecordInsightsLOCO[T <: Model[T]]
left.zipAll(right, 0.0, 0.0).map { case (l, r) => l + r }
}

/**
* Optionally convert columnMetadata's descriptorValue like
* "y_DayOfWeek", "x_DayOfWeek" to TimePeriod enum - DayOfWeek.
* @return Option[TimePeriod]
*/
private def convertToTimePeriod(descriptorValue: String): Option[TimePeriod] =
descriptorValue.split("_").lastOption.flatMap(TimePeriod.withNameInsensitiveOption)

private def getRawFeatureName(history: OpVectorColumnHistory): Option[String] = history.grouping match {
case Some(grouping) => history.parentFeatureOrigins.headOption.map(_ + "_" + grouping)
case None => history.parentFeatureOrigins.headOption
}

private def returnTopPosNeg
(
featureArray: Array[(Int, Double)],
Expand All @@ -143,26 +179,23 @@ class RecordInsightsLOCO[T <: Model[T]]
val diffToExamine = computeDiffs(i, oldInd, oldVal, featureArray, featureSize, baseScore)
val history = histories(oldInd)

// Let's check the indicator value and descriptor value
// If those values are empty, the field is likely to be a derived text feature (hashing tf output)
if (textFeatureIndices.contains(oldInd) && history.indicatorValue.isEmpty && history.descriptorValue.isEmpty) {
// Name of the field
val rawName = history.parentFeatureType match {
case s if s.exists(textMapTypes.contains) => {
val grouping = history.grouping
history.parentFeatureOrigins.headOption.map(_ + "_" + grouping.getOrElse(""))
history match {
// If indicator value and descriptor value of a derived text feature are empty, then it is likely
// to be a hashing tf output. We aggregate such features for each (rawFeatureName).
case h if h.indicatorValue.isEmpty && h.descriptorValue.isEmpty && textFeatureIndices.contains(oldInd) =>
for {name <- getRawFeatureName(h)} {
val (indices, array) = aggregationMap.getOrElse(name, (Array.empty[Int], Array.empty[Double]))
aggregationMap.update(name, (indices :+ i, sumArrays(array, diffToExamine)))
}
// If the descriptor value of a derived date feature exists, then it is likely to be
// from unit circle transformer. We aggregate such features for each (rawFeatureName, timePeriod).
case h if h.descriptorValue.isDefined && dateFeatureIndices.contains(oldInd) =>
for {name <- getRawFeatureName(h)} {
val key = name + h.descriptorValue.flatMap(convertToTimePeriod).map(p => "_" + p.entryName).getOrElse("")
val (indices, array) = aggregationMap.getOrElse(key, (Array.empty[Int], Array.empty[Double]))
aggregationMap.update(key, (indices :+ i, sumArrays(array, diffToExamine)))
}
case s if s.exists(textTypes.contains) => history.parentFeatureOrigins.headOption
case s => throw new Error(s"type should be Text or TextMap, here ${s.mkString(",")}")
}
// Update the aggregation map
for {name <- rawName} {
val key = name
val (indices, array) = aggregationMap.getOrElse(key, (Array.empty[Int], Array.empty[Double]))
aggregationMap.update(key, (indices :+ i, sumArrays(array, diffToExamine)))
}
} else {
minMaxHeap enqueue LOCOValue(i, diffToExamine(indexToExamine), diffToExamine)
case _ => minMaxHeap enqueue LOCOValue(i, diffToExamine(indexToExamine), diffToExamine)
}
}

Expand All @@ -185,8 +218,8 @@ class RecordInsightsLOCO[T <: Model[T]]
val featuresSparse = features.value.toSparse
val res = ArrayBuffer.empty[(Int, Double)]
featuresSparse.foreachActive((i, v) => res += i -> v)
// Besides non 0 values, we want to check the text features as well
textFeatureIndices.foreach(i => if (!featuresSparse.indices.contains(i)) res += i -> 0.0)
// Besides non 0 values, we want to check the text/date features as well
(textFeatureIndices ++ dateFeatureIndices).foreach(i => if (!featuresSparse.indices.contains(i)) res += i -> 0.0)
val featureArray = res.toArray
val featureSize = featuresSparse.size

Expand Down
Loading

0 comments on commit 87aca8d

Please sign in to comment.