In [0]:
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, DateType
from pyspark.sql.functions import to_date
from pyspark.sql import functions as F
from pyspark.sql import Window

data = [
    (1, "2023-01-01", 100),
    (1, "2023-01-02", 150),
    (1, "2023-01-05", 200),
    (2, "2023-01-03", 300),
    (2, "2023-01-04", 400),
    (3, "2023-01-10", 500)
]

schema = StructType([
    StructField("customer_id", IntegerType(), True),
    StructField("purchase_date", StringType(), True),
    StructField("amount", IntegerType(), True)
])

df = spark.createDataFrame(data, schema)

df = df.withColumn("purchase_date", to_date("purchase_date"))
df.show()



+-----------+-------------+------+
|customer_id|purchase_date|amount|
+-----------+-------------+------+
|          1|   2023-01-01|   100|
|          1|   2023-01-02|   150|
|          1|   2023-01-05|   200|
|          2|   2023-01-03|   300|
|          2|   2023-01-04|   400|
|          3|   2023-01-10|   500|
+-----------+-------------+------+



In [0]:
window = Window.partitionBy(F.col("customer_id")).orderBy(F.col("purchase_date"))

df = df.withColumn("prev_purchase_dt", F.lag(F.col("purchase_date")).over(window))

df.filter((F.col("prev_purchase_dt").isNotNull()) & (F.datediff(F.col("purchase_date"), F.col("prev_purchase_dt")) == 1))\
    .select("customer_id").show()



+-----------+
|customer_id|
+-----------+
|          1|
|          2|
+-----------+

