# PySpark on Docker Compose
Distributed tabular ETL with Apache Spark — petabyte-scale joins with fault tolerance and Adaptive Query Execution (AQE).

## Setup

Start the Spark stack and launch Jupyter:

```bash
# 1. Build images
docker compose build

# 2. Start MinIO + Spark + App
docker compose up -d minio minio-setup spark-master
docker compose up -d --scale spark-worker=1 spark-worker app

# 3. Upload sample data
./scripts/upload-data.sh

# 4. Launch Jupyter Lab
docker compose exec app jupyter lab --ip 0.0.0.0 --port 8888 --allow-root --no-browser --notebook-dir=/app/notebook
```

Then open http://localhost:8888 in your browser.

## What is Spark?

Apache Spark is the de facto standard for **petabyte-scale tabular ETL**. Key concepts:

- **SparkSession** — entry point to all Spark functionality
- **DataFrames** — distributed collections with SQL-like API
- **Lazy evaluation** — transformations build a DAG, actions trigger execution
- **Adaptive Query Execution (AQE)** — runtime query re-optimization
- **Fault tolerance** — automatic recovery via RDD lineage

## Architecture

```
Driver (app container) → Executor (spark-worker) → MinIO (S3 storage)
```

The driver plans and coordinates. Workers execute tasks in parallel. Data flows through MinIO as the shared storage layer.

In [1]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.window import Window

spark = (
    SparkSession.builder.appName("Notebook_Spark_ETL")
    .master("spark://spark-master:7077")
    .config("spark.hadoop.fs.s3a.endpoint", "http://minio:9000")
    .config("spark.hadoop.fs.s3a.access.key", "minioadmin")
    .config("spark.hadoop.fs.s3a.secret.key", "minioadmin")
    .config("spark.hadoop.fs.s3a.path.style.access", "true")
    .config("spark.hadoop.fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem")
    .config("spark.sql.adaptive.enabled", "true")
    .getOrCreate()
)
print(f"Spark UI: http://localhost:8080")
print(f"Connected to: {spark.sparkContext.master}")

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
26/02/06 19:59:15 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


Spark UI: http://localhost:8080
Connected to: spark://spark-master:7077


## Explore Data

In [2]:
trips = spark.read.parquet("s3a://lake/taxi/yellow_tripdata_2024-01.parquet")
zones = (
    spark.read.option("header", "true")
    .option("inferSchema", "true")
    .csv("s3a://lake/taxi/taxi_zone_lookup.csv")
)

print(f"Trips: {trips.count():,} rows")
trips.printSchema()
trips.show(5)

print(f"\nZones: {zones.count()} rows")
zones.show(5)

SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.
                                                                                

Trips: 2,964,624 rows
root
 |-- VendorID: integer (nullable = true)
 |-- tpep_pickup_datetime: timestamp_ntz (nullable = true)
 |-- tpep_dropoff_datetime: timestamp_ntz (nullable = true)
 |-- passenger_count: long (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: long (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- payment_type: long (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- Airport_fee: double (nullable = true)

+--------+--------------------+---------------------+---------------+-------------+----------+---------

## Filter Operations

In [3]:
# High-value trips: fare > $10 and distance > 5 miles
filtered = trips.filter((F.col("fare_amount") > 10.0) & (F.col("trip_distance") > 5.0))
print(f"High-value trips: {filtered.count():,} (out of {trips.count():,})")
filtered.select("trip_distance", "fare_amount", "tip_amount", "total_amount").show(10)

High-value trips: 455,098 (out of 2,964,624)
+-------------+-----------+----------+------------+
|trip_distance|fare_amount|tip_amount|total_amount|
+-------------+-----------+----------+------------+
|        10.82|       45.7|      10.0|       64.95|
|         5.44|       31.0|       0.0|        36.0|
|          8.2|       59.0|     14.15|       85.09|
|         23.9|      120.0|       0.0|      127.94|
|         5.88|       28.9|       2.5|        36.4|
|          5.1|       28.9|       0.0|        33.9|
|         8.89|       47.8|      7.92|       60.72|
|        11.51|       44.3|     11.25|       67.49|
|         5.28|       31.0|       7.2|        43.2|
|        11.48|       47.8|     10.56|       63.36|
+-------------+-----------+----------+------------+
only showing top 10 rows


## GroupBy + Aggregation

In [4]:
# Revenue breakdown by payment type
payment_stats = (
    trips.groupBy("payment_type")
    .agg(
        F.count("*").alias("trip_count"),
        F.sum("total_amount").alias("total_revenue"),
        F.avg("tip_amount").alias("avg_tip"),
        F.avg("trip_distance").alias("avg_distance"),
    )
    .orderBy(F.desc("total_revenue"))
)
payment_stats.show()

+------------+----------+-------------------+--------------------+------------------+
|payment_type|trip_count|      total_revenue|             avg_tip|      avg_distance|
+------------+----------+-------------------+--------------------+------------------+
|           1|   2319046|6.553359931006052E7|   4.169670627490679| 3.264648614128438|
|           2|    439191|1.005066921999955E7|0.002296016994883775| 3.259125551297723|
|           0|    140162| 3617824.6300004106|  1.5459567500463327|11.674403475977822|
|           3|     19597| 171581.03999999986|0.014559881614532833|2.1592182476909856|
|           4|     46628|           82710.08| 0.04212511795487692| 3.140549455262926|
+------------+----------+-------------------+--------------------+------------------+



## Join — Enrich with Zone Names

In [5]:
# Join trips with zone lookup to get human-readable pickup locations
enriched = trips.join(zones, trips["PULocationID"] == zones["LocationID"], "inner")
enriched.select("Borough", "Zone", "fare_amount", "trip_distance", "tip_amount").show(
    10
)

+---------+--------------------+-----------+-------------+----------+
|  Borough|                Zone|fare_amount|trip_distance|tip_amount|
+---------+--------------------+-----------+-------------+----------+
|Manhattan|Penn Station/Madi...|       17.7|         1.72|       0.0|
|Manhattan|     Lenox Hill East|       10.0|          1.8|      3.75|
|Manhattan|Upper East Side N...|       23.3|          4.7|       3.0|
|Manhattan|        East Village|       10.0|          1.4|       2.0|
|Manhattan|                SoHo|        7.9|          0.8|       3.2|
|Manhattan|     Lower East Side|       29.6|          4.7|       6.9|
|   Queens|   LaGuardia Airport|       45.7|        10.82|      10.0|
|Manhattan|West Chelsea/Huds...|       25.4|          3.0|       0.0|
|Manhattan|      Midtown Center|       31.0|         5.44|       0.0|
|Manhattan|Greenwich Village...|        3.0|         0.04|       0.0|
+---------+--------------------+-----------+-------------+----------+
only showing top 10 

## Window Functions + Ranking

In [6]:
# Top zones per borough by revenue
window = Window.partitionBy("Borough").orderBy(F.desc("revenue"))

zone_revenue = (
    enriched.groupBy("Borough", "Zone")
    .agg(
        F.sum("total_amount").alias("revenue"),
        F.count("*").alias("trips"),
        F.avg("tip_amount").alias("avg_tip"),
    )
    .withColumn("rank", F.row_number().over(window))
    .filter(F.col("rank") <= 3)
    .orderBy("Borough", "rank")
)
zone_revenue.show(20, truncate=False)

+-------------+--------------------------------+--------------------+------+-------------------+----+
|Borough      |Zone                            |revenue             |trips |avg_tip            |rank|
+-------------+--------------------------------+--------------------+------+-------------------+----+
|Bronx        |Mott Haven/Port Morris          |19992.030000000002  |663   |0.5844042232277527 |1   |
|Bronx        |Co-Op City                      |18409.319999999992  |408   |0.16073529411764706|2   |
|Bronx        |East Concourse/Concourse Village|10710.930000000002  |340   |0.29167647058823526|3   |
|Brooklyn     |Downtown Brooklyn/MetroTech     |45331.860000000015  |1393  |2.946956209619527  |1   |
|Brooklyn     |Brooklyn Heights                |40401.74000000001   |1257  |3.329649960222753  |2   |
|Brooklyn     |Crown Heights North             |32585.66000000001   |1046  |0.8158891013384322 |3   |
|EWR          |Newark Airport                  |30738.030000000006  |295   |12.649

## Full ETL Pipeline

Filter → Join → Aggregate → Rank → Write to S3

In [7]:
window = Window.partitionBy("Borough").orderBy(F.desc("revenue"))

report = (
    trips.filter((F.col("fare_amount") > 10.0) & (F.col("trip_distance") > 0))
    .join(zones, trips["PULocationID"] == zones["LocationID"], "inner")
    .groupBy("Borough", "Zone")
    .agg(
        F.sum("total_amount").alias("revenue"),
        F.count("*").alias("trips"),
        F.avg("tip_amount").alias("avg_tip"),
        F.avg("trip_distance").alias("avg_distance"),
    )
    .withColumn("rank", F.row_number().over(window))
    .orderBy("Borough", "rank")
)

report.write.partitionBy("Borough").mode("overwrite").parquet(
    "s3a://warehouse/notebook_report/"
)
print(f"Wrote {report.count():,} rows to s3a://warehouse/notebook_report/")

Wrote 254 rows to s3a://warehouse/notebook_report/


## Read Back Results

In [8]:
saved = spark.read.parquet("s3a://warehouse/notebook_report/")
saved.filter(F.col("rank") <= 3).orderBy("Borough", "rank").show(20, truncate=False)

+--------------------------------+--------------------+------+-------------------+------------------+----+-------------+
|Zone                            |revenue             |trips |avg_tip            |avg_distance      |rank|Borough      |
+--------------------------------+--------------------+------+-------------------+------------------+----+-------------+
|Co-Op City                      |15927.419999999998  |360   |0.1515             |10.53916666666667 |1   |Bronx        |
|Mott Haven/Port Morris          |14461.579999999998  |383   |0.6086684073107049 |7.175587467362926 |2   |Bronx        |
|East Concourse/Concourse Village|9430.170000000002   |289   |0.13951557093425607|6.745051903114186 |3   |Bronx        |
|Downtown Brooklyn/MetroTech     |40539.560000000005  |1099  |3.2706278434940863 |5.977333939945408 |1   |Brooklyn     |
|Brooklyn Heights                |36904.32            |995   |3.6088341708542715 |78.14236180904523 |2   |Brooklyn     |
|East New York                  

## Cleanup

In [9]:
spark.stop()
print("SparkSession stopped.")

SparkSession stopped.
