diff --git a/core/src/main/scala/com/salesforce/op/OpWorkflowModelReader.scala b/core/src/main/scala/com/salesforce/op/OpWorkflowModelReader.scala index ead0aec74e..7d089e5151 100644 --- a/core/src/main/scala/com/salesforce/op/OpWorkflowModelReader.scala +++ b/core/src/main/scala/com/salesforce/op/OpWorkflowModelReader.scala @@ -138,6 +138,8 @@ class OpWorkflowModelReader(val workflowOpt: Option[OpWorkflow]) extends MLReade val originalStage = workflow.stages.find(_.uid == stageUid) originalStage match { case Some(os) => new OpPipelineStageReader(os).loadFromJson(j, path = path).asInstanceOf[OPStage] + case None if stageUid.startsWith("FeatureGeneratorStage_") => + new OpPipelineStageReader(Seq()).loadFromJson(j, path).asInstanceOf[OPStage] case None => throw new RuntimeException(s"Workflow does not contain a stage with uid: $stageUid") } } diff --git a/core/src/main/scala/com/salesforce/op/OpWorkflowModelWriter.scala b/core/src/main/scala/com/salesforce/op/OpWorkflowModelWriter.scala index aa229d8c14..73e797d795 100644 --- a/core/src/main/scala/com/salesforce/op/OpWorkflowModelWriter.scala +++ b/core/src/main/scala/com/salesforce/op/OpWorkflowModelWriter.scala @@ -32,7 +32,7 @@ package com.salesforce.op import com.salesforce.op.features.FeatureJsonHelper import com.salesforce.op.filters.RawFeatureFilterResults -import com.salesforce.op.stages.{OpPipelineStageBase, OpPipelineStageWriter} +import com.salesforce.op.stages.{FeatureGeneratorStage, OPStage, OpPipelineStageBase, OpPipelineStageWriter} import enumeratum._ import org.apache.hadoop.fs.Path import org.apache.spark.ml.util.MLWriter @@ -98,13 +98,22 @@ class OpWorkflowModelWriter(val model: OpWorkflowModel) extends MLWriter { * @return array of serialized stages */ private def stagesJArray(path: String): JArray = { - val stages: Seq[OpPipelineStageBase] = model.stages + val stages: Seq[OpPipelineStageBase] = getFeatureGenStages(model.stages) ++ model.stages val stagesJson: Seq[JObject] = stages .map(_.write.asInstanceOf[OpPipelineStageWriter].writeToJson(path)) .filter(_.children.nonEmpty) JArray(stagesJson.toList) } + private def getFeatureGenStages(stages: Seq[OPStage]): Seq[OpPipelineStageBase] = { + for { + stage <- stages + inputFeatures <- stage.getInputFeatures() + orgStage = inputFeatures.originStage + if orgStage.isInstanceOf[FeatureGeneratorStage[_, _]] + } yield orgStage + } + /** * Gets all features to be serialized. * @@ -134,14 +143,23 @@ private[op] object OpWorkflowModelReadWriteShared { */ object FieldNames extends Enum[FieldNames] { val values = findValues + case object Uid extends FieldNames("uid") + case object ResultFeaturesUids extends FieldNames("resultFeaturesUids") + case object BlacklistedFeaturesUids extends FieldNames("blacklistedFeaturesUids") + case object Stages extends FieldNames("stages") + case object AllFeatures extends FieldNames("allFeatures") + case object Parameters extends FieldNames("parameters") + case object TrainParameters extends FieldNames("trainParameters") + case object RawFeatureFilterResultsFieldName extends FieldNames("rawFeatureFilterResults") + } } diff --git a/core/src/test/scala/com/salesforce/op/OpWorkflowModelReaderWriterTest.scala b/core/src/test/scala/com/salesforce/op/OpWorkflowModelReaderWriterTest.scala index 74f5bdd8e0..245834fc36 100644 --- a/core/src/test/scala/com/salesforce/op/OpWorkflowModelReaderWriterTest.scala +++ b/core/src/test/scala/com/salesforce/op/OpWorkflowModelReaderWriterTest.scala @@ -52,7 +52,7 @@ import org.scalatest.{BeforeAndAfterEach, FlatSpec} import org.slf4j.LoggerFactory import scala.collection.JavaConverters._ - +import OpWorkflowModelReaderWriterTest._ @RunWith(classOf[JUnitRunner]) class OpWorkflowModelReaderWriterTest @@ -145,7 +145,7 @@ class OpWorkflowModelReaderWriterTest } trait SwSingleStageFlow { - val vec = FeatureBuilder.OPVector[Passenger].extract(_ => OPVector.empty).asPredictor + val vec = FeatureBuilder.OPVector[Passenger].extract(emptyVectFnc).asPredictor val scaler = new StandardScaler().setWithStd(false).setWithMean(false) val schema = FeatureSparkTypes.toStructType(vec) val data = spark.createDataFrame(List(Row(Vectors.dense(1.0))).asJava, schema) @@ -172,7 +172,7 @@ class OpWorkflowModelReaderWriterTest it should "have a single stage" in new SingleStageFlow { val stagesM = (jsonModel \ Stages.entryName).extract[JArray] - stagesM.values.size shouldBe 1 + stagesM.values.size shouldBe 3 } it should "have 3 features" in new SingleStageFlow { @@ -193,7 +193,7 @@ class OpWorkflowModelReaderWriterTest "MultiStage OpWorkflowWriter" should "recover all relevant stages" in new MultiStageFlow { val stagesM = (jsonModel \ Stages.entryName).extract[JArray] - stagesM.values.size shouldBe 2 + stagesM.values.size shouldBe 5 } it should "recover all relevant features" in new MultiStageFlow { @@ -379,4 +379,6 @@ trait UIDReset { object OpWorkflowModelReaderWriterTest { def mapFnc0: OPVector => Real = v => Real(v.value.toArray.headOption) + + def emptyVectFnc: (Passenger => OPVector) = _ => OPVector.empty } diff --git a/core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala b/core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala index 985b48fc81..2d0cf977ae 100644 --- a/core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala +++ b/core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala @@ -44,6 +44,7 @@ import com.salesforce.op.stages.impl.tuning._ import com.salesforce.op.test.{Passenger, PassengerSparkFixtureTest, TestFeatureBuilder} import com.salesforce.op.utils.spark.RichDataset._ import com.salesforce.op.utils.spark.{OpVectorColumnMetadata, OpVectorMetadata} +import org.apache.log4j.Level import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.{BooleanParam, ParamMap} import org.apache.spark.ml.tuning.ParamGridBuilder diff --git a/features/src/main/scala/com/salesforce/op/features/FeatureBuilder.scala b/features/src/main/scala/com/salesforce/op/features/FeatureBuilder.scala index 53b093ea1d..51fc9fd2a6 100644 --- a/features/src/main/scala/com/salesforce/op/features/FeatureBuilder.scala +++ b/features/src/main/scala/com/salesforce/op/features/FeatureBuilder.scala @@ -52,44 +52,77 @@ object FeatureBuilder { // Lists def TextList[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, TextList] = TextList(name.value) + def DateList[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, DateList] = DateList(name.value) + def DateTimeList[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, DateTimeList] = DateTimeList(name.value) + def Geolocation[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, Geolocation] = Geolocation(name.value) // Maps def Base64Map[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, Base64Map] = Base64Map(name.value) + def BinaryMap[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, BinaryMap] = BinaryMap(name.value) + def ComboBoxMap[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, ComboBoxMap] = ComboBoxMap(name.value) + def CurrencyMap[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, CurrencyMap] = CurrencyMap(name.value) + def DateMap[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, DateMap] = DateMap(name.value) + def DateTimeMap[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, DateTimeMap] = DateTimeMap(name.value) + def EmailMap[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, EmailMap] = EmailMap(name.value) + def IDMap[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, IDMap] = IDMap(name.value) + def IntegralMap[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, IntegralMap] = IntegralMap(name.value) + def MultiPickListMap[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, MultiPickListMap] = MultiPickListMap(name.value) + def PercentMap[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, PercentMap] = PercentMap(name.value) + def PhoneMap[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, PhoneMap] = PhoneMap(name.value) + def PickListMap[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, PickListMap] = PickListMap(name.value) + def RealMap[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, RealMap] = RealMap(name.value) + def TextAreaMap[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, TextAreaMap] = TextAreaMap(name.value) + def TextMap[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, TextMap] = TextMap(name.value) + def URLMap[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, URLMap] = URLMap(name.value) + def CountryMap[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, CountryMap] = CountryMap(name.value) + def StateMap[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, StateMap] = StateMap(name.value) + def CityMap[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, CityMap] = CityMap(name.value) + def PostalCodeMap[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, PostalCodeMap] = PostalCodeMap(name.value) + def StreetMap[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, StreetMap] = StreetMap(name.value) + def GeolocationMap[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, GeolocationMap] = GeolocationMap(name.value) + def Prediction[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, Prediction] = Prediction(name.value) // Numerics def Binary[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, Binary] = Binary(name.value) + def Currency[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, Currency] = Currency(name.value) + def Date[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, Date] = Date(name.value) + def DateTime[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, DateTime] = DateTime(name.value) + def Integral[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, Integral] = Integral(name.value) + def Percent[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, Percent] = Percent(name.value) + def Real[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, Real] = Real(name.value) + def RealNN[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, RealNN] = RealNN(name.value) // Sets @@ -97,18 +130,31 @@ object FeatureBuilder { // Text def Base64[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, Base64] = Base64(name.value) + def ComboBox[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, ComboBox] = ComboBox(name.value) + def Email[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, Email] = Email(name.value) + def ID[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, ID] = ID(name.value) + def Phone[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, Phone] = Phone(name.value) + def PickList[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, PickList] = PickList(name.value) + def Text[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, Text] = Text(name.value) + def TextArea[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, TextArea] = TextArea(name.value) + def URL[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, URL] = URL(name.value) + def Country[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, Country] = Country(name.value) + def State[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, State] = State(name.value) + def City[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, City] = City(name.value) + def PostalCode[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, PostalCode] = PostalCode(name.value) + def Street[I: WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilder[I, Street] = Street(name.value) // Vector @@ -116,44 +162,77 @@ object FeatureBuilder { // Lists def TextList[I: WeakTypeTag](name: String): FeatureBuilder[I, TextList] = FeatureBuilder[I, TextList](name = name) + def DateList[I: WeakTypeTag](name: String): FeatureBuilder[I, DateList] = FeatureBuilder[I, DateList](name = name) + def DateTimeList[I: WeakTypeTag](name: String): FeatureBuilder[I, DateTimeList] = FeatureBuilder[I, DateTimeList](name = name) + def Geolocation[I: WeakTypeTag](name: String): FeatureBuilder[I, Geolocation] = FeatureBuilder[I, Geolocation](name = name) // Maps def Base64Map[I: WeakTypeTag](name: String): FeatureBuilder[I, Base64Map] = FeatureBuilder[I, Base64Map](name = name) + def BinaryMap[I: WeakTypeTag](name: String): FeatureBuilder[I, BinaryMap] = FeatureBuilder[I, BinaryMap](name = name) + def ComboBoxMap[I: WeakTypeTag](name: String): FeatureBuilder[I, ComboBoxMap] = FeatureBuilder[I, ComboBoxMap](name = name) + def CurrencyMap[I: WeakTypeTag](name: String): FeatureBuilder[I, CurrencyMap] = FeatureBuilder[I, CurrencyMap](name = name) + def DateMap[I: WeakTypeTag](name: String): FeatureBuilder[I, DateMap] = FeatureBuilder[I, DateMap](name = name) + def DateTimeMap[I: WeakTypeTag](name: String): FeatureBuilder[I, DateTimeMap] = FeatureBuilder[I, DateTimeMap](name = name) + def EmailMap[I: WeakTypeTag](name: String): FeatureBuilder[I, EmailMap] = FeatureBuilder[I, EmailMap](name = name) + def IDMap[I: WeakTypeTag](name: String): FeatureBuilder[I, IDMap] = FeatureBuilder[I, IDMap](name = name) + def IntegralMap[I: WeakTypeTag](name: String): FeatureBuilder[I, IntegralMap] = FeatureBuilder[I, IntegralMap](name = name) + def MultiPickListMap[I: WeakTypeTag](name: String): FeatureBuilder[I, MultiPickListMap] = FeatureBuilder[I, MultiPickListMap](name = name) + def PercentMap[I: WeakTypeTag](name: String): FeatureBuilder[I, PercentMap] = FeatureBuilder[I, PercentMap](name = name) + def PhoneMap[I: WeakTypeTag](name: String): FeatureBuilder[I, PhoneMap] = FeatureBuilder[I, PhoneMap](name = name) + def PickListMap[I: WeakTypeTag](name: String): FeatureBuilder[I, PickListMap] = FeatureBuilder[I, PickListMap](name = name) + def RealMap[I: WeakTypeTag](name: String): FeatureBuilder[I, RealMap] = FeatureBuilder[I, RealMap](name = name) + def TextAreaMap[I: WeakTypeTag](name: String): FeatureBuilder[I, TextAreaMap] = FeatureBuilder[I, TextAreaMap](name = name) + def TextMap[I: WeakTypeTag](name: String): FeatureBuilder[I, TextMap] = FeatureBuilder[I, TextMap](name = name) + def URLMap[I: WeakTypeTag](name: String): FeatureBuilder[I, URLMap] = FeatureBuilder[I, URLMap](name = name) + def CountryMap[I: WeakTypeTag](name: String): FeatureBuilder[I, CountryMap] = FeatureBuilder[I, CountryMap](name = name) + def StateMap[I: WeakTypeTag](name: String): FeatureBuilder[I, StateMap] = FeatureBuilder[I, StateMap](name = name) + def CityMap[I: WeakTypeTag](name: String): FeatureBuilder[I, CityMap] = FeatureBuilder[I, CityMap](name = name) + def PostalCodeMap[I: WeakTypeTag](name: String): FeatureBuilder[I, PostalCodeMap] = FeatureBuilder[I, PostalCodeMap](name = name) + def StreetMap[I: WeakTypeTag](name: String): FeatureBuilder[I, StreetMap] = FeatureBuilder[I, StreetMap](name = name) + def GeolocationMap[I: WeakTypeTag](name: String): FeatureBuilder[I, GeolocationMap] = FeatureBuilder[I, GeolocationMap](name = name) + def Prediction[I: WeakTypeTag](name: String): FeatureBuilder[I, Prediction] = FeatureBuilder[I, Prediction](name = name) // Numerics def Binary[I: WeakTypeTag](name: String): FeatureBuilder[I, Binary] = FeatureBuilder[I, Binary](name = name) + def Currency[I: WeakTypeTag](name: String): FeatureBuilder[I, Currency] = FeatureBuilder[I, Currency](name = name) + def Date[I: WeakTypeTag](name: String): FeatureBuilder[I, Date] = FeatureBuilder[I, Date](name = name) + def DateTime[I: WeakTypeTag](name: String): FeatureBuilder[I, DateTime] = FeatureBuilder[I, DateTime](name = name) + def Integral[I: WeakTypeTag](name: String): FeatureBuilder[I, Integral] = FeatureBuilder[I, Integral](name = name) + def Percent[I: WeakTypeTag](name: String): FeatureBuilder[I, Percent] = FeatureBuilder[I, Percent](name = name) + def Real[I: WeakTypeTag](name: String): FeatureBuilder[I, Real] = FeatureBuilder[I, Real](name = name) + def RealNN[I: WeakTypeTag](name: String): FeatureBuilder[I, RealNN] = FeatureBuilder[I, RealNN](name = name) // Sets @@ -161,18 +240,31 @@ object FeatureBuilder { // Text def Base64[I: WeakTypeTag](name: String): FeatureBuilder[I, Base64] = FeatureBuilder[I, Base64](name = name) + def ComboBox[I: WeakTypeTag](name: String): FeatureBuilder[I, ComboBox] = FeatureBuilder[I, ComboBox](name = name) + def Email[I: WeakTypeTag](name: String): FeatureBuilder[I, Email] = FeatureBuilder[I, Email](name = name) + def ID[I: WeakTypeTag](name: String): FeatureBuilder[I, ID] = FeatureBuilder[I, ID](name = name) + def Phone[I: WeakTypeTag](name: String): FeatureBuilder[I, Phone] = FeatureBuilder[I, Phone](name = name) + def PickList[I: WeakTypeTag](name: String): FeatureBuilder[I, PickList] = FeatureBuilder[I, PickList](name = name) + def Text[I: WeakTypeTag](name: String): FeatureBuilder[I, Text] = FeatureBuilder[I, Text](name = name) + def TextArea[I: WeakTypeTag](name: String): FeatureBuilder[I, TextArea] = FeatureBuilder[I, TextArea](name = name) + def URL[I: WeakTypeTag](name: String): FeatureBuilder[I, URL] = FeatureBuilder[I, URL](name = name) + def Country[I: WeakTypeTag](name: String): FeatureBuilder[I, Country] = FeatureBuilder[I, Country](name = name) + def State[I: WeakTypeTag](name: String): FeatureBuilder[I, State] = FeatureBuilder[I, State](name = name) + def City[I: WeakTypeTag](name: String): FeatureBuilder[I, City] = FeatureBuilder[I, City](name = name) + def PostalCode[I: WeakTypeTag](name: String): FeatureBuilder[I, PostalCode] = FeatureBuilder[I, PostalCode](name = name) + def Street[I: WeakTypeTag](name: String): FeatureBuilder[I, Street] = FeatureBuilder[I, Street](name = name) def apply[I: WeakTypeTag, O <: FeatureType : WeakTypeTag](name: String): FeatureBuilder[I, O] = new FeatureBuilder[I, O](name) @@ -184,7 +276,7 @@ object FeatureBuilder { * @param response response feature name * @param nonNullable optional non nullable feature names * @throws IllegalArgumentException if fails to map dataframe field type into a feature type - * @throws RuntimeException if fails to construct a response feature + * @throws RuntimeException if fails to construct a response feature * @return label and other features */ def fromDataFrame[ResponseType <: FeatureType : WeakTypeTag]( @@ -215,20 +307,32 @@ object FeatureBuilder { } responseFeature -> features } + def fromRow[O <: FeatureType : WeakTypeTag](implicit name: sourcecode.Name): FeatureBuilderWithExtract[Row, O] = fromRow[O](name.value, None) + def fromRow[O <: FeatureType : WeakTypeTag](name: String): FeatureBuilderWithExtract[Row, O] = fromRow[O](name, None) + def fromRow[O <: FeatureType : WeakTypeTag](index: Int)(implicit name: sourcecode.Name): FeatureBuilderWithExtract[Row, O] = fromRow[O](name.value, Some(index)) + def fromRow[O <: FeatureType : WeakTypeTag](name: String, index: Option[Int]): FeatureBuilderWithExtract[Row, O] = { - val c = FeatureTypeSparkConverter[O]() + new FeatureBuilderWithExtract[Row, O]( name = name, - extractFn = (r: Row) => c.fromSpark(index.map(r.get).getOrElse(r.getAny(name))), - extractSource = "(r: Row) => c.fromSpark(index.map(r.get).getOrElse(r.getAny(name)))" + extractFn = FromRowExtractFn(index, name), + extractSource = s"FromRowExtractFn($index, $name)" ) } + // scalastyle:on } +case class FromRowExtractFn[O <: FeatureType](index: Option[Int], name: String) + (implicit tto: WeakTypeTag[O]) extends Function1[Row, O] with Serializable { + val c = FeatureTypeSparkConverter[O]() + + override def apply(r: Row): O = c.fromSpark(index.map(r.get).getOrElse(r.getAny(name))) +} + /** * Feature Builder allows building features * @@ -244,7 +348,7 @@ final class FeatureBuilder[I, O <: FeatureType](val name: String) { * @param fn a function to extract value of the feature from the raw data */ def extract(fn: I => O): FeatureBuilderWithExtract[I, O] = - macro FeatureBuilderMacros.extract[I, O] + macro FeatureBuilderMacros.extract[I, O] /** * Feature extract method - a function to extract value of the feature from the raw data. @@ -254,7 +358,7 @@ final class FeatureBuilder[I, O <: FeatureType](val name: String) { * @param default the default value */ def extract(fn: I => O, default: O): FeatureBuilderWithExtract[I, O] = - macro FeatureBuilderMacros.extractWithDefault[I, O] + macro FeatureBuilderMacros.extractWithDefault[I, O] } @@ -270,33 +374,14 @@ final class FeatureBuilderWithExtract[I, O <: FeatureType] val name: String, val extractFn: I => O, val extractSource: String -)(implicit val tti: WeakTypeTag[I], val tto: WeakTypeTag[O]) { +)(implicit val tti: WeakTypeTag[I], val tto: WeakTypeTag[O], val ttov: WeakTypeTag[O#Value]) { var aggregator: MonoidAggregator[Event[O], _, O] = MonoidAggregatorDefaults.aggregatorOf[O](tto) var aggregateWindow: Option[Duration] = None - /** - * Feature aggregation function with zero value - * @param zero a zero element for aggregation - * @param fn aggregation function - */ - def aggregate(zero: O#Value, fn: (O#Value, O#Value) => O#Value): this.type = { - aggregator = new CustomMonoidAggregator[O](associativeFn = fn, zero = zero)(tto) - this - } - - /** - * Feature aggregation function with zero value of [[FeatureTypeDefaults.default[O].value]] - * @param fn aggregation function - */ - def aggregate(fn: (O#Value, O#Value) => O#Value): this.type = { - val zero = FeatureTypeDefaults.default[O](tto).value - aggregator = new CustomMonoidAggregator[O](associativeFn = fn, zero = zero)(tto) - this - } - /** * Feature aggregation with a monoid aggregator + * * @param monoid a monoid aggregator */ def aggregate(monoid: MonoidAggregator[Event[O], _, O]): this.type = { @@ -306,6 +391,7 @@ final class FeatureBuilderWithExtract[I, O <: FeatureType] /** * Aggregation time window + * * @param time a time period during which to include features in aggregation */ def window(time: Duration): this.type = { @@ -315,12 +401,14 @@ final class FeatureBuilderWithExtract[I, O <: FeatureType] /** * Make a predictor feature + * * @return a predictor feature */ def asPredictor: Feature[O] = makeFeature(isResponse = false) /** * Make a response feature + * * @return a response feature */ def asResponse: Feature[O] = makeFeature(isResponse = true) @@ -334,7 +422,7 @@ final class FeatureBuilderWithExtract[I, O <: FeatureType] outputName = name, outputIsResponse = isResponse, aggregateWindow = aggregateWindow - )(tti, tto) + )(Left(tti), tto) originStage.getOutput().asInstanceOf[Feature[O]] } diff --git a/features/src/main/scala/com/salesforce/op/stages/FeatureGeneratorStage.scala b/features/src/main/scala/com/salesforce/op/stages/FeatureGeneratorStage.scala index 17346b0d3c..40ea7e28c6 100644 --- a/features/src/main/scala/com/salesforce/op/stages/FeatureGeneratorStage.scala +++ b/features/src/main/scala/com/salesforce/op/stages/FeatureGeneratorStage.scala @@ -32,15 +32,18 @@ package com.salesforce.op.stages import com.salesforce.op.UID import com.salesforce.op.aggregators.{Event, FeatureAggregator, GenericFeatureAggregator} -import com.salesforce.op.features.types.FeatureType -import com.salesforce.op.features.{Feature, FeatureLike, FeatureUID, OPFeature} +import com.salesforce.op.features.types.{FeatureType, Text} +import com.salesforce.op.features._ +import com.salesforce.op.utils.reflection.ReflectionUtils import com.twitter.algebird.MonoidAggregator import org.apache.spark.ml.PipelineStage import org.apache.spark.util.ClosureUtils import org.joda.time.Duration +import org.json4s.JValue +import org.json4s.JsonDSL._ import scala.reflect.runtime.universe.WeakTypeTag -import scala.util.Try +import scala.util.{Failure, Success, Try} /** * Origin stage for first features in workflow @@ -58,20 +61,29 @@ import scala.util.Try * @tparam I input data type * @tparam O output feature type */ + +@ReaderWriter(classOf[FeatureGeneratorStageReaderWriter[_, _ <: FeatureType]]) final class FeatureGeneratorStage[I, O <: FeatureType] ( val extractFn: I => O, val extractSource: String, val aggregator: MonoidAggregator[Event[O], _, O], - outputName: String, + val outputName: String, override val outputIsResponse: Boolean, val aggregateWindow: Option[Duration] = None, val uid: String = UID[FeatureGeneratorStage[I, O]] )( - implicit val tti: WeakTypeTag[I], + implicit val _tti: Either[WeakTypeTag[I], String], val tto: WeakTypeTag[O] ) extends PipelineStage with OpPipelineStage[O] with HasIn1 { + // this hack is required as Spark can't serialize run-time created + // TypeTags (because it is following the ReflectionUtils...) + def tti: WeakTypeTag[I] = _tti match { + case Left(x) => x + case Right(n) => ReflectionUtils.weakTypeTagForName(n).asInstanceOf[WeakTypeTag[I]] + } + setOutputFeatureName(outputName) override type InputFeatures = OPFeature @@ -107,3 +119,83 @@ final class FeatureGeneratorStage[I, O <: FeatureType] */ override def checkSerializable: Try[Unit] = ClosureUtils.checkSerializable(extractFn) } + + +class FeatureGeneratorStageReaderWriter[I, O <: FeatureType] + extends OpPipelineStageJsonReaderWriter[FeatureGeneratorStage[I, O]] with LambdaSerializer { + + private val FromRowExtractFnName = classOf[FromRowExtractFn[_]].getName + + /** + * Read stage from json + * + * @param stageClass stage class + * @param json json to read stage from + * @return read result + */ + override def read(stageClass: Class[FeatureGeneratorStage[I, O]], json: JValue): Try[FeatureGeneratorStage[I, O]] = { + Try { + val tto = FeatureType.featureTypeTag((json \ "tto").extract[String]).asInstanceOf[WeakTypeTag[O]] + val ttiName = (json \ "tti").extract[String] + val extractFnStr = (json \ "extractFn").extract[String] + + val extractFn = extractFnStr match { + case FromRowExtractFnName => { + val extractFnName = (json \ "extractFnName").extract[String] + val extractFnIdx = (json \ "extractFnIdx").extractOpt[Int] + ReflectionUtils.classForName(extractFnStr) + .getConstructors.head.newInstance(extractFnIdx, extractFnName, tto) + .asInstanceOf[Function1[I, O]] + } + case _ => ReflectionUtils.getInstanceOfObject[Function1[I, O]](extractFnStr) + } + + val aggregator = ReflectionUtils. + getInstanceOfObject[MonoidAggregator[Event[O], _, O]]((json \ "aggregator").extract[String]) + + val outputName = (json \ "outputName").extract[String] + val extractSource = (json \ "extractSource").extract[String] + val uid = (json \ "uid").extract[String] + val outputIsResponse = (json \ "outputIsResponse").extract[Boolean] + val aggregateWindow = (json \ "aggregateWindow").extractOpt[Int].map(x => Duration.standardSeconds(x)) + + new FeatureGeneratorStage(extractFn, extractSource, aggregator, + outputName, outputIsResponse, aggregateWindow, uid)(_tti = Right(ttiName), tto) + } + + } + + /** + * Write stage to json + * + * @param stage stage instance to write + * @return write result + */ + override def write(stage: FeatureGeneratorStage[I, O]): Try[JValue] = { + for { + extractFn <- Try { + stage.extractFn match { + case c: FromRowExtractFn[_] => c.getClass.getName + case _ => serializeFunction("extractFn", stage.extractFn).value.toString + } + } + aggregatorFn <- Try(serializeFunction("aggregator", stage.aggregator)) + } yield { + val res = ("tti" -> stage.tti.tpe.typeSymbol.fullName) ~ + ("tto" -> FeatureType.typeName(stage.tto)) ~ + ("aggregator" -> aggregatorFn.value.toString) ~ + ("extractFn" -> extractFn) ~ + ("outputName" -> stage.outputName) ~ + ("aggregateWindow" -> stage.aggregateWindow.map(_.toStandardSeconds.getSeconds)) ~ + ("uid" -> stage.uid) ~ + ("extractSource" -> stage.extractSource) ~ + ("outputIsResponse" -> stage.outputIsResponse) + + stage.extractFn match { + case x: FromRowExtractFn[_] => res ~ ("extractFnIdx" -> x.index) ~ ("extractFnName" -> x.name) + case _ => res + } + } + + } +} diff --git a/features/src/main/scala/com/salesforce/op/stages/OpPipelineStageJsonReaderWriter.scala b/features/src/main/scala/com/salesforce/op/stages/OpPipelineStageJsonReaderWriter.scala index 25b6446723..0f030a289e 100644 --- a/features/src/main/scala/com/salesforce/op/stages/OpPipelineStageJsonReaderWriter.scala +++ b/features/src/main/scala/com/salesforce/op/stages/OpPipelineStageJsonReaderWriter.scala @@ -76,7 +76,7 @@ trait OpPipelineStageJsonReaderWriter[StageType <: OpPipelineStageBase] extends final class DefaultOpPipelineStageJsonReaderWriter[StageType <: OpPipelineStageBase] ( implicit val ct: ClassTag[StageType] -) extends OpPipelineStageJsonReaderWriter[StageType] { +) extends OpPipelineStageJsonReaderWriter[StageType] with LambdaSerializer { /** * Read stage from json @@ -205,19 +205,22 @@ final class DefaultOpPipelineStageJsonReaderWriter[StageType <: OpPipelineStageB Extraction.decompose(args.toMap) } - private def serializeFunction(argName: String, function: AnyRef): AnyValue = { + + private def jsonSerialize(v: Any): JValue = render(Extraction.decompose(v)) +} + +trait LambdaSerializer { + protected def serializeFunction(argName: String, function: AnyRef): AnyValue = { try { val functionClass = function.getClass // Test that function has no external dependencies and can be constructed without ctor args - functionClass.getConstructors.head.newInstance() + functionClass.getConstructors.headOption.foreach(_.newInstance()) AnyValue(AnyValueTypes.ClassInstance, functionClass.getName, Option(functionClass.getName)) } catch { case error: Exception => throw new RuntimeException( - s"Function argument '$argName' cannot be serialized. " + + s"Function argument '$argName' [${function.getClass.getName}] cannot be serialized. " + "Make sure your function does not have any external dependencies, " + "e.g. use any out of scope variables.", error) } } - - private def jsonSerialize(v: Any): JValue = render(Extraction.decompose(v)) } diff --git a/features/src/test/scala/com/salesforce/op/features/FeatureBuilderTest.scala b/features/src/test/scala/com/salesforce/op/features/FeatureBuilderTest.scala index 756f19d7e7..c33315e969 100644 --- a/features/src/test/scala/com/salesforce/op/features/FeatureBuilderTest.scala +++ b/features/src/test/scala/com/salesforce/op/features/FeatureBuilderTest.scala @@ -160,7 +160,7 @@ class FeatureBuilderTest extends FlatSpec with TestSparkContext { val feature = FeatureBuilder.Real[Passenger] .extract(p => Option(p.getAge).map(_.toDouble).toReal) - .aggregate((v1, _) => v1) + .aggregate(TestCustomMonoidAggregator) .asPredictor assertFeature[Passenger, Real](feature)(name = name, in = passenger, out = 1.toReal, @@ -171,7 +171,7 @@ class FeatureBuilderTest extends FlatSpec with TestSparkContext { it should "build an aggregated feature with a custom aggregate function with zero" in { val feature = FeatureBuilder.Real[Passenger] .extract(p => Option(p.getAge).map(_.toDouble).toReal) - .aggregate(Real.empty.v, (v1, _) => v1) + .aggregate(TestCustomMonoidAggregator) .asPredictor assertFeature[Passenger, Real](feature)(name = name, in = passenger, out = 1.toReal, @@ -181,6 +181,9 @@ class FeatureBuilderTest extends FlatSpec with TestSparkContext { } +object TestCustomMonoidAggregator extends CustomMonoidAggregator[Real](zero = Real.empty.v, + associativeFn = (v1, _) => v1) + /** * Assert feature instance on a given input/output */ diff --git a/readers/src/main/scala/com/salesforce/op/test/PassengerFeaturesTest.scala b/readers/src/main/scala/com/salesforce/op/test/PassengerFeaturesTest.scala index 2645f1cdf9..04011fefb2 100644 --- a/readers/src/main/scala/com/salesforce/op/test/PassengerFeaturesTest.scala +++ b/readers/src/main/scala/com/salesforce/op/test/PassengerFeaturesTest.scala @@ -34,36 +34,70 @@ import com.salesforce.op.features.types._ import com.salesforce.op.features.{FeatureBuilder, OPFeature} import com.salesforce.op.utils.tuples.RichTuple._ import org.joda.time.Duration - +import PassengerFeaturesTestLambdas._ +import com.salesforce.op.aggregators.CustomMonoidAggregator trait PassengerFeaturesTest { val age = FeatureBuilder.Real[Passenger] - .extract(_.getAge.toReal) - .aggregate((l, r) => (l -> r).map(breeze.linalg.max(_, _))) + .extract(ageFnc) + .aggregate(TestMonoidAggregator) .asPredictor - val gender = FeatureBuilder.MultiPickList[Passenger].extract(p => Set(p.getGender).toMultiPickList).asPredictor - val genderPL = FeatureBuilder.PickList[Passenger].extract(p => p.getGender.toPickList).asPredictor + val gender = FeatureBuilder.MultiPickList[Passenger].extract(genderFnc).asPredictor + val genderPL = FeatureBuilder.PickList[Passenger].extract(genderPLFnc).asPredictor val height = FeatureBuilder.RealNN[Passenger] - .extract(p => Option(p.getHeight).map(_.toDouble).toRealNN(0.0)) + .extract(heightFnc) .window(Duration.millis(300)) .asPredictor - val heightNoWindow = FeatureBuilder.Real[Passenger].extract(_.getHeight.toReal).asPredictor - val weight = FeatureBuilder.Real[Passenger].extract(_.getWeight.toReal).asPredictor - val description = FeatureBuilder.Text[Passenger].extract(_.getDescription.toText).asPredictor - val boarded = FeatureBuilder.DateList[Passenger].extract(p => Seq(p.getBoarded.toLong).toDateList).asPredictor - val stringMap = FeatureBuilder.TextMap[Passenger].extract(p => p.getStringMap.toTextMap).asPredictor - val numericMap = FeatureBuilder.RealMap[Passenger].extract(p => p.getNumericMap.toRealMap).asPredictor - val booleanMap = FeatureBuilder.BinaryMap[Passenger].extract(p => p.getBooleanMap.toBinaryMap).asPredictor - val survived = FeatureBuilder.Binary[Passenger].extract(p => Option(p.getSurvived).map(_ == 1).toBinary).asResponse - val boardedTime = FeatureBuilder.Date[Passenger].extract(_.getBoarded.toLong.toDate).asPredictor - val boardedTimeAsDateTime = FeatureBuilder.DateTime[Passenger].extract(_.getBoarded.toLong.toDateTime).asPredictor + val heightNoWindow = FeatureBuilder.Real[Passenger].extract(heightToReal).asPredictor + val weight = FeatureBuilder.Real[Passenger].extract(weightToReal).asPredictor + val description = FeatureBuilder.Text[Passenger].extract(descrToText).asPredictor + val boarded = FeatureBuilder.DateList[Passenger].extract(boardedToDL).asPredictor + val stringMap = FeatureBuilder.TextMap[Passenger].extract(stringMapFnc).asPredictor + val numericMap = FeatureBuilder.RealMap[Passenger].extract(numericMapFnc).asPredictor + val booleanMap = FeatureBuilder.BinaryMap[Passenger].extract(booleanMapFnc).asPredictor + val survived = FeatureBuilder.Binary[Passenger].extract(survivedFnc).asResponse + val boardedTime = FeatureBuilder.Date[Passenger].extract(boardedTimeFnc).asPredictor + val boardedTimeAsDateTime = FeatureBuilder.DateTime[Passenger].extract(boardedDTFnc).asPredictor val rawFeatures: Array[OPFeature] = Array( survived, age, gender, height, weight, description, boarded, stringMap, numericMap, booleanMap ) } + +object TestMonoidAggregator + extends CustomMonoidAggregator[Real](None, (l, r) => (l -> r).map(breeze.linalg.max(_, _))) with Serializable + +object PassengerFeaturesTestLambdas { + def genderFnc: (Passenger => MultiPickList) = p => Set(p.getGender).toMultiPickList + + def genderPLFnc: (Passenger => PickList) = p => p.getGender.toPickList + + def heightFnc: (Passenger => RealNN) = p => Option(p.getHeight).map(_.toDouble).toRealNN(0.0) + + def heightToReal: (Passenger => Real) = _.getHeight.toReal + + def weightToReal: (Passenger => Real) = _.getWeight.toReal + + def descrToText: (Passenger => Text) = _.getDescription.toText + + def boardedToDL: (Passenger => DateList) = p => Seq(p.getBoarded.toLong).toDateList + + def stringMapFnc: (Passenger => TextMap) = p => p.getStringMap.toTextMap + + def numericMapFnc: (Passenger => RealMap) = p => p.getNumericMap.toRealMap + + def booleanMapFnc: (Passenger => BinaryMap) = p => p.getBooleanMap.toBinaryMap + + def survivedFnc: (Passenger => Binary) = p => Option(p.getSurvived).map(_ == 1).toBinary + + def boardedTimeFnc: (Passenger => Date) = _.getBoarded.toLong.toDate + + def boardedDTFnc: (Passenger => DateTime) = _.getBoarded.toLong.toDateTime + + def ageFnc: (Passenger => Real) = _.getAge.toReal +} diff --git a/readers/src/test/scala/com/salesforce/op/readers/DataReadersTest.scala b/readers/src/test/scala/com/salesforce/op/readers/DataReadersTest.scala index 4cfcbb48e8..21c52053d5 100644 --- a/readers/src/test/scala/com/salesforce/op/readers/DataReadersTest.scala +++ b/readers/src/test/scala/com/salesforce/op/readers/DataReadersTest.scala @@ -31,7 +31,7 @@ package com.salesforce.op.readers import com.salesforce.op.OpParams -import com.salesforce.op.aggregators.CutOffTime +import com.salesforce.op.aggregators.{CustomMonoidAggregator, CutOffTime} import com.salesforce.op.features.FeatureBuilder import com.salesforce.op.features.types._ import com.salesforce.op.test._ @@ -62,7 +62,7 @@ class DataReadersTest extends FlatSpec with PassengerSparkFixtureTest with TestC val survivedResponse = FeatureBuilder.Binary[PassengerCaseClass] .extract(_.survived.toBinary) - .aggregate(zero = Some(true), (l, r) => Some(l.getOrElse(false) && r.getOrElse(false))) + .aggregate(TestCustomMonoidAggregator) .asResponse val aggregateParameters = AggregateParams( @@ -175,7 +175,7 @@ class DataReadersTest extends FlatSpec with PassengerSparkFixtureTest with TestC } } - aggReaders.foreach( reader => + aggReaders.foreach(reader => Spec(reader.getClass) should "read and aggregate data correctly" in { val data = reader.readDataset().collect() data.foreach(_ shouldBe a[PassengerCaseClass]) @@ -183,13 +183,13 @@ class DataReadersTest extends FlatSpec with PassengerSparkFixtureTest with TestC val aggregatedData = reader.generateDataFrame(rawFeatures = Array(agePredictor, survivedResponse)).collect() aggregatedData.length shouldBe 6 - aggregatedData.collect { case r if r.get(0) == "4" => r} shouldEqual Array(Row("4", 60, false)) + aggregatedData.collect { case r if r.get(0) == "4" => r } shouldEqual Array(Row("4", 60, false)) reader.fullTypeName shouldBe typeOf[PassengerCaseClass].toString } ) - conditionalReaders.foreach( reader => + conditionalReaders.foreach(reader => Spec(reader.getClass) should "read and conditionally aggregate data correctly" in { val data = reader.readDataset().collect() data.foreach(_ shouldBe a[PassengerCaseClass]) @@ -204,3 +204,5 @@ class DataReadersTest extends FlatSpec with PassengerSparkFixtureTest with TestC ) } +object TestCustomMonoidAggregator extends CustomMonoidAggregator[Binary](zero = Some(true), + (l, r) => Some(l.getOrElse(false) && r.getOrElse(false))) diff --git a/readers/src/test/scala/com/salesforce/op/readers/JoinedDataReaderDataGenerationTest.scala b/readers/src/test/scala/com/salesforce/op/readers/JoinedDataReaderDataGenerationTest.scala index 609a3552a7..99ee8b5381 100644 --- a/readers/src/test/scala/com/salesforce/op/readers/JoinedDataReaderDataGenerationTest.scala +++ b/readers/src/test/scala/com/salesforce/op/readers/JoinedDataReaderDataGenerationTest.scala @@ -30,7 +30,7 @@ package com.salesforce.op.readers -import com.salesforce.op.aggregators.CutOffTime +import com.salesforce.op.aggregators.{CustomMonoidAggregator, CutOffTime} import com.salesforce.op.features.types._ import com.salesforce.op.features.{FeatureBuilder, OPFeature} import com.salesforce.op.test._ @@ -51,13 +51,13 @@ class JoinedDataReaderDataGenerationTest extends FlatSpec with PassengerSparkFix val newWeight = FeatureBuilder.RealNN[PassengerCSV] .extract(_.getWeight.toDouble.toRealNN) - .aggregate(zero = Some(Double.MaxValue), (a, b) => Some(math.min(a.v.getOrElse(0.0), b.v.getOrElse(0.0)))) + .aggregate(TestMinCustomMonoidAggregator) .asPredictor val newHeight = FeatureBuilder.RealNN[PassengerCSV] .extract(_.getHeight.toDouble.toRealNN) - .aggregate(zero = Some(0.0), (a, b) => Some(math.max(a.v.getOrElse(0.0), b.v.getOrElse(0.0)))) + .aggregate(TestMaxCustomMonoidAggregator) .asPredictor val recordTime = FeatureBuilder.DateTime[PassengerCSV].extract(_.getRecordDate.toLong.toDateTime).asPredictor @@ -323,3 +323,10 @@ class JoinedDataReaderDataGenerationTest extends FlatSpec with PassengerSparkFix } } + +object TestMinCustomMonoidAggregator extends CustomMonoidAggregator[RealNN](zero = Some(Double.MaxValue), + (a, b) => Some(math.min(a.v.getOrElse(0.0), b.v.getOrElse(0.0)))) + + +object TestMaxCustomMonoidAggregator extends CustomMonoidAggregator[RealNN]( + zero = Some(0.0), (a, b) => Some(math.max(a.v.getOrElse(0.0), b.v.getOrElse(0.0)))) diff --git a/utils/src/main/scala/com/salesforce/op/utils/reflection/ReflectionUtils.scala b/utils/src/main/scala/com/salesforce/op/utils/reflection/ReflectionUtils.scala index cf7366063e..64a226688d 100644 --- a/utils/src/main/scala/com/salesforce/op/utils/reflection/ReflectionUtils.scala +++ b/utils/src/main/scala/com/salesforce/op/utils/reflection/ReflectionUtils.scala @@ -234,6 +234,65 @@ object ReflectionUtils { }) } + /** + * Create a WeakTypeTag for Type + * + * @param rtm runtime mirror + * @param tpe type + * @tparam T type T + * @return TypeTag[T] + */ + def weakTypeTagForType[T](tpe: Type): WeakTypeTag[T] = { + WeakTypeTag(runtimeMirror(), new api.TypeCreator { + def apply[U <: api.Universe with Singleton](m: api.Mirror[U]): U#Type = + if (m eq runtimeMirror()) tpe.asInstanceOf[U#Type] + else throw new IllegalArgumentException(s"Type tag defined in cannot be migrated to other mirrors.") + }) + } + + + /** + * Returns a Type Tag by string name + * + * @param rtm runtime mirror + * @param n class name + * @return TypeTag[_] + */ + def typeTagForName(rtm: Mirror = runtimeMirror(), n: String): TypeTag[_] = { + val clazz = classForName(n) + typeTagForType(rtm, rtm.classSymbol(clazz).toType) + } + + /** + * Returns a Weak Type Tag by string name + * + * @param rtm runtime mirror + * @param n class name + * @return TypeTag[_] + */ + def weakTypeTagForName(n: String): WeakTypeTag[_] = { + val clazz = classForName(n) + weakTypeTagForType(runtimeMirror().classSymbol(clazz).toType) + } + + + /** + * A helper function to get instance of lambda function or object + * @param name full name + * @return + */ + def getInstanceOfObject[T](name: String): T = { + val clazz = ReflectionUtils.classForName(name) + + val res = clazz.getConstructors.headOption match { + case Some(c) => c.newInstance() + case _ => { + clazz.getField("MODULE$").get(clazz) + } + } + res.asInstanceOf[T] + } + /** * Create a ClassTag for a WeakTypeTag *