Skip to content

Commit

Permalink
Make labels a local param for StringIndexerInverse
Browse files Browse the repository at this point in the history
  • Loading branch information
holdenk committed Jul 30, 2015
1 parent 8450d0b commit 71e8d66
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ class StringIndexerModel private[ml] (
@Experimental
class StringIndexerInverse private[ml] (
override val uid: String) extends Transformer
with HasInputCol with HasOutputCol with HasLabels {
with HasInputCol with HasOutputCol {

def this(labels: Option[Array[String]] = None) =
this(Identifiable.randomUID("strIdxInv"))
Expand All @@ -190,6 +190,17 @@ class StringIndexerInverse private[ml] (
/** @group setParam */
def setLabels(value: Array[String]): this.type = set(labels, value)

/**
* Param for array of labels.
* @group param
*/
final val labels: StringArrayParam = new StringArrayParam(this, "labels", "array of labels")

setDefault(labels, null)

/** @group getParam */
final def getLabels: Array[String] = $(labels)

/** Transform the schema for the inverse transformation */
override def transformSchema(schema: StructType): StructType = {
val inputColName = $(inputCol)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ private[shared] object SharedParamsCodeGen {
isValid = "ParamValidators.gtEq(0)"),
ParamDesc[String]("featuresCol", "features column name", Some("\"features\"")),
ParamDesc[String]("labelCol", "label column name", Some("\"label\"")),
ParamDesc[Array[String]]("labels", "array of labels", Some("null")),
ParamDesc[String]("predictionCol", "prediction column name", Some("\"prediction\"")),
ParamDesc[String]("rawPredictionCol", "raw prediction (a.k.a. confidence) column name",
Some("\"rawPrediction\"")),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,23 +87,6 @@ private[ml] trait HasLabelCol extends Params {
final def getLabelCol: String = $(labelCol)
}

/**
* Trait for shared param labels (default: null).
*/
private[ml] trait HasLabels extends Params {

/**
* Param for array of labels.
* @group param
*/
final val labels: StringArrayParam = new StringArrayParam(this, "labels", "array of labels")

setDefault(labels, null)

/** @group getParam */
final def getLabels: Array[String] = $(labels)
}

/**
* Trait for shared param predictionCol (default: "prediction").
*/
Expand Down

0 comments on commit 71e8d66

Please sign in to comment.