In [None]:
from pyspark.sql import SparkSession

spark = (
    SparkSession.builder.appName("nyc-tlc-yellow-tripdata")
    .config("spark.driver.memory", "4g")
    .config("spark.executor.memory", "4g")
    .getOrCreate()
)

In [None]:
spark = SparkSession.builder.appName("nyc-tlc-yellow-tripdata").getOrCreate()

raw = spark.read.option("pathGlobFilter", "yellow_tripdata_2015-*.parquet").parquet(
    "/home/tb24/projects/lakehouse/data/2015/"
)

raw.printSchema()

In [None]:
raw.show(5)

### Perform cleaning on Yellow Trip Data

In [None]:
from pyspark.sql import functions as F

stg = (
    raw.withColumnRenamed("VendorID", "vendor_id")
    .withColumnRenamed("tpep_pickup_datetime", "pickup_ts")
    .withColumnRenamed("tpep_dropoff_datetime", "dropoff_ts")
    .withColumnRenamed("PULocationID", "pickup_location_id")
    .withColumnRenamed("DOLocationID", "dropoff_location_id")
    .withColumnRenamed("Airport_fee", "airport_fee")
    .withColumnRenamed("RatecodeID", "rate_code_id")
    .withColumn(
        "trip_duration_min",
        (F.col("dropoff_ts") - F.col("pickup_ts")) / 60.0,
    )
    .withColumn("trip_duration_min", F.col("trip_duration_min").cast("long"))
    .withColumn("year", F.year("pickup_ts"))
    .withColumn("month", F.month("pickup_ts"))
)

stg.show(5)

In [None]:
# Mark valid trips
stg = stg.withColumn("is_valid_trip", F.col("trip_duration_min") > 0)
stg.show(5)

In [None]:
# standardize money columns to two decimal places
money_cols = [
    "fare_amount",
    "extra",
    "mta_tax",
    "tip_amount",
    "tolls_amount",
    "improvement_surcharge",
    "total_amount",
    "airport_fee",
]
for col in money_cols:
    stg = stg.withColumn(col, F.col(col).cast("decimal(10,2)"))

# Mark trips that have tips
stg = stg.withColumn("has_tip", F.col("tip_amount") > 0).withColumn(
    "tip_ratio",
    F.when(
        F.col("fare_amount") > 0, F.round(F.col("tip_amount") / F.col("fare_amount"), 2)
    ).otherwise(0.0),
)

stg.show(5)

In [None]:
# Filter out "unreasonable" passenger counts
stg = stg.filter((F.col("passenger_count") > 0) & (F.col("passenger_count") <= 6))
stg.show(5)