In [1]:
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

from warnings import simplefilter
simplefilter('ignore')

In [2]:
# Create a SparkSession
spark = (
    SparkSession.builder
    .master("local[4]")
    .appName("pyspark-sql.sandbox")
    .getOrCreate()
)

# Stop the SparkSession, if needed
# spark.stop() 

22/08/31 16:28:33 WARN Utils: Your hostname, mpb-m1max.local resolves to a loopback address: 127.0.0.1; using 192.168.1.12 instead (on interface en0)
22/08/31 16:28:33 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


22/08/31 16:28:33 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
# Read green taxi data
df_green = spark.read.parquet("../data/part/green/2021/*/*")
df_green.printSchema()

                                                                                

root
 |-- VendorID: long (nullable = true)
 |-- lpep_pickup_datetime: timestamp (nullable = true)
 |-- lpep_dropoff_datetime: timestamp (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- RatecodeID: double (nullable = true)
 |-- PULocationID: long (nullable = true)
 |-- DOLocationID: long (nullable = true)
 |-- passenger_count: double (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- ehail_fee: integer (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- payment_type: double (nullable = true)
 |-- trip_type: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)



In [4]:
# Read yellow taxi data
df_yellow = spark.read.parquet("../data/part/yellow/2021/*/*")
df_yellow.printSchema()

root
 |-- VendorID: long (nullable = true)
 |-- tpep_pickup_datetime: timestamp (nullable = true)
 |-- tpep_dropoff_datetime: timestamp (nullable = true)
 |-- passenger_count: double (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: double (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: long (nullable = true)
 |-- DOLocationID: long (nullable = true)
 |-- payment_type: long (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- airport_fee: double (nullable = true)



In [5]:
# Rename columns to match in both datasets

# Green dataset
df_green = (
    df_green
        .withColumnRenamed("lpep_pickup_datetime", "pickup_datetime")
        .withColumnRenamed("lpep_dropoff_datetime", "dropoff_datetime")
)

# Yellow dataset
df_yellow = (
    df_yellow
        .withColumnRenamed("tpep_pickup_datetime", "pickup_datetime")
        .withColumnRenamed("tpep_dropoff_datetime", "dropoff_datetime")
)

In [6]:
# Select columns to match in both datasets
tripdata_columns = [col for col in df_green.columns if col in df_yellow.columns]
tripdata_columns

['VendorID',
 'pickup_datetime',
 'dropoff_datetime',
 'store_and_fwd_flag',
 'RatecodeID',
 'PULocationID',
 'DOLocationID',
 'passenger_count',
 'trip_distance',
 'fare_amount',
 'extra',
 'mta_tax',
 'tip_amount',
 'tolls_amount',
 'improvement_surcharge',
 'total_amount',
 'payment_type',
 'congestion_surcharge']

In [7]:
# Create a new column with the taxi type
df_green = df_green.select(*tripdata_columns).withColumn("service", F.lit("green"))
df_yellow = df_yellow.select(*tripdata_columns).withColumn("service", F.lit("yellow"))

# Combine both datasets
df_tripdata = df_green.unionAll(df_yellow)

In [8]:
# Check the number of rows for each taxi type
%time df_tripdata.groupBy("service").count().show()



+-------+--------+
|service|   count|
+-------+--------+
|  green|  319081|
| yellow|10150822|
+-------+--------+

CPU times: user 4.78 ms, sys: 2.2 ms, total: 6.98 ms
Wall time: 3.78 s



                                                                                

## PySpark SQL

In [9]:
# Register the DataFrame as a SQL temporary table
df_tripdata.createOrReplaceTempView("nyc_taxi_tripdata")

In [10]:
# Run a SQL query
%time spark.sql("SELECT service, COUNT(*) AS count FROM nyc_taxi_tripdata GROUP BY service").show()



+-------+--------+
|service|   count|
+-------+--------+
|  green|  319081|
| yellow|10150822|
+-------+--------+

CPU times: user 2.55 ms, sys: 1.25 ms, total: 3.8 ms
Wall time: 1.28 s



                                                                                

In [11]:
%%time 

# Fancy query
df_result = spark.sql("""
SELECT 
    -- Reveneue grouping
    DATE_TRUNC('month', pickup_datetime) AS revenue_month, 
    PULocationID AS revenue_zone,
    service, 

    -- Revenue calculation 
    ROUND(SUM(fare_amount), 2) AS revenue_monthly_fare,
    ROUND(SUM(extra), 2) AS revenue_monthly_extra,
    ROUND(SUM(mta_tax), 2) AS revenue_monthly_mta_tax,
    ROUND(SUM(tip_amount), 2) AS revenue_monthly_tip_amount,
    ROUND(SUM(tolls_amount), 2) AS revenue_monthly_tolls_amount,
    ROUND(SUM(improvement_surcharge), 2) AS revenue_monthly_improvement_surcharge,
    ROUND(SUM(total_amount), 2) AS revenue_monthly_total_amount,
    ROUND(SUM(congestion_surcharge), 2) AS revenue_monthly_congestion_surcharge,

    -- Additional calculations
    ROUND(AVG(passenger_count), 2) AS avg_montly_passenger_count,
    ROUND(AVG(trip_distance), 2) AS avg_montly_trip_distance
FROM
    nyc_taxi_tripdata
GROUP BY
    revenue_month, revenue_zone, service
"""
)

CPU times: user 1.11 ms, sys: 841 µs, total: 1.95 ms
Wall time: 222 ms


In [12]:
%%time 

# Check some datapoints
df_result.select(
    "revenue_month",
    "revenue_zone", 
    "service", 
    "revenue_monthly_fare").show(10)



+-------------------+------------+-------+--------------------+
|      revenue_month|revenue_zone|service|revenue_monthly_fare|
+-------------------+------------+-------+--------------------+
|2021-10-01 00:00:00|         225|  green|            15349.12|
|2021-10-01 00:00:00|          23|  green|             7239.53|
|2021-10-01 00:00:00|          11|  green|             2438.41|
|2021-10-01 00:00:00|          79|  green|             1013.59|
|2021-10-01 00:00:00|           6|  green|              879.45|
|2021-09-01 00:00:00|          82|  green|                58.5|
|2021-10-01 00:00:00|         205|  green|            13412.16|
|2021-10-01 00:00:00|          60|  green|             4616.21|
|2021-10-01 00:00:00|           4|  green|              951.82|
|2021-10-01 00:00:00|         129|  green|            30595.45|
+-------------------+------------+-------+--------------------+
only showing top 10 rows

CPU times: user 5.33 ms, sys: 2.14 ms, total: 7.47 ms
Wall time: 3.86 s



                                                                                

In [13]:
%%time

# Write Resulting table to parquet
df_result.write.parquet("../data/reports/revenue/tripdata_all/", mode = "overwrite")

[Stage 13:>                                                         (0 + 1) / 1]

CPU times: user 4.18 ms, sys: 1.5 ms, total: 5.69 ms
Wall time: 5.03 s



                                                                                