## Instanciating Spark

Import the PySpark library and the `SparkSession` class


In [1]:
import pyspark
from pyspark.sql import SparkSession
import warnings

warnings.filterwarnings('ignore')

In [2]:
spark = SparkSession.builder.master("local[*]").appName('demo').getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/03/13 14:33:43 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


## Read data

Download some large data

In [3]:
!wget https://d37ci6vzurychx.cloudfront.net/trip-data/fhvhv_tripdata_2021-01.parquet

--2025-03-13 14:34:22--  https://d37ci6vzurychx.cloudfront.net/trip-data/fhvhv_tripdata_2021-01.parquet
Resolving d37ci6vzurychx.cloudfront.net (d37ci6vzurychx.cloudfront.net)... 3.164.82.197, 3.164.82.160, 3.164.82.112, ...
Connecting to d37ci6vzurychx.cloudfront.net (d37ci6vzurychx.cloudfront.net)|3.164.82.197|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 308924937 (295M) [application/x-www-form-urlencoded]
Saving to: ‘fhvhv_tripdata_2021-01.parquet’


2025-03-13 14:34:53 (9.86 MB/s) - ‘fhvhv_tripdata_2021-01.parquet’ saved [308924937/308924937]



In [4]:
# Read the data, Note that the datatypes are not inferred and that everything is a string!
df = spark.read.option("header","true").parquet("fhvhv_tripdata_2021-01.parquet")

df.show()

                                                                                

+-----------------+--------------------+--------------------+-------------------+-------------------+-------------------+-------------------+------------+------------+----------+---------+-------------------+-----+----+---------+--------------------+-----------+----+----------+-------------------+-----------------+------------------+----------------+--------------+
|hvfhs_license_num|dispatching_base_num|originating_base_num|   request_datetime|  on_scene_datetime|    pickup_datetime|   dropoff_datetime|PULocationID|DOLocationID|trip_miles|trip_time|base_passenger_fare|tolls| bcf|sales_tax|congestion_surcharge|airport_fee|tips|driver_pay|shared_request_flag|shared_match_flag|access_a_ride_flag|wav_request_flag|wav_match_flag|
+-----------------+--------------------+--------------------+-------------------+-------------------+-------------------+-------------------+------------+------------+----------+---------+-------------------+-----+----+---------+--------------------+-----------+--

In [5]:
# Displays the first five rows as a list of strings
df.head(5)

                                                                                

[Row(hvfhs_license_num='HV0003', dispatching_base_num='B02682', originating_base_num='B02682', request_datetime=datetime.datetime(2021, 1, 1, 0, 28, 9), on_scene_datetime=datetime.datetime(2021, 1, 1, 0, 31, 42), pickup_datetime=datetime.datetime(2021, 1, 1, 0, 33, 44), dropoff_datetime=datetime.datetime(2021, 1, 1, 0, 49, 7), PULocationID=230, DOLocationID=166, trip_miles=5.26, trip_time=923, base_passenger_fare=22.28, tolls=0.0, bcf=0.67, sales_tax=1.98, congestion_surcharge=2.75, airport_fee=None, tips=0.0, driver_pay=14.99, shared_request_flag='N', shared_match_flag='N', access_a_ride_flag=' ', wav_request_flag='N', wav_match_flag='N'),
 Row(hvfhs_license_num='HV0003', dispatching_base_num='B02682', originating_base_num='B02682', request_datetime=datetime.datetime(2021, 1, 1, 0, 45, 56), on_scene_datetime=datetime.datetime(2021, 1, 1, 0, 55, 19), pickup_datetime=datetime.datetime(2021, 1, 1, 0, 55, 19), dropoff_datetime=datetime.datetime(2021, 1, 1, 1, 18, 21), PULocationID=152, DO

In [6]:
# Datatype of the fields in dataframe, again all as a list of strings
df.schema

StructType([StructField('hvfhs_license_num', StringType(), True), StructField('dispatching_base_num', StringType(), True), StructField('originating_base_num', StringType(), True), StructField('request_datetime', TimestampNTZType(), True), StructField('on_scene_datetime', TimestampNTZType(), True), StructField('pickup_datetime', TimestampNTZType(), True), StructField('dropoff_datetime', TimestampNTZType(), True), StructField('PULocationID', LongType(), True), StructField('DOLocationID', LongType(), True), StructField('trip_miles', DoubleType(), True), StructField('trip_time', LongType(), True), StructField('base_passenger_fare', DoubleType(), True), StructField('tolls', DoubleType(), True), StructField('bcf', DoubleType(), True), StructField('sales_tax', DoubleType(), True), StructField('congestion_surcharge', DoubleType(), True), StructField('airport_fee', DoubleType(), True), StructField('tips', DoubleType(), True), StructField('driver_pay', DoubleType(), True), StructField('shared_re

In [7]:
df.columns

['hvfhs_license_num',
 'dispatching_base_num',
 'originating_base_num',
 'request_datetime',
 'on_scene_datetime',
 'pickup_datetime',
 'dropoff_datetime',
 'PULocationID',
 'DOLocationID',
 'trip_miles',
 'trip_time',
 'base_passenger_fare',
 'tolls',
 'bcf',
 'sales_tax',
 'congestion_surcharge',
 'airport_fee',
 'tips',
 'driver_pay',
 'shared_request_flag',
 'shared_match_flag',
 'access_a_ride_flag',
 'wav_request_flag',
 'wav_match_flag']

## Creating a schema using Pandas

write the first 1001 lines of the output csv file to another file

In [8]:
!head -n 1001 output.csv > head.csv

In [9]:
# Count the lines in the new .csv file
!wc -l head.csv

1001 head.csv


We will now use Pandas to infer the datatypes. Let's import pandas and read the small .csv file.

In [10]:
import pandas as pd

df_pandas = pd.read_csv("head.csv")

In [11]:
# Lets see how pandas infers the datatypes - it does a good job except for time fields
df_pandas.dtypes

hvfhs_license_num        object
dispatching_base_num     object
originating_base_num     object
request_datetime         object
on_scene_datetime        object
pickup_datetime          object
dropoff_datetime         object
PULocationID              int64
DOLocationID              int64
trip_miles              float64
trip_time                 int64
base_passenger_fare     float64
tolls                   float64
bcf                     float64
sales_tax               float64
congestion_surcharge    float64
tips                    float64
driver_pay              float64
shared_request_flag      object
shared_match_flag        object
access_a_ride_flag       object
wav_request_flag         object
wav_match_flag           object
dtype: object

In [12]:
# Now we need to convert the pandas dataframe into a spark dataframe and call the schema attribute
# to access the schema of the pandas dataframe that we saw previously but using Spark. Result should give you 
# a schema that is different from what you saw previously using .dtypes
spark.createDataFrame(df_pandas).schema

StructType([StructField('hvfhs_license_num', StringType(), True), StructField('dispatching_base_num', StringType(), True), StructField('originating_base_num', StringType(), True), StructField('request_datetime', StringType(), True), StructField('on_scene_datetime', StringType(), True), StructField('pickup_datetime', StringType(), True), StructField('dropoff_datetime', StringType(), True), StructField('PULocationID', LongType(), True), StructField('DOLocationID', LongType(), True), StructField('trip_miles', DoubleType(), True), StructField('trip_time', LongType(), True), StructField('base_passenger_fare', DoubleType(), True), StructField('tolls', DoubleType(), True), StructField('bcf', DoubleType(), True), StructField('sales_tax', DoubleType(), True), StructField('congestion_surcharge', DoubleType(), True), StructField('tips', DoubleType(), True), StructField('driver_pay', DoubleType(), True), StructField('shared_request_flag', StringType(), True), StructField('shared_match_flag', Strin

`StructType` comes from `scala` and we need to turn this into python code for declaring schema for our dataframe. Let's take this opportunity to also change the data types to more optimised data types than what was prescribed above. 

Integer = 4 bytes
Long = 8 bytes

> Note: Parquet stores INT64 as LongType in Spark, not IntegerType.

In [13]:
from pyspark.sql import types

In [14]:
# Lets amend the schema so that its more efficient than what was prescribed

schema = types.StructType(
    [
        types.StructField('hvfhs_license_num', types.StringType(), True),
        types.StructField('dispatching_base_num', types.StringType(), True), 
        types.StructField('originating_base_num', types.StringType(), True), 
        types.StructField('request_datetime', types.TimestampType(), True), 
        types.StructField('on_scene_datetime', types.TimestampType(), True), 
        types.StructField('pickup_datetime', types.TimestampType(), True), 
        types.StructField('dropoff_datetime', types.TimestampType(), True), 
        types.StructField('PULocationID', types.LongType(), True), 
        types.StructField('DOLocationID', types.LongType(), True), 
        types.StructField('trip_miles', types.DoubleType(), True), 
        types.StructField('trip_time', types.LongType(), True), 
        types.StructField('base_passenger_fare', types.DoubleType(), True), 
        types.StructField('tolls', types.DoubleType(), True), 
        types.StructField('bcf', types.DoubleType(), True), 
        types.StructField('sales_tax', types.DoubleType(), True), 
        types.StructField('congestion_surcharge', types.DoubleType(), True), 
        types.StructField('tips', types.DoubleType(), True), 
        types.StructField('driver_pay', types.DoubleType(), True), 
        types.StructField('shared_request_flag', types.StringType(), True), 
        types.StructField('shared_match_flag', types.StringType(), True), 
        types.StructField('access_a_ride_flag', types.StringType(), True), 
        types.StructField('wav_request_flag', types.StringType(), True), 
        types.StructField('wav_match_flag', types.StringType(), True)
    ]
)

In [15]:
# now to read our parquet file into a spark dataframe again, but this time with a schema

df = spark.read.schema(schema).parquet("fhvhv_tripdata_2021-01.parquet")
df.head(5)

                                                                                

[Row(hvfhs_license_num='HV0003', dispatching_base_num='B02682', originating_base_num='B02682', request_datetime=datetime.datetime(2021, 1, 1, 0, 28, 9), on_scene_datetime=datetime.datetime(2021, 1, 1, 0, 31, 42), pickup_datetime=datetime.datetime(2021, 1, 1, 0, 33, 44), dropoff_datetime=datetime.datetime(2021, 1, 1, 0, 49, 7), PULocationID=230, DOLocationID=166, trip_miles=5.26, trip_time=923, base_passenger_fare=22.28, tolls=0.0, bcf=0.67, sales_tax=1.98, congestion_surcharge=2.75, tips=0.0, driver_pay=14.99, shared_request_flag='N', shared_match_flag='N', access_a_ride_flag=' ', wav_request_flag='N', wav_match_flag='N'),
 Row(hvfhs_license_num='HV0003', dispatching_base_num='B02682', originating_base_num='B02682', request_datetime=datetime.datetime(2021, 1, 1, 0, 45, 56), on_scene_datetime=datetime.datetime(2021, 1, 1, 0, 55, 19), pickup_datetime=datetime.datetime(2021, 1, 1, 0, 55, 19), dropoff_datetime=datetime.datetime(2021, 1, 1, 1, 18, 21), PULocationID=152, DOLocationID=167, tr

## Partitions

In [16]:
# create 24 partitions in our dataframe - nothing actually happens here, df still gives the original dataframe
# only when we execute a step, say .write, then only will the execution happen in partitions
df = df.repartition(24)

In [17]:
# parquetize and write to fhvhv/2021/01/ folder
df.write.parquet('fhvhv/2021/01/', mode='overwrite')

                                                                                

In [19]:
# Check the created files
!ls -lh fhvhv/2021/01/

total 526M
-rw-r--r-- 1 peter peter   0 Mar 13 14:37 _SUCCESS
-rw-r--r-- 1 peter peter 22M Mar 13 14:37 part-00000-c310a6b6-ef55-417e-a1cf-d843d3658b8b-c000.snappy.parquet
-rw-r--r-- 1 peter peter 22M Mar 13 14:37 part-00001-c310a6b6-ef55-417e-a1cf-d843d3658b8b-c000.snappy.parquet
-rw-r--r-- 1 peter peter 22M Mar 13 14:37 part-00002-c310a6b6-ef55-417e-a1cf-d843d3658b8b-c000.snappy.parquet
-rw-r--r-- 1 peter peter 22M Mar 13 14:37 part-00003-c310a6b6-ef55-417e-a1cf-d843d3658b8b-c000.snappy.parquet
-rw-r--r-- 1 peter peter 22M Mar 13 14:37 part-00004-c310a6b6-ef55-417e-a1cf-d843d3658b8b-c000.snappy.parquet
-rw-r--r-- 1 peter peter 22M Mar 13 14:37 part-00005-c310a6b6-ef55-417e-a1cf-d843d3658b8b-c000.snappy.parquet
-rw-r--r-- 1 peter peter 22M Mar 13 14:37 part-00006-c310a6b6-ef55-417e-a1cf-d843d3658b8b-c000.snappy.parquet
-rw-r--r-- 1 peter peter 22M Mar 13 14:37 part-00007-c310a6b6-ef55-417e-a1cf-d843d3658b8b-c000.snappy.parquet
-rw-r--r-- 1 peter peter 22M Mar 13 14:37 part-00008-c310a

## Spark DataFrames

In [20]:
# Create a spark dataframe from the partitioned parquet files
df = spark.read.parquet('fhvhv/2021/01/')

# Check the schema
df.printSchema()

root
 |-- hvfhs_license_num: string (nullable = true)
 |-- dispatching_base_num: string (nullable = true)
 |-- originating_base_num: string (nullable = true)
 |-- request_datetime: timestamp (nullable = true)
 |-- on_scene_datetime: timestamp (nullable = true)
 |-- pickup_datetime: timestamp (nullable = true)
 |-- dropoff_datetime: timestamp (nullable = true)
 |-- PULocationID: long (nullable = true)
 |-- DOLocationID: long (nullable = true)
 |-- trip_miles: double (nullable = true)
 |-- trip_time: long (nullable = true)
 |-- base_passenger_fare: double (nullable = true)
 |-- tolls: double (nullable = true)
 |-- bcf: double (nullable = true)
 |-- sales_tax: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- tips: double (nullable = true)
 |-- driver_pay: double (nullable = true)
 |-- shared_request_flag: string (nullable = true)
 |-- shared_match_flag: string (nullable = true)
 |-- access_a_ride_flag: string (nullable = true)
 |-- wav_request_flag: string

## Functions and UDFs

In [21]:
from pyspark.sql import functions as F

In [22]:
df.show()

+-----------------+--------------------+--------------------+-------------------+-------------------+-------------------+-------------------+------------+------------+----------+---------+-------------------+-----+----+---------+--------------------+----+----------+-------------------+-----------------+------------------+----------------+--------------+
|hvfhs_license_num|dispatching_base_num|originating_base_num|   request_datetime|  on_scene_datetime|    pickup_datetime|   dropoff_datetime|PULocationID|DOLocationID|trip_miles|trip_time|base_passenger_fare|tolls| bcf|sales_tax|congestion_surcharge|tips|driver_pay|shared_request_flag|shared_match_flag|access_a_ride_flag|wav_request_flag|wav_match_flag|
+-----------------+--------------------+--------------------+-------------------+-------------------+-------------------+-------------------+------------+------------+----------+---------+-------------------+-----+----+---------+--------------------+----+----------+-------------------+--

In [23]:

# Use built-in functions (F.to_date())
df \
    .withColumn('pickup_date', F.to_date(df.pickup_datetime)) \
    .withColumn('dropoff_date', F.to_date(df.dropoff_datetime)) \
    .select('pickup_date', 'dropoff_date', 'PULocationID', 'DOLocationID') \
    .show()

+-----------+------------+------------+------------+
|pickup_date|dropoff_date|PULocationID|DOLocationID|
+-----------+------------+------------+------------+
| 2021-01-25|  2021-01-25|          86|          86|
| 2021-01-07|  2021-01-07|         238|          74|
| 2021-01-06|  2021-01-06|          41|          41|
| 2021-01-22|  2021-01-22|         216|         216|
| 2021-01-30|  2021-01-30|          39|          76|
| 2021-01-10|  2021-01-10|          61|         225|
| 2021-01-18|  2021-01-18|         107|          68|
| 2021-01-14|  2021-01-14|         129|          75|
| 2021-01-01|  2021-01-01|         208|         140|
| 2021-01-10|  2021-01-10|         119|         265|
| 2021-01-02|  2021-01-02|          94|         174|
| 2021-01-21|  2021-01-21|         231|         234|
| 2021-01-08|  2021-01-08|          89|          91|
| 2021-01-30|  2021-01-30|          79|         141|
| 2021-01-15|  2021-01-15|          61|          61|
| 2021-01-03|  2021-01-03|         246|       

In [24]:
# Creating a user-defined function (UDF)

def crazy_stuff(base_num):
    num = int(base_num[1:])
    if num % 7 == 0:
        return f's/{num:03x}'
    elif num % 3 == 0:
        return f'a/{num:03x}'
    else:
        return f'e/{num:03x}'

crazy_stuff_udf = F.udf(crazy_stuff, returnType=types.StringType())

In [25]:
df \
    .withColumn('pickup_date', F.to_date(df.pickup_datetime)) \
    .withColumn('dropoff_date', F.to_date(df.dropoff_datetime)) \
    .withColumn('base_id', crazy_stuff_udf(df.dispatching_base_num)) \
    .select('base_id', 'pickup_date', 'dropoff_date', 'PULocationID', 'DOLocationID') \
    .show()

[Stage 14:>                                                         (0 + 1) / 1]

+-------+-----------+------------+------------+------------+
|base_id|pickup_date|dropoff_date|PULocationID|DOLocationID|
+-------+-----------+------------+------------+------------+
|  e/9ce| 2021-01-25|  2021-01-25|          86|          86|
|  s/b3d| 2021-01-07|  2021-01-07|         238|          74|
|  e/b30| 2021-01-06|  2021-01-06|          41|          41|
|  s/b3d| 2021-01-22|  2021-01-22|         216|         216|
|  e/b38| 2021-01-30|  2021-01-30|          39|          76|
|  e/b32| 2021-01-10|  2021-01-10|          61|         225|
|  e/a39| 2021-01-18|  2021-01-18|         107|          68|
|  e/9ce| 2021-01-14|  2021-01-14|         129|          75|
|  e/b47| 2021-01-01|  2021-01-01|         208|         140|
|  e/9ce| 2021-01-10|  2021-01-10|         119|         265|
|  s/b3d| 2021-01-02|  2021-01-02|          94|         174|
|  e/9ce| 2021-01-21|  2021-01-21|         231|         234|
|  e/b32| 2021-01-08|  2021-01-08|          89|          91|
|  e/9ce| 2021-01-30|  2

                                                                                