Skip to content

Commit

Permalink
(refs #3) Improve type conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
takezoe committed May 27, 2019
1 parent f29b974 commit f4e2f89
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 15 deletions.
10 changes: 5 additions & 5 deletions data/import_titanic.py
Expand Up @@ -19,15 +19,15 @@ def import_events(client, file):
header = next(reader)
for row in reader:
properties = {
"survived": row[1],
"survived": float(row[1]),
"pClass": row[2],
"name": row[3],
"sex": row[4],
"age": row[5],
"sibSp": row[6],
"parCh": row[7],
"age": None if row[5] == '' else float(row[5]),
"sibSp": None if row[6] == '' else int(row[6]),
"parCh": None if row[7] == '' else int(row[7]),
"ticket": row[8],
"fare": row[9],
"fare": None if row[9] == '' else float(row[9]),
"cabin": row[10],
"embarked": row[11]
}
Expand Down
16 changes: 6 additions & 10 deletions src/main/scala/Algorithm.scala
Expand Up @@ -10,7 +10,7 @@ import grizzled.slf4j.Logger
import com.salesforce.op.features.{Feature, FeatureSparkTypes}
import com.salesforce.op.features.types._
import com.salesforce.op.local._
import com.salesforce.op.stages.impl.classification.{BinaryClassificationModelSelector, OpLogisticRegression}
import com.salesforce.op.stages.impl.classification.{BinaryClassificationModelSelector}
import org.apache.commons.io.FileUtils
import org.apache.predictionio.data.storage.Event
import org.apache.spark.sql.{Row, SparkSession}
Expand All @@ -37,11 +37,10 @@ case class AlgorithmParams(target: String, schema: Seq[Field]) extends Params {
def row(event: Event): Row = {
Row(
(schema.map { field =>
// TODO Better type conversion
val (value, default) = field.`type` match {
case "string" => (event.properties.getOpt[String](field.field), "")
case "double" => (event.properties.getOpt[String](field.field).filter(_.nonEmpty).map(_.toDouble), 0d)
case "int" => (event.properties.getOpt[String](field.field).filter(_.nonEmpty).map(_.toInt), 0)
case "double" => (event.properties.getOpt[Double](field.field), 0d)
case "int" => (event.properties.getOpt[Int](field.field), 0)
}
value match {
case Some(x) => x
Expand Down Expand Up @@ -82,12 +81,9 @@ case class AlgorithmParams(target: String, schema: Seq[Field]) extends Params {

def query(map: Map[String, Any]): Map[String, Any] = {
map.map { case (key, value) =>
// TODO Better type conversion
val field = schema.find(_.field == key).get
key -> (field.`type` match {
case "string" => value.toString
case "double" => value.toString.toDouble
case "int" => value.toString.toInt
key -> (value match {
case x: BigInt => x.toInt
case x => x
})
} + (target -> 0d)
}
Expand Down

0 comments on commit f4e2f89

Please sign in to comment.