In [1]:

from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, DateType
from datetime import date
from pyspark.sql.functions import col, lead, datediff, when
from pyspark.sql.window import Window

spark = SparkSession.builder.getOrCreate()

data = [
    (1, 101, date(2023, 1, 1), 'Y'),
    (1, 102, date(2023, 1, 25), 'Y'),   # within 30 days of previous
    (1, 103, date(2023, 3, 1), 'Y'),    # >30 days from 1/25
    (2, 201, date(2023, 1, 10), 'Y'),
    (2, 202, date(2023, 2, 15), 'N'),
    (2, 203, date(2023, 2, 25), 'Y'),   # >30 days from 1/10
    (3, 301, date(2023, 2, 1), 'Y'),
    (3, 302, date(2023, 2, 25), 'Y'),   # within 30 days
    (4, 401, date(2023, 1, 1), 'N'),
    (4, 402, date(2023, 2, 1), 'N'),    # no successful orders
]

schema = StructType([
    StructField("cust_id", IntegerType(), True),
    StructField("order_id", IntegerType(), True),
    StructField("order_date", DateType(), True),
    StructField("order_status", StringType(), True),
])

orders_df = spark.createDataFrame(data, schema)

# Create a Window Spec
window_spec = Window.partitionBy('cust_id').orderBy('order_date')

# filter the successful orders
filtered_orders_df = orders_df.filter(col('order_status') == 'Y')

# add row_number col
rn_orders_df = filtered_orders_df.withColumn('next_odr_date', lead('order_date').over(window_spec))
rn_orders_df.show()


+-------+--------+----------+------------+-------------+
|cust_id|order_id|order_date|order_status|next_odr_date|
+-------+--------+----------+------------+-------------+
|      1|     101|2023-01-01|           Y|   2023-01-25|
|      1|     102|2023-01-25|           Y|   2023-03-01|
|      1|     103|2023-03-01|           Y|         NULL|
|      2|     201|2023-01-10|           Y|   2023-02-25|
|      2|     203|2023-02-25|           Y|         NULL|
|      3|     301|2023-02-01|           Y|   2023-02-25|
|      3|     302|2023-02-25|           Y|         NULL|
+-------+--------+----------+------------+-------------+



In [2]:
# retention -> diff between order_date and next_order_date
retention_check_df = rn_orders_df.withColumn('retention', datediff(col('next_odr_date'),col('order_date')))
retention_check_df.show()
# result
# result = retention_check_df.filter('retention ==  1').select('cust_id').distinct()

# result.show()

+-------+--------+----------+------------+-------------+---------+
|cust_id|order_id|order_date|order_status|next_odr_date|retention|
+-------+--------+----------+------------+-------------+---------+
|      1|     101|2023-01-01|           Y|   2023-01-25|       24|
|      1|     102|2023-01-25|           Y|   2023-03-01|       35|
|      1|     103|2023-03-01|           Y|         NULL|     NULL|
|      2|     201|2023-01-10|           Y|   2023-02-25|       46|
|      2|     203|2023-02-25|           Y|         NULL|     NULL|
|      3|     301|2023-02-01|           Y|   2023-02-25|       24|
|      3|     302|2023-02-25|           Y|         NULL|     NULL|
+-------+--------+----------+------------+-------------+---------+



In [4]:
spark.stop()