In [0]:
import pyspark.sql.functions as f

In [0]:
df = spark.read.table('workspace.nyc_cleansed.yellow_taxi')

Columns Renaming

In [0]:
rename_cols = {
    "tpep_pickup_datetime":"pickup_datetime",
    "tpep_dropoff_datetime":"dropoff_datetime",
    "pulocationid":"pickup_location_id",
    "dolocationid":"dropoff_location_id",
    "payment_type":"payment_type_id"
}

for key,val in rename_cols.items():
    df = df.withColumnRenamed(key,val)

select_cols = [
    'pickup_datetime',
    'dropoff_datetime',
    'passenger_count',
    'trip_distance',
    'pickup_location_id',
    'dropoff_location_id',
    'payment_type_id',
    'fare_amount',
    'extra',
    'mta_tax',
    'tip_amount',
    'tolls_amount',
    'improvement_surcharge',
    'total_amount',
    'congestion_surcharge',
    'airport_fee'
]

df = df.select(*select_cols)

Filtering Out Bad Records

In [0]:
df = df.filter( (f.col('fare_amount') > 0.0 ) | ( f.col('total_amount') > 0.0) )\
       .filter(f.col('pickup_datetime') < f.col('dropoff_datetime'))\
       .filter(f.col('passenger_count') > 0)

Adding Primary Key

In [0]:
df = df.withColumn('trip_id',f.md5(f.concat_ws( f.col('pickup_datetime'),f.col('dropoff_datetime'),f.col('pickup_location_id'),f.col('dropoff_location_id'),f.col("trip_distance"))))

df = df.withColumn('trip_id',f.regexp_replace(
                                            f.col('trip_id'),
                                            "(.{8})(.{4})(.{4})(.{4})(.{12})",
                                            "$1-$2-$3-$4-$5"
                                            )
                )

In [0]:
from pyspark.sql.window import Window

In [0]:
df_window = Window.partitionBy('trip_id').orderBy(f.desc('total_amount'))


df = df.withColumn('rnk',f.row_number().over(df_window))\
       .filter(f.col('rnk') == 1 )

In [0]:
# checking for duplicates
display(df.groupBy('trip_id').agg(f.count('*').alias('cnt')).filter(f.col('cnt') > 1))