## Spark for beginners

In [None]:
from pyspark.sql import SparkSession
import re

spark = SparkSession.builder.getOrCreate()

#### Problem 1: Count word frequency from a text file (ignore case, strip punctuation).


Shows RDD API: flatMap, map, reduceByKey.

Avoid groupByKey for counting — reduceByKey is more efficient (combiner).

Interview tip: Explain lazy evaluation and how collect() triggers computation.

In [3]:
sc = spark.sparkContext 

lines = sc.parallelize([ 
    "Hello world",
    "Hello Spark",
    "world of Spark"
]) 

def normalize(line): 
    return line.split(' ') # re.findall(r"\w+", line.lower()) 

counts = (lines
          .flatMap(lambda x: x.lower().split(' '))  # .flatMap(normalize)        # split into words 
          .map(lambda w: (w, 1))  # pair 
          .reduceByKey(lambda a,b: a+b)  # aggregate 
         )

counts.collect()
# for key, value in counts.collect(): 
#     print(key, value) 



                                                                                

[('of', 1), ('world', 2), ('hello', 2), ('spark', 2)]

#### Problem 2: Given a CSV of user events (user_id, event_type, timestamp), compute number of click events per user.

Use DataFrame API. Avoid collecting large result sets.

If reading files, use spark.read.csv(..., header=True, schema=schema).

Interview tip: Mention schema inference costs and why passing schema is better for performance.

In [99]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()

df = spark.read.csv('spark-warehouse/user_event.csv', header=True)
df.show()

+-------+----------+-------------------+
|user_id|event_type|          timestamp|
+-------+----------+-------------------+
|      1|     click|2025-09-30 12:00:00|
|      1|      view|2025-09-30 12:00:10|
|      2|     click|2025-09-30 12:01:00|
|      1|     click|2025-09-30 12:02:00|
+-------+----------+-------------------+



In [98]:
from pyspark.sql.types import StructType, StructField, IntegerType, StringType
schema = StructType([
    StructField("user_id", IntegerType(), True),
    StructField("event_type", StringType(), True),
    StructField("timestamp", StringType(), True)
])

df = spark.read.csv('spark-warehouse/user_event.csv', schema = schema, sep=',', header=True, mode='FAILFAST')
df.show() 

+-------+----------+-------------------+
|user_id|event_type|          timestamp|
+-------+----------+-------------------+
|      1|     click|2025-09-30 12:00:00|
|      1|      view|2025-09-30 12:00:10|
|      2|     click|2025-09-30 12:01:00|
|      1|     click|2025-09-30 12:02:00|
+-------+----------+-------------------+



In [106]:
df.select('event_type').distinct().show() 

+----------+
|event_type|
+----------+
|      view|
|     click|
+----------+



In [108]:
from pyspark.sql.functions import col
df.filter(col("event_type")=="click").show() 

+-------+----------+-------------------+
|user_id|event_type|          timestamp|
+-------+----------+-------------------+
|      1|     click|2025-09-30 12:00:00|
|      2|     click|2025-09-30 12:01:00|
|      1|     click|2025-09-30 12:02:00|
+-------+----------+-------------------+



In [117]:
df.filter(col("event_type")=="click") \
.groupBy("user_id") \
.count() \
.withColumnRenamed("count", "click_count_per_user") \
.show() 


+-------+--------------------+
|user_id|click_count_per_user|
+-------+--------------------+
|      1|                   2|
|      2|                   1|
+-------+--------------------+



### 3. Top N per group using Window functions

#### Problem: Given sales(product_id, category, amount), for each category return top 3 products by total sales.

Use Window and row_number() (not rank() if you want strict top-N).

If groups are huge, window can be expensive; consider pre-aggregating and using limit per partition techniques for scale.

Interview tip: Explain difference between rank, dense_rank, row_number.


In [135]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()

rows = [
    (1,"A",100),(2,"A",200),(3,"A",150),(4,"A",130),
    (5,"B",300),(6,"B",100)
]

df = spark.createDataFrame(rows, schema=["product_id", "category", "amount"])
df.show()

+----------+--------+------+
|product_id|category|amount|
+----------+--------+------+
|         1|       A|   100|
|         2|       A|   200|
|         3|       A|   150|
|         4|       A|   130|
|         5|       B|   300|
|         6|       B|   100|
+----------+--------+------+



In [154]:
from pyspark.sql.functions import sum as _sum
df.groupBy("product_id", "category").agg(_sum("amount").alias("total_sales")).show()


+----------+--------+-----------+
|product_id|category|total_sales|
+----------+--------+-----------+
|         1|       A|        100|
|         2|       A|        200|
|         3|       A|        150|
|         4|       A|        130|
|         5|       B|        300|
|         6|       B|        100|
+----------+--------+-----------+



Using row_number()

In [171]:
agg_df = df.groupBy("product_id", "category").agg(_sum("amount").alias("total_sales"))
from pyspark.sql.functions import row_number
from pyspark.sql.window import Window
(
    agg_df.withColumn("rn", row_number().over(
            Window.partitionBy("category").orderBy(agg_df["total_sales"].desc())
        )
    )
    .filter(col("rn")<=3)
    .drop("rn")
    .show()
)

+----------+--------+-----------+
|product_id|category|total_sales|
+----------+--------+-----------+
|         2|       A|        200|
|         3|       A|        150|
|         4|       A|        130|
|         5|       B|        300|
|         6|       B|        100|
+----------+--------+-----------+



Using dense_rank()

In [172]:
agg_df = df.groupBy("product_id", "category").agg(_sum("amount").alias("total_sales"))
from pyspark.sql.functions import row_number, dense_rank
from pyspark.sql.window import Window
(
    agg_df.withColumn("rn", dense_rank().over(
        Window.partitionBy("category").orderBy(agg_df["total_sales"].desc())
    )
)
.filter(col("rn")<=3)
.drop("rn")
.show())

+----------+--------+-----------+
|product_id|category|total_sales|
+----------+--------+-----------+
|         2|       A|        200|
|         3|       A|        150|
|         4|       A|        130|
|         5|       B|        300|
|         6|       B|        100|
+----------+--------+-----------+



### 4. Efficient join: broadcast join when one table is small

#### Problem: Join orders (very large) with country_lookup (small) to add country_name to each order.

Use broadcast() to avoid large shuffle when one side is small.

Spark may auto-broadcast small datasets (config spark.sql.autoBroadcastJoinThreshold) — mention this.

Interview tip: discuss memory tradeoffs: broadcasting increases driver/executor memory usage.