## Select Top N rows from each group

In [1]:
import findspark
findspark.init()

In [2]:
import pyspark
from pyspark.sql import SparkSession

In [3]:
spark = SparkSession.builder.appName("PySpark_Practice_03").getOrCreate()

In [4]:
from pyspark.sql.functions import current_timestamp, to_timestamp
from pyspark.sql.functions import *
from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType, DateType, BooleanType

In [5]:
sampledata = (("Nitya", "Sales", 3000), \
                ("Abhi", "Sales", 4600), \
                ("Rakesh", "Sales", 4100), \
                ("Sandeep", "finance", 3000), \
                ("Abhishek", "Sales", 3000), \
                ("Shyan", "finance", 3300), \
                ("Madan", "finance", 3900), \
                ("Jarin", "marketing", 3000), \
                ("kumar", "marketing", 2000))

In [6]:
columns = ["employee_name", "department", "salary"]

In [7]:
df = spark.createDataFrame(sampledata,columns)

In [8]:
df.show(truncate=False)

+-------------+----------+------+
|employee_name|department|salary|
+-------------+----------+------+
|Nitya        |Sales     |3000  |
|Abhi         |Sales     |4600  |
|Rakesh       |Sales     |4100  |
|Sandeep      |finance   |3000  |
|Abhishek     |Sales     |3000  |
|Shyan        |finance   |3300  |
|Madan        |finance   |3900  |
|Jarin        |marketing |3000  |
|kumar        |marketing |2000  |
+-------------+----------+------+



In [9]:
df.printSchema()

root
 |-- employee_name: string (nullable = true)
 |-- department: string (nullable = true)
 |-- salary: long (nullable = true)



In [10]:
from pyspark.sql import Window

In [11]:
windowSpec = Window.partitionBy("department").orderBy("salary")

In [16]:
# Row_number
df1 = df.withColumn("row", row_number().over(windowSpec))
df1.show()

+-------------+----------+------+---+
|employee_name|department|salary|row|
+-------------+----------+------+---+
|        Nitya|     Sales|  3000|  1|
|     Abhishek|     Sales|  3000|  2|
|       Rakesh|     Sales|  4100|  3|
|         Abhi|     Sales|  4600|  4|
|      Sandeep|   finance|  3000|  1|
|        Shyan|   finance|  3300|  2|
|        Madan|   finance|  3900|  3|
|        kumar| marketing|  2000|  1|
|        Jarin| marketing|  3000|  2|
+-------------+----------+------+---+



In [17]:
df2 = df1.filter(col("row") < 3)
df2.show()

+-------------+----------+------+---+
|employee_name|department|salary|row|
+-------------+----------+------+---+
|        Nitya|     Sales|  3000|  1|
|     Abhishek|     Sales|  3000|  2|
|      Sandeep|   finance|  3000|  1|
|        Shyan|   finance|  3300|  2|
|        kumar| marketing|  2000|  1|
|        Jarin| marketing|  3000|  2|
+-------------+----------+------+---+

