In [1]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("example").getOrCreate()

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
26/01/27 12:51:42 WARN Utils: Your hostname, tb24-workstation, resolves to a loopback address: 127.0.1.1; using 192.168.1.61 instead (on interface enp7s0)
26/01/27 12:51:42 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
26/01/27 12:51:43 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [2]:
spark = SparkSession.builder.appName("example").getOrCreate()

raw = spark.read.option("pathGlobFilter", "yellow_tripdata_2024*.parquet").parquet(
    "/home/tb24/projects/lakehouse/data/2024/"
)

raw.printSchema()

root
 |-- VendorID: integer (nullable = true)
 |-- tpep_pickup_datetime: timestamp_ntz (nullable = true)
 |-- tpep_dropoff_datetime: timestamp_ntz (nullable = true)
 |-- passenger_count: long (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: long (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- payment_type: long (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- Airport_fee: double (nullable = true)



### Perform cleaning on Yellow Trip Data

In [3]:
from pyspark.sql import functions as F

stg = (
    raw.withColumnRenamed("VendorID", "vendor_id")
    .withColumnRenamed("tpep_pickup_datetime", "pickup_ts")
    .withColumnRenamed("tpep_dropoff_datetime", "dropoff_ts")
    .withColumnRenamed("PULocationID", "pickup_location_id")
    .withColumnRenamed("DOLocationID", "dropoff_location_id")
    .withColumnRenamed("Airport_fee", "airport_fee")
    .withColumnRenamed("RatecodeID", "rate_code_id")
    .withColumn(
        "trip_duration_min",
        (F.col("dropoff_ts") - F.col("pickup_ts")) / 60.0,
    )
    .withColumn("trip_duration_min", F.col("trip_duration_min").cast("long"))
    .withColumn("trip_year", F.year("pickup_ts"))
    .withColumn("trip_month", F.month("pickup_ts"))
)

stg.show(5)

+---------+-------------------+-------------------+---------------+-------------+------------+------------------+------------------+-------------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+-----------------+---------+----------+
|vendor_id|          pickup_ts|         dropoff_ts|passenger_count|trip_distance|rate_code_id|store_and_fwd_flag|pickup_location_id|dropoff_location_id|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|airport_fee|trip_duration_min|trip_year|trip_month|
+---------+-------------------+-------------------+---------------+-------------+------------+------------------+------------------+-------------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+-----------------+---------+----------+
|        1|2024-06-01 00:03:46|

In [4]:
# Mark valid trips
stg = stg.withColumn("is_valid_trip", F.col("trip_duration_min") > 0)
stg.show(5)

+---------+-------------------+-------------------+---------------+-------------+------------+------------------+------------------+-------------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+-----------------+---------+----------+-------------+
|vendor_id|          pickup_ts|         dropoff_ts|passenger_count|trip_distance|rate_code_id|store_and_fwd_flag|pickup_location_id|dropoff_location_id|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|airport_fee|trip_duration_min|trip_year|trip_month|is_valid_trip|
+---------+-------------------+-------------------+---------------+-------------+------------+------------------+------------------+-------------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+-----------------+---------+----------+----

In [5]:
# standardize money columns to two decimal places
money_cols = [
    "fare_amount",
    "extra",
    "mta_tax",
    "tip_amount",
    "tolls_amount",
    "improvement_surcharge",
    "total_amount",
    "airport_fee",
]
for col in money_cols:
    stg = stg.withColumn(col, F.col(col).cast("decimal(10,2)"))

# Mark trips that have tips
stg = stg.withColumn("has_tip", F.col("tip_amount") > 0).withColumn(
    "tip_ratio",
    F.when(
        F.col("fare_amount") > 0, F.round(F.col("tip_amount") / F.col("fare_amount"), 2)
    ).otherwise(0.0),
)

stg.show(5)

+---------+-------------------+-------------------+---------------+-------------+------------+------------------+------------------+-------------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+-----------------+---------+----------+-------------+-------+---------+
|vendor_id|          pickup_ts|         dropoff_ts|passenger_count|trip_distance|rate_code_id|store_and_fwd_flag|pickup_location_id|dropoff_location_id|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|airport_fee|trip_duration_min|trip_year|trip_month|is_valid_trip|has_tip|tip_ratio|
+---------+-------------------+-------------------+---------------+-------------+------------+------------------+------------------+-------------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+-------

In [7]:
# Filter out "unreasonable" passenger counts
stg = stg.filter((F.col("passenger_count") > 0) & (F.col("passenger_count") <= 6))
stg.show(5)

+---------+-------------------+-------------------+---------------+-------------+------------+------------------+------------------+-------------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+-----------------+---------+----------+-------------+-------+---------+
|vendor_id|          pickup_ts|         dropoff_ts|passenger_count|trip_distance|rate_code_id|store_and_fwd_flag|pickup_location_id|dropoff_location_id|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|airport_fee|trip_duration_min|trip_year|trip_month|is_valid_trip|has_tip|tip_ratio|
+---------+-------------------+-------------------+---------------+-------------+------------+------------------+------------------+-------------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+-------

In [None]:
# Categorical cleanup
