## Pivot and Unpivot a Spark DataFrame

In [1]:
import findspark
findspark.init('C:\\spark\\spark-3.0.0-hadoop2.7')
import pyspark

In [2]:
from pyspark.sql import SparkSession

In [3]:
spark = SparkSession.builder.appName('Basic').getOrCreate()

In [4]:
spark

In [7]:
data_df = spark.createDataFrame([
    ("Banana",1000,"USA"), ("Carrots",1500,"USA"), ("Beans",1600,"USA"),
      ("Orange",2000,"USA"),("Orange",2000,"USA"),("Banana",400,"China"),
      ("Carrots",1200,"China"),("Beans",1500,"China"),("Orange",4000,"China"),
      ("Banana",2000,"Canada"),("Carrots",2000,"Canada"),("Beans",2000,"Mexico")
], ["Product","Amount","Country"])

In [8]:
data_df.show()

+-------+------+-------+
|Product|Amount|Country|
+-------+------+-------+
| Banana|  1000|    USA|
|Carrots|  1500|    USA|
|  Beans|  1600|    USA|
| Orange|  2000|    USA|
| Orange|  2000|    USA|
| Banana|   400|  China|
|Carrots|  1200|  China|
|  Beans|  1500|  China|
| Orange|  4000|  China|
| Banana|  2000| Canada|
|Carrots|  2000| Canada|
|  Beans|  2000| Mexico|
+-------+------+-------+



In [14]:
data_df.groupBy('Product').pivot('Country').sum('Amount').show()

+-------+------+-----+------+----+
|Product|Canada|China|Mexico| USA|
+-------+------+-----+------+----+
| Orange|  null| 4000|  null|4000|
|  Beans|  null| 1500|  2000|1600|
| Banana|  2000|  400|  null|1000|
|Carrots|  2000| 1200|  null|1500|
+-------+------+-----+------+----+



In [16]:
data_df.groupBy('Country').pivot('Product').sum('Amount').show()

+-------+------+-----+-------+------+
|Country|Banana|Beans|Carrots|Orange|
+-------+------+-----+-------+------+
|  China|   400| 1500|   1200|  4000|
|    USA|  1000| 1600|   1500|  4000|
| Mexico|  null| 2000|   null|  null|
| Canada|  2000| null|   2000|  null|
+-------+------+-----+-------+------+



In [24]:
# Following both approach will produce same output with better performance

# Spark 2.0 on-wards performance has been improved on Pivot, however, if you are using lower version; note that pivot is a very expensive operation hence, it is recommended to provide column data (if known) as an argument to function as shown below.
data_df.groupBy('Product').pivot('Country',["USA","China","Canada","Mexico"]).sum('Amount').show()


+-------+----+-----+------+------+
|Product| USA|China|Canada|Mexico|
+-------+----+-----+------+------+
| Orange|4000| 4000|  null|  null|
|  Beans|1600| 1500|  null|  2000|
| Banana|1000|  400|  2000|  null|
|Carrots|1500| 1200|  2000|  null|
+-------+----+-----+------+------+



In [23]:
data_df.groupBy('Product','Country').sum('Amount').groupBy('Product').pivot('Country').sum('sum(Amount)').show()

+-------+------+-----+------+----+
|Product|Canada|China|Mexico| USA|
+-------+------+-----+------+----+
| Orange|  null| 4000|  null|4000|
|  Beans|  null| 1500|  2000|1600|
| Banana|  2000|  400|  null|1000|
|Carrots|  2000| 1200|  null|1500|
+-------+------+-----+------+----+



In [30]:
data_df.groupBy('Product').agg({'Amount': 'mean'}).show()

+-------+------------------+
|Product|       avg(Amount)|
+-------+------------------+
| Orange|2666.6666666666665|
|  Beans|            1700.0|
| Banana|1133.3333333333333|
|Carrots|1566.6666666666667|
+-------+------------------+



In [29]:
row[0][0]

Column<b'Product[0][0]'>