diff --git a/helloworld/src/main/resources/log4j.properties b/helloworld/src/main/resources/log4j.properties index dfcff8ca09..f258dfda9b 100644 --- a/helloworld/src/main/resources/log4j.properties +++ b/helloworld/src/main/resources/log4j.properties @@ -39,7 +39,7 @@ log4j.logger.com.github.fommil.netlib=ERROR log4j.logger.org.apache.avro.mapreduce.AvroKeyInputFormat=ERROR # TransmogrifAI logging -log4j.logger.com.salesforce.op=INFO +log4j.logger.com.salesforce.op=ERROR log4j.logger.com.salesforce.op.utils.spark.OpSparkListener=OFF # Helloworld logging diff --git a/helloworld/src/main/scala/com/salesforce/hw/OpTitanicSimple.scala b/helloworld/src/main/scala/com/salesforce/hw/OpTitanicSimple.scala index 100c43b320..2ce5d033e1 100644 --- a/helloworld/src/main/scala/com/salesforce/hw/OpTitanicSimple.scala +++ b/helloworld/src/main/scala/com/salesforce/hw/OpTitanicSimple.scala @@ -90,8 +90,8 @@ object OpTitanicSimple { println(s"Using user-supplied CSV file path: $csvFilePath") // Set up a SparkSession as normal - val conf = new SparkConf().setAppName(this.getClass.getSimpleName.stripSuffix("$")) - implicit val spark = SparkSession.builder.config(conf).getOrCreate() + implicit val spark = SparkSession.builder.config(new SparkConf()).getOrCreate() + import spark.implicits._ // Needed for Encoders for the Passenger case class //////////////////////////////////////////////////////////////////////////////// // RAW FEATURE DEFINITIONS @@ -125,18 +125,16 @@ object OpTitanicSimple { val passengerFeatures = Seq( pClass, name, age, sibSp, parCh, ticket, cabin, embarked, familySize, estimatedCostOfTickets, - pivotedSex, ageGroup + pivotedSex, ageGroup, normedAge ).transmogrify() // Optionally check the features with a sanity checker - val sanityCheck = true - val finalFeatures = if (sanityCheck) survived.sanityCheck(passengerFeatures) else passengerFeatures + val checkedFeatures = survived.sanityCheck(passengerFeatures, removeBadFeatures = true) // Define the model we want to use (here a simple logistic regression) and get the resulting output - val prediction = - BinaryClassificationModelSelector.withTrainValidationSplit( - modelTypesToUse = Seq(OpLogisticRegression) - ).setInput(survived, finalFeatures).getOutput() + val prediction = BinaryClassificationModelSelector.withTrainValidationSplit( + modelTypesToUse = Seq(OpLogisticRegression) + ).setInput(survived, checkedFeatures).getOutput() val evaluator = Evaluators.BinaryClassification().setLabelCol(survived).setPredictionCol(prediction) @@ -144,30 +142,20 @@ object OpTitanicSimple { // WORKFLOW ///////////////////////////////////////////////////////////////////////////////// - import spark.implicits._ // Needed for Encoders for the Passenger case class // Define a way to read data into our Passenger class from our CSV file - val trainDataReader = DataReaders.Simple.csvCase[Passenger]( - path = Option(csvFilePath), - key = _.id.toString - ) + val dataReader = DataReaders.Simple.csvCase[Passenger](path = Option(csvFilePath), key = _.id.toString) // Define a new workflow and attach our data reader - val workflow = - new OpWorkflow() - .setResultFeatures(survived, prediction) - .setReader(trainDataReader) + val workflow = new OpWorkflow().setResultFeatures(survived, prediction).setReader(dataReader) // Fit the workflow to the data - val fittedWorkflow = workflow.train() - println(s"Summary: ${fittedWorkflow.summary()}") + val model = workflow.train() + println(s"Model summary:\n${model.summaryPretty()}") // Manifest the result features of the workflow println("Scoring the model") - val (dataframe, metrics) = fittedWorkflow.scoreAndEvaluate(evaluator = evaluator) + val (scores, metrics) = model.scoreAndEvaluate(evaluator = evaluator) - println("Transformed dataframe columns:") - dataframe.columns.foreach(println) - println("Metrics:") - println(metrics) + println("Metrics:\n" + metrics) } } diff --git a/helloworld/src/main/scala/com/salesforce/hw/iris/OpIris.scala b/helloworld/src/main/scala/com/salesforce/hw/iris/OpIris.scala index cd450246a6..efd681f054 100644 --- a/helloworld/src/main/scala/com/salesforce/hw/iris/OpIris.scala +++ b/helloworld/src/main/scala/com/salesforce/hw/iris/OpIris.scala @@ -57,9 +57,9 @@ object OpIris extends OpAppWithRunner with IrisFeatures { val path = getFinalReadPath(params) val myFile = spark.sparkContext.textFile(path) - Left(myFile.filter(_.nonEmpty).zipWithIndex.map { case (x, number) => - val words = x.split(",") - new Iris(number.toInt, words(0).toDouble, words(1).toDouble, words(2).toDouble, words(3).toDouble, words(4)) + Left(myFile.filter(_.nonEmpty).zipWithIndex.map { case (x, id) => + val Array(sepalLength, sepalWidth, petalLength, petalWidth, klass) = x.split(",") + new Iris(id.toInt, sepalLength.toDouble, sepalWidth.toDouble, petalLength.toDouble, petalWidth.toDouble, klass) }) } } diff --git a/helloworld/src/main/scala/com/salesforce/hw/titanic/OpTitanicMini.scala b/helloworld/src/main/scala/com/salesforce/hw/titanic/OpTitanicMini.scala index 8b21c9d5a7..7c6dabf7e3 100644 --- a/helloworld/src/main/scala/com/salesforce/hw/titanic/OpTitanicMini.scala +++ b/helloworld/src/main/scala/com/salesforce/hw/titanic/OpTitanicMini.scala @@ -38,7 +38,6 @@ import com.salesforce.op.stages.impl.classification.BinaryClassificationModelsTo import com.salesforce.op.stages.impl.classification._ import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession -import org.apache.log4j.{Level, LogManager} /** * A minimal Titanic Survival example with TransmogrifAI @@ -62,7 +61,6 @@ object OpTitanicMini { ) def main(args: Array[String]): Unit = { - LogManager.getLogger("com.salesforce.op").setLevel(Level.ERROR) implicit val spark = SparkSession.builder.config(new SparkConf()).getOrCreate() import spark.implicits._ @@ -81,6 +79,7 @@ object OpTitanicMini { val prediction = BinaryClassificationModelSelector .withCrossValidation(modelTypesToUse = Seq(OpLogisticRegression, OpRandomForestClassifier)) .setInput(survived, checkedFeatures).getOutput() + val model = new OpWorkflow().setInputDataset(passengersData).setResultFeatures(prediction).train() println("Model summary:\n" + model.summaryPretty())