# Data Wrangling with Spark

This is the code used in the previous screencast. Run each code cell to understand what the code does and how it works.

These first three cells import libraries, instantiate a SparkSession, and then read in the data set

In [49]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf, to_timestamp, to_date, isnan, when, count, col
from pyspark.sql.types import StringType
from pyspark.sql.types import IntegerType
from pyspark.sql.functions import desc
from pyspark.sql.functions import asc
from pyspark.sql.functions import sum as Fsum

import datetime

import numpy as np
import pandas as pd
%matplotlib inline
import matplotlib.pyplot as plt

In [2]:
spark = SparkSession \
    .builder \
    .appName("Wrangling Data") \
    .getOrCreate()

In [3]:
path = "titanic.json"
df = spark.read.json(path)

# Data Exploration 

The next cells explore the data set.

In [4]:
## Get the first 5 rows
df.head(5)

[Row(Age=22.0, Cabin=None, Embarked='S', Fare=7.25, Name='Braund, Mr. Owen Harris', Parch=0, PassengerId=1, Pclass=3, Sex='male', SibSp=1, Survived=0, Ticket='A/5 21171'),
 Row(Age=38.0, Cabin='C85', Embarked='C', Fare=71.2833, Name='Cumings, Mrs. John Bradley (Florence Briggs Thayer)', Parch=0, PassengerId=2, Pclass=1, Sex='female', SibSp=1, Survived=1, Ticket='PC 17599'),
 Row(Age=26.0, Cabin=None, Embarked='S', Fare=7.925, Name='Heikkinen, Miss. Laina', Parch=0, PassengerId=3, Pclass=3, Sex='female', SibSp=0, Survived=1, Ticket='STON/O2. 3101282'),
 Row(Age=35.0, Cabin='C123', Embarked='S', Fare=53.1, Name='Futrelle, Mrs. Jacques Heath (Lily May Peel)', Parch=0, PassengerId=4, Pclass=1, Sex='female', SibSp=1, Survived=1, Ticket='113803'),
 Row(Age=35.0, Cabin=None, Embarked='S', Fare=8.05, Name='Allen, Mr. William Henry', Parch=0, PassengerId=5, Pclass=3, Sex='male', SibSp=0, Survived=0, Ticket='373450')]

In [5]:
## check the dataframe attributes
df.printSchema()

root
 |-- Age: double (nullable = true)
 |-- Cabin: string (nullable = true)
 |-- Embarked: string (nullable = true)
 |-- Fare: double (nullable = true)
 |-- Name: string (nullable = true)
 |-- Parch: long (nullable = true)
 |-- PassengerId: long (nullable = true)
 |-- Pclass: long (nullable = true)
 |-- Sex: string (nullable = true)
 |-- SibSp: long (nullable = true)
 |-- Survived: long (nullable = true)
 |-- Ticket: string (nullable = true)



In [6]:
## Get dataframe columns statistics
df.describe().show()

+-------+------------------+-----+--------+-----------------+--------------------+-------------------+-----------------+------------------+------+------------------+-------------------+------------------+
|summary|               Age|Cabin|Embarked|             Fare|                Name|              Parch|      PassengerId|            Pclass|   Sex|             SibSp|           Survived|            Ticket|
+-------+------------------+-----+--------+-----------------+--------------------+-------------------+-----------------+------------------+------+------------------+-------------------+------------------+
|  count|               714|  204|     889|              891|                 891|                891|              891|               891|   891|               891|                891|               891|
|   mean| 29.69911764705882| null|    null| 32.2042079685746|                null|0.38159371492704824|            446.0| 2.308641975308642|  null|0.5230078563411896| 0.383838383838

In [7]:
## check the Age statistics
df.describe("Age").show()

+-------+------------------+
|summary|               Age|
+-------+------------------+
|  count|               714|
|   mean| 29.69911764705882|
| stddev|14.526497332334035|
|    min|              0.42|
|    max|              80.0|
+-------+------------------+



In [8]:
## Get the number of rows
df.count()

891

In [9]:
## get the number of columns
len(df.columns)

12

In [10]:
## Get unique Parch values using dropDuplicates()
df.select("Parch").dropDuplicates().sort("Parch").show()

+-----+
|Parch|
+-----+
|    0|
|    1|
|    2|
|    3|
|    4|
|    5|
|    6|
+-----+



In [11]:
df.select('Parch').distinct().sort('Parch').show()

+-----+
|Parch|
+-----+
|    0|
|    1|
|    2|
|    3|
|    4|
|    5|
|    6|
+-----+



In [12]:
## get unique Parch values using distinct()
df.select('Parch').distinct().sort('Parch').show()

+-----+
|Parch|
+-----+
|    0|
|    1|
|    2|
|    3|
|    4|
|    5|
|    6|
+-----+



In [13]:
## select details based on some condition using where
df.select(["Name", "Age", "Pclass", "Sex", "Survived"]).where(df.PassengerId == "1").collect()

[Row(Name='Braund, Mr. Owen Harris', Age=22.0, Pclass=3, Sex='male', Survived=0)]

In [14]:
## select details based on some condition using filter
df.select(["Name", "Age", "Pclass", "Sex", "Survived"]).filter(df.PassengerId == "1").collect()

[Row(Name='Braund, Mr. Owen Harris', Age=22.0, Pclass=3, Sex='male', Survived=0)]

## Use of functions as udf

In [16]:
get_prefix = udf(lambda x: x.split(',')[1].strip(" ").split(' ')[0])

In [17]:
## Get the name prefix of each passenger and store the prefix values in NamePrefix column
df = df.withColumn("NamePrefix", get_prefix(df.Name))

In [20]:
## As you can see Name Prefix column is added
df.head(5)

[Row(Age=22.0, Cabin=None, Embarked='S', Fare=7.25, Name='Braund, Mr. Owen Harris', Parch=0, PassengerId=1, Pclass=3, Sex='male', SibSp=1, Survived=0, Ticket='A/5 21171', NamePrefix='Mr.'),
 Row(Age=38.0, Cabin='C85', Embarked='C', Fare=71.2833, Name='Cumings, Mrs. John Bradley (Florence Briggs Thayer)', Parch=0, PassengerId=2, Pclass=1, Sex='female', SibSp=1, Survived=1, Ticket='PC 17599', NamePrefix='Mrs.'),
 Row(Age=26.0, Cabin=None, Embarked='S', Fare=7.925, Name='Heikkinen, Miss. Laina', Parch=0, PassengerId=3, Pclass=3, Sex='female', SibSp=0, Survived=1, Ticket='STON/O2. 3101282', NamePrefix='Miss.'),
 Row(Age=35.0, Cabin='C123', Embarked='S', Fare=53.1, Name='Futrelle, Mrs. Jacques Heath (Lily May Peel)', Parch=0, PassengerId=4, Pclass=1, Sex='female', SibSp=1, Survived=1, Ticket='113803', NamePrefix='Mrs.'),
 Row(Age=35.0, Cabin=None, Embarked='S', Fare=8.05, Name='Allen, Mr. William Henry', Parch=0, PassengerId=5, Pclass=3, Sex='male', SibSp=0, Survived=0, Ticket='373450', Nam

## Get some statistics using Groupby

In [22]:
## Check how many passengers survived based on their name prefix
name_prefix_survived = df.filter(df.Survived == "1").groupby(df.NamePrefix).count().orderBy(desc('count'))

In [23]:
name_prefix_survived.show()

+----------+-----+
|NamePrefix|count|
+----------+-----+
|     Miss.|  127|
|      Mrs.|   99|
|       Mr.|   81|
|   Master.|   23|
|       Dr.|    3|
|     Mlle.|    2|
|       Ms.|    1|
|    Major.|    1|
|      Sir.|    1|
|      Col.|    1|
|      Mme.|    1|
|       the|    1|
|     Lady.|    1|
+----------+-----+



# Drop Rows with Missing Values

As you'll see, it turns out there are some missing values.

In [54]:
## Check Null values in each column
df.select([count(when(col(c).isNull(), c)).alias(c) for c in df.columns]).show()

+---+-----+--------+----+----+-----+-----------+------+---+-----+--------+------+----------+
|Age|Cabin|Embarked|Fare|Name|Parch|PassengerId|Pclass|Sex|SibSp|Survived|Ticket|NamePrefix|
+---+-----+--------+----+----+-----+-----------+------+---+-----+--------+------+----------+
|177|  687|       2|   0|   0|    0|          0|     0|  0|    0|       0|     0|         0|
+---+-----+--------+----+----+-----+-----------+------+---+-----+--------+------+----------+



In [41]:
## Check Null values in Age column
len(df.where(col("Age").isNull()).collect())

177

In [25]:
df_valid = df.dropna(how = "any", subset = ["Age"])

In [26]:
## It dropped around 177 rows
df_valid.count()

714

# Use Window function

Use a window function and cumulative sum to distinguish each user's data.

In [56]:
from pyspark.sql import Window

In [57]:
## Created some dummy data to apply Window function
user_log = spark.read.csv("data_for_window.csv", header=True)

In [58]:
user_log.show()

+---+---------+----------+
| id|    start|some_value|
+---+---------+----------+
|  1| 1/1/2015|        20|
|  1| 1/6/2015|        10|
|  1| 1/7/2015|        25|
|  1|1/12/2015|        30|
|  2| 1/1/2015|         5|
|  2| 1/3/2015|        30|
|  2| 2/1/2015|        20|
+---+---------+----------+



In [59]:
user_log.printSchema()

root
 |-- id: string (nullable = true)
 |-- start: string (nullable = true)
 |-- some_value: string (nullable = true)



As you can see start column which contains dates is of type string so we need to first convert it to date type

In [60]:
user_log = user_log.withColumn("start_dates",to_date(user_log.start, 'mm/dd/yyyy'))
user_log.show()

+---+---------+----------+-----------+
| id|    start|some_value|start_dates|
+---+---------+----------+-----------+
|  1| 1/1/2015|        20| 2015-01-01|
|  1| 1/6/2015|        10| 2015-01-06|
|  1| 1/7/2015|        25| 2015-01-07|
|  1|1/12/2015|        30| 2015-01-12|
|  2| 1/1/2015|         5| 2015-01-01|
|  2| 1/3/2015|        30| 2015-01-03|
|  2| 2/1/2015|        20| 2015-01-01|
+---+---------+----------+-----------+



In [62]:
# as you can see start_dates is a column of type Date
user_log.printSchema()

root
 |-- id: string (nullable = true)
 |-- start: string (nullable = true)
 |-- some_value: string (nullable = true)
 |-- start_dates: date (nullable = true)



Create a window object for each id i.e. any function we apply over window such as mean or sum, that will apply on each id

In [68]:
windowval = Window.partitionBy("id").orderBy(desc("start_dates")).rangeBetween(Window.unboundedPreceding, 0)

In [69]:
user_log_valid = user_log.withColumn("range_sum", Fsum("some_value").over(windowval))

As you can see, range sum column is having the cumulative sum of some_value for each id

In [70]:
user_log_valid.show()

+---+---------+----------+-----------+---------+
| id|    start|some_value|start_dates|range_sum|
+---+---------+----------+-----------+---------+
|  1|1/12/2015|        30| 2015-01-12|     30.0|
|  1| 1/7/2015|        25| 2015-01-07|     55.0|
|  1| 1/6/2015|        10| 2015-01-06|     65.0|
|  1| 1/1/2015|        20| 2015-01-01|     85.0|
|  2| 1/3/2015|        30| 2015-01-03|     30.0|
|  2| 1/1/2015|         5| 2015-01-01|     55.0|
|  2| 2/1/2015|        20| 2015-01-01|     55.0|
+---+---------+----------+-----------+---------+

