In [None]:
from IPython.display import display, clear_output
clear_output(wait=True)

In [None]:
%matplotlib notebook

import os
import time
import matplotlib.pyplot as plt
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, avg, explode, to_timestamp
from pyspark.sql.types import StructType, StringType, IntegerType, DoubleType
from pymongo import MongoClient
import pandas as pd
plt.ion()  # Turn on interactive mode


# MongoDB connection
hostip = "10.192.41.222"
DB_NAME = "awas_db"

client = MongoClient(
        host=f'{hostip}',
        port=27017
    )
db = client["awas_db"]
violation_col = db["violation"]

print(violation_col.count_documents({}))

for doc in violation_col.find().limit(5):
    print(doc)


os.environ["PYSPARK_SUBMIT_ARGS"] = (
    "--packages "
    "org.apache.spark:spark-sql-kafka-0-10_2.12:3.5.0,"
    "org.apache.spark:spark-streaming-kafka-0-10_2.12:3.5.0,"
    "org.mongodb.spark:mongo-spark-connector_2.12:10.3.0 "
    "pyspark-shell"
)

# # Initialize Spark
# spark = SparkSession.builder \
#     .appName("AWAS-Speed-Enforcement") \
#     .master("local[*]") \
#     .config("spark.mongodb.read.connection.uri", f"mongodb://{hostip}:27017/{DB_NAME}") \
#     .config("spark.mongodb.write.connection.uri", f"mongodb://{hostip}:27017/{DB_NAME}") \
#     .getOrCreate()
# spark.sparkContext.setLogLevel("WARN")
# 
# spark.read \
#     .format("mongodb") \
#     .option("collection", "violation") \
#     .load()

# ====== Initialize plotting ======
def init_plot():
    fig = plt.figure(figsize=(9.5, 6))
    fig.subplots_adjust(hspace=0.6)
    ax1 = fig.add_subplot(221)
    ax1.set_xlabel('Arrival Time')
    ax1.set_ylabel('Number of Violations')
    ax1.set_title('Violations vs Time')
    ax2 = fig.add_subplot(222)
    ax2.set_xlabel('Arrival Time')
    ax2.set_ylabel('Average Speed')
    ax2.set_title('Avg Speed vs Time')
    fig.suptitle('AWAS Real-time Streaming Visualization')
    fig.canvas.draw()
    plt.pause(0.01)

    return fig, ax1, ax2

# ====== Annotations ======
def annotate_max(x, y, ax):
    ymax = max(y)
    xpos = y.index(ymax)
    ax.annotate(f'Max: {ymax}', xy=(x[xpos], ymax), xytext=(x[xpos], ymax+5),
                arrowprops=dict(facecolor='red', shrink=0.05))

def annotate_min(x, y, ax):
    ymin = min(y)
    xpos = y.index(ymin)
    ax.annotate(f'Min: {ymin}', xy=(x[xpos], ymin), xytext=(x[xpos], ymin+5),
                arrowprops=dict(facecolor='orange', shrink=0.05))

def annotate_avg(x, y, ax):
    avg_val = sum(y) / len(y)
    ax.axhline(avg_val, color='blue', linestyle='--', linewidth=1)
    ax.annotate(f'Avg: {avg_val:.2f}', xy=(x[-1], avg_val), xytext=(x[-1], avg_val + 5),
                arrowprops=dict(facecolor='blue', arrowstyle='->'), fontsize=9)

# def fetch_violation_data():
#     pipeline = [
#         {"$unwind": "$violations"},
#         {"$project": {
#             "_id": 0,
#             "timestamp": "$violations.timestamp_start"
#         }}
#     ]
#     cursor = violation_col.aggregate(pipeline)
#     timestamps = [doc["timestamp"] for doc in cursor if "timestamp" in doc]
#     return pd.DataFrame(timestamps, columns=["timestamp"])

def fetch_violation_data():
    pipeline = [
        {"$unwind": "$violations"},
        {"$project": {
            "timestamp": "$violations.timestamp_start",
            "recorded_speed": "$violations.recorded_speed"
        }}
    ]
    cursor = violation_col.aggregate(pipeline)
    return pd.DataFrame(list(cursor))

def stream_plot_loop(fig, ax1, ax2, interval_sec=5):
    while True:
        print("Plotting tick...")  # Debug print
        # ax1.set_title(f"Violations vs Time (tick {i})")
        try:
            df = fetch_violation_data()

            if df.empty:
                print("No data yet...")
                time.sleep(interval_sec)
                continue

            df["timestamp"] = pd.to_datetime(df["timestamp"], errors='coerce')
            print(df["timestamp"].max())

            df.dropna(subset=["timestamp"], inplace=True)

            df["rounded_time"] = df["timestamp"].dt.floor("min")

            # === Aggregation ===
            vdf = df.groupby("rounded_time").size().reset_index(name="num_violations")
            sdf = df.groupby("rounded_time")["recorded_speed"].mean().reset_index(name="avg_speed")

            if vdf.empty or sdf.empty:
                time.sleep(interval_sec)
                continue

            # Clear plots
            ax1.clear()
            ax2.clear()

            # Re-label (since cla() removes them)
            ax1.set_xlabel('Arrival Time')
            ax1.set_ylabel('Number of Violations')
            ax1.set_title('Violations vs Time')
            ax1.set_title(f"Tick at {pd.Timestamp.now()}")

            ax2.set_xlabel('Arrival Time')
            ax2.set_ylabel('Average Speed')
            ax2.set_title('Avg Speed vs Time')

            # Plot violation count
            x1, y1 = list(vdf["rounded_time"]), list(vdf["num_violations"])
            ax1.plot(x1, y1, color='red')
            annotate_max(x1, y1, ax1)
            annotate_min(x1, y1, ax1)

            # Plot avg speed
            x2, y2 = list(sdf["rounded_time"]), list(sdf["avg_speed"])
            ax2.plot(x2, y2, color='blue')
            annotate_max(x2, y2, ax2)
            annotate_min(x2, y2, ax2)
            annotate_avg(x2, y2, ax2)

            fig.canvas.draw_idle()
            print(f"Plotted {len(x1)} points. Max time: {max(x1)}")

            plt.pause(0.01)
            time.sleep(interval_sec)




        except Exception as e:
            print(f"[ERROR] {e}")
            time.sleep(interval_sec)


# ====== Run ======
# if __name__ == "__main__":
fig, ax1, ax2 = init_plot()
plt.show()  # Show the plot without blocking
stream_plot_loop(fig, ax1, ax2)