Skip to content

Commit

Permalink
Convert lambda functions into concrete classes to allow compatibility…
Browse files Browse the repository at this point in the history
… with Scala 2.12 (#357)
  • Loading branch information
tovbinm authored and leahmcguire committed Jul 11, 2019
1 parent a53decd commit 28eac0c
Show file tree
Hide file tree
Showing 23 changed files with 234 additions and 133 deletions.
14 changes: 10 additions & 4 deletions core/src/main/scala/com/salesforce/op/dsl/RichDateFeature.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ trait RichDateFeature {
f.transformWith(
new UnaryLambdaTransformer[Date, DateList](
operationName = "dateToList",
RichDateFeatureLambdas.toDateList
new RichDateFeatureLambdas.ToDateList
)
)
}
Expand Down Expand Up @@ -137,7 +137,7 @@ trait RichDateFeature {
f.transformWith(
new UnaryLambdaTransformer[DateTime, DateTimeList](
operationName = "dateTimeToList",
RichDateFeatureLambdas.toDateTimeList
new RichDateFeatureLambdas.ToDateTimeList
)
)
}
Expand Down Expand Up @@ -204,7 +204,13 @@ trait RichDateFeature {
}

object RichDateFeatureLambdas {
def toDateList: Date => DateList = (x: Date) => x.value.toSeq.toDateList

def toDateTimeList: DateTime => DateTimeList = (x: DateTime) => x.value.toSeq.toDateTimeList
class ToDateList extends Function1[Date, DateList] with Serializable {
def apply(v: Date): DateList = v.value.toSeq.toDateList
}

class ToDateTimeList extends Function1[Date, DateTimeList] with Serializable {
def apply(v: Date): DateTimeList = v.value.toSeq.toDateTimeList
}

}
20 changes: 13 additions & 7 deletions core/src/main/scala/com/salesforce/op/dsl/RichMapFeature.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

package com.salesforce.op.dsl

import com.salesforce.op.dsl.RichMapFeatureLambdas._
import com.salesforce.op.features.FeatureLike
import com.salesforce.op.features.types._
import com.salesforce.op.stages.impl.feature._
Expand Down Expand Up @@ -1098,9 +1097,10 @@ trait RichMapFeature {
* @return prediction, rawPrediction, probability
*/
def tupled(): (FeatureLike[RealNN], FeatureLike[OPVector], FeatureLike[OPVector]) = {
(f.map[RealNN](predictionToRealNN),
f.map[OPVector](predictionToRaw),
f.map[OPVector](predictionToProbability)
import RichMapFeatureLambdas._
(f.map[RealNN](new PredictionToRealNN),
f.map[OPVector](new PredictionToRaw),
f.map[OPVector](new PredictionToProbability)
)
}

Expand All @@ -1121,11 +1121,17 @@ trait RichMapFeature {

object RichMapFeatureLambdas {

def predictionToRealNN: Prediction => RealNN = _.prediction.toRealNN
class PredictionToRealNN extends Function1[Prediction, RealNN] with Serializable {
def apply(p: Prediction): RealNN = p.prediction.toRealNN
}

def predictionToRaw: Prediction => OPVector = p => Vectors.dense(p.rawPrediction).toOPVector
class PredictionToRaw extends Function1[Prediction, OPVector] with Serializable {
def apply(p: Prediction): OPVector = Vectors.dense(p.rawPrediction).toOPVector
}

def predictionToProbability: Prediction => OPVector = p => Vectors.dense(p.probability).toOPVector
class PredictionToProbability extends Function1[Prediction, OPVector] with Serializable {
def apply(p: Prediction): OPVector = Vectors.dense(p.probability).toOPVector
}

}

Expand Down
56 changes: 38 additions & 18 deletions core/src/main/scala/com/salesforce/op/dsl/RichTextFeature.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ import com.salesforce.op.stages.impl.feature._
import com.salesforce.op.utils.text._

import scala.reflect.runtime.universe.TypeTag


trait RichTextFeature {
self: RichFeature =>

Expand All @@ -48,7 +50,7 @@ trait RichTextFeature {
*
* @return A new MultiPickList feature
*/
def toMultiPickList: FeatureLike[MultiPickList] = f.map[MultiPickList](textToMultiPickList)
def toMultiPickList: FeatureLike[MultiPickList] = f.map[MultiPickList](new TextToMultiPickList)


/**
Expand Down Expand Up @@ -560,14 +562,14 @@ trait RichTextFeature {
*
* @return email prefix
*/
def toEmailPrefix: FeatureLike[Text] = f.map[Text](emailToPrefix, "prefix")
def toEmailPrefix: FeatureLike[Text] = f.map[Text](new EmailPrefixToText, "prefix")

/**
* Extract email domains
*
* @return email domain
*/
def toEmailDomain: FeatureLike[Text] = f.map[Text](emailToDomain, "domain")
def toEmailDomain: FeatureLike[Text] = f.map[Text](new EmailDomainToText, "domain")

/**
* Check if email is valid
Expand Down Expand Up @@ -600,7 +602,7 @@ trait RichTextFeature {
others: Array[FeatureLike[Email]] = Array.empty,
maxPctCardinality: Double = OpOneHotVectorizer.MaxPctCardinality
): FeatureLike[OPVector] = {
val domains = (f +: others).map(_.map[PickList](emailToPickList))
val domains = (f +: others).map(_.map[PickList](new EmailDomainToPickList))
domains.head.pivot(others = domains.tail, topK = topK, minSupport = minSupport, cleanText = cleanText,
trackNulls = trackNulls, maxPctCardinality = maxPctCardinality
)
Expand All @@ -613,19 +615,19 @@ trait RichTextFeature {
/**
* Extract url domain, i.e. salesforce.com, data.com etc.
*/
def toDomain: FeatureLike[Text] = f.map[Text](urlToDomain, "urlDomain")
def toDomain: FeatureLike[Text] = f.map[Text](new URLDomainToText, "urlDomain")

/**
* Extracts url protocol, i.e. http, https, ftp etc.
*/
def toProtocol: FeatureLike[Text] = f.map[Text](urlToProtocol, "urlProtocol")
def toProtocol: FeatureLike[Text] = f.map[Text](new URLProtocolToText, "urlProtocol")

/**
* Verifies if the url is of correct form of "Uniform Resource Identifiers (URI): Generic Syntax"
* RFC2396 (http://www.ietf.org/rfc/rfc2396.txt)
* Default valid protocols are: http, https, ftp.
*/
def isValidUrl: FeatureLike[Binary] = f.exists(urlIsValid)
def isValidUrl: FeatureLike[Binary] = f.exists(new URLIsValid)

/**
* Converts a sequence of [[URL]] features into a vector, extracting the domains of the valid urls
Expand All @@ -650,7 +652,7 @@ trait RichTextFeature {
others: Array[FeatureLike[URL]] = Array.empty,
maxPctCardinality: Double = OpOneHotVectorizer.MaxPctCardinality
): FeatureLike[OPVector] = {
val domains = (f +: others).map(_.map[PickList](urlToPickList))
val domains = (f +: others).map(_.map[PickList](new URLDomainToPickList))
domains.head.pivot(others = domains.tail, topK = topK, minSupport = minSupport, cleanText = cleanText,
trackNulls = trackNulls, maxPctCardinality = maxPctCardinality
)
Expand Down Expand Up @@ -697,7 +699,7 @@ trait RichTextFeature {
): FeatureLike[OPVector] = {

val feats: Array[FeatureLike[PickList]] =
(f +: others).map(_.detectMimeTypes(typeHint).map[PickList](textToPickList))
(f +: others).map(_.detectMimeTypes(typeHint).map[PickList](new TextToPickList))

feats.head.vectorize(
topK = topK, minSupport = minSupport, cleanText = cleanText, trackNulls = trackNulls, others = feats.tail,
Expand Down Expand Up @@ -801,22 +803,40 @@ trait RichTextFeature {

object RichTextFeatureLambdas {

def emailToPickList: Email => PickList = _.domain.toPickList
class EmailDomainToPickList extends Function1[Email, PickList] with Serializable {
def apply(v: Email): PickList = v.domain.toPickList
}

def emailToPrefix: Email => Text = _.prefix.toText
class EmailDomainToText extends Function1[Email, Text] with Serializable {
def apply(v: Email): Text = v.domain.toText
}

def emailToDomain: Email => Text = _.domain.toText
class EmailPrefixToText extends Function1[Email, Text] with Serializable {
def apply(v: Email): Text = v.prefix.toText
}

def urlToPickList: URL => PickList = (v: URL) => if (v.isValid) v.domain.toPickList else PickList.empty
class URLDomainToPickList extends Function1[URL, PickList] with Serializable {
def apply(v: URL): PickList = if (v.isValid) v.domain.toPickList else PickList.empty
}

def urlToDomain: URL => Text = _.domain.toText
class URLDomainToText extends Function1[URL, Text] with Serializable {
def apply(v: URL): Text = v.domain.toText
}

def urlToProtocol: URL => Text = _.protocol.toText
class URLProtocolToText extends Function1[URL, Text] with Serializable {
def apply(v: URL): Text = v.protocol.toText
}

def urlIsValid: URL => Boolean = _.isValid
class URLIsValid extends Function1[URL, Boolean] with Serializable {
def apply(v: URL): Boolean = v.isValid
}

def textToPickList: Text => PickList = _.value.toPickList
class TextToPickList extends Function1[Text, PickList] with Serializable {
def apply(v: Text): PickList = v.value.toPickList
}

def textToMultiPickList: Text => MultiPickList = _.value.toSet[String].toMultiPickList
class TextToMultiPickList extends Function1[Text, MultiPickList] with Serializable {
def apply(v: Text): MultiPickList = v.value.toSet[String].toMultiPickList
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,14 @@ class EmailToPickListMapTransformer(uid: String = UID[EmailToPickListMapTransfor
operationName = "emailToPickListMap",
transformer = new UnaryLambdaTransformer[Email, PickList](
operationName = "emailToPickList",
transformFn = EmailToPickListMapTransformer.emailToPickList
transformFn = new EmailToPickListMapTransformer.EmailToPickList
)
)

object EmailToPickListMapTransformer {
def emailToPickList: Email => PickList = email => email.domain.toPickList

class EmailToPickList extends Function1[Email, PickList] with Serializable {
def apply(v: Email): PickList = v.domain.toPickList
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,8 @@ package com.salesforce.op.stages.impl.feature
import com.salesforce.op.UID
import com.salesforce.op.features.types._
import com.salesforce.op.stages.base.unary.UnaryTransformer
import com.salesforce.op.utils.json.{JsonLike, JsonUtils}
import com.salesforce.op.utils.json.JsonUtils
import org.apache.spark.sql.types.{Metadata, MetadataBuilder}
import org.json4s.JsonAST.{JField, JNothing}
import org.json4s.{CustomSerializer, JObject}

import scala.reflect.runtime.universe.TypeTag
import scala.util.{Failure, Try}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import scala.reflect.runtime.universe.TypeTag
class ToOccurTransformer[I <: FeatureType]
(
uid: String = UID[ToOccurTransformer[I]],
val matchFn: I => Boolean = ToOccurTransformer.defaultMatches[I]
val matchFn: I => Boolean = new ToOccurTransformer.DefaultMatches[I]
)(implicit tti: TypeTag[I])
extends UnaryTransformer[I, RealNN](operationName = "toOccur", uid = uid) {

Expand All @@ -60,11 +60,13 @@ class ToOccurTransformer[I <: FeatureType]

object ToOccurTransformer {

def defaultMatches[T <: FeatureType]: T => Boolean = {
case num: OPNumeric[_] if num.nonEmpty => num.toDouble.get > 0.0
case text: Text if text.nonEmpty => text.value.get.length > 0
case collection: OPCollection => collection.nonEmpty
case _ => false
class DefaultMatches[T <: FeatureType] extends Function1[T, Boolean] with Serializable {
def apply(t: T): Boolean = t match {
case num: OPNumeric[_] if num.nonEmpty => num.toDouble.get > 0.0
case text: Text if text.nonEmpty => text.value.get.length > 0
case collection: OPCollection => collection.nonEmpty
case _ => false
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class OpWorkflowModelReaderWriterTest
}

trait SwSingleStageFlow {
val vec = FeatureBuilder.OPVector[Passenger].extract(OpWorkflowModelReaderWriterTest.emptyVectorFn).asPredictor
val vec = FeatureBuilder.OPVector[Passenger].extract(new OpWorkflowModelReaderWriterTest.EmptyVectorFn).asPredictor
val scaler = new StandardScaler().setWithStd(false).setWithMean(false)
val schema = FeatureSparkTypes.toStructType(vec)
val data = spark.createDataFrame(List(Row(Vectors.dense(1.0))).asJava, schema)
Expand All @@ -158,12 +158,12 @@ class OpWorkflowModelReaderWriterTest

trait OldVectorizedFlow extends UIDReset {
val cat = Seq(gender, boarded, height, age, description).transmogrify()
val catHead = cat.map[Real](OpWorkflowModelReaderWriterTest.catHeadFn)
val catHead = cat.map[Real](new OpWorkflowModelReaderWriterTest.CatHeadFn)
val wf = new OpWorkflow().setParameters(workflowParams).setResultFeatures(catHead)
}

trait VectorizedFlow extends UIDReset {
val catHead = rawFeatures.transmogrify().map[Real](OpWorkflowModelReaderWriterTest.catHeadFn)
val catHead = rawFeatures.transmogrify().map[Real](new OpWorkflowModelReaderWriterTest.CatHeadFn)
val wf = new OpWorkflow().setParameters(workflowParams).setResultFeatures(catHead)
}

Expand Down Expand Up @@ -386,6 +386,13 @@ trait UIDReset {
}

object OpWorkflowModelReaderWriterTest {
def catHeadFn: OPVector => Real = v => Real(v.value.toArray.headOption)
def emptyVectorFn: Passenger => OPVector = _ => OPVector.empty

class CatHeadFn extends Function1[OPVector, Real] with Serializable {
def apply(v: OPVector): Real = Real(v.value.toArray.headOption)
}

class EmptyVectorFn extends Function1[Passenger, OPVector] with Serializable {
def apply(v: Passenger): OPVector = OPVector.empty
}

}
51 changes: 25 additions & 26 deletions core/src/test/scala/com/salesforce/op/stages/Lambdas.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,38 +34,37 @@ import com.salesforce.op.features.types.Real
import com.salesforce.op.features.types._

object Lambdas {
def fncUnary: Real => Real = (x: Real) => x.v.map(_ * 0.1234).toReal

def fncSequence: Seq[DateList] => Real = (x: Seq[DateList]) => {
val v = x.foldLeft(0.0)((a, b) => a + b.value.sum)
Math.round(v / 1E6).toReal
class FncSequence extends Function1[Seq[DateList], Real] with Serializable {
def apply(x: Seq[DateList]): Real = {
val v = x.foldLeft(0.0)((a, b) => a + b.value.sum)
Math.round(v / 1E6).toReal
}
}

def fncBinarySequence: (Real, Seq[DateList]) => Real = (y: Real, x: Seq[DateList]) => {
val v = x.foldLeft(0.0)((a, b) => a + b.value.sum)
(Math.round(v / 1E6) + y.value.getOrElse(0.0)).toReal
class FncBinarySequence extends Function2[Real, Seq[DateList], Real] with Serializable {
def apply(y: Real, x: Seq[DateList]): Real = {
val v = x.foldLeft(0.0)((a, b) => a + b.value.sum)
(Math.round(v / 1E6) + y.value.getOrElse(0.0)).toReal
}
}

def fncBinary: (Real, Real) => Real = (x: Real, y: Real) => (
for {
yv <- y.value
xv <- x.value
} yield xv * yv
).toReal
class FncUnary extends Function1[Real, Real] with Serializable {
def apply(x: Real): Real = x.v.map(_ * 0.1234).toReal
}

def fncTernary: (Real, Real, Real) => Real = (x: Real, y: Real, z: Real) =>
(for {
xv <- x.value
yv <- y.value
zv <- z.value
} yield xv * yv + zv).toReal
class FncBinary extends Function2[Real, Real, Real] with Serializable {
def apply(x: Real, y: Real): Real = (for {yv <- y.value; xv <- x.value} yield xv * yv).toReal
}

def fncQuaternary: (Real, Real, Text, Real) => Real = (x: Real, y: Real, t: Text, z: Real) =>
(for {
xv <- x.value
yv <- y.value
tv <- t.value
zv <- z.value
} yield xv * yv + zv * tv.length).toReal
class FncTernary extends Function3[Real, Real, Real, Real] with Serializable {
def apply(x: Real, y: Real, z: Real): Real =
(for {yv <- y.value; xv <- x.value; zv <- z.value} yield xv * yv + zv).toReal
}

class FncQuaternary extends Function4[Real, Real, Text, Real, Real] with Serializable {
def apply(x: Real, y: Real, t: Text, z: Real): Real =
(for {yv <- y.value; xv <- x.value; tv <- t.value; zv <- z.value} yield xv * yv + zv * tv.length).toReal
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class OpPipelineStagesTest

val testOp = new com.salesforce.op.stages.base.unary.UnaryLambdaTransformer[Real, Real](
operationName = "test",
transformFn = OpPipelineStagesTest.fnc0,
transformFn = new OpPipelineStagesTest.RealIdentity,
uid = "myID"
)

Expand Down Expand Up @@ -162,7 +162,10 @@ class OpPipelineStagesTest
}

object OpPipelineStagesTest {
def fnc0: Real => Real = x => x

class RealIdentity extends Function1[Real, Real] with Serializable {
def apply(v: Real): Real = v
}

class TestStage(implicit val tto: TypeTag[RealNN], val ttov: TypeTag[RealNN#Value])
extends Pipeline with OpPipelineStage1[RealNN, RealNN] {
Expand Down
Loading

0 comments on commit 28eac0c

Please sign in to comment.