In [7]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    col, count, desc, row_number, lit, current_timestamp, 
    unix_timestamp, from_unixtime, struct, array, expr, 
    collect_list, window
)
from pyspark.sql.window import Window
from pyspark.sql.types import *
from delta import *
from delta.tables import *

import pandas as pd
import os
import time
import uuid
import json

In [5]:
WAREHOUSE_DIR = "/home/jovyan/spark-warehouse"

In [2]:
def create_spark_session(app_name="NYC Taxi Data ETL"):
    """
    start spark session with kafka and delta support / memory config setup too
    """
    builder = SparkSession.builder.appName(app_name) \
        .config("spark.jars.packages", "org.apache.spark:spark-sql-kafka-0-10_2.12:3.5.3") \
        .config("spark.sql.session.timeZone", "UTC") \
        .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
        .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \
        .config("spark.sql.warehouse.dir", WAREHOUSE_DIR) \
        .config("spark.sql.catalogImplementation", "hive") \
        .config("spark.driver.memory", "5g") \
        .config("spark.executor.memory", "4g") \
        .config("spark.memory.offHeap.enabled", "true") \
        .config("spark.memory.offHeap.size", "2g") \
        .config("spark.driver.maxResultSize", "2g") \
        .config("spark.sql.shuffle.partitions", "100") \
        .config("spark.default.parallelism", "100") \
        .config("spark.memory.fraction", "0.8") \
        .config("spark.sql.debug.maxToStringFields", 100) \
        .enableHiveSupport()
    
    # delta config
    spark = configure_spark_with_delta_pip(builder).getOrCreate()
    
    # do not flood logs
    spark.sparkContext.setLogLevel("WARN")
    
    # Print configs for debugging
    print(f"Warehouse directory: {spark.conf.get('spark.sql.warehouse.dir')}")
    print(f"Catalog implementation: {spark.conf.get('spark.sql.catalogImplementation')}")
    
    return spark

In [3]:
def get_frequent_routes_batch(spark, window_minutes=30):
    """
    Query 1 Part 1: Find top 10 most frequent routes in batch mode
    
    Args:
        spark: SparkSession object
        window_minutes: Time window in minutes (default 30)
    
    Returns:
        DataFrame with top 10 most frequent routes
    """
    # Get current timestamp for window calculation
    current_time = spark.sql("SELECT current_timestamp() as now").collect()[0]['now']
    
    # Calculate the window start time
    window_start = current_time.timestamp() - (window_minutes * 60)
    window_start_timestamp = from_unixtime(lit(window_start)).cast("timestamp")
    
    # Query to find the top 10 most frequent routes
    query = f"""
    SELECT 
        start_cell, 
        end_cell, 
        COUNT(*) AS num_rides
    FROM 
        query1_base
    WHERE 
        dropoff_datetime > '{window_start_timestamp}'
    GROUP BY 
        start_cell, end_cell
    ORDER BY 
        num_rides DESC
    LIMIT 10
    """
    
    result = spark.sql(query)
    return result

def stream_frequent_routes(spark, output_path, stream_interval=5):
    """
    Query 1 Part 2: Stream results whenever top 10 frequent routes change
    
    Args:
        spark: SparkSession object
        output_path: Path to save the output stream
        stream_interval: Interval in seconds to check for changes
    """
    # Create schema for the output
    schema = StructType([
        StructField("pickup_datetime", TimestampType(), False),
        StructField("dropoff_datetime", TimestampType(), False),
        StructField("start_cell_id_1", StringType(), True),
        StructField("end_cell_id_1", StringType(), True),
        StructField("start_cell_id_2", StringType(), True),
        StructField("end_cell_id_2", StringType(), True),
        StructField("start_cell_id_3", StringType(), True),
        StructField("end_cell_id_3", StringType(), True),
        StructField("start_cell_id_4", StringType(), True),
        StructField("end_cell_id_4", StringType(), True),
        StructField("start_cell_id_5", StringType(), True),
        StructField("end_cell_id_5", StringType(), True),
        StructField("start_cell_id_6", StringType(), True),
        StructField("end_cell_id_6", StringType(), True),
        StructField("start_cell_id_7", StringType(), True),
        StructField("end_cell_id_7", StringType(), True),
        StructField("start_cell_id_8", StringType(), True),
        StructField("end_cell_id_8", StringType(), True),
        StructField("start_cell_id_9", StringType(), True),
        StructField("end_cell_id_9", StringType(), True),
        StructField("start_cell_id_10", StringType(), True),
        StructField("end_cell_id_10", StringType(), True),
        StructField("delay", DoubleType(), False)
    ])
    
    # Function to process each micro-batch
    def process_batch(batch_df, batch_id):
        nonlocal previous_top_routes
        
        # Get the latest event timestamp
        latest_event = batch_df.agg({"dropoff_datetime": "max"}).collect()[0][0]
        if latest_event is None:
            return
        
        # Mark the processing start time for delay calculation
        processing_start = time.time()
        
        # Get current top 10 routes
        current_top_routes = get_frequent_routes_batch(spark)
        
        # Convert to list of tuples for comparison
        route_tuples = [(row['start_cell'], row['end_cell']) for row in current_top_routes.collect()]
        
        # Check if routes have changed
        if set(route_tuples) != set(previous_top_routes):
            # Routes have changed, create output record
            latest_pickup = batch_df.agg({"pickup_datetime": "max"}).collect()[0][0]
            
            # Pad to ensure 10 routes (use NULL for missing routes)
            padded_routes = route_tuples + [(None, None)] * (10 - len(route_tuples))
            padded_routes = padded_routes[:10]  # Ensure exactly 10 entries
            
            # Calculate processing delay
            delay = time.time() - processing_start
            
            # Create output record
            output_data = [(
                latest_pickup,
                latest_event,
                padded_routes[0][0], padded_routes[0][1],
                padded_routes[1][0], padded_routes[1][1],
                padded_routes[2][0], padded_routes[2][1],
                padded_routes[3][0], padded_routes[3][1],
                padded_routes[4][0], padded_routes[4][1],
                padded_routes[5][0], padded_routes[5][1],
                padded_routes[6][0], padded_routes[6][1],
                padded_routes[7][0], padded_routes[7][1],
                padded_routes[8][0], padded_routes[8][1],
                padded_routes[9][0], padded_routes[9][1],
                delay
            )]
            
            # Create output DataFrame
            output_df = spark.createDataFrame(output_data, schema)
            
            # Write to output path
            output_df.write.mode("append").format("csv").save(output_path)
            
            # Update previous routes
            previous_top_routes = route_tuples
            
            print(f"Top routes changed at {latest_event}. Wrote update with delay: {delay:.6f}s")
    
    # Initialize with empty routes
    previous_top_routes = []
    
    # Create the streaming DataFrame
    stream_df = spark.readStream \
        .format("delta") \
        .table("enhanced_taxi_data") \
        .select("pickup_datetime", "dropoff_datetime", "pickup_cell_id_500m", "dropoff_cell_id_500m") \
        .where("pickup_cell_id_500m IS NOT NULL AND dropoff_cell_id_500m IS NOT NULL")
    
    # Define checkpoint location
    checkpoint_path = os.path.join(output_path, "_checkpoint")
    os.makedirs(checkpoint_path, exist_ok=True)
    
    # Start the streaming query
    query = stream_df.writeStream \
        .foreachBatch(process_batch) \
        .outputMode("update") \
        .trigger(processingTime=f"{stream_interval} seconds") \
        .option("checkpointLocation", checkpoint_path) \
        .start()
    
    return query

def run_query1(spark, output_dir, mode="batch"):
    """
    Run Query 1 in either batch or streaming mode
    
    Args:
        spark: SparkSession object
        output_dir: Directory to save results
        mode: "batch" for Part 1 or "stream" for Part 2
    """
    print(f"Running Query 1 ({mode} mode)...")
    
    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    if mode == "batch":
        # Run batch mode (Part 1)
        result = get_frequent_routes_batch(spark)
        
        # Save results
        output_path = os.path.join(output_dir, "query1_part1_results.csv")
        result.write.mode("overwrite").csv(output_path)
        
        print(f"Query 1 Part 1 results saved to {output_path}")
        print("Top 10 most frequent routes:")
        result.show()
        
        return result
    
    elif mode == "stream":
        # Run streaming mode (Part 2)
        output_path = os.path.join(output_dir, "query1_part2_results")
        query = stream_frequent_routes(spark, output_path)
        
        print(f"Query 1 Part 2 streaming started. Results will be saved to {output_path}")
        print("Streaming query will run until explicitly stopped.")
        
        return query
    
    else:
        raise ValueError(f"Invalid mode: {mode}. Use 'batch' or 'stream'.")

In [None]:
max_time_query = "SELECT MAX(dropoff_datetime) as max_time FROM clean_taxi_data"
max_time = spark.sql(max_time_query).collect()[0]['max_time']

In [8]:
# Create Spark session
print("Initializing Spark session...")
spark = create_spark_session()
print(f"Spark version: {spark.version}")
    
# Show warehouse directory for debugging
warehouse_dir = spark.conf.get("spark.sql.warehouse.dir")
print(f"Warehouse directory: {warehouse_dir}")
    
# Set up output directory
output_dir = f"results"
os.makedirs(output_dir, exist_ok=True)

print(f"\n===== EXECUTING QUERY 1")
# Run in batch or streaming mode based on part
mode = "batch" 
query1_output = os.path.join(output_dir, "query1")
result = run_query1(spark, query1_output, mode)

Initializing Spark session...
Warehouse directory: file:/home/jovyan/spark-warehouse
Catalog implementation: hive
Spark version: 3.5.3
Warehouse directory: file:/home/jovyan/spark-warehouse

===== EXECUTING QUERY 1
Running Query 1 (batch mode)...


ParseException: 
[PARSE_SYNTAX_ERROR] Syntax error at or near 'CAST'.(line 9, pos 36)

== SQL ==

    SELECT 
        start_cell, 
        end_cell, 
        COUNT(*) AS num_rides
    FROM 
        query1_base
    WHERE 
        dropoff_datetime > 'Column<'CAST(from_unixtime(1.743111259746981E9, yyyy-MM-dd HH:mm:ss) AS TIMESTAMP)'>'
------------------------------------^^^
    GROUP BY 
        start_cell, end_cell
    ORDER BY 
        num_rides DESC
    LIMIT 10
    
