Navigation Menu

Skip to content

Commit

Permalink
Merge pull request apache#4 from mengxr/ml-package-docs
Browse files Browse the repository at this point in the history
replace TypeTag with explicit datatype
  • Loading branch information
jkbradley committed Dec 3, 2014
2 parents 41ad9b1 + 3b83ec0 commit ea34dc6
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 10 deletions.
18 changes: 10 additions & 8 deletions mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
Expand Up @@ -18,16 +18,14 @@
package org.apache.spark.ml

import scala.annotation.varargs
import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param._
import org.apache.spark.sql.SchemaRDD
import org.apache.spark.sql.api.java.JavaSchemaRDD
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis.Star
import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.catalyst.expressions.ScalaUdf
import org.apache.spark.sql.catalyst.types._

/**
Expand Down Expand Up @@ -86,7 +84,7 @@ abstract class Transformer extends PipelineStage with Params {
* Abstract class for transformers that take one input column, apply transformation, and output the
* result as a new column.
*/
private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransformer[IN, OUT, T]]
private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
extends Transformer with HasInputCol with HasOutputCol with Logging {

def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T]
Expand All @@ -99,6 +97,11 @@ private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransfor
*/
protected def createTransformFunc(paramMap: ParamMap): IN => OUT

/**
* Returns the data type of the output column.
*/
protected def outputDataType: DataType

/**
* Validates the input type. Throw an exception if it is invalid.
*/
Expand All @@ -111,17 +114,16 @@ private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransfor
if (schema.fieldNames.contains(map(outputCol))) {
throw new IllegalArgumentException(s"Output column ${map(outputCol)} already exists.")
}
val output = ScalaReflection.schemaFor[OUT]
val outputFields = schema.fields :+
StructField(map(outputCol), output.dataType, output.nullable)
StructField(map(outputCol), outputDataType, !outputDataType.isPrimitive)
StructType(outputFields)
}

override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
transformSchema(dataset.schema, paramMap, logging = true)
import dataset.sqlContext._
val map = this.paramMap ++ paramMap
val udf = this.createTransformFunc(map)
dataset.select(Star(None), udf.call(map(inputCol).attr) as map(outputCol))
val udf = ScalaUdf(this.createTransformFunc(map), outputDataType, Seq(map(inputCol).attr))
dataset.select(Star(None), udf as map(outputCol))
}
}
Expand Up @@ -21,7 +21,8 @@ import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.{IntParam, ParamMap}
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
import org.apache.spark.sql.catalyst.types.DataType

/**
* :: AlphaComponent ::
Expand All @@ -39,4 +40,6 @@ class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] {
val hashingTF = new feature.HashingTF(paramMap(numFeatures))
hashingTF.transform
}

override protected def outputDataType: DataType = new VectorUDT()
}
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.{DataType, StringType}
import org.apache.spark.sql.{DataType, StringType, ArrayType}

/**
* :: AlphaComponent ::
Expand All @@ -36,4 +36,6 @@ class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] {
protected override def validateInputType(inputType: DataType): Unit = {
require(inputType == StringType, s"Input type must be string type but got $inputType.")
}

override protected def outputDataType: DataType = new ArrayType(StringType, false)
}

0 comments on commit ea34dc6

Please sign in to comment.