In [62]:
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql import types

In [1]:
# Download the trip data
!mkdir data/raw/yellow/2024
!wget --directory=data/raw/yellow/2024/ https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2024-10.parquet

--2026-01-17 17:13:52--  https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2024-10.parquet
Resolving d37ci6vzurychx.cloudfront.net (d37ci6vzurychx.cloudfront.net)... 3.169.167.13, 3.169.167.152, 3.169.167.112, ...
Connecting to d37ci6vzurychx.cloudfront.net (d37ci6vzurychx.cloudfront.net)|3.169.167.13|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 64346071 (61M) [binary/octet-stream]
Saving to: ‘data/raw/yellow/2024/yellow_tripdata_2024-10.parquet’


2026-01-17 17:13:54 (41.2 MB/s) - ‘data/raw/yellow/2024/yellow_tripdata_2024-10.parquet’ saved [64346071/64346071]



In [132]:
# Download the zone data
!wget --directory=data/raw/zones/ https://d37ci6vzurychx.cloudfront.net/misc/taxi_zone_lookup.csv

--2026-01-17 18:23:50--  https://d37ci6vzurychx.cloudfront.net/misc/taxi_zone_lookup.csv
Resolving d37ci6vzurychx.cloudfront.net (d37ci6vzurychx.cloudfront.net)... 3.169.167.222, 3.169.167.13, 3.169.167.152, ...
Connecting to d37ci6vzurychx.cloudfront.net (d37ci6vzurychx.cloudfront.net)|3.169.167.222|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 12331 (12K) [text/csv]
Saving to: ‘data/raw/zones/taxi_zone_lookup.csv’


2026-01-17 18:23:50 (653 MB/s) - ‘data/raw/zones/taxi_zone_lookup.csv’ saved [12331/12331]



In [None]:
spark = SparkSession.builder \
    .master("local[*]") \
    .appName("hw") \
    .getOrCreate()

In [4]:
spark.version

'4.1.1'

In [175]:
df = spark.read.parquet("data/raw/yellow/2024/yellow_tripdata_2024-10.parquet")
df.head(3)

[Row(VendorID=2, tpep_pickup_datetime=datetime.datetime(2024, 10, 1, 0, 30, 44), tpep_dropoff_datetime=datetime.datetime(2024, 10, 1, 0, 48, 26), passenger_count=1, trip_distance=3.0, RatecodeID=1, store_and_fwd_flag='N', PULocationID=162, DOLocationID=246, payment_type=1, fare_amount=18.4, extra=1.0, mta_tax=0.5, tip_amount=1.5, tolls_amount=0.0, improvement_surcharge=1.0, total_amount=24.9, congestion_surcharge=2.5, Airport_fee=0.0),
 Row(VendorID=1, tpep_pickup_datetime=datetime.datetime(2024, 10, 1, 0, 12, 20), tpep_dropoff_datetime=datetime.datetime(2024, 10, 1, 0, 25, 25), passenger_count=1, trip_distance=2.2, RatecodeID=1, store_and_fwd_flag='N', PULocationID=48, DOLocationID=236, payment_type=1, fare_amount=14.2, extra=3.5, mta_tax=0.5, tip_amount=3.8, tolls_amount=0.0, improvement_surcharge=1.0, total_amount=23.0, congestion_surcharge=2.5, Airport_fee=0.0),
 Row(VendorID=1, tpep_pickup_datetime=datetime.datetime(2024, 10, 1, 0, 4, 46), tpep_dropoff_datetime=datetime.datetime(2

In [9]:
df.repartition(4) \
    .write.parquet("data/pq/yellow/2024/")

                                                                                

In [165]:
!ls -lh data/pq/yellow/2024

total 188928
-rw-r--r--  1 sashkawarner  staff     0B Jan 17 17:19 _SUCCESS
-rw-r--r--  1 sashkawarner  staff    22M Jan 17 17:19 part-00000-1265ba44-eefc-48a3-945f-1832d6902bb5-c000.snappy.parquet
-rw-r--r--  1 sashkawarner  staff    22M Jan 17 17:19 part-00001-1265ba44-eefc-48a3-945f-1832d6902bb5-c000.snappy.parquet
-rw-r--r--  1 sashkawarner  staff    22M Jan 17 17:19 part-00002-1265ba44-eefc-48a3-945f-1832d6902bb5-c000.snappy.parquet
-rw-r--r--  1 sashkawarner  staff    22M Jan 17 17:19 part-00003-1265ba44-eefc-48a3-945f-1832d6902bb5-c000.snappy.parquet


In [176]:
df.columns

['VendorID',
 'tpep_pickup_datetime',
 'tpep_dropoff_datetime',
 'passenger_count',
 'trip_distance',
 'RatecodeID',
 'store_and_fwd_flag',
 'PULocationID',
 'DOLocationID',
 'payment_type',
 'fare_amount',
 'extra',
 'mta_tax',
 'tip_amount',
 'tolls_amount',
 'improvement_surcharge',
 'total_amount',
 'congestion_surcharge',
 'Airport_fee']

In [177]:
# Format pickup date and count # of trips for provided date filter
df = df.withColumn("pickup_date", F.date_format("tpep_pickup_datetime", "yyyy-MM-dd"))
df.filter(df.pickup_date == '2024-10-15') \
    .count()

128893

In [178]:
# Same as above but using SQL
df.createOrReplaceTempView("trips")
spark.sql("""
    SELECT
        COUNT(*)
    FROM trips
    WHERE pickup_date = '2024-10-15';
""").show()

+--------+
|count(1)|
+--------+
|  128893|
+--------+



In [179]:
# Convert dates to TimestampType
df = df.withColumn("tpep_pickup_datetime", F.to_timestamp(df.tpep_pickup_datetime)) \
    .withColumn("tpep_dropoff_datetime", F.to_timestamp(df.tpep_dropoff_datetime))
df.printSchema()

root
 |-- VendorID: integer (nullable = true)
 |-- tpep_pickup_datetime: timestamp (nullable = true)
 |-- tpep_dropoff_datetime: timestamp (nullable = true)
 |-- passenger_count: long (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: long (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- payment_type: long (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- Airport_fee: double (nullable = true)
 |-- pickup_date: string (nullable = true)



In [181]:
# Calculate trip duration
df = df.withColumn("trip_duration", (df.tpep_dropoff_datetime.cast("long") - df.tpep_pickup_datetime.cast("long")) / 3600)
time_cols = ["tpep_pickup_datetime", "tpep_dropoff_datetime", "trip_duration"]
df.select(time_cols) \
    .show(3)

+--------------------+---------------------+-------------------+
|tpep_pickup_datetime|tpep_dropoff_datetime|      trip_duration|
+--------------------+---------------------+-------------------+
| 2024-10-01 00:30:44|  2024-10-01 00:48:26|              0.295|
| 2024-10-01 00:12:20|  2024-10-01 00:25:25|0.21805555555555556|
| 2024-10-01 00:04:46|  2024-10-01 00:13:52|0.15166666666666667|
+--------------------+---------------------+-------------------+
only showing top 3 rows


In [182]:
df.sort(F.desc("trip_duration")) \
    .select(time_cols) \
    .show(3)



+--------------------+---------------------+------------------+
|tpep_pickup_datetime|tpep_dropoff_datetime|     trip_duration|
+--------------------+---------------------+------------------+
| 2024-10-16 13:03:49|  2024-10-23 07:40:53|162.61777777777777|
| 2024-10-03 18:47:25|  2024-10-09 18:06:55|           143.325|
| 2024-10-22 16:00:55|  2024-10-28 09:46:33|137.76055555555556|
+--------------------+---------------------+------------------+
only showing top 3 rows


                                                                                

In [183]:
df.printSchema()

root
 |-- VendorID: integer (nullable = true)
 |-- tpep_pickup_datetime: timestamp (nullable = true)
 |-- tpep_dropoff_datetime: timestamp (nullable = true)
 |-- passenger_count: long (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: long (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- payment_type: long (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- Airport_fee: double (nullable = true)
 |-- pickup_date: string (nullable = true)
 |-- trip_duration: double (nullable = true)



In [184]:
zone_schema = types.StructType([
    types.StructField(name="LocationID", dataType=types.IntegerType(), nullable=False),
    types.StructField(name="Borough", dataType=types.StringType(), nullable=False),
    types.StructField(name="Zone", dataType=types.StringType(), nullable=False),
    types.StructField(name="service_zone", dataType=types.StringType(), nullable=False),
])
zones = spark.read \
    .option("header", "true") \
    .schema(zone_schema) \
    .csv("data/raw/zones/")

zones.printSchema()
zones.show()

root
 |-- LocationID: integer (nullable = true)
 |-- Borough: string (nullable = true)
 |-- Zone: string (nullable = true)
 |-- service_zone: string (nullable = true)

+----------+-------------+--------------------+------------+
|LocationID|      Borough|                Zone|service_zone|
+----------+-------------+--------------------+------------+
|         1|          EWR|      Newark Airport|         EWR|
|         2|       Queens|         Jamaica Bay|   Boro Zone|
|         3|        Bronx|Allerton/Pelham G...|   Boro Zone|
|         4|    Manhattan|       Alphabet City| Yellow Zone|
|         5|Staten Island|       Arden Heights|   Boro Zone|
|         6|Staten Island|Arrochar/Fort Wad...|   Boro Zone|
|         7|       Queens|             Astoria|   Boro Zone|
|         8|       Queens|        Astoria Park|   Boro Zone|
|         9|       Queens|          Auburndale|   Boro Zone|
|        10|       Queens|        Baisley Park|   Boro Zone|
|        11|     Brooklyn|          Bat

In [187]:
# Update the trips view
df.createOrReplaceTempView("trips")

In [188]:
# Load zones into a temp view
zones.createOrReplaceTempView("zones")

In [198]:
spark.sql("""
    SELECT
        Zone,
        COUNT(*)
    FROM trips AS t
    LEFT JOIN zones AS z
        ON t.PULocationID = z.LocationID
    GROUP BY z.Zone
    ORDER BY COUNT(*);
""").show(3, truncate=False)

+---------------------------------------------+--------+
|Zone                                         |count(1)|
+---------------------------------------------+--------+
|Governor's Island/Ellis Island/Liberty Island|1       |
|Rikers Island                                |2       |
|Arden Heights                                |2       |
+---------------------------------------------+--------+
only showing top 3 rows


In [199]:
# Same as above, but using DataFrames
df_join = df.join(other=zones, on=df.PULocationID == zones.LocationID, how="left")
join_cols = ["pickup_date", "PULocationID", "LocationID", "Zone", "trip_distance", "trip_duration"]
# df_join.select(join_cols) \
#     .show()
df_join.groupBy("Zone") \
    .count() \
    .sort("count") \
    .show(3, truncate=False)

+---------------------------------------------+-----+
|Zone                                         |count|
+---------------------------------------------+-----+
|Governor's Island/Ellis Island/Liberty Island|1    |
|Rikers Island                                |2    |
|Arden Heights                                |2    |
+---------------------------------------------+-----+
only showing top 3 rows
