In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.window import Window
from pyspark.sql.types import *
from pyspark.sql.functions import *

In [3]:
spark = SparkSession.builder \
.master("local") \
.appName("window_groupby_df") \
.getOrCreate()

In [4]:
simpleData = [("James","Sales",3000),
      ("Michael","Sales",4600),
      ("Robert","Sales",4100),
      ("Maria","Finance",3000),
      ("Raman","Finance",3000),
      ("Scott","Finance",3300),
      ("Jen","Finance",3900),
      ("Jeff","Marketing",3000),
      ("Kumar","Marketing",2000)
             ]

columns = ("employee_name","department","salary")

df1 = spark.createDataFrame(data=simpleData, schema=columns)

In [6]:
mywindow = Window.partitionBy("department").orderBy(col("salary").desc())

In [8]:
df1.withColumn("row",row_number().over(mywindow)) \
.where(col("row")==1).drop().show()

+-------------+----------+------+---+
|employee_name|department|salary|row|
+-------------+----------+------+---+
|      Michael|     Sales|  4600|  1|
|          Jen|   Finance|  3900|  1|
|         Jeff| Marketing|  3000|  1|
+-------------+----------+------+---+



In [13]:
w3 = Window.partitionBy("department").orderBy(col("salary").desc())
w4 = Window.partitionBy("department")

In [14]:
df1.withColumn("row",row_number().over(w3)) \
      .withColumn("avg", avg(col("salary")).over(w4)) \
      .withColumn("sum", sum(col("salary")).over(w4)) \
      .withColumn("min", min(col("salary")).over(w4)) \
      .withColumn("max", max(col("salary")).over(w4)) \
      .where(col("row")==1).select("department","avg","sum","min","max") \
      .show()

+----------+------+-----+----+----+
|department|   avg|  sum| min| max|
+----------+------+-----+----+----+
|     Sales|3900.0|11700|3000|4600|
|   Finance|3300.0|13200|3000|3900|
| Marketing|2500.0| 5000|2000|3000|
+----------+------+-----+----+----+

