## Setup

In [0]:
import pandas as pd
import time
from pyspark.sql.functions import col, hour, from_unixtime, count, max, avg, sum as spark_sum, desc, when

# Verify Spark is ready (spark is pre-created in Databricks)
print(f"✅ Spark version: {spark.version}")
print(f"✅ Ready to go!")

✅ Spark version: 4.0.0
✅ Ready to go!


I'm using NYC taxi data for 2019. Here is the schema for it: 
- VendorID: integer (nullable = true)
- tpep_pickup_datetime: timestamp (nullable = true)
- tpep_dropoff_datetime: timestamp (nullable = true)
- passenger_count: integer (nullable = true)
- trip_distance: double (nullable = true)
- RatecodeID: integer (nullable = true)
- store_and_fwd_flag: string (nullable = true)
- PULocationID: integer (nullable = true)
- DOLocationID: integer (nullable = true)
- payment_type: integer (nullable = true)
- fare_amount: double (nullable = true)
- extra: double (nullable = true)
- mta_tax: double (nullable = true)
- tip_amount: double (nullable = true)
- tolls_amount: double (nullable = true)
- improvement_surcharge: double (nullable = true)
- total_amount: double (nullable = true)
- congestion_surcharge: double (nullable = true)


In [0]:
# Path to the CSV.gz file
taxi_file = "dbfs:/databricks-datasets/nyctaxi/tripdata/yellow/yellow_tripdata_2019-11.csv.gz"

# Read with Spark
taxi_df = spark.read.csv(
    taxi_file,
    header=True,        # file has header row
    inferSchema=True    # let Spark infer column types
)

taxi_df.count()

6878111

This dataset has around 6.9 million rows. Let's duplicate it to get even bigger dataset of around 20 million rows.

In [0]:
# Create large dataset by duplicating
print("Creating large dataset (3x duplication)...")

taxi_df_large = taxi_df
for i in range(2):
    taxi_df_large = taxi_df_large.union(taxi_df)

print("Created ~1 GB dataset with 20M rows")

Creating large dataset (3x duplication)...
Created ~1 GB dataset with 20M rows


In [0]:
taxi_df_large.count()

20634333

In [0]:
taxi_df.printSchema()

root
 |-- VendorID: integer (nullable = true)
 |-- tpep_pickup_datetime: timestamp (nullable = true)
 |-- tpep_dropoff_datetime: timestamp (nullable = true)
 |-- passenger_count: integer (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: integer (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- payment_type: integer (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)



# Applying 2+ Filters

In [0]:
# Applying Filters
taxi_filtered = taxi_df_large.filter(
    (col("fare_amount") > 0) & 
    (col("trip_distance") > 0) & 
    (col("passenger_count") > 0)
)

taxi_filtered.count()

19853730

# Transformation

In [0]:
# Transformation 1: Tip percentage
taxi_transformed = taxi_filtered.withColumn("tip_percentage", col("tip_amount") / col("fare_amount") * 100)

# Transformation 2: Extract hour of pickup
taxi_transformed = taxi_transformed.withColumn("pickup_hour", hour(from_unixtime(col("tpep_pickup_datetime").cast("long"))))

display(taxi_transformed.limit(10))

VendorID,tpep_pickup_datetime,tpep_dropoff_datetime,passenger_count,trip_distance,RatecodeID,store_and_fwd_flag,PULocationID,DOLocationID,payment_type,fare_amount,extra,mta_tax,tip_amount,tolls_amount,improvement_surcharge,total_amount,congestion_surcharge,tip_percentage,pickup_hour
1,2019-11-01T00:35:48.000Z,2019-11-01T01:13:12.000Z,1,8.5,1,N,138,161,1,32.0,3.0,0.5,8.35,6.12,0.3,50.27,2.5,26.09375,0
1,2019-11-01T00:02:59.000Z,2019-11-01T00:05:26.000Z,1,0.3,1,N,239,142,2,3.5,3.0,0.5,0.0,0.0,0.3,7.3,2.5,0.0,0
1,2019-11-01T00:11:43.000Z,2019-11-01T00:19:43.000Z,1,1.9,1,N,142,262,2,8.5,3.0,0.5,0.0,0.0,0.3,12.3,2.5,0.0,0
1,2019-11-01T00:11:34.000Z,2019-11-01T00:18:02.000Z,1,0.5,1,N,246,186,2,5.5,3.0,0.5,0.0,0.0,0.3,9.3,2.5,0.0,0
1,2019-11-01T00:54:04.000Z,2019-11-01T01:05:24.000Z,2,2.4,1,N,246,239,1,10.0,3.0,0.5,2.75,0.0,0.3,16.55,2.5,27.500000000000004,0
1,2019-11-01T00:33:59.000Z,2019-11-01T00:39:56.000Z,1,1.6,1,N,50,238,1,7.0,3.0,0.5,2.15,0.0,0.3,12.95,2.5,30.71428571428571,0
1,2019-11-01T00:52:02.000Z,2019-11-01T01:35:35.000Z,1,9.3,1,N,48,61,1,33.5,3.0,0.5,7.45,0.0,0.3,44.75,2.5,22.23880597014925,0
1,2019-11-01T00:15:08.000Z,2019-11-01T00:19:25.000Z,1,0.6,1,N,142,239,2,5.0,3.0,0.5,0.0,0.0,0.3,8.8,2.5,0.0,0
1,2019-11-01T00:23:14.000Z,2019-11-01T00:36:08.000Z,1,1.9,1,N,142,186,1,11.0,3.0,0.5,2.95,0.0,0.3,17.75,2.5,26.81818181818182,0
1,2019-11-01T00:47:26.000Z,2019-11-01T00:59:08.000Z,1,2.1,1,N,163,170,1,10.0,3.0,0.5,2.75,0.0,0.3,16.55,2.5,27.500000000000004,0


# Complex Aggregation

In [0]:

# Cast fare_amount and tip_percentage to float
taxi_transformed = taxi_transformed.withColumn("fare_amount", col("fare_amount").cast("double")) \
                                   .withColumn("tip_percentage", col("tip_percentage").cast("double"))
# Average fare, average tip percentage, maximum fare, total trips for each hour
hourly_complex_agg = taxi_transformed.groupBy("pickup_hour").agg(
    avg("fare_amount").alias("avg_fare"),
    avg("tip_percentage").alias("avg_tip_pct"),
    max("fare_amount").alias("max_fare"),
    spark_sum("fare_amount").alias("total_fare"),
    count("*").alias("total_trips")
).orderBy("pickup_hour")

display(hourly_complex_agg.limit(10))

pickup_hour,avg_fare,avg_tip_pct,max_fare,total_fare,total_trips
0,13.302613736319934,26.924299041446247,442.0,7803579.27,586620
1,12.225052927227626,19.12361856931616,318.0,5366786.010000001,438999
2,11.954951097452083,18.91125985876737,550.5,3454287.4800000004,288942
3,12.772131607434174,17.686672075534933,300.0,2513134.02,196767
4,15.358676019996771,15.751605829538262,400.0,2285739.6,148824
5,17.390758171304793,16.638180383578916,455.0,2986497.5100000007,171729
6,13.604029999675504,17.58281081420658,425.0,5030824.71,369804
7,12.312448536075207,22.89384405970244,512.0,8600232.989999996,698499
8,12.47170109599748,19.27588696097258,386.0,10818315.209999995,867429
9,12.539675225075024,18.71763634282345,495.5,11281945.799999995,899700


# Write results 

In [0]:
# Create a result to save
result_to_save = hourly_complex_agg \
    .filter(col('avg_tip_pct') > 20)

print(f"Preview of what would be saved: {result_to_save.count():,} rows")
result_to_save.show(5)


Preview of what would be saved: 4 rows
+-----------+------------------+------------------+--------+--------------------+-----------+
|pickup_hour|          avg_fare|       avg_tip_pct|max_fare|          total_fare|total_trips|
+-----------+------------------+------------------+--------+--------------------+-----------+
|          0|13.302613736319934|26.924299041446247|   442.0|          7803579.27|     586620|
|          7|12.312448536075207|22.893844059702438|   512.0|   8600232.989999996|     698499|
|         10|12.690002187859948|23.598701732766017|   756.0|1.1832386459999997E7|     932418|
|         23|13.314039225476728|21.606143839976152|   616.0|1.0879194359999996E7|     817122|
+-----------+------------------+------------------+--------+--------------------+-----------+



# SQL Queries

In [0]:
# Register temp view
taxi_transformed.createOrReplaceTempView("taxi")

# SQL Query 1: Hour with highest average fare
spark.sql("""
SELECT pickup_hour, AVG(fare_amount) AS avg_fare
FROM taxi
GROUP BY pickup_hour
ORDER BY avg_fare DESC
LIMIT 5
""").show()

# SQL Query 2: Hour with most trips
spark.sql("""
SELECT pickup_hour, COUNT(*) AS num_trips
FROM taxi
GROUP BY pickup_hour
ORDER BY num_trips DESC
LIMIT 5
""").show()


+-----------+------------------+
|pickup_hour|          avg_fare|
+-----------+------------------+
|          5|17.390758171304793|
|          4|15.358676019996771|
|         16|14.146448771455097|
|         14| 14.10013912670659|
|         15|14.020786022254226|
+-----------+------------------+

+-----------+---------+
|pickup_hour|num_trips|
+-----------+---------+
|         18|  1308624|
|         19|  1233921|
|         17|  1190640|
|         15|  1128594|
|         21|  1122327|
+-----------+---------+



# Performance Analysis

In [0]:
print("Execution plan:")
hourly_complex_agg.explain()

Execution plan:
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- == Initial Plan ==
   ColumnarToRow
   +- PhotonResultStage
      +- PhotonSort [pickup_hour#12398 ASC NULLS FIRST]
         +- PhotonShuffleExchangeSource
            +- PhotonShuffleMapStage ENSURE_REQUIREMENTS, [id=#13243]
               +- PhotonShuffleExchangeSink rangepartitioning(pickup_hour#12398 ASC NULLS FIRST, 1024)
                  +- PhotonGroupingAgg(keys=[pickup_hour#12398], functions=[finalmerge_sum(merge sum#12878) AS sum(fare_amount)#12873, finalmerge_count(merge count#12880L) AS count(fare_amount)#12874L, finalmerge_avg(merge sum#12883, count#12884L) AS avg(tip_percentage)#12868, finalmerge_max(merge max#12886) AS max(fare_amount)#12869, finalmerge_count(merge count#12888L) AS count(1)#12866L])
                     +- PhotonShuffleExchangeSource
                        +- PhotonShuffleMapStage ENSURE_REQUIREMENTS, [id=#13237]
                           +- PhotonShuffleExchangeSink hashpartiti


### Query Optimization and Filter Pushdown

Spark and Photon (Databricks' vectorized engine) aggressively optimized this pipeline using several key techniques. The most significant is the full utilization of **Photon**, indicated by the `Photon*` operators throughout the plan, which provides vectorized processing for substantial speedups. 

The plan also implemented **Column Pruning** by only reading the necessary columns (`tpep_pickup_datetime`, `passenger_count`, `trip_distance`, `fare_amount`, `tip_amount`) from the CSV files, discarding the rest immediately. Critical performance optimization came from **Filter Pushdown** (Predicate Pushdown): the explicit filters (`isnotnull`, `passenger_count > 0`, `fare_amount > 0.0`, `trip_distance > 0.0`) were pushed down directly to the `FileScan csv` operation. 

This filtering is done *before* the data even enters the main processing stages (`PhotonRowToColumnar`), drastically reducing the volume of data transferred from storage, which is the most effective optimization for I/O-bound queries.

### Performance Bottlenecks and Pipeline Strategy

The query utilized a common and highly effective pipeline optimization strategy: **Two-Stage Aggregation**. The `PhotonGroupingAgg` is split into a partial phase (calculating `partial_sum`, `partial_count`, etc.) followed by a final merge phase (`finalmerge_sum`, `finalmerge_count`). This partial aggregation is performed on the worker nodes *before* the first major data movement, or **Shuffle** (`PhotonShuffleExchangeSink`). 

The Shuffle is always a potential performance bottleneck, as it requires sending data across the network, but by aggregating locally first, Spark significantly reduces the amount of data that needs to be shuffled, minimizing this cost. A second Shuffle (`rangepartitioning`) is required later to perform the final sort (`PhotonSort`) for the `ORDER BY` clause of the query. The filters were placed right after the file scan and before any complex calculations (like generating `tip_percentage`), which is the optimal ordering for column pruning and filtering.

# Action vs Transformations

In [0]:
import time

# Transformations - lazy (just build a plan)
start = time.time()
filtered = taxi_df_large.filter(col('passenger_count') > 2)
selected = filtered.select('fare_amount', 'trip_distance', 'passenger_count')
print(f"Transformations: {time.time() - start:.4f}s")

# Actions - eager (trigger execution)
print("\nAction 1:")
start = time.time()
count = selected.count()
print(f"count() = {count} rows, took {time.time() - start:.2f}s")

print("\nAction 2:")
start = time.time()
selected.show(5)
print(f"show() took {time.time() - start:.2f}s")

print("\nNotice: Each action re-executes the transformations!")

Transformations: 0.0004s

Action 1:
count() = 2503404 rows, took 8.86s

Action 2:
+-----------+-------------+---------------+
|fare_amount|trip_distance|passenger_count|
+-----------+-------------+---------------+
|        7.5|          1.1|              3|
|        6.5|          0.9|              3|
|       17.5|         4.15|              4|
|       15.5|          2.5|              4|
|        6.5|          0.9|              4|
+-----------+-------------+---------------+
only showing top 5 rows
show() took 0.33s

Notice: Each action re-executes the transformations!
