## Window Functions in PySpark

Apache Spark window functions let you perform **analytics across related rows** *without collapsing rows* (unlike `groupBy`). They’re essential for rankings, running totals, moving averages, and session-style analysis.

---

### What a Window Function Does

A window function computes a value for **each row** based on:

* **Partition** → how rows are grouped
* **Order** → how rows are sorted within each group
* **Frame** → which rows around the current row are considered




## 1) Define a Window Specification

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql.functions import *

In [None]:
spark = SparkSession.builder.appName("WindowApp").getOrCreate()

In [None]:
data = [
    (101, "Amit",   "IT",     60000, "2024-01-01"),
    (102, "Ravi",   "IT",     65000, "2024-02-01"),
    (103, "Sneha",  "IT",     65000, "2024-03-01"),
    (104, "Kiran",  "HR",     45000, "2024-01-01"),
    (105, "Pooja",  "HR",     48000, "2024-02-01"),
    (106, "Neha",   "HR",     48000, "2024-03-01"),
    (107, "Arjun",  "Sales",  55000, "2024-01-01"),
    (108, "Manoj",  "Sales",  52000, "2024-02-01"),
    (109, "Divya",  "Sales",  58000, "2024-03-01"),
]

company_schema = StructType([
    StructField("emp_id", IntegerType()),
    StructField("name", StringType()),
    StructField("department", StringType()),
    StructField("salary", IntegerType()),
    StructField("salary_date", StringType())
    
])

company_df = spark.createDataFrame(data, schema=company_schema)


In [None]:
import sys
print("Driver Python:", sys.executable)

spark.sparkContext.parallelize([1]).map(
    lambda x: sys.executable
).collect()


In [4]:
company_df.show()

Py4JJavaError: An error occurred while calling o41.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 0.0 failed 1 times, most recent failure: Lost task 0.0 in stage 0.0 (TID 0) (192.168.1.11 executor driver): org.apache.spark.SparkException: Python worker exited unexpectedly (crashed). Consider setting 'spark.sql.execution.pyspark.udf.faulthandler.enabled' or'spark.python.worker.faulthandler.enabled' configuration to 'true' for the better Python traceback.
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator$$anonfun$1.applyOrElse(PythonRunner.scala:678)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator$$anonfun$1.applyOrElse(PythonRunner.scala:663)
	at scala.runtime.AbstractPartialFunction.apply(AbstractPartialFunction.scala:35)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:1034)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:1014)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:596)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:611)
	at scala.collection.Iterator$$anon$9.hasNext(Iterator.scala:593)
	at scala.collection.Iterator$$anon$9.hasNext(Iterator.scala:593)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:50)
	at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:402)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:901)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:901)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:374)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:338)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:180)
	at org.apache.spark.scheduler.Task.run(Task.scala:147)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$5(Executor.scala:716)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:86)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:83)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:97)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:719)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642)
	at java.base/java.lang.Thread.run(Thread.java:1583)
Caused by: java.net.SocketException: Connection reset
	at java.base/sun.nio.ch.SocketChannelImpl.throwConnectionReset(SocketChannelImpl.java:401)
	at java.base/sun.nio.ch.SocketChannelImpl.read(SocketChannelImpl.java:434)
	at org.apache.spark.api.python.BasePythonRunner$ReaderInputStream.read(PythonRunner.scala:837)
	at java.base/java.io.BufferedInputStream.fill(BufferedInputStream.java:291)
	at java.base/java.io.BufferedInputStream.read1(BufferedInputStream.java:347)
	at java.base/java.io.BufferedInputStream.implRead(BufferedInputStream.java:420)
	at java.base/java.io.BufferedInputStream.read(BufferedInputStream.java:399)
	at java.base/java.io.DataInputStream.readFully(DataInputStream.java:208)
	at java.base/java.io.DataInputStream.readInt(DataInputStream.java:385)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:1022)
	... 26 more

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$3(DAGScheduler.scala:3122)
	at scala.Option.getOrElse(Option.scala:201)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:3122)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:3114)
	at scala.collection.immutable.List.foreach(List.scala:323)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:3114)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1303)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1303)
	at scala.Option.foreach(Option.scala:437)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1303)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3397)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3328)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3317)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:50)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:1017)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2496)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2517)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2536)
	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:544)
	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:497)
	at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:58)
	at org.apache.spark.sql.classic.Dataset.collectFromPlan(Dataset.scala:2275)
	at org.apache.spark.sql.classic.Dataset.$anonfun$head$1(Dataset.scala:1401)
	at org.apache.spark.sql.classic.Dataset.$anonfun$withAction$2(Dataset.scala:2265)
	at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:717)
	at org.apache.spark.sql.classic.Dataset.$anonfun$withAction$1(Dataset.scala:2263)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId0$8(SQLExecution.scala:177)
	at org.apache.spark.sql.execution.SQLExecution$.withSessionTagsApplied(SQLExecution.scala:285)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId0$7(SQLExecution.scala:139)
	at org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94)
	at org.apache.spark.sql.artifact.ArtifactManager.$anonfun$withResources$1(ArtifactManager.scala:112)
	at org.apache.spark.sql.artifact.ArtifactManager.withClassLoaderIfNeeded(ArtifactManager.scala:106)
	at org.apache.spark.sql.artifact.ArtifactManager.withResources(ArtifactManager.scala:111)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId0$6(SQLExecution.scala:139)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:308)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId0$1(SQLExecution.scala:138)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:804)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId0(SQLExecution.scala:92)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:250)
	at org.apache.spark.sql.classic.Dataset.withAction(Dataset.scala:2263)
	at org.apache.spark.sql.classic.Dataset.head(Dataset.scala:1401)
	at org.apache.spark.sql.Dataset.take(Dataset.scala:2814)
	at org.apache.spark.sql.classic.Dataset.getRows(Dataset.scala:338)
	at org.apache.spark.sql.classic.Dataset.showString(Dataset.scala:374)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:75)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:52)
	at java.base/java.lang.reflect.Method.invoke(Method.java:580)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:184)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:108)
	at java.base/java.lang.Thread.run(Thread.java:1583)
Caused by: org.apache.spark.SparkException: Python worker exited unexpectedly (crashed). Consider setting 'spark.sql.execution.pyspark.udf.faulthandler.enabled' or'spark.python.worker.faulthandler.enabled' configuration to 'true' for the better Python traceback.
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator$$anonfun$1.applyOrElse(PythonRunner.scala:678)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator$$anonfun$1.applyOrElse(PythonRunner.scala:663)
	at scala.runtime.AbstractPartialFunction.apply(AbstractPartialFunction.scala:35)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:1034)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:1014)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:596)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:611)
	at scala.collection.Iterator$$anon$9.hasNext(Iterator.scala:593)
	at scala.collection.Iterator$$anon$9.hasNext(Iterator.scala:593)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:50)
	at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:402)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:901)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:901)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:374)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:338)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:180)
	at org.apache.spark.scheduler.Task.run(Task.scala:147)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$5(Executor.scala:716)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:86)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:83)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:97)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:719)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642)
	... 1 more
Caused by: java.net.SocketException: Connection reset
	at java.base/sun.nio.ch.SocketChannelImpl.throwConnectionReset(SocketChannelImpl.java:401)
	at java.base/sun.nio.ch.SocketChannelImpl.read(SocketChannelImpl.java:434)
	at org.apache.spark.api.python.BasePythonRunner$ReaderInputStream.read(PythonRunner.scala:837)
	at java.base/java.io.BufferedInputStream.fill(BufferedInputStream.java:291)
	at java.base/java.io.BufferedInputStream.read1(BufferedInputStream.java:347)
	at java.base/java.io.BufferedInputStream.implRead(BufferedInputStream.java:420)
	at java.base/java.io.BufferedInputStream.read(BufferedInputStream.java:399)
	at java.base/java.io.DataInputStream.readFully(DataInputStream.java:208)
	at java.base/java.io.DataInputStream.readInt(DataInputStream.java:385)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:1022)
	... 26 more


Optionally add a **frame**:

In [None]:
window_spec = (
    Window
    .partitionBy("department")
    .orderBy("date")
    .rowsBetween(Window.unboundedPreceding, Window.currentRow)
)

---

## 2) Common Window Functions

### Ranking Functions

In [None]:
from pyspark.sql.functions import row_number, rank, dense_rank

df.withColumn("row_num", row_number().over(window_spec))
df.withColumn("rank", rank().over(window_spec))
df.withColumn("dense_rank", dense_rank().over(window_spec))

**Differences**

* `row_number()` → unique sequence (no ties)
* `rank()` → gaps after ties
* `dense_rank()` → no gaps

![Image](https://miro.medium.com/1%2AtuGFvhwk5rUtoQWcX4A6ng.gif)

![Image](https://media.licdn.com/dms/image/v2/D4D12AQFRS0AU_T_NQQ/article-cover_image-shrink_720_1280/article-cover_image-shrink_720_1280/0/1654768699532?e=2147483647\&t=ONuIV9v8k99yhD5U52ocS8i3WOsKk_p_xlfjmGVVXEg\&v=beta)

---

### Aggregate Functions (Windowed)

In [None]:
from pyspark.sql.functions import sum, avg, max

df.withColumn("dept_avg_salary", avg("salary").over(window_spec))
df.withColumn("running_total", sum("sales").over(window_spec))

> Same aggregates as `groupBy`, but **row-level output is preserved**.

---

### Analytical / Value Functions

In [None]:
from pyspark.sql.functions import lag, lead, first, last

df.withColumn("prev_salary", lag("salary", 1).over(window_spec))
df.withColumn("next_salary", lead("salary", 1).over(window_spec))
df.withColumn("first_sal", first("salary").over(window_spec))
df.withColumn("last_sal", last("salary").over(window_spec))

Use `ignoreNulls=True` when needed:

In [None]:
last("salary", ignorenulls=True).over(window_spec)

---

## 3) Rows vs Range Frames (Very Important)

### Rows-based frame

In [None]:
Window.orderBy("date").rowsBetween(-2, 0)

* Counts **physical rows**
* Deterministic with strict ordering

### Range-based frame

In [None]:
Window.orderBy("date").rangeBetween(-7, 0)

* Uses **value ranges**
* Multiple rows can share the same order key

> For time series with duplicate timestamps, **rowsBetween** is safer.

---

## 4) Practical Data Engineering Examples

### Latest Record per Key (De-duplication)

In [None]:
w = Window.partitionBy("user_id").orderBy(col("updated_at").desc())

df.filter(row_number().over(w) == 1)

### Top-N per Group

In [None]:
w = Window.partitionBy("category").orderBy(col("revenue").desc())

df.withColumn("r", dense_rank().over(w)) \
  .filter("r <= 3")

### Running Metrics (Finance / Logs)

In [None]:
w = Window.partitionBy("account").orderBy("date") \
          .rowsBetween(Window.unboundedPreceding, Window.currentRow)

df.withColumn("balance", sum("amount").over(w))

---

## 5) Performance Notes (Critical for Interviews)

* Window functions **cause shuffles** if `partitionBy` changes data distribution
* Prefer **low-cardinality partitions**
* Avoid very wide frames on huge datasets
* Cache if the same window spec is reused
* Window ≠ groupBy: window keeps **N rows**, groupBy reduces to **K rows**

---

## Quick Mental Model

| Aspect         | groupBy | Window |
| -------------- | ------- | ------ |
| Rows preserved | ❌       | ✅      |
| Aggregation    | ✅       | ✅      |
| Ranking / Lag  | ❌       | ✅      |
| Shuffle        | Yes     | Yes    |

---

If you want, I can next:

* Deep-dive into **window execution internals (DAG + shuffle)**
* Compare **SQL window functions vs PySpark**
* Give **interview-grade window function questions**
* Optimize a **real dataset window query** you’re working on