In [None]:
# ==========================================
# Task 10: Spark Query Optimization (Colab)
# Explain Plan + Partitioning + Benchmark + Cache
# ==========================================

!pip -q install pyspark

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, rand, when
import time

spark = SparkSession.builder \
    .appName("Task10-Optimization") \
    .getOrCreate()

print("Spark Started")

# ------------------------------------------
# 1) Create a demo silver.events dataset
# ------------------------------------------
n = 2_000_00  # adjust if Colab becomes slow (try 200k, 500k, 1M)

df = spark.range(0, n).withColumnRenamed("id", "row_id")

events = df.select(
    (col("row_id") % 50000).alias("user_id"),
    (col("row_id") % 2000).alias("product_id"),
    when((col("row_id") % 3) == 0, "purchase")
    .when((col("row_id") % 3) == 1, "view")
    .otherwise("cart").alias("event_type"),
    (col("row_id") % 30 + 1).alias("event_date"),
    (rand() * 500).alias("price")
)

events.createOrReplaceTempView("silver_events")

print("Created temp view: silver_events")
print("Rows:", events.count())

# ------------------------------------------
# 2) Explain Query Plan
# ------------------------------------------
print("\n 1) EXPLAIN Query Plan")
spark.sql("SELECT * FROM silver_events WHERE event_type = 'purchase'").explain(True)

# ------------------------------------------
# 3) Partitioned Table (Parquet partitioning simulation)
# In Databricks: PARTITIONED BY (event_date, event_type)
# Here: write as partitioned parquet to disk
# ------------------------------------------
output_path = "/content/silver_events_part"

print("\n 2) Writing partitioned table (parquet) to simulate partition pruning...")
events.write.mode("overwrite").partitionBy("event_date", "event_type").parquet(output_path)

events_part = spark.read.parquet(output_path)
events_part.createOrReplaceTempView("silver_events_part")

print(" Partitioned table created & loaded: silver_events_part")

# ------------------------------------------
# 4) Benchmark: query on non-partitioned vs partitioned
# ------------------------------------------
def benchmark(query, label):
    start = time.time()
    cnt = spark.sql(query).count()
    end = time.time()
    print(f" {label} | rows = {cnt} | time = {end-start:.3f}s")

print("\n 3) Benchmark (non-partitioned vs partitioned)")

benchmark(
    "SELECT * FROM silver_events WHERE event_type='purchase' AND event_date = 10",
    "Non-partitioned filter"
)

benchmark(
    "SELECT * FROM silver_events_part WHERE event_type='purchase' AND event_date = 10",
    "Partitioned filter (pruning expected)"
)

# ------------------------------------------
# 5) Cache for iterative queries
# ------------------------------------------
print("\n 4) Cache table for iterative queries")
events_part.cache()          # Cache dataframe
events_part.count()          # Materialize cache

benchmark(
    "SELECT * FROM silver_events_part WHERE user_id=12345",
    "After cache: user_id filter"
)

benchmark(
    "SELECT * FROM silver_events_part WHERE user_id=12345",
    "After cache (repeat): should be faster"
)

print("\n Task 10 Completed Successfully!")

Spark Started
Created temp view: silver_events
Rows: 200000

 1) EXPLAIN Query Plan
== Parsed Logical Plan ==
'Project [*]
+- 'Filter ('event_type = purchase)
   +- 'UnresolvedRelation [silver_events], [], false

== Analyzed Logical Plan ==
user_id: bigint, product_id: bigint, event_type: string, event_date: bigint, price: double
Project [user_id#2L, product_id#3L, event_type#4, event_date#5L, price#6]
+- Filter (event_type#4 = purchase)
   +- SubqueryAlias silver_events
      +- View (`silver_events`, [user_id#2L, product_id#3L, event_type#4, event_date#5L, price#6])
         +- Project [(row_id#1L % cast(50000 as bigint)) AS user_id#2L, (row_id#1L % cast(2000 as bigint)) AS product_id#3L, CASE WHEN ((row_id#1L % cast(3 as bigint)) = cast(0 as bigint)) THEN purchase WHEN ((row_id#1L % cast(3 as bigint)) = cast(1 as bigint)) THEN view ELSE cart END AS event_type#4, ((row_id#1L % cast(30 as bigint)) + cast(1 as bigint)) AS event_date#5L, (rand(5601479030039924967) * cast(500 as double))