In [1]:
import math, time, datetime
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    col, udf, expr, when, lit, array, percentile_approx, desc
)
from pyspark.sql.types import (
    StructType, StructField, StringType, TimestampType,
    DoubleType, IntegerType, ArrayType
)

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, count, window, unix_timestamp, max
from pyspark.sql.window import Window
from pyspark.sql import functions as F
import math
import time

from delta import *
from delta.tables import *
from pyspark.sql.functions import col, to_json, struct, lit, current_timestamp, expr, when, from_json, window
from pyspark.sql.types import (
    StructType,
    StructField,
    StringType,
    DoubleType,
    TimestampType,
    IntegerType,
)
import pandas as pd
import os
import uuid
import json

from pyspark.sql.functions import udf
from pyspark.sql.types import StringType

In [2]:
WAREHOUSE_DIR = "./spark-warehouse"

def create_spark_session(app_name="FrequentRoutes"):
    """
    start spark session with kafka and delta support / memory config setup too
    """
    builder = SparkSession.builder.appName(app_name) \
        .config("spark.jars.packages", "org.apache.spark:spark-sql-kafka-0-10_2.12:3.5.3") \
        .config("spark.sql.session.timeZone", "UTC") \
        .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
        .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \
        .config("spark.sql.warehouse.dir", WAREHOUSE_DIR) \
        .config("spark.sql.catalogImplementation", "hive") \
        .config("spark.driver.memory", "5g") \
        .config("spark.executor.memory", "4g") \
        .config("spark.memory.offHeap.enabled", "true") \
        .config("spark.memory.offHeap.size", "2g") \
        .config("spark.driver.maxResultSize", "2g") \
        .config("spark.sql.shuffle.partitions", "100") \
        .config("spark.default.parallelism", "100") \
        .config("spark.memory.fraction", "0.8") \
        .config("spark.sql.debug.maxToStringFields", 100) \
        .enableHiveSupport()
    
    # delta config
    spark = configure_spark_with_delta_pip(builder).getOrCreate()
    
    # do not flood logs
    spark.sparkContext.setLogLevel("WARN")
    
    # Print configs for debugging
    print(f"Warehouse directory: {spark.conf.get('spark.sql.warehouse.dir')}")
    print(f"Catalog implementation: {spark.conf.get('spark.sql.catalogImplementation')}")
    
    return spark

spark = create_spark_session()

:: loading settings :: url = jar:file:/gpfs/helios/home/fidankarimova/myenv/lib/python3.9/site-packages/pyspark/jars/ivy-2.5.1.jar!/org/apache/ivy/core/settings/ivysettings.xml


Ivy Default Cache set to: /gpfs/helios/home/fidankarimova/.ivy2/cache
The jars for the packages stored in: /gpfs/helios/home/fidankarimova/.ivy2/jars
io.delta#delta-spark_2.12 added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-630e1920-9d42-4a94-a611-d5306931a5f9;1.0
	confs: [default]
	found io.delta#delta-spark_2.12;3.3.0 in central
	found io.delta#delta-storage;3.3.0 in central
	found org.antlr#antlr4-runtime;4.9.3 in central
:: resolution report :: resolve 166ms :: artifacts dl 20ms
	:: modules in use:
	io.delta#delta-spark_2.12;3.3.0 from central in [default]
	io.delta#delta-storage;3.3.0 from central in [default]
	org.antlr#antlr4-runtime;4.9.3 from central in [default]
	---------------------------------------------------------------------
	|                  |            modules            ||   artifacts   |
	|       conf       | number| search|dwnlded|evicted|| number|dwnlded|
	-----------------------------------------------------------------

Warehouse directory: file:/gpfs/helios/home/fidankarimova/input/spark-warehouse/spark-warehouse
Catalog implementation: hive


In [3]:
# Part 1: Define Schema and Read Cleaned Data as a Stream
#######################################
def create_raw_taxi_schema():
    return StructType([
        StructField("medallion", StringType(), True),
        StructField("hack_license", StringType(), True),
        StructField("pickup_datetime", TimestampType(), True),
        StructField("dropoff_datetime", TimestampType(), True),
        StructField("trip_time_in_secs", IntegerType(), True),
        StructField("trip_distance", DoubleType(), True),
        StructField("pickup_longitude", DoubleType(), True),
        StructField("pickup_latitude", DoubleType(), True),
        StructField("dropoff_longitude", DoubleType(), True),
        StructField("dropoff_latitude", DoubleType(), True),
        StructField("payment_type", StringType(), True),
        StructField("fare_amount", DoubleType(), True),
        StructField("surcharge", DoubleType(), True),
        StructField("mta_tax", DoubleType(), True),
        StructField("tip_amount", DoubleType(), True),
        StructField("tolls_amount", DoubleType(), True),
        StructField("total_amount", DoubleType(), True),
    ])

raw_schema = create_raw_taxi_schema()
CLEAN_OUTPUT_PATH = "clean_taxi_data"
batch_df = spark.read.format("delta").load(CLEAN_OUTPUT_PATH)
df = spark.readStream.format("delta").load(CLEAN_OUTPUT_PATH)

In [4]:
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType, ArrayType

LAT_REF = 41.474937
LON_REF = -74.913585
EARTH_RADIUS = 6371000.0  # in meters

def latlon_to_meters(lat, lon):
    """Approx x,y distance from reference using equirectangular approximation."""
    lat_r = math.radians(lat)
    lon_r = math.radians(lon)
    lat_ref_r = math.radians(LAT_REF)
    lon_ref_r = math.radians(LON_REF)
    x = EARTH_RADIUS * (lon_r - lon_ref_r) * math.cos((lat_r + lat_ref_r) / 2)
    y = EARTH_RADIUS * (lat_r - lat_ref_r)
    return (x, y)

def get_250m_cell(lat, lon):
    """Return [cell_x, cell_y] for a 250m x 250m grid. None if out of range."""
    (x_m, y_m) = latlon_to_meters(lat, lon)
    cell_x = int(math.floor(x_m / 250.0)) + 1
    cell_y = int(math.floor((-1 * y_m) / 250.0)) + 1
    if 1 <= cell_x <= 600 and 1 <= cell_y <= 600:
        return [cell_x, cell_y]
    return None

udf_get_250m_cell = udf(get_250m_cell, ArrayType(IntegerType()))

df_cells = (
    df
    # Add columns for start_cell_x, start_cell_y, end_cell_x, end_cell_y
    .withColumn("start_cell",
        udf_get_250m_cell(col("pickup_latitude"), col("pickup_longitude")))
    .withColumn("start_cell_x", expr("start_cell[0]"))
    .withColumn("start_cell_y", expr("start_cell[1]"))
    .withColumn("end_cell",
        udf_get_250m_cell(col("dropoff_latitude"), col("dropoff_longitude")))
    .withColumn("end_cell_x", expr("end_cell[0]"))
    .withColumn("end_cell_y", expr("end_cell[1]"))
    # Keep only rows with valid cells
    .filter(col("start_cell_x").isNotNull() & col("end_cell_x").isNotNull())
)

In [6]:
# STEP B: Profit Aggregation (last 15 min)
########################################
# Windowed aggregator on dropoff_datetime
# Median fare+tip for trips that ended in last 15 min, grouped by pickup cell.
from pyspark.sql.functions import percentile_approx, window
from pyspark.sql.functions import max as spark_max

profit_agg = (
    df_cells
    .withColumn("fare_plus_tip", col("fare_amount") + col("tip_amount"))
    .withWatermark("dropoff_datetime", "20 minutes")
    .groupBy(
        window("dropoff_datetime", "15 minutes"),
        col("start_cell_x"),
        col("start_cell_y")
    )
    .agg(
        percentile_approx("fare_plus_tip", 0.5).alias("median_fare_tip"),
        spark_max("pickup_datetime").alias("trigger_pickup"),
        spark_max("dropoff_datetime").alias("trigger_dropoff")
    )
)

In [7]:
########################################
# 3. Track "empty" taxis with ephemeral dictionary in foreachBatch
########################################
# We will store ephemeral state in a global dictionary:
# empty_state[medallion] = (dropoff_time, end_cell_x, end_cell_y)
# If we see a subsequent pickup after dropoff_time, remove it.
# We'll mark a taxi as empty if dropoff_time < 30 min ago.

empty_state = {}  # global dictionary
last_top10 = None  # track previous top 10

########################################
# 3.1. We'll define a microbatch function that:
#   - collects the new rows
#   - updates empty_state for each row
#   - for each taxi, if dropoff_time < 30 min ago and no new pickup -> that taxi is empty
#   - produce a (cell_x, cell_y, 1) row for each empty taxi
#   - store them in a local aggregator: empty_count_map[(cell_x, cell_y)] = number_of_empty
########################################

def process_empty_taxis(batch_df, batch_id):
    global empty_state
    rows = batch_df.collect()  # gather micro-batch
    if not rows:
        return
    
    # Update ephemeral dictionary
    now_ts = datetime.datetime.utcnow()
    
    for r in rows:
        med = r["medallion"]
        pick_t = r["pickup_datetime"]
        drop_t = r["dropoff_datetime"]
        end_x = r["end_cell_x"]
        end_y = r["end_cell_y"]
        
        # update dropoff
        if drop_t:
            empty_state[med] = (drop_t, end_x, end_y)
        
        # if there's a subsequent pickup after the last dropoff, remove it
        old_val = empty_state.get(med, None)
        if old_val and pick_t:
            (old_drop_t, old_cellx, old_celly) = old_val
            if pick_t > old_drop_t:
                # not empty
                del empty_state[med]

# We'll write df_cells to a sink with foreachBatch=process_empty_taxis
# But that won't produce final output. We'll store empty results in ephemeral dict.
# Then next step is to gather "current empties" for a join with the profit aggregator in a separate foreachBatch.

empty_query = (
    df_cells
    .writeStream
    .outputMode("append")  # just to trigger micro-batches
    .format("console")     # or any sink
    .option("truncate", "false")
    .trigger(processingTime="10 seconds")
    .foreachBatch(process_empty_taxis)
    .start()
)

########################################
# 4. Next, we combine ephemeral empty-taxi info with 15-min profit aggregator
#    in a foreachBatch on the profit aggregator
########################################

def process_profit(batch_df, batch_id):
    global empty_state, last_top10
    
    # This micro-batch has the windowed aggregator for median_fare_tip
    # We'll build a local aggregator for empties: (cell_x, cell_y) -> count
    now_ts = datetime.datetime.utcnow()
    
    # 4.1. Build an ephemeral aggregator for empties
    empty_map = {}
    for med, (drop_time, cx, cy) in empty_state.items():
        # if drop_time < 30 min
        if drop_time is not None:
            delta_sec = (now_ts - drop_time).total_seconds()
            if delta_sec <= 1800:
                key = (cx, cy)
                empty_map[key] = empty_map.get(key, 0) + 1
    
    # 4.2. Convert batch_df to a local list
    profit_rows = batch_df.collect()
    if not profit_rows:
        return
    
    # We'll produce final joined rows: (cell_x, cell_y, median_fare_tip, empty_taxis, profitability, pickup, dropoff)
    final_list = []
    for pr in profit_rows:
        sx = pr["start_cell_x"]
        sy = pr["start_cell_y"]
        medf = pr["median_fare_tip"]
        empties = empty_map.get((sx, sy), 0)
        profitval = (medf / empties) if empties > 0 else None
        final_list.append({
            "cell_x": sx,
            "cell_y": sy,
            "median_fare_tip": medf,
            "empty_taxis": empties,
            "profitability": profitval,
            "pickup": pr["trigger_pickup"],
            "dropoff": pr["trigger_dropoff"]
        })
    
    # 4.3. Sort descending by profitability, pick top 10
    final_list.sort(key=lambda row: (row["profitability"] if row["profitability"] else 0), reverse=True)
    top10 = final_list[:10]
    
    # 4.4. Compare with last_top10 (optional). For simplicity we always print here
    if top10:
        # Build a single line output
        out = {}
        out["pickup_datetime"] = str(top10[0]["pickup"]) if top10[0]["pickup"] else None
        out["dropoff_datetime"] = str(top10[0]["dropoff"]) if top10[0]["dropoff"] else None
        
        for i in range(len(top10)):
            idx = i+1
            rowi = top10[i]
            out[f"profitable_cell_id_{idx}"] = f"{rowi['cell_x']}.{rowi['cell_y']}"
            out[f"empty_taxies_in_cell_id_{idx}"] = rowi["empty_taxis"]
            out[f"median_profit_in_cell_id_{idx}"] = rowi["median_fare_tip"]
            out[f"profitability_of_cell_{idx}"] = rowi["profitability"]
        
        # Fill up to 10
        for j in range(len(top10)+1, 11):
            out[f"profitable_cell_id_{j}"] = None
            out[f"empty_taxies_in_cell_id_{j}"] = None
            out[f"median_profit_in_cell_id_{j}"] = None
            out[f"profitability_of_cell_{j}"] = None
        
        out["delay"] = 1.0  # placeholder
        print(f"\n=== Top 10 Profitable Areas (batch={batch_id}) ===")
        print(out)

# 4.5. Write the profit aggregator in a foreachBatch, calling process_profit
profit_final_query = (
    profit_agg
    .writeStream
    .outputMode("update")
    .format("console")     # or any sink
    .option("truncate", "false")
    .trigger(processingTime="10 seconds")
    .foreachBatch(process_profit)
    .start()
)

# profit_final_query.awaitTermination()


25/03/30 17:43:39 WARN ResolveWriteToStream: Temporary checkpoint location created which is deleted normally when the query didn't fail: /tmp/temporary-d63ef4d8-21fd-42ed-921a-cc04808368d1. If it's required to delete it under any circumstances, please set spark.sql.streaming.forceDeleteTempCheckpointLocation to true. Important to know deleting temp checkpoint folder is best effort.
25/03/30 17:43:39 WARN ResolveWriteToStream: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled.
25/03/30 17:43:39 WARN ResolveWriteToStream: Temporary checkpoint location created which is deleted normally when the query didn't fail: /tmp/temporary-01bd86f7-ac72-467e-9ae2-80939a30e821. If it's required to delete it under any circumstances, please set spark.sql.streaming.forceDeleteTempCheckpointLocation to true. Important to know deleting temp checkpoint folder is best effort.
25/03/30 17:43:39 WARN ResolveWriteToStream: spark.sql.adaptive.enabled is not support


=== Top 10 Profitable Areas (batch=0) ===
{'pickup_datetime': '2013-01-01 02:03:00', 'dropoff_datetime': '2013-01-01 02:08:00', 'profitable_cell_id_1': '309.334', 'empty_taxies_in_cell_id_1': 0, 'median_profit_in_cell_id_1': 8.25, 'profitability_of_cell_1': None, 'profitable_cell_id_2': '327.342', 'empty_taxies_in_cell_id_2': 0, 'median_profit_in_cell_id_2': 5.0, 'profitability_of_cell_2': None, 'profitable_cell_id_3': '310.333', 'empty_taxies_in_cell_id_3': 0, 'median_profit_in_cell_id_3': 9.1, 'profitability_of_cell_3': None, 'profitable_cell_id_4': '317.301', 'empty_taxies_in_cell_id_4': 0, 'median_profit_in_cell_id_4': 9.5, 'profitability_of_cell_4': None, 'profitable_cell_id_5': '317.316', 'empty_taxies_in_cell_id_5': 0, 'median_profit_in_cell_id_5': 4.5, 'profitability_of_cell_5': None, 'profitable_cell_id_6': '307.320', 'empty_taxies_in_cell_id_6': 0, 'median_profit_in_cell_id_6': 9.5, 'profitability_of_cell_6': None, 'profitable_cell_id_7': '309.331', 'empty_taxies_in_cell_id_

25/03/30 17:45:59 WARN ProcessingTimeExecutor: Current batch is falling behind. The trigger interval is 10000 milliseconds, but spent 139107 milliseconds

In [5]:
# Part 2: 250m x 250m Cell Mapping
#######################################
LAT_REF = 41.474937
LON_REF = -74.913585
EARTH_RADIUS = 6371000.0

def latlon_to_meters(lat, lon):
    lat_r = math.radians(lat)
    lon_r = math.radians(lon)
    lat_ref_r = math.radians(LAT_REF)
    lon_ref_r = math.radians(LON_REF)
    x = EARTH_RADIUS * (lon_r - lon_ref_r) * math.cos((lat_r + lat_ref_r)/2)
    y = EARTH_RADIUS * (lat_r - lat_ref_r)
    return (x, y)

def get_250m_cell(lat, lon):
    (x_m, y_m) = latlon_to_meters(lat, lon)
    cell_x = int(math.floor(x_m / 250.0)) + 1
    cell_y = int(math.floor((-1 * y_m) / 250.0)) + 1
    if 1 <= cell_x <= 600 and 1 <= cell_y <= 600:
        return [cell_x, cell_y]
    return None

@udf(ArrayType(IntegerType()))
def udf_get_250m_cell(lat, lon):
    if lat is None or lon is None:
        return None
    return get_250m_cell(lat, lon)

df_cells = (df
    .withColumn("start_cell", udf_get_250m_cell(col("pickup_latitude"), col("pickup_longitude")))
    .withColumn("end_cell", udf_get_250m_cell(col("dropoff_latitude"), col("dropoff_longitude")))
    .withColumn("start_cell_x", expr("start_cell[0]"))
    .withColumn("start_cell_y", expr("start_cell[1]"))
    .withColumn("end_cell_x", expr("end_cell[0]"))
    .withColumn("end_cell_y", expr("end_cell[1]"))
    .filter(col("start_cell_x").isNotNull() & col("end_cell_x").isNotNull())
)


In [6]:
# Part 3: Ephemeral Python State for Empty Taxi Tracking and Profit Aggregation
#######################################
# Global dictionary to track for each taxi its last dropoff state:
dropoff_state = {}  # key: medallion, value: (dropoff_datetime, end_cell_x, end_cell_y)
last_top10 = None   # to track previous top 10 result

def median(values):
    if not values:
        return None
    s = sorted(values)
    n = len(s)
    mid = n // 2
    return s[mid] if n % 2 == 1 else (s[mid-1] + s[mid]) / 2.0

In [15]:
from pyspark.sql.functions import percentile_approx, window, desc, max as spark_max, approx_count_distinct

# Step 1: Add fare_plus_tip column
with_fare = df_cells.withColumn("fare_plus_tip", col("fare_amount") + col("tip_amount"))

# Step 2: Profit Aggregation: Median fare+tip per 15-minute window grouped by start_cell
profit_df = (
    with_fare
    .withWatermark("dropoff_datetime", "20 minutes")
    .groupBy(
        window("dropoff_datetime", "15 minutes"),
        col("start_cell_x"), col("start_cell_y")
    )
    .agg(
        percentile_approx("fare_plus_tip", 0.5).alias("median_fare_tip"),
        spark_max("pickup_datetime").alias("trigger_pickup"),
        spark_max("dropoff_datetime").alias("trigger_dropoff")
    )
)

profit_query = (
    profit_df
    .writeStream
    .outputMode("append")          # Or "update", depending on your aggregation logic
    .format("console")
    .option("truncate", False)
    .start()
)

25/03/30 16:58:34 WARN ResolveWriteToStream: Temporary checkpoint location created which is deleted normally when the query didn't fail: /tmp/temporary-2455a1c5-75c0-4a2c-a815-c2090470edbc. If it's required to delete it under any circumstances, please set spark.sql.streaming.forceDeleteTempCheckpointLocation to true. Important to know deleting temp checkpoint folder is best effort.
25/03/30 16:58:34 WARN ResolveWriteToStream: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled.
[Stage 32:(10 + 12) / 100][Stage 34:>  (0 + 0) / 50][Stage 35:>  (0 + 0) / 50]

In [16]:
# Step 3: Empty Taxi Aggregation: Count distinct medallions in end_cell per 30-minute window
empty_df = (
    df_cells
    .withWatermark("dropoff_datetime", "35 minutes")
    .groupBy(
        window("dropoff_datetime", "30 minutes"),
        col("end_cell_x").alias("cell_x"),
        col("end_cell_y").alias("cell_y")
    )
    .agg(
        approx_count_distinct("medallion").alias("empty_taxis")
    )
)
empty_df

DataFrame[window: struct<start:timestamp,end:timestamp>, cell_x: int, cell_y: int, empty_taxis: bigint]

[Stage 32:(14 + 12) / 100][Stage 34:>  (0 + 0) / 50][Stage 35:>  (0 + 0) / 50]

In [9]:
# Step 4: Join the two aggregations on matching cell coordinates
joined_df = (
    profit_df.join(
        empty_df,
        (profit_df.start_cell_x == empty_df.cell_x) &
        (profit_df.start_cell_y == empty_df.cell_y),
        "inner"
    )
    .select(
        profit_df.trigger_pickup,
        profit_df.trigger_dropoff,
        profit_df.start_cell_x,
        profit_df.start_cell_y,
        profit_df.median_fare_tip,
        empty_df.empty_taxis
    )
    .withColumn(
        "profitability",
        when(col("empty_taxis") > 0, col("median_fare_tip") / col("empty_taxis"))
    )
)
joined_df

DataFrame[trigger_pickup: timestamp, trigger_dropoff: timestamp, start_cell_x: int, start_cell_y: int, median_fare_tip: double, empty_taxis: bigint, profitability: double]

In [10]:
# Step 5: Output top 10 profitable cells per batch
last_top10 = None

def print_top_10(batch_df, batch_id):
    global last_top10
    top10 = batch_df.orderBy(desc("profitability")).limit(10).collect()
    if top10 != last_top10:
        print(f"\n=== TOP 10 CHANGED (batch {batch_id}) ===")
        for i, row in enumerate(top10, start=1):
            print(f"{i:02d}. Cell {row['start_cell_x']}.{row['start_cell_y']} → Profitability: {row['profitability']:.4f} | Median Profit: {row['median_fare_tip']:.2f} | Empty Taxis: {row['empty_taxis']}")
        last_top10 = top10

# Start streaming query
query = (
    joined_df
    .writeStream
    .foreachBatch(print_top_10)
    .outputMode("append")
    .start()
)

# query.awaitTermination()

25/03/30 16:35:48 WARN ResolveWriteToStream: Temporary checkpoint location created which is deleted normally when the query didn't fail: /tmp/temporary-764bde61-479c-47ac-92c2-e7405dfdab0e. If it's required to delete it under any circumstances, please set spark.sql.streaming.forceDeleteTempCheckpointLocation to true. Important to know deleting temp checkpoint folder is best effort.
25/03/30 16:35:48 WARN ResolveWriteToStream: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled.
                                                                                


=== TOP 10 CHANGED (batch 0) ===


[Stage 32:>                                                      (0 + 12) / 100]

In [14]:
for query in spark.streams.active:
    print(f"Query name: {query.name}")
    print(f"Status: {query.status}")
    print(f"Is active: {query.isActive}")

Query name: None
Status: {'message': 'No new data but cleaning up state', 'isDataAvailable': False, 'isTriggerActive': True}
Is active: True
Query name: None
Status: {'message': 'Getting offsets from DeltaSource[file:/gpfs/helios/home/fidankarimova/input/spark-warehouse/clean_taxi_data]', 'isDataAvailable': False, 'isTriggerActive': True}
Is active: True


[Stage 32:=>            (10 + 12) / 100][Stage 34:>                (0 + 0) / 50]