In [0]:
from pyspark.sql.functions import col, when, unix_timestamp, hour
from pyspark.sql.types import *

In [0]:
def clean_nyc_taxi_data(df):
    # # Step 1: Cast columns to appropriate types
    # df = df.withColumn("VendorID", col("VendorID").cast("int")) \
    #        .withColumn("tpep_pickup_datetime", col("tpep_pickup_datetime").cast("timestamp")) \
    #        .withColumn("tpep_dropoff_datetime", col("tpep_dropoff_datetime").cast("timestamp")) \
    #        .withColumn("passenger_count", col("passenger_count").cast("int")) \
    #        .withColumn("trip_distance", col("trip_distance").cast("float")) \
    #        .withColumn("RatecodeID", col("RatecodeID").cast("int")) \
    #        .withColumn("PULocationID", col("PULocationID").cast("int")) \
    #        .withColumn("DOLocationID", col("DOLocationID").cast("int")) \
    #        .withColumn("payment_type", col("payment_type").cast("int")) \
    #        .withColumn("fare_amount", col("fare_amount").cast("float")) \
    #        .withColumn("extra", col("extra").cast("float")) \
    #        .withColumn("mta_tax", col("mta_tax").cast("float")) \
    #        .withColumn("tip_amount", col("tip_amount").cast("float")) \
    #        .withColumn("tolls_amount", col("tolls_amount").cast("float")) \
    #        .withColumn("improvement_surcharge", col("improvement_surcharge").cast("float")) \
    #        .withColumn("total_amount", col("total_amount").cast("float")) \
    #        .withColumn("congestion_surcharge", col("congestion_surcharge").cast("float"))

    # Step 2: Drop duplicates and rows with critical nulls
    df = df.dropDuplicates()
    df = df.dropna(subset=["tpep_pickup_datetime", "tpep_dropoff_datetime", "fare_amount"])

    # Step 3: Filter invalid or outlier data
    df = df.filter((col("trip_distance") > 0) & (col("fare_amount") > 0))
    df = df.filter((col("passenger_count") > 0) & (col("passenger_count") <= 6))
    
    df = df.withColumn("trip_time_minutes",
                       (unix_timestamp("tpep_dropoff_datetime") - unix_timestamp("tpep_pickup_datetime")) / 60)
    df = df.filter((col("trip_time_minutes") > 0) & (col("trip_time_minutes") < 180))

    # Step 4: Standardize categorical values
    df = df.withColumn("store_and_fwd_flag",
                       when(col("store_and_fwd_flag") == "Y", "Yes").otherwise("No"))
    
    df = df.withColumn("payment_type",
                       when(col("payment_type") == 1, "Credit Card")
                       .when(col("payment_type") == 2, "Cash")
                       .when(col("payment_type") == 3, "No Charge")
                       .when(col("payment_type") == 4, "Dispute")
                       .when(col("payment_type") == 5, "Unknown")
                       .otherwise("Other"))

    # Step 5: Add derived columns
    df = df.withColumn("pickup_hour", hour("tpep_pickup_datetime")) \
           .withColumn("tip_percentage", 
                       when(col("fare_amount") > 0, (col("tip_amount") / col("fare_amount")) * 100).otherwise(0.0))
           
    
    
    return df