In [0]:
from pyspark.sql.functions import expr, from_json, window, sum, to_timestamp, col
from pyspark.sql.types import StructType, StructField, StringType, DoubleType, IntegerType

In [0]:
class TradeSummary:
  def __init__(self):
    self.base_data_dir = "/FileStore/test"

  def getSchema(self):
    return StructType([
      StructField("CreatedTime", StringType()),
      StructField("Type", StringType()),
      StructField("Amount", IntegerType()),
      StructField("BrokerCode", StringType()),
    ])

  def readBronze(self):
    return spark.readStream.table("kafka_bz")

  def getTrade(self, kafka_df):
    return (kafka_df.select(from_json(kafka_df.value, self.getSchema()).alias("value"))
            .select("value.*")
            .withColumn("CreatedTime", expr("to_timestamp(CreatedTime, 'yyyy-MM-dd HH:mm:ss')"))
            .withColumn("Buy", expr("CASE WHEN Type == 'BUY' THEN Amount ELSE 0 END"))
            .withColumn("Sell", expr("CASE WHEN Type == 'SELL' THEN Amount ELSE 0 END"))
    )

  def getAggregate(self, trade_df):
    return ( trade_df.groupBy(window(trade_df.CreatedTime, "15 minutes"))
            .agg(sum("Buy").alias("TotalBuy"),
                 sum("Sell").alias("TotalSell")
                 )
            .select("window.start", "window.end", "TotalBuy", "TotalSell")
            )

  def saveResults(self, results_df):
    return ( results_df.writeStream
            .queryName("trade-summary")
            .format("delta")
            .option("checkpointLocation", f"{self.base_data_dir}/checkpoint/trade_summary")
            .outputMode("complete")
            .toTable("trade_summary")
          )

  def process(self):
    kafka_df = self.readBronze()
    trade_df = self.getTrade(kafka_df)
    result_df = self.getAggregate(trade_df)
    sQuery = self.saveResults(result_df)
    return sQuery

    