Skip to content

Commit

Permalink
Instead of using a private inverse transform, add an invert function …
Browse files Browse the repository at this point in the history
…so we can use it in a pipeline
  • Loading branch information
holdenk committed Jul 8, 2015
1 parent 88779c1 commit 557bef8
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,22 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
val outputFields = inputFields :+ attr.toStructField()
StructType(outputFields)
}

/** Transform the schema for the inverse transformation */
protected def invertSchema(schema: StructType): StructType = {
val inputColName = $(inputCol)
val inputDataType = schema(inputColName).dataType
require(inputDataType.isInstanceOf[NumericType],
s"The input column $inputColName must be a numeric type, " +
s"but got $inputDataType.")
val inputFields = schema.fields
val outputColName = $(outputCol)
require(inputFields.forall(_.name != outputColName),
s"Output column $outputColName already exists.")
val attr = NominalAttribute.defaultAttr.withName($(outputCol))
val outputFields = inputFields :+ attr.toStructField()
StructType(outputFields)
}
}

/**
Expand Down Expand Up @@ -110,18 +126,6 @@ class StringIndexerModel private[ml] (
map
}

private lazy val indexToLabel: OpenHashMap[Double, String] = {
val n = labels.length
val map = new OpenHashMap[Double, String](n)
var i = 0
while (i < n) {
map.update(i, labels(i))
i += 1
}
map
}


/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)

Expand Down Expand Up @@ -164,18 +168,63 @@ class StringIndexerModel private[ml] (
copyValues(copied, extra)
}

def invertTransform(dataset: DataFrame): DataFrame = {
/**
* Return a model to perform the inverse transformation.
* Note: by default we keep the original columns during this transformation
* so the invert should only be needed if you do something beyond simply
* applying the original transformation.
*/
def invert(): StringIndexerInvertModel = {
new StringIndexerInvertModel(uid, labels)
.setInputCol(getOutputCol)
.setOutputCol(getInputCol)
}
}

class StringIndexerInvertModel private[ml] (
override val uid: String,
labels: Array[String]) extends Model[StringIndexerInvertModel] with StringIndexerBase {

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)

/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)


private val indexToLabel: OpenHashMap[Double, String] = {
val n = labels.length
val map = new OpenHashMap[Double, String](n)
var i = 0
while (i < n) {
map.update(i, labels(i))
i += 1
}
map
}

override def transformSchema(schema: StructType): StructType = {
invertSchema(schema)
}

override def transform(dataset: DataFrame): DataFrame = {
val indexer = udf { index: Double =>
if (indexToLabel.contains(index)) {
indexToLabel(index)
} else {
// TODO: handle unseen labels
throw new SparkException(s"Unseen index: $index ??")
}
}
val inputColName = $(inputCol)
val outputColName = $(outputCol)
val metadata = NominalAttribute.defaultAttr
.withName(outputColName).withValues(labels).toMetadata()
dataset.select(col("*"),
indexer(dataset($(outputCol))).as(inputColName))
indexer(dataset($(inputCol)).cast(StringType)).as(outputColName, metadata))
}

override def copy(extra: ParamMap): StringIndexerInvertModel = {
val copied = new StringIndexerInvertModel(uid, labels)
copyValues(copied, extra)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
assert(output === expected)
// convert reverse our transform
val reversed = indexer.invertTransform(transformed.select("id", "labelIndex"))
val reversed = indexer.invert().transform(transformed.select("id", "labelIndex"))
.select("id", "label")
assert(df.collect().toSet === reversed.collect().toSet)
assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet ===
reversed.collect().map(r => (r.getInt(0), r.getString(1))).toSet)
}

test("StringIndexer with a numeric input column") {
Expand Down

0 comments on commit 557bef8

Please sign in to comment.