# 0. Setting up a Spark Session

In [1]:
from pyspark.sql import SparkSession
import os
os.environ['PYSPARK_SUBMIT_ARGS'] = '--packages org.apache.spark:spark-streaming-kafka-0-10_2.12:3.5.0,org.apache.spark:spark-sql-kafka-0-10_2.12:3.5.0 pyspark-shell'
KAFKA_BOOTSTRAP_SERVERS = "kafka:9092"
KAFKA_TOPIC = "transaction_data"
spark = SparkSession.builder\
                    .appName("transaction_consumer")\
                    .master("spark://spark:7077")\
                    .getOrCreate()
spark.sparkContext.setLogLevel("WARN") 

# 1. Reading Kafka Stream

In [2]:
df = spark.readStream.format("kafka") \
    .option("kafka.bootstrap.servers", KAFKA_BOOTSTRAP_SERVERS) \
    .option("subscribe", KAFKA_TOPIC) \
    .option("startingOffsets", "earliest") \
    .load()

query = df \
    .withWatermark("timestamp", "2 seconds") \
    .writeStream \
    .outputMode("append") \
    .format("memory") \
    .queryName("transaction_data") \
    .start()

In [3]:
from IPython.display import display, clear_output
from time import sleep 
from pyspark.ml.pipeline import PipelineModel
from pyspark.sql.functions import first, col
import pyspark.sql.utils

classifier_model = PipelineModel.load("/src/gbt_model.model")
while True:
    clear_output(wait=True)
    display(query.status)
    input_data = spark.sql("SELECT timestamp, CAST(key AS STRING), CAST(value AS STRING) FROM transaction_data ORDER BY timestamp DESC LIMIT 30")
    length = len(input_data.collect())
    print(f"Transaction data [{length}/30]:")
    display(input_data.show(5))

    if (length < 30):
        print("[!] Incomplete transaction data stream!")
    else:
        try:
            input_data = input_data.groupby().pivot("key").agg(first(col("value")))
            doubles = ['V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9', 'V10', 'V11', 'V12',\
                   'V13', 'V14', 'V15', 'V16', 'V17', 'V18', 'V19', 'V20', 'V21', 'V22', 'V23',\
                   'V24', 'V25', 'V26', 'V27', 'V28', 'Amount']

            for i in doubles:
                input_data = input_data.withColumn(i, col(i).cast('double'))
        
            preds = classifier_model.transform(input_data).collect()[0].asDict()
            print(f"Probabilities: {preds['rawPrediction']} Result: {preds['prediction']}")
            if (preds['prediction'] == 0):
                print("[i] Transaction OK")
            else:
                print("[!] Fraudulent transaction detected!")
                break
        except pyspark.sql.utils.AnalysisException:
            print("[i] Fragmented data, skipping...")
        
    sleep(3)

{'message': 'Waiting for data to arrive',
 'isDataAvailable': False,
 'isTriggerActive': False}

Transaction data [30/30]:
+--------------------+------+-----------------+
|           timestamp|   key|            value|
+--------------------+------+-----------------+
|2024-01-07 08:23:...|Amount|              2.0|
|2024-01-07 08:23:...|   V12| -2.1498630541179|
|2024-01-07 08:23:...| Class|              0.0|
|2024-01-07 08:23:...|    V1|0.503301988400519|
|2024-01-07 08:23:...|   V10|-1.97461696232863|
+--------------------+------+-----------------+
only showing top 5 rows



None

Probabilities: [-0.2826651861428179,0.2826651861428179] Result: 1.0
[!] Fraudulent transaction detected!
