diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index cdbd06200c7bd..4b4fc74097862 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.Transformer import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{NumericType, StringType, StructType} +import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, StructType} import org.apache.spark.util.collection.OpenHashMap /** @@ -159,8 +159,7 @@ class StringIndexerModel private[ml] ( * should only be used on new columns such as predicted labels. */ def invert(inputCol: String, outputCol: String): StringIndexerInverse = { - val labelsCol: String = $(this.outputCol) - new StringIndexerInverseTransformer(labelsCol) + new StringIndexerInverse(Some(labels)) .setInputCol(inputCol) .setOutputCol(outputCol) } @@ -175,10 +174,11 @@ class StringIndexerModel private[ml] ( @Experimental class StringIndexerInverse private[ml] ( override val uid: String, - val labelsCol: String) extends Transformer + val labels: Option[Array[String]]) extends Transformer with HasInputCol with HasOutputCol { - def this(labelsCol: String) = this(Identifiable.randomUID("strIdxInv"), labelsCol) + def this(labels: Option[Array[String]] = None) = + this(Identifiable.randomUID("strIdxInv"), labels) /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -204,9 +204,10 @@ class StringIndexerInverse private[ml] ( override def transform(dataset: DataFrame): DataFrame = { val inputColSchema = dataset.schema($(inputCol)) - val attr = Attribute.fromStructField(inputColSchema) - .asInstanceOf[NominalAttribute] - val values = attr.values.get + val values = labels.getOrElse{ + Attribute.fromStructField(inputColSchema) + .asInstanceOf[NominalAttribute].values.get + } val indexer = udf { index: Double => val idx = index.toInt if (0 <= idx && idx < values.size) { @@ -220,8 +221,8 @@ class StringIndexerInverse private[ml] ( indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName)) } - override def copy(extra: ParamMap): StringIndexerInverseTransformer = { - val copied = new StringIndexerInverseTransformer(uid) + override def copy(extra: ParamMap): StringIndexerInverse = { + val copied = new StringIndexerInverse(uid, labels) copyValues(copied, extra) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index d13f4999b5487..d0295a0fe2fc1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -53,6 +53,13 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { .select("id", "label2") assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet === reversed.collect().map(r => (r.getInt(0), r.getString(1))).toSet) + // Check invert using only metadata + val inverse2 = new StringIndexerInverse() + .setInputCol("labelIndex") + .setOutputCol("label2") + val reversed2 = inverse2.transform(transformed).select("id", "label2") + assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet === + reversed2.collect().map(r => (r.getInt(0), r.getString(1))).toSet) } test("StringIndexer with a numeric input column") {