# Predict on data stream

In [None]:
// import "org.apache.spark %% spark-sql-kafka-0-10 % 2.4.0"

In [None]:
val spark = sparkSession
val dataDir = System.getenv("HOME") + "/data"

### Load the saved model

In [None]:
import org.apache.spark.ml.{Pipeline, PipelineModel}

In [None]:
val model = PipelineModel.load(s"$dataDir/spark-linear-model")

### Create a structured stream from kafka `test` topic

In [None]:
val rawData = spark
  .readStream
  .format("kafka")
  .option("kafka.bootstrap.servers", "192.168.58.111:9092")
  .option("subscribe", "test")
  .option("startingOffsets", "earliest")
  .load()

In [None]:
rawData.isStreaming

In [None]:
rawData.printSchema()

### case class to deserialize json messages to

In [None]:
case class Trade(exchange: String, pair: String, timestamp: Long, price: Double, volume: Double)

import org.apache.spark.sql.Encoders
val schema = Encoders.product[Trade].schema

Kafka messages are

* cast as `String`
* json is parsed
* and decoded as `Trade` objects

In [None]:
val rawValues = rawData.selectExpr("CAST(value AS STRING)").as[String]
val jsonValues = rawValues.select(from_json($"value", schema) as "record")
val tradeData = jsonValues.select("record.*").as[Trade]

### Inspect the content of the stream with an in-memory output

In [None]:
val visualizationQuery = tradeData.writeStream
  .queryName("visualization")    // this query name will be the SQL table name
  .outputMode("append")
  .format("memory")
  .start()

In [None]:
val sampleDataset = sparkSession.sql("select * from visualization")

### Count the amount of data processed

Use the dataframe api...

### Count the number of records entries per `pair`

### What is the latest `timestamp` for each `pair`?

You may need to `collect` to display on notebook

### define an aggregate function for latest price

In [None]:
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._

class LastPrice extends UserDefinedAggregateFunction {
  // This is the input fields for your aggregate function.
  override def inputSchema: org.apache.spark.sql.types.StructType =
    StructType(StructField("timestamp", LongType) :: StructField("price", DoubleType) :: Nil)

  // This is the internal fields you keep for computing your aggregate.
  override def bufferSchema: StructType = StructType(
    StructField("timestamp", LongType) ::
    StructField("last", DoubleType) :: Nil
  )

  // This is the output type of your aggregatation function.
  override def dataType: DataType = DoubleType

  override def deterministic: Boolean = true

  // This is the initial value for your buffer schema.
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = -1L
    buffer(1) = 0.0D
  }

  // This is how to update your buffer schema given an input.
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    if ( buffer.getAs[Long](0) < input.getAs[Long](0)) {
      buffer(0) = input.getAs[Long](0)
      buffer(1) = input.getAs[Double](1)
    }
  }

  // This is how to merge two objects with the bufferSchema type.
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    if ( buffer1.getAs[Long](0) < buffer2.getAs[Long](0)) {
      buffer1(0) = buffer2.getAs[Long](0)
      buffer1(1) = buffer2.getAs[Double](1)
    }
  }

  // This is where you output the final value, given the final value of your bufferSchema.
  override def evaluate(buffer: Row): Any = {
    buffer.getDouble(1)
  }
}

In [None]:
spark.udf.register("lastp", new LastPrice)

In [None]:
val lastp = new LastPrice

### Test the aggregate

### Test the pivoted latest data

### Define the prediction function

In [None]:
def predict() = {
  val data = sampleDataset.groupBy("pair")
             .agg(lastp($"timestamp",$"price") as "price", max($"timestamp") as "timestamp")
             .withColumn("ts", lit(1L))
             .groupBy("ts")
             .pivot("pair")
             .agg(min($"price"))
  model.transform(data)
}

In [None]:
predict().select("ETHUSD","prediction").first

In [None]:
predict().select("ETHUSD","prediction").first