In [0]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.window import Window

# -------------------------------------------
# 1. Create Spark Session
# -------------------------------------------
spark = SparkSession.builder \
    .appName("Running Total Salary") \
    .getOrCreate()

# -------------------------------------------
# 2. Create DataFrame
# -------------------------------------------
data = [
    ("John",  "Sales", 5000),
    ("Mary",  "Sales", 7000),
    ("Rick",  "Sales", 6500),
    ("Alex",  "HR",    6000),
    ("Alice", "HR",    6500)
]

columns = ["employee", "department", "salary"]

df = spark.createDataFrame(data, schema=columns)

print("Input DataFrame:")
df.show()


Input DataFrame:
+--------+----------+------+
|employee|department|salary|
+--------+----------+------+
|    John|     Sales|  5000|
|    Mary|     Sales|  7000|
|    Rick|     Sales|  6500|
|    Alex|        HR|  6000|
|   Alice|        HR|  6500|
+--------+----------+------+



In [0]:
#Time: O(N LOG N) => Sorting + window
#Space: O(N) => Window
window = Window.partitionBy("department").orderBy(F.col("salary").desc()).rowsBetween(Window.unboundedPreceding, Window.currentRow)
df.withColumn("running_total_salary", F.sum(F.col("salary")).over(window)).show()

+--------+----------+------+--------------------+
|employee|department|salary|running_total_salary|
+--------+----------+------+--------------------+
|   Alice|        HR|  6500|                6500|
|    Alex|        HR|  6000|               12500|
|    Mary|     Sales|  7000|                7000|
|    Rick|     Sales|  6500|               13500|
|    John|     Sales|  5000|               18500|
+--------+----------+------+--------------------+

