diff --git a/features/src/main/java/com/salesforce/op/stages/ReaderWriter.java b/features/src/main/java/com/salesforce/op/stages/ReaderWriter.java index 55adb8349a..e5b5c1b1a6 100644 --- a/features/src/main/java/com/salesforce/op/stages/ReaderWriter.java +++ b/features/src/main/java/com/salesforce/op/stages/ReaderWriter.java @@ -3,9 +3,8 @@ import java.lang.annotation.*; /** - * Stage class annotation to specify custom reader/writer implementation of [[OpPipelineStageReaderWriter]]. - * Reader/writer implementation must extend [[OpPipelineStageReaderWriter]] trait - * and has a single no arguments constructor. + * Stage of value class annotation to specify custom reader/writer implementation of [[ValueReaderWriter]]. + * Reader/writer implementation must extend [[ValueReaderWriter]] trait and has a single no arguments constructor. */ @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) @@ -13,8 +12,8 @@ public @interface ReaderWriter { /** - * Reader/writer class extending [[OpPipelineStageReaderWriter]] to use when reading/writing the stage. - * It must extend [[OpPipelineStageReaderWriter]] trait and has a single no arguments constructor. + * Reader/writer class extending [[ValueReaderWriter]] to use when reading/writing the stage or it's arguments. + * It must extend [[ValueReaderWriter]] trait and has a single no arguments constructor. */ Class value(); diff --git a/features/src/main/scala/com/salesforce/op/stages/DefaultOpPipelineStageReaderWriter.scala b/features/src/main/scala/com/salesforce/op/stages/DefaultOpPipelineStageReaderWriter.scala index 389ba93746..bdc2b75a51 100644 --- a/features/src/main/scala/com/salesforce/op/stages/DefaultOpPipelineStageReaderWriter.scala +++ b/features/src/main/scala/com/salesforce/op/stages/DefaultOpPipelineStageReaderWriter.scala @@ -34,12 +34,10 @@ import com.salesforce.op.features.types.FeatureType import com.salesforce.op.stages.OpPipelineStageReaderWriter._ import com.salesforce.op.utils.reflection.ReflectionUtils import org.apache.spark.ml.PipelineStage -import org.json4s.{JObject, JValue} -import org.json4s.jackson.JsonMethods.render -import org.json4s.{Extraction, _} +import org.json4s.{Extraction, JObject, JValue, _} -import scala.reflect.{ClassTag, ManifestFactory} import scala.reflect.runtime.universe._ +import scala.reflect.{ClassTag, ManifestFactory} import scala.util.{Failure, Success, Try} /** @@ -47,10 +45,8 @@ import scala.util.{Failure, Success, Try} * * @tparam StageType stage type to read/write */ -final class DefaultOpPipelineStageReaderWriter[StageType <: OpPipelineStageBase] -( - implicit val ct: ClassTag[StageType] -) extends OpPipelineStageReaderWriter[StageType] with OpPipelineStageSerializationFuns { +final class DefaultOpPipelineStageReaderWriter[StageType <: OpPipelineStageBase](implicit val ct: ClassTag[StageType]) + extends OpPipelineStageReaderWriter[StageType] with OpPipelineStageSerializationFuns { /** * Read stage from json @@ -179,6 +175,4 @@ final class DefaultOpPipelineStageReaderWriter[StageType <: OpPipelineStageBase] Extraction.decompose(args.toMap) } - - private def jsonSerialize(v: Any): JValue = render(Extraction.decompose(v)) } diff --git a/features/src/main/scala/com/salesforce/op/stages/DefaultValueReaderWriter.scala b/features/src/main/scala/com/salesforce/op/stages/DefaultValueReaderWriter.scala new file mode 100644 index 0000000000..a885aa09c3 --- /dev/null +++ b/features/src/main/scala/com/salesforce/op/stages/DefaultValueReaderWriter.scala @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2017, Salesforce.com, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * * Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package com.salesforce.op.stages + +import com.salesforce.op.utils.reflection.ReflectionUtils +import org.json4s.JValue +import org.json4s.JsonAST.{JObject, JString} + +import scala.reflect.ClassTag +import scala.util.Try + + +/** + * Default value reader/writer implementation used to (de)serialize stage arguments from/to trained models + * based on their class name and no args ctor. + * + * @param valueName value name + * @tparam T value type to read/write + */ +final class DefaultValueReaderWriter[T <: AnyRef](valueName: String)(implicit val ct: ClassTag[T]) + extends ValueReaderWriter[T] with OpPipelineStageReadWriteFormats with OpPipelineStageSerializationFuns { + + /** + * Read value from json + * + * @param valueClass value class + * @param json json to read argument value from + * @return read result + */ + def read(valueClass: Class[T], json: JValue): Try[T] = Try { + val className = (json \ "className").extract[String] + ReflectionUtils.newInstance[T](className) + } + + /** + * Write value to json + * + * @param value value to write + * @return write result + */ + def write(value: T): Try[JValue] = Try { + val arg = serializeArgument(valueName, value) + JObject("className" -> JString(arg.value.toString)) + } + +} diff --git a/features/src/main/scala/com/salesforce/op/stages/FeatureGeneratorStage.scala b/features/src/main/scala/com/salesforce/op/stages/FeatureGeneratorStage.scala index bbbe6d9029..5f652884e8 100644 --- a/features/src/main/scala/com/salesforce/op/stages/FeatureGeneratorStage.scala +++ b/features/src/main/scala/com/salesforce/op/stages/FeatureGeneratorStage.scala @@ -40,8 +40,8 @@ import org.apache.spark.ml.PipelineStage import org.apache.spark.util.ClosureUtils import org.joda.time.Duration import org.json4s.JValue -import org.json4s.JsonAST.JObject import org.json4s.JsonDSL._ +import com.salesforce.op.stages.ValueReaderWriter._ import scala.reflect.runtime.universe.WeakTypeTag import scala.util.Try @@ -136,35 +136,38 @@ class FeatureGeneratorStageReaderWriter[I, O <: FeatureType] * @param json json to read stage from * @return read result */ - def read(stageClass: Class[FeatureGeneratorStage[I, O]], json: JValue): Try[FeatureGeneratorStage[I, O]] = { - Try { - val tti = (json \ "tti").extract[String] - val tto = FeatureType.featureTypeTag((json \ "tto").extract[String]).asInstanceOf[WeakTypeTag[O]] - - val extractFnJson = json \ "extractFn" - val extractFnClassName = (extractFnJson \ "className").extract[String] - val extractFn = extractFnClassName match { - case c if classOf[FromRowExtractFn[_]].getName == c => - val index = (extractFnJson \ "index").extractOpt[Int] - val name = (extractFnJson \ "name").extract[String] - FromRowExtractFn[O](index, name)(tto).asInstanceOf[Function1[I, O]] - case c => - ReflectionUtils.newInstance[Function1[I, O]](c) - } - - val aggregatorClassName = (json \ "aggregator" \ "className").extract[String] - val aggregator = ReflectionUtils.newInstance[MonoidAggregator[Event[O], _, O]](aggregatorClassName) - - val outputName = (json \ "outputName").extract[String] - val extractSource = (json \ "extractSource").extract[String] - val uid = (json \ "uid").extract[String] - val outputIsResponse = (json \ "outputIsResponse").extract[Boolean] - val aggregateWindow = (json \ "aggregateWindow").extractOpt[Long].map(Duration.millis) - - new FeatureGeneratorStage[I, O](extractFn, extractSource, aggregator, - outputName, outputIsResponse, aggregateWindow, uid, Right(tti))(tto) + def read(stageClass: Class[FeatureGeneratorStage[I, O]], json: JValue): Try[FeatureGeneratorStage[I, O]] = Try { + val tti = (json \ "tti").extract[String] + val tto = FeatureType.featureTypeTag((json \ "tto").extract[String]).asInstanceOf[WeakTypeTag[O]] + + val extractFnJson = json \ "extractFn" + val extractFn = (extractFnJson \ "className").extract[String] match { + case extractFnClassName if classOf[FromRowExtractFn[_]].getName == extractFnClassName => + val index = (extractFnJson \ "index").extractOpt[Int] + val name = (extractFnJson \ "name").extract[String] + FromRowExtractFn[O](index, name)(tto).asInstanceOf[Function1[I, O]] + case extractFnClassName => + val extractFnClass = ReflectionUtils.classForName(extractFnClassName).asInstanceOf[Class[I => O]] + readerWriterFor(extractFnClass, "extractFn") + .read(extractFnClass, extractFnJson \ "value").get } + val aggregatorJson = json \ "aggregator" + val aggregatorClassName = (aggregatorJson \ "className").extract[String] + val aggregatorClass = ReflectionUtils.classForName(aggregatorClassName) + .asInstanceOf[Class[MonoidAggregator[Event[O], _, O]]] + val aggregator = + readerWriterFor(aggregatorClass, "aggregator") + .read(aggregatorClass, aggregatorJson \ "value").get + + val outputName = (json \ "outputName").extract[String] + val extractSource = (json \ "extractSource").extract[String] + val uid = (json \ "uid").extract[String] + val outputIsResponse = (json \ "outputIsResponse").extract[Boolean] + val aggregateWindow = (json \ "aggregateWindow").extractOpt[Long].map(Duration.millis) + + new FeatureGeneratorStage[I, O](extractFn, extractSource, aggregator, + outputName, outputIsResponse, aggregateWindow, uid, Right(tti))(tto) } /** @@ -175,17 +178,23 @@ class FeatureGeneratorStageReaderWriter[I, O <: FeatureType] */ def write(stage: FeatureGeneratorStage[I, O]): Try[JValue] = { for { - extractFn <- Try { + extractFn <- { stage.extractFn match { - case e: FromRowExtractFn[_] => - ("className" -> e.getClass.getName) ~ ("index" -> e.index) ~ ("name" -> e.name) - case e => - ("className" -> serializeArgument("extractFn", e).value.toString) ~ JObject() + case extract: FromRowExtractFn[_] => Try { + ("className" -> extract.getClass.getName) ~ ("index" -> extract.index) ~ ("name" -> extract.name) + } + case extract => { + val extractClass = extract.getClass.asInstanceOf[Class[I => O]] + readerWriterFor(extractClass, "extractFn") + .write(extract).map { j => ("className" -> extractClass.getName) ~ ("value" -> j) } + } } } - aggregator <- Try( - ("className" -> serializeArgument("aggregator", stage.aggregator).value.toString) ~ JObject() - ) + aggregator <- { + val aggregatorClass = stage.aggregator.getClass.asInstanceOf[Class[MonoidAggregator[Event[O], _, O]]] + readerWriterFor[MonoidAggregator[Event[O], _, O]](aggregatorClass, "aggregator") + .write(stage.aggregator).map { j => ("className" -> aggregatorClass.getName) ~ ("value" -> j) } + } } yield { ("tti" -> stage.tti.tpe.typeSymbol.fullName) ~ ("tto" -> FeatureType.typeName(stage.tto)) ~ diff --git a/features/src/main/scala/com/salesforce/op/stages/OpPipelineStageReaderWriter.scala b/features/src/main/scala/com/salesforce/op/stages/OpPipelineStageReaderWriter.scala index 30d60096a5..c8bff874df 100644 --- a/features/src/main/scala/com/salesforce/op/stages/OpPipelineStageReaderWriter.scala +++ b/features/src/main/scala/com/salesforce/op/stages/OpPipelineStageReaderWriter.scala @@ -36,44 +36,125 @@ import com.salesforce.op.utils.json.{EnumEntrySerializer, SpecialDoubleSerialize import com.salesforce.op.utils.reflection.ReflectionUtils import enumeratum.{Enum, EnumEntry} import org.json4s.ext.JodaTimeSerializers +import org.json4s.jackson.JsonMethods._ import org.json4s.jackson.Serialization -import org.json4s.{Formats, FullTypeHints, JValue} +import org.json4s.{Extraction, Formats, FullTypeHints, JValue} import org.slf4j.LoggerFactory import scala.reflect.ClassTag import scala.util.{Failure, Success, Try} - /** - * Stage reader/writer implementation used to (de)serialize stages from/to trained models + * Value reader/writer implementation used to (de)serialize stage arguments from/to trained models * - * @tparam StageType stage type to read/write + * @tparam T value type to read/write */ -trait OpPipelineStageReaderWriter[StageType <: OpPipelineStageBase] extends OpPipelineStageReadWriteFormats { +trait ValueReaderWriter[T <: AnyRef] { /** - * Read stage from json + * Read value from json * - * @param stageClass stage class - * @param json json to read stage from + * @param valueClass value class + * @param json json to read argument value from * @return read result */ - def read(stageClass: Class[StageType], json: JValue): Try[StageType] + def read(valueClass: Class[T], json: JValue): Try[T] /** - * Write stage to json + * Write value to json * - * @param stage stage instance to write + * @param value value to write * @return write result */ - def write(stage: StageType): Try[JValue] + def write(value: T): Try[JValue] + +} + + +object ValueReaderWriter { + + private val log = LoggerFactory.getLogger(ValueReaderWriter.getClass) + + /** + * Retrieve reader/writer implementation: either the custom one specified with [[ReaderWriter]] annotation + * or the default one [[DefaultValueReaderWriter]] + * + * @param valueClass value class + * @param valueName value name + * @tparam T value type + * @return reader/writer implementation + */ + def readerWriterFor[T <: AnyRef : ClassTag](valueClass: Class[T], valueName: String): ValueReaderWriter[T] = { + readerWriterForOrDefault[T, ValueReaderWriter[T]](valueClass, new DefaultValueReaderWriter[T](valueName)) + } + + /** + * Retrieve reader/writer implementation: either the custom one specified with [[ReaderWriter]] annotation + * or the default one + * + * @param valueClass stage class + * @param defaultReaderWriter default reader writer instance maker + * @tparam T value type + * @return reader/writer implementation + */ + private[op] def readerWriterForOrDefault[T <: AnyRef : ClassTag, RW <: ValueReaderWriter[T]] + ( + valueClass: Class[T], + defaultReaderWriter: => RW + ): RW = { + if (!valueClass.isAnnotationPresent(classOf[ReaderWriter])) defaultReaderWriter + else { + Try { + val readerWriterClass = valueClass.getAnnotation[ReaderWriter](classOf[ReaderWriter]).value() + ReflectionUtils.newInstance[RW](readerWriterClass.getName) + } match { + case Success(readerWriter) => + if (log.isDebugEnabled) { + log.debug(s"Using reader/writer of type '${readerWriter.getClass.getName}'" + + s"to (de)serialize value of type '${valueClass.getName}'") + } + readerWriter + case Failure(e) => throw new RuntimeException( + s"Failed to create reader/writer instance for value class ${valueClass.getName}", e) + } + } + } } +/** + * Stage reader/writer implementation used to (de)serialize stages from/to trained models + * + * @tparam StageType stage type to read/write + */ +trait OpPipelineStageReaderWriter[StageType <: OpPipelineStageBase] + extends ValueReaderWriter[StageType] with OpPipelineStageReadWriteFormats + + object OpPipelineStageReaderWriter extends OpPipelineStageReadWriteFormats { - private val log = LoggerFactory.getLogger(OpPipelineStageReaderWriter.getClass) + /** + * Retrieve reader/writer implementation: either the custom one specified with [[ReaderWriter]] annotation + * or the default one [[DefaultOpPipelineStageReaderWriter]] + * + * @param stageClass stage class + * @tparam StageType stage type + * @return reader/writer implementation + */ + def readerWriterFor[StageType <: OpPipelineStageBase : ClassTag] + ( + stageClass: Class[StageType] + ): OpPipelineStageReaderWriter[StageType] = { + ValueReaderWriter.readerWriterForOrDefault[StageType, OpPipelineStageReaderWriter[StageType]]( + stageClass, defaultReaderWriter = new DefaultOpPipelineStageReaderWriter[StageType]() + ) + } + + /** + * Serialize value to json + */ + def jsonSerialize(v: Any): JValue = render(Extraction.decompose(v)(formats))(formats) /** * Stage json field names @@ -113,38 +194,6 @@ object OpPipelineStageReaderWriter extends OpPipelineStageReadWriteFormats { */ case class AnyValue(`type`: AnyValueTypes, value: Any, valueClass: Option[String]) - /** - * Retrieve reader/writer implementation: either the custom one specified with [[ReaderWriter]] annotation - * or the default one [[DefaultOpPipelineStageReaderWriter]] - * - * @param stageClass stage class - * @tparam StageType stage type - * @return reader/writer implementation - */ - def readerWriterFor[StageType <: OpPipelineStageBase : ClassTag] - ( - stageClass: Class[StageType] - ): OpPipelineStageReaderWriter[StageType] = { - if (!stageClass.isAnnotationPresent(classOf[ReaderWriter])) { - new DefaultOpPipelineStageReaderWriter[StageType]() - } - else { - Try { - val readerWriterClass = stageClass.getAnnotation[ReaderWriter](classOf[ReaderWriter]).value() - ReflectionUtils.newInstance[OpPipelineStageReaderWriter[StageType]](readerWriterClass.getName) - } match { - case Success(readerWriter) => - if (log.isDebugEnabled) { - log.debug(s"Using reader/writer of type '${readerWriter.getClass.getName}'" - + s"to (de)serialize stage of type '${stageClass.getName}'") - } - readerWriter - case Failure(e) => throw new RuntimeException( - s"Failed to create reader/writer instance for stage class ${stageClass.getName}", e) - } - } - } - } diff --git a/features/src/test/scala/com/salesforce/op/stages/FeatureGeneratorStageTest.scala b/features/src/test/scala/com/salesforce/op/stages/FeatureGeneratorStageTest.scala index 5be19f31e4..6e859c5055 100644 --- a/features/src/test/scala/com/salesforce/op/stages/FeatureGeneratorStageTest.scala +++ b/features/src/test/scala/com/salesforce/op/stages/FeatureGeneratorStageTest.scala @@ -31,16 +31,21 @@ package com.salesforce.op.stages import com.salesforce.op.aggregators.{CutOffTime, Event, MonoidAggregatorDefaults} -import com.salesforce.op.features.Feature -import com.salesforce.op.features.types.{FeatureType, FeatureTypeSparkConverter} +import com.salesforce.op.features.{Feature, FeatureBuilder} +import com.salesforce.op.features.types._ import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext} import com.salesforce.op.utils.spark.RichRow._ import org.apache.spark.sql.Row -import org.junit.runner.RunWith +import org.json4s.{DefaultFormats, JValue} import org.scalatest.FlatSpec import org.scalatest.junit.JUnitRunner +import org.json4s.JValue +import org.json4s.JsonAST.{JInt, JObject} +import org.json4s.JsonDSL._ +import org.junit.runner.RunWith import scala.reflect.runtime.universe.WeakTypeTag +import scala.util.Try @RunWith(classOf[JUnitRunner]) @@ -77,6 +82,18 @@ class FeatureGeneratorStageTest extends FlatSpec with TestSparkContext { assertAggregateFeatures(recovered) } + it should "serialize to/from json with a parametrized extract function" in { + val multiplier = 10 + val multiplied = FeatureBuilder.Integral[Int].extract(new IntMultiplyExtractor(multiplier)).asPredictor + val featureGenerator = multiplied.originStage + val featureGenJson = featureGenerator.write.asInstanceOf[OpPipelineStageWriter].writeToJsonString("") + val recoveredStage = new OpPipelineStageReader(Seq.empty).loadFromJsonString(featureGenJson, "") + recoveredStage shouldBe a[FeatureGeneratorStage[_, _]] + val extractFn = recoveredStage.asInstanceOf[FeatureGeneratorStage[Int, Integral]].extractFn + extractFn shouldBe a[IntMultiplyExtractor] + extractFn.apply(7) shouldBe 70.toIntegral + } + def assertExtractFeatures(fgs: FeaturesAndGenerators): Unit = { for {(feature, featureGenerator) <- fgs} { rows.map { row => @@ -103,3 +120,18 @@ class FeatureGeneratorStageTest extends FlatSpec with TestSparkContext { } } + +@ReaderWriter(classOf[IntMultiplyExtractorReadWrite]) +class IntMultiplyExtractor(val multiplier: Int) extends Function1[Int, Integral] { + def apply(i: Int): Integral = (i * multiplier).toIntegral +} + +class IntMultiplyExtractorReadWrite extends ValueReaderWriter[IntMultiplyExtractor] { + implicit val formats = DefaultFormats + def read(valueClass: Class[IntMultiplyExtractor], json: JValue): Try[IntMultiplyExtractor] = Try { + new IntMultiplyExtractor((json \ "multiplier").extract[Int]) + } + def write(value: IntMultiplyExtractor): Try[JValue] = Try { + "multiplier" -> JInt(value.multiplier) + } +}