Skip to content

Commit

Permalink
Add an inverse test using only meta data, pass labels when calling in…
Browse files Browse the repository at this point in the history
…verse method
  • Loading branch information
holdenk committed Jul 29, 2015
1 parent f3e0c64 commit 5aa38bf
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.ml.Transformer
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{NumericType, StringType, StructType}
import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, StructType}
import org.apache.spark.util.collection.OpenHashMap

/**
Expand Down Expand Up @@ -159,8 +159,7 @@ class StringIndexerModel private[ml] (
* should only be used on new columns such as predicted labels.
*/
def invert(inputCol: String, outputCol: String): StringIndexerInverse = {
val labelsCol: String = $(this.outputCol)
new StringIndexerInverseTransformer(labelsCol)
new StringIndexerInverse(Some(labels))
.setInputCol(inputCol)
.setOutputCol(outputCol)
}
Expand All @@ -175,10 +174,11 @@ class StringIndexerModel private[ml] (
@Experimental
class StringIndexerInverse private[ml] (
override val uid: String,
val labelsCol: String) extends Transformer
val labels: Option[Array[String]]) extends Transformer
with HasInputCol with HasOutputCol {

def this(labelsCol: String) = this(Identifiable.randomUID("strIdxInv"), labelsCol)
def this(labels: Option[Array[String]] = None) =
this(Identifiable.randomUID("strIdxInv"), labels)

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
Expand All @@ -204,9 +204,10 @@ class StringIndexerInverse private[ml] (

override def transform(dataset: DataFrame): DataFrame = {
val inputColSchema = dataset.schema($(inputCol))
val attr = Attribute.fromStructField(inputColSchema)
.asInstanceOf[NominalAttribute]
val values = attr.values.get
val values = labels.getOrElse{
Attribute.fromStructField(inputColSchema)
.asInstanceOf[NominalAttribute].values.get
}
val indexer = udf { index: Double =>
val idx = index.toInt
if (0 <= idx && idx < values.size) {
Expand All @@ -220,8 +221,8 @@ class StringIndexerInverse private[ml] (
indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName))
}

override def copy(extra: ParamMap): StringIndexerInverseTransformer = {
val copied = new StringIndexerInverseTransformer(uid)
override def copy(extra: ParamMap): StringIndexerInverse = {
val copied = new StringIndexerInverse(uid, labels)
copyValues(copied, extra)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
.select("id", "label2")
assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet ===
reversed.collect().map(r => (r.getInt(0), r.getString(1))).toSet)
// Check invert using only metadata
val inverse2 = new StringIndexerInverse()
.setInputCol("labelIndex")
.setOutputCol("label2")
val reversed2 = inverse2.transform(transformed).select("id", "label2")
assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet ===
reversed2.collect().map(r => (r.getInt(0), r.getString(1))).toSet)
}

test("StringIndexer with a numeric input column") {
Expand Down

0 comments on commit 5aa38bf

Please sign in to comment.