In [None]:
import time, os
from utils.generic import load_environment_variables
from utils.generate_data import generate_nyctaxistream_async
from utils.kafka import write_to_kafka_async,list_kafka_topics,delete_kafka_topic,create_kafka_topic,close_kafka_admin
from utils.spark import start_spark_session,read_from_kafka,read_spark_parquet,write_spark_stream
import asyncio
from delta.tables import DeltaTable
from dotenv import load_dotenv

# load env vars
config = load_environment_variables()

topicname = "nyctaxistream"
kafka_bootstrap_servers = os.getenv("KAFKA_BOOTSTRAP_SERVERS")
delta_path = os.getenv("DELTA_PATH_NYCTAXI")
write_folder_path = os.getenv("WRITE_FOLDER_PATH_NYCTAXI")
schema = os.getenv("SCHEMA_NYCTAXI")
merge_keys = os.getenv("MERGE_KEYS_NYCTAXI").split(",")
raw_file_path = os.path.join(os.getenv("RAW_PATH_NYCTAXI"),os.getenv("RAW_FILE_NAME"))


In [None]:
spark = start_spark_session()

# ===============================
#  Spark session tuning configs
# ===============================

# --- Parallelism and partitioning ---
spark.conf.set("spark.sql.shuffle.partitions", "8")           # Reduce shuffle writers
spark.conf.set("spark.default.parallelism", "8")              # Match cores (for RDD operations)
spark.conf.set("spark.sql.files.maxPartitionBytes", "256m")    # Smaller data chunks per writer

# --- Parquet / Delta write optimizations ---
spark.conf.set("spark.sql.parquet.enableDictionary", "false") # Disable dictionary encoding (major OOM saver)
spark.conf.set("spark.sql.parquet.compression.codec", "snappy")  # Lightweight compression
spark.conf.set("spark.sql.parquet.writer.maxBlockSize", "64m")   # Smaller write blocks
spark.conf.set("spark.sql.parquet.writer.maxRowGroupSize", "32m")# Smaller in-memory row groups

# # --- Optional Delta-specific safety ---
# spark.conf.set("spark.databricks.delta.optimizeWrite.enabled", "false")  # Turn off small-batch aggregation (uses memory)
# spark.conf.set("spark.databricks.delta.autoCompact.enabled", "false")    # Disable automatic compaction

# # --- Networking / timeout resiliency ---
# spark.conf.set("spark.network.timeout", "600s")               # Prevent premature heartbeats
# spark.conf.set("spark.executor.heartbeatInterval", "60s")     # Give executors longer to report health
# spark.conf.set("spark.rpc.message.maxSize", "256")            # Increase for large commits

# # --- Arrow / Pandas interop (disable unless you need it) ---
# spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "false")

# # --- Miscellaneous ---
# spark.conf.set("spark.sql.adaptive.enabled", "true")          # Allow AQE to optimize join/shuffle
# spark.conf.set("spark.sql.execution.reuseSubquery", "true")   # Avoid recomputation in nested queries

In [None]:
# Check if Delta table exists
# If not, read raw parquet data and write as Delta table
try:
    if not(DeltaTable.isDeltaTable(spark, delta_path)):
        df_raw = read_spark_parquet(spark,raw_file_path,)
        df_raw.write.format("delta").mode("overwrite").save(delta_path)
finally:    
    print(f"Delta table exists at {delta_path}")

In [None]:
#play with this
params_read_from_kafka = {
    "spark": spark,
    "topic": topicname,
    "schema": schema,
    "startingOffsets": "earliest"
}

df_stream = read_from_kafka(**params_read_from_kafka )

In [None]:
####delete rows where dispatching_base_num is B03494 or B03496 so that we can see them re-ingested from kafka stream

from delta.tables import DeltaTable

# Load the Delta table
delta_table = DeltaTable.forPath(spark, delta_path)

# Delete rows where dispatching_base_num is B03494 or B03496
delta_table.delete("dispatching_base_num IN ('B03494', 'B03496')")

In [None]:
#play with this
params_write_spark_stream = {
"parsed_df": df_stream,
"spark": spark,
"topic": "nycstream",
"write_to": "memory",
"trigger_mode": "once",
"output_path": delta_path,
"output_mode": "append",
"merge_keys": merge_keys,
"interval_seconds": 5,
}

query = write_spark_stream(**params_write_spark_stream)

# Wait until the query finishes
try:
    query.awaitTermination()
    if params_write_spark_stream["trigger_mode"] == "continuous": 
        spark.sql("select * from nycstream").show(10)
finally:
    if params_write_spark_stream["trigger_mode"] != "continuous":
        query.stop()

In [None]:
spark.sql("select count(*) from nycstream").show(10)

In [None]:
spark.read.format("delta") \
    .load(delta_path) \
    .filter("dispatching_base_num = 'B03494'") \
    .count()


In [None]:
from delta.tables import DeltaTable

# Load Delta table
delta_table = DeltaTable.forPath(spark, delta_path)

# Filter rows
matching_rows = delta_table.toDF().filter("dispatching_base_num IN ('B03494', 'B03496')")

# Show matching rows
# matching_rows.show(truncate=False)

# If you want the count
print("Count:", matching_rows.count())

In [None]:
####TODO
# streaming dataframe has to re-created if kafka starts.. no live viewing.
# trigger mode continuous .. foreachbatch not supported
# play with startingOffsets