In [0]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.window import Window

spark = SparkSession.builder.getOrCreate()

data = [
    ("Electronics", "Phone", 10, 500),
    ("Electronics", "Laptop", 5, 1000),
    ("Electronics", "Tablet", 8, 600),
    ("Clothing", "Shirt", 20, 40),
    ("Clothing", "Jeans", 10, 60),
    ("Clothing", "Jacket", 5, 120),
    ("Clothing", "Shoes", 7, 80)
]

columns = ["category", "product", "quantity", "price"]

df = spark.createDataFrame(data, columns)
df = df.withColumn("total_revenue", F.col("quantity")*F.col("price"))
df.show()

+-----------+-------+--------+-----+-------------+
|   category|product|quantity|price|total_revenue|
+-----------+-------+--------+-----+-------------+
|Electronics|  Phone|      10|  500|         5000|
|Electronics| Laptop|       5| 1000|         5000|
|Electronics| Tablet|       8|  600|         4800|
|   Clothing|  Shirt|      20|   40|          800|
|   Clothing|  Jeans|      10|   60|          600|
|   Clothing| Jacket|       5|  120|          600|
|   Clothing|  Shoes|       7|   80|          560|
+-----------+-------+--------+-----+-------------+



In [0]:
window = Window.partitionBy(F.col("category")).orderBy(F.col("total_revenue").desc())
df.withColumn("rn", F.row_number().over(window)).filter(F.col("rn")<=2)\
    .select(F.col("category"), F.col("product"), F.col("total_revenue")).show()

+-----------+-------+-------------+
|   category|product|total_revenue|
+-----------+-------+-------------+
|   Clothing|  Shirt|          800|
|   Clothing|  Jeans|          600|
|Electronics|  Phone|         5000|
|Electronics| Laptop|         5000|
+-----------+-------+-------------+

