diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index ffd01314813cc..d92b66989227a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -65,8 +65,6 @@ class TrainValidatorSplit(override val uid: String) extends Estimator[TrainValid def this() = this(Identifiable.randomUID("cv")) - private val f2jBLAS = new F2jBLAS - /** @group setParam */ def setEstimator(value: Estimator[_]): this.type = set(estimator, value) @@ -104,6 +102,7 @@ class TrainValidatorSplit(override val uid: String) extends Estimator[TrainValid val validationDataset = sqlCtx.createDataFrame(validation, schema).cache() // multi-model training + logDebug(s"Train split with multiple sets of parameters.") val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]] trainingDataset.unpersist() var i = 0 diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index ad023c5b81e46..089019dd8b0a1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -18,13 +18,16 @@ package org.apache.spark.ml.tuning import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.LogisticRegression -import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, RegressionEvaluator} +import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.tuning.CrossValidatorSuite.{MyEvaluator, MyEstimator} -import org.apache.spark.mllib.classification.LogisticRegressionSuite._ +import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.StructType class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext { test("train validation with logistic regression") { @@ -81,6 +84,8 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext } test("validateParams should check estimatorParamMaps") { + import TrainValidationSplitSuite._ + val est = new MyEstimator("est") val eval = new MyEvaluator val paramMaps = new ParamGridBuilder() @@ -97,7 +102,38 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "") cv.setEstimatorParamMaps(invalidParamMaps) intercept[IllegalArgumentException] { - cv.validateParams() + cv.validateParams() } } -} \ No newline at end of file +} + +object TrainValidationSplitSuite { + + abstract class MyModel extends Model[MyModel] + + class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol { + + override def validateParams(): Unit = require($(inputCol).nonEmpty) + + override def fit(dataset: DataFrame): MyModel = { + throw new UnsupportedOperationException + } + + override def transformSchema(schema: StructType): StructType = { + throw new UnsupportedOperationException + } + + override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra) + } + + class MyEvaluator extends Evaluator { + + override def evaluate(dataset: DataFrame): Double = { + throw new UnsupportedOperationException + } + + override val uid: String = "eval" + + override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra) + } +}