# Aggregations in Spark

In [1]:
import findspark
findspark.init()

import pyspark # only run after findspark.init()
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

In [3]:
df = spark.read.format("csv")\
    .option("header", "true")\
    .option("inferSchema", "true")\
    .load("../data/retail-data/all/*.csv")\
    .coalesce(5)

df.cache()
df.createOrReplaceTempView("dfTable")

In [4]:
df.printSchema()

root
 |-- InvoiceNo: string (nullable = true)
 |-- StockCode: string (nullable = true)
 |-- Description: string (nullable = true)
 |-- Quantity: integer (nullable = true)
 |-- InvoiceDate: string (nullable = true)
 |-- UnitPrice: double (nullable = true)
 |-- CustomerID: integer (nullable = true)
 |-- Country: string (nullable = true)



#### Count 

In [5]:
# df.count() is an action which returns the result immediately
df.count()

541909

In [8]:
# Below count function is transformation which is lazily evaluated
from pyspark.sql.functions import expr

df.select(expr("count(*)")).show()

+--------+
|count(1)|
+--------+
|  541909|
+--------+



In [9]:
## Count(*) includes null value count while count(column_name) doesn't.

In [15]:
from pyspark.sql.functions import count, countDistinct, approx_count_distinct

df.select(count("*")).show()

+--------+
|count(1)|
+--------+
|  541909|
+--------+



In [14]:
df.select(countDistinct("InvoiceNo")).show()

+-------------------------+
|count(DISTINCT InvoiceNo)|
+-------------------------+
|                    25900|
+-------------------------+



In [16]:
# Approx count distinct is much faster

df.select(approx_count_distinct("InvoiceNo", 0.1)).show() # 0.1 - maximum estimation error allowed

+--------------------------------+
|approx_count_distinct(InvoiceNo)|
+--------------------------------+
|                           25085|
+--------------------------------+



#### Min, Max

In [17]:
from pyspark.sql.functions import min, max

df.select(min("Quantity"), max("Quantity")).show()

+-------------+-------------+
|min(Quantity)|max(Quantity)|
+-------------+-------------+
|       -80995|        80995|
+-------------+-------------+



#### Sum

In [18]:
from pyspark.sql.functions import sum

df.select(sum("Quantity")).show() 

+-------------+
|sum(Quantity)|
+-------------+
|      5176450|
+-------------+



In [19]:
from pyspark.sql.functions import sumDistinct

df.select(sumDistinct("Quantity")).show()

+----------------------+
|sum(DISTINCT Quantity)|
+----------------------+
|                 29310|
+----------------------+



#### Average

In [20]:
from pyspark.sql.functions import sum, count, avg, expr


df.select(
    count("Quantity").alias("total_transactions"),
    sum("Quantity").alias("total_purchases"),
    avg("Quantity").alias("avg_purchases"),
    expr("mean(Quantity)").alias("mean_purchases"))\
    .selectExpr(
    "total_purchases/total_transactions",
    "avg_purchases",
    "mean_purchases").show()


+--------------------------------------+----------------+----------------+
|(total_purchases / total_transactions)|   avg_purchases|  mean_purchases|
+--------------------------------------+----------------+----------------+
|                      9.55224954743324|9.55224954743324|9.55224954743324|
+--------------------------------------+----------------+----------------+



#### Variance and Standard Deviation

In [21]:
# Population statistics
from pyspark.sql.functions import var_pop, stddev_pop

# Sample statistics
from pyspark.sql.functions import var_samp, stddev_samp

df.select(var_pop("Quantity"), var_samp("Quantity"),
        stddev_pop("Quantity"), stddev_samp("Quantity")).show()

+------------------+------------------+--------------------+---------------------+
| var_pop(Quantity)|var_samp(Quantity)|stddev_pop(Quantity)|stddev_samp(Quantity)|
+------------------+------------------+--------------------+---------------------+
|47559.303646609056|47559.391409298754|  218.08095663447796|   218.08115785023418|
+------------------+------------------+--------------------+---------------------+



#### Aggregating to Complex Types

In [24]:
from pyspark.sql.functions import collect_set, collect_list

df.agg(collect_set("Country"), collect_list("Country")).show()

+--------------------+---------------------+
|collect_set(Country)|collect_list(Country)|
+--------------------+---------------------+
|[Portugal, Italy,...| [United Kingdom, ...|
+--------------------+---------------------+



### Grouping

In [25]:
df.groupBy("InvoiceNo", "CustomerId").count().show()

+---------+----------+-----+
|InvoiceNo|CustomerId|count|
+---------+----------+-----+
|   536846|     14573|   76|
|   537026|     12395|   12|
|   537883|     14437|    5|
|   538068|     17978|   12|
|   538279|     14952|    7|
|   538800|     16458|   10|
|   538942|     17346|   12|
|  C539947|     13854|    1|
|   540096|     13253|   16|
|   540530|     14755|   27|
|   541225|     14099|   19|
|   541978|     13551|    4|
|   542093|     17677|   16|
|   536596|      null|    6|
|   537252|      null|    1|
|   538041|      null|    1|
|   543188|     12567|   63|
|   543590|     17377|   19|
|  C543757|     13115|    1|
|  C544318|     12989|    1|
+---------+----------+-----+
only showing top 20 rows



In [27]:
# Grouping with expression (expr)

df.groupBy("InvoiceNo").agg(count("Quantity").alias("quan"),expr("count(Quantity)")).show()

+---------+----+---------------+
|InvoiceNo|quan|count(Quantity)|
+---------+----+---------------+
|   536596|   6|              6|
|   536938|  14|             14|
|   537252|   1|              1|
|   537691|  20|             20|
|   538041|   1|              1|
|   538184|  26|             26|
|   538517|  53|             53|
|   538879|  19|             19|
|   539275|   6|              6|
|   539630|  12|             12|
|   540499|  24|             24|
|   540540|  22|             22|
|  C540850|   1|              1|
|   540976|  48|             48|
|   541432|   4|              4|
|   541518| 101|            101|
|   541783|  35|             35|
|   542026|   9|              9|
|   542375|   6|              6|
|  C542604|   8|              8|
+---------+----+---------------+
only showing top 20 rows



In [33]:
df.groupBy("Country").agg(countDistinct("InvoiceNo").alias("CountDistinct")).orderBy("CountDistinct", ascending=False).show()

+---------------+-------------+
|        Country|CountDistinct|
+---------------+-------------+
| United Kingdom|        23494|
|        Germany|          603|
|         France|          461|
|           EIRE|          360|
|        Belgium|          119|
|          Spain|          105|
|    Netherlands|          101|
|    Switzerland|           74|
|       Portugal|           71|
|      Australia|           69|
|          Italy|           55|
|        Finland|           48|
|         Sweden|           46|
|         Norway|           40|
|Channel Islands|           33|
|          Japan|           28|
|         Poland|           24|
|        Denmark|           21|
|         Cyprus|           20|
|        Austria|           19|
+---------------+-------------+
only showing top 20 rows



In [34]:
df.groupBy("InvoiceNo").agg(expr("avg(Quantity)"),expr("stddev_pop(Quantity)")).show()

+---------+------------------+--------------------+
|InvoiceNo|     avg(Quantity)|stddev_pop(Quantity)|
+---------+------------------+--------------------+
|   536596|               1.5|  1.1180339887498947|
|   536938|33.142857142857146|  20.698023172885524|
|   537252|              31.0|                 0.0|
|   537691|              8.15|   5.597097462078001|
|   538041|              30.0|                 0.0|
|   538184|12.076923076923077|   8.142590198943392|
|   538517|3.0377358490566038|  2.3946659604837897|
|   538879|21.157894736842106|  11.811070444356483|
|   539275|              26.0|  12.806248474865697|
|   539630|20.333333333333332|  10.225241100118645|
|   540499|              3.75|  2.6653642652865788|
|   540540|2.1363636363636362|  1.0572457590557278|
|  C540850|              -1.0|                 0.0|
|   540976|10.520833333333334|   6.496760677872902|
|   541432|             12.25|  10.825317547305483|
|   541518| 23.10891089108911|  20.550782784878713|
|   541783|1

### Window Functions

Spark supports three kinds of window functions: ranking functions, analytic functions,
and aggregate functions

In [35]:
from pyspark.sql.functions import col, to_date

dfWithDate = df.withColumn("date", to_date(col("InvoiceDate"), "MM/d/yyyy H:mm"))
dfWithDate.createOrReplaceTempView("dfWithDate")

In [36]:
# Import Window Function
from pyspark.sql.window import Window

from pyspark.sql.functions import desc

# Window.paritionBy.orderBy
## Frame specification (the rowsBetween statement) states which rows will be included in the frame 
## based on its reference to the currentb input row.

windowSpec = Window\
            .partitionBy("CustomerId", "date")\
            .orderBy(desc("Quantity"))\
            .rowsBetween(Window.unboundedPreceding, Window.currentRow)

In [39]:
# Running aggregation over windows

from pyspark.sql.functions import max
from pyspark.sql.functions import dense_rank, rank

maxPurchaseQuantity = max(col("Quantity")).over(windowSpec)
purchaseDenseRank = dense_rank().over(windowSpec)
purchaseRank = rank().over(windowSpec)

In [40]:
dfWithDate.where("CustomerId IS NOT NULL").orderBy("CustomerId")\
            .select(
                    col("CustomerId"),
                    col("date"),
                    col("Quantity"),
                    purchaseRank.alias("quantityRank"),
                    purchaseDenseRank.alias("quantityDenseRank"),
                    maxPurchaseQuantity.alias("maxPurchaseQuantity")).show()


+----------+----------+--------+------------+-----------------+-------------------+
|CustomerId|      date|Quantity|quantityRank|quantityDenseRank|maxPurchaseQuantity|
+----------+----------+--------+------------+-----------------+-------------------+
|     12346|2011-01-18|   74215|           1|                1|              74215|
|     12346|2011-01-18|  -74215|           2|                2|              74215|
|     12347|2010-12-07|      36|           1|                1|                 36|
|     12347|2010-12-07|      30|           2|                2|                 36|
|     12347|2010-12-07|      24|           3|                3|                 36|
|     12347|2010-12-07|      12|           4|                4|                 36|
|     12347|2010-12-07|      12|           4|                4|                 36|
|     12347|2010-12-07|      12|           4|                4|                 36|
|     12347|2010-12-07|      12|           4|                4|             

### Roll Ups

In [42]:
# Create DataFrame without null values
dfNoNull = dfWithDate.drop()

dfNoNull.createOrReplaceTempView("dfNoNull")

In [43]:
## DataFrame that includes:
## the grand total over all dates, 
## the grand total for each date in the DataFrame,
## and the subtotal for each country on each date in the DataFrame

rolledUpDF = dfNoNull.rollup("Date", "Country").agg(sum("Quantity"))\
    .selectExpr("Date", "Country", "`sum(Quantity)` as total_quantity")\
    .orderBy("Date")

rolledUpDF.show()

+----------+--------------+--------------+
|      Date|       Country|total_quantity|
+----------+--------------+--------------+
|      null|          null|       5176450|
|2010-12-01|United Kingdom|         23949|
|2010-12-01|          EIRE|           243|
|2010-12-01|       Germany|           117|
|2010-12-01|     Australia|           107|
|2010-12-01|   Netherlands|            97|
|2010-12-01|        France|           449|
|2010-12-01|          null|         26814|
|2010-12-01|        Norway|          1852|
|2010-12-02|       Germany|           146|
|2010-12-02|          null|         21023|
|2010-12-02|          EIRE|             4|
|2010-12-02|United Kingdom|         20873|
|2010-12-03|       Belgium|           528|
|2010-12-03|      Portugal|            65|
|2010-12-03|       Germany|           170|
|2010-12-03|   Switzerland|           110|
|2010-12-03|        Poland|           140|
|2010-12-03|          null|         14830|
|2010-12-03|         Spain|           400|
+----------

Now where you see the null values is where you’ll find the grand totals. A null in both rollup
columns specifies the grand total across both of those columns. 

### Cube

To pose this as a question again, can you make a table that includes the following?

    The total across all dates and countries
    The total for each date across all countries
    The total for each country on each date
    The total for each country across all dates


In [44]:
from pyspark.sql.functions import sum

dfNoNull.cube("Date", "Country").agg(sum(col("Quantity")))\
                                .select("Date", "Country", "sum(Quantity)").orderBy("Date").show()


+----+--------------------+-------------+
|Date|             Country|sum(Quantity)|
+----+--------------------+-------------+
|null|               Japan|        25218|
|null|            Portugal|        16180|
|null|           Hong Kong|         4769|
|null|                 RSA|          352|
|null|           Australia|        83653|
|null|United Arab Emirates|          982|
|null|         Unspecified|         3300|
|null|             Finland|        10666|
|null|             Germany|       117448|
|null|                null|      5176450|
|null|     Channel Islands|         9479|
|null|           Singapore|         5234|
|null|  European Community|          497|
|null|             Lebanon|          386|
|null|              Cyprus|         6317|
|null|              Norway|        19247|
|null|               Spain|        26824|
|null|                 USA|         1034|
|null|             Denmark|         8188|
|null|      Czech Republic|          592|
+----+--------------------+-------

### Pivot

In [47]:
pivoted = dfWithDate.groupBy("date").pivot("Country").sum()

In [53]:
pivoted.printSchema()

root
 |-- date: date (nullable = true)
 |-- Australia_sum(CAST(Quantity AS BIGINT)): long (nullable = true)
 |-- Australia_sum(UnitPrice): double (nullable = true)
 |-- Australia_sum(CAST(CustomerID AS BIGINT)): long (nullable = true)
 |-- Austria_sum(CAST(Quantity AS BIGINT)): long (nullable = true)
 |-- Austria_sum(UnitPrice): double (nullable = true)
 |-- Austria_sum(CAST(CustomerID AS BIGINT)): long (nullable = true)
 |-- Bahrain_sum(CAST(Quantity AS BIGINT)): long (nullable = true)
 |-- Bahrain_sum(UnitPrice): double (nullable = true)
 |-- Bahrain_sum(CAST(CustomerID AS BIGINT)): long (nullable = true)
 |-- Belgium_sum(CAST(Quantity AS BIGINT)): long (nullable = true)
 |-- Belgium_sum(UnitPrice): double (nullable = true)
 |-- Belgium_sum(CAST(CustomerID AS BIGINT)): long (nullable = true)
 |-- Brazil_sum(CAST(Quantity AS BIGINT)): long (nullable = true)
 |-- Brazil_sum(UnitPrice): double (nullable = true)
 |-- Brazil_sum(CAST(CustomerID AS BIGINT)): long (nullable = true)
 |-- Can

In [56]:
pivoted.select("date", "Australia_sum(CAST(Quantity AS BIGINT))").show()

+----------+---------------------------------------+
|      date|Australia_sum(CAST(Quantity AS BIGINT))|
+----------+---------------------------------------+
|2011-10-07|                                   null|
|2011-05-06|                                   null|
|2011-01-30|                                   null|
|2011-11-18|                                   null|
|2011-07-18|                                   null|
|2011-08-21|                                   null|
|2011-01-23|                                   null|
|2011-07-07|                                   null|
|2010-12-15|                                   null|
|2011-11-14|                                   null|
|2010-12-01|                                    107|
|2011-04-06|                                   null|
|2011-06-21|                                   null|
|2011-02-21|                                   null|
|2011-09-04|                                   null|
|2011-08-30|                                  