In [0]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

spark = SparkSession.builder.getOrCreate()

data = [
    (1, "C1", "2024-01-01", 100),
    (2, "C1", "2024-01-05", 200),
    (3, "C2", "2024-01-03", 50),
    (4, "C2", "2024-01-10", 70),
    (5, "C3", "2024-01-02", 500),
    (6, "C3", "2024-01-08", 300),
    (7, "C4", "2024-01-04", 40)
]

columns = ["order_id", "customer_id", "order_date", "amount"]

df = spark.createDataFrame(data, columns)
df.show()


+--------+-----------+----------+------+
|order_id|customer_id|order_date|amount|
+--------+-----------+----------+------+
|       1|         C1|2024-01-01|   100|
|       2|         C1|2024-01-05|   200|
|       3|         C2|2024-01-03|    50|
|       4|         C2|2024-01-10|    70|
|       5|         C3|2024-01-02|   500|
|       6|         C3|2024-01-08|   300|
|       7|         C4|2024-01-04|    40|
+--------+-----------+----------+------+



In [0]:
customer_total = df.groupBy(F.col("customer_id")).agg(F.sum(F.col("amount")).alias("total_amount"))

avg = customer_total.agg(F.avg(F.col("total_amount")).alias("avg_amount"))

customer_total.show()

+-----------+------------+
|customer_id|total_amount|
+-----------+------------+
|         C1|         300|
|         C2|         120|
|         C3|         800|
|         C4|          40|
+-----------+------------+



In [0]:
customer_total.crossJoin(avg).filter(F.col("total_amount") > F.col("avg_amount")).show()

+-----------+------------+----------+
|customer_id|total_amount|avg_amount|
+-----------+------------+----------+
|         C3|         800|     315.0|
+-----------+------------+----------+

