Skip to content

Commit

Permalink
Allow customizing serialization for FeatureGenerator extract function (
Browse files Browse the repository at this point in the history
  • Loading branch information
tovbinm committed Jul 6, 2019
1 parent e6cc43f commit c25f4eb
Show file tree
Hide file tree
Showing 6 changed files with 256 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,17 @@
import java.lang.annotation.*;

/**
* Stage class annotation to specify custom reader/writer implementation of [[OpPipelineStageReaderWriter]].
* Reader/writer implementation must extend [[OpPipelineStageReaderWriter]] trait
* and has a single no arguments constructor.
* Stage of value class annotation to specify custom reader/writer implementation of [[ValueReaderWriter]].
* Reader/writer implementation must extend [[ValueReaderWriter]] trait and has a single no arguments constructor.
*/
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
@Inherited
public @interface ReaderWriter {

/**
* Reader/writer class extending [[OpPipelineStageReaderWriter]] to use when reading/writing the stage.
* It must extend [[OpPipelineStageReaderWriter]] trait and has a single no arguments constructor.
* Reader/writer class extending [[ValueReaderWriter]] to use when reading/writing the stage or it's arguments.
* It must extend [[ValueReaderWriter]] trait and has a single no arguments constructor.
*/
Class<?> value();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,19 @@ import com.salesforce.op.features.types.FeatureType
import com.salesforce.op.stages.OpPipelineStageReaderWriter._
import com.salesforce.op.utils.reflection.ReflectionUtils
import org.apache.spark.ml.PipelineStage
import org.json4s.{JObject, JValue}
import org.json4s.jackson.JsonMethods.render
import org.json4s.{Extraction, _}
import org.json4s.{Extraction, JObject, JValue, _}

import scala.reflect.{ClassTag, ManifestFactory}
import scala.reflect.runtime.universe._
import scala.reflect.{ClassTag, ManifestFactory}
import scala.util.{Failure, Success, Try}

/**
* Default reader/writer for stages that uses reflection to reflect stage ctor arguments
*
* @tparam StageType stage type to read/write
*/
final class DefaultOpPipelineStageReaderWriter[StageType <: OpPipelineStageBase]
(
implicit val ct: ClassTag[StageType]
) extends OpPipelineStageReaderWriter[StageType] with OpPipelineStageSerializationFuns {
final class DefaultOpPipelineStageReaderWriter[StageType <: OpPipelineStageBase](implicit val ct: ClassTag[StageType])
extends OpPipelineStageReaderWriter[StageType] with OpPipelineStageSerializationFuns {

/**
* Read stage from json
Expand Down Expand Up @@ -179,6 +175,4 @@ final class DefaultOpPipelineStageReaderWriter[StageType <: OpPipelineStageBase]
Extraction.decompose(args.toMap)
}


private def jsonSerialize(v: Any): JValue = render(Extraction.decompose(v))
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright (c) 2017, Salesforce.com, Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* * Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* * Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* * Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

package com.salesforce.op.stages

import com.salesforce.op.utils.reflection.ReflectionUtils
import org.json4s.JValue
import org.json4s.JsonAST.{JObject, JString}

import scala.reflect.ClassTag
import scala.util.Try


/**
* Default value reader/writer implementation used to (de)serialize stage arguments from/to trained models
* based on their class name and no args ctor.
*
* @param valueName value name
* @tparam T value type to read/write
*/
final class DefaultValueReaderWriter[T <: AnyRef](valueName: String)(implicit val ct: ClassTag[T])
extends ValueReaderWriter[T] with OpPipelineStageReadWriteFormats with OpPipelineStageSerializationFuns {

/**
* Read value from json
*
* @param valueClass value class
* @param json json to read argument value from
* @return read result
*/
def read(valueClass: Class[T], json: JValue): Try[T] = Try {
val className = (json \ "className").extract[String]
ReflectionUtils.newInstance[T](className)
}

/**
* Write value to json
*
* @param value value to write
* @return write result
*/
def write(value: T): Try[JValue] = Try {
val arg = serializeArgument(valueName, value)
JObject("className" -> JString(arg.value.toString))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ import org.apache.spark.ml.PipelineStage
import org.apache.spark.util.ClosureUtils
import org.joda.time.Duration
import org.json4s.JValue
import org.json4s.JsonAST.JObject
import org.json4s.JsonDSL._
import com.salesforce.op.stages.ValueReaderWriter._

import scala.reflect.runtime.universe.WeakTypeTag
import scala.util.Try
Expand Down Expand Up @@ -136,35 +136,38 @@ class FeatureGeneratorStageReaderWriter[I, O <: FeatureType]
* @param json json to read stage from
* @return read result
*/
def read(stageClass: Class[FeatureGeneratorStage[I, O]], json: JValue): Try[FeatureGeneratorStage[I, O]] = {
Try {
val tti = (json \ "tti").extract[String]
val tto = FeatureType.featureTypeTag((json \ "tto").extract[String]).asInstanceOf[WeakTypeTag[O]]

val extractFnJson = json \ "extractFn"
val extractFnClassName = (extractFnJson \ "className").extract[String]
val extractFn = extractFnClassName match {
case c if classOf[FromRowExtractFn[_]].getName == c =>
val index = (extractFnJson \ "index").extractOpt[Int]
val name = (extractFnJson \ "name").extract[String]
FromRowExtractFn[O](index, name)(tto).asInstanceOf[Function1[I, O]]
case c =>
ReflectionUtils.newInstance[Function1[I, O]](c)
}

val aggregatorClassName = (json \ "aggregator" \ "className").extract[String]
val aggregator = ReflectionUtils.newInstance[MonoidAggregator[Event[O], _, O]](aggregatorClassName)

val outputName = (json \ "outputName").extract[String]
val extractSource = (json \ "extractSource").extract[String]
val uid = (json \ "uid").extract[String]
val outputIsResponse = (json \ "outputIsResponse").extract[Boolean]
val aggregateWindow = (json \ "aggregateWindow").extractOpt[Long].map(Duration.millis)

new FeatureGeneratorStage[I, O](extractFn, extractSource, aggregator,
outputName, outputIsResponse, aggregateWindow, uid, Right(tti))(tto)
def read(stageClass: Class[FeatureGeneratorStage[I, O]], json: JValue): Try[FeatureGeneratorStage[I, O]] = Try {
val tti = (json \ "tti").extract[String]
val tto = FeatureType.featureTypeTag((json \ "tto").extract[String]).asInstanceOf[WeakTypeTag[O]]

val extractFnJson = json \ "extractFn"
val extractFn = (extractFnJson \ "className").extract[String] match {
case extractFnClassName if classOf[FromRowExtractFn[_]].getName == extractFnClassName =>
val index = (extractFnJson \ "index").extractOpt[Int]
val name = (extractFnJson \ "name").extract[String]
FromRowExtractFn[O](index, name)(tto).asInstanceOf[Function1[I, O]]
case extractFnClassName =>
val extractFnClass = ReflectionUtils.classForName(extractFnClassName).asInstanceOf[Class[I => O]]
readerWriterFor(extractFnClass, "extractFn")
.read(extractFnClass, extractFnJson \ "value").get
}

val aggregatorJson = json \ "aggregator"
val aggregatorClassName = (aggregatorJson \ "className").extract[String]
val aggregatorClass = ReflectionUtils.classForName(aggregatorClassName)
.asInstanceOf[Class[MonoidAggregator[Event[O], _, O]]]
val aggregator =
readerWriterFor(aggregatorClass, "aggregator")
.read(aggregatorClass, aggregatorJson \ "value").get

val outputName = (json \ "outputName").extract[String]
val extractSource = (json \ "extractSource").extract[String]
val uid = (json \ "uid").extract[String]
val outputIsResponse = (json \ "outputIsResponse").extract[Boolean]
val aggregateWindow = (json \ "aggregateWindow").extractOpt[Long].map(Duration.millis)

new FeatureGeneratorStage[I, O](extractFn, extractSource, aggregator,
outputName, outputIsResponse, aggregateWindow, uid, Right(tti))(tto)
}

/**
Expand All @@ -175,17 +178,23 @@ class FeatureGeneratorStageReaderWriter[I, O <: FeatureType]
*/
def write(stage: FeatureGeneratorStage[I, O]): Try[JValue] = {
for {
extractFn <- Try {
extractFn <- {
stage.extractFn match {
case e: FromRowExtractFn[_] =>
("className" -> e.getClass.getName) ~ ("index" -> e.index) ~ ("name" -> e.name)
case e =>
("className" -> serializeArgument("extractFn", e).value.toString) ~ JObject()
case extract: FromRowExtractFn[_] => Try {
("className" -> extract.getClass.getName) ~ ("index" -> extract.index) ~ ("name" -> extract.name)
}
case extract => {
val extractClass = extract.getClass.asInstanceOf[Class[I => O]]
readerWriterFor(extractClass, "extractFn")
.write(extract).map { j => ("className" -> extractClass.getName) ~ ("value" -> j) }
}
}
}
aggregator <- Try(
("className" -> serializeArgument("aggregator", stage.aggregator).value.toString) ~ JObject()
)
aggregator <- {
val aggregatorClass = stage.aggregator.getClass.asInstanceOf[Class[MonoidAggregator[Event[O], _, O]]]
readerWriterFor[MonoidAggregator[Event[O], _, O]](aggregatorClass, "aggregator")
.write(stage.aggregator).map { j => ("className" -> aggregatorClass.getName) ~ ("value" -> j) }
}
} yield {
("tti" -> stage.tti.tpe.typeSymbol.fullName) ~
("tto" -> FeatureType.typeName(stage.tto)) ~
Expand Down

0 comments on commit c25f4eb

Please sign in to comment.