Skip to content

Commit

Permalink
SPARK-8484. Added TrainValidationSplit for hyper-parameter tuning. It…
Browse files Browse the repository at this point in the history
… randomly splits the input dataset into train and validation and use evaluation metric on the validation set to select the best model.
  • Loading branch information
zapletal-martin committed Jul 9, 2015
1 parent 21662eb commit 2aa6f43
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,38 +32,7 @@ import org.apache.spark.sql.types.StructType
/**
* Params for [[CrossValidator]] and [[CrossValidatorModel]].
*/
private[ml] trait CrossValidatorParams extends Params {

/**
* param for the estimator to be cross-validated
* @group param
*/
val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection")

/** @group getParam */
def getEstimator: Estimator[_] = $(estimator)

/**
* param for estimator param maps
* @group param
*/
val estimatorParamMaps: Param[Array[ParamMap]] =
new Param(this, "estimatorParamMaps", "param maps for the estimator")

/** @group getParam */
def getEstimatorParamMaps: Array[ParamMap] = $(estimatorParamMaps)

/**
* param for the evaluator used to select hyper-parameters that maximize the cross-validated
* metric
* @group param
*/
val evaluator: Param[Evaluator] = new Param(this, "evaluator",
"evaluator used to select hyper-parameters that maximize the cross-validated metric")

/** @group getParam */
def getEvaluator: Evaluator = $(evaluator)

private[ml] trait CrossValidatorParams extends ValidatorParams {
/**
* Param for number of folds for cross validation. Must be >= 2.
* Default: 3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

package org.apache.spark.ml.tuning

import scala.reflect.ClassTag

import com.github.fommil.netlib.F2jBLAS

import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.evaluation.Evaluator
Expand All @@ -30,13 +33,11 @@ import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
import org.apache.spark.util.random.BernoulliCellSampler

import scala.reflect.ClassTag

/**
* Params for [[TrainValidatorSplit]] and [[TrainValidatorSplitModel]].
*/
private[ml] trait TrainValidatorSplitParams extends ValidatorParams {

/**
* Param for ratio between train and validation data. Must be between 0 and 1.
* Default: 0.75
Expand All @@ -51,7 +52,7 @@ private[ml] trait TrainValidatorSplitParams extends ValidatorParams {
setDefault(trainRatio -> 0.75)
}

/**
/**
* :: Experimental ::
* Validation for hyper-parameter tuning.
* Randomly splits the input dataset into train and validation sets.
Expand All @@ -78,7 +79,7 @@ class TrainValidatorSplit(override val uid: String) extends Estimator[TrainValid
/** @group setParam */
def setTrainRatio(value: Double): this.type = set(trainRatio, value)

private def sample[T: ClassTag](
private[this] def sample[T: ClassTag](
rdd: RDD[T],
lb: Double,
ub: Double,
Expand Down

0 comments on commit 2aa6f43

Please sign in to comment.