In [0]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.window import Window

spark = SparkSession.builder.appName("salary_growth").getOrCreate()

data = [
    (1, 2020, 50000),
    (1, 2021, 55000),
    (1, 2022, 53000),
    (2, 2020, 60000),
    (2, 2021, 65000),
    (3, 2021, 70000),
    (3, 2022, 75000)
]

columns = ["emp_id", "year", "salary"]

df = spark.createDataFrame(data, columns)
df.show()


+------+----+------+
|emp_id|year|salary|
+------+----+------+
|     1|2020| 50000|
|     1|2021| 55000|
|     1|2022| 53000|
|     2|2020| 60000|
|     2|2021| 65000|
|     3|2021| 70000|
|     3|2022| 75000|
+------+----+------+



In [0]:
window = Window.partitionBy(F.col("emp_id")).orderBy("year")
df.withColumn("prev_salary", F.lag(F.col("salary")).over(window))\
    .filter((F.col("salary")>F.col("prev_salary"))).show()


+------+----+------+-----------+
|emp_id|year|salary|prev_salary|
+------+----+------+-----------+
|     1|2021| 55000|      50000|
|     2|2021| 65000|      60000|
|     3|2022| 75000|      70000|
+------+----+------+-----------+



Time: O(N log N) => iteration and Sorting within window
Space : O(N) => Needs memory to hold window state