Skip to content

Commit

Permalink
Pick the best model when training GradientBoostedTrees with validation.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Feb 25, 2015
1 parent d641fbb commit ea2fae2
Showing 1 changed file with 11 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -226,17 +226,10 @@ object GradientBoostedTrees extends Logging {
logDebug("error of gbt = " + loss.computeError(partialModel, input))

if (validate) {
// Stop training early if
// 1. Reduction in error is less than the validationTol or
// 2. If the error increases, that is if the model is overfit.
// Record the best model if the reduction in error is more than the validationTol.
// We want the model returned corresponding to the best validation error.
val currentValidateError = loss.computeError(partialModel, validationInput)
if (bestValidateError - currentValidateError < validationTol) {
return new GradientBoostedTreesModel(
boostingStrategy.treeStrategy.algo,
baseLearners.slice(0, bestM),
baseLearnerWeights.slice(0, bestM))
} else if (currentValidateError < bestValidateError) {
if (currentValidateError < bestValidateError - validationTol) {
bestValidateError = currentValidateError
bestM = m + 1
}
Expand All @@ -251,9 +244,15 @@ object GradientBoostedTrees extends Logging {

logInfo("Internal timing for DecisionTree:")
logInfo(s"$timer")

new GradientBoostedTreesModel(
boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights)
if (validate) {
new GradientBoostedTreesModel(
boostingStrategy.treeStrategy.algo,
baseLearners.slice(0, bestM),
baseLearnerWeights.slice(0, bestM))
} else {
new GradientBoostedTreesModel(
boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights)
}
}

}

0 comments on commit ea2fae2

Please sign in to comment.