In [16]:
import os
import pandas as pd
os.environ['PYSPARK_SUBMIT_ARGS'] = '--packages org.apache.spark:spark-streaming-kafka-0-10_2.12:3.3.0,org.apache.spark:spark-sql-kafka-0-10_2.12:3.3.0,org.mongodb.spark:mongo-spark-connector_2.12:10.1.1 pyspark-shell'
from pymongo import MongoClient
from datetime import datetime
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

from pyspark.sql.types import (
    StructType, StringType, IntegerType, DoubleType, TimestampType, BooleanType
)
from pyspark.sql.functions import (
    col, expr, from_json
)
import uuid

class SparkInst:
    def __init__(self, app_name: str, batch_interval: int, kafka_output_topic: str):
        """
        Initializes a Spark instance with the given application name, batch interval, and Kafka topic.

        Args:
            app_name (str): The name of the Spark application.
            batch_interval (int): The interval (in seconds) at which streaming data is processed.
            kafka_topic (str): The name of the Kafka topic to consume from.
        """
        os.environ['PYSPARK_SUBMIT_ARGS'] = '--packages org.apache.spark:spark-streaming-kafka-0-10_2.12:3.3.0,org.apache.spark:spark-sql-kafka-0-10_2.12:3.3.0,org.mongodb.spark:mongo-spark-connector_2.12:10.1.1 pyspark-shell'
        self.batch_interval = batch_interval
        self.kafka_output_topic = kafka_output_topic
        self.eventSchema= StructType() \
                        .add("batch_id", IntegerType()) \
                        .add("event_id", StringType()) \
                        .add("car_plate", StringType()) \
                        .add("camera_id", IntegerType()) \
                        .add("timestamp", TimestampType()) \
                        .add("speed_reading", DoubleType()) \
                        .add("producer_id", StringType()) \
                        .add("sent_at", TimestampType())
        self.spark = SparkSession.builder.appName(app_name).master("local[*]").getOrCreate()
        
        # immediately bump the KafkaDataConsumer logger to ERROR
        sc = self.spark.sparkContext
        jvm = sc._jvm
        LogManager = jvm.org.apache.log4j.LogManager
        Level      = jvm.org.apache.log4j.Level
        kafka_logger = LogManager.getLogger("org.apache.spark.sql.kafka010.KafkaDataConsumer")
        kafka_logger.setLevel(Level.ERROR)
        

    def get_session(self):
        return self.spark
    
    def attach_kafka_stream(self, topic_name:str, hostip:str, watermark_time:str):
        return (
            self.spark.readStream
            .format("kafka")
            .option("kafka.bootstrap.servers", f"{hostip}:9092")
            .option("subscribe", topic_name)
            .option("startingOffsets", "earliest")
            .load()
            .selectExpr("CAST(value AS STRING) as json")
            .select(from_json(col("json"), self.eventSchema).alias("data"))
            .select("data.*")
            .withWatermark("timestamp", watermark_time)
        )
    

    def essentialData_broadcast(self, sdf):
        """
        Filter a Spark DataFrame by topic_id and broadcast it.

        Args:
            sdf (DataFrame): Spark DataFrame

        Returns:
            Broadcast variable containing a dictionary of camera_id to speed_limit
        """
        # Select necessary columns
        df_filtered = sdf.select("camera_id", "speed_limit")

        # Convert to a Python dictionary (camera_id -> speed_limit)
        data = df_filtered.rdd.map(lambda row: (row["camera_id"], row["speed_limit"])).collectAsMap()

        # Broadcast the dictionary
        spark_context = self.spark.sparkContext
        return spark_context.broadcast(data)



class DbWriter():
    def __init__(self, mongo_host, mongo_port, mongo_db, mongo_coll):
        self.mongo_host = mongo_host
        self.mongo_port = mongo_port
        self.mongo_db   = mongo_db
        self.mongo_coll = mongo_coll
        self.client     = None
        self.violation_coll  = None

    def open(self, partition_id: str, epoch_id: str) -> bool:
        from pymongo import MongoClient
        self.client = MongoClient(host=self.mongo_host, port=self.mongo_port)
        self.violation_coll  = self.client[self.mongo_db][self.mongo_coll]
        return True

    def process(self, row):
        try:
            print(f"\nProcessing: {row.asDict()}")
            t_a = row.timestamp_a
            t_b = row.timestamp_b
            t_c = row.timestamp_c

            if isinstance(t_a, str):
                t_a = datetime.fromisoformat(t_a)

            if isinstance(t_b, str):
                t_b = datetime.fromisoformat(t_b)

            if isinstance(t_c, str):
                t_c = datetime.fromisoformat(t_c)

            date_bucket_a = datetime(t_a.year, t_a.month, t_a.day)
            date_bucket_b = datetime(t_b.year, t_b.month, t_b.day)
            date_bucket_c = datetime(t_c.year, t_c.month, t_c.day)

            violations_a = []
            violations_b = []
            violations_c = []

            if row.speed_flag_instant_a:
                violations_a.append({
                    "type": "instantaneous",
                    "camera_id_start": row.camera_id_a,
                    "camera_id_end": None,
                    "timestamp_start": t_a,
                    "timestamp_end": None,
                    "measured_speed": row.speed_reading_a
                })
            if row.speed_flag_instant_b:
                violations_b.append({
                    "type": "instantaneous",
                    "camera_id_start": row.camera_id_b,
                    "camera_id_end": None,
                    "timestamp_start": t_b,
                    "timestamp_end": None,
                    "measured_speed": row.speed_reading_b
                })
            if row.speed_flag_instant_c:
                violations_c.append({
                    "type": "instantaneous",
                    "camera_id_start": row.camera_id_c,
                    "camera_id_end": None,
                    "timestamp_start": t_c,
                    "timestamp_end": None,
                    "measured_speed": row.speed_reading_c
                })
            if row.speed_flag_average_ab:
                violations_b.append({
                    "type": "average",
                    "camera_id_start": row.camera_id_a,
                    "camera_id_end": row.camera_id_b,
                    "timestamp_start": t_a,
                    "timestamp_end": t_b,
                    "measured_speed": row.avg_speed_reading_ab
                })
            if row.speed_flag_average_bc:
                violations_c.append({
                    "type": "average",
                    "camera_id_start": row.camera_id_b,
                    "camera_id_end": row.camera_id_c,
                    "timestamp_start": t_b,
                    "timestamp_end": t_c,
                    "measured_speed": row.avg_speed_reading_bc
                })

            existing_a = self.violation_coll.find_one({"car_plate": row.car_plate, "date": date_bucket_a})
            if existing_a and len(violations_a) > 0:
                for violation in violations_a:
                    existing_a["violations"].append(violation)
                    self.violation_coll.update_one(
                        {"car_plate": row.car_plate, "date": date_bucket_a},
                        {"$set": {"violations": existing_a["violations"]}},
                    )
            elif len(violations_a) > 0:
                self.violation_coll.insert_one(
                    {
                        "violation_id": str(uuid.uuid4()),  # or f"{data['car_plate']}_{date_bucket.date()}"
                        "car_plate":    row.car_plate,
                        "date":         date_bucket_a,
                        "violations":   violations_a
                    }
                )

            existing_b = self.violation_coll.find_one({"car_plate": row.car_plate, "date": date_bucket_b})
            if existing_b and len(violations_b) > 0:
                for violation in violations_b:
                    existing_b["violations"].append(violation)
                    self.violation_coll.update_one(
                        {"car_plate": row.car_plate, "date": date_bucket_b},
                        {"$set": {"violations": existing_b["violations"]}},
                    )
            elif len(violations_b) > 0:
                self.violation_coll.insert_one(
                    {
                        "violation_id": str(uuid.uuid4()),  # or f"{data['car_plate']}_{date_bucket.date()}"
                        "car_plate":    row.car_plate,
                        "date":         date_bucket_b,
                        "violations":   violations_b
                    }
                )

            existing_c = self.violation_coll.find_one({"car_plate": row.car_plate, "date": date_bucket_c})                                    
            if existing_c and len(violations_c) > 0:
                for violation in violations_c:
                    existing_c["violations"].append(violation)
                    self.violation_coll.update_one(
                        {"car_plate": row.car_plate, "date": date_bucket_c},
                        {"$set": {"violations": existing_c["violations"]}},
                    )
            elif len(violations_c) > 0:
                self.violation_coll.insert_one(
                    {
                        "violation_id": str(uuid.uuid4()),  # or f"{data['car_plate']}_{date_bucket.date()}"
                        "car_plate":    row.car_plate,
                        "date":         date_bucket_c,
                        "violations":   violations_c
                    }
                )
#             print(f"\nAdded violations: {sum([len(violations_a),len(violations_b),len(violations_c)])}")
            if sum([len(violations_a),len(violations_b),len(violations_c)] == 0 :
                   print("No violations detected for {row.car_plate} from {t_a} to {t_c}")
        except Exception as e:
            # this will print on the executor logs
            print(f"[DbWriter][ERROR] failed to process row {row}: {e}")
            # optionally, you could write to a dead‐letter collection instead
                                                  
    def close(self, error):
        if error:
            # this also shows up in the executor log
            print(f"[DbWriter][ERROR] task shutting down due to: {error}")
        if self.client:
            self.client.close()

In [17]:
os.environ['PYSPARK_SUBMIT_ARGS'] = '--packages org.apache.spark:spark-streaming-kafka-0-10_2.12:3.3.0,org.apache.spark:spark-sql-kafka-0-10_2.12:3.3.0,org.mongodb.spark:mongo-spark-connector_2.12:10.1.1 pyspark-shell'
spark_job=SparkInst("AWAS SYSTEM", 5, kafka_output_topic="violations")

In [18]:
import sys
import os
# add the folder where util.py lives
from pyspark.sql.functions import udf, col, window, lit
from pyspark.sql.types import StringType
import pandas as pd

df_pd = pd.read_csv("data/camera.csv")
if '_id' in df_pd.columns:
    df_pd.drop(columns=['_id'], inplace=True)
spark_df = spark_job.get_session().createDataFrame(df_pd)

speed_limit_map = {row['camera_id']: row['speed_limit'] for row in spark_df.select("camera_id", "speed_limit").collect()}
broadcast_map = spark_job.essentialData_broadcast(spark_df)

def mark_speeding(camera_id:str, speed:float, ops:str)-> str:
    """
    """
    limit = broadcast_map.value.get(camera_id)
    if limit is not None and ops == "instant":
        return True if speed > limit else False
    elif limit is not None and ops == "average":
        return True  if speed > limit else False
    return False

speeding_udf = udf(mark_speeding, BooleanType())

# Step 5: Apply UDF to each streaming dataframe
def add_speed_flag(df, ops: str):
    return df.withColumn(f"speed_flag_{ops}", speeding_udf(col("camera_id"), col("speed_reading"), lit(ops)))


  for column, series in pdf.iteritems():
  for column, series in pdf.iteritems():


In [19]:
from pyspark.sql.functions import expr, col, lit

# Attach Kafka streams
stream_a = spark_job.attach_kafka_stream("camera_event_a", "172.22.32.1", "24 hours")
stream_b = spark_job.attach_kafka_stream("camera_event_b", "172.22.32.1", "24 hours")
stream_c = spark_job.attach_kafka_stream("camera_event_c", "172.22.32.1", "24 hours")
from pyspark.sql.functions import expr, col, lit

# Flag and drop unnecessary fields
# stream_a_flagged = add_speed_flag(stream_a.drop("batch_id", "event_id", "sent_at"), "instant")
# stream_b_flagged = add_speed_flag(stream_b.drop("batch_id", "event_id", "sent_at"), "instant")
# stream_c_flagged = add_speed_flag(stream_c.drop("batch_id", "event_id", "sent_at"), "instant")
stream_a_flagged = add_speed_flag(stream_a.drop("event_id", "batch_id", "sent_at"), "instant")
stream_b_flagged = add_speed_flag(stream_b.drop("event_id", "batch_id", "sent_at"), "instant")
stream_c_flagged = add_speed_flag(stream_c.drop("event_id", "batch_id", "sent_at"), "instant")

# Rename for joining
a = stream_a_flagged.selectExpr(
    "car_plate",
    "batch_id as batch_id_a",
    "camera_id as camera_id_a",
    "timestamp as timestamp_a",
    "speed_reading as speed_reading_a",
    "producer_id as producer_a",
    "speed_flag_instant as speed_flag_instant_a"
)

b = stream_b_flagged.selectExpr(
    "car_plate",
    "batch_id as batch_id_b",
    "camera_id as camera_id_b",
    "timestamp as timestamp_b",
    "speed_reading as speed_reading_b",
    "producer_id as producer_b",
    "speed_flag_instant as speed_flag_instant_b"
)

c = stream_c_flagged.selectExpr(
    "car_plate",
    "batch_id as batch_id_c",
    "camera_id as camera_id_c",
    "timestamp as timestamp_c",
    "speed_reading as speed_reading_c",
    "producer_id as producer_c",
    "speed_flag_instant as speed_flag_instant_c"
)

# Join A & B
ab_join = b.alias("b").join(
    a.alias("a"),
    (col("a.car_plate") == col("b.car_plate")) &
    (col("a.timestamp_a") < col("b.timestamp_b")) &
    (col("b.timestamp_b") <= col("a.timestamp_a") + expr("interval 10 minutes")),
    "inner"
).select(
    col("a.car_plate"),
    col("a.batch_id_a"),
    col("a.camera_id_a"),
    col("a.timestamp_a"),
    col("a.speed_reading_a"),
    col("a.speed_flag_instant_a"),
    ((col("a.speed_reading_a") + col("b.speed_reading_b")) / 2).alias("avg_speed_reading_ab"),
    speeding_udf(
        col("a.camera_id_a"),
        ((col("a.speed_reading_a") + col("b.speed_reading_b")) / 2),
        lit("average")
    ).alias("speed_flag_average_ab"),
    col("b.batch_id_b"),
    col("b.camera_id_b"),
    col("b.timestamp_b"),
    col("b.speed_reading_b"),
    col("b.speed_flag_instant_b")
)

# Join AB & C
abc_join = ab_join.alias("ab").join(
    c.alias("c"),
    (col("ab.car_plate") == col("c.car_plate")) &
    (col("c.timestamp_c") > col("ab.timestamp_b")) &
    (col("c.timestamp_c") <= col("ab.timestamp_b") + expr("interval 10 minutes")),
    "inner"
).select(
    col("ab.*"),
    ((col("ab.speed_reading_b") + col("c.speed_reading_c")) / 2).alias("avg_speed_reading_bc"),
    speeding_udf(
        col("ab.camera_id_b"),
        ((col("ab.speed_reading_b") + col("c.speed_reading_c")) / 2),
        lit("average")
    ).alias("speed_flag_average_bc"),
    col("c.batch_id_c"),
    col("c.camera_id_c"),
    col("c.timestamp_c"),
    col("c.speed_reading_c"),
    col("c.speed_flag_instant_c")
)

In [20]:
import os
import shutil
from pyspark.sql.streaming import StreamingQueryException

checkpoint_dir = "./stream_checkpoints"

# 1) Clean up any existing checkpoint directory before starting
if os.path.isdir(checkpoint_dir):
    shutil.rmtree(checkpoint_dir)
    print(f"Deleted existing checkpoint directory: {checkpoint_dir}")

# Write to the console
query = (
    abc_join.writeStream
    .format("console")
    .option("checkpointLocation", "./stream_checkpoints")
    .outputMode("append")
    .foreachBatch(DbWriter(
        mongo_host="172.22.32.1",
        mongo_port=27017,
        mongo_db="fit3182_db",
        mongo_coll="Violation"
    ))
    .option("numRows", 1000)
    .option("truncate", False)  # Optional: show full column contents
    .start()
)

# Write to the console
# query = (
#     abc_join.writeStream
#     .format("console")
#     .option("checkpointLocation", "./stream_checkpoints")
#     .outputMode("append")
#     .option("numRows", 1000)
#     .option("truncate", False)  # Optional: show full column contents
#     .start()
# )

# Run query and handle termination gracefully
try:
    query.awaitTermination()
except KeyboardInterrupt:
    print("Interrupted by CTRL-C. Stopping query.")
except StreamingQueryException as exc:
    print(f"Streaming error: {exc}")
finally:
    query.stop()

Deleted existing checkpoint directory: ./stream_checkpoints


ERROR:root:KeyboardInterrupt while sending command.
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
  File "/opt/conda/lib/python3.8/site-packages/py4j/clientserver.py", line 511, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
  File "/opt/conda/lib/python3.8/socket.py", line 669, in readinto
    return self._sock.recv_into(b)
KeyboardInterrupt


Interrupted by CTRL-C. Stopping query.
