Skip to content

Commit

Permalink
Cleanup helloworld example (#226)
Browse files Browse the repository at this point in the history
  • Loading branch information
tovbinm committed Feb 16, 2019
1 parent a0af563 commit a893747
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 31 deletions.
2 changes: 1 addition & 1 deletion helloworld/src/main/resources/log4j.properties
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ log4j.logger.com.github.fommil.netlib=ERROR
log4j.logger.org.apache.avro.mapreduce.AvroKeyInputFormat=ERROR

# TransmogrifAI logging
log4j.logger.com.salesforce.op=INFO
log4j.logger.com.salesforce.op=ERROR
log4j.logger.com.salesforce.op.utils.spark.OpSparkListener=OFF

# Helloworld logging
Expand Down
38 changes: 13 additions & 25 deletions helloworld/src/main/scala/com/salesforce/hw/OpTitanicSimple.scala
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ object OpTitanicSimple {
println(s"Using user-supplied CSV file path: $csvFilePath")

// Set up a SparkSession as normal
val conf = new SparkConf().setAppName(this.getClass.getSimpleName.stripSuffix("$"))
implicit val spark = SparkSession.builder.config(conf).getOrCreate()
implicit val spark = SparkSession.builder.config(new SparkConf()).getOrCreate()
import spark.implicits._ // Needed for Encoders for the Passenger case class

////////////////////////////////////////////////////////////////////////////////
// RAW FEATURE DEFINITIONS
Expand Down Expand Up @@ -125,49 +125,37 @@ object OpTitanicSimple {
val passengerFeatures = Seq(
pClass, name, age, sibSp, parCh, ticket,
cabin, embarked, familySize, estimatedCostOfTickets,
pivotedSex, ageGroup
pivotedSex, ageGroup, normedAge
).transmogrify()

// Optionally check the features with a sanity checker
val sanityCheck = true
val finalFeatures = if (sanityCheck) survived.sanityCheck(passengerFeatures) else passengerFeatures
val checkedFeatures = survived.sanityCheck(passengerFeatures, removeBadFeatures = true)

// Define the model we want to use (here a simple logistic regression) and get the resulting output
val prediction =
BinaryClassificationModelSelector.withTrainValidationSplit(
modelTypesToUse = Seq(OpLogisticRegression)
).setInput(survived, finalFeatures).getOutput()
val prediction = BinaryClassificationModelSelector.withTrainValidationSplit(
modelTypesToUse = Seq(OpLogisticRegression)
).setInput(survived, checkedFeatures).getOutput()

val evaluator = Evaluators.BinaryClassification().setLabelCol(survived).setPredictionCol(prediction)

////////////////////////////////////////////////////////////////////////////////
// WORKFLOW
/////////////////////////////////////////////////////////////////////////////////

import spark.implicits._ // Needed for Encoders for the Passenger case class
// Define a way to read data into our Passenger class from our CSV file
val trainDataReader = DataReaders.Simple.csvCase[Passenger](
path = Option(csvFilePath),
key = _.id.toString
)
val dataReader = DataReaders.Simple.csvCase[Passenger](path = Option(csvFilePath), key = _.id.toString)

// Define a new workflow and attach our data reader
val workflow =
new OpWorkflow()
.setResultFeatures(survived, prediction)
.setReader(trainDataReader)
val workflow = new OpWorkflow().setResultFeatures(survived, prediction).setReader(dataReader)

// Fit the workflow to the data
val fittedWorkflow = workflow.train()
println(s"Summary: ${fittedWorkflow.summary()}")
val model = workflow.train()
println(s"Model summary:\n${model.summaryPretty()}")

// Manifest the result features of the workflow
println("Scoring the model")
val (dataframe, metrics) = fittedWorkflow.scoreAndEvaluate(evaluator = evaluator)
val (scores, metrics) = model.scoreAndEvaluate(evaluator = evaluator)

println("Transformed dataframe columns:")
dataframe.columns.foreach(println)
println("Metrics:")
println(metrics)
println("Metrics:\n" + metrics)
}
}
6 changes: 3 additions & 3 deletions helloworld/src/main/scala/com/salesforce/hw/iris/OpIris.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ object OpIris extends OpAppWithRunner with IrisFeatures {
val path = getFinalReadPath(params)
val myFile = spark.sparkContext.textFile(path)

Left(myFile.filter(_.nonEmpty).zipWithIndex.map { case (x, number) =>
val words = x.split(",")
new Iris(number.toInt, words(0).toDouble, words(1).toDouble, words(2).toDouble, words(3).toDouble, words(4))
Left(myFile.filter(_.nonEmpty).zipWithIndex.map { case (x, id) =>
val Array(sepalLength, sepalWidth, petalLength, petalWidth, klass) = x.split(",")
new Iris(id.toInt, sepalLength.toDouble, sepalWidth.toDouble, petalLength.toDouble, petalWidth.toDouble, klass)
})
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ import com.salesforce.op.stages.impl.classification.BinaryClassificationModelsTo
import com.salesforce.op.stages.impl.classification._
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.log4j.{Level, LogManager}

/**
* A minimal Titanic Survival example with TransmogrifAI
Expand All @@ -62,7 +61,6 @@ object OpTitanicMini {
)

def main(args: Array[String]): Unit = {
LogManager.getLogger("com.salesforce.op").setLevel(Level.ERROR)
implicit val spark = SparkSession.builder.config(new SparkConf()).getOrCreate()
import spark.implicits._

Expand All @@ -81,6 +79,7 @@ object OpTitanicMini {
val prediction = BinaryClassificationModelSelector
.withCrossValidation(modelTypesToUse = Seq(OpLogisticRegression, OpRandomForestClassifier))
.setInput(survived, checkedFeatures).getOutput()

val model = new OpWorkflow().setInputDataset(passengersData).setResultFeatures(prediction).train()

println("Model summary:\n" + model.summaryPretty())
Expand Down

0 comments on commit a893747

Please sign in to comment.