In [0]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql import Window

# -------------------------------------------
# 1. Create Spark Session
# -------------------------------------------
spark = SparkSession.builder \
    .appName("Top Earner by Department") \
    .getOrCreate()

# -------------------------------------------
# 2. Create the DataFrame
# -------------------------------------------
data = [
    ("John",  "Sales", 5000),
    ("Mary",  "Sales", 7000),
    ("Alex",  "HR",    6000),
    ("Alice", "HR",    6500)
]

columns = ["employee", "department", "salary"]

df = spark.createDataFrame(data, schema=columns)

print("Input DataFrame:")
df.show(5)

Input DataFrame:
+--------+----------+------+
|employee|department|salary|
+--------+----------+------+
|    John|     Sales|  5000|
|    Mary|     Sales|  7000|
|    Alex|        HR|  6000|
|   Alice|        HR|  6500|
+--------+----------+------+



In [0]:
max_df = df.groupBy(F.col('department').alias('dept')).agg(F.max(F.col('salary')).alias('max_salary'))
max_df.show()

+-----+----------+
| dept|max_salary|
+-----+----------+
|Sales|      7000|
|   HR|      6500|
+-----+----------+



In [0]:
df.join(max_df, on= (
    (df.department == max_df.dept) & (df.salary == max_df.max_salary))  
    ).select(F.col('employee'),F.col("department"), F.col("salary")).show()

+--------+----------+------+
|employee|department|salary|
+--------+----------+------+
|    Mary|     Sales|  7000|
|   Alice|        HR|  6500|
+--------+----------+------+

