In [None]:
# IMPORTANT! Remove the empty space in each cell of the header, schema parsing was erroring.
# If erroring run it from the shell.
# !for file in input/trip_data_*.csv; do sed -i '1s/, /,/g' "$file"; done
# !for file in input/sample.csv; do sed -i '1s/, /,/g' "$file"; done

## Dependencies

In [None]:
from pyspark.sql import SparkSession
from delta import *
from delta.tables import *
from pyspark.sql.functions import col, to_json, struct, lit, current_timestamp, expr, when, from_json
from pyspark.sql.types import (
    StructType,
    StructField,
    StringType,
    DoubleType,
    TimestampType,
    IntegerType,
)
import pandas as pd
import os
import time
import uuid
import json

## Schema definition fn

In [None]:
# Create a consistent warehouse directory path - use absolute path
WAREHOUSE_DIR = "/home/jovyan/spark-warehouse"

# Create the schema definition for NYC taxi data
def create_taxi_schema():
    """Create the schema for NYC taxi data"""
    return StructType([
        StructField("medallion", StringType(), True),
        StructField("hack_license", StringType(), True),
        StructField("vendor_id", StringType(), True),
        StructField("rate_code", StringType(), True),
        StructField("store_and_fwd_flag", StringType(), True),
        StructField("pickup_datetime", TimestampType(), True),
        StructField("dropoff_datetime", TimestampType(), True),
        StructField("passenger_count", IntegerType(), True),
        StructField("trip_time_in_secs", IntegerType(), True),
        StructField("trip_distance", DoubleType(), True),
        StructField("pickup_longitude", DoubleType(), True),
        StructField("pickup_latitude", DoubleType(), True),
        StructField("dropoff_longitude", DoubleType(), True),
        StructField("dropoff_latitude", DoubleType(), True),
    ])

## Spark Session fn

In [None]:
def create_spark_session(app_name="NYC Taxi Data ETL"):
    """
    start spark session with kafka and delta support / memory config setup too
    """
    builder = SparkSession.builder.appName(app_name) \
        .config("spark.jars.packages", "org.apache.spark:spark-sql-kafka-0-10_2.12:3.5.3") \
        .config("spark.sql.session.timeZone", "UTC") \
        .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
        .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \
        .config("spark.sql.warehouse.dir", WAREHOUSE_DIR) \
        .config("spark.sql.catalogImplementation", "hive") \
        .config("spark.driver.memory", "5g") \
        .config("spark.executor.memory", "4g") \
        .config("spark.memory.offHeap.enabled", "true") \
        .config("spark.memory.offHeap.size", "2g") \
        .config("spark.driver.maxResultSize", "2g") \
        .config("spark.sql.shuffle.partitions", "100") \
        .config("spark.default.parallelism", "100") \
        .config("spark.memory.fraction", "0.8") \
        .config("spark.sql.debug.maxToStringFields", 100) \
        .enableHiveSupport()
    
    # delta config
    spark = configure_spark_with_delta_pip(builder).getOrCreate()
    
    # do not flood logs
    spark.sparkContext.setLogLevel("WARN")
    
    # Print configs for debugging
    print(f"Warehouse directory: {spark.conf.get('spark.sql.warehouse.dir')}")
    print(f"Catalog implementation: {spark.conf.get('spark.sql.catalogImplementation')}")
    
    return spark

In [None]:
# Initialize the Spark session
spark = create_spark_session()

Warehouse directory: file:/home/jovyan/spark-warehouse
Catalog implementation: hive


In [None]:
# Define the taxi schema
raw_taxi_schema = create_taxi_schema()

In [None]:
# Function to ingest CSV data to Kafka
def ingest_csv_to_kafka(csv_path, batch_size=1000):
    """
    Read a CSV file and publish records to Kafka using Spark's partitioning
    """
    print(f"Ingesting data from {csv_path} to Kafka topic 'raw-taxi-data'")
    
    # Load the data
    taxi_data = spark.read.csv(csv_path, header=True, schema=raw_taxi_schema)
    
    # Get total count for reporting
    total_count = taxi_data.count()
    print(f"Total records to process: {total_count}")
    
    # Add partition key
    taxi_data = taxi_data.withColumn("kafka_key", col("medallion"))
    
    # Create the JSON structure
    kafka_batch = taxi_data.select(
        col("kafka_key").cast("string"),
        to_json(
            struct(*[col(c) for c in taxi_data.columns if c != "kafka_key"])
        ).alias("value")
    )
    
    # Write to Kafka in one go (Spark will handle the batching internally)
    kafka_batch.write.format("kafka") \
        .option("kafka.bootstrap.servers", "kafka:9092") \
        .option("topic", "raw-taxi-data") \
        .option("maxOffsetsPerTrigger", batch_size) \
        .save()
    
    print(f"Finished ingesting data to Kafka topic 'raw-taxi-data'")

# Function to create a table once streaming data is available
def create_table_if_exists(output_path, table_name):
    """
    Check if data exists in the given path and create a table pointing to it
    """
    data_exists = False
    for _i in range(30):  # Longer timeout to 30 seconds
        try:
            time.sleep(120)
            if os.path.exists(output_path):
                files = os.listdir(output_path)
                for _f in files:
                    if ".parquet" in _f:
                        if os.path.exists(f"{output_path}/_delta_log") and len(os.listdir(f"{output_path}/_delta_log")) > 0:
                            print(f"Data exists in {output_path}")
                            data_exists = True
                            break
            if data_exists:
                # Create external table with explicit location
                spark.sql(f"DROP TABLE IF EXISTS {table_name}")
                spark.sql(f"CREATE TABLE {table_name} USING DELTA LOCATION '{output_path}'")
                print(f"Created table {table_name} using data at {output_path}")
                break
        except Exception as e:
            print(f"Waiting for data: {e}")
            pass
    
    if not data_exists:
        print(f"WARNING: No data found in {output_path} after waiting. Table may not be created.")

# Function to clean the taxi data
from pyspark.sql.functions import col, when

def clean_taxi_data(df):
    """
    Cleans and filters the taxi data stream by removing invalid trips, 
    handling nulls, and computing additional features.
    """

    # Define valid range conditions for latitude and longitude
    valid_coords = (
        (col("pickup_longitude").between(-180, 180)) & 
        (col("pickup_latitude").between(-90, 90)) & 
        (col("dropoff_longitude").between(-180, 180)) & 
        (col("dropoff_latitude").between(-90, 90))
    )

    # Filter invalid trips
    cleaned_df = df.filter(
        col("medallion").isNotNull() &   # Ensure taxi identifier exists
        col("trip_distance") > 0 &       # Trips must have distance
        col("passenger_count") > 0 &     # Must have at least one passenger
        col("trip_time_in_secs") > 0 &   # Duration must be positive
        valid_coords                     # Coordinates must be valid
    )

    # Compute trip speed in mph
    cleaned_df = cleaned_df.withColumn(
        "trip_speed_mph", 
        when(col("trip_time_in_secs") > 0, 
             col("trip_distance") / (col("trip_time_in_secs") / 3600)
        ).otherwise(0)
    )

    # Flag valid trips based on speed
    cleaned_df = cleaned_df.withColumn(
        "is_valid_trip", 
        (col("trip_distance") > 0) & 
        (col("trip_time_in_secs") > 0) & 
        (col("trip_speed_mph") < 100)  # Filter unrealistic speeds
    )

    return cleaned_df

In [3]:
# Step 1: Define paths
RAW_TABLE_NAME = "raw_taxi_data"
CLEAN_TABLE_NAME = "clean_taxi_data"

RAW_OUTPUT_PATH = os.path.join(WAREHOUSE_DIR, RAW_TABLE_NAME)
CLEAN_OUTPUT_PATH = os.path.join(WAREHOUSE_DIR, CLEAN_TABLE_NAME)

CHECKPOINT_DIR = os.path.join(WAREHOUSE_DIR, "streaming/checkpoints")
RAW_CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, "raw_taxi_data") 
CLEAN_CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, "clean_taxi_data")


In [4]:
# Ensure checkpoint directories exist
os.makedirs(RAW_CHECKPOINT_PATH, exist_ok=True)
os.makedirs(CLEAN_CHECKPOINT_PATH, exist_ok=True)


In [5]:
# Step 2: Ingest CSV data to Kafka (producer)
def run_producer():
    csv_path = "input/trip_data_8.csv"
    ingest_csv_to_kafka(csv_path, batch_size=10000)

In [6]:
# Step 3: Set up the consumer to read from Kafka and write to Delta table
def run_consumer(skip_table_creation=False):
    # 1. Read from Kafka
    raw_stream = spark \
        .readStream \
        .format("kafka") \
        .option("kafka.bootstrap.servers", "kafka:9092") \
        .option("subscribe", "raw-taxi-data") \
        .option("startingOffsets", "earliest") \
        .load()
    
    # 2. Parse JSON data
    filtered_kafka_stream = raw_stream.select(
        col("key").cast("string").alias("kafka_key"),
        col("value").cast("string").alias("json_data"),
        col("timestamp").alias("kafka_timestamp")
    )
    
    raw_data_df = filtered_kafka_stream.select(
        "kafka_key",
        from_json("json_data", raw_taxi_schema).alias("data"),
        "kafka_timestamp"
    ).select("kafka_key", "kafka_timestamp", "data.*")
    
    # 3. Write raw data to Delta
    raw_query = (raw_data_df.writeStream
        .outputMode("append")
        .format("delta")
        .queryName("raw_taxi_query")
        .trigger(processingTime="10 second")
        .option("checkpointLocation", RAW_CHECKPOINT_PATH)
        .start(RAW_OUTPUT_PATH)
    )
    
    # Wait for some data to be written
    print("Waiting for raw data to be processed...")
    time.sleep(60)
    
    # 4. Create table in the metastore
    if not skip_table_creation:
        create_table_if_exists(RAW_OUTPUT_PATH, RAW_TABLE_NAME)
    
    return raw_query

In [7]:
# Step 4: Clean data pipeline
def run_cleaner(raw_query):
    # 1. Read from raw Delta table
    raw_df = spark.readStream.format("delta").load(RAW_OUTPUT_PATH)
    
    # 2. Apply cleaning transformations
    clean_df = clean_taxi_data(raw_df)
    
    # 3. Write clean data to Delta
    clean_query = (clean_df.writeStream
        .outputMode("append")
        .format("delta")
        .queryName("clean_taxi_query")
        .trigger(processingTime="10 second")
        .option("checkpointLocation", CLEAN_CHECKPOINT_PATH)
        .start(CLEAN_OUTPUT_PATH)
    )
    
    # Wait for some data to be written
    print("Waiting for clean data to be processed...")
    time.sleep(30)
    
    # 4. Create table in the metastore
    create_table_if_exists(CLEAN_OUTPUT_PATH, CLEAN_TABLE_NAME)
    
    return clean_query


In [8]:
# Step 5: Verify tables
def verify_tables():
    print("Available tables:")
    spark.sql("SHOW TABLES").show()
    
    print("\nRaw table schema:")
    spark.sql(f"DESCRIBE EXTENDED {RAW_TABLE_NAME}").show(truncate=False)
    
    print("\nClean table schema:")
    spark.sql(f"DESCRIBE EXTENDED {CLEAN_TABLE_NAME}").show(truncate=False)
    
    print("\nSample data from raw table:")
    spark.sql(f"SELECT * FROM {RAW_TABLE_NAME} LIMIT 5").show()
    
    print("\nSample data from clean table:")
    spark.sql(f"SELECT * FROM {CLEAN_TABLE_NAME} LIMIT 5").show()


## Pipeline

The data is moved from `csv -> kafka topic`. In this part since we produce the data we use batch mode.

In [None]:
# Run the producer
run_producer()

We consume the data later on and get the streaming query

In [None]:
# Run the consumer
raw_query = run_consumer(skip_table_creation=False) # default should be False, if run from scratch change it to false

In [None]:
# Run the cleaner
clean_query = run_cleaner(raw_query)

In [None]:
# Verify tables
verify_tables()

In [9]:
# Print query status
print("\nRaw query status:")
print(raw_query.status)
print(f"Is active: {raw_query.isActive}")

print("\nClean query status:")
print(clean_query.status)
print(f"Is active: {clean_query.isActive}")

# raw_query.stop()
# clean_query.stop()

Waiting for raw data to be processed...
Data exists in /home/jovyan/spark-warehouse/raw_taxi_data
Created table raw_taxi_data using data at /home/jovyan/spark-warehouse/raw_taxi_data
Waiting for clean data to be processed...
Data exists in /home/jovyan/spark-warehouse/clean_taxi_data
Created table clean_taxi_data using data at /home/jovyan/spark-warehouse/clean_taxi_data
Available tables:
+---------+---------------+-----------+
|namespace|      tableName|isTemporary|
+---------+---------------+-----------+
|  default|clean_taxi_data|      false|
|  default|  raw_taxi_data|      false|
+---------+---------------+-----------+


Raw table schema:
+----------------------------+-----------------------------------+-------+
|col_name                    |data_type                          |comment|
+----------------------------+-----------------------------------+-------+
|kafka_key                   |string                             |NULL   |
|kafka_timestamp             |timestamp         

In [10]:
for query in spark.streams.active:
    print(f"Query name: {query.name}")
    print(f"Status: {query.status}")
    print(f"Is active: {query.isActive}")

Query name: clean_taxi_query
Status: {'message': 'Waiting for next trigger', 'isDataAvailable': False, 'isTriggerActive': False}
Is active: True


In [11]:
raw_query.stop()
clean_query.stop()

# EDA

In [None]:
We already have the tables persisted.

In [12]:
raw_df = spark.table(RAW_TABLE_NAME)

In [13]:
raw_df.count()

50388436

In [14]:
cleaned_df = spark.table(CLEAN_TABLE_NAME)

In [15]:
cleaned_df.count()

49208520