From 2aa6f43d988316ad90ae568a3556ee1ab1758a06 Mon Sep 17 00:00:00 2001 From: martinzapletal Date: Thu, 9 Jul 2015 07:37:56 -0400 Subject: [PATCH] SPARK-8484. Added TrainValidationSplit for hyper-parameter tuning. It randomly splits the input dataset into train and validation and use evaluation metric on the validation set to select the best model. --- .../spark/ml/tuning/CrossValidator.scala | 33 +------------------ .../ml/tuning/TrainValidationSplit.scala | 9 ++--- 2 files changed, 6 insertions(+), 36 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index e2444ab65b43b..f979319cc4b58 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -32,38 +32,7 @@ import org.apache.spark.sql.types.StructType /** * Params for [[CrossValidator]] and [[CrossValidatorModel]]. */ -private[ml] trait CrossValidatorParams extends Params { - - /** - * param for the estimator to be cross-validated - * @group param - */ - val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection") - - /** @group getParam */ - def getEstimator: Estimator[_] = $(estimator) - - /** - * param for estimator param maps - * @group param - */ - val estimatorParamMaps: Param[Array[ParamMap]] = - new Param(this, "estimatorParamMaps", "param maps for the estimator") - - /** @group getParam */ - def getEstimatorParamMaps: Array[ParamMap] = $(estimatorParamMaps) - - /** - * param for the evaluator used to select hyper-parameters that maximize the cross-validated - * metric - * @group param - */ - val evaluator: Param[Evaluator] = new Param(this, "evaluator", - "evaluator used to select hyper-parameters that maximize the cross-validated metric") - - /** @group getParam */ - def getEvaluator: Evaluator = $(evaluator) - +private[ml] trait CrossValidatorParams extends ValidatorParams { /** * Param for number of folds for cross validation. Must be >= 2. * Default: 3 diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 90ac9cfd4a5c2..ffd01314813cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -17,7 +17,10 @@ package org.apache.spark.ml.tuning +import scala.reflect.ClassTag + import com.github.fommil.netlib.F2jBLAS + import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.ml.evaluation.Evaluator @@ -30,13 +33,11 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils import org.apache.spark.util.random.BernoulliCellSampler -import scala.reflect.ClassTag /** * Params for [[TrainValidatorSplit]] and [[TrainValidatorSplitModel]]. */ private[ml] trait TrainValidatorSplitParams extends ValidatorParams { - /** * Param for ratio between train and validation data. Must be between 0 and 1. * Default: 0.75 @@ -51,7 +52,7 @@ private[ml] trait TrainValidatorSplitParams extends ValidatorParams { setDefault(trainRatio -> 0.75) } - /** +/** * :: Experimental :: * Validation for hyper-parameter tuning. * Randomly splits the input dataset into train and validation sets. @@ -78,7 +79,7 @@ class TrainValidatorSplit(override val uid: String) extends Estimator[TrainValid /** @group setParam */ def setTrainRatio(value: Double): this.type = set(trainRatio, value) - private def sample[T: ClassTag]( + private[this] def sample[T: ClassTag]( rdd: RDD[T], lb: Double, ub: Double,