Skip to content

Commit

Permalink
Upgrade XGBoost to 0.81 (#180)
Browse files Browse the repository at this point in the history
  • Loading branch information
tovbinm committed Nov 9, 2018
1 parent cdfa7f4 commit d0785f0
Show file tree
Hide file tree
Showing 4 changed files with 2 additions and 20 deletions.
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ configure(allProjs) {
scoveragePluginVersion = '1.3.1'
hadrianVersion = '0.8.5'
aardpfarkVersion = '0.1.0-SNAPSHOT'
xgboostVersion = '0.80'
xgboostVersion = '0.81'
akkaSlf4jVersion = '2.3.11'

mainClassName = 'com.salesforce.Main'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ case object ModelSelectorSummary {
JsonUtils.fromString[Map[String, Map[String, Any]]](json).map{ d =>
val asMetrics = d.flatMap{ case (_, values) => values.map{
case (nm: String, mp: Map[String, Any]@unchecked) =>
val valsJson = JsonUtils.toJsonString(mp) // gross but it works TODO try to find a better way
val valsJson = JsonUtils.toJsonString(mp) // TODO: gross but it works. try to find a better way
nm match {
case OpEvaluatorNames.Binary.humanFriendlyName =>
nm -> JsonUtils.fromString[BinaryClassificationMetrics](valsJson).get
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,23 +134,6 @@ class OpClassifierModelTest extends FlatSpec with TestSparkContext with OpXGBoos
.setLabelCol(labelF.name)
val spk = cl.fit(rawDF)
val op = toOP(spk, spk.uid).setInput(labelF, featureV)

// ******************************************************
// TODO: remove equality tolerance once XGBoost rounding bug in XGBoostClassifier.transform(probabilityUDF) is fixed
// TODO: ETA - will be added in XGBoost version 0.81
implicit val doubleEquality = new Equality[Double] {
def areEqual(a: Double, b: Any): Boolean = b match {
case s: Double => (a.isNaN && s.isNaN) || math.abs(a - s) < 0.0000001
case _ => false
}
}
implicit val doubleArrayEquality = new Equality[Array[Double]] {
def areEqual(a: Array[Double], b: Any): Boolean = b match {
case s: Array[_] if a.length == s.length => a.zip(s).forall(v => doubleEquality.areEqual(v._1, v._2))
case _ => false
}
}
// ******************************************************
compareOutputs(spk.transform(rawDF), op.transform(rawDF))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ import scala.language.postfixOps

@RunWith(classOf[JUnitRunner])
class RichDatasetTest extends FlatSpec with TestSparkContext {

import com.salesforce.op.utils.spark.RichDataType._
import com.salesforce.op.utils.spark.RichDataset._
import com.salesforce.op.utils.spark.RichMetadata._
Expand Down

0 comments on commit d0785f0

Please sign in to comment.