In [1]:
import sys
import os

In [2]:
os.environ.get('JAVA_HOME')

'C:\\Program Files\\Java\\jdk1.8.0_311'

In [3]:
import findspark
findspark.init()

In [4]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *

In [5]:
spark = SparkSession.builder.master("local[4]").appName("SparkSQL").getOrCreate()

In [6]:
data = (("Banana", 1000, "USA"), 
        ("Carrots", 1500, "USA"),
        ("Beans", 1600, "USA"), 
        ("Orange", 2000, "USA"),
        ("Orange", 2000, "USA"), 
        ("Banana", 400, "China"),
        ("Carrots", 1200, "China"), 
        ("Beans", 1500, "China"),
        ("Orange", 4000, "China"), 
        ("Banana", 2000, "Canada"),
        ("Carrots", 2000, "Canada"), 
        ("Beans", 2000, "Mexico"))
rdd = spark.sparkContext.parallelize(data)

In [7]:
rdd.take(5)

[('Banana', 1000, 'USA'),
 ('Carrots', 1500, 'USA'),
 ('Beans', 1600, 'USA'),
 ('Orange', 2000, 'USA'),
 ('Orange', 2000, 'USA')]

In [9]:
df = rdd.toDF(["Product", "Amount", "Country"])
df.show()

+-------+------+-------+
|Product|Amount|Country|
+-------+------+-------+
| Banana|  1000|    USA|
|Carrots|  1500|    USA|
|  Beans|  1600|    USA|
| Orange|  2000|    USA|
| Orange|  2000|    USA|
| Banana|   400|  China|
|Carrots|  1200|  China|
|  Beans|  1500|  China|
| Orange|  4000|  China|
| Banana|  2000| Canada|
|Carrots|  2000| Canada|
|  Beans|  2000| Mexico|
+-------+------+-------+



In [12]:
import time
start_time = time.time()
df.groupBy("Country").pivot("Product").sum("Amount").show()
print(f"Execution time: {time.time() - start_time}")

+-------+------+-----+-------+------+
|Country|Banana|Beans|Carrots|Orange|
+-------+------+-----+-------+------+
|  China|   400| 1500|   1200|  4000|
|    USA|  1000| 1600|   1500|  4000|
| Mexico|  null| 2000|   null|  null|
| Canada|  2000| null|   2000|  null|
+-------+------+-----+-------+------+

Execution time: 9.09967827796936


In [23]:
products = df.select("Product").distinct().rdd.map(lambda row : row[0]).collect()
products

['Beans', 'Banana', 'Carrots', 'Orange']

## New Optimized Approach
Spark 2.0 on-wards performance has been improved on Pivot. Pivot is a costly operation. Hence, to enhance performance, provide column data (if known) as an argument to function

In [24]:
import time
start_time = time.time()
df.groupBy("Country").pivot("Product", products).sum("Amount").show()
print(f"Execution time: {time.time() - start_time}")

+-------+-----+------+-------+------+
|Country|Beans|Banana|Carrots|Orange|
+-------+-----+------+-------+------+
|  China| 1500|   400|   1200|  4000|
|    USA| 1600|  1000|   1500|  4000|
| Mexico| 2000|  null|   null|  null|
| Canada| null|  2000|   2000|  null|
+-------+-----+------+-------+------+

Execution time: 4.363990068435669


## Unpivot

In [27]:
unpivoted_df = pivoted_df \
    .selectExpr("Country", "stack(4, 'Banana', Banana, 'Beans', Beans, 'Carrots', Carrots, 'Orange', Orange) as (Country, Total)") \
    .where("Total is not null")
unpivoted_df.show()

+-------+-------+-----+
|Country|Country|Total|
+-------+-------+-----+
|  China| Banana|  400|
|  China|  Beans| 1500|
|  China|Carrots| 1200|
|  China| Orange| 4000|
|    USA| Banana| 1000|
|    USA|  Beans| 1600|
|    USA|Carrots| 1500|
|    USA| Orange| 4000|
| Mexico|  Beans| 2000|
| Canada| Banana| 2000|
| Canada|Carrots| 2000|
+-------+-------+-----+

