In [1]:
from pyspark.sql import SparkSession
from pyspark.sql import SQLContext

print('Import successfull')

Import successfull


<u><h1>First Part</h1></u>
----------------------
very simple script to initialize a spark session and import a flat file.

In [3]:
# Init session
# Spark session creates entry point for application; lets you interact with Spark APIs
# getOrCreate() returns a new session if app exists or creates a new one
scSpark = SparkSession.builder.appName("reading csv").getOrCreate()

# read data
file_data = './Data/data.csv'
sdfData = scSpark.read.csv(file_data, header=True, sep=",").cache()
print(f'Total Records = {sdfData.count()}')
sdfData.show()

Total Records = 4
+------+---+--------+
|  name|age| country|
+------+---+--------+
| adnan| 40|Pakistan|
|  maaz|  9|Pakistan|
| musab|  4|Pakistan|
|ayesha| 32|Pakistan|
+------+---+--------+



<u><h1>Second Part</h1></u>
We will explore using SQL queries in Spark. We will be using real data which can be found on <a href=https://www.kaggle.com/aungpyaeap/supermarket-sales target=_blank>kaggle</a>.
<br>
<h2>Extract</h2>


In [19]:
# Extract the data
data_file = './Data/supermarket_sales - Sheet1.csv'
data = scSpark.read.csv(data_file, header=True, sep=",").cache()

#Explore the dataset
#Look at num of records
print(f'Total Records = {data.count()}')
#Look at schema
print('\nSchema')
print('--------------------------------------------------')
print(data.printSchema())
print('--------------------------------------------------\n')
print('Summary statistics')
print('--------------------------------------------------\n')
print(data.summary().show())

Total Records = 1000

Schema
--------------------------------------------------
root
 |-- Invoice ID: string (nullable = true)
 |-- Branch: string (nullable = true)
 |-- City: string (nullable = true)
 |-- Customer type: string (nullable = true)
 |-- Gender: string (nullable = true)
 |-- Product line: string (nullable = true)
 |-- Unit price: string (nullable = true)
 |-- Quantity: string (nullable = true)
 |-- Tax 5%: string (nullable = true)
 |-- Total: string (nullable = true)
 |-- Date: string (nullable = true)
 |-- Time: string (nullable = true)
 |-- Payment: string (nullable = true)
 |-- cogs: string (nullable = true)
 |-- gross margin percentage: string (nullable = true)
 |-- gross income: string (nullable = true)
 |-- Rating: string (nullable = true)

None
--------------------------------------------------

Summary statistics
--------------------------------------------------

+-------+-----------+------+--------+-------------+------+--------------------+------------------+----

<h2>Transform</h2>

In [20]:
# group data by gender

gender = data.groupBy('Gender').count()
print(gender.show())

+------+-----+
|Gender|count|
+------+-----+
|Female|  501|
|  Male|  499|
+------+-----+

None


In [22]:
# Create temporary table of sales
# query the temporary table to select all fields

data.registerTempTable('sales')
output = scSpark.sql('SELECT * FROM sales')
output.show()

+-----------+------+---------+-------------+------+--------------------+----------+--------+-------+--------+---------+-----+-----------+------+-----------------------+------------+------+
| Invoice ID|Branch|     City|Customer type|Gender|        Product line|Unit price|Quantity| Tax 5%|   Total|     Date| Time|    Payment|  cogs|gross margin percentage|gross income|Rating|
+-----------+------+---------+-------------+------+--------------------+----------+--------+-------+--------+---------+-----+-----------+------+-----------------------+------------+------+
|750-67-8428|     A|   Yangon|       Member|Female|   Health and beauty|     74.69|       7|26.1415|548.9715| 1/5/2019|13:08|    Ewallet|522.83|            4.761904762|     26.1415|   9.1|
|226-31-3081|     C|Naypyitaw|       Normal|Female|Electronic access...|     15.28|       5|   3.82|   80.22| 3/8/2019|10:29|       Cash|  76.4|            4.761904762|        3.82|   9.6|
|631-41-3108|     A|   Yangon|       Normal|  Male|  Ho

In [23]:
# Modify previous query to add WHERE clause

output2 = scSpark.sql('SELECT * \
                    FROM sales \
                    WHERE `Unit Price` < 15 AND Quantity < 10 ')
output2.show()

+-----------+------+---------+-------------+------+--------------------+----------+--------+------+--------+---------+-----+-----------+------+-----------------------+------------+------+
| Invoice ID|Branch|     City|Customer type|Gender|        Product line|Unit price|Quantity|Tax 5%|   Total|     Date| Time|    Payment|  cogs|gross margin percentage|gross income|Rating|
+-----------+------+---------+-------------+------+--------------------+----------+--------+------+--------+---------+-----+-----------+------+-----------------------+------------+------+
|351-62-0822|     B| Mandalay|       Member|Female| Fashion accessories|     14.48|       4| 2.896|  60.816| 2/6/2019|18:07|    Ewallet| 57.92|            4.761904762|       2.896|   4.5|
|871-39-9221|     C|Naypyitaw|       Normal|Female|Electronic access...|     12.45|       6| 3.735|  78.435| 2/9/2019|13:11|       Cash|  74.7|            4.761904762|       3.735|   4.1|
|586-25-0848|     A|   Yangon|       Normal|Female|   Sports

In [27]:
# Aggregate values

output3 = scSpark.sql('SELECT COUNT(*) as total, City FROM sales GROUP BY City')
output3.show()

+-----+---------+
|total|     City|
+-----+---------+
|  328|Naypyitaw|
|  332| Mandalay|
|  340|   Yangon|
+-----+---------+



<h2>Load</h2>

In [28]:
# Multiple files will be created
output3.write.format('json').save('filtered.json')

In [30]:
# to savve as only 1 file use coalesce
output3.coalesce(1).write.format('json').save('filtered_onefile.json')

## We can use pyspark to extract data from a database as well

We will access MySql with pyspark
The dataset can be found <a href='https://archive.ics.uci.edu/ml/datasets/Wine+Quality' target=_blank>here</a>

- First we will use panda to import the data, seperate red from white wines
  - no cleaning, eda, or nothing is done here
  - this step is just to create a MySql Database
- We will then write the results to a MySql database
- Then access the database using Pyspark


In [11]:
import pandas as pd
import mysql.connector

# Load data
red_wines = pd.read_csv('./winequality/winequality-red.csv', sep=";")
white_wines = pd.read_csv('./winequality/winequality-white.csv', sep=";")

# add coluns labeling type of wine: 1 for red 0 for white
red_wines['is_red'] = 1
white_wines['is_red'] = 0

# concat red and white wines dataframe into one
wines = pd.concat([red_wines, white_wines])

# Connect to MySql and add a cursor to execute queries
db_con = mysql.connector.connect(user='root', password='DumbDumb')
db_cursor = db_con.cursor()

# Create DataBase and create a Table
db_cursor.execute('CREATE DATABASE IF NOT EXISTS TestDB;')
db_cursor.execute('USE TestDB')

# Make column names in table match with the ones from the csv file
db_cursor.execute("CREATE TABLE IF NOT EXISTS Wines(fixed_acidity FLOAT, volatile_acidity FLOAT, \
                   citric_acid FLOAT, residual_sugar FLOAT, chlorides FLOAT, \
                   free_so2 FLOAT, total_so2 FLOAT, density FLOAT, pH FLOAT, \
                   sulphates FLOAT, alcohol FLOAT, quality INT, is_red INT);")

# Can load multiple rows in a MySQL table if the contents of each row is contained within parenthesis and comma seperated
# Match the syntax of INSERT
wine_tuples = list(wines.itertuples(index=False, name=None))
wine_tuples_string = ",".join(["(" + ",".join([str(w) for w in wt]) + ")" for wt in wine_tuples])

db_cursor.execute("INSERT INTO Wines(fixed_acidity, volatile_acidity, citric_acid,\
                   residual_sugar, chlorides, free_so2, total_so2, density, pH,\
                   sulphates, alcohol, quality, is_red) VALUES " + wine_tuples_string + ";")



In [11]:
# Check that the database and table were created and data sucessfully added to the table:
db_cursor.execute('SELECT * FROM Wines LIMIT 2;')
result1 = db_cursor.fetchall()
for x in result1:
    print(x)
    
print(wines.head(2))

# The results are kind of messy because the long column names makes the pandas result span multiple rows
# But you can compare the results between pandas and MySql amd see that the data matches

(7.4, 0.7, 0.0, 1.9, 0.076, 11.0, 34.0, 0.9978, 3.51, 0.56, 9.4, 5, 1)
(7.8, 0.88, 0.0, 2.6, 0.098, 25.0, 67.0, 0.9968, 3.2, 0.68, 9.8, 5, 1)
   fixed acidity  volatile acidity  citric acid  residual sugar  chlorides  \
0            7.4              0.70          0.0             1.9      0.076   
1            7.8              0.88          0.0             2.6      0.098   

   free sulfur dioxide  total sulfur dioxide  density    pH  sulphates  \
0                 11.0                  34.0   0.9978  3.51       0.56   
1                 25.0                  67.0   0.9968  3.20       0.68   

   alcohol  quality  is_red  
0      9.4        5       1  
1      9.8        5       1  


None


## Access MySQL with Pyspark

Doing this can very quickly lead to massive headheaches...<br>
We are establishing a jdbc connection<br>
Make sure you have downloaded the correct MysQL connector (specifically the mysql-connector-java:$<version>$) and that the jar file is in the right location


In [2]:
spark = SparkSession.builder.config("spark.jars", "/usr/share/java/mysql-connector-java-8.0.22.jar") \
    .master("local").appName("PySpark_MySQL_test").getOrCreate()

In [4]:
import findspark

findspark.add_packages('mysql:mysql-connector-java:8.0.11')

In [7]:
wine_df = spark.read.format("jdbc").option("url", "jdbc:mysql://localhost:3306/TestDB") \
    .option("driver", "com.mysql.jdbc.Driver").option("dbtable", "Wines") \
    .option("user", "root").option("password", "DumbDumb").load()

In [43]:
# explore a bit :)
# SHow first 5 rows of dataframe
print(wine_df.show(n=5))

# Count number of rows
print(f'\nThe dataframe has {wine_df.count()} rows')

# Show number of partions in underlying RDD
print(f'Number of partitions = {wine_df.rdd.getNumPartitions()}')

# Get datatype of columns 
print(wine_df.printSchema())

# Get basic statistic of datframe
print(wine_df.describe().show())

# Calling describe() on the whole dataframe gives a messy output
# we can call it on a single column
print(wine_df.describe('volatile_acidity').show())

# We can also print the summary statistics of more than 1 col at a time
print(wine_df.describe(['volatile_acidity', 'citric_acid', 'alcohol', 'quality']).show())

# If we absolutely need the basic summary statistics of the whole data frame and need to format it to look nice
# we can can itterate over each column, find the statistics of that column, round the numbers to the desired format
# concat the results into a new data frame
# kind of a pain..


+-------------+----------------+-----------+--------------+---------+--------+---------+-------+----+---------+-------+-------+------+
|fixed_acidity|volatile_acidity|citric_acid|residual_sugar|chlorides|free_so2|total_so2|density|  pH|sulphates|alcohol|quality|is_red|
+-------------+----------------+-----------+--------------+---------+--------+---------+-------+----+---------+-------+-------+------+
|          7.4|             0.7|        0.0|           1.9|    0.076|    11.0|     34.0| 0.9978|3.51|     0.56|    9.4|      5|     1|
|          7.8|            0.88|        0.0|           2.6|    0.098|    25.0|     67.0| 0.9968| 3.2|     0.68|    9.8|      5|     1|
|          7.8|            0.76|       0.04|           2.3|    0.092|    15.0|     54.0|  0.997|3.26|     0.65|    9.8|      5|     1|
|         11.2|            0.28|       0.56|           1.9|    0.075|    17.0|     60.0|  0.998|3.16|     0.58|    9.8|      6|     1|
|          7.4|             0.7|        0.0|           