### Skewed Data Simulation and Solution

In [35]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, explode, concat_ws, sum, count, split, array, lit
import datetime

In [36]:
# 1. Create Sample Skewed Data Directly in Python Lists
num_rows = 1000000
skewed_customer_id = 'customerA'
customer_ids_skewed = [skewed_customer_id] * (num_rows // 2)
customer_ids_others = [f'customer{i}' for i in range(1, 50)] * (num_rows // 2)
customer_ids = customer_ids_skewed + customer_ids_others
order_ids = range(1, num_rows + 1)
created_ats = [datetime.datetime(2023, 10, 26) for _ in range(num_rows)]
total_prices = [i * 0.1 for i in range(num_rows)]

In [37]:
data = []
for i in range(num_rows):
    data.append((order_ids[i], customer_ids[i], order_ids[i], created_ats[i], total_prices[i]))

In [38]:
# 2. Initialize SparkSession
spark = SparkSession.builder.appName("SaltingGroupByExample") \
  .config("spark.executor.memory", "1g") \
  .config("spark.driver.memory", "1g") \
  .config("spark.sql.autoBroadcastJoinThreshold", "100mb") \
  .config("spark.sql.shuffle.partitions", "90") \
  .config("spark.sql.adaptive.enabled", "true") \
  .getOrCreate()

In [39]:
# 3. Create DataFrame directly from Python List (In-Memory DataFrame)
orders_df = spark.createDataFrame(
  data,
  schema=[
    "id",
    "customer_id",
    "order_id",
    "created_at",
    "total_price"
  ]
)

In [40]:
# 4. Demonstrate Data Skew (Optional)
print("--- Count of orders per customer (showing skew) ---")
orders_df.groupBy("customer_id") \
  .count() \
  .orderBy(col("count").desc()) \
  .show(n=60, truncate=False)

--- Count of orders per customer (showing skew) ---


25/02/25 18:39:39 WARN TaskSetManager: Stage 0 contains a task of very large size (5102 KiB). The maximum recommended task size is 1000 KiB.
[Stage 0:>                                                          (0 + 6) / 6]

+-----------+------+
|customer_id|count |
+-----------+------+
|customerA  |500000|
|customer2  |10205 |
|customer1  |10205 |
|customer4  |10205 |
|customer3  |10205 |
|customer35 |10204 |
|customer14 |10204 |
|customer41 |10204 |
|customer6  |10204 |
|customer19 |10204 |
|customer30 |10204 |
|customer12 |10204 |
|customer43 |10204 |
|customer7  |10204 |
|customer44 |10204 |
|customer32 |10204 |
|customer47 |10204 |
|customer49 |10204 |
|customer36 |10204 |
|customer20 |10204 |
|customer5  |10204 |
|customer13 |10204 |
|customer40 |10204 |
|customer15 |10204 |
|customer31 |10204 |
|customer8  |10204 |
|customer37 |10204 |
|customer9  |10204 |
|customer29 |10204 |
|customer10 |10204 |
|customer38 |10204 |
|customer23 |10204 |
|customer33 |10204 |
|customer22 |10204 |
|customer25 |10204 |
|customer34 |10204 |
|customer46 |10204 |
|customer42 |10204 |
|customer11 |10204 |
|customer28 |10204 |
|customer27 |10204 |
|customer45 |10204 |
|customer26 |10204 |
|customer48 |10204 |
|customer18 |

                                                                                

In [41]:
# 5. GroupBy and Sum without Salting (Potential Skew)
print("\n--- GroupBy and Sum WITHOUT Salting (Potential Skew) ---")
start_time_no_salt = datetime.datetime.now()

grouped_no_salt_df = orders_df.groupBy("customer_id") \
  .agg(sum("total_price")
    .alias("total_spending")
  ).orderBy(
  col("total_spending").desc()
  ) \
  .show(n=60, truncate=False)

end_time_no_salt = datetime.datetime.now()
duration_no_salt = (end_time_no_salt - start_time_no_salt).total_seconds()
print(f"GroupBy without salting took: {duration_no_salt:.2f} seconds")

25/02/25 18:39:42 WARN TaskSetManager: Stage 3 contains a task of very large size (5102 KiB). The maximum recommended task size is 1000 KiB.



--- GroupBy and Sum WITHOUT Salting (Potential Skew) ---


[Stage 3:>                                                          (0 + 6) / 6]

+-----------+-------------------+
|customer_id|total_spending     |
+-----------+-------------------+
|customerA  |1.2499975E10       |
|customer4  |7.653760205E8      |
|customer3  |7.65375E8          |
|customer2  |7.653739795E8      |
|customer1  |7.65372959E8       |
|customer49 |7.653219385999999E8|
|customer48 |7.653209182E8      |
|customer47 |7.653198978E8      |
|customer46 |7.653188774000001E8|
|customer45 |7.65317857E8       |
|customer44 |7.653168365999999E8|
|customer43 |7.653158162E8      |
|customer42 |7.653147958000001E8|
|customer41 |7.653137754E8      |
|customer40 |7.65312755E8       |
|customer39 |7.653117346E8      |
|customer38 |7.653107142E8      |
|customer37 |7.653096938000001E8|
|customer36 |7.653086734E8      |
|customer35 |7.65307653E8       |
|customer34 |7.653066326E8      |
|customer33 |7.653056122E8      |
|customer32 |7.653045918000001E8|
|customer31 |7.653035714E8      |
|customer30 |7.65302551E8       |
|customer29 |7.653015306E8      |
|customer28 |7

                                                                                

In [42]:
# 6. GroupBy and Sum WITH Salting
print("\n--- GroupBy and Sum WITH Salting ---")
salt_count = 10  # Number of salts to use

start_time_salted = datetime.datetime.now()

# a) Generate salt values as a Python list
salt_values = list(range(salt_count))


--- GroupBy and Sum WITH Salting ---


In [43]:
# b) Add Salt to customer_id using array and explode
salted_orders_df = orders_df.withColumn(
  "salt", explode(
    array(*[lit(s) for s in salt_values])
  )) \
  .withColumn(
    "salted_customer_id",
    concat_ws(
      "_",
      col("customer_id"),
      col("salt")
    )
  )

In [44]:
# b) GroupBy on salted_customer_id and Aggregate
salted_grouped_df = salted_orders_df.groupBy("salted_customer_id") \
  .agg(
    sum("total_price")
    .alias("salted_total_spending")
  )


In [None]:
# c) Re-aggregate to original customer_id (remove salt and group again)
desalted_grouped_df = salted_grouped_df.withColumn(
  "original_customer_id",
  split(
    col("salted_customer_id"),
    "_"
  ).getItem(0)
) \
  .groupBy("original_customer_id") \
  .agg(
    sum("salted_total_spending")
    .alias("total_spending_salted")
  )

In [46]:
desalted_grouped_df.orderBy(
  col("total_spending_salted")
  .desc()
).show(n=60, truncate=False)

end_time_salted = datetime.datetime.now()
duration_salted = (end_time_salted - start_time_salted).total_seconds()
print(f"GroupBy with salting took: {duration_salted:.2f} seconds")

25/02/25 18:39:44 WARN TaskSetManager: Stage 6 contains a task of very large size (5102 KiB). The maximum recommended task size is 1000 KiB.
[Stage 6:>                                                          (0 + 6) / 6]

+--------------------+---------------------+
|original_customer_id|total_spending_salted|
+--------------------+---------------------+
|customerA           |1.2499975E11         |
|customer4           |7.653760205E9        |
|customer3           |7.65375E9            |
|customer2           |7.653739795E9        |
|customer1           |7.65372959E9         |
|customer49          |7.653219386E9        |
|customer48          |7.653209181999999E9  |
|customer47          |7.653198978000001E9  |
|customer46          |7.653188773999998E9  |
|customer45          |7.65317857E9         |
|customer44          |7.653168366E9        |
|customer43          |7.653158161999999E9  |
|customer42          |7.653147958000001E9  |
|customer41          |7.653137753999998E9  |
|customer40          |7.65312755E9         |
|customer39          |7.653117346000002E9  |
|customer38          |7.653107141999999E9  |
|customer37          |7.653096938000001E9  |
|customer36          |7.653086733999998E9  |
|customer3

                                                                                

In [47]:
print("\n--- Performance Comparison ---")
print(f"GroupBy without Salting Duration: {duration_no_salt:.2f} seconds")
print(f"GroupBy with Salting Duration: {duration_salted:.2f} seconds")
print("\n--- Note: In a local[*] mode with small data, the difference might not be drastic.")
print("--- In a cluster with significant skew and large datasets, salting's benefits will be much more pronounced. ---")


--- Performance Comparison ---
GroupBy without Salting Duration: 1.16 seconds
GroupBy with Salting Duration: 1.70 seconds

--- Note: In a local[*] mode with small data, the difference might not be drastic.
--- In a cluster with significant skew and large datasets, salting's benefits will be much more pronounced. ---


In [48]:
spark.stop()