In [1]:
import os
os.environ['PYSPARK_SUBMIT_ARGS'] = (
    '--packages org.apache.spark:spark-sql-kafka-0-10_2.12:3.5.0 '
    'pyspark-shell'
)

In [2]:
!pip install redis





In [3]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from pyspark.ml import PipelineModel
from pyspark.ml.classification import RandomForestClassificationModel
from pyspark.ml.feature import VectorAssembler
from pyspark.sql.types import StringType
import redis
from pyspark.sql import SparkSession
from pyspark.sql.types import *
import pyspark.sql.functions as F
import pyspark

jars = [
    "/home/jovyan/spark_jars/hadoop-aws-3.3.4.jar",
    "/home/jovyan/spark_jars/aws-java-sdk-bundle-1.12.262.jar",
    "/home/jovyan/spark_jars/hadoop-common-3.3.4.jar",
    "/home/jovyan/spark_jars/delta-spark_2.12-3.2.0.jar",
    "/home/jovyan/spark_jars/delta-storage-3.2.0.jar",
    "/home/jovyan/spark_jars/spark-redis_2.12-3.5.0.jar"
]


spark = SparkSession.builder.appName('Stream Demo') \
    .config("spark.jars", ",".join(jars)) \
    .config("spark.hadoop.fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem") \
    .config("spark.hadoop.fs.s3a.access.key", "AKIA2CUNLJPWTKKNAQVQ") \
    .config("spark.hadoop.fs.s3a.secret.key", "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") \
    .config("spark.hadoop.fs.s3a.endpoint", "s3.amazonaws.com") \
    .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
    .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \
    .config("spark.redis.host", "redis-server") \
    .config("spark.redis.port", "6379") \
    .config("spark.redis.db", "0") \
    .getOrCreate()

# Set the legacy time parser policy to handle the date format correctly
spark.conf.set("spark.sql.legacy.timeParserPolicy", "LEGACY")

print(spark.sparkContext._jsc.sc().getConf().get("spark.jars"))

spark 

/home/jovyan/spark_jars/hadoop-aws-3.3.4.jar,/home/jovyan/spark_jars/aws-java-sdk-bundle-1.12.262.jar,/home/jovyan/spark_jars/hadoop-common-3.3.4.jar,/home/jovyan/spark_jars/delta-spark_2.12-3.2.0.jar,/home/jovyan/spark_jars/delta-storage-3.2.0.jar,/home/jovyan/spark_jars/spark-redis_2.12-3.5.0.jar


In [None]:
!pip install kafka-python

In [None]:
# FOR TESTING
from kafka import KafkaConsumer
import json

consumer = KafkaConsumer("weather-data", bootstrap_servers="kafka:9092",
                        auto_offset_reset="latest",
                        value_deserializer=lambda m: json.loads(m.decode("utf-8")))
for message in consumer:
    print(json.dumps(message.value, indent=4))
    break  # just check one

from kafka import KafkaConsumer
import json

consumer = KafkaConsumer("traffic-data", bootstrap_servers="kafka:9092",
                        auto_offset_reset="latest",
                        value_deserializer=lambda m: json.loads(m.decode("utf-8")))
for message in consumer:
    print(json.dumps(message.value, indent=4))
    break  # just check one

In [4]:
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.functions import date_trunc
from pyspark.sql.functions import unix_timestamp, from_unixtime, floor


weather_stream = spark \
    .readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", "kafka:9092") \
    .option("subscribe", "weather-data") \
    .option("startingOffsets", "latest") \
    .load()

#.option("startingOffsets", "latest") \

#.option("startingOffsets", "latest") \


traffic_stream = spark \
    .readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", "kafka:9092") \
    .option("subscribe", "traffic-data") \
    .option("startingOffsets", "latest") \
    .load()


#.option("startingOffsets", "latest") \


weather_json_df = weather_stream.selectExpr("CAST(value AS STRING) as value")

traffic_json_df = traffic_stream.selectExpr("CAST(value AS STRING) as value")

#traffic_json_df = traffic_stream.withColumn('value', expr('cast(value as string)')).withColumn('key', expr('cast(key as string)'))

# {"latitude": 49.2838889, "longitude": -122.7933334, "current_speed": 28, "free_flow_speed": 28, 
#"current_travel_time": 131, "free_flow_travel_time": 131, "confidence": 1, "road_closure": false}

weather_schema = StructType([
    StructField('name', StringType(), True),
    StructField('latitude', DoubleType(), True),
    StructField('longitude', DoubleType(), True),
    StructField('date', StringType(), True),
    StructField('weather', StringType(), True),
    StructField('weather_description', StringType(), True),
    StructField('temp', DoubleType(), True),
    StructField('visibility', IntegerType(), True),
    StructField('clouds', IntegerType(), True),
    StructField('rain', DoubleType(), True),
    StructField('snow', DoubleType(), True),
])

traffic_schema = StructType([
    StructField("latitude", DoubleType(), True),
    StructField("longitude", DoubleType(), True),
    StructField("current_speed", IntegerType(), True),
    StructField("free_flow_speed", IntegerType(), True),
    StructField("current_travel_time", IntegerType(), True),
    StructField("free_flow_travel_time", IntegerType(), True),
    StructField("confidence", IntegerType(), True),
    StructField("road_closure", BooleanType(), True),
    StructField('date', StringType(), True)
])

# Parse the 'value' column as JSON
weather_parsed_df = weather_json_df.select(from_json("value", weather_schema).alias("data"))
traffic_parsed_df = traffic_json_df.select(from_json("value", traffic_schema).alias("data"))
#traffic_parsed_df = traffic_json_df.withColumn("values_json", from_json(col("value"), traffic_schema))


# Flatten the JSON into separate columns
weather_flatten_df = weather_parsed_df.select(
                col('data.name').alias('name'),
                col('data.latitude').alias('latitude'),
                col('data.longitude').alias('longitude'), 
                from_utc_timestamp(to_timestamp(col("data.date")), "America/Los_Angeles").alias("date"),
                col('data.weather').alias('weather'), 
                col('data.weather_description').alias('weather_description'), 
                col('data.temp').alias('temp'), 
                col('data.visibility').alias('visibility'),
                col('data.clouds').alias('clouds'),
                col('data.rain').alias('rain'),
                col('data.snow').alias('snow'))
    
weather_flatten_df = weather_flatten_df.drop("date_unix")

traffic_flatten_df = traffic_parsed_df.select(
    col("data.latitude").alias("latitude"), 
    col("data.longitude").alias("longitude"),
    col("data.current_speed").alias("current_speed"),
    col("data.free_flow_speed").alias("free_flow_speed"),
    col("data.current_travel_time").alias("current_travel_time"),
    col("data.free_flow_travel_time").alias("free_flow_travel_time"),
    col("data.confidence").alias("confidence"),
    col("data.road_closure").alias("road_closure"),
    to_timestamp(col("data.date"), "EEE, dd MMM yyyy HH:mm:ss z").alias("date_utc"),
    from_utc_timestamp(col("date_utc"), "America/Los_Angeles").alias("date")
)

traffic_flatten_df = traffic_flatten_df.drop("date_utc")

traffic_flatten_df = traffic_flatten_df.withColumn("speed_diff", 
                                           col("current_speed") - col("free_flow_speed"))

# Timestamp and watermark for windowing
weather_flatten_df = weather_flatten_df.withColumn("processing_time", current_timestamp())
traffic_flatten_df = traffic_flatten_df.withColumn("processing_time", current_timestamp())

# Watermark for late data
weather_flatten_df = weather_flatten_df.withWatermark("date", "10 minutes")
traffic_flatten_df = traffic_flatten_df.withWatermark("date", "10 minutes")

weather_flatten_df = weather_flatten_df.withColumn("lat_bin", round(col("latitude"), 3)) \
                                       .withColumn("lon_bin", round(col("longitude"), 3)) \
                                       .withColumn("month", month(col("date"))) \
                                       .withColumn("day", dayofmonth(col("date"))) \
                                       .withColumn("hour", hour(col("date"))) \
                                       .withColumn("day_of_week", dayofweek(col("date"))) \
                                    

traffic_flatten_df = traffic_flatten_df.withColumn("lat_bin", round(col("latitude"), 3)) \
                                       .withColumn("lon_bin", round(col("longitude"), 3)) 

#.withColumn("month", month(col("date"))) \
#.withColumn("day", dayofmonth(col("date"))) \
#.withColumn("hour", hour(col("date"))) \
#.withColumn("day_of_week", dayofweek(col("date"))) \


#window(col("date"), "10 minutes", "1 minute"),
weather_agg_df = weather_flatten_df.groupBy(
    window(col("date"), "1 minute"),
    col("name"),
    col("lat_bin"),
    col("lon_bin")
).agg(
    F.avg("temp").alias("avg_temp"),
    F.avg("visibility").alias("avg_visibility"),
    F.avg("clouds").alias("avg_clouds"),
    F.max("rain").alias("max_rain"),
    F.max("snow").alias("max_snow"),
    F.first("weather").alias("last_weather"),
    F.first("weather_description").alias("last_weather_description"),
    F.first("hour").alias("hour"),
    F.first("month").alias("month"),
    F.first("day_of_week").alias("day_of_week")   
)

weather_agg_df = weather_agg_df.withColumn('municipality', lower(col('name'))) \
                               .drop('name') \
                               .withColumn("time_bin", col("window.start")) \
                               .withColumn('last_weather', lower(col('last_weather'))) \
                               .withColumn('last_weather_description', lower(col('last_weather_description'))) \
                               .withColumn("time_period",
                                           F.when((col("hour") >= 6) & (col("hour") <= 11), "morning")
                                            .when((col("hour") >= 12) & (col("hour") <= 17), "afternoon")
                                            .when((col("hour") >= 18) & (col("hour") <= 23), "evening")
                                            .otherwise("night")) \
                               .withColumn("season",
                                           F.when((col('month') == 3) | (col('month') == 4) | (col('month') == 5), 'spring')
                                            .when((col('month') == 6) | (col('month') == 7) | (col('month') == 8), 'summer')
                                            .when((col('month') == 9) | (col('month') == 10) | (col('month') == 11), 'autumn')
                                            .when((col('month') == 12) | (col('month') == 1) | (col('month') == 2), 'winter')) \
                                .withColumn("pct_is_weekend", F.when((col("day_of_week") == 1) | (col("day_of_week") == 7), 1) \
                                .otherwise(0)) \
                                .withColumn("pct_is_rush_hour", 
                                            F.when((col("hour") >= 7) & (col("hour") <= 9), 1) \
                                            .when((col("hour") >= 16) & (col("hour") <= 18), 1) \
                                            .otherwise(0))

#weather_agg_df.show()
#window(col("date"), "10 minutes", "1 minute"),
traffic_agg_df = traffic_flatten_df.groupBy(
    window(col("date"), "1 minute"),
    col("lat_bin"),
    col("lon_bin")
).agg(
    F.avg("current_speed").alias("avg_speed"),
    F.avg("free_flow_speed").alias("avg_flow_speed"),
    F.avg("current_travel_time").alias("avg_travel_time"),
    F.avg("free_flow_travel_time").alias("avg_flow_travel_time"),
    F.avg("speed_diff").alias("avg_speed_diff"),
    F.max(col("road_closure").cast("double")).alias("had_closure"),
)

#    F.first("hour").alias("hour"),
#    F.first("month").alias("month"),
#    F.first("day_of_week").alias("day_of_week")

traffic_agg_df = traffic_agg_df.withColumn("time_bin", col("window.start")) \
                               .withColumn("month", month(col("window.start"))) \
                               .withColumn("day", dayofmonth(col("window.start"))) \
                               .withColumn("hour", hour(col("window.start"))) \
                               .withColumn("day_of_week", dayofweek(col("window.start"))) \
                               .withColumn("time_period",
                                           F.when((col("hour") >= 6) & (col("hour") <= 11), "morning")
                                            .when((col("hour") >= 12) & (col("hour") <= 17), "afternoon")
                                            .when((col("hour") >= 18) & (col("hour") <= 23), "evening")
                                            .otherwise("night")) \
                               .withColumn("season",
                                           F.when((col('month') == 3) | (col('month') == 4) | (col('month') == 5), 'spring')
                                            .when((col('month') == 6) | (col('month') == 7) | (col('month') == 8), 'summer')
                                            .when((col('month') == 9) | (col('month') == 10) | (col('month') == 11), 'autumn')
                                            .when((col('month') == 12) | (col('month') == 1) | (col('month') == 2), 'winter')) \
                                .withColumn("pct_is_weekend", F.when((col("day_of_week") == 1) | (col("day_of_week") == 7), 1) \
                                .otherwise(0)) \
                                .withColumn("pct_is_rush_hour", 
                                            F.when((col("hour") >= 7) & (col("hour") <= 9), 1) \
                                            .when((col("hour") >= 16) & (col("hour") <= 18), 1) \
                                            .otherwise(0))


#weather_flatten_df.select("date").orderBy("date", ascending=False).show(5, truncate=False)

#traffic_flatten_df.select("date").orderBy("date", ascending=False).show(5, truncate=False)

In [None]:
weather_flatten_df.show()
traffic_flatten_df.show()
#weather_agg_df.printSchema()

In [None]:
weather_keys = weather_agg_df.select("lat_bin", "lon_bin", "time_bin").withColumn("source", lit("weather"))
traffic_keys = traffic_agg_df.select("lat_bin", "lon_bin", "time_bin").withColumn("source", lit("traffic"))

all_keys = weather_keys.union(traffic_keys)

key_counts = all_keys.groupBy("lat_bin", "lon_bin", "time_bin").agg(countDistinct("source").alias("sources_present"))

key_counts.filter("sources_present = 2").show(truncate=False)


In [None]:
weather_agg_df.select("time_bin").orderBy("time_bin", ascending=False).show(5, False)
traffic_agg_df.select("time_bin").orderBy("time_bin", ascending=False).show(5, False)

In [None]:
# TESTING
weather_debug_query = weather_flatten_df \
    .select("date") \
    .writeStream \
    .format("console") \
    .outputMode("append") \
    .option("truncate", False) \
    .start()

traffic_debug_query = traffic_flatten_df \
    .select("date") \
    .writeStream \
    .format("console") \
    .outputMode("append") \
    .option("truncate", False) \
    .start()

weather_debug_query.awaitTermination()
traffic_debug_query.awaitTermination()

In [None]:
weather_debug_query = weather_flatten_df \
    .selectExpr("CAST(date AS STRING) AS date", "to_json(struct(*)) AS json") \
    .writeStream \
    .format("console") \
    .option("truncate", False) \
    .outputMode("append") \
    .start()

traffic_debug_query = traffic_flatten_df \
    .selectExpr("CAST(date AS STRING) AS date", "to_json(struct(*)) AS json") \
    .writeStream \
    .format("console") \
    .option("truncate", False) \
    .outputMode("append") \
    .start()

weather_debug_query.awaitTermination()
traffic_debug_query.awaitTermination()

In [5]:

train_df = spark.read.parquet("/home/jovyan/work/final_real_df")

# Rename columns in traffic_agg_df before joining to avoid duplication
traffic_agg_df = traffic_agg_df.withColumnRenamed("day", "traffic_day") \
                                        .withColumnRenamed("date", "traffic_date") \
                                        .withColumnRenamed("month", "traffic_month") \
                                        .withColumnRenamed("hour", "traffic_hour") \
                                        .withColumnRenamed("time_period", "traffic_time_period") \
                                        .withColumnRenamed("season", "traffic_season") \
                                        .withColumnRenamed("day_of_week", "traffic_day_of_week") \
                                        .withColumnRenamed("pct_is_weekend", "traffic_is_weekend") \
                                        .withColumnRenamed("pct_is_rush_hour", "traffic_pct_is_rush_hour")

combined_df = weather_agg_df.join(
    traffic_agg_df, 
    on=["lat_bin", 
        "lon_bin", 
        "time_bin"], 
    how="inner")


combined_df = combined_df.drop("traffic_day", "traffic_date", "traffic_month", "traffic_hour", "traffic_time_period", "traffic_season", "traffic_day_of_week", "traffic_is_weekend", "traffic_pct_is_rush_hour")

combined_df = combined_df.withColumn('last_weather', lower('last_weather')) \
                .withColumnRenamed("name", "municipality") \
                .withColumn("pct_is_weekend", col("pct_is_weekend").cast("double")) \
                .withColumn("pct_is_rush_hour", col("pct_is_rush_hour").cast("double")) \
                .drop('latitude') \
                .drop('longitude') \
                .drop('day') \
                .drop('month') \
                .drop('hour') \
                .drop('day_of_week') \
                .drop('window')
                

mean_pct_speed_involved = train_df.select(mean("pct_speed_involved")).collect()[0][0]
mean_pct_drug_involved = train_df.select(mean("pct_drug_involved")).collect()[0][0]
mean_pct_impaired_involved = train_df.select(mean("pct_impaired_involved")).collect()[0][0]
mean_avg_total_vehicles_involved = train_df.select(mean("avg_total_vehicles_involved")).collect()[0][0]
mean_avg_total_casualty = train_df.select(mean("avg_total_casualty")).collect()[0][0]
mean_avg_speed_limit = train_df.select(mean("avg_speed_limit")).collect()[0][0]
mean_pct_distraction_involved = train_df.select(mean("pct_distraction_involved")).collect()[0][0]
mean_pct_intersection_crash = train_df.select(mean("pct_intersection_crash")).collect()[0][0]
mean_pct_pedestrian_involved = train_df.select(mean("pct_pedestrian_involved")).collect()[0][0]
mode_weather = train_df.select(mode("weather")).collect()[0][0]
mode_road_condition = train_df.select(mode("road_condition")).collect()[0][0]

mean_avg_speed = train_df.select(mean("avg_speed")).collect()[0][0]
mean_avg_flow_speed = train_df.select(mean("avg_flow_speed")).collect()[0][0]
mean_avg_travel_time = train_df.select(mean("avg_travel_time")).collect()[0][0]
mean_avg_flow_travel_time = train_df.select(mean("avg_flow_travel_time")).collect()[0][0]
mean_avg_speed_diff = train_df.select(mean("avg_speed_diff")).collect()[0][0]
mean_had_closure = train_df.select(mean("had_closure")).collect()[0][0]
                
# Adding missing columns with placeholder values (e.g., 0, 'Unknown', etc.)
combined_df = combined_df.withColumn("avg_speed", when(col("avg_speed").isNull(), mean_avg_speed).otherwise(col("avg_speed")))
combined_df = combined_df.withColumn("avg_flow_speed", when(col("avg_flow_speed").isNull(), mean_avg_flow_speed).otherwise(col("avg_flow_speed")))
combined_df = combined_df.withColumn("avg_travel_time", when(col("avg_travel_time").isNull(), mean_avg_travel_time).otherwise(col("avg_travel_time")))
combined_df = combined_df.withColumn("avg_flow_travel_time", when(col("avg_flow_travel_time").isNull(), mean_avg_flow_travel_time).otherwise(col("avg_flow_travel_time")))
combined_df = combined_df.withColumn("avg_speed_diff", when(col("avg_speed_diff").isNull(), mean_avg_speed_diff).otherwise(col("avg_speed_diff")))
combined_df = combined_df.withColumn("had_closure", when(col("had_closure").isNull(), mean_had_closure).otherwise(col("had_closure")))

combined_df = combined_df.withColumn("avg_speed_limit", F.lit(mean_avg_speed_limit))
combined_df = combined_df.withColumn("avg_total_vehicles_involved", F.lit(mean_avg_total_vehicles_involved))
combined_df = combined_df.withColumn("avg_total_casualty", F.lit(mean_avg_total_casualty))
combined_df = combined_df.withColumn("pct_intersection_crash", F.lit(mean_pct_intersection_crash))
combined_df = combined_df.withColumn("pct_pedestrian_involved", F.lit(mean_pct_pedestrian_involved))
combined_df = combined_df.withColumn("pct_distraction_involved", F.lit(mean_pct_distraction_involved))
combined_df = combined_df.withColumn("pct_drug_involved", F.lit(mean_pct_drug_involved))
combined_df = combined_df.withColumn("pct_impaired_involved", F.lit(mean_pct_impaired_involved))           
combined_df = combined_df.withColumn("pct_speed_involved", F.lit(mean_pct_speed_involved))
combined_df = combined_df.withColumn("weather", F.lit(mode_weather))
combined_df = combined_df.withColumn("road_condition", F.lit(mode_road_condition))
combined_df = combined_df.withColumn("hotspot_risk_level", lit("unknown"))


#print(combined_df.columns)
combined_df.printSchema()
#print(combined_df.count())

#combined_df.show()
#combined_df.select("lat_bin", "lon_bin", "time_bin").show()

root
 |-- lat_bin: double (nullable = true)
 |-- lon_bin: double (nullable = true)
 |-- time_bin: timestamp (nullable = true)
 |-- avg_temp: double (nullable = true)
 |-- avg_visibility: double (nullable = true)
 |-- avg_clouds: double (nullable = true)
 |-- max_rain: double (nullable = true)
 |-- max_snow: double (nullable = true)
 |-- last_weather: string (nullable = true)
 |-- last_weather_description: string (nullable = true)
 |-- municipality: string (nullable = true)
 |-- time_period: string (nullable = false)
 |-- season: string (nullable = true)
 |-- pct_is_weekend: double (nullable = false)
 |-- pct_is_rush_hour: double (nullable = false)
 |-- avg_speed: double (nullable = true)
 |-- avg_flow_speed: double (nullable = true)
 |-- avg_travel_time: double (nullable = true)
 |-- avg_flow_travel_time: double (nullable = true)
 |-- avg_speed_diff: double (nullable = true)
 |-- had_closure: double (nullable = true)
 |-- avg_speed_limit: double (nullable = false)
 |-- avg_total_vehicl

In [None]:
from pyspark.sql.functions import col

combined_df.select("time_bin").orderBy("time_bin", ascending=False).show()


In [None]:
print("[DEBUG] Weather count:", weather_agg_df.count())
print("[DEBUG] Traffic count:", traffic_agg_df.count())

weather_agg_df.select("time_bin").distinct().orderBy(col("time_bin").desc()).show(5)
traffic_agg_df.select("time_bin").distinct().orderBy(col("time_bin").desc()).show(5)

#combined_df.select("time_bin", "lat_bin", "lon_bin").distinct().orderBy(col("time_bin").desc()).show(5)
#combined_df.select("time_bin").distinct().orderBy(col("time_bin").desc()).show(5)

In [None]:
combined_df.printSchema()

query = combined_df.writeStream \
    .format("console") \
    .outputMode("append") \
    .option("truncate", False) \
    .start()

query.awaitTermination()



In [None]:
from pyspark.sql.functions import col, sum

combined_df.select([sum(col(c).isNull().cast("int")).alias(c) for c in combined_df.columns]).show()


In [None]:
pipeline_path = "/home/jovyan/work/rfc_model"  
pipeline_model = PipelineModel.load(pipeline_path)

predictions_df = pipeline_model.transform(combined_df)

# Map prediction index to label
prediction_df = predictions_df.withColumn(
    "hotspot_risk_level",
    F.when(F.col("prediction") == 0.0, 'low')
     .when(F.col("prediction") == 1.0, 'moderate')
     .when(F.col("prediction") == 2.0, 'high')
     .otherwise("unknown")
)

#prediction_df.select("time_bin", "lat_bin", "lon_bin").distinct().orderBy(col("time_bin").desc()).show(5)
#prediction_df.select("hotspot_risk_level", "prediction", "probability").filter(col("hotspot_risk_level") == "low").show()

In [None]:
from pyspark.sql.functions import col, to_json, struct
import redis
import json
import logging
from datetime import datetime, timezone

# Configure the logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

try: 
    # Redis connection setup
    redis_host = "redis-server"  # Use "localhost" if running Redis outside Docker
    redis_port = 6379
    redis_client = redis.StrictRedis(host=redis_host, port=redis_port, decode_responses=True)

    redis_client.ping()
    logger.info(f"Successfully connected to Redis at {redis_host}:{redis_port}")
except Exception as e:
    logger.error(f"Failed to connect to Redis at {redis_host}:{redis_port}: {e}")
    

# Function to write data to Redis
def write_weather_to_redis(df, batch_id):
    if df.isEmpty():
        logger.warning(f"Batch {batch_id} is empty, no data to write to Redis")
        return
        
    logger.info(f"Processing batch {batch_id}, number of rows: {df.count()}")  # Log row count

    try:
        try:
            redis_client.ping()
            logger.info("Redis connection is working")
        except Exception as e:
            logger.error(f"Redis connection failed: {e}")
            return

        # Convert to pandas for easier iteration (only for small batches)
        rows = df.collect()
        logger.info(f"Processing batch {batch_id}, number of rows: {len(rows)}")
        
        redis_pipe = redis_client.pipeline()
        
        for row in rows:
            key = f"weather:{row.name}:{row.latitude}:{row.longitude}"
            value = {
                "name": row.name,
                "avg_temp": row.avg_temp,
                "avg_visibility": row.avg_visibility,
                "avg_clouds": row.avg_clouds,
                "max_rain": row.max_rain,
                "max_snow": row.max_snow,
                "weather": row.last_weather,
                "weather_description": row.last_weather_description
            }
            redis_pipe.set(key, json.dumps(value)) # Store the data as a JSON string

        redis_pipe.execute()  # Execute the batch write
        logger.info(f"Successfully wrote {len(rows)} weather rows to Redis")
    except Exception as e:
        logger.error(f"Error writing weather data to Redis: {e}")
        
    
# Function to write aggregated traffic data to Redis
def write_traffic_to_redis(df, batch_id):
    if df.isEmpty():
        logger.warning(f"Batch {batch_id} is empty, no data to write to Redis")
        return

    try:
        # Convert to pandas for easier iteration (only for small batches)
        rows = df.collect()
        logger.info(f"Processing batch {batch_id}, number of rows: {len(rows)}")
        
        redis_pipe = redis_client.pipeline()
        for row in rows:  # Avoids full collect() in driver memory
            key = f"traffic:{row.latitude}_{row.longitude}"
            value = {
                "avg_speed": row.avg_speed,
                "avg_flow_speed": row.avg_flow_speed,
                "avg_travel_time": row.avg_travel_time,
                "avg_flow_travel_time": row.avg_flow_travel_time,
                "avg_speed_diff": row.avg_speed_diff,
                "had_closure": row.had_closure
            }
            redis_pipe.set(key, json.dumps(value))  # Store as JSON

        
        redis_pipe.execute()  # Write in bulk
        logger.info(f"Successfully wrote {len(rows)} traffic rows to Redis")
    except Exception as e:
        logger.error(f"Error writing traffic data to Redis: {e}")


def write_predictions_to_redis(df, batch_id):
    if df.isEmpty():
        logger.warning(f"Batch {batch_id} is empty, no data to write to Redis")
        return

    try:
        rows = df.collect()
        logger.info(f"Processing batch {batch_id}, number of rows: {len(rows)}")

        redis_pipe = redis_client.pipeline()
        for row in rows:
            municipality = getattr(row, "municipality", "unknown") or "unknown"
            key = f"hotspot:{municipality}:{row.lat_bin}_{row.lon_bin}"

            value = {
                "municipality": municipality,
                "latitude": row.lat_bin,
                "longitude": row.lon_bin,
                "weather": getattr(row, "last_weather", "unknown"),
                "weather_description": row.last_weather_description,
                "avg_temp": row.avg_temp,
                "max_rain": row.max_rain,
                "max_snow": row.max_snow,
                "avg_speed": getattr(row, "avg_speed", None),
                "avg_flow_speed": row.avg_flow_speed,
                "avg_travel_time": row.avg_travel_time,
                "avg_flow_travel_time": row.avg_flow_travel_time,
                "hotspot_risk_level": row.hotspot_risk_level,
                "timestamp": datetime.now(timezone.utc).isoformat()
            }
            redis_pipe.set(key, json.dumps(value))  # Store as JSON

        redis_pipe.execute()  # Write in bulk
        logger.info(f"Successfully wrote {len(rows)} prediction rows to Redis")
    except Exception as e:
        logger.error(f"Error writing prediction data to Redis: {e}")

# Apply the prediction transformation to the aggregated data
#prediction_df = predictions_df.withColumn("hotspot_risk_level", predictions_df["prediction"].cast("string"))

In [None]:
# TESTING
traffic_console_query = traffic_agg_df.writeStream \
    .format("console") \
    .outputMode("update") \
    .option("truncate", False) \
    .start()

weather_console_query = weather_agg_df.writeStream \
    .format("console") \
    .outputMode("update") \
    .option("truncate", False) \
    .start()

prediction_query = prediction_df.writeStream \
    .format("console") \
    .outputMode('append') \
    .option("truncate", False) \
    .start()


print("Streaming started... waiting for data...")

prediction_query.awaitTermination()
weather_console_query.awaitTermination()
traffic_console_query.awaitTermination()

In [None]:
weather_query.stop()
traffic_query.stop()

In [None]:
prediction_query = prediction_df \
    .writeStream \
    .format("console") \
    .outputMode('append') \
    .option("truncate", False) \
    .start()

prediction_query.awaitTermination()

In [None]:
weather_query = weather_agg_df \
    .writeStream \
    .format("console") \
    .outputMode('append') \
    .option("truncate", False) \
    .start()
weather_query.awaitTermination()

In [None]:
traffic_query = traffic_agg_df \
    .writeStream \
    .format("console") \
    .outputMode('append') \
    .option("truncate", False) \
    .start()

traffic_query.awaitTermination()

In [None]:
# Final write to Redis
from pyspark.sql.functions import col
import logging
import time


weather_query = weather_agg_df \
    .writeStream \
    .foreachBatch(write_weather_to_redis) \
    .outputMode('append') \
    .trigger(processingTime='10 seconds') \
    .start() 

        
traffic_query = traffic_agg_df \
    .writeStream \
    .foreachBatch(write_traffic_to_redis) \
    .outputMode('append') \
    .trigger(processingTime='10 seconds') \
    .start()


prediction_query = prediction_df \
    .writeStream \
    .foreachBatch(write_predictions_to_redis) \
    .outputMode('append') \
    .trigger(processingTime='10 seconds') \
    .start()


print("Streaming started... waiting for data...")

weather_query.awaitTermination()
traffic_query.awaitTermination()
prediction_query.awaitTermination()