In [1]:
# Checking if we are connected to the Spark cluster
sc

### `Importing useful libraries`

In [2]:
from pyspark.sql.functions import broadcast
from pyspark.sql import functions as F
from datetime import date
import pandas as pd
import numpy as np

import pyspark.sql.types as T
from pyspark.sql.window import Window
from pyspark.sql.types import DoubleType

import pyspark.sql.functions as F
from pyspark.sql.types import IntegerType
from pyspark.sql.types import FloatType
from pyspark.sql.functions import *

#### `Logical Operations`

  * Filter products based on value
  
  
  * Create a flag based on values
  
  
  * Modifying/ Creating columns using a list

In [3]:
# Pulling datasets from GCP
customers = sqlContext.table("an_training_easl.training1_cards")
products = sqlContext.table("an_training_easl.training1_products")
transactions = sqlContext.table("an_training_easl.training1_transactions")

In [4]:
# Subsetting dataset based on values
prod_sub = products.where((F.col("category_id") == "C1") | (F.col("category_desc") == "yogurt"))
prod_sub.show()

+-------+-----------+----------+-------------+
|prod_id|category_id|department|category_desc|
+-------+-----------+----------+-------------+
|     P1|         C1|        D1|         Milk|
|     P3|         C3|        D1|       yogurt|
|     P8|         C1|        D1|         Milk|
|    P10|         C3|        D1|       yogurt|
|    P15|         C1|        D1|         Milk|
|    P17|         C3|        D1|       yogurt|
|    P22|         C1|        D1|         Milk|
|    P24|         C3|        D1|       yogurt|
|    P29|         C1|        D1|         Milk|
|    P31|         C3|        D1|       yogurt|
+-------+-----------+----------+-------------+



In [5]:
# Creating a column called Flag based on certain values/ like case when in SQL
prod_sub = products.withColumn("Flag", F.when((F.col("category_id") == "C1")|\
                                              (F.col("category_desc") == "yogurt"), F.lit(1)) \
                                       .otherwise(F.lit(0))
                              )
prod_sub.show(n=5)

+-------+-----------+----------+-------------+----+
|prod_id|category_id|department|category_desc|Flag|
+-------+-----------+----------+-------------+----+
|     P1|         C1|        D1|         Milk|   1|
|     P2|         C2|        D2|      Shampoo|   0|
|     P3|         C3|        D1|       yogurt|   1|
|     P4|         C4|        D2|         soap|   0|
|     P5|         C5|        D2|floor cleaner|   0|
+-------+-----------+----------+-------------+----+
only showing top 5 rows



In [6]:
# Can be skipped as this uses for loops and list
# Pre-appending 'col_' before each column
prod_sub.select([F.col(column).alias('col_'+column) for column in products.columns if column in prod_sub.columns]).show(n=5)

+-----------+---------------+--------------+-----------------+
|col_prod_id|col_category_id|col_department|col_category_desc|
+-----------+---------------+--------------+-----------------+
|         P1|             C1|            D1|             Milk|
|         P2|             C2|            D2|          Shampoo|
|         P3|             C3|            D1|           yogurt|
|         P4|             C4|            D2|             soap|
|         P5|             C5|            D2|    floor cleaner|
+-----------+---------------+--------------+-----------------+
only showing top 5 rows



#### `Group by and aggregations on Dataframes`

  * Find distribution
  
  
      * number of values in each column
      * maximum and minimun of each column
      * Average values
      
      
  * Rename columns using withColumnRenamed

In [7]:
# Summarizing dataframe at a shabit level and finding count, max and average values
shabits_dist1 = customers.groupBy('shabits') \
                         .agg({'card_id': 'count',
                               'income' : 'max',
                               'age'    : 'mean'})
shabits_dist1.show()

+-------+-----------+--------+--------------+
|shabits|max(income)|avg(age)|count(card_id)|
+-------+-----------+--------+--------------+
|     VL|   999432.0|    49.6|             5|
|     Po|    55674.0|    31.0|             1|
|     Vl|    50000.0|    25.0|             1|
|     PO|   558700.0|    38.0|             6|
|     LP|    20000.0|    24.0|             1|
|     PR|  1000000.0|    40.4|             5|
+-------+-----------+--------+--------------+



In [8]:
#In case of multiple aggregation on single column, dict is not useful
shabits_dist2 = customers.groupBy('shabits')\
                         .agg(F.count('card_id').alias('customers'), \
                              F.max('income').alias('max_income'), \
                              F.min('income').alias('min_income'), \
                              F.mean('age').alias('average_age'))
shabits_dist2.show()

+-------+---------+----------+----------+-----------+
|shabits|customers|max_income|min_income|average_age|
+-------+---------+----------+----------+-----------+
|     VL|        5|  999432.0|   45342.0|       49.6|
|     Po|        1|   55674.0|   55674.0|       31.0|
|     Vl|        1|   50000.0|   50000.0|       25.0|
|     PO|        6|  558700.0|   12000.0|       38.0|
|     LP|        1|   20000.0|   20000.0|       24.0|
|     PR|        5| 1000000.0|   21098.0|       40.4|
+-------+---------+----------+----------+-----------+



In [9]:
#Rename aggregated columns using withColumnRenamed or select with alias
shabits_dist_renamed = shabits_dist1.withColumnRenamed('max(income)', 'max_income') \
                                    .withColumnRenamed('avg(age)', 'avg_age') \
                                    .withColumnRenamed('count(card_id)', '# of Customers')
shabits_dist_renamed.show()

+-------+----------+-------+--------------+
|shabits|max_income|avg_age|# of Customers|
+-------+----------+-------+--------------+
|     VL|  999432.0|   49.6|             5|
|     Po|   55674.0|   31.0|             1|
|     Vl|   50000.0|   25.0|             1|
|     PO|  558700.0|   38.0|             6|
|     LP|   20000.0|   24.0|             1|
|     PR| 1000000.0|   40.4|             5|
+-------+----------+-------+--------------+



#### `Joins On DataFrames`
  
  * Join 2 different dataframes together
  
  
  * Find out summarized value on joined dataframe

In [10]:
# Printing schema of transaction dataframe
transactions.printSchema()

root
 |-- transaction_id: long (nullable = true)
 |-- date_id: string (nullable = true)
 |-- card_id: long (nullable = true)
 |-- prod_id: string (nullable = true)
 |-- qty: long (nullable = true)
 |-- amount: float (nullable = true)
 |-- week_id: integer (nullable = true)



In [11]:
# Printing schema of products dataframe
products.printSchema()

root
 |-- prod_id: string (nullable = true)
 |-- category_id: string (nullable = true)
 |-- department: string (nullable = true)
 |-- category_desc: string (nullable = true)



In [12]:
# Joining transactions and products dataframe together using prod_id as the join key
prod_trans = transactions.join(products, transactions.prod_id == products.prod_id, 'inner')
prod_trans.show()

#For 'left', 'right', 'outer' -  replace 'inner' with the needed join type

+--------------+----------+-------+-------+---+------+-------+-------+-----------+----------+-------------+
|transaction_id|   date_id|card_id|prod_id|qty|amount|week_id|prod_id|category_id|department|category_desc|
+--------------+----------+-------+-------+---+------+-------+-------+-----------+----------+-------------+
|      10134103|01-05-2017|    101|    P20|  1|  45.0| 201718|    P20|         C6|        D3|     biscuits|
|      10162849|01-05-2017|    103|    P35|  1| 120.0| 201718|    P35|         C7|        D3|         soda|
|      10187085|01-05-2017|    105|    P31|  1|  28.0| 201718|    P31|         C3|        D1|       yogurt|
|      10413355|02-05-2017|    107|    P13|  1|  70.0| 201718|    P13|         C6|        D3|     biscuits|
|      10543839|02-05-2017|    111|    P35|  1| 120.0| 201718|    P35|         C7|        D3|         soda|
|      10600256|02-05-2017|    113|    P25|  1|  40.0| 201718|    P25|         C4|        D2|         soap|
|      10664653|02-05-2017| 

In [13]:
# Joining datasets and summarizing/ aggregating values
prod_trans2 = transactions.join(products, transactions.prod_id == products.prod_id, 'inner') \
                          .groupBy(products.category_desc, products.category_id) \
                          .agg(F.sum('amount').alias("tot_amount"), F.sum('qty').alias("tot_qty"))
prod_trans2.show()

+-------------+-----------+----------+-------+
|category_desc|category_id|tot_amount|tot_qty|
+-------------+-----------+----------+-------+
|         soda|         C7|    2470.0|     31|
|         soap|         C4|     760.0|     22|
|       yogurt|         C3|     952.0|     25|
|      Shampoo|         C2|    6121.0|     33|
|floor cleaner|         C5|    3750.0|     29|
|     biscuits|         C6|    1700.0|     38|
|         Milk|         C1|     765.0|     38|
+-------------+-----------+----------+-------+



#### `Advanced Functionalities`
  
  * Summarizing values based on several filters and Creating pivot
  
  
  * Creating pivot on multiple indexes
  
  
  * Creating pivot with multiple aggregations
  
  
  * Pivot with filtered values

In [14]:
# Creating a transaction table with selected columns and data filtered on weeks and card_id
trans=sqlContext.table('an_training_easl.training1_transactions') \
                    .select('week_id','prod_id','card_id', 'amount', 'qty','transaction_id') \
                    .where((F.col('week_id').between(201718,201719)) & 
                           (F.col('card_id').isNotNull()))

# Picking up non-duplicate values from products table
prods=sqlContext.table('an_training_easl.training1_products') \
                    .select('prod_id','department','category_desc') \
                    .dropDuplicates()

# Joining the datasets together
df = trans.join(broadcast(prods),"prod_id","inner") 
df.show(10)
### We'll see what a broadcast join is later, for the time being consider it just as a join

+-------+-------+-------+------+---+--------------+----------+-------------+
|prod_id|week_id|card_id|amount|qty|transaction_id|department|category_desc|
+-------+-------+-------+------+---+--------------+----------+-------------+
|    P20| 201718|    101|  45.0|  1|      10134103|        D3|     biscuits|
|    P35| 201718|    103| 120.0|  1|      10162849|        D3|         soda|
|    P31| 201718|    105|  28.0|  1|      10187085|        D1|       yogurt|
|    P13| 201718|    107|  70.0|  1|      10413355|        D3|     biscuits|
|    P35| 201718|    111| 120.0|  1|      10543839|        D3|         soda|
|    P25| 201718|    113|  40.0|  1|      10600256|        D2|         soap|
|    P33| 201718|    115| 110.0|  1|      10664653|        D2|floor cleaner|
|    P23| 201718|    117| 250.0|  1|      11026889|        D2|      Shampoo|
|    P13| 201718|    120|  70.0|  1|      11026939|        D3|     biscuits|
|     P2| 201718|    122| 250.0|  1|      11221390|        D2|      Shampoo|

In [15]:
# Pivoting the above dataset on category description and adding up the amounts
pivot_df = df.groupby('category_desc').pivot('week_id').sum('amount')
pivot_df.show(10)

+-------------+------+------+
|category_desc|201718|201719|
+-------------+------+------+
|         Milk| 300.0| 465.0|
|       yogurt| 417.0| 535.0|
|         soap| 280.0| 480.0|
|floor cleaner|2710.0|1040.0|
|         soda|1630.0| 840.0|
|      Shampoo|3148.0|2973.0|
|     biscuits| 974.0| 726.0|
+-------------+------+------+



In [16]:
# Pivoting with multiple indexes
pivot_df = df.groupby('category_desc','department').pivot('week_id').sum('amount').sort('category_desc')
pivot_df.show(10)

+-------------+----------+------+------+
|category_desc|department|201718|201719|
+-------------+----------+------+------+
|         Milk|        D1| 300.0| 465.0|
|      Shampoo|        D2|3148.0|2973.0|
|     biscuits|        D3| 974.0| 726.0|
|floor cleaner|        D2|2710.0|1040.0|
|         soap|        D2| 280.0| 480.0|
|         soda|        D3|1630.0| 840.0|
|       yogurt|        D1| 417.0| 535.0|
+-------------+----------+------+------+



In [17]:
# multiple aggregations using pivot on the same dataset
pivot_df = df.groupby('category_desc') \
             .pivot('week_id') \
             .agg(F.sum('amount').alias('sum'), \
                  F.countDistinct('prod_id').alias('count'))

pivot_df.show(10)

+-------------+----------+------------+----------+------------+
|category_desc|201718_sum|201718_count|201719_sum|201719_count|
+-------------+----------+------------+----------+------------+
|         Milk|     300.0|           5|     465.0|           5|
|       yogurt|     417.0|           5|     535.0|           5|
|         soap|     280.0|           4|     480.0|           4|
|floor cleaner|    2710.0|           5|    1040.0|           5|
|         soda|    1630.0|           5|     840.0|           5|
|      Shampoo|    3148.0|           5|    2973.0|           5|
|     biscuits|     974.0|           5|     726.0|           5|
+-------------+----------+------------+----------+------------+



In [18]:
pivot_df = df.groupby('category_desc') \
             .pivot('week_id',[201718]) \
             .agg(F.sum('amount').alias('sum'), \
                  F.countDistinct('prod_id').alias('count'))
        
pivot_df.show(10)

+-------------+----------+------------+
|category_desc|201718_sum|201718_count|
+-------------+----------+------------+
|         Milk|     300.0|           5|
|       yogurt|     417.0|           5|
|         soap|     280.0|           4|
|floor cleaner|    2710.0|           5|
|         soda|    1630.0|           5|
|      Shampoo|    3148.0|           5|
|     biscuits|     974.0|           5|
+-------------+----------+------------+



#### `Advanced Functionalities`
  
  * Sampling using limit
  
  
  * Creating a random sample
  
  
  * Creating a stratified sample

In [19]:
# Pulling out products data
prods = sqlContext.table("an_training_easl.training1_products")

# Limiting values to 10 rows
prods.limit(10).show()

# Checking the count of rows
print("The count of rows is ->")
print(prods.count())
print(prods.limit(10).count())

+-------+-----------+----------+-------------+
|prod_id|category_id|department|category_desc|
+-------+-----------+----------+-------------+
|     P1|         C1|        D1|         Milk|
|     P2|         C2|        D2|      Shampoo|
|     P3|         C3|        D1|       yogurt|
|     P4|         C4|        D2|         soap|
|     P5|         C5|        D2|floor cleaner|
|     P6|         C6|        D3|     biscuits|
|     P7|         C7|        D3|         soda|
|     P8|         C1|        D1|         Milk|
|     P9|         C2|        D2|      Shampoo|
|    P10|         C3|        D1|       yogurt|
+-------+-----------+----------+-------------+

The count of rows is ->
35
10


In [20]:
# Pulling a random sample

# Sample function arguments are: -
# sample(withReplacement, fraction, seed=None)

sample = prods.sample(False,0.5,222)

print("The count of rows is ->")
sample.count()

The count of rows is ->


16

In [21]:
#Creating a stratified sample

print "Actual data distribution by dept\n"
prods.groupBy("department").count().orderBy("department").show()

print "sampled data distribution by dept\n"

# Startified sampling
sampled = prods.sampleBy("department", fractions={'D1': 0.5, 'D2': 0.2, 'D3': 0.1}, seed=0)

# Aggregating counts
sampled.groupBy("department").count().orderBy("department").show()

Actual data distribution by dept

+----------+-----+
|department|count|
+----------+-----+
|        D1|   10|
|        D2|   15|
|        D3|   10|
+----------+-----+

sampled data distribution by dept

+----------+-----+
|department|count|
+----------+-----+
|        D1|    5|
|        D3|    2|
+----------+-----+



#### For curious ones!

In [22]:
# Creating an id column
from pyspark.sql.functions import monotonically_increasing_id 
df_index = prods.select("*").withColumn("id", monotonically_increasing_id())
df_index.show()

+-------+-----------+----------+-------------+---+
|prod_id|category_id|department|category_desc| id|
+-------+-----------+----------+-------------+---+
|     P1|         C1|        D1|         Milk|  0|
|     P2|         C2|        D2|      Shampoo|  1|
|     P3|         C3|        D1|       yogurt|  2|
|     P4|         C4|        D2|         soap|  3|
|     P5|         C5|        D2|floor cleaner|  4|
|     P6|         C6|        D3|     biscuits|  5|
|     P7|         C7|        D3|         soda|  6|
|     P8|         C1|        D1|         Milk|  7|
|     P9|         C2|        D2|      Shampoo|  8|
|    P10|         C3|        D1|       yogurt|  9|
|    P11|         C4|        D2|         soap| 10|
|    P12|         C5|        D2|floor cleaner| 11|
|    P13|         C6|        D3|     biscuits| 12|
|    P14|         C7|        D3|         soda| 13|
|    P15|         C1|        D1|         Milk| 14|
|    P16|         C2|        D2|      Shampoo| 15|
|    P17|         C3|        D1

see http://spark.apache.org/docs/2.2.0/api/python/pyspark.sql.html#pyspark.sql.functions.monotonically_increasing_id

*A column that generates monotonically increasing 64-bit integers. The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. The current implementation puts the partition ID in the upper 31 bits, and the record number within each partition in the lower 33 bits. The assumption is that the data frame has less than 1 billion partitions, and each partition has less than 8 billion records.*

#### `Advanced Functionalities - Spark User Defined Functions - (UDF)`

  * Use a Python function in PySpark
  
<span style='color:Orange'>*UDFs might sound a good option. Python UDF are not. They should be avoided if possible, because they require data serialisation and deserialization from the JVM to the python layer and are always a performance bottleneck*</span>

In [23]:
# Creating a function to take square of a value
def square_number(s):
    return s * s

sqlContext.udf.register("squaredWithPython", square_number)

# Creating a lambda function
square_udf_int = F.udf(lambda z: square_number(z), DoubleType())

In [24]:
# Showing 2 rows of the transaction dataframe
transactions = sqlContext.table("an_training_easl.training1_transactions")
transactions.show(2)

+--------------+----------+-------+-------+---+------+-------+
|transaction_id|   date_id|card_id|prod_id|qty|amount|week_id|
+--------------+----------+-------+-------+---+------+-------+
|      10134103|01-05-2017|    101|    P20|  1|  45.0| 201718|
|      10162849|01-05-2017|    103|    P35|  1| 120.0| 201718|
+--------------+----------+-------+-------+---+------+-------+
only showing top 2 rows



In [25]:
# Calling the udf on the transaction dataframe
transactions_withSquare=transactions.select('amount',square_udf_int('amount').alias('int_squared'))
transactions_withSquare.show(2)

+------+-----------+
|amount|int_squared|
+------+-----------+
|  45.0|     2025.0|
| 120.0|    14400.0|
+------+-----------+
only showing top 2 rows



#### `Advanced Functionalities - Some More String functions`

  * Padding on the right of a column
  
  
  * Translating values into other values
  
  
  * Formatting values through concatenation

In [26]:
# R-padding values in card_id column
transactions.select('card_id',rpad(transactions['card_id'],6,'0').alias('card_id')).show()

+-------+-------+
|card_id|card_id|
+-------+-------+
|    101| 101000|
|    103| 103000|
|    105| 105000|
|    107| 107000|
|    111| 111000|
|    113| 113000|
|    115| 115000|
|    117| 117000|
|    120| 120000|
|    122| 122000|
|    123| 123000|
|    127| 127000|
|    133| 133000|
|    137| 137000|
|    191| 191000|
|    101| 101000|
|    103| 103000|
|    105| 105000|
|    107| 107000|
|    111| 111000|
+-------+-------+
only showing top 20 rows



In [27]:
## Replacing each occurrence of a character with another character
## In the following example every 'P' is replaced by 'A', '3' by 'B', '1' by 'C', '2' by 'D'
transactions.select('prod_id',translate(transactions['prod_id'],'P312','ABCD').alias('prod_id')).show()

+-------+-------+
|prod_id|prod_id|
+-------+-------+
|    P20|    AD0|
|    P35|    AB5|
|    P31|    ABC|
|    P13|    ACB|
|    P35|    AB5|
|    P25|    AD5|
|    P33|    ABB|
|    P23|    ADB|
|    P13|    ACB|
|     P2|     AD|
|    P24|    AD4|
|    P12|    ACD|
|    P35|    AB5|
|     P9|     A9|
|    P34|    AB4|
|    P29|    AD9|
|    P15|    AC5|
|    P33|    ABB|
|     P7|     A7|
|    P16|    AC6|
+-------+-------+
only showing top 20 rows



In [28]:
# Formatting values in string
transactions.select('week_id',F.format_string('Week: %s', transactions['week_id']).alias('Week')).show()

+-------+------------+
|week_id|        Week|
+-------+------------+
| 201718|Week: 201718|
| 201718|Week: 201718|
| 201718|Week: 201718|
| 201718|Week: 201718|
| 201718|Week: 201718|
| 201718|Week: 201718|
| 201718|Week: 201718|
| 201718|Week: 201718|
| 201718|Week: 201718|
| 201718|Week: 201718|
| 201718|Week: 201718|
| 201718|Week: 201718|
| 201718|Week: 201718|
| 201718|Week: 201718|
| 201718|Week: 201718|
| 201718|Week: 201718|
| 201718|Week: 201718|
| 201718|Week: 201718|
| 201718|Week: 201718|
| 201718|Week: 201718|
+-------+------------+
only showing top 20 rows



#### `Efficiency Tips`

  * Broadcast Joins
  
  
  * Persisting tables
  
  
  * Repartitioning and Coalesce

#### `Broadcast Joins`

  * Broadcast joins are used when one of the bigger tables needs to be joined with smaller tables
  
  
  * Use broadcast join of transactions table and prod table where product table is broadcasted

In [29]:
# Broadcast Join using the broadcast on the smaller dataframe
transactions.join(F.broadcast(products), transactions.prod_id == products.prod_id, 'inner').show()

+--------------+----------+-------+-------+---+------+-------+-------+-----------+----------+-------------+
|transaction_id|   date_id|card_id|prod_id|qty|amount|week_id|prod_id|category_id|department|category_desc|
+--------------+----------+-------+-------+---+------+-------+-------+-----------+----------+-------------+
|      10134103|01-05-2017|    101|    P20|  1|  45.0| 201718|    P20|         C6|        D3|     biscuits|
|      10162849|01-05-2017|    103|    P35|  1| 120.0| 201718|    P35|         C7|        D3|         soda|
|      10187085|01-05-2017|    105|    P31|  1|  28.0| 201718|    P31|         C3|        D1|       yogurt|
|      10413355|02-05-2017|    107|    P13|  1|  70.0| 201718|    P13|         C6|        D3|     biscuits|
|      10543839|02-05-2017|    111|    P35|  1| 120.0| 201718|    P35|         C7|        D3|         soda|
|      10600256|02-05-2017|    113|    P25|  1|  40.0| 201718|    P25|         C4|        D2|         soap|
|      10664653|02-05-2017| 

#### `Persisting Dataframes`

  * Persist dataframes if they are used multiple times and the cost of creation of these dataframes is higher
  
  
  * Eg. Persist prod_trans table to avoid the table to be recreated when it's used after the first time

In [30]:
# Persisting dataframe and calling an action on it
prod_trans = prod_trans.persist()

# An action for the first time takes some time
prod_trans.show()
#unpersist it using DataFrame.unpersist()

+--------------+----------+-------+-------+---+------+-------+-------+-----------+----------+-------------+
|transaction_id|   date_id|card_id|prod_id|qty|amount|week_id|prod_id|category_id|department|category_desc|
+--------------+----------+-------+-------+---+------+-------+-------+-----------+----------+-------------+
|      10134103|01-05-2017|    101|    P20|  1|  45.0| 201718|    P20|         C6|        D3|     biscuits|
|      10162849|01-05-2017|    103|    P35|  1| 120.0| 201718|    P35|         C7|        D3|         soda|
|      10187085|01-05-2017|    105|    P31|  1|  28.0| 201718|    P31|         C3|        D1|       yogurt|
|      10413355|02-05-2017|    107|    P13|  1|  70.0| 201718|    P13|         C6|        D3|     biscuits|
|      10543839|02-05-2017|    111|    P35|  1| 120.0| 201718|    P35|         C7|        D3|         soda|
|      10600256|02-05-2017|    113|    P25|  1|  40.0| 201718|    P25|         C4|        D2|         soap|
|      10664653|02-05-2017| 

In [31]:
# Best Practice - Unpersist the dataset when it is no longer used (multiple times)
prod_trans.unpersist()

DataFrame[transaction_id: bigint, date_id: string, card_id: bigint, prod_id: string, qty: bigint, amount: float, week_id: int, prod_id: string, category_id: string, department: string, category_desc: string]

#### `Repartition and Coalesce`

  * Repartition the table using repartition(no_of_partitions) if you want equal distibutions across cluster
  
  
  * Coalesce the table if you dont want shuffle overhead unlike repartition

In [32]:
print(products.rdd.getNumPartitions())
print(products.count())

1
35


In [33]:
# Re-partition the dataframe into 2
products_rp = products.repartition(2)
products_rp.rdd.getNumPartitions()

2

In [34]:
# Coalesce the dataframe into 1
products_cl = products.coalesce(1)
products_cl.rdd.getNumPartitions()

1

### Spark APIs vs. SQL

#### Efficiency Tip - Spark APIs should be preferred - it gives you much more control on what happens




#### Important - Spark IS NOT a RDBMS, so avoid writing long SQL statements. Complicated SQL code is often a good candidate for performance issues

In [35]:
# Using SQL to create a dataframe
df = sqlContext.sql("""select a.week_id,a.prod_id,a.card_id,a.amount,qty,a.transaction_id, b.category_desc
                       from an_training_easl.training1_transactions a
                       inner join 
                       (select distinct prod_id, category_desc from an_training_easl.training1_products) b
                       on a.prod_id = b.prod_id
                       where week_id > 201718 and a.card_id is not null""")
df.show()

+-------+-------+-------+------+---+--------------+-------------+
|week_id|prod_id|card_id|amount|qty|transaction_id|category_desc|
+-------+-------+-------+------+---+--------------+-------------+
| 201719|     P9|    113| 190.0|  1|      28007758|      Shampoo|
| 201719|     P5|    115| 150.0|  1|      28014164|floor cleaner|
| 201719|    P18|    117|  25.0|  1|      28221466|         soap|
| 201719|    P28|    127|  35.0|  1|      26959555|         soda|
| 201719|     P1|    133|  21.0|  1|      25884734|         Milk|
| 201719|    P17|    133|  55.0|  1|      27127095|       yogurt|
| 201719|    P12|    137| 110.0|  1|      26031912|floor cleaner|
| 201719|     P6|    137|  27.0|  1|      27147180|     biscuits|
| 201719|     P1|    191|  21.0|  1|      26136122|         Milk|
| 201719|    P32|    191|  30.0|  1|      27160634|         soap|
| 201719|     P4|    203|  40.0|  1|      26275286|         soap|
| 201719|     P7|    203|  35.0|  1|      27360181|         soda|
| 201719| 

In [36]:
# Checking the type of object created
type(df)

pyspark.sql.dataframe.DataFrame

### Windowing functions

In [38]:
trnDF = transactions.where("date_id between '01-05-2017' and '03-05-2017'") \
                    .join(F.broadcast(products), transactions.prod_id == products.prod_id, 'inner') \
                    .persist()

In [39]:
trnDF.show()

+--------------+----------+-------+-------+---+------+-------+-------+-----------+----------+-------------+
|transaction_id|   date_id|card_id|prod_id|qty|amount|week_id|prod_id|category_id|department|category_desc|
+--------------+----------+-------+-------+---+------+-------+-------+-----------+----------+-------------+
|      10134103|01-05-2017|    101|    P20|  1|  45.0| 201718|    P20|         C6|        D3|     biscuits|
|      10162849|01-05-2017|    103|    P35|  1| 120.0| 201718|    P35|         C7|        D3|         soda|
|      10187085|01-05-2017|    105|    P31|  1|  28.0| 201718|    P31|         C3|        D1|       yogurt|
|      10413355|02-05-2017|    107|    P13|  1|  70.0| 201718|    P13|         C6|        D3|     biscuits|
|      10543839|02-05-2017|    111|    P35|  1| 120.0| 201718|    P35|         C7|        D3|         soda|
|      10600256|02-05-2017|    113|    P25|  1|  40.0| 201718|    P25|         C4|        D2|         soap|
|      10664653|02-05-2017| 

In [40]:
from pyspark.sql.window import Window

w = Window.partitionBy(trnDF['date_id']).orderBy(trnDF['amount'].desc())

In [41]:
trn2 = trnDF.select('*', F.rank().over(w).alias("day_rank"))    

In [42]:
trn2.show()

+--------------+----------+-------+-------+---+------+-------+-------+-----------+----------+-------------+--------+
|transaction_id|   date_id|card_id|prod_id|qty|amount|week_id|prod_id|category_id|department|category_desc|day_rank|
+--------------+----------+-------+-------+---+------+-------+-------+-----------+----------+-------------+--------+
|      11026889|02-05-2017|    117|    P23|  1| 250.0| 201718|    P23|         C2|        D2|      Shampoo|       1|
|      11221390|02-05-2017|    122|     P2|  1| 250.0| 201718|     P2|         C2|        D2|      Shampoo|       1|
|      10543839|02-05-2017|    111|    P35|  1| 120.0| 201718|    P35|         C7|        D3|         soda|       3|
|      10664653|02-05-2017|    115|    P33|  1| 110.0| 201718|    P33|         C5|        D2|floor cleaner|       4|
|      10413355|02-05-2017|    107|    P13|  1|  70.0| 201718|    P13|         C6|        D3|     biscuits|       5|
|      11026939|02-05-2017|    120|    P13|  1|  70.0| 201718|  