From 71e8d66c6e9ab665cf20c3c477e161754b5bc8d1 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 29 Jul 2015 17:02:56 -0700 Subject: [PATCH] Make labels a local param for StringIndexerInverse --- .../apache/spark/ml/feature/StringIndexer.scala | 13 ++++++++++++- .../ml/param/shared/SharedParamsCodeGen.scala | 1 - .../spark/ml/param/shared/sharedParams.scala | 17 ----------------- 3 files changed, 12 insertions(+), 19 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index b7766414beeaf..720fd84d9f1c6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -176,7 +176,7 @@ class StringIndexerModel private[ml] ( @Experimental class StringIndexerInverse private[ml] ( override val uid: String) extends Transformer - with HasInputCol with HasOutputCol with HasLabels { + with HasInputCol with HasOutputCol { def this(labels: Option[Array[String]] = None) = this(Identifiable.randomUID("strIdxInv")) @@ -190,6 +190,17 @@ class StringIndexerInverse private[ml] ( /** @group setParam */ def setLabels(value: Array[String]): this.type = set(labels, value) + /** + * Param for array of labels. + * @group param + */ + final val labels: StringArrayParam = new StringArrayParam(this, "labels", "array of labels") + + setDefault(labels, null) + + /** @group getParam */ + final def getLabels: Array[String] = $(labels) + /** Transform the schema for the inverse transformation */ override def transformSchema(schema: StructType): StructType = { val inputColName = $(inputCol) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 02bde6e4d1024..f7ae1de522e01 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -37,7 +37,6 @@ private[shared] object SharedParamsCodeGen { isValid = "ParamValidators.gtEq(0)"), ParamDesc[String]("featuresCol", "features column name", Some("\"features\"")), ParamDesc[String]("labelCol", "label column name", Some("\"label\"")), - ParamDesc[Array[String]]("labels", "array of labels", Some("null")), ParamDesc[String]("predictionCol", "prediction column name", Some("\"prediction\"")), ParamDesc[String]("rawPredictionCol", "raw prediction (a.k.a. confidence) column name", Some("\"rawPrediction\"")), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index fd9c32dfc2a38..65e48e4ee5083 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -87,23 +87,6 @@ private[ml] trait HasLabelCol extends Params { final def getLabelCol: String = $(labelCol) } -/** - * Trait for shared param labels (default: null). - */ -private[ml] trait HasLabels extends Params { - - /** - * Param for array of labels. - * @group param - */ - final val labels: StringArrayParam = new StringArrayParam(this, "labels", "array of labels") - - setDefault(labels, null) - - /** @group getParam */ - final def getLabels: Array[String] = $(labels) -} - /** * Trait for shared param predictionCol (default: "prediction"). */