###PartitionBy

In PySpark, Finding or Selecting the Top N rows per each group can be calculated by partitioning the data by window. Use the Window.partitionBy() function, running the row_number() function over the grouped partition, and finally, filtering the rows to get the top N rows. 

In [0]:
from pyspark.sql import Row
from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number

# Data
data = [
    ("James","Sales",3000),
    ("Michael","Sales",4600),
    ("Robert","Sales",4100),
    ("Maria","Finance",3000),
    ("Raman","Finance",3000),
    ("Scott","Finance",3300),
    ("Jen","Finance",3900),
    ("Jeff","Marketing",3000),
    ("Kumar","Marketing",2000)
]

# Create DataFrame
df = spark.createDataFrame(data, ["Name","Department","Salary"])
df.show()



In [0]:
# Define window specification
windowDept = Window.partitionBy("Department").orderBy(col("Salary").desc())

# Apply window and filter top 2 salaries per department
df.withColumn("row", row_number().over(windowDept)) \
  .filter(col("row") <= 2) \
  .drop("row") \
  .show()

In [0]:
#PySark SQL
df.createOrReplaceTempView("EMP")
spark.sql("select Name, Department, Salary from "+
     " (select *, row_number() OVER (PARTITION BY department ORDER BY salary DESC) as rn " +
     " FROM EMP) tmp where rn <= 2").show()