Scenario: StreamPulse Revenue Pipeline (Slow!)
The StreamPulse revenue pipeline has been running for 45+ minutes in production. Management wants it under 10 minutes. The pipeline:

Reads listening events, subscription data, and ad revenue
Joins them into a unified revenue DataFrame
Produces 6 reports: genre revenue, regional breakdown, subscription analysis, ad performance, artist payouts, and a daily summary
The current code has multiple performance anti-patterns. Your job: fix them.



In [1]:
# Part 1: Set Up the "Slow" Environment

# Intentionally misconfigured SparkSession:

from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.window import Window
import random
import time

# ANTI-PATTERN: broadcast disabled, too many shuffle partitions
spark = SparkSession.builder \
    .appName("StreamPulse-Revenue-SLOW") \
    .config("spark.sql.shuffle.partitions", "200") \
    .config("spark.sql.adaptive.enabled", "false") \
    .config("spark.sql.autoBroadcastJoinThreshold", "-1") \
    .getOrCreate()

print("‚úÖ SparkSession created (intentionally misconfigured)")

‚úÖ SparkSession created (intentionally misconfigured)


In [3]:
# Generate the revenue dataset:

import builtins
import random

N = 600000

# Events (large)
event_data = []
for i in range(N):
    event_data.append((
        f"EVT-{i+1:07d}",
        f"USR-{random.randint(1, 100000):06d}",
        f"ART-{random.randint(1, 5000):05d}",
        random.choice(["Pop", "Rock", "Hip-Hop", "Jazz", "Electronic", "R&B", "Country", "Classical"]),
        random.choice(["North America", "Europe", "Asia Pacific", "Latin America", "Africa"]),
        random.randint(15, 350),
        random.choice([True, False]),
        random.choice(["mobile", "desktop", "smart_speaker", "tablet", "car", "tv"]),
        f"2024-{random.randint(1,12):02d}-{random.randint(1,28):02d}",
    ))

events = spark.createDataFrame(event_data,
    ["event_id", "user_id", "artist_id", "genre", "region",
     "duration_sec", "completed", "device", "event_date"]) \
    .withColumn("event_date", col("event_date").cast("date")) \
    .withColumn("month", month(col("event_date")))

events.write.parquet("revenue_data/events", mode="overwrite")

# Subscriptions (medium - 100K users with subscription info)
sub_data = [(f"USR-{i+1:06d}",
             random.choice(["free", "individual", "family", "student"]),
             builtins.round(random.choice([0.0, 9.99, 14.99, 4.99]), 2),
             random.choice(["US", "UK", "DE", "JP", "BR", "IN", "KR", "FR"]))
            for i in range(100000)]
subscriptions = spark.createDataFrame(sub_data, ["user_id", "plan", "monthly_price", "country"])
subscriptions.write.parquet("revenue_data/subscriptions", mode="overwrite")

# Ad rates (tiny - 8 genres x 6 devices = 48 rows)
ad_data = []
for genre in ["Pop", "Rock", "Hip-Hop", "Jazz", "Electronic", "R&B", "Country", "Classical"]:
    for device in ["mobile", "desktop", "smart_speaker", "tablet", "car", "tv"]:
        cpm = builtins.round(random.uniform(1.5, 8.0), 2)
        ad_data.append((genre, device, cpm))
ad_rates = spark.createDataFrame(ad_data, ["ad_genre", "ad_device", "cpm"])
ad_rates.write.parquet("revenue_data/ad_rates", mode="overwrite")

# Artist payout rates (small - 5000 artists)
payout_data = [(f"ART-{i+1:05d}", builtins.round(random.uniform(0.003, 0.008), 4),
                random.choice(["major", "indie", "unsigned"]))
               for i in range(5000)]
payouts = spark.createDataFrame(payout_data, ["artist_id", "per_stream_rate", "label_type"])
payouts.write.parquet("revenue_data/payouts", mode="overwrite")

# Reload from disk
events = spark.read.parquet("revenue_data/events")
subscriptions = spark.read.parquet("revenue_data/subscriptions")
ad_rates = spark.read.parquet("revenue_data/ad_rates")
payouts = spark.read.parquet("revenue_data/payouts")

print(f"Events: {events.count()} | Subs: {subscriptions.count()} | "
      f"Ad rates: {ad_rates.count()} | Payouts: {payouts.count()}")

Events: 600000 | Subs: 100000 | Ad rates: 48 | Payouts: 5000


In [5]:
# Part 2: The Unoptimized Pipeline
from pyspark.sql.functions import countDistinct

print("=" * 60)
print("RUNNING UNOPTIMIZED PIPELINE (BASELINE)")
print("=" * 60)

total_start = time.time()

# Build enriched revenue DataFrame (NOT cached, recomputed every time)
def build_revenue():
    return events \
        .join(subscriptions, "user_id") \
        .join(ad_rates,
              (events.genre == ad_rates.ad_genre) & (events.device == ad_rates.ad_device)) \
        .join(payouts, "artist_id") \
        .withColumn("ad_revenue", col("cpm") / 1000) \
        .withColumn("stream_payout", col("per_stream_rate")) \
        .withColumn("is_premium", when(col("plan") != "free", True).otherwise(False))

# Report 1: Genre Revenue
revenue = build_revenue()
r1_start = time.time()
report_1 = revenue.groupBy("genre") \
    .agg(sum("ad_revenue").alias("total_ad_rev"),
         countDistinct("user_id").alias("unique_listeners")) \
    .collect()
r1_time = time.time() - r1_start

# Report 2: Regional Breakdown
revenue = build_revenue()
r2_start = time.time()
report_2 = revenue.groupBy("region", "country") \
    .agg(count("*").alias("streams"),
         sum("ad_revenue").alias("ad_rev")) \
    .collect()
r2_time = time.time() - r2_start

# Report 3: Subscription Analysis
revenue = build_revenue()
r3_start = time.time()
report_3 = revenue.groupBy("plan") \
    .agg(countDistinct("user_id").alias("users"),
         count("*").alias("total_streams"),
         avg("duration_sec").alias("avg_duration")) \
    .collect()
r3_time = time.time() - r3_start

# Report 4: Ad Performance
revenue = build_revenue()
r4_start = time.time()
report_4 = revenue.groupBy("device", "genre") \
    .agg(sum("ad_revenue").alias("total_ad_rev"),
         count("*").alias("impressions")) \
    .collect()
r4_time = time.time() - r4_start

# Report 5: Artist Payouts
revenue = build_revenue()
r5_start = time.time()
report_5 = revenue.groupBy("artist_id", "label_type") \
    .agg(sum("stream_payout").alias("total_payout"),
         count("*").alias("total_streams")) \
    .orderBy(desc("total_payout")).limit(100) \
    .collect()
r5_time = time.time() - r5_start

# Report 6: Daily Summary
revenue = build_revenue()
r6_start = time.time()
report_6 = revenue.groupBy("event_date") \
    .agg(count("*").alias("streams"),
         sum("ad_revenue").alias("ad_rev"),
         countDistinct("user_id").alias("unique_users")) \
    .orderBy("event_date") \
    .collect()
r6_time = time.time() - r6_start

baseline_total = time.time() - total_start

print(f"\nReport 1 (genre):        {r1_time:.2f}s")
print(f"Report 2 (regional):     {r2_time:.2f}s")
print(f"Report 3 (subscription): {r3_time:.2f}s")
print(f"Report 4 (ad perf):      {r4_time:.2f}s")
print(f"Report 5 (payouts):      {r5_time:.2f}s")
print(f"Report 6 (daily):        {r6_time:.2f}s")
print(f"\n‚è±Ô∏è  BASELINE TOTAL: {baseline_total:.2f}s")

RUNNING UNOPTIMIZED PIPELINE (BASELINE)

Report 1 (genre):        54.86s
Report 2 (regional):     25.24s
Report 3 (subscription): 32.94s
Report 4 (ad perf):      21.38s
Report 5 (payouts):      15.00s
Report 6 (daily):        39.85s

‚è±Ô∏è  BASELINE TOTAL: 190.12s


In [6]:
print("\nBASELINE PLAN:")
build_revenue().groupBy("genre").agg(sum("ad_revenue")).explain(mode="formatted")


BASELINE PLAN:
== Physical Plan ==
* HashAggregate (33)
+- Exchange (32)
   +- * HashAggregate (31)
      +- * Project (30)
         +- * SortMergeJoin Inner (29)
            :- * Sort (23)
            :  +- Exchange (22)
            :     +- * Project (21)
            :        +- * SortMergeJoin Inner (20)
            :           :- * Sort (14)
            :           :  +- Exchange (13)
            :           :     +- * Project (12)
            :           :        +- * SortMergeJoin Inner (11)
            :           :           :- * Sort (5)
            :           :           :  +- Exchange (4)
            :           :           :     +- * Filter (3)
            :           :           :        +- * ColumnarToRow (2)
            :           :           :           +- Scan parquet  (1)
            :           :           +- * Sort (10)
            :           :              +- Exchange (9)
            :           :                 +- * Filter (8)
            :           :       

Anti-Pattern Documentation Table

| Anti-Pattern | Description                                            | Impact                                                          |
| ------------ | ------------------------------------------------------ | --------------------------------------------------------------- |
| 1            | Broadcast disabled (`autoBroadcastJoinThreshold = -1`) | Forces SortMergeJoin even for tiny tables ‚Üí unnecessary shuffle |
| 2            | `build_revenue()` called 6 times                       | Entire join pipeline recomputed 6 times                         |
| 3            | No caching of enriched DataFrame                       | Full lineage re-executed for every report                       |
| 4            | Too many shuffle partitions (200 in local mode)        | Excessive task overhead and scheduling cost                     |
| 5            | No column pruning before joins                         | Unnecessary data read and shuffled across cluster               |



In [7]:
# optimization 1 :OPTIMIZATION 1 ‚Äî Enable Broadcast Joins
# What we fix

# Small tables (ad_rates, payouts) should use BroadcastHashJoin instead of SortMergeJoin.
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "10485760")  # 10MB

from pyspark.sql.functions import broadcast

def build_revenue_opt1():
    return events \
        .join(subscriptions, "user_id") \
        .join(
            broadcast(ad_rates),
            (events.genre == ad_rates.ad_genre) &
            (events.device == ad_rates.ad_device)
        ) \
        .join(
            broadcast(payouts),
            "artist_id"
        ) \
        .withColumn("ad_revenue", col("cpm") / 1000) \
        .withColumn("stream_payout", col("per_stream_rate")) \
        .withColumn("is_premium", when(col("plan") != "free", True).otherwise(False))

start = time.time()
build_revenue_opt1().groupBy("genre").agg(sum("ad_revenue")).collect()
print("Optimization 1 time:", time.time() - start)

Optimization 1 time: 5.916804552078247


In [8]:
#OPTIMIZATION 2 ‚Äî Cache Enriched DataFrame
#What we fix

#Baseline rebuilt full join 6 times.

#Code
revenue_cached = build_revenue_opt1().cache()

# Materialize cache
revenue_cached.count()

start = time.time()

revenue_cached.groupBy("genre").agg(sum("ad_revenue")).collect()
revenue_cached.groupBy("region", "country").agg(count("*")).collect()
revenue_cached.groupBy("plan").agg(count("*")).collect()
revenue_cached.groupBy("device", "genre").agg(count("*")).collect()
revenue_cached.groupBy("artist_id").agg(sum("stream_payout")).collect()
revenue_cached.groupBy("event_date").agg(count("*")).collect()

print("Optimization 2 time:", time.time() - start)

revenue_cached.unpersist()

Optimization 2 time: 14.344804763793945


DataFrame[artist_id: string, user_id: string, event_id: string, genre: string, region: string, duration_sec: bigint, completed: boolean, device: string, event_date: date, month: int, plan: string, monthly_price: double, country: string, ad_genre: string, ad_device: string, cpm: double, per_stream_rate: double, label_type: string, ad_revenue: double, stream_payout: double, is_premium: boolean]

In [9]:
#OPTIMIZATION 3 ‚Äî Reduce Shuffle Partitions
# What we fix

# 200 partitions is too high for local mode.

# Config
spark.conf.set("spark.sql.shuffle.partitions", "8")
#Re-run
revenue_opt3 = build_revenue_opt1().cache()
revenue_opt3.count()

start = time.time()
revenue_opt3.groupBy("genre").agg(sum("ad_revenue")).collect()
print("Optimization 3 time:", time.time() - start)

revenue_opt3.unpersist()

Optimization 3 time: 0.5991461277008057


DataFrame[artist_id: string, user_id: string, event_id: string, genre: string, region: string, duration_sec: bigint, completed: boolean, device: string, event_date: date, month: int, plan: string, monthly_price: double, country: string, ad_genre: string, ad_device: string, cpm: double, per_stream_rate: double, label_type: string, ad_revenue: double, stream_payout: double, is_premium: boolean]

In [10]:
#OPTIMIZATION 4 ‚Äî Column Pruning
# What we fix

# Baseline read unnecessary columns.

#Code
events_small = events.select(
    "event_id", "user_id", "artist_id",
    "genre", "region", "duration_sec",
    "device", "event_date"
)

subscriptions_small = subscriptions.select(
    "user_id", "plan", "country"
)

ad_rates_small = ad_rates.select(
    "ad_genre", "ad_device", "cpm"
)

payouts_small = payouts.select(
    "artist_id", "per_stream_rate", "label_type"
)

revenue_opt4 = events_small \
    .join(subscriptions_small, "user_id") \
    .join(
        broadcast(ad_rates_small),
        (col("genre") == col("ad_genre")) &
        (col("device") == col("ad_device"))
    ) \
    .join(
        broadcast(payouts_small),
        "artist_id"
    ) \
    .withColumn("ad_revenue", col("cpm") / 1000) \
    .withColumn("stream_payout", col("per_stream_rate")) \
    .drop("ad_genre", "ad_device") \
    .cache()

revenue_opt4.count()

start = time.time()
revenue_opt4.groupBy("genre").agg(sum("ad_revenue")).collect()
print("Optimization 4 time:", time.time() - start)

revenue_opt4.unpersist()

Optimization 4 time: 0.20737957954406738


DataFrame[artist_id: string, user_id: string, event_id: string, genre: string, region: string, duration_sec: bigint, device: string, event_date: date, plan: string, country: string, cpm: double, per_stream_rate: double, label_type: string, ad_revenue: double, stream_payout: double]

In [11]:
#OPTIMIZATION 5 ‚Äî Filter Early
# What we fix

#If reports only need certain months, filter BEFORE joins.

# Code Example (Q1 only)
events_filtered = events.filter(col("month").isin([1, 2, 3]))

revenue_opt5 = events_filtered \
    .join(subscriptions_small, "user_id") \
    .join(
        broadcast(ad_rates_small),
        (col("genre") == col("ad_genre")) &
        (col("device") == col("ad_device"))
    ) \
    .join(
        broadcast(payouts_small),
        "artist_id"
    ) \
    .withColumn("ad_revenue", col("cpm") / 1000) \
    .withColumn("stream_payout", col("per_stream_rate")) \
    .cache()

revenue_opt5.count()

start = time.time()
revenue_opt5.groupBy("genre").agg(sum("ad_revenue")).collect()
print("Optimization 5 time:", time.time() - start)

revenue_opt5.unpersist()

Optimization 5 time: 0.15427231788635254


DataFrame[artist_id: string, user_id: string, event_id: string, genre: string, region: string, duration_sec: bigint, completed: boolean, device: string, event_date: date, month: int, plan: string, country: string, ad_genre: string, ad_device: string, cpm: double, per_stream_rate: double, label_type: string, ad_revenue: double, stream_payout: double]

In [12]:
print("=" * 60)
print("RUNNING FULLY OPTIMIZED PIPELINE")
print("=" * 60)

# Reset config
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "10485760")
spark.conf.set("spark.sql.shuffle.partitions", "8")

total_start = time.time()

# Build enriched DataFrame ONCE with all optimizations
revenue_opt = events \
    .select("event_id", "user_id", "artist_id", "genre", "region",
            "duration_sec", "completed", "device", "event_date", "month") \
    .join(subscriptions.select("user_id", "plan", "country"), "user_id") \
    .join(broadcast(ad_rates),
          (col("genre") == col("ad_genre")) & (col("device") == col("ad_device"))) \
    .join(broadcast(payouts), "artist_id") \
    .withColumn("ad_revenue", col("cpm") / 1000) \
    .withColumn("stream_payout", col("per_stream_rate")) \
    .drop("ad_genre", "ad_device")

# Cache the shared DataFrame
revenue_opt.cache()
cache_start = time.time()
row_count = revenue_opt.count()
cache_time = time.time() - cache_start
print(f"‚úÖ Cached {row_count} rows in {cache_time:.2f}s")

# Run all 6 reports from cache
r1 = revenue_opt.groupBy("genre").agg(sum("ad_revenue"), countDistinct("user_id")).collect()
r2 = revenue_opt.groupBy("region", "country").agg(count("*"), sum("ad_revenue")).collect()
r3 = revenue_opt.groupBy("plan").agg(countDistinct("user_id"), count("*"), avg("duration_sec")).collect()
r4 = revenue_opt.groupBy("device", "genre").agg(sum("ad_revenue"), count("*")).collect()
r5 = revenue_opt.groupBy("artist_id", "label_type") \
    .agg(sum("stream_payout"), count("*")) \
    .orderBy(desc("sum(stream_payout)")).limit(100).collect()
r6 = revenue_opt.groupBy("event_date") \
    .agg(count("*"), sum("ad_revenue"), countDistinct("user_id")) \
    .orderBy("event_date").collect()

optimized_total = time.time() - total_start

print(f"\n‚è±Ô∏è  OPTIMIZED TOTAL: {optimized_total:.2f}s")
print(f"‚è±Ô∏è  BASELINE TOTAL:  {baseline_total:.2f}s")
print(f"üìà SPEEDUP:          {baseline_total/optimized_total:.1f}x")
print(f"üìâ TIME SAVED:       {baseline_total - optimized_total:.2f}s ({(1-optimized_total/baseline_total)*100:.0f}%)")

# Verify the plan
print("\nOPTIMIZED PLAN:")
revenue_opt.groupBy("genre").agg(sum("ad_revenue")).explain(mode="formatted")

revenue_opt.unpersist()


RUNNING FULLY OPTIMIZED PIPELINE
‚úÖ Cached 600000 rows in 7.88s

‚è±Ô∏è  OPTIMIZED TOTAL: 19.98s
‚è±Ô∏è  BASELINE TOTAL:  190.12s
üìà SPEEDUP:          9.5x
üìâ TIME SAVED:       170.15s (89%)

OPTIMIZED PLAN:
== Physical Plan ==
* HashAggregate (26)
+- Exchange (25)
   +- * HashAggregate (24)
      +- InMemoryTableScan (1)
            +- InMemoryRelation (2)
                  +- * Project (23)
                     +- * BroadcastHashJoin Inner BuildRight (22)
                        :- * Project (17)
                        :  +- * BroadcastHashJoin Inner BuildRight (16)
                        :     :- * Project (11)
                        :     :  +- * BroadcastHashJoin Inner BuildRight (10)
                        :     :     :- * Filter (5)
                        :     :     :  +- * ColumnarToRow (4)
                        :     :     :     +- Scan parquet  (3)
                        :     :     +- BroadcastExchange (9)
                        :     :        +- * Filter (8)


DataFrame[artist_id: string, user_id: string, event_id: string, genre: string, region: string, duration_sec: bigint, completed: boolean, device: string, event_date: date, month: int, plan: string, country: string, cpm: double, per_stream_rate: double, label_type: string, ad_revenue: double, stream_payout: double]

Execution Plan Improvements

SortMergeJoin ‚Üí BroadcastHashJoin

6 full recomputations ‚Üí 1 cached computation

200 shuffle partitions ‚Üí 8

Reduced input schema size

Reduced shuffle spill

Performance Principles Demonstrated

Broadcast small dimension tables

Cache reused transformations

Reduce shuffle partitions to match cluster size

Prune unused columns

Filter early to reduce join size

Avoid recomputation of expensive joins

Create a final optimization report:


print("=" * 65)
print("OPTIMIZATION REPORT ‚Äî StreamPulse Revenue Pipeline")
print("=" * 65)

print(f"""
Pipeline: Revenue Analytics (6 reports from joined data)

CONFIGURATION CHANGES:
  spark.sql.autoBroadcastJoinThreshold: -1 ‚Üí 10MB
  spark.sql.shuffle.partitions: 200 ‚Üí 8
  spark.sql.adaptive.enabled: false ‚Üí (unchanged for testing)

CODE CHANGES:
  1. broadcast() on ad_rates (48 rows) and payouts (5K rows)
  2. .cache() on enriched DataFrame (built once, used 6 times)
  3. Column pruning on all source tables
  4. Single build_revenue() call instead of 6 separate calls

RESULTS:
  Baseline:  {baseline_total:.2f}s
  Optimized: {optimized_total:.2f}s
  Speedup:   {baseline_total/optimized_total:.1f}x

PLAN IMPROVEMENTS:
  - SortMergeJoin ‚Üí BroadcastHashJoin (ad_rates, payouts)
  - 6 full recomputations ‚Üí 1 computation + 5 cache reads
  - 200 shuffle partitions ‚Üí 8 (matched to local cores)
  - ReadSchema reduced (column pruning)
""")