# Lab 5.1: Data Persistence with Spark

Let's explore data persistence and caching and its importance in Spark performance tuning.

In [1]:
# Set up Spark Context
from pyspark import SparkContext, SparkConf

SparkContext.setSystemProperty('spark.executor.memory', '2g')
conf = SparkConf()
conf.set('spark.executor.instances', 15)
sc = SparkContext('yarn-client', 'Spark-lab5.1', conf=conf)

Create an RDD with weather data, with the following 4 fields: station, date, metric and value. Load only the year 2013 from this dataset - the file for year 2013 resides in /user/jupyter/weather/2013.csv.

In [2]:
import pandas as pd

lines = sc.textFile("weather/2013.csv")
weather = lines.map(lambda line: line.split(',')) \
               .map(lambda row: [row[0], row[1], row[2], row[3]])  # schema: station, date, metric type, value

Without caching any data, use Spark to compute the maximum and minimum values of TMIN during 2013. 

Use IPython's "%%timeit" to measure the time it takes Spark to execute both queries. 
Note: "%%timtit" runs the cell in a loop, and thus all variables inside the cell are NOT available in later cells.

More details on %%timeit are here: https://ipython.org/ipython-doc/dev/interactive/magics.html#magic-timeit

In [3]:
%%timeit -n1 -r1

wf = weather.filter(lambda row: row[2]=='TMIN').map(lambda row: int(row[3]))
print wf.min()
print wf.max()

-994
7111
1 loops, best of 1: 1min 8s per loop


Now run this again, this time use cache() to optimize performance and measure again using "%%timeit" 

In [4]:
%%timeit -n1 -r1

wf = weather.filter(lambda row: row[2]=='TMIN').map(lambda row: int(row[3])).cache()
print wf.min()
print wf.max()

-994
7111
1 loops, best of 1: 33.8 s per loop


Check Spark UI for the jobs execution time. See how the caching improved the run-time of the second operation.

Next, determine how many partitions exist for the weather dataset?

In [5]:
weather.getNumPartitions()

8

Print out the number of items in each partition (use mapPartitionsWithIndex)

In [6]:
print weather.mapPartitionsWithIndex(lambda index,iterator: ((index,sum(1 for _ in iterator)),)).collect()

[(0, 3813172), (1, 3817869), (2, 3824474), (3, 3815669), (4, 3814805), (5, 3824935), (6, 3833317), (7, 3155909)]
