Skip to content

Commit

Permalink
Workflow independent model loading (#274)
Browse files Browse the repository at this point in the history
  • Loading branch information
tovbinm committed Jun 14, 2019
1 parent 428eb4c commit 13ddb4f
Show file tree
Hide file tree
Showing 70 changed files with 2,087 additions and 840 deletions.
7 changes: 4 additions & 3 deletions core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,10 @@ class OpWorkflowModel(val uid: String = UID[OpWorkflowModel], val trainingParams
* @return Updated instance of feature
*/
def getUpdatedFeatures(features: Array[OPFeature]): Array[OPFeature] = {
val allFeatures = rawFeatures ++ blacklistedFeatures ++ stages.map(_.getOutput())
features.map{f => allFeatures.find(_.sameOrigin(f))
.getOrElse(throw new IllegalArgumentException(s"feature $f is not a part of this workflow"))
val allFeatures = getRawFeatures() ++ getBlacklist() ++ getStages().map(_.getOutput())
features.map { f =>
allFeatures.find(_.sameOrigin(f))
.getOrElse(throw new IllegalArgumentException(s"feature $f is not a part of this workflow"))
}
}

Expand Down
157 changes: 86 additions & 71 deletions core/src/main/scala/com/salesforce/op/OpWorkflowModelReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,17 @@

package com.salesforce.op

import com.salesforce.op.OpWorkflowModelReadWriteShared.{FieldNames => FN}
import com.salesforce.op.OpWorkflowModelReadWriteShared.FieldNames._
import com.salesforce.op.features.{FeatureJsonHelper, OPFeature, TransientFeature}
import com.salesforce.op.filters.{FeatureDistribution, RawFeatureFilterResults}
import com.salesforce.op.stages.OpPipelineStageReadWriteShared._
import com.salesforce.op.stages.OpPipelineStageReaderWriter._
import com.salesforce.op.stages._
import org.apache.spark.ml.util.MLReader
import org.json4s.JsonAST.{JArray, JNothing, JValue}
import org.json4s.jackson.JsonMethods.parse

import scala.collection.mutable.ArrayBuffer
import scala.util.{Failure, Success, Try}

/**
Expand All @@ -50,7 +52,6 @@ import scala.util.{Failure, Success, Try}
*/
class OpWorkflowModelReader(val workflowOpt: Option[OpWorkflow]) extends MLReader[OpWorkflowModel] {


/**
* Load a previously trained workflow model from path
*
Expand All @@ -72,7 +73,9 @@ class OpWorkflowModelReader(val workflowOpt: Option[OpWorkflow]) extends MLReade
* @param path to the trained workflow model
* @return workflow model
*/
def loadJson(json: String, path: String): Try[OpWorkflowModel] = Try(parse(json)).flatMap(loadJson(_, path = path))
def loadJson(json: String, path: String): Try[OpWorkflowModel] = {
Try(parse(json)).flatMap(loadJson(_, path = path))
}

/**
* Load Workflow instance from json
Expand All @@ -81,98 +84,110 @@ class OpWorkflowModelReader(val workflowOpt: Option[OpWorkflow]) extends MLReade
* @param path to the trained workflow model
* @return workflow model instance
*/
def loadJson(json: JValue, path: String): Try[OpWorkflowModel] = workflowOpt match {
case None =>
throw new NotImplementedError("Loading models without the original workflow is currently not supported")

case Some(workflow) =>
for {
trainParams <- OpParams.fromString((json \ TrainParameters.entryName).extract[String])
params <- OpParams.fromString((json \ Parameters.entryName).extract[String])
model <- Try(new OpWorkflowModel(uid = (json \ Uid.entryName).extract[String], trainParams))
(stages, resultFeatures) <- Try(resolveFeaturesAndStages(workflow, json, path))
blacklist <- Try(resolveBlacklist(workflow, json))
blacklistMapKeys <- Try(resolveBlacklistMapKeys(json))
results <- resolveRawFeatureFilterResults(json)
} yield model
.setStages(stages.filterNot(_.isInstanceOf[FeatureGeneratorStage[_, _]]))
.setFeatures(resultFeatures)
.setParameters(params)
.setBlacklist(blacklist)
.setBlacklistMapKeys(blacklistMapKeys)
.setRawFeatureFilterResults(results)
def loadJson(json: JValue, path: String): Try[OpWorkflowModel] = {
for {
trainingParams <- OpParams.fromString((json \ TrainParameters.entryName).extract[String])
params <- OpParams.fromString((json \ Parameters.entryName).extract[String])
model <- Try(new OpWorkflowModel(uid = (json \ Uid.entryName).extract[String], trainingParams))
stages <- loadStages(json, workflowOpt, path)
resolvedFeatures <- resolveFeatures(json, stages)
resultFeatures <- resolveResultFeatures(json, resolvedFeatures)
blacklist <- resolveBlacklist(json, workflowOpt, resolvedFeatures, path)
blacklistMapKeys <- resolveBlacklistMapKeys(json)
rffResults <- resolveRawFeatureFilterResults(json)
} yield model
.setStages(stages.filterNot(_.isInstanceOf[FeatureGeneratorStage[_, _]]))
.setFeatures(resultFeatures)
.setParameters(params)
.setBlacklist(blacklist)
.setBlacklistMapKeys(blacklistMapKeys)
.setRawFeatureFilterResults(rffResults)
}

private def resolveBlacklist(workflow: OpWorkflow, json: JValue): Array[OPFeature] = {
if ((json \ BlacklistedFeaturesUids.entryName) != JNothing) { // for backwards compatibility
val blacklistIds = (json \ BlacklistedFeaturesUids.entryName).extract[JArray].arr
val allFeatures = workflow.getRawFeatures() ++ workflow.getBlacklist() ++
workflow.getStages().flatMap(_.getInputFeatures()) ++
workflow.getResultFeatures()
blacklistIds.flatMap(uid => allFeatures.find(_.uid == uid.extract[String])).toArray
} else {
Array.empty[OPFeature]
}
private def loadStages(json: JValue, wfOpt: Option[OpWorkflow], path: String): Try[Array[OPStage]] = {
wfOpt.map(wf => loadStages(json, Stages, wf, path)).getOrElse(loadStages(json, Stages, path).map(_._1))
}

private def resolveBlacklistMapKeys(json: JValue): Map[String, Set[String]] = {
(json \ BlacklistedMapKeys.entryName).extractOpt[Map[String, List[String]]] match {
case Some(blackMapKeys) => blackMapKeys.map { case (k, vs) => k -> vs.toSet }
case None => Map.empty
private def loadStages(json: JValue, field: FN, path: String): Try[(Array[OPStage], Array[OPFeature])] = Try {
val stagesJs = (json \ field.entryName).extract[JArray].arr
val (recoveredStages, recoveredFeatures) = ArrayBuffer.empty[OPStage] -> ArrayBuffer.empty[OPFeature]
for {j <- stagesJs} {
val stage = new OpPipelineStageReader(recoveredFeatures).loadFromJson(j, path = path).asInstanceOf[OPStage]
recoveredStages += stage
recoveredFeatures += stage.getOutput()
}
recoveredStages.toArray -> recoveredFeatures.toArray
}

private def resolveFeaturesAndStages
(
workflow: OpWorkflow,
json: JValue,
path: String
): (Array[OPStage], Array[OPFeature]) = {
val stages = loadStages(workflow, json, path)
val stagesMap = stages.map(stage => stage.uid -> stage).toMap[String, OPStage]
val featuresMap = resolveFeatures(json, stagesMap)
resolveStages(stages, featuresMap)

val resultIds = (json \ ResultFeaturesUids.entryName).extract[Array[String]]
val resultFeatures = featuresMap.filterKeys(resultIds.toSet).values

stages.toArray -> resultFeatures.toArray
}

private def loadStages(workflow: OpWorkflow, json: JValue, path: String): Seq[OPStage] = {
val stagesJs = (json \ Stages.entryName).extract[JArray].arr
private def loadStages(json: JValue, field: FN, workflow: OpWorkflow, path: String): Try[Array[OPStage]] = Try {
val generators = workflow.getRawFeatures().map(_.originStage)
val stagesJs = (json \ field.entryName).extract[JArray].arr
val recoveredStages = stagesJs.flatMap { j =>
val stageUidOpt = (j \ Uid.entryName).extractOpt[String]
stageUidOpt.map { stageUid =>
val originalStage = workflow.getStages().find(_.uid == stageUid)
originalStage match {
case Some(os) => new OpPipelineStageReader(os).loadFromJson(j, path = path).asInstanceOf[OPStage]
case None => throw new RuntimeException(s"Workflow does not contain a stage with uid: $stageUid")
}
val stageUid = (j \ Uid.entryName).extract[String]
val originalStage = workflow.getStages().find(_.uid == stageUid)
originalStage match {
case Some(os) => Option(
new OpPipelineStageReader(os).loadFromJson(j, path = path)).map(_.asInstanceOf[OPStage]
)
case None if generators.exists(_.uid == stageUid) => None // skip the generator since they are in the workflow
case None => throw new RuntimeException(s"Workflow does not contain a stage with uid: $stageUid")
}
}
val generators = workflow.getRawFeatures().map(_.originStage)
generators ++ recoveredStages
}

private def resolveFeatures(json: JValue, stages: Map[String, OPStage]): Map[String, OPFeature] = {
val results = (json \ AllFeatures.entryName).extract[JArray].arr
private def resolveFeatures(json: JValue, stages: Array[OPStage]): Try[Array[OPFeature]] = Try {
val featuresArr = (json \ AllFeatures.entryName).extract[JArray].arr
val stagesMap = stages.map(stage => stage.uid -> stage).toMap[String, OPStage]

// should have been serialized in topological order
// so that parent features can be used to construct each new feature
results.foldLeft(Map.empty[String, OPFeature])((featMap, feat) =>
FeatureJsonHelper.fromJson(feat, stages, featMap) match {
case Success(f) => featMap + (f.uid -> f)
val featuresMap = featuresArr.foldLeft(Map.empty[String, OPFeature])((featMap, feat) =>
FeatureJsonHelper.fromJson(feat, stagesMap, featMap) match {
case Failure(e) => throw new RuntimeException(s"Error resolving feature: $feat", e)
case Success(f) => featMap + (f.uid -> f)
}
)
}

private def resolveStages(stages: Seq[OPStage], featuresMap: Map[String, OPFeature]): Unit = {
// set input features to stages
for {stage <- stages} {
val inputIds = stage.getTransientFeatures().map(_.uid)
val inFeatures = inputIds.map(id => TransientFeature(featuresMap(id))) // features are order dependent
stage.set(stage.inputFeatures, inFeatures)
}
featuresMap.values.toArray
}

private def resolveResultFeatures(json: JValue, features: Array[OPFeature]): Try[Array[OPFeature]] = Try {
val resultIds = (json \ ResultFeaturesUids.entryName).extract[Array[String]].toSet
features.filter(f => resultIds.contains(f.uid))
}

private def resolveBlacklist
(
json: JValue,
wfOpt: Option[OpWorkflow],
features: Array[OPFeature],
path: String
): Try[Array[OPFeature]] = {
if ((json \ BlacklistedFeaturesUids.entryName) != JNothing) { // for backwards compatibility
for {
feats <- wfOpt
.map(wf => Success(wf.getAllFeatures() ++ wf.getBlacklist()))
.getOrElse(loadStages(json, BlacklistedStages, path).map(_._2))
allFeatures = features ++ feats
blacklistIds = (json \ BlacklistedFeaturesUids.entryName).extract[Array[String]]
} yield blacklistIds.flatMap(uid => allFeatures.find(_.uid == uid))
} else {
Success(Array.empty[OPFeature])
}
}

private def resolveBlacklistMapKeys(json: JValue): Try[Map[String, Set[String]]] = Try {
(json \ BlacklistedMapKeys.entryName).extractOpt[Map[String, List[String]]] match {
case Some(blackMapKeys) => blackMapKeys.map { case (k, vs) => k -> vs.toSet }
case None => Map.empty
}
}

private def resolveRawFeatureFilterResults(json: JValue): Try[RawFeatureFilterResults] = {
Expand Down
35 changes: 29 additions & 6 deletions core/src/main/scala/com/salesforce/op/OpWorkflowModelWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ package com.salesforce.op

import com.salesforce.op.features.FeatureJsonHelper
import com.salesforce.op.filters.RawFeatureFilterResults
import com.salesforce.op.stages.{OpPipelineStageBase, OpPipelineStageWriter}
import com.salesforce.op.stages.{OPStage, OpPipelineStageWriter}
import enumeratum._
import org.apache.hadoop.fs.Path
import org.apache.spark.ml.util.MLWriter
Expand All @@ -54,8 +54,7 @@ class OpWorkflowModelWriter(val model: OpWorkflowModel) extends MLWriter {
implicit val jsonFormats: Formats = DefaultFormats

override protected def saveImpl(path: String): Unit = {
sc.parallelize(Seq(toJsonString(path)), 1)
.saveAsTextFile(OpWorkflowModelReadWriteShared.jsonPath(path))
sc.parallelize(Seq(toJsonString(path)), 1).saveAsTextFile(OpWorkflowModelReadWriteShared.jsonPath(path))
}

/**
Expand All @@ -78,6 +77,7 @@ class OpWorkflowModelWriter(val model: OpWorkflowModel) extends MLWriter {
(FN.ResultFeaturesUids.entryName -> resultFeaturesJArray) ~
(FN.BlacklistedFeaturesUids.entryName -> blacklistFeaturesJArray()) ~
(FN.BlacklistedMapKeys.entryName -> blacklistMapKeys()) ~
(FN.BlacklistedStages.entryName -> blackListedStagesJArray(path)) ~
(FN.Stages.entryName -> stagesJArray(path)) ~
(FN.AllFeatures.entryName -> allFeaturesJArray) ~
(FN.Parameters.entryName -> model.getParameters().toJson(pretty = false)) ~
Expand All @@ -96,14 +96,36 @@ class OpWorkflowModelWriter(val model: OpWorkflowModel) extends MLWriter {
JObject(model.getBlacklistMapKeys().map { case (k, vs) => k -> JArray(vs.map(JString).toList) }.toList)

/**
* Serialize all the workflow model stages
* Serialize all the model stages
*
* @param path path to store the spark params for stages
* @return array of serialized stages
*/
private def stagesJArray(path: String): JArray = {
val stages: Seq[OpPipelineStageBase] = model.getStages()
val stagesJson: Seq[JObject] = stages
val stages = model.getRawFeatures().map(_.originStage) ++ model.getStages()
stagesJArray(stages, path)
}

/**
* Serialize all the blacklisted model stages
*
* @param path path to store the spark params for stages
* @return array of serialized stages
*/
private def blackListedStagesJArray(path: String): JArray = {
val blacklistStages = model.getBlacklist().map(_.originStage)
stagesJArray(blacklistStages, path)
}

/**
* Serialize the stages
*
* @param stages path to store the spark params for stages
* @param path path to store the spark params for stages
* @return array of serialized stages
*/
private def stagesJArray(stages: Array[OPStage], path: String): JArray = {
val stagesJson = stages
.map(_.write.asInstanceOf[OpPipelineStageWriter].writeToJson(path))
.filter(_.children.nonEmpty)
JArray(stagesJson.toList)
Expand Down Expand Up @@ -140,6 +162,7 @@ private[op] object OpWorkflowModelReadWriteShared {
case object ResultFeaturesUids extends FieldNames("resultFeaturesUids")
case object BlacklistedFeaturesUids extends FieldNames("blacklistedFeaturesUids")
case object BlacklistedMapKeys extends FieldNames("blacklistedMapKeys")
case object BlacklistedStages extends FieldNames("blacklistedStages")
case object Stages extends FieldNames("stages")
case object AllFeatures extends FieldNames("allFeatures")
case object Parameters extends FieldNames("parameters")
Expand Down
19 changes: 15 additions & 4 deletions core/src/main/scala/com/salesforce/op/dsl/RichDateFeature.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,15 @@ trait RichDateFeature {

/**
* Convert to DateList feature
*
* @return
*/
def toDateList(): FeatureLike[DateList] = {
f.transformWith(
new UnaryLambdaTransformer[Date, DateList](operationName = "dateToList", _.value.toSeq.toDateList)
new UnaryLambdaTransformer[Date, DateList](
operationName = "dateToList",
RichDateFeatureLambdas.toDateList
)
)
}

Expand All @@ -70,7 +74,7 @@ trait RichDateFeature {
*
* @param timePeriod The time period to extract from the timestamp
* @param others Other features of same type
* enum from: DayOfMonth, DayOfWeek, DayOfYear, HourOfDay, WeekOfMonth, WeekOfYear
* enum from: DayOfMonth, DayOfWeek, DayOfYear, HourOfDay, WeekOfMonth, WeekOfYear
*/
def toUnitCircle
(
Expand Down Expand Up @@ -126,13 +130,14 @@ trait RichDateFeature {

/**
* Convert to DateTimeList feature
*
* @return
*/
def toDateTimeList(): FeatureLike[DateTimeList] = {
f.transformWith(
new UnaryLambdaTransformer[DateTime, DateTimeList](
operationName = "dateTimeToList",
_.value.toSeq.toDateTimeList
RichDateFeatureLambdas.toDateTimeList
)
)
}
Expand All @@ -150,7 +155,7 @@ trait RichDateFeature {
*
* @param timePeriod The time period to extract from the timestamp
* @param others Other features of same type
* enum from: DayOfMonth, DayOfWeek, DayOfYear, HourOfDay, WeekOfMonth, WeekOfYear
* enum from: DayOfMonth, DayOfWeek, DayOfYear, HourOfDay, WeekOfMonth, WeekOfYear
*/
def toUnitCircle(
timePeriod: TimePeriod = TimePeriod.HourOfDay,
Expand Down Expand Up @@ -197,3 +202,9 @@ trait RichDateFeature {
}

}

object RichDateFeatureLambdas {
def toDateList: Date => DateList = (x: Date) => x.value.toSeq.toDateList

def toDateTimeList: DateTime => DateTimeList = (x: DateTime) => x.value.toSeq.toDateTimeList
}
Loading

0 comments on commit 13ddb4f

Please sign in to comment.