## Reddit Practice Problem: Surging Subreddits

### The Scenario

You're an Analytics Engineering Lead at Reddit. The "Discovery" team wants to identify subreddits that are experiencing a sudden surge in user engagement. They hypothesize that a rapid increase in the number of comments per post is a key indicator of a subreddit "going viral" or becoming a hot topic.

Your task is to build a data pipeline that identifies the **top 10 subreddits with the highest week-over-week percentage growth in their average comments per post**.

### Input DataFrames

You are given two primary sources of data:

1.  **`posts` DataFrame**: Contains information about each post submitted.

      * `post_id` (string)
      * `subreddit_id` (string)
      * `created_utc` (timestamp)

2.  **`comments` DataFrame**: A log of all comments made on posts.

      * `comment_id` (string)
      * `post_id` (string)
      * `created_utc` (timestamp)

### The Task

Write a PySpark script that performs the following steps:

1.  **Calculate Weekly Comment Counts:** For each post, determine how many comments it received. Then, aggregate this data to find the total number of comments and the total number of posts for each subreddit, for each week.
      * A "week" can be defined using the `weekofyear` function.
2.  **Calculate Average Comments Per Post:** Using the weekly aggregated data from Step 1, calculate the average number of comments per post for each subreddit for each week.
3.  **Find Previous Week's Average:** For each subreddit and each week, you need to find the average comments per post from the *previous* week. This is the key step and will require a **window function** partitioned by subreddit and ordered by week.
4.  **Calculate Week-over-Week Growth:** Calculate the percentage growth from the previous week's average to the current week's average. The formula is: `((current_avg - previous_avg) / previous_avg) * 100`.
      * Handle cases where the previous week's average was zero to avoid division errors.
5.  **Filter and Rank:** Filter for the most recent complete week in the dataset. Then, rank the subreddits by their week-over-week growth in descending order and return the top 10.

### Expected Final Output

A DataFrame with the following schema, showing the top 10 surging subreddits for the most recent week:

  * `subreddit_id`
  * `week`
  * `avg_comments_per_post`
  * `previous_week_avg_comments`
  * `wow_growth_percentage`
  * `rank`

-----

### Follow-up Questions for a Lead Role

1.  **Optimization:** This job could be slow if the `comments` table is huge. How would you optimize the initial join between `posts` and `comments`? What if there's data skew in a few very popular posts? (Probes knowledge of partitioning, broadcasting, and salting).
2.  **Data Modeling:** How would you productionize this logic? Would you create an intermediate, aggregated table? What would be the "grain" of that table (daily, weekly)? How would you handle late-arriving data?
3.  **Edge Cases:** What are the flaws in this "surge" logic? What if a subreddit is new and has no "previous week" data? How should its growth be represented? What if a subreddit's activity is so low that going from 1 comment to 2 registers as 100% growth? How might you refine the logic to account for this?
4.  **Definition of "Week":** We used `weekofyear`. What are the potential issues with this function, especially around the new year? (e.g., Week 52 vs. Week 1). What is a more robust way to define a "week"? (Probes deeper SQL/Spark function knowledge, like using `date_trunc`).

-----

### PySpark Solution Example

```python
from pyspark.sql import SparkSession, Window
from pyspark.sql.functions import col, count, weekofyear, lag, when, lit, desc, rank

# --- 1. Setup Spark Session and Create Dummy Data ---
spark = SparkSession.builder.appName("RedditSurgeAnalysis").getOrCreate()

posts_data = [("p1", "sub1", "2025-08-20 10:00:00"), ("p2", "sub1", "2025-08-28 11:00:00"),
              ("p3", "sub2", "2025-08-21 12:00:00"), ("p4", "sub2", "2025-08-29 13:00:00"),
              ("p5", "sub3", "2025-08-22 14:00:00"), ("p6", "sub3", "2025-08-30 15:00:00")]
posts = spark.createDataFrame(posts_data, ["post_id", "subreddit_id", "created_utc"]) \
    .withColumn("created_utc", col("created_utc").cast("timestamp"))

comments_data = [("c1", "p1"), ("c2", "p1"), ("c3", "p2"), # sub1
                 ("c4", "p3"), ("c5", "p4"), ("c6", "p4"), ("c7", "p4"), # sub2
                 ("c8", "p5"), ("c9", "p6"), ("c10", "p6")] # sub3
comments = spark.createDataFrame(comments_data, ["comment_id", "post_id"])

# --- 2. Calculate Weekly Comment Counts & Averages ---
comments_per_post = comments.groupBy("post_id").agg(count("comment_id").alias("comment_count"))

post_comments = posts.join(comments_per_post, "post_id")

weekly_agg = post_comments.groupBy("subreddit_id", weekofyear("created_utc").alias("week")) \
    .agg(
        count("post_id").alias("num_posts"),
        sum("comment_count").alias("num_comments")
    )

weekly_avg = weekly_agg.withColumn("avg_comments_per_post", col("num_comments") / col("num_posts"))

# --- 3. Use Window Function to Get Previous Week's Data ---
subreddit_window = Window.partitionBy("subreddit_id").orderBy("week")

weekly_comparison = weekly_avg.withColumn(
    "previous_week_avg_comments",
    lag(col("avg_comments_per_post"), 1).over(subreddit_window)
)

# --- 4. Calculate Week-over-Week Growth ---
wow_growth = weekly_comparison.withColumn(
    "wow_growth_percentage",
    when(
        (col("previous_week_avg_comments").isNotNull()) & (col("previous_week_avg_comments") != 0),
        ((col("avg_comments_per_post") - col("previous_week_avg_comments")) / col("previous_week_avg_comments")) * 100
    ).otherwise(lit(None)) # Handle new subreddits or weeks with 0 comments
).filter(col("wow_growth_percentage").isNotNull())

# --- 5. Filter for a Specific Week and Rank ---
# For this example, we'll assume the most recent week is 35
most_recent_week = 35 
final_ranked = wow_growth.filter(col("week") == most_recent_week) \
    .withColumn("rank", rank().over(Window.orderBy(desc("wow_growth_percentage")))) \
    .filter(col("rank") <= 10)

# --- Show Final Result ---
print("Top 10 Surging Subreddits:")
final_ranked.show()

# Expected output from dummy data:
# +------------+----+---------+------------+---------------------+--------------------------+---------------------+----+
# |subreddit_id|week|num_posts|num_comments|avg_comments_per_post|previous_week_avg_comments|wow_growth_percentage|rank|
# +------------+----+---------+------------+---------------------+--------------------------+---------------------+----+
# |        sub2|  35|        1|           3|                  3.0|                       1.0|                200.0|   1|
# |        sub3|  35|        1|           2|                  2.0|                       1.0|                100.0|   2|
# |        sub1|  35|        1|           1|                  1.0|                       2.0|                -50.0|   3|
# +------------+----+---------+------------+---------------------+--------------------------+---------------------+----+

```

In [2]:
from pyspark.sql import SparkSession, Window
from pyspark.sql.functions import col, count, weekofyear, lag, when, lit, desc, rank

# --- 1. Setup Spark Session and Create Dummy Data ---
spark = SparkSession.builder.appName("RedditSurgeAnalysis").getOrCreate()

posts_data = [("p1", "sub1", "2025-08-20 10:00:00"), ("p2", "sub1", "2025-08-28 11:00:00"),
              ("p3", "sub2", "2025-08-21 12:00:00"), ("p4", "sub2", "2025-08-29 13:00:00"),
              ("p5", "sub3", "2025-08-22 14:00:00"), ("p6", "sub3", "2025-08-30 15:00:00")]
posts = spark.createDataFrame(posts_data, ["post_id", "subreddit_id", "created_utc"]) \
    .withColumn("created_utc", col("created_utc").cast("timestamp"))

comments_data = [("c1", "p1", "2025-08-30 15:01:00"), ("c2", "p1", "2025-08-30 15:02:00"), ("c3", "p2", "2025-08-30 15:03:00"), # sub1
                 ("c4", "p3", "2025-08-30 15:04:00"), ("c5", "p4", "2025-08-30 15:05:00"), ("c6", "p4", "2025-08-30 15:06:00"), ("c7", "p4", "2025-08-30 15:07:00"), # sub2
                 ("c8", "p5", "2025-08-30 15:08:00"), ("c9", "p6", "2025-08-30 15:09:00"), ("c10", "p6", "2025-08-30 15:11:00")] # sub3
comments = spark.createDataFrame(comments_data, ["comment_id", "post_id", "created_utc"]) \
    .withColumn("created_utc", col("created_utc").cast("timestamp"))

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/09/07 12:14:21 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/09/07 12:14:22 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
25/09/07 12:14:22 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.
25/09/07 12:14:22 WARN Utils: Service 'SparkUI' could not bind on port 4042. Attempting port 4043.


In [4]:
comments_per_post = comments.groupBy("post_id").agg(count("comment_id").alias("comment_count"))

post_comments = posts.join(comments_per_post, "post_id")
post_comments.show()

[Stage 3:>                                                          (0 + 4) / 4]

+-------+------------+-------------------+-------------+
|post_id|subreddit_id|        created_utc|comment_count|
+-------+------------+-------------------+-------------+
|     p1|        sub1|2025-08-20 10:00:00|            2|
|     p2|        sub1|2025-08-28 11:00:00|            1|
|     p3|        sub2|2025-08-21 12:00:00|            1|
|     p4|        sub2|2025-08-29 13:00:00|            3|
|     p5|        sub3|2025-08-22 14:00:00|            1|
|     p6|        sub3|2025-08-30 15:00:00|            2|
+-------+------------+-------------------+-------------+



                                                                                

In [5]:
weekly_agg = post_comments.groupBy("subreddit_id", weekofyear("created_utc").alias("week")) \
    .agg(
        count("post_id").alias("num_posts"),
        sum("comment_count").alias("num_comments")
    )

weekly_avg = weekly_agg.withColumn("avg_comments_per_post", col("num_comments") / col("num_posts"))


TypeError: unsupported operand type(s) for +: 'int' and 'str'

25/09/08 11:14:42 WARN HeartbeatReceiver: Removing executor driver with no recent heartbeats: 8079839 ms exceeds timeout 120000 ms
25/09/08 11:14:42 WARN SparkContext: Killing executors is not supported by current scheduler.
25/09/08 11:14:49 ERROR Inbox: Ignoring error
org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.SparkThreadUtils$.awaitResult(SparkThreadUtils.scala:53)
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:342)
	at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRefByURI(RpcEnv.scala:102)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRef(RpcEnv.scala:110)
	at org.apache.spark.util.RpcUtils$.makeDriverRef(RpcUtils.scala:36)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.driverEndpoint$lzycompute(BlockManagerMasterEndpoint.scala:132)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.org$apache$spark$storage$BlockManagerMasterEndpoint$

In [None]:

# --- 3. Use Window Function to Get Previous Week's Data ---
subreddit_window = Window.partitionBy("subreddit_id").orderBy("week")

weekly_comparison = weekly_avg.withColumn(
    "previous_week_avg_comments",
    lag(col("avg_comments_per_post"), 1).over(subreddit_window)
)

# --- 4. Calculate Week-over-Week Growth ---
wow_growth = weekly_comparison.withColumn(
    "wow_growth_percentage",
    when(
        (col("previous_week_avg_comments").isNotNull()) & (col("previous_week_avg_comments") != 0),
        ((col("avg_comments_per_post") - col("previous_week_avg_comments")) / col("previous_week_avg_comments")) * 100
    ).otherwise(lit(None)) # Handle new subreddits or weeks with 0 comments
).filter(col("wow_growth_percentage").isNotNull())

# --- 5. Filter for a Specific Week and Rank ---
# For this example, we'll assume the most recent week is 35
most_recent_week = 35 
final_ranked = wow_growth.filter(col("week") == most_recent_week) \
    .withColumn("rank", rank().over(Window.orderBy(desc("wow_growth_percentage")))) \
    .filter(col("rank") <= 10)

# --- Show Final Result ---
print("Top 10 Surging Subreddits:")
final_ranked.show()