In [1]:
from pyspark.sql import SparkSession

# Create a SparkSession
spark = SparkSession.builder \
    .appName("example") \
    .getOrCreate()

# Caching a DataFrame

You've been assigned a task that requires running several analysis operations on a DataFrame. You've learned that caching can improve performance when reusing DataFrames and would like to implement it.

In [16]:
# Import the pyspark.sql.types library
from pyspark.sql.types import *

# Define a new schema using the StructType method
departure_schema = StructType([
  # Define a StructField for each field
  StructField('Date (MM/DD/YYYY)', DataType(), True),
  StructField('Flight Number', StringType(), True),
  StructField('Destination Airport', StringType(), True),
  StructField('Actual elapsed time (Minutes)', IntegerType(), True),
])

departures_df = spark.read.csv("dataset/AA_DFW_2017_Departures_Short.csv", header=True)
# departures_df.show(3)
departures_df.dtypes

[('Date (MM/DD/YYYY)', 'string'),
 ('Flight Number', 'string'),
 ('Destination Airport', 'string'),
 ('Actual elapsed time (Minutes)', 'string')]

In [19]:
import time
start_time = time.time()

# Add caching to the unique rows in departures_df
departures_df = departures_df.distinct().cache()

# Count the unique rows in departures_df, noting how long the operation takes
print("Counting %d rows took %f seconds" % (departures_df.count(), time.time() - start_time))



Counting 139358 rows took 1.847591 seconds


In [20]:

# Count the rows again, noting the variance in time of a cached DataFrame
start_time = time.time()
print("Counting %d rows again took %f seconds" % (departures_df.count(), time.time() - start_time))

Counting 139358 rows again took 0.148730 seconds


# Removing a DataFrame from cache

You've finished the analysis tasks with the departures_df DataFrame, but have some other processing to do. You'd like to remove the DataFrame from the cache to prevent any excess memory usage on your cluster.

In [21]:
# Determine if departures_df is in the cache
print("Is departures_df cached?: %s" % departures_df.is_cached)
print("Removing departures_df from cache")



Is departures_df cached?: True
Removing departures_df from cache


In [22]:
# Remove departures_df from the cache
departures_df.unpersist()

# Check the cache status again
print("Is departures_df cached?: %s" % departures_df.is_cached)

Is departures_df cached?: False


# File size optimization

Consider if you're given 2 large data files on a cluster with 10 nodes. Each file contains 10M rows of roughly the same size. While working with your data, the responsiveness is acceptable but the initial read from the files takes a considerable period of time. Note that you are the only one who will use the data and it changes for each run.

Which of the following is the best option to improve performance?

- Split the 2 files into 50 files of 400K rows each.

# File import performance

You've been given a large set of data to import into a Spark DataFrame. You'd like to test the difference in import speed by splitting up the file.

In [23]:
# # Import the full and split files into DataFrames
# full_df = spark.read.csv('departures_full.txt.gz')
# split_df = spark.read.csv('departures_*.txt.gz') 

# # Print the count and run time for each DataFrame
# start_time_a = time.time()
# print("Total rows in full DataFrame:\t%d" % full_df.count())
# print("Time to run: %f" % (time.time() - start_time_a))

# start_time_b = time.time()
# print("Total rows in split DataFrame:\t%d" % split_df.count())
# print("Time to run: %f" % (time.time() - start_time_b))

# Reading Spark configurations

You've recently configured a cluster via a cloud provider. Your only access is via the command shell or your python code. You'd like to verify some Spark settings to validate the configuration of the cluster

In [24]:
# Name of the Spark application instance
app_name = spark.conf.get("spark.app.name")

# Driver TCP port
driver_tcp_port = spark.conf.get("spark.driver.port")

# Number of join partitions
num_partitions = spark.conf.get('spark.sql.shuffle.partitions')

# Show the results
print("Name: %s" % app_name)
print("Driver TCP port: %s" % driver_tcp_port)
print("Number of partitions: %s" % num_partitions)

Name: example
Driver TCP port: 63379
Number of partitions: 200


# Writing Spark configurations

Now that you've reviewed some of the Spark configurations on your cluster, you want to modify some of the settings to tune Spark to your needs. You'll import some data to review that your changes have affected the cluster.

In [25]:
# # Store the number of partitions in variable
# before = departures_df.rdd.getNumPartitions()

# # Configure Spark to use 500 partitions
# spark.conf.set('spark.sql.shuffle.partitions', 500)

# # Recreate the DataFrame using the departures data file
# departures_df = spark.read.csv('departures.txt.gz').distinct()

# # Print the number of partitions for each instance
# print("Partition count before change: %d" % before)
# print("Partition count after change: %d" % departures_df.count())

# Normal joins

You've been given two DataFrames to combine into a single useful DataFrame. Your first task is to combine the DataFrames normally and view the execution plan.

In [26]:
flights_df = spark.read.csv("dataset/AA_DFW_2016_Departures_Short.csv", header=True)
airports_df =  spark.read.csv("dataset/AA_DFW_2017_Departures_Short.csv", header=True)
flights_df.show(3)
airports_df.show(3)


+-----------------+-------------+-------------------+-----------------------------+
|Date (MM/DD/YYYY)|Flight Number|Destination Airport|Actual elapsed time (Minutes)|
+-----------------+-------------+-------------------+-----------------------------+
|       01/01/2016|         0005|                HNL|                          529|
|       01/01/2016|         0007|                OGG|                          512|
|       01/01/2016|         0025|                PHL|                          161|
+-----------------+-------------+-------------------+-----------------------------+
only showing top 3 rows

+-----------------+-------------+-------------------+-----------------------------+
|Date (MM/DD/YYYY)|Flight Number|Destination Airport|Actual elapsed time (Minutes)|
+-----------------+-------------+-------------------+-----------------------------+
|       01/01/2017|         0005|                HNL|                          537|
|       01/01/2017|         0007|                OG

In [27]:
# Join the flights_df and aiports_df DataFrames
normal_df = flights_df.join(airports_df, \
    flights_df["Destination Airport"] == airports_df["Destination Airport"] )

# Show the query plan
normal_df.explain()


== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- BroadcastHashJoin [Destination Airport#516], [Destination Airport#541], Inner, BuildRight, false
   :- Filter isnotnull(Destination Airport#516)
   :  +- FileScan csv [Date (MM/DD/YYYY)#514,Flight Number#515,Destination Airport#516,Actual elapsed time (Minutes)#517] Batched: false, DataFilters: [isnotnull(Destination Airport#516)], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/c:/Datacamp/Python/Cleaning Data with PySpark/dataset/AA_DFW_201..., PartitionFilters: [], PushedFilters: [IsNotNull(Destination Airport)], ReadSchema: struct<Date (MM/DD/YYYY):string,Flight Number:string,Destination Airport:string,Actual elapsed ti...
   +- BroadcastExchange HashedRelationBroadcastMode(List(input[2, string, false]),false), [plan_id=580]
      +- Filter isnotnull(Destination Airport#541)
         +- FileScan csv [Date (MM/DD/YYYY)#539,Flight Number#540,Destination Airport#541,Actual elapsed time (Minutes)#542] Batched: false, Da

# Using broadcasting on Spark joins

Remember that table joins in Spark are split between the cluster workers. If the data is not local, various shuffle operations are required and can have a negative impact on performance. Instead, we're going to use Spark's broadcast operations to give each node a copy of the specified data.

In [29]:
# Import the broadcast method from pyspark.sql.functions
from pyspark.sql.functions import broadcast

# Join the flights_df and airports_df DataFrames using broadcasting
broadcast_df = flights_df.join(broadcast(airports_df), \
    flights_df["Destination Airport"] == airports_df["Destination Airport"] )

# Show the query plan and compare against the original
broadcast_df.explain()

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- BroadcastHashJoin [Destination Airport#516], [Destination Airport#541], Inner, BuildRight, false
   :- Filter isnotnull(Destination Airport#516)
   :  +- FileScan csv [Date (MM/DD/YYYY)#514,Flight Number#515,Destination Airport#516,Actual elapsed time (Minutes)#517] Batched: false, DataFilters: [isnotnull(Destination Airport#516)], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/c:/Datacamp/Python/Cleaning Data with PySpark/dataset/AA_DFW_201..., PartitionFilters: [], PushedFilters: [IsNotNull(Destination Airport)], ReadSchema: struct<Date (MM/DD/YYYY):string,Flight Number:string,Destination Airport:string,Actual elapsed ti...
   +- BroadcastExchange HashedRelationBroadcastMode(List(input[2, string, false]),false), [plan_id=603]
      +- Filter isnotnull(Destination Airport#541)
         +- FileScan csv [Date (MM/DD/YYYY)#539,Flight Number#540,Destination Airport#541,Actual elapsed time (Minutes)#542] Batched: false, Da

# Comparing broadcast vs normal joins

You've created two types of joins, normal and broadcasted. Now your manager would like to know what the performance improvement is by using Spark optimizations. If the results are promising, you'll be given more opportunity to tweak the Spark setup as needed.

In [30]:
start_time = time.time()
# Count the number of rows in the normal DataFrame
normal_count = normal_df.count()
normal_duration = time.time() - start_time

start_time = time.time()
# Count the number of rows in the broadcast DataFrame
broadcast_count = broadcast_df.count()
broadcast_duration = time.time() - start_time

# Print the counts and the duration of the tests
print("Normal count:\t\t%d\tduration: %f" % (normal_count, normal_duration))
print("Broadcast count:\t%d\tduration: %f" % (broadcast_count, broadcast_duration))

Normal count:		369539750	duration: 23.009646
Broadcast count:	369539750	duration: 19.118006
