Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Descale feature contribution for Linear Regression & Logistic Regression #345

Merged
merged 43 commits into from
Jul 25, 2019
Merged
Changes from 1 commit
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
fe8e6df
move PR to a branch on Tmog
TuanNguyen27 Jun 25, 2019
4ad6204
add condition to descale only when it is Linear / Logistic regression…
TuanNguyen27 Jun 25, 2019
1e2bc81
fix modelinsighttest
TuanNguyen27 Jun 25, 2019
5cbe132
only compute descaled contrib if a model is present & fits our criteria
TuanNguyen27 Jun 25, 2019
673dad8
fix test failure with a check for empty list of feature contribution
TuanNguyen27 Jun 25, 2019
4c44263
addressing comments
TuanNguyen27 Jun 27, 2019
d3fa01b
more comment addressing
TuanNguyen27 Jun 27, 2019
7a882f2
test in progress, still broken
TuanNguyen27 Jun 27, 2019
8ba36c2
Merge branch 'master' into tn/descaleLR
tovbinm Jun 27, 2019
e4da5ab
seems to be working
TuanNguyen27 Jun 27, 2019
0767e84
first version of test
TuanNguyen27 Jun 28, 2019
992db1a
Merge branch 'tn/descaleLR' of https://github.com/salesforce/Transmog…
TuanNguyen27 Jun 28, 2019
7eaa209
fix scala style
TuanNguyen27 Jun 28, 2019
2c5d2f7
Merge branch 'master' into tn/descaleLR
tovbinm Jul 2, 2019
99236f4
Merge branch 'master' into tn/descaleLR
TuanNguyen27 Jul 3, 2019
e1bb482
addressing comments
TuanNguyen27 Jul 8, 2019
00fb0d6
Merge branch 'master' into tn/descaleLR
TuanNguyen27 Jul 8, 2019
27a2449
Merge branch 'master' into tn/descaleLR
TuanNguyen27 Jul 9, 2019
cb798db
change log to warn
TuanNguyen27 Jul 9, 2019
3a541ed
Merge branch 'master' into tn/descaleLR
leahmcguire Jul 11, 2019
083448a
Merge branch 'master' into tn/descaleLR
leahmcguire Jul 11, 2019
606b6e1
fix an error in calculating standard deviation for discrete distribution
TuanNguyen27 Jul 11, 2019
1643635
Merge branch 'tn/descaleLR' of https://github.com/salesforce/Transmog…
TuanNguyen27 Jul 11, 2019
6ca53fc
Merge branch 'master' into tn/descaleLR
leahmcguire Jul 11, 2019
4c432e6
correctly pull out standard deviation for each feature
TuanNguyen27 Jul 11, 2019
4f252bd
Merge branch 'master' into tn/descaleLR
TuanNguyen27 Jul 12, 2019
4be8752
Merge branch 'master' into tn/descaleLR
TuanNguyen27 Jul 12, 2019
0e63491
descale entire contribution vector & clearly separate out between lin…
TuanNguyen27 Jul 12, 2019
e6de82b
Merge branch 'tn/descaleLR' of https://github.com/salesforce/Transmog…
TuanNguyen27 Jul 12, 2019
54d28a1
Merge branch 'master' into tn/descaleLR
TuanNguyen27 Jul 14, 2019
4677c96
fix scala idiom
TuanNguyen27 Jul 15, 2019
fa4221f
Merge branch 'tn/descaleLR' of https://github.com/salesforce/Transmog…
TuanNguyen27 Jul 15, 2019
a5901d8
remove logistic regression pattern matching
TuanNguyen27 Jul 15, 2019
90ff504
add citations for future readability
TuanNguyen27 Jul 15, 2019
a7dea4e
refactor & add test for binary logistic regression case
TuanNguyen27 Jul 18, 2019
35bdfe8
Merge branch 'master' into tn/descaleLR
TuanNguyen27 Jul 18, 2019
c6bae48
remove redundant import
TuanNguyen27 Jul 18, 2019
bc60187
fix scala style
TuanNguyen27 Jul 19, 2019
a6839b2
fix scala style again
TuanNguyen27 Jul 19, 2019
23b2443
Update warning message
TuanNguyen27 Jul 22, 2019
36c8420
update failure threshold so test will pass
TuanNguyen27 Jul 22, 2019
c80cc1a
update test to be ratio instead of absolute difference
TuanNguyen27 Jul 24, 2019
51627e9
small update to set tolerance threshold
TuanNguyen27 Jul 25, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 21 additions & 22 deletions core/src/main/scala/com/salesforce/op/ModelInsights.scala
Original file line number Diff line number Diff line change
Expand Up @@ -562,22 +562,23 @@ case object ModelInsights {
val sparkFtrContrib = keptIndex
.map(i => contributions.map(_.applyOrElse(i, (_: Int) => 0.0))).getOrElse(Seq.empty)
val LRStandardization = checkLRStandardization(model).getOrElse(false)
val labelStd = label.distribution.getOrElse(1.0) match {
case Continuous(_, _, _, variance) => math.sqrt(variance)
// for (binary) logistic regression we only need to multiply by feature standard deviation
case Discrete(domain, prob) =>
def computeVariance(domain: Seq[String], prob: Seq[Double]): Double = {
val floatDomain = domain.map(_.toDouble)
val sqfloatDomain = floatDomain.map(math.pow(_, 2))
val weighted = (floatDomain, prob).zipped map (_ * _)
val sqweighted = (sqfloatDomain, prob).zipped map (_ * _)
val mean = weighted.sum
return sqweighted.sum - mean
}
computeVariance(domain, prob)
}
// TODO: throw exception if (labelStd == 0)

val labelStd = label.distribution match {
case Some(Continuous(_, _, _, variance)) => math.sqrt(variance)
case Some(Discrete(domain, prob)) =>
val (weighted, sqweighted) = (domain zip prob).foldLeft((0.0, 0.0)) {
case ((weightSum, sqweightSum), (d, p)) =>
val floatD = d.toDouble
val weight = floatD * p
val sqweight = floatD * weight
(weightSum + weight, sqweightSum + sqweight)
}
TuanNguyen27 marked this conversation as resolved.
Show resolved Hide resolved
sqweighted - weighted
case Some(d) => throw new Exception("Unsupported distribution type for the label")
TuanNguyen27 marked this conversation as resolved.
Show resolved Hide resolved
case None => throw new Exception("Label does not exist, please check your data")
}
if (labelStd == 0) throw new Exception("The standard deviation of the label is zero, " +
"so the coefficients and intercepts of the model will be zeros, training is not needed.\"")
TuanNguyen27 marked this conversation as resolved.
Show resolved Hide resolved
h.parentFeatureOrigins ->
Insights(
derivedFeatureName = h.columnName,
Expand Down Expand Up @@ -671,13 +672,11 @@ case object ModelInsights {
}

private[op] def checkLRStandardization(model: Option[Model[_]]): Option[Boolean] = {
val stage = model.flatMap {
case m: SparkWrapperParams[_] => m.getSparkMlStage()
case _ => None
}
stage.collect {
case m: LogisticRegressionModel => true && m.getStandardization
case m: LinearRegressionModel => true && m.getStandardization
for {
m : SparkWrapperParams[_] <- model
stage <- m.getSparkMlStage()
} yield stage match {
case s: LogisticRegressionModel | LinearRegressionModel => s.getStandardization
case _ => false
}
}
Expand Down