Skip to content

Commit

Permalink
Add FeatureBuilder.fromSchema (#325)
Browse files Browse the repository at this point in the history
  • Loading branch information
takezoe authored and tovbinm committed May 26, 2019
1 parent c47ff42 commit b08dd94
Showing 1 changed file with 18 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import com.salesforce.op.features.types._
import com.salesforce.op.stages.{FeatureGeneratorStage, OpPipelineStage}
import com.salesforce.op.utils.spark.RichRow._
import com.twitter.algebird.MonoidAggregator
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Row}
import org.joda.time.Duration

Expand Down Expand Up @@ -178,22 +179,22 @@ object FeatureBuilder {
def apply[I: WeakTypeTag, O <: FeatureType : WeakTypeTag](name: String): FeatureBuilder[I, O] = new FeatureBuilder[I, O](name)

/**
* Builds features from a [[DataFrame]]
* Builds features from a [[StructType]]
*
* @param data input [[DataFrame]]
* @param schema input [[StructType]]
* @param response response feature name
* @param nonNullable optional non nullable feature names
* @throws IllegalArgumentException if fails to map dataframe field type into a feature type
* @throws RuntimeException if fails to construct a response feature
* @return label and other features
*/
def fromDataFrame[ResponseType <: FeatureType : WeakTypeTag](
data: DataFrame,
def fromSchema[ResponseType <: FeatureType : WeakTypeTag](
schema: StructType,
response: String,
nonNullable: Set[String] = Set.empty
): (Feature[ResponseType], Array[Feature[_ <: FeatureType]]) = {
val allFeatures: Array[Feature[_ <: FeatureType]] =
data.schema.fields.zipWithIndex.map { case (field, index) =>
schema.fields.zipWithIndex.map { case (field, index) =>
val isResponse = field.name == response
val isNullable = !isResponse && !nonNullable.contains(field.name)
val wtt: WeakTypeTag[_ <: FeatureType] = FeatureSparkTypes.featureTypeTagOf(field.dataType, isNullable)
Expand All @@ -215,6 +216,18 @@ object FeatureBuilder {
}
responseFeature -> features
}

/**
* Builds features from a [[DataFrame]]
*
* @param data input [[DataFrame]]
* @param response response feature name
* @param nonNullable optional non nullable feature names
* @throws IllegalArgumentException if fails to map dataframe field type into a feature type
* @throws RuntimeException if fails to construct a response feature
* @return label and other features
*/
def fromDataFrame[ResponseType <: FeatureType : WeakTypeTag](data: DataFrame, response: String, nonNullable: Set[String] = Set.empty): (Feature[ResponseType], Array[Feature[_ <: FeatureType]]) = fromSchema(data.schema, response, nonNullable)
def fromRow[O <: FeatureType : WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilderWithExtract[Row, O] = fromRow[O](name.value, None)
def fromRow[O <: FeatureType : WeakTypeTag](name: String): FeatureBuilderWithExtract[Row, O] = fromRow[O](name, None)
def fromRow[O <: FeatureType : WeakTypeTag](index: Int)(implicit name: sourcecode.Name): FeatureBuilderWithExtract[Row, O] = fromRow[O](name.value, Some(index))
Expand Down

0 comments on commit b08dd94

Please sign in to comment.