## Setting Environment

In [1]:
import os
import sys
os.environ["PYSPARK_PYTHON"] = sys.executable
os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable

## Creating DataFrame

In [2]:
from pyspark.sql import SparkSession
spark = SparkSession.builder\
        .appName("Pivot")\
        .getOrCreate()

In [3]:
data = [("Banana",1000,"USA"), ("Carrots",1500,"USA"), ("Beans",1600,"USA"), \
      ("Orange",2000,"USA"),("Orange",2000,"USA"),("Banana",400,"China"), \
      ("Carrots",1200,"China"),("Beans",1500,"China"),("Orange",4000,"China"), \
      ("Banana",2000,"Canada"),("Carrots",2000,"Canada"),("Beans",2000,"Mexico")]

columns= ["Product","Amount","Country"]
df = spark.createDataFrame(data = data, schema = columns)
df.printSchema()

root
 |-- Product: string (nullable = true)
 |-- Amount: long (nullable = true)
 |-- Country: string (nullable = true)



In [4]:
df.show()

+-------+------+-------+
|Product|Amount|Country|
+-------+------+-------+
| Banana|  1000|    USA|
|Carrots|  1500|    USA|
|  Beans|  1600|    USA|
| Orange|  2000|    USA|
| Orange|  2000|    USA|
| Banana|   400|  China|
|Carrots|  1200|  China|
|  Beans|  1500|  China|
| Orange|  4000|  China|
| Banana|  2000| Canada|
|Carrots|  2000| Canada|
|  Beans|  2000| Mexico|
+-------+------+-------+



## Pivot DataFrame

In [5]:
pivotDf = df.groupBy("Product")\
            .pivot("Country")\
            .sum("Amount")

In [10]:
pivotDf.show()

+-------+------+-----+------+----+
|Product|Canada|China|Mexico| USA|
+-------+------+-----+------+----+
| Orange|  null| 4000|  null|4000|
|  Beans|  null| 1500|  2000|1600|
| Banana|  2000|  400|  null|1000|
|Carrots|  2000| 1200|  null|1500|
+-------+------+-----+------+----+



In [9]:
pivotDF = df.groupBy("Country")\
            .pivot("Product")\
            .sum("Amount")

# The column in the groupBy is created as the first column with its rows as values
# The Pivot column is spread as the columns and
# Sum is the aggregate value of each column and corresponding rows

In [8]:
pivotDF.show()

+-------+------+-----+-------+------+
|Country|Banana|Beans|Carrots|Orange|
+-------+------+-----+-------+------+
|  China|   400| 1500|   1200|  4000|
|    USA|  1000| 1600|   1500|  4000|
| Mexico|  null| 2000|   null|  null|
| Canada|  2000| null|   2000|  null|
+-------+------+-----+-------+------+



In [12]:
# We can choose the list order to print accordingly as columns.
countries = ["USA", "China", "Mexico", "Canada"]
pivotDF1 = df.groupBy("Product")\
            .pivot("Country", countries)\
            .sum("Amount")
pivotDF1.show()

+-------+----+-----+------+------+
|Product| USA|China|Mexico|Canada|
+-------+----+-----+------+------+
| Orange|4000| 4000|  null|  null|
|  Beans|1600| 1500|  2000|  null|
| Banana|1000|  400|  null|  2000|
|Carrots|1500| 1200|  null|  2000|
+-------+----+-----+------+------+



In [14]:
pivotDF = df.groupBy("Product","Country") \
      .sum("Amount") \
      .groupBy("Product") \
      .pivot("Country") \
      .sum("sum(Amount)") 
pivotDF.show(truncate=False)

+-------+------+-----+------+----+
|Product|Canada|China|Mexico|USA |
+-------+------+-----+------+----+
|Orange |null  |4000 |null  |4000|
|Beans  |null  |1500 |2000  |1600|
|Banana |2000  |400  |null  |1000|
|Carrots|2000  |1200 |null  |1500|
+-------+------+-----+------+----+



## Unpivot DataFrame

In [21]:
from pyspark.sql.functions import expr
unpivotExpr = "stack(2, 'Cana', Canada, 'Mexico', Mexico) as (Country, Total)"
unpivotDF = pivotDF.select("Product", expr(unpivotExpr))
unpivotDF.show()

+-------+-------+-----+
|Product|Country|Total|
+-------+-------+-----+
| Orange|   Cana| null|
| Orange| Mexico| null|
|  Beans|   Cana| null|
|  Beans| Mexico| 2000|
| Banana|   Cana| 2000|
| Banana| Mexico| null|
|Carrots|   Cana| 2000|
|Carrots| Mexico| null|
+-------+-------+-----+



In [22]:
#Printing Only not null values
from pyspark.sql.functions import expr
unpivotExpr = "stack(2, 'Cana', Canada, 'Mexico', Mexico) as (Country, Total)"
unpivotDF = pivotDF.select("Product", expr(unpivotExpr)).where("Total is not Null")
unpivotDF.show()

+-------+-------+-----+
|Product|Country|Total|
+-------+-------+-----+
|  Beans| Mexico| 2000|
| Banana|   Cana| 2000|
|Carrots|   Cana| 2000|
+-------+-------+-----+

