# Pivot - Transform the Rows into Columns

In [2]:
import findspark
findspark.init()

from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('Pivot').master('local[1]').getOrCreate()

In [3]:

from pyspark.sql.functions import expr
#Create spark session
data = [("Banana",1000,"USA"), ("Carrots",1500,"USA"), ("Beans",1600,"USA"), \
      ("Orange",2000,"USA"),("Orange",2000,"USA"),("Banana",400,"China"), \
      ("Carrots",1200,"China"),("Beans",1500,"China"),("Orange",4000,"China"), \
      ("Banana",2000,"Canada"),("Carrots",2000,"Canada"),("Beans",2000,"Mexico")]

columns= ["Product","Amount","Country"]

df = spark.createDataFrame(data = data, schema = columns)

df.printSchema()
df.show(truncate=False)


root
 |-- Product: string (nullable = true)
 |-- Amount: long (nullable = true)
 |-- Country: string (nullable = true)

+-------+------+-------+
|Product|Amount|Country|
+-------+------+-------+
|Banana |1000  |USA    |
|Carrots|1500  |USA    |
|Beans  |1600  |USA    |
|Orange |2000  |USA    |
|Orange |2000  |USA    |
|Banana |400   |China  |
|Carrots|1200  |China  |
|Beans  |1500  |China  |
|Orange |4000  |China  |
|Banana |2000  |Canada |
|Carrots|2000  |Canada |
|Beans  |2000  |Mexico |
+-------+------+-------+



In [5]:

pivotDF = df.groupBy("Product").pivot("Country").sum("Amount").fillna(0)

pivotDF.printSchema()
pivotDF.show(truncate=False)


root
 |-- Product: string (nullable = true)
 |-- Canada: long (nullable = true)
 |-- China: long (nullable = true)
 |-- Mexico: long (nullable = true)
 |-- USA: long (nullable = true)

+-------+------+-----+------+----+
|Product|Canada|China|Mexico|USA |
+-------+------+-----+------+----+
|Orange |0     |4000 |0     |4000|
|Beans  |0     |1500 |2000  |1600|
|Banana |2000  |400  |0     |1000|
|Carrots|2000  |1200 |0     |1500|
+-------+------+-----+------+----+



In [6]:
#Better Performance

countries = ["USA","China","Canada","Mexico"]
pivotDF1 = df.groupBy("Product").pivot("Country", countries).sum("Amount")

pivotDF1.show(truncate=False)


+-------+----+-----+------+------+
|Product|USA |China|Canada|Mexico|
+-------+----+-----+------+------+
|Orange |4000|4000 |null  |null  |
|Beans  |1600|1500 |null  |2000  |
|Banana |1000|400  |2000  |null  |
|Carrots|1500|1200 |2000  |null  |
+-------+----+-----+------+------+



# Unpivot

In [7]:
'''Unpivot is a reverse operation, we can achieve by rotating column values into rows values. PySpark SQL doesn’t have unpivot function hence will use the stack() function'''

'Unpivot is a reverse operation, we can achieve by rotating column values into rows values. PySpark SQL doesn’t have unpivot function hence will use the stack() function'

In [10]:

from pyspark.sql.functions import expr
unpivotExpr = "stack(3, 'Canada', Canada, 'China', China, 'Mexico', Mexico) as (Country,Total)"
unPivotDF = pivotDF.select("Product", expr(unpivotExpr)) \
    .where("Total is not null")
unPivotDF.show(truncate=False)
#unPivotDF.show()

+-------+-------+-----+
|Product|Country|Total|
+-------+-------+-----+
|Orange |Canada |0    |
|Orange |China  |4000 |
|Orange |Mexico |0    |
|Beans  |Canada |0    |
|Beans  |China  |1500 |
|Beans  |Mexico |2000 |
|Banana |Canada |2000 |
|Banana |China  |400  |
|Banana |Mexico |0    |
|Carrots|Canada |2000 |
|Carrots|China  |1200 |
|Carrots|Mexico |0    |
+-------+-------+-----+



In [11]:
spark.stop()