In [56]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import sum, avg, col, lit
import os
from custom_exception import CustomException

In [57]:
# Define user variables
yellow_taxi_path = os.path.join(os.getcwd(), "/rawdata/yellow_tripdata_2021-01.csv")
green_taxi_path = os.path.join(os.getcwd(), "/rawdata/green_tripdata_2021-01.csv")
bronze_dir = os.path.join(os.getcwd(), "/pipelinedata/Bronze")
silver_dir = os.path.join(os.getcwd(), "/pipelinedata/Silver")
gold_dir = os.path.join(os.getcwd(), "/pipelinedata/Gold")

yellow_tripdata_parquet = f"{bronze_dir}/yellow_tripdata.parquet"
green_tripdata_parquet = f"{bronze_dir}/green_tripdata.parquet"
merged_tripdata_parquet = f"{silver_dir}/merged_tripdata.parquet"

yellow_tripdata_valid_parquet = f"{silver_dir}/yellow_tripdata_valid.parquet"
yellow_tripdata_invalid_parquet = f"{silver_dir}/yellow_tripdata_invalid.csv"

locations_csv = f"{gold_dir}/locations.csv"
vendors_csv = f"{gold_dir}/vendors.csv"

In [58]:
# define application variables
spark = None

In [59]:
def create_spark_session():

    # Create SparkSession
    try:
        return SparkSession.builder.appName("TaxiDataPipeline").master("local").getOrCreate()
    except Exception as e:
        print(f"Error creating Spark session: {e}")
        exit(1)

In [60]:
def read_csv_files(spark, file_path):

    try:

        # read the csv file and return the dataframe 
        return spark.read.csv(file_path, header=True)
    
    except Exception as e:
        print(f"Error reading data in step1: {e}")
        exit(1)

In [61]:
def rename_transform_columns(df, taxi_type):

    try:

        # rename columns for standardization
        df = df.withColumnRenamed("VendorID", "VendorId") \
            .withColumnRenamed("PULocationID", "PickUpLocationId") \
            .withColumnRenamed("DOLocationID", "DropOffLocationId") \
            .withColumnRenamed("passenger_count", "PassengerCount") \
            .withColumnRenamed("trip_distance", "TripDistance") \
            .withColumnRenamed("tip_amount", "TipAmount") \
            .withColumnRenamed("total_amount", "TotalAmount")
        
        if taxi_type == 'Y':

            df = df.withColumnRenamed("tpep_pickup_datetime", "PickUpDateTime") \
                .withColumnRenamed("tpep_dropoff_datetime", "DropOffDateTime")
            
        elif taxi_type == 'G':

            df = df.withColumnRenamed("lpep_pickup_datetime", "PickUpDateTime") \
                .withColumnRenamed("lpep_dropoff_datetime", "DropOffDateTime")
        
        else:
            
            # raise an exception of unknown taxi type
            raise CustomException(f"Unknown taxi type : {taxi_type}", 1001)
        
        return df.select("VendorId", "PickUpDateTime", "DropOffDateTime", "PickUpLocationId", "DropOffLocationId", "PassengerCount", "TripDistance", "TipAmount", "TotalAmount")
    
    except Exception as e:

        print(f"Error renaming columns : {e} for taxi_type : {taxi_type}")
        exit(1)

In [62]:
def merge_dataframes(first_df, second_df):

    try:

        # merge dataframes
        return first_df.unionByName(second_df)
    
    except Exception as e:

        print(f"Error merging dataframes : {e}")
        exit(1)

In [63]:
def filter_df(df, col):

    try:

        # filter the dataframes
        return df.filter(col)
    
    except Exception as e:

        print(f"Error filtering dataframes : {e}")
        exit(1)

In [64]:
def replace_null_values(df, col, value):

    try:

        # replace null values with the given value
        return df.fillna(value, subset=[col])
    
    except Exception as e:

        print(f"Error replacing null values : {e}")
        exit(1)

In [65]:
def deduplicate_df(df, cols):

    try:

        # deduplicate the dataframes
        return df.dropDuplicates(cols)
    
    except Exception as e:

        print(f"Error deduplicating dataframes : {e}")
        exit(1)

In [66]:
def aggregate_locations_df(df):

    try:
        
        # Calculate aggregations for Locations
        return df.groupBy("PickUpLocationId").agg(
            sum("TotalAmount").alias("TotalFares"),
            sum("TipAmount").alias("TotalTips"),
            avg("TripDistance").alias("AverageDistance")
        ).withColumn("LocationType", lit("PickUp")) \
        .withColumnRenamed("PickUpLocationId", "LocationId")

    except Exception as e:

        print(f"Error aggregating locations : {e}")
        exit(1)

In [67]:
def aggregate_dropoffs(df):

    try:

        # Calculate the average distance by dropoff location separately
        return df.groupBy("DropOffLocationId").agg(
            avg("TripDistance").alias("AverageDistance")
            ).withColumn("TotalFares", lit(0)) \
            .withColumn("TotalTips", lit(0)) \
            .withColumn("LocationType", lit("DropOff")) \
            .withColumnRenamed("DropOffLocationId", "LocationId") \
            .select("LocationId", "TotalFares", "TotalTips", "AverageDistance", "LocationType")

    except Exception as e:

        print(f"Error aggregating dropoffs : {e}")
        exit(1)

In [68]:
def aggregate_vendors(df):
    
    try:
        # Calculate aggregations for Vendors
        return df.groupBy("VendorId").agg(
            sum("TotalAmount").alias("TotalFares"),
            sum("TipAmount").alias("TotalTips"),
            avg("TotalAmount").alias("AverageFare"),
            avg("TipAmount").alias("AverageTips")
        )

    except Exception as e:

        print(f"Error aggregating vendors : {e}")
        exit(1)

In [69]:
def stop_spark_session():
    if spark:
        spark.stop()

In [70]:
# Main execution
if __name__ == "__main__":
    
    # create sparkSession
    spark = create_spark_session()

    # step 1 - load the raw data into initial df's
    yellow_df = read_csv_files(spark=spark, file_path=yellow_taxi_path)
    green_df = read_csv_files(spark=spark, file_path=green_taxi_path)

    print(f"The number of rows in the yellow is: {yellow_df.count()}")
    print(f"The number of rows in the green is: {green_df.count()}")

    # step 2 - rename and reduce columns
    yellow_df = rename_transform_columns(df=yellow_df, taxi_type='Y')
    green_df = rename_transform_columns(df=green_df, taxi_type='G')

    sample_yellow = yellow_df.sample(withReplacement=False, fraction=0.1)
    sample_green = green_df.sample(withReplacement=False, fraction=0.1)

    sample_yellow.show(10)
    sample_green.show(10)

    merged_df = merge_dataframes(first_df=yellow_df, second_df=green_df)

    # Group by the column and count distinct occurrences
    distinct_counts = merged_df.groupBy("PassengerCount").count()

    # Show the results
    distinct_counts.show()
    
    # step 3 - apply validation rules

    # validation rule 1 based on the passenger count
    valid_df = filter_df(df=merged_df, col=col("PassengerCount") >= 1)
    invalid_df = filter_df(merged_df, (col("PassengerCount") < 1) | col("PassengerCount").isNull())

    print(f"The number of rows in the valid is: {valid_df.count()}")
    print(f"The number of rows in the invalid is: {invalid_df.count()}")

    # validation rule 2 based on the vendor id
    valid_df = replace_null_values(df=valid_df, col="VendorId", value=999)

    # Group by the column and count distinct occurrences
    distinct_counts = valid_df.groupBy("VendorId").count()

    # Show the results
    distinct_counts.show()
    
    # step 4 - deduplicate the data
    deduped_df = deduplicate_df(df=valid_df, cols=["PickUpLocationId", "PickUpDateTime", "DropOffDateTime", "DropOffLocationId", "VendorId"])

    # step 5 - apply aggregations
    locations_df = aggregate_locations_df(df=deduped_df)
    dropoff_df = aggregate_dropoffs(df=deduped_df)

    merged_df = merge_dataframes(locations_df, dropoff_df)

    vendors_df = aggregate_vendors(df=deduped_df)

    print(f"The number of rows in the locations is: {locations_df.count()}")
    print(f"The number of rows in the vendors is: {vendors_df.count()}")
    

    stop_spark_session()

The number of rows in the yellow is: 1369819
The number of rows in the green is: 76539
+--------+-------------------+-------------------+----------------+-----------------+--------------+------------+---------+-----------+
|VendorId|     PickUpDateTime|    DropOffDateTime|PickUpLocationId|DropOffLocationId|PassengerCount|TripDistance|TipAmount|TotalAmount|
+--------+-------------------+-------------------+----------------+-----------------+--------------+------------+---------+-----------+
|       2|2021-01-01 00:42:11|2021-01-01 00:44:24|              50|              142|             5|         .81|        0|        8.3|
|       2|2021-01-01 00:43:41|2021-01-01 00:48:17|             239|              142|             3|         .93|        0|        9.3|
|       1|2021-01-01 00:16:27|2021-01-01 00:25:36|             249|              137|             0|        2.20|        0|       12.8|
|       2|2021-01-01 00:55:19|2021-01-01 00:58:45|             263|               75|            