From 3bf263a186ffc66e2e60729a61e8af0c392d712c Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 16 Mar 2016 17:31:55 -0700 Subject: [PATCH] [SPARK-13761][ML] Deprecate validateParams ## What changes were proposed in this pull request? Deprecate validateParams() method here: https://github.com/apache/spark/blob/035d3acdf3c1be5b309a861d5c5beb803b946b5e/mllib/src/main/scala/org/apache/spark/ml/param/params.scala#L553 Move all functionality in overridden methods to transformSchema(). Check docs to make sure they indicate complex Param interaction checks should be done in transformSchema. ## How was this patch tested? unit tests Author: Yuhao Yang Closes #11620 from hhbyyh/depreValid. --- .../main/scala/org/apache/spark/ml/Pipeline.scala | 14 -------------- .../main/scala/org/apache/spark/ml/Predictor.scala | 1 - .../scala/org/apache/spark/ml/Transformer.scala | 1 - .../org/apache/spark/ml/clustering/KMeans.scala | 1 - .../scala/org/apache/spark/ml/clustering/LDA.scala | 9 ++------- .../org/apache/spark/ml/feature/Bucketizer.scala | 1 - .../apache/spark/ml/feature/ChiSqSelector.scala | 2 -- .../apache/spark/ml/feature/CountVectorizer.scala | 1 - .../org/apache/spark/ml/feature/HashingTF.scala | 1 - .../scala/org/apache/spark/ml/feature/IDF.scala | 1 - .../org/apache/spark/ml/feature/Interaction.scala | 13 ++++--------- .../org/apache/spark/ml/feature/MaxAbsScaler.scala | 1 - .../org/apache/spark/ml/feature/MinMaxScaler.scala | 5 +---- .../apache/spark/ml/feature/OneHotEncoder.scala | 1 - .../scala/org/apache/spark/ml/feature/PCA.scala | 2 -- .../spark/ml/feature/QuantileDiscretizer.scala | 1 - .../org/apache/spark/ml/feature/RFormula.scala | 4 ---- .../apache/spark/ml/feature/SQLTransformer.scala | 1 - .../apache/spark/ml/feature/StandardScaler.scala | 2 -- .../apache/spark/ml/feature/StopWordsRemover.scala | 1 - .../apache/spark/ml/feature/StringIndexer.scala | 2 -- .../apache/spark/ml/feature/VectorAssembler.scala | 1 - .../apache/spark/ml/feature/VectorIndexer.scala | 2 -- .../org/apache/spark/ml/feature/VectorSlicer.scala | 8 ++------ .../org/apache/spark/ml/feature/Word2Vec.scala | 1 - .../scala/org/apache/spark/ml/param/params.scala | 7 ++++--- .../org/apache/spark/ml/recommendation/ALS.scala | 2 -- .../ml/regression/AFTSurvivalRegression.scala | 1 - .../regression/GeneralizedLinearRegression.scala | 7 ++++++- .../spark/ml/regression/IsotonicRegression.scala | 1 - .../org/apache/spark/ml/clustering/LDASuite.scala | 12 +++++++----- .../spark/ml/feature/MinMaxScalerSuite.scala | 10 ++++++---- .../spark/ml/feature/VectorSlicerSuite.scala | 8 ++++---- 33 files changed, 36 insertions(+), 89 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index cbac7bbf49fc4..f4c6214a56360 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -110,12 +110,6 @@ class Pipeline @Since("1.4.0") ( @Since("1.2.0") def getStages: Array[PipelineStage] = $(stages).clone() - @Since("1.4.0") - override def validateParams(): Unit = { - super.validateParams() - $(stages).foreach(_.validateParams()) - } - /** * Fits the pipeline to the input dataset with additional parameters. If a stage is an * [[Estimator]], its [[Estimator#fit]] method will be called on the input dataset to fit a model. @@ -175,7 +169,6 @@ class Pipeline @Since("1.4.0") ( @Since("1.2.0") override def transformSchema(schema: StructType): StructType = { - validateParams() val theStages = $(stages) require(theStages.toSet.size == theStages.length, "Cannot have duplicate components in a pipeline.") @@ -297,12 +290,6 @@ class PipelineModel private[ml] ( this(uid, stages.asScala.toArray) } - @Since("1.4.0") - override def validateParams(): Unit = { - super.validateParams() - stages.foreach(_.validateParams()) - } - @Since("1.2.0") override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) @@ -311,7 +298,6 @@ class PipelineModel private[ml] ( @Since("1.2.0") override def transformSchema(schema: StructType): StructType = { - validateParams() stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index 4b27ee6c5a414..ebe48700f8717 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -46,7 +46,6 @@ private[ml] trait PredictorParams extends Params schema: StructType, fitting: Boolean, featuresDataType: DataType): StructType = { - validateParams() // TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType) if (fitting) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index fdce273193b7c..1f3325ad09ef1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -103,7 +103,6 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] protected def validateInputType(inputType: DataType): Unit = {} override def transformSchema(schema: StructType): StructType = { - validateParams() val inputType = schema($(inputCol)).dataType validateInputType(inputType) if (schema.fieldNames.contains($(outputCol))) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 79332b0d02157..ab00127899edf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -81,7 +81,6 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - validateParams() SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 6304b20d544ad..0de82b49ff6f5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -263,13 +263,6 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - validateParams() - SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) - SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT) - } - - @Since("1.6.0") - override def validateParams(): Unit = { if (isSet(docConcentration)) { if (getDocConcentration.length != 1) { require(getDocConcentration.length == getK, s"LDA docConcentration was of length" + @@ -297,6 +290,8 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM s" must be >= 1. Found value: $getTopicConcentration") } } + SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT) } private[clustering] def getOldOptimizer: OldLDAOptimizer = getOptimizer match { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 0c75317d82703..33abc7c99d4b0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -86,7 +86,6 @@ final class Bucketizer(override val uid: String) } override def transformSchema(schema: StructType): StructType = { - validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) SchemaUtils.appendColumn(schema, prepOutputField(schema)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index 4abc459f5369a..b9e9d56853605 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -88,7 +88,6 @@ final class ChiSqSelector(override val uid: String) } override def transformSchema(schema: StructType): StructType = { - validateParams() SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) @@ -136,7 +135,6 @@ final class ChiSqSelectorModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { - validateParams() SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) val newField = prepOutputField(schema) val outputFields = schema.fields :+ newField diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index cf151458f0917..f7d08b39a9746 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -70,7 +70,6 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { - validateParams() val typeCandidates = List(new ArrayType(StringType, true), new ArrayType(StringType, false)) SchemaUtils.checkColumnTypes(schema, $(inputCol), typeCandidates) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 8af00581f7e54..61a78d73c4347 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -69,7 +69,6 @@ class HashingTF(override val uid: String) } override def transformSchema(schema: StructType): StructType = { - validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[ArrayType], s"The input column must be ArrayType, but got $inputType.") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index cebbe5c162f79..f36cf503a0b80 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -52,7 +52,6 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol * Validate and transform the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { - validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 7d2a1da990fce..d3fe6e528f0b6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -61,13 +61,15 @@ class Interaction @Since("1.6.0") (override val uid: String) extends Transformer // optimistic schema; does not contain any ML attributes @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { - validateParams() + require(get(inputCols).isDefined, "Input cols must be defined first.") + require(get(outputCol).isDefined, "Output col must be defined first.") + require($(inputCols).length > 0, "Input cols must have non-zero length.") + require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.") StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, false)) } @Since("1.6.0") override def transform(dataset: DataFrame): DataFrame = { - validateParams() val inputFeatures = $(inputCols).map(c => dataset.schema(c)) val featureEncoders = getFeatureEncoders(inputFeatures) val featureAttrs = getFeatureAttrs(inputFeatures) @@ -217,13 +219,6 @@ class Interaction @Since("1.6.0") (override val uid: String) extends Transformer @Since("1.6.0") override def copy(extra: ParamMap): Interaction = defaultCopy(extra) - @Since("1.6.0") - override def validateParams(): Unit = { - require(get(inputCols).isDefined, "Input cols must be defined first.") - require(get(outputCol).isDefined, "Output col must be defined first.") - require($(inputCols).length > 0, "Input cols must have non-zero length.") - require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.") - } } @Since("1.6.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala index 09fad236422b7..7de5a4d5d314c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala @@ -37,7 +37,6 @@ private[feature] trait MaxAbsScalerParams extends Params with HasInputCol with H /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { - validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index 3b4209bbc49ad..b13684a1cb76a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -59,7 +59,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { - validateParams() + require($(min) < $(max), s"The specified min(${$(min)}) is larger or equal to max(${$(max)})") val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") @@ -69,9 +69,6 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H StructType(outputFields) } - override def validateParams(): Unit = { - require($(min) < $(max), s"The specified min(${$(min)}) is larger or equal to max(${$(max)})") - } } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index fa5013d3c939e..4f67042629c5a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -66,7 +66,6 @@ class OneHotEncoder(override val uid: String) extends Transformer def setOutputCol(value: String): this.type = set(outputCol, value) override def transformSchema(schema: StructType): StructType = { - validateParams() val inputColName = $(inputCol) val outputColName = $(outputCol) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 80b124f74716d..305c3d187fcbb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -77,7 +77,6 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams } override def transformSchema(schema: StructType): StructType = { - validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") @@ -133,7 +132,6 @@ class PCAModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { - validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 18896fcc4d8c1..e830d2a9adc41 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -78,7 +78,6 @@ final class QuantileDiscretizer(override val uid: String) def setSeed(value: Long): this.type = set(seed, value) override def transformSchema(schema: StructType): StructType = { - validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) val inputFields = schema.fields require(inputFields.forall(_.name != $(outputCol)), diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index c21da218b36d6..ab5f4a1a9a6c4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -167,7 +167,6 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R // optimistic schema; does not contain any ML attributes override def transformSchema(schema: StructType): StructType = { - validateParams() if (hasLabelCol(schema)) { StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true)) } else { @@ -200,7 +199,6 @@ class RFormulaModel private[feature]( } override def transformSchema(schema: StructType): StructType = { - validateParams() checkCanTransform(schema) val withFeatures = pipelineModel.transformSchema(schema) if (hasLabelCol(withFeatures)) { @@ -263,7 +261,6 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { } override def transformSchema(schema: StructType): StructType = { - validateParams() StructType(schema.fields.filter(col => !columnsToPrune.contains(col.name))) } @@ -312,7 +309,6 @@ private class VectorAttributeRewriter( } override def transformSchema(schema: StructType): StructType = { - validateParams() StructType( schema.fields.filter(_.name != vectorCol) ++ schema.fields.filter(_.name == vectorCol)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index af6494b234cee..e0ca45b9a6190 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -74,7 +74,6 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { - validateParams() val sc = SparkContext.getOrCreate() val sqlContext = SQLContext.getOrCreate(sc) val dummyRDD = sc.parallelize(Seq(Row.empty)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 9952d3bc9f1a5..26ee8e1bf1669 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -94,7 +94,6 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM } override def transformSchema(schema: StructType): StructType = { - validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") @@ -144,7 +143,6 @@ class StandardScalerModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { - validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 0d4c968633295..0a0e0b0960c88 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -145,7 +145,6 @@ class StopWordsRemover(override val uid: String) } override def transformSchema(schema: StructType): StructType = { - validateParams() val inputType = schema($(inputCol)).dataType require(inputType.sameType(ArrayType(StringType)), s"Input type must be ArrayType(StringType) but got $inputType.") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 7dd794b9d7d1d..c579a0d68ec6c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -39,7 +39,6 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { - validateParams() val inputColName = $(inputCol) val inputDataType = schema(inputColName).dataType require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType], @@ -275,7 +274,6 @@ class IndexToString private[ml] (override val uid: String) final def getLabels: Array[String] = $(labels) override def transformSchema(schema: StructType): StructType = { - validateParams() val inputColName = $(inputCol) val inputDataType = schema(inputColName).dataType require(inputDataType.isInstanceOf[NumericType], diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 7ff5ad143f80b..957e8e7a5983c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -106,7 +106,6 @@ class VectorAssembler(override val uid: String) } override def transformSchema(schema: StructType): StructType = { - validateParams() val inputColNames = $(inputCols) val outputColName = $(outputCol) val inputDataTypes = inputColNames.map(name => schema(name).dataType) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 5c11760fab9b2..bf4aef2a74c71 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -126,7 +126,6 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod } override def transformSchema(schema: StructType): StructType = { - validateParams() // We do not transfer feature metadata since we do not know what types of features we will // produce in transform(). val dataType = new VectorUDT @@ -355,7 +354,6 @@ class VectorIndexerModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { - validateParams() val dataType = new VectorUDT require(isDefined(inputCol), s"VectorIndexerModel requires input column parameter: $inputCol") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala index 300d63bd3a0da..b60e82de00c08 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -89,11 +89,6 @@ final class VectorSlicer(override val uid: String) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def validateParams(): Unit = { - require($(indices).length > 0 || $(names).length > 0, - s"VectorSlicer requires that at least one feature be selected.") - } - override def transform(dataset: DataFrame): DataFrame = { // Validity checks transformSchema(dataset.schema) @@ -139,7 +134,8 @@ final class VectorSlicer(override val uid: String) } override def transformSchema(schema: StructType): StructType = { - validateParams() + require($(indices).length > 0 || $(names).length > 0, + s"VectorSlicer requires that at least one feature be selected.") SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) if (schema.fieldNames.contains($(outputCol))) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 3d3c7bdc2f4d8..95bae1c8a3127 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -92,7 +92,6 @@ private[feature] trait Word2VecBase extends Params * Validate and transform the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { - validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true)) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 42411d2d8af9c..d7837b67303f7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -58,9 +58,8 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali /** * Assert that the given value is valid for this parameter. * - * Note: Parameter checks involving interactions between multiple parameters should be - * implemented in [[Params.validateParams()]]. Checks for input/output columns should be - * implemented in [[org.apache.spark.ml.PipelineStage.transformSchema()]]. + * Note: Parameter checks involving interactions between multiple parameters and input/output + * columns should be implemented in [[org.apache.spark.ml.PipelineStage.transformSchema()]]. * * DEVELOPERS: This method is only called by [[ParamPair]], which means that all parameters * should be specified via [[ParamPair]]. @@ -555,7 +554,9 @@ trait Params extends Identifiable with Serializable { * Parameter value checks which do not depend on other parameters are handled by * [[Param.validate()]]. This method does not handle input/output column parameters; * those are checked during schema validation. + * @deprecated Will be removed in 2.1.0. All the checks should be merged into transformSchema */ + @deprecated("Will be removed in 2.1.0. Checks should be merged into transformSchema.", "2.0.0") def validateParams(): Unit = { // Do nothing by default. Override to handle Param interactions. } diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index dacdac9a1df16..f3bc9f095a8c1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -162,7 +162,6 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - validateParams() SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) val ratingType = schema($(ratingCol)).dataType @@ -220,7 +219,6 @@ class ALSModel private[ml] ( @Since("1.3.0") override def transformSchema(schema: StructType): StructType = { - validateParams() SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index e4339d67b928d..0901642d392d3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -99,7 +99,6 @@ private[regression] trait AFTSurvivalRegressionParams extends Params protected def validateAndTransformSchema( schema: StructType, fitting: Boolean): StructType = { - validateParams() SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) if (fitting) { SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index b4e47c80735ca..46ba5589ff85f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -32,6 +32,7 @@ import org.apache.spark.mllib.linalg.{BLAS, Vector} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DataType, StructType} /** * Params for Generalized Linear Regression. @@ -77,7 +78,10 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam import GeneralizedLinearRegression._ @Since("2.0.0") - override def validateParams(): Unit = { + override def validateAndTransformSchema( + schema: StructType, + fitting: Boolean, + featuresDataType: DataType): StructType = { if ($(solver) == "irls") { setDefault(maxIter -> 25) } @@ -86,6 +90,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam Family.fromName($(family)) -> Link.fromName($(link))), "Generalized Linear Regression " + s"with ${$(family)} family does not support ${$(link)} link function.") } + super.validateAndTransformSchema(schema, fitting, featuresDataType) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index 36b006c10e1fb..20a09982014ca 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -105,7 +105,6 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures protected[ml] def validateAndTransformSchema( schema: StructType, fitting: Boolean): StructType = { - validateParams() if (fitting) { SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) if (hasWeightCol) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index a3a8f65eac176..dd3f4c6e5391d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -138,16 +138,18 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead new LDA().setTopicConcentration(-1.1) } - // validateParams() - lda.validateParams() + val dummyDF = sqlContext.createDataFrame(Seq( + (1, Vectors.dense(1.0, 2.0)))).toDF("id", "features") + // validate parameters + lda.transformSchema(dummyDF.schema) lda.setDocConcentration(1.1) - lda.validateParams() + lda.transformSchema(dummyDF.schema) lda.setDocConcentration(Range(0, lda.getK).map(_ + 2.0).toArray) - lda.validateParams() + lda.transformSchema(dummyDF.schema) lda.setDocConcentration(Range(0, lda.getK - 1).map(_ + 2.0).toArray) withClue("LDA docConcentration validity check failed for bad array length") { intercept[IllegalArgumentException] { - lda.validateParams() + lda.transformSchema(dummyDF.schema) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala index 035bfc07b684d..87206c777e352 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -57,13 +57,15 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De test("MinMaxScaler arguments max must be larger than min") { withClue("arguments max must be larger than min") { + val dummyDF = sqlContext.createDataFrame(Seq( + (1, Vectors.dense(1.0, 2.0)))).toDF("id", "feature") intercept[IllegalArgumentException] { - val scaler = new MinMaxScaler().setMin(10).setMax(0) - scaler.validateParams() + val scaler = new MinMaxScaler().setMin(10).setMax(0).setInputCol("feature") + scaler.transformSchema(dummyDF.schema) } intercept[IllegalArgumentException] { - val scaler = new MinMaxScaler().setMin(0).setMax(0) - scaler.validateParams() + val scaler = new MinMaxScaler().setMin(0).setMax(0).setInputCol("feature") + scaler.transformSchema(dummyDF.schema) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala index 94191e5df383b..6bb4678dc5f97 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala @@ -21,21 +21,21 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StructField, StructType} class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { - val slicer = new VectorSlicer + val slicer = new VectorSlicer().setInputCol("feature") ParamsSuite.checkParams(slicer) assert(slicer.getIndices.length === 0) assert(slicer.getNames.length === 0) withClue("VectorSlicer should not have any features selected by default") { intercept[IllegalArgumentException] { - slicer.validateParams() + slicer.transformSchema(StructType(Seq(StructField("feature", new VectorUDT, true)))) } } }