In [0]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.window import Window

spark = SparkSession.builder.getOrCreate()

data = [
    (1, "2023-01-01", 100),
    (1, "2023-01-02", 150),
    (1, "2023-01-05", 200),
    (2, "2023-02-10", 300),
    (2, "2023-02-11", 350),
    (3, "2023-03-01", 400)
]

columns = ["customer_id", "order_date", "amount"]

df = spark.createDataFrame(data, columns)\
          .withColumn("order_date", F.to_date("order_date"))

df.show()


+-----------+----------+------+
|customer_id|order_date|amount|
+-----------+----------+------+
|          1|2023-01-01|   100|
|          1|2023-01-02|   150|
|          1|2023-01-05|   200|
|          2|2023-02-10|   300|
|          2|2023-02-11|   350|
|          3|2023-03-01|   400|
+-----------+----------+------+



In [0]:
window = Window.partitionBy(F.col("customer_id")).orderBy(F.col("order_date"))
df.withColumn("prev_purchase_date", F.lag("order_date").over(window))\
    .filter(F.datediff(F.col("order_date"), F.col("prev_purchase_date"))==1).show()

+-----------+----------+------+------------------+
|customer_id|order_date|amount|prev_purchase_date|
+-----------+----------+------+------------------+
|          1|2023-01-02|   150|        2023-01-01|
|          2|2023-02-11|   350|        2023-02-10|
+-----------+----------+------+------------------+

