# 1. Pivot Row to column and back

It's very common in data transformation, we need to transform some row into column. In this tutorial, we will show how to do pivot properly.

## 1.1 Prepare the spark environmnet

In [5]:
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql import functions as f
import os

In [7]:
# create spark session
local=False

if local:
    spark = SparkSession.builder\
        .master("local[4]")\
        .appName("Pivot_and_UnPivot")\
        .config("spark.executor.memory", "4g")\
        .getOrCreate()
else:
    spark = SparkSession.builder\
        .master("k8s://https://kubernetes.default.svc:443")\
        .appName("RepartitionAndCoalesce")\
        .config("spark.kubernetes.container.image", os.environ["IMAGE_NAME"])\
        .config("spark.kubernetes.authenticate.driver.serviceAccountName", os.environ['KUBERNETES_SERVICE_ACCOUNT'])\
        .config("spark.executor.instances", "4")\
        .config("spark.executor.memory","2g")\
        .config("spark.kubernetes.namespace", os.environ['KUBERNETES_NAMESPACE'])\
        .getOrCreate()


## 1.2 Prepare the data

The source data frame has three column(e.g. "Product", "Amount", "Country") which describes the product exporting number for each country of a year.

For example, "Banana", 1000, "USA" means USA export 1000 tons of banana.


In [8]:
data = [("Banana", 1000, "USA"), ("Carrots", 1500, "USA"), ("Beans", 1600, "USA"),
            ("Orange", 2000, "USA"), ("Orange", 2000, "USA"), ("Banana", 400, "China"),
            ("Carrots", 1200, "China"), ("Beans", 1500, "China"), ("Orange", 4000, "China"),
            ("Banana", 2000, "Canada"), ("Carrots", 2000, "Canada"), ("Beans", 2000, "Mexico")]

columns = ["Product", "Amount", "Country"]
df = spark.createDataFrame(data=data, schema=columns)
print("main output: source data schema")
df.printSchema()
print("main output: source data")
df.show(truncate=False)

main output: source data schema
root
 |-- Product: string (nullable = true)
 |-- Amount: long (nullable = true)
 |-- Country: string (nullable = true)

main output: source data




+-------+------+-------+
|Product|Amount|Country|
+-------+------+-------+
|Banana |1000  |USA    |
|Carrots|1500  |USA    |
|Beans  |1600  |USA    |
|Orange |2000  |USA    |
|Orange |2000  |USA    |
|Banana |400   |China  |
|Carrots|1200  |China  |
|Beans  |1500  |China  |
|Orange |4000  |China  |
|Banana |2000  |Canada |
|Carrots|2000  |Canada |
|Beans  |2000  |Mexico |
+-------+------+-------+



                                                                                

## 1.3 First example
Let's first understand what a pivot will do. Below we first build a df that calculates the sum of amount by product. 
The data frame are shown in below figure.
```text
+-------+----------+
|Product|amount_sum|
+-------+----------+
|  Beans|      5100|
| Banana|      3400|
|Carrots|      4700|
| Orange|      8000|
+-------+----------+
```


In [13]:
df_product_sum=df.groupBy("Product").agg(f.sum("Amount").alias("amount_sum"))
df_product_sum.show()

+-------+----------+
|Product|amount_sum|
+-------+----------+
|  Beans|      5100|
| Banana|      3400|
|Carrots|      4700|
| Orange|      8000|
+-------+----------+



### Use pivot function
Now we want the product name become column names. Below code shows how to pivot rows to column.
Note pivot function can only be called after a groupBy. For now, we leave the groupBy argument empty, It means there is no group.

Then the pivot function takes a column name, for each distinct value in the given column, it will create a new column. 
Note, as you may have duplicate values in the given column. So the pivot function will return a list of other column values. You need to use an aggregation function to transform these list to a single value. In this example, we use first to get the first value. 

In [16]:
df_product_sum.groupBy().pivot("Product").agg(f.first("amount_sum")).show()

+------+-----+-------+------+
|Banana|Beans|Carrots|Orange|
+------+-----+-------+------+
|  3400| 5100|   4700|  8000|
+------+-----+-------+------+



##  1.4 A more advance example
In the above example, we have seen it's quite easy to pivot with two dimentions. Now we will add another dimension. We want see the export number for each country and product.

First, let's see which country export which product.

In [18]:
# show the country list groupby Product
df.groupBy("Product").agg(f.collect_list("Country")).show(truncate=False)


+-------+---------------------+
|Product|collect_list(Country)|
+-------+---------------------+
|Orange |[China, USA, USA]    |
|Beans  |[China, USA, Mexico] |
|Carrots|[China, USA, Canada] |
|Banana |[USA, Canada, China] |
+-------+---------------------+



In [19]:
# show the Amount by Product and country
df.groupBy("Product", "Country").sum("Amount").show(truncate=False)


+-------+-------+-----------+
|Product|Country|sum(Amount)|
+-------+-------+-----------+
|Carrots|USA    |1500       |
|Banana |USA    |1000       |
|Beans  |USA    |1600       |
|Banana |China  |400        |
|Orange |USA    |4000       |
|Orange |China  |4000       |
|Carrots|China  |1200       |
|Beans  |China  |1500       |
|Carrots|Canada |2000       |
|Beans  |Mexico |2000       |
|Banana |Canada |2000       |
+-------+-------+-----------+



With the multiple groupBy, we can get the export number for each country and product. But it's not very easy to read. So we want to pivot the distinct country value to columns 

In [21]:
# The pivot function will transform the country list into columns with a value calculated by an aggregation function sum
# Note, for the rows that country does not export the product, spark fills it with null.
pivot_country = df.groupBy("Product").pivot("Country").sum("Amount")
pivot_country.printSchema()
pivot_country.show(truncate=False)

root
 |-- Product: string (nullable = true)
 |-- Canada: long (nullable = true)
 |-- China: long (nullable = true)
 |-- Mexico: long (nullable = true)
 |-- USA: long (nullable = true)

+-------+------+-----+------+----+
|Product|Canada|China|Mexico|USA |
+-------+------+-----+------+----+
|Orange |null  |4000 |null  |4000|
|Beans  |null  |1500 |2000  |1600|
|Banana |2000  |400  |null  |1000|
|Carrots|2000  |1200 |null  |1500|
+-------+------+-----+------+----+



##  1.5 Performence issues

Pivot is a very expensive operation. If the data frame that you want to pivot is quite big, you may want to optimize it by using following solutions.
- Solution 1: We can provide a list of the column name(row value) that we want to pivot.
- Solution 2: Two phrase groupby

### 1.5.1 Provide a column list 

By providing a column list by indicating which column you want to have can improves performence a lot. Otherwise, spark need to change the column size of the result data frame on the fly.

You only need to build a list that contains all the distinct values, then put this list as the second argumnet in the pivot function. 


In [22]:
country_list = ["USA", "China", "Canada", "Mexico"]
pivot_country2 = df.groupBy("Product").pivot("Country", country_list).sum("Amount")
pivot_country2.show(truncate=False)

+-------+------+-----+------+----+
|Product|Canada|China|Mexico|USA |
+-------+------+-----+------+----+
|Orange |null  |4000 |null  |4000|
|Beans  |null  |1500 |2000  |1600|
|Banana |2000  |400  |null  |1000|
|Carrots|2000  |1200 |null  |1500|
+-------+------+-----+------+----+



Note the order of the values list does not reflect on the order of columns. Because spark sort the value in alphabet order.
What happens if I remove USA from the list? 

In [23]:
country_list1 = [ "China", "Canada", "Mexico"]
pivot_country3 = df.groupBy("Product").pivot("Country", country_list1).sum("Amount")
pivot_country3.show(truncate=False)

+-------+-----+------+------+
|Product|China|Canada|Mexico|
+-------+-----+------+------+
|Orange |4000 |null  |null  |
|Beans  |1500 |null  |2000  |
|Banana |400  |2000  |null  |
|Carrots|1200 |2000  |null  |
+-------+-----+------+------+



You can notice spark will simply ignore all rows of "USA", and the resulting data frame only contains the column that are defined in the column list.

### 1.5.2 Use two phase groupBy 

The optimization philosophy is exactly the same as the provinding a column list. To avoid spark change the resulting data frame's column dimension. **We need to tell spark how many columns the data frame will have before it does the pivot function**. So in the two phase groupBy, the first groupBy calculate all the possible disctinct value of the pivot column. The second phase do the pivot function.

#### Phase 1. Use groupBy to get all dictinct value

In [25]:
df_tmp=df.groupBy("Product","Country").agg(f.sum("Amount").alias("sum_amount"))
df_tmp.show()

+-------+-------+----------+
|Product|Country|sum_amount|
+-------+-------+----------+
|Carrots| Canada|      2000|
|  Beans| Mexico|      2000|
| Banana| Canada|      2000|
| Banana|  China|       400|
| Orange|    USA|      4000|
| Orange|  China|      4000|
|Carrots|  China|      1200|
|  Beans|  China|      1500|
|Carrots|    USA|      1500|
| Banana|    USA|      1000|
|  Beans|    USA|      1600|
+-------+-------+----------+



#### Phase 2. Pivot the column
Just little note, the sum function after pivot is different from the f.sum(). They do the same thing, but not from the same liberary.

In [26]:
df_pivot_country4=df_tmp.groupBy("Product").pivot("Country").sum("sum_amount")
df_pivot_country4.show()

+-------+------+-----+------+----+
|Product|Canada|China|Mexico| USA|
+-------+------+-----+------+----+
| Orange|  null| 4000|  null|4000|
|  Beans|  null| 1500|  2000|1600|
| Banana|  2000|  400|  null|1000|
|Carrots|  2000| 1200|  null|1500|
+-------+------+-----+------+----+



## 1.6 Unpivot
Unpivot is a reverse operation, we can achieve by rotating column values into rows values. 
PySpark SQL doesn’t have unpivot function hence will use the stack() function. Below code converts 
column countries to row.


In the stack function, the first argument is the number of pairs that you want to unpivot. Below we set 3
Then, we put the three pair: '<row_value>', column_name. The first argument will be the value in each row after the pivot. The second argument is the column name of the source dataframe (must be the same, otherwise raise error).

Below example, you can notice the value Can_ada is in the resulting data frame. 

The **as (Country, Total)** defines the column name in the resulting data frame. 

In [32]:
unpivot_expr1 = "stack(3, 'Can_ada', Canada, 'China', China, 'Mexico', Mexico) as (Country,Total)"
unpivot_df1 = pivot_country.select("Product", f.expr(unpivot_expr1)).where("Total is not null")
unpivot_df1.show(truncate=False)
unpivot_df1.printSchema()

+-------+-------+-----+
|Product|Country|Total|
+-------+-------+-----+
|Orange |China  |4000 |
|Beans  |China  |1500 |
|Beans  |Mexico |2000 |
|Banana |Can_ada|2000 |
|Banana |China  |400  |
|Carrots|Can_ada|2000 |
|Carrots|China  |1200 |
+-------+-------+-----+

root
 |-- Product: string (nullable = true)
 |-- Country: string (nullable = true)
 |-- Total: long (nullable = true)



Below exampl shows if we want to add another column "USA"

In [28]:
unpivot_expr2 = "stack(4, 'Canada', Canada, 'China', China, 'Mexico', Mexico, 'USA', USA) as (Country,Total)"
unpivot_df2 = pivot_country.select("Product", f.expr(unpivot_expr2)).where("Total is not null")
unpivot_df2.show(truncate=False)
unpivot_df2.printSchema()

+-------+-------+-----+
|Product|Country|Total|
+-------+-------+-----+
|Orange |China  |4000 |
|Orange |USA    |4000 |
|Beans  |China  |1500 |
|Beans  |Mexico |2000 |
|Beans  |USA    |1600 |
|Banana |Canada |2000 |
|Banana |China  |400  |
|Banana |USA    |1000 |
|Carrots|Canada |2000 |
|Carrots|China  |1200 |
|Carrots|USA    |1500 |
+-------+-------+-----+

root
 |-- Product: string (nullable = true)
 |-- Country: string (nullable = true)
 |-- Total: long (nullable = true)

