# Streaming application using Spark Structured Streaming

### 1. Write code to create a SparkSession, which uses four cores with a proper application name, use the Melbourne timezone, and make sure a checkpoint location has been set.


In [1]:
from IPython.display import clear_output

def foreach_batch_function(df, epoch_id):
    clear_output(wait=True) 
    df.show(10, False) 

    
topic = "big-data-a2-topic"
host_ip = "118.139.10.179"

In [2]:
import os
os.environ['PYSPARK_SUBMIT_ARGS'] = '--packages org.apache.spark:spark-streaming-kafka-0-10_2.12:3.4.0,org.apache.spark:spark-sql-kafka-0-10_2.12:3.4.0 pyspark-shell'

from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession
from pyspark.sql.functions import explode, split, regexp_extract

master = "local[4]"
app_name = "spark streaming"
spark_conf = SparkConf().setMaster(master).setAppName(app_name)

spark = SparkSession.builder.config(conf=spark_conf).getOrCreate()

spark.conf.set("spark.sql.session.timeZone", "Australia/Melbourne")
spark.conf.set("spark.sql.streaming.statefulOperator.checkCorrectness.enabled", "false")
spark.conf.set("spark.sql.streaming.stateStore.stateSchemaCheck", "false")



# Retrieve Spark context and set its log level to 'ERROR'.
sc = spark.sparkContext
sc.setLogLevel("ERROR")

### 2. Similar to assignment 2A, write code to define the data schema for the data files, following the data types suggested in the metadata file. Load the static datasets (e.g. customer, product, category) into data frames. 



In [3]:
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, FloatType, DateType, TimestampType

# 1. Customer.csv Schema
customer_schema = StructType([
    StructField("#", IntegerType(), True),
    StructField("customer_id", IntegerType(), True),
    StructField("first_name", StringType(), True),
    StructField("last_name", StringType(), True),
    StructField("username", StringType(), True),
    StructField("email", StringType(), True),
    StructField("gender", StringType(), True),
    StructField("birthdate", DateType(), True),
    StructField("device_type", StringType(), True),
    StructField("device_id", StringType(), True),
    StructField("device_version", StringType(), True),
    StructField("home_location_lat", FloatType(), True),
    StructField("home_location_long", FloatType(), True),
    StructField("home_location", StringType(), True),
    StructField("home_country", StringType(), True),
    StructField("first_join_date", DateType(), True)
])

# Load CSV files with predefined schemas
directory_path = "./dataset"
file_schemas = {
    "customer": customer_schema,
}

for file_name, schema in file_schemas.items():
    file_path = os.path.join(directory_path, file_name + ".csv")
    df_name = file_name + "_df"
    globals()[df_name] = spark.read.format('csv')\
        .option('header', True).option('escape', '"')\
        .schema(schema)\
        .load(file_path)
    # drop redundant info
    globals()[df_name] = globals()[df_name].drop("#")
    
    # show schema
    print(f"Schema for {file_name}:")
    globals()[df_name].printSchema()

Schema for customer:
root
 |-- customer_id: integer (nullable = true)
 |-- first_name: string (nullable = true)
 |-- last_name: string (nullable = true)
 |-- username: string (nullable = true)
 |-- email: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- birthdate: date (nullable = true)
 |-- device_type: string (nullable = true)
 |-- device_id: string (nullable = true)
 |-- device_version: string (nullable = true)
 |-- home_location_lat: float (nullable = true)
 |-- home_location_long: float (nullable = true)
 |-- home_location: string (nullable = true)
 |-- home_country: string (nullable = true)
 |-- first_join_date: date (nullable = true)



### 3 Using the Kafka topic from the producer in Task 1, ingest the streaming data into Spark Streaming, assuming all data comes in the String format. Except for the 'ts' column, you shall receive it as an Int type.




In [4]:
rt_stream =  spark \
    .readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", f"{host_ip}:9092") \
    .option("subscribe", topic) \
    .load()

rt_stream.printSchema()

from pyspark.sql.types import StructType, StructField, StringType, IntegerType, TimestampType

# main schema
main_schema = StructType([    
    StructField('session_id', StringType(), True), 
    StructField('event_name', StringType(), True),
    StructField('event_id', StringType(), True),     
    StructField('traffic_source', StringType(), True),     
    StructField('customer_id', StringType(), True),
    StructField('ts', TimestampType(), True),
])

# branch schema
metadata_schema = StructType([
    StructField("product_id", IntegerType(), True),
    StructField("quantity", IntegerType(), True),
    StructField("item_price", IntegerType(), True),
    StructField("payment_status", StringType(), True),
    StructField("search_keywords", StringType(), True),
    StructField("promo_code", StringType(), True),
    StructField("promo_amount", IntegerType(), True)
])

root
 |-- key: binary (nullable = true)
 |-- value: binary (nullable = true)
 |-- topic: string (nullable = true)
 |-- partition: integer (nullable = true)
 |-- offset: long (nullable = true)
 |-- timestamp: timestamp (nullable = true)
 |-- timestampType: integer (nullable = true)



In [5]:
# Get value of the kafka message
stream_df = rt_stream.selectExpr("CAST(value AS STRING) AS value", "CAST(timestamp as timestamp)")
stream_df.printSchema()

root
 |-- value: string (nullable = true)
 |-- timestamp: timestamp (nullable = true)



### 4 Then, the streaming data format should be transformed into the proper formats following the metadata file schema, similar to assignment 2A.  
Perform the following tasks:  
a) For the 'ts' column, convert it to the timestamp format, we will use it as event_time.  
b) If the data is late for more than 1 minute, discard it.  



In [6]:
from pyspark.sql.functions import split, trim, regexp_extract, regexp_replace, col, from_json

cleaned_value = regexp_replace(stream_df["value"], "^\[|\]$", "")
split_col = split(cleaned_value, ",")

stream_df = stream_df.withColumn("event_metadata", regexp_extract(cleaned_value, "\{.*\}", 0))

stream_df = stream_df.select(
    trim(regexp_replace(split_col.getItem(1), "['\"]", "")).alias('session_id'),
    trim(regexp_replace(split_col.getItem(2), "['\"]", "")).alias('event_name'),
    trim(regexp_replace(split_col.getItem(3), "['\"]", "")).alias('event_id'),
    trim(regexp_replace(split_col.getItem(4), "['\"]", "")).alias('traffic_source'),
    "event_metadata",
    trim(regexp_replace(split_col.getItem(6), "['\"]", "")).alias('customer_id'),
    split_col.getItem(7).cast("int").alias('ts'),
    "timestamp"
)

stream_df.printSchema()

root
 |-- session_id: string (nullable = true)
 |-- event_name: string (nullable = true)
 |-- event_id: string (nullable = true)
 |-- traffic_source: string (nullable = true)
 |-- event_metadata: string (nullable = true)
 |-- customer_id: string (nullable = true)
 |-- ts: integer (nullable = true)
 |-- timestamp: timestamp (nullable = true)



In [7]:
from pyspark.sql.functions import from_unixtime, current_timestamp, expr, minute, col, unix_timestamp, current_timestamp

# ts to timestamp
stream_df = stream_df.withColumn("ts", from_unixtime(col("ts")).cast("timestamp"))


# Keep only rows from the last 1 minute
stream_df = stream_df.where(
    (unix_timestamp(current_timestamp()) - unix_timestamp(col("timestamp"))) <= 60
)

# watermark
stream_df = stream_df.withWatermark("timestamp", "1 minutes")

### 5  
Aggregate the streaming data frame by session id and create features you used in your assignment 2A model. (note: customer ID has already been included in the stream.)   
Then, join the static data frames with the streaming data frame as our final data for prediction.  
Perform data type/column conversion according to your ML model, and print out the Schema.


In [8]:
from pyspark.sql.functions import month, current_timestamp
from pyspark.sql.functions import sum, count, when, col, first, collect_list

category_1 = ["ADD_PROMO", "ADD_TO_CART"]
category_2 = ["VIEW_PROMO", "VIEW_ITEM", "SEARCH"]
category_3 = ["SCROLL", "HOMEPAGE", "CLICK"]

agg_df = stream_df.withColumn(
    "is_promotion",
    when(col("event_name") == "ADD_PROMO", 1).otherwise(0)
).withColumn(
    "event_category",
    when(col("event_name").isin(category_1), "num_cat_highvalue")
    .when(col("event_name").isin(category_2), "num_cat_midvalue")
    .otherwise("num_cat_lowvalue")
)


In [9]:
from pyspark.sql.functions import collect_list, explode, from_json, col, first, last, struct
from pyspark.sql.types import StructType, StructField, StringType, IntegerType



# # DataFrame 2: Event Metadata
df_event_metadata = agg_df.groupBy("session_id").agg(
    collect_list("event_metadata").alias("event_metadata_list")
)

df_event_metadata = df_event_metadata.select(
    "session_id",
    explode(col("event_metadata_list")).alias("single_event_metadata")
)

df_event_metadata = df_event_metadata.withColumn("metadata", from_json("single_event_metadata", metadata_schema))


df_event_metadata = df_event_metadata.select(
    "session_id",
    "metadata.product_id",
    "metadata.quantity",
    "metadata.item_price"
).filter(col("metadata.product_id").isNotNull())  # Filter out null product_ids


df_event_metadata = df_event_metadata.groupBy("session_id", "product_id").agg(
    last("quantity").alias("quantity"),
    last("item_price").alias("item_price")
)

df_event_metadata = df_event_metadata.dropDuplicates(["session_id", "product_id"])


df_event_metadata = df_event_metadata.groupBy("session_id").agg(
    collect_list(struct("product_id", "quantity", "item_price")).alias("event_metadata_list")
)

query_event_metadata = df_event_metadata.writeStream.outputMode("complete") \
        .format("memory")\
        .queryName("event_metadata_table")\
        .trigger(processingTime='5 seconds')\
        .start()

event_metadata_df = spark.sql("select * from event_metadata_table")

In [10]:
# query_event_metadata.stop()

In [11]:
# DataFrame 3: Aggregated Data
from pyspark.sql.functions import window

df_agg = agg_df.groupBy("session_id",
    window("timestamp", "1 minutes")
                       ).agg(
    first("customer_id").alias("customer_id"),
    first("ts").alias("ts"),
    first("traffic_source").alias("traffic_source"),
    first("timestamp").alias("timestamp"),
    sum(when(col("event_category") == "num_cat_highvalue", 1).otherwise(0)).alias("num_cat_highvalue"),
    sum(when(col("event_category") == "num_cat_midvalue", 1).otherwise(0)).alias("num_cat_midvalue"),
    sum(when(col("event_category") == "num_cat_lowvalue", 1).otherwise(0)).alias("num_cat_lowvalue"),
    (when(sum(col("is_promotion")) > 0, 1).otherwise(0)).alias("is_promotion"),
    count("*").alias("total_actions")
)

agg_df = df_agg.join(event_metadata_df, on="session_id", how="left")

agg_df = agg_df.withColumn("high_value_ratio", 
                                         (col("num_cat_highvalue") / col("total_actions")) * 100)
agg_df = agg_df.withColumn("low_value_ratio", 
                                         (col("num_cat_lowvalue") / col("total_actions")) * 100)

agg_df = agg_df.drop("total_actions")

agg_df.printSchema()

root
 |-- session_id: string (nullable = true)
 |-- window: struct (nullable = false)
 |    |-- start: timestamp (nullable = true)
 |    |-- end: timestamp (nullable = true)
 |-- customer_id: string (nullable = true)
 |-- ts: timestamp (nullable = true)
 |-- traffic_source: string (nullable = true)
 |-- timestamp: timestamp (nullable = true)
 |-- num_cat_highvalue: long (nullable = true)
 |-- num_cat_midvalue: long (nullable = true)
 |-- num_cat_lowvalue: long (nullable = true)
 |-- is_promotion: integer (nullable = false)
 |-- event_metadata_list: array (nullable = true)
 |    |-- element: struct (containsNull = false)
 |    |    |-- product_id: integer (nullable = true)
 |    |    |-- quantity: integer (nullable = true)
 |    |    |-- item_price: integer (nullable = true)
 |-- high_value_ratio: double (nullable = true)
 |-- low_value_ratio: double (nullable = true)



In [None]:
# query_agg = agg_df.writeStream.outputMode("append")\
#         .foreachBatch(foreach_batch_function)\
#         .trigger(processingTime='10 seconds')\
#         .start()

In [None]:
# query_agg.stop()

In [12]:
from pyspark.sql.functions import year, datediff, current_date, broadcast


join_df = agg_df.join(broadcast(customer_df), on="customer_id", how="inner")


join_df = join_df.withColumn("month", month("ts"))

join_df = join_df.withColumn("season", 
    when(col("month").between(3, 5), "Spring")
    .when(col("month").between(6, 8), "Summer")
    .when(col("month").between(9, 11), "Autumn")
    .otherwise("Winter")
)

join_df = join_df.withColumn("age", (datediff(current_date(), col("birthdate"))/365).cast('int'))
join_df = join_df.withColumn("first_join_year", year("first_join_date"))

join_df.printSchema()

root
 |-- customer_id: string (nullable = true)
 |-- session_id: string (nullable = true)
 |-- window: struct (nullable = false)
 |    |-- start: timestamp (nullable = true)
 |    |-- end: timestamp (nullable = true)
 |-- ts: timestamp (nullable = true)
 |-- traffic_source: string (nullable = true)
 |-- timestamp: timestamp (nullable = true)
 |-- num_cat_highvalue: long (nullable = true)
 |-- num_cat_midvalue: long (nullable = true)
 |-- num_cat_lowvalue: long (nullable = true)
 |-- is_promotion: integer (nullable = false)
 |-- event_metadata_list: array (nullable = true)
 |    |-- element: struct (containsNull = false)
 |    |    |-- product_id: integer (nullable = true)
 |    |    |-- quantity: integer (nullable = true)
 |    |    |-- item_price: integer (nullable = true)
 |-- high_value_ratio: double (nullable = true)
 |-- low_value_ratio: double (nullable = true)
 |-- first_name: string (nullable = true)
 |-- last_name: string (nullable = true)
 |-- username: string (nullable = tru

### 6 Load your ML model, and use the model to predict if each session will purchase according to the requirements below:
a) Every 10 seconds, show the total number of potential sales transactions (prediction = 1) in the last 1 minute.   
b) Every 30 seconds, show the total potential revenue in the last 30 seconds. “Potiential revenue” here is definded as: When prediction=1, extract customer shopping cart detail from metadata (sum of all items of ADD_TO_CART events).  
c) Every 1 minute, show the top 10 best-selling products by total quantity. (note: No historical data is required, only the top 10 in each 1 minute window.)  


In [13]:
columns_to_select = ["traffic_source", "season", "gender", "device_type", "home_location", "home_country",\
                     "home_location_lat", "home_location_long", "event_metadata_list", "high_value_ratio",\
                     "low_value_ratio", "is_promotion", "month", "age", "first_join_year", "timestamp", "customer_id"]
transformed_df = join_df.select(*columns_to_select)
transformed_df = transformed_df.dropna(how="any")

# Displaying the schema
transformed_df.printSchema()


root
 |-- traffic_source: string (nullable = true)
 |-- season: string (nullable = false)
 |-- gender: string (nullable = true)
 |-- device_type: string (nullable = true)
 |-- home_location: string (nullable = true)
 |-- home_country: string (nullable = true)
 |-- home_location_lat: float (nullable = true)
 |-- home_location_long: float (nullable = true)
 |-- event_metadata_list: array (nullable = true)
 |    |-- element: struct (containsNull = false)
 |    |    |-- product_id: integer (nullable = true)
 |    |    |-- quantity: integer (nullable = true)
 |    |    |-- item_price: integer (nullable = true)
 |-- high_value_ratio: double (nullable = true)
 |-- low_value_ratio: double (nullable = true)
 |-- is_promotion: integer (nullable = false)
 |-- month: integer (nullable = true)
 |-- age: integer (nullable = true)
 |-- first_join_year: integer (nullable = true)
 |-- timestamp: timestamp (nullable = true)
 |-- customer_id: string (nullable = true)



In [None]:
# query_trans = transformed_df.writeStream.outputMode("append")\
#         .foreachBatch(foreach_batch_function)\
#         .trigger(processingTime='1 seconds')\
#         .start()

In [None]:
# query_trans.stop()

In [14]:
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
from pyspark.ml.classification import RandomForestClassifier, GBTClassifier
from pyspark.ml import Pipeline

In [15]:
# StringIndexer columns
categoryInputCols = ["traffic_source", "season", "gender", "device_type", "home_location"]
numericInputCols = ["high_value_ratio", "low_value_ratio", "is_promotion", "month", "age",
                    "first_join_year"]

outputCols = [f'{x}_idx' for x in categoryInputCols]

# StringIndexer 
inputIndexer = StringIndexer(inputCols=categoryInputCols, outputCols=outputCols, handleInvalid="keep")

# OneHotEncoder columns
inputCols_OHE = [x for x in outputCols]
outputCols_OHE = [f'{x}_vec' for x in categoryInputCols]

# OneHotEncoder
encoder = OneHotEncoder(inputCols=inputCols_OHE, outputCols=outputCols_OHE)

# Assembler columns
assemblerInputs = outputCols_OHE + numericInputCols

# Assembler
assembler = VectorAssembler(inputCols=assemblerInputs, outputCol="features", handleInvalid = "keep")    

# Gradient Boosted Tree Classifier
gbt = GBTClassifier(labelCol="made_purchase", featuresCol="features")

# pipeline for Gradient Boosted Tree
pipeline_gbt = Pipeline(stages=[inputIndexer, encoder, assembler, gbt])

In [16]:
# 6a
from pyspark.ml import PipelineModel

model = PipelineModel.load("./model_gbt/")
predictions_df = model.transform(transformed_df)
predictions_df.printSchema()

root
 |-- traffic_source: string (nullable = true)
 |-- season: string (nullable = false)
 |-- gender: string (nullable = true)
 |-- device_type: string (nullable = true)
 |-- home_location: string (nullable = true)
 |-- home_country: string (nullable = true)
 |-- home_location_lat: float (nullable = true)
 |-- home_location_long: float (nullable = true)
 |-- event_metadata_list: array (nullable = true)
 |    |-- element: struct (containsNull = false)
 |    |    |-- product_id: integer (nullable = true)
 |    |    |-- quantity: integer (nullable = true)
 |    |    |-- item_price: integer (nullable = true)
 |-- high_value_ratio: double (nullable = true)
 |-- low_value_ratio: double (nullable = true)
 |-- is_promotion: integer (nullable = false)
 |-- month: integer (nullable = true)
 |-- age: integer (nullable = true)
 |-- first_join_year: integer (nullable = true)
 |-- timestamp: timestamp (nullable = true)
 |-- customer_id: string (nullable = true)
 |-- traffic_source_idx: double (null

In [None]:
# query_predictions = predictions_df.writeStream.outputMode("append")\
#         .foreachBatch("console")\
#         .trigger(processingTime='1 seconds')\
#         .start()

In [None]:
# query_predictions.stop()

In [17]:
from pyspark.sql.functions import window, sum, col
from pyspark.sql.functions import current_timestamp, unix_timestamp

# 6a) Every 10 seconds, show the total number of potential sales transactions (prediction = 1) in the last 1 minute

import uuid

base_checkpoint_path = "./checkpoints/"
count_checkpoint_path = base_checkpoint_path + str(uuid.uuid4())

filter_df = predictions_df.filter(col("prediction") == 1.0)

windowedCounts = filter_df \
    .withWatermark("timestamp", "10 seconds") \
    .groupBy(window(col("timestamp"), "1 minute", "10 seconds")) \
    .count()

query_prediction = windowedCounts \
    .writeStream \
    .outputMode("update") \
    .option("truncate", "false") \
    .option("checkpointLocation", count_checkpoint_path) \
    .format("console") \
    .trigger(processingTime="10 seconds") \
    .start()

In [None]:
# query_prediction.stop()

In [18]:
# 6b
# b) Every 30 seconds, show the total potential revenue in the last 30 seconds.

from pyspark.sql.functions import explode, sum, window

revenue_checkpoint_path = base_checkpoint_path + str(uuid.uuid4())


# Explode the event_meta_list to extract each product
exploded_df = filter_df.withColumn("event_meta_item", explode(col("event_metadata_list")))

# Calculate potential revenue for each product
revenue_df = exploded_df.withColumn("revenue", col("event_meta_item.quantity") * col("event_meta_item.item_price"))

# Calculate total potential revenue in the last 30 seconds using window function
windowedRevenue = revenue_df \
    .withWatermark("timestamp", "30 seconds") \
    .groupBy(window(col("timestamp"), "30 seconds")) \
    .sum("revenue")\
    .withColumnRenamed("sum(revenue)", "revenue") \

# Stream the results
query_revenue = windowedRevenue.writeStream.outputMode("update") \
    .format("console") \
    .option("truncate", "false") \
    .option("checkpointLocation", revenue_checkpoint_path) \
    .trigger(processingTime="30 seconds") \
    .start()

In [None]:
# query_revenue.stop()

In [None]:
# # 6c) Every 1 minute, show the top 10 best-selling products by total quantity.
# from pyspark.sql.functions import explode, sum, window, col, desc

# # 1. Group by product_id and window of 1 minute sliding every 10 seconds
# product_agg = exploded_df.withWatermark("timestamp", "1 minute") \
#                          .groupBy("event_meta_item.product_id", 
#                                   window(col("timestamp"), "1 minute", "10 seconds")
#                                  ).agg(
#     sum("event_meta_item.quantity").alias("total_quantity")
# )

# # 2. Stream the results using console output
# query_product = product_agg.writeStream.outputMode("complete") \
#     .format("console") \
#     .option("truncate", "false") \
#     .trigger(processingTime='10 seconds') \
#     .start()


In [19]:
# # 6c) Every 1 minute, show the top 10 best-selling products by total quantity.

product_checkpoint_path = base_checkpoint_path + str(uuid.uuid4())

def process_order(df, epoch_id):
    df = df.orderBy(desc("total_quantity")).limit(10)
    df.show()

# 1. Group by product_id and window of 1 minute sliding every 10 seconds
product_agg = exploded_df.withWatermark("timestamp", "1 minute") \
                         .groupBy("event_meta_item.product_id", 
                                  window(col("timestamp"), "1 minute", "10 seconds")
                                 ).agg(
    sum("event_meta_item.quantity").alias("total_quantity")
)

# 2. Stream the results using foreachBatch
query_product = product_agg.writeStream.outputMode("update") \
    .format("console") \
    .option("truncate", "false") \
    .option("checkpointLocation", product_checkpoint_path) \
    .trigger(processingTime="10 seconds") \
    .start()

In [None]:
# query_product.stop()

### 7  
a) Persist the prediction result along with cart metadata in parquet format; after that, read the parquet file and show the results to verify it is saved properly.  
b) Persist the 30-second sales prediction in another parquet file.

In [20]:
# 7a
from time import sleep
import shutil


selected_df = filter_df.withWatermark("timestamp", "1 minutes")\
            .select("prediction", "event_metadata_list", "timestamp", "customer_id")

output_path = "./output/predictions.parquet"
shutil.rmtree(output_path, ignore_errors=True)
    
def write_to_parquet(batch_df, batch_id):
    batch_df.write.parquet(output_path, mode="append")

query_sink = selected_df.writeStream \
    .outputMode("update") \
    .foreachBatch(write_to_parquet) \
    .trigger(processingTime='10 seconds') \
    .start()

# Wait for the streaming query to process some data
query_sink.awaitTermination(timeout=60)  # Wait for 60 seconds or until the query is terminated

# Now, read the Parquet file
prediction_sink = spark.read.load(output_path)
prediction_sink.show()

+----------+-------------------+---------+-----------+
|prediction|event_metadata_list|timestamp|customer_id|
+----------+-------------------+---------+-----------+
+----------+-------------------+---------+-----------+



In [24]:
prediction_sink = spark.read.load(output_path)
prediction_sink.show()

+----------+--------------------+--------------------+-----------+
|prediction| event_metadata_list|           timestamp|customer_id|
+----------+--------------------+--------------------+-----------+
|       1.0|[{39214, 1, 15964...|2023-10-19 10:20:...|      55963|
|       1.0|[{20705, 1, 27147...|2023-10-19 10:20:...|       8997|
|       1.0|[{19130, 4, 18968...|2023-10-19 10:20:...|      66840|
|       1.0|[{13287, 1, 16265...|2023-10-19 10:18:...|      72301|
|       1.0|[{35078, 1, 32088...|2023-10-19 10:20:...|      48622|
|       1.0|[{53941, 1, 30837...|2023-10-19 10:20:...|      23696|
|       1.0|[{49605, 1, 25232...|2023-10-19 10:20:...|      51927|
|       1.0|[{11841, 2, 38563...|2023-10-19 10:18:...|      60118|
|       1.0|[{47681, 1, 51961...|2023-10-19 10:20:...|      39406|
|       1.0|[{47350, 2, 17022...|2023-10-19 10:20:...|      95334|
|       1.0|[{51938, 1, 14871...|2023-10-19 10:18:...|      58344|
|       1.0|[{22984, 7, 24040...|2023-10-19 10:20:...|      48

In [None]:
# query_sink.stop()

In [21]:
# 7b
# Save to Parquet
import shutil

output_path_r = "./output/revenue.parquet"
shutil.rmtree(output_path_r, ignore_errors=True)
    
def write_new_parquet(batch_df, batch_id):
    batch_df.write.parquet(output_path_r, mode="append")
#     batch_df.write.mode("append").save(output_path_r)

query_sink_r = windowedRevenue.writeStream \
    .outputMode("update") \
    .foreachBatch(write_new_parquet) \
    .trigger(processingTime='5 seconds') \
    .start()

query_sink_r.awaitTermination(timeout=60)

False

In [None]:
# query_sink_r.stop()

In [25]:
sale_sink = spark.read.load(output_path_r)
sale_sink.show()

+--------------------+--------+
|              window| revenue|
+--------------------+--------+
|{2023-10-19 10:20...|10662908|
|{2023-10-19 10:20...|19669485|
+--------------------+--------+



### 8  
Read the parquet files as a data stream, for 7a) join customer information and send to a Kafka topic with an appropriate name to the data visualisation. For 7b) Send the message directly to another Kafka topic.

In [22]:
# Stream 1
from pyspark.sql.functions import to_json, struct

static_df = spark.read.parquet(output_path)
schema = static_df.schema

parquet_stream = spark.readStream.schema(schema).parquet(output_path)
parquet_stream = parquet_stream.join(broadcast(customer_df), on="customer_id", how="inner")
parquet_stream = parquet_stream.selectExpr("to_json(struct(*)) AS value")
parquet_stream.printSchema()

query_send = parquet_stream.writeStream \
    .outputMode("append")\
    .format("kafka") \
    .option("kafka.bootstrap.servers", f"{host_ip}:9092") \
    .option("topic", "for_visualisation") \
    .option("checkpointLocation", "./output/check/for_visualisation") \
    .start()

root
 |-- value: string (nullable = true)



In [None]:
# from pyspark.sql import functions as F


# parquet_stream = spark.readStream.schema(schema).parquet(output_path)
# parquet_stream = parquet_stream.withWatermark("timestamp", "30 seconds")
# parquet_stream = parquet_stream.alias("stream").join(broadcast(customer_df), on="customer_id", how="inner")

# current_time = F.current_timestamp()
# five_minutes_ago = current_time - F.expr("INTERVAL 30 SECONDS")
# filtered_stream = parquet_stream.filter((parquet_stream["stream.timestamp"] >= five_minutes_ago) & (parquet_stream["stream.timestamp"] <= current_time))

# filtered_stream = filtered_stream.selectExpr("to_json(struct(*)) AS value")

# query_send = filtered_stream.writeStream \
#     .outputMode("append")\
#     .format("kafka") \
#     .option("kafka.bootstrap.servers", f"{host_ip}:9092") \
#     .option("topic", "for_visualisation") \
#     .option("checkpointLocation", "./output/check/for_visualisation") \
#     .trigger(processingTime='30 seconds') \
#     .start()


In [None]:
query_send.status

In [None]:
# query_send.stop()

In [23]:
# Stream 2
static_df_2 = spark.read.parquet(output_path_r)
schema_2 = static_df_2.schema

parquet_stream_2 = spark.readStream.schema(schema_2).parquet(output_path_r)
parquet_stream_2 = parquet_stream_2.selectExpr("to_json(struct(*)) AS value")
parquet_stream_2.printSchema()

query_topic = parquet_stream_2.writeStream \
    .outputMode("append")\
    .format("kafka") \
    .option("kafka.bootstrap.servers", f"{host_ip}:9092") \
    .option("topic", "another_topic") \
    .option("checkpointLocation", "./output/check/another_topic") \
    .start()

root
 |-- value: string (nullable = true)



In [None]:
# from pyspark.sql import functions as F

# static_df_2 = spark.read.parquet(output_path_r)
# schema_2 = static_df_2.schema


# parquet_stream_2 = spark.readStream.schema(schema_2).parquet(output_path_r)
# parquet_stream_2 = parquet_stream_2.withColumn("start_time", parquet_stream_2.window.start)
# parquet_stream_2 = parquet_stream_2.withWatermark("start_time", "30 seconds")

# current_time = F.current_timestamp()
# thirty_seconds_ago = current_time - F.expr("INTERVAL 30 SECONDS")
# filtered_stream_2 = parquet_stream_2.filter((parquet_stream_2["start_time"] >= thirty_seconds_ago) & (parquet_stream_2["start_time"] <= current_time))

# filtered_stream_2 = filtered_stream_2.selectExpr("to_json(struct(*)) AS value")

# query_topic = filtered_stream_2.writeStream \
#     .outputMode("append")\
#     .format("kafka") \
#     .option("kafka.bootstrap.servers", f"{host_ip}:9092") \
#     .option("topic", "another_topic") \
#     .option("checkpointLocation", "./output/check/another_topic") \
#     .trigger(processingTime='30 seconds') \
#     .start()


In [None]:
query_topic.status

In [None]:
# query_topic.stop()

In [None]:
# # Stream 3
# query_skip = filter_df.selectExpr("to_json(struct(*)) AS value") \
#     .writeStream \
#     .outputMode("append")\
#     .format("kafka") \
#     .option("kafka.bootstrap.servers", f"{host_ip}:9092") \
#     .option("topic", "skip") \
#     .option("checkpointLocation", "./output/skip_new")\
#     .start()

In [None]:
# query_skip.status

In [None]:
# query_skip.stop()