In [1]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

In [2]:
import time

In [3]:
spark = SparkSession.builder.appName("ExampleApp").getOrCreate()

In [4]:
departures_df = spark.read.format('csv').options(Header=True).load('../data/AA_DFW_2014_Departures_Short.csv.gz')

In [5]:
start_time = time.time()

# Add caching to the unique rows in departures_df
departures_df = departures_df.distinct().cache()

# Count the unique rows in departures_df, noting how long the operation takes
print("Counting %d rows took %f seconds" % (departures_df.count(), time.time() - start_time))

# Count the rows again, noting the variance in time of a cached DataFrame
start_time = time.time()
print("Counting %d rows again took %f seconds" % (departures_df.count(), time.time() - start_time))

Counting 157198 rows took 6.182705 seconds
Counting 157198 rows again took 1.149794 seconds


In [6]:
# Determine if departures_df is in the cache
print("Is departures_df cached?: %s" % departures_df.is_cached)
print("Removing departures_df from cache")

# Remove departures_df from the cache
departures_df.unpersist()

# Check the cache status again
print("Is departures_df cached?: %s" % departures_df.is_cached)

Is departures_df cached?: True
Removing departures_df from cache
Is departures_df cached?: False


In [7]:
# # Import the full and split files into DataFrames
# full_df = spark.read.csv('departures_full.txt.gz')
# split_df = spark.read.csv('departures_0*.txt.gz')

# # Print the count and run time for each DataFrame
# start_time_a = time.time()
# print("Total rows in full DataFrame:\t%d" % full_df.count())
# print("Time to run: %f" % (time.time() - start_time_a))

# start_time_b = time.time()
# print("Total rows in split DataFrame:\t%d" % split_df.count())
# print("Time to run: %f" % (time.time() - start_time_b))

# import split is faster.

In [8]:
# Name of the Spark application instance
app_name = spark.conf.get('spark.app.name')

# Driver TCP port
driver_tcp_port = spark.conf.get('spark.driver.port')

# Number of join partitions
num_partitions = spark.conf.get('spark.sql.shuffle.partitions')

# Show the results
print("Name: %s" % app_name)
print("Driver TCP port: %s" % driver_tcp_port)
print("Number of partitions: %s" % num_partitions)

Name: ExampleApp
Driver TCP port: 65435
Number of partitions: 200


In [9]:
# Store the number of partitions in variable
before = departures_df.rdd.getNumPartitions()

# Configure Spark to use 500 partitions
spark.conf.set('spark.sql.shuffle.partitions', 500)

# Recreate the DataFrame using the departures data file
departures_df = spark.read.csv('../data/AA_DFW_2014_Departures_Short.csv.gz').distinct()

# Print the number of partitions for each instance
print("Partition count before change: %d" % before)
print("Partition count after change: %d" % departures_df.rdd.getNumPartitions())

Partition count before change: 2
Partition count after change: 3


In [None]:
# # Join the flights_df and aiports_df DataFrames
# normal_df = flights_df.join(airports_df, \
#     flights_df["Destination Airport"] == airports_df["IATA"] )

# # Show the query plan
# normal_df.explain()

In [None]:
# start_time = time.time()
# # Count the number of rows in the normal DataFrame
# normal_count = normal_df.count()
# normal_duration = time.time() - start_time

# start_time = time.time()
# # Count the number of rows in the broadcast DataFrame
# broadcast_count = broadcast_df.count()
# broadcast_duration = time.time() - start_time

# # Print the counts and the duration of the tests
# print("Normal count:\t\t%d\tduration: %f" % (normal_count, normal_duration))
# print("Broadcast count:\t%d\tduration: %f" % (broadcast_count, broadcast_duration))

# Normal count:		119910	duration: 0.626342
# Broadcast count:	119910	duration: 0.309472