In [119]:
import os
import pyspark
from pyspark.sql import SparkSession
import findspark
from functools import reduce
from pyspark.sql import DataFrame
import pyspark.sql.functions as F
from pyspark.sql.functions import col, when, unix_timestamp, dayofweek, udf, hour, minute, unix_timestamp, concat_ws
from IPython.display import Image, display
import matplotlib.pyplot as plt
import pandas as pd
from pyspark.sql.types import IntegerType
import holidays

01_data_cleaning.ipynb	
- Correct invalid data
- Create new simple features (e.g. trip_duration)
- Save clean dataset for next step

In [140]:
spark = SparkSession.builder \
    .appName("NYC Taxi 2024 Cleaning") \
    .config("spark.driver.memory", "8g") \
    .getOrCreate()

25/04/26 01:00:04 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [5]:
# Load the dataset saved after EDA
df = spark.read.parquet("full_dataset_after_eda.parquet")

# Quick check
df.printSchema()


                                                                                

root
 |-- VendorID: integer (nullable = true)
 |-- tpep_pickup_datetime: timestamp_ntz (nullable = true)
 |-- tpep_dropoff_datetime: timestamp_ntz (nullable = true)
 |-- passenger_count: long (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: long (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- payment_type: long (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- Airport_fee: double (nullable = true)
 |-- month: integer (nullable = true)



# Correct invalid data and Transform some classes

## vendor Id

In [8]:
df = df.withColumn("VendorID", col("VendorID").cast("string"))

In [9]:
df = df.withColumn(
    "VendorID",
    when(col("VendorID").isin("6", "7"), "Other").otherwise(col("VendorID"))
)

# tpep_pickup_datetime and tpep_dropoff_datetime

In [11]:
# Remove rows where dropoff is earlier than or equal to pickup
df = df.filter(col("tpep_dropoff_datetime") > col("tpep_pickup_datetime"))

# passenger_count

In [13]:
# Remove rows with 0 passengers
df = df.filter(col("passenger_count") > 0)

In [14]:
df = df.withColumn("passenger_count", col("passenger_count").cast("string"))

In [15]:
df = df.withColumn(
    "passenger_count",
    when(col("passenger_count").isin("5","6", "7"), "BIG GROUP").otherwise(col("passenger_count"))
)

# trip_distance

In [17]:
# Remove rows where trip_distance is 0 or negative
df = df.filter(col("trip_distance") > 0)

In [18]:
# Cap to 100 miles (optional)
df = df.filter(col("trip_distance") <= 100)

# store_and_fwd_flag

In [20]:
df = df.drop("store_and_fwd_flag")


# PULocationID and DOLocationID

In [22]:
df = df.withColumn("PULocationID", col("PULocationID").cast("string")) \
       .withColumn("DOLocationID", col("DOLocationID").cast("string"))

In [30]:
# Count frequencies
pu_counts = df.groupBy("PULocationID").count().filter("count >= 1000")
do_counts = df.groupBy("DOLocationID").count().filter("count >= 1000")

In [32]:
# First get popular PUs
popular_pu = pu_counts.select("PULocationID").rdd.flatMap(lambda x: x).collect()
popular_do = do_counts.select("DOLocationID").rdd.flatMap(lambda x: x).collect()

# Replace rare values
df = df.withColumn("PULocationID", when(col("PULocationID").isin(popular_pu), col("PULocationID")).otherwise("other"))
df = df.withColumn("DOLocationID", when(col("DOLocationID").isin(popular_do), col("DOLocationID")).otherwise("other"))

                                                                                

# payment_type

In [35]:
df = df.withColumn("payment_type", col("payment_type").cast("string"))

In [37]:
# Remove rows where payment_type == 5
df = df.filter(col("payment_type") != "5")

# fare_amount

In [40]:
# Remove rows where fare_amount is 0 or negative
df = df.filter(col("fare_amount") > 0)

In [42]:
# Cap to 100 miles (optional)
df = df.filter(col("fare_amount") <= 200)

# extra

In [45]:
# Remove rows where extra is 0 or negative
df = df.filter(col("extra") > 0)

# mta_tax

In [48]:
df = df.filter((col("mta_tax") >= 0) & (col("mta_tax") <= 0.5))

# tip_amount

In [51]:
# Remove rows where trip_distance is 0 or negative
df = df.filter(col("tip_amount") >= 0)

In [53]:
# Cap to 100 miles (optional)
df = df.filter(col("tip_amount") <= 200)

# tolls_amount

In [56]:
df = df.filter((col("tolls_amount") >= 0) & (col("tolls_amount") <= 30))

# improvement_surcharge

In [59]:
df = df.filter(col("improvement_surcharge").isin(0.3, 1.0))

# total_amount

In [62]:
df = df.filter((col("total_amount") >= 0) & (col("total_amount") <= 200))

# congestion_surcharge

In [65]:
# Remove rows where trip_distance is 0 or negative
df = df.filter(col("congestion_surcharge") >= 0)

# Airport_fee

In [75]:
df = df.withColumn(
    "Airport_fee",
    when(col("Airport_fee").isNull(), 0.0).otherwise(col("Airport_fee"))
)

In [77]:
df = df.filter(col("Airport_fee").isin(0.0, 1.75))

In [132]:
df = df.withColumn(
    "Airport_fee",
    when(col("Airport_fee") == 1.75, 1).otherwise(0)
)

# Create new features

# tpep_pickup_datetime and tpep_dropoff_datetime

| New Feature             | Purpose                                   |
|:-------------------------|:-----------------------------------------|
| `is_weekend`             | Captures weekend behavior                |
| `is_holiday`             | Captures holiday behavior (different patterns) |
| `pickup_hour_decimal`    | Time of day with high precision          |
| `trip_duration_minutes`  | Total time of the trip in minutes        |
| `trip_speed_mph`  | Average speed       |

In [88]:
# In Spark, dayofweek() returns 1 = Sunday, 7 = Saturday
df = df.withColumn(
    "is_weekend",
    when(dayofweek("tpep_pickup_datetime").isin(1, 7), 1).otherwise(0)
)

In [86]:
# List of US holidays (for 2024)
us_holidays = holidays.US(years=[2024])

# UDF to check if pickup date is a holiday
def is_holiday(date):
    if date is None:
        return 0
    return 1 if date.date() in us_holidays else 0

is_holiday_udf = udf(is_holiday, IntegerType())

# Apply it
df = df.withColumn("is_holiday", is_holiday_udf("tpep_pickup_datetime"))

In [90]:
df = df.withColumn(
    "pickup_hour_decimal",
    hour("tpep_pickup_datetime") + (minute("tpep_pickup_datetime") / 60)
)

In [94]:
df = df.withColumn(
    "trip_duration_minutes",
    (unix_timestamp("tpep_dropoff_datetime") - unix_timestamp("tpep_pickup_datetime")) / 60
)


In [104]:
df = df.withColumn(
    "trip_speed_mph",
    col("trip_distance") / (col("trip_duration_minutes") / 60)
)

# passenger_count

| New Feature        | Purpose                                    |
|:--------------------|:-------------------------------------------|
| `is_shared_ride`    | Flag indicating if more than 1 passenger was on the ride |

In [109]:
df = df.withColumn(
    "is_shared_ride",
    when(col("passenger_count") > 1, 1).otherwise(0)
)

# tolls_amount

| New Feature        | Purpose                                    |
|:--------------------|:-------------------------------------------|
| `has_toll`          | Flag indicating if a toll was paid during the trip |

In [112]:
df = df.withColumn(
    "has_toll",
    when(col("tolls_amount") > 0, 1).otherwise(0)
)

## PULocationID

| New Feature    | Purpose                                        |
|:---------------|:------------------------------------------------|
| `PU_DO_pair`   | Combines pickup and dropoff location IDs to capture trip origin-destination patterns |

In [121]:
# Combine pickup and dropoff location IDs
df = df.withColumn(
    "PU_DO_pair",
    concat_ws("-", col("PULocationID"), col("DOLocationID"))
)

## congestion_surcharge

| New Feature            | Purpose                                               |
|:------------------------|:------------------------------------------------------|
| `has_congestion_fee`     | Flags trips that were charged a congestion surcharge, indicating high-traffic areas or times |

In [130]:
df = df.withColumn(
    "has_congestion_fee",
    when(col("congestion_surcharge") > 0, 1).otherwise(0)
)

In [144]:
# Repartition and save
df.repartition(200).write.mode("overwrite").parquet("full_dataset_after_cleaning.parquet")

25/04/26 01:03:51 WARN MemoryManager: Total allocation exceeds 95,00% (1 020 054 720 bytes) of heap memory
Scaling row group sizes to 95,00% for 8 writers
25/04/26 01:03:51 WARN MemoryManager: Total allocation exceeds 95,00% (1 020 054 720 bytes) of heap memory
Scaling row group sizes to 95,00% for 8 writers
25/04/26 01:03:51 WARN MemoryManager: Total allocation exceeds 95,00% (1 020 054 720 bytes) of heap memory
Scaling row group sizes to 95,00% for 8 writers
25/04/26 01:03:52 WARN MemoryManager: Total allocation exceeds 95,00% (1 020 054 720 bytes) of heap memory
Scaling row group sizes to 95,00% for 8 writers
25/04/26 01:03:52 WARN MemoryManager: Total allocation exceeds 95,00% (1 020 054 720 bytes) of heap memory
Scaling row group sizes to 95,00% for 8 writers
25/04/26 01:03:52 WARN MemoryManager: Total allocation exceeds 95,00% (1 020 054 720 bytes) of heap memory
Scaling row group sizes to 95,00% for 8 writers
25/04/26 01:03:52 WARN MemoryManager: Total allocation exceeds 95,00% 