Skip to content

Commit

Permalink
SPARK-8484. Styling.
Browse files Browse the repository at this point in the history
  • Loading branch information
zapletal-martin committed Jul 9, 2015
1 parent 2aa6f43 commit 3bc1853
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ class TrainValidatorSplit(override val uid: String) extends Estimator[TrainValid

def this() = this(Identifiable.randomUID("cv"))

private val f2jBLAS = new F2jBLAS

/** @group setParam */
def setEstimator(value: Estimator[_]): this.type = set(estimator, value)

Expand Down Expand Up @@ -104,6 +102,7 @@ class TrainValidatorSplit(override val uid: String) extends Estimator[TrainValid
val validationDataset = sqlCtx.createDataFrame(validation, schema).cache()

// multi-model training
logDebug(s"Train split with multiple sets of parameters.")
val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]]
trainingDataset.unpersist()
var i = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@
package org.apache.spark.ml.tuning

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, RegressionEvaluator}
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.HasInputCol
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.tuning.CrossValidatorSuite.{MyEvaluator, MyEstimator}
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType

class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext {
test("train validation with logistic regression") {
Expand Down Expand Up @@ -81,6 +84,8 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext
}

test("validateParams should check estimatorParamMaps") {
import TrainValidationSplitSuite._

val est = new MyEstimator("est")
val eval = new MyEvaluator
val paramMaps = new ParamGridBuilder()
Expand All @@ -97,7 +102,38 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext
val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "")
cv.setEstimatorParamMaps(invalidParamMaps)
intercept[IllegalArgumentException] {
cv.validateParams()
cv.validateParams()
}
}
}
}

object TrainValidationSplitSuite {

abstract class MyModel extends Model[MyModel]

class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {

override def validateParams(): Unit = require($(inputCol).nonEmpty)

override def fit(dataset: DataFrame): MyModel = {
throw new UnsupportedOperationException
}

override def transformSchema(schema: StructType): StructType = {
throw new UnsupportedOperationException
}

override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra)
}

class MyEvaluator extends Evaluator {

override def evaluate(dataset: DataFrame): Double = {
throw new UnsupportedOperationException
}

override val uid: String = "eval"

override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra)
}
}

0 comments on commit 3bc1853

Please sign in to comment.