In [0]:
# Fetch pipeline parameters passed in ADF for start month and end month
dbutils.widgets.text("p_start_month", "")
dbutils.widgets.text("p_end_month", "")

In [0]:
import sys
import os
import importlib
from pyspark.sql.functions import col, when, current_timestamp, timestamp_diff, year, month

# path for module imports
root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
if root not in sys.path:
    sys.path.append(root)

# Force reload to ensure latest version
if 'modules.utils.date_utils' in sys.modules:
    importlib.reload(sys.modules['modules.utils.date_utils'])
if 'modules.utils.table_utils' in sys.modules:
    importlib.reload(sys.modules['modules.utils.table_utils'])

from modules.utils.date_utils import get_month_start_n_months_ago
from modules.utils.table_utils import get_filtered_dataframe, upsert_delta_table

##### FULL LOAD AND INCREMENTAL LOAD

In [0]:
# Read data 3 months ago for incremental load
three_months_start = get_month_start_n_months_ago(3)
two_months_start = get_month_start_n_months_ago(2)

# Paths and Configuration
source = "nyctaxi.01_bronze.green_trips_raw"
target = "nyctaxi.02_silver.green_trips_cleansed"
storage_path = "abfss://silver@stnyctaxigreen.dfs.core.windows.net/green_trips_cleansed"

# Read Filtered Data
df = get_filtered_dataframe(
    spark, 
    source_table=source, 
    target_table=target, 
    start_date=three_months_start, 
    end_date=two_months_start
)

##### TRANSFORMATION AND MAPPING

In [0]:
df_cleansed = df.select(
    when(col("VendorID") == 1, "Creative Mobile Technologies, LLC")
      .when(col("VendorID") == 2, "Curb Mobility, LLC")
      .when(col("VendorID") == 6, "Myle Technologies Inc")
      .otherwise("Unknown")
      .alias("vendor"),
    
    "lpep_pickup_datetime",
    "lpep_dropoff_datetime",

    # trip duration in minutes
    timestamp_diff('MINUTE', df.lpep_pickup_datetime, df.lpep_dropoff_datetime).alias("trip_duration"),

    "passenger_count",

    # Convert Miles to KM (1 Mile = 1.60934 Kilometers)
    (col("trip_distance") * 1.60934).alias("trip_distance_km"),

    when(col("RatecodeID") == 1, "Standard Rate")
      .when(col("RatecodeID") == 2, "JFK")
      .when(col("RatecodeID") == 3, "Newark")
      .when(col("RatecodeID") == 4, "Nassau or Westchester")
      .when(col("RatecodeID") == 5, "Negotiated Fare")
      .when(col("RatecodeID") == 6, "Group Ride")
      .otherwise("Unknown")
      .alias("rate_type"),
    
    "store_and_fwd_flag",
    col("PULocationID").alias("pu_location_id"),
    col("DOLocationID").alias("do_location_id"),
    
    when(col("payment_type") == 0, "Flex Fare trip")
      .when(col("payment_type") == 1, "Credit card")
      .when(col("payment_type") == 2, "Cash")
      .when(col("payment_type") == 3, "No charge")
      .when(col("payment_type") == 4, "Dispute")
      .when(col("payment_type") == 6, "Voided trip")
      .otherwise("Unknown")
      .alias("payment_type"),
    
    "fare_amount",
    "extra",
    "mta_tax",
    "tip_amount",
    "tolls_amount",
    "improvement_surcharge",
    "total_amount",
    "congestion_surcharge",

    when(col("trip_type") == 1, "Street-hail")
      .when(col("trip_type") == 2, "Dispatch")
      .otherwise("Unknown")
      .alias("trip_type"),

    "cbd_congestion_fee",
    "load_timestamp"
)

##### FILTERING DATA

In [0]:
# Filter for the correct year (2025) to remove stray 2024 records
# Filter for positive fare amounts to remove voids/errors (negative values)

p_start_month = int(dbutils.widgets.get("p_start_month"))
p_end_month = int(dbutils.widgets.get("p_end_month"))

print(f"Filtering data: {p_start_month} to {p_end_month}")

df_cleansed = df_cleansed.filter(
    (year(col("lpep_pickup_datetime")) == 2025) & 
    (month(col("lpep_pickup_datetime")) >= p_start_month) & 
    (month(col("lpep_pickup_datetime")) <= p_end_month) & 
    (col("fare_amount") >= 0)
)

##### WRITING DATA EFFICIENTLY TO AVOID DUPLICATES

In [0]:
upsert_delta_table(
    spark, 
    df_cleansed, 
    target, 
    "t.lpep_pickup_datetime = s.lpep_pickup_datetime AND t.pu_location_id = s.pu_location_id",
    storage_path
)

print(f"Total records in {target}: {spark.read.table(target).count()}")