In [1]:
from pyspark.sql import SparkSession

spark = (
    SparkSession.builder.appName("nyc-tlc-yellow-tripdata")
    .config("spark.driver.memory", "4g")
    .config("spark.executor.memory", "4g")
    .getOrCreate()
)

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
26/01/29 15:12:28 WARN Utils: Your hostname, tb24-workstation, resolves to a loopback address: 127.0.1.1; using 192.168.1.61 instead (on interface enp7s0)
26/01/29 15:12:28 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
26/01/29 15:12:29 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
26/01/29 15:12:29 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [2]:
spark = SparkSession.builder.appName("nyc-tlc-yellow-tripdata").getOrCreate()

raw = spark.read.option("pathGlobFilter", "yellow_tripdata_2019-*.parquet").parquet(
    "/home/tb24/projects/lakehouse/data/2019/"
)

raw.printSchema()

root
 |-- VendorID: long (nullable = true)
 |-- tpep_pickup_datetime: timestamp_ntz (nullable = true)
 |-- tpep_dropoff_datetime: timestamp_ntz (nullable = true)
 |-- passenger_count: double (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: double (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: long (nullable = true)
 |-- DOLocationID: long (nullable = true)
 |-- payment_type: long (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- airport_fee: integer (nullable = true)



In [3]:
raw.show(5)

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|airport_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|       1| 2019-03-01 00:24:41|  2019-03-01 00:25:31|            1.0|          0.0|       1.0|                 N|         145|         145|           2|        2.5|  0.5|    0.5|       0.

### Perform cleaning on Yellow Trip Data

In [4]:
from pyspark.sql import functions as F

stg = (
    raw.withColumnRenamed("VendorID", "vendor_id")
    .withColumnRenamed("tpep_pickup_datetime", "pickup_ts")
    .withColumnRenamed("tpep_dropoff_datetime", "dropoff_ts")
    .withColumnRenamed("PULocationID", "pickup_location_id")
    .withColumnRenamed("DOLocationID", "dropoff_location_id")
    .withColumnRenamed("Airport_fee", "airport_fee")
    .withColumnRenamed("RatecodeID", "rate_code_id")
    .withColumn(
        "trip_duration_min",
        (F.col("dropoff_ts") - F.col("pickup_ts")) / 60.0,
    )
    .withColumn("trip_duration_min", F.col("trip_duration_min").cast("long"))
    .withColumn("year", F.year("pickup_ts"))
    .withColumn("month", F.month("pickup_ts"))
)

stg.show(5)

+---------+-------------------+-------------------+---------------+-------------+------------+------------------+------------------+-------------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+-----------------+----+-----+
|vendor_id|          pickup_ts|         dropoff_ts|passenger_count|trip_distance|rate_code_id|store_and_fwd_flag|pickup_location_id|dropoff_location_id|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|airport_fee|trip_duration_min|year|month|
+---------+-------------------+-------------------+---------------+-------------+------------+------------------+------------------+-------------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+-----------------+----+-----+
|        1|2019-03-01 00:24:41|2019-03-01 00:25:31|          

In [None]:
# Mark valid trips
stg = stg.withColumn("is_valid_trip", F.col("trip_duration_min") > 0)
stg.show(5)

In [None]:
# standardize money columns to two decimal places
money_cols = [
    "fare_amount",
    "extra",
    "mta_tax",
    "tip_amount",
    "tolls_amount",
    "improvement_surcharge",
    "total_amount",
    "airport_fee",
]
for col in money_cols:
    stg = stg.withColumn(col, F.col(col).cast("decimal(10,2)"))

# Mark trips that have tips
stg = stg.withColumn("has_tip", F.col("tip_amount") > 0).withColumn(
    "tip_ratio",
    F.when(
        F.col("fare_amount") > 0, F.round(F.col("tip_amount") / F.col("fare_amount"), 2)
    ).otherwise(0.0),
)

stg.show(5)

In [None]:
# Filter out "unreasonable" passenger counts
stg = stg.filter((F.col("passenger_count") > 0) & (F.col("passenger_count") <= 6))
stg.show(5)

In [None]:
# Categorical cleanup
