diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 4992d5db1f57a..cdbd06200c7bd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -50,7 +50,6 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha val outputFields = inputFields :+ attr.toStructField() StructType(outputFields) } - } /** @@ -159,7 +158,7 @@ class StringIndexerModel private[ml] ( * Note: By default we keep the original columns during this transformation, so the inverse * should only be used on new columns such as predicted labels. */ - def invert(inputCol: String, outputCol: String): StringIndexerInverseTransformer = { + def invert(inputCol: String, outputCol: String): StringIndexerInverse = { val labelsCol: String = $(this.outputCol) new StringIndexerInverseTransformer(labelsCol) .setInputCol(inputCol) @@ -174,7 +173,7 @@ class StringIndexerModel private[ml] ( * so the inverse should only be used on new columns such as predicted labels. */ @Experimental -class StringIndexerInverseTransformer private[ml] ( +class StringIndexerInverse private[ml] ( override val uid: String, val labelsCol: String) extends Transformer with HasInputCol with HasOutputCol { @@ -210,7 +209,7 @@ class StringIndexerInverseTransformer private[ml] ( val values = attr.values.get val indexer = udf { index: Double => val idx = index.toInt - if (0 <= idx && idx <= values.size) { + if (0 <= idx && idx < values.size) { values(idx) } else { throw new SparkException(s"Unseen index: $index ??") @@ -218,7 +217,7 @@ class StringIndexerInverseTransformer private[ml] ( } val outputColName = $(outputCol) dataset.select(col("*"), - indexer(dataset($(inputCol)).cast(StringType)).as(outputColName)) + indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName)) } override def copy(extra: ParamMap): StringIndexerInverseTransformer = {