- **Name:** 18_dataframe_optimizations
- **Author:** Shamas Imran
- **Desciption:** Optimizing DataFrame performance
- **Date:** 19-Aug-2025
<!--
REVISION HISTORY
Version          Date        Author           Desciption
01           19-Aug-2025   Shamas Imran       Demonstrated caching and persistence  
                                              Used repartition and coalesce  
                                              Analyzed execution plan with explain  
-->

In [0]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

spark = SparkSession.builder.appName("DataFrameOptimizations").getOrCreate()

# Sample DataFrames
df_large = spark.range(0, 1_000_000).withColumn("value", (F.rand() * 100).cast("int"))
df_small = spark.createDataFrame([(i, f"cat_{i}") for i in range(100)], ["id", "category"])

In [0]:
                                        # ---------------------------------------------------------
                                        # 1. Repartition vs Coalesce
                                        # ---------------------------------------------------------
# Repartition increases or decreases partitions (full shuffle)
df_repart = df_large.repartition(10)

# print(df_repart.rdd.glom().collect()) 
count_df_repart = "You cannot use RDD APIs (like .rdd.getNumPartitions()) on Databricks serverless compute"  #df_repart.rdd.getNumPartitions() 
print("Number of partitions after repartition:", count_df_repart)

# Coalesce only decreases partitions (avoids full shuffle if possible)
df_coalesce = df_large.coalesce(2)
count_df_coalesce = 0 # df_coalesce.rdd.getNumPartitions()
print("Number of partitions after coalesce:", count_df_coalesce)


In [0]:
                                        # ---------------------------------------------------------
                                        # 2. Caching and Persistence
                                        # ---------------------------------------------------------
# Cache: stores the DataFrame in memory (default MEMORY_AND_DISK)
df_cached = df_large.cache()
print("First count triggers caching:", df_cached.count())  # first action
print("Second count reads from cache:", df_cached.count())

# Persist: allows specifying storage levels
from pyspark import StorageLevel
df_persist = df_large.persist(StorageLevel.MEMORY_ONLY)

In [0]:
# ---------------------------------------------------------
# 3. Broadcast Join Optimization
# ---------------------------------------------------------
# Broadcast small DataFrame to avoid shuffle join
df_join = df_large.join(F.broadcast(df_small), df_large.value == df_small.id, "inner")
df_join.explain()  # Check execution plan for BroadcastHashJoin

In [0]:
# ---------------------------------------------------------
# 4. Handling Data Skew
# ---------------------------------------------------------
# Simulate skew: most rows have value = 1
df_skewed = df_large.withColumn("key", F.when(F.rand() < 0.9, 1).otherwise(F.col("value")))

# Skew mitigation: add a "salt" key to spread data
df_skewed_salted = df_skewed.withColumn("salt", (F.rand() * 10).cast("int"))
df_small_salted = df_small.withColumn("salt", F.expr("explode(sequence(0,9))"))

df_salted_join = df_skewed_salted.join(
    df_small_salted,
    (df_skewed_salted.key == df_small_salted.id) & (df_skewed_salted.salt == df_small_salted.salt),
    "inner"
)

In [0]:
# ---------------------------------------------------------
# 5. Measuring Execution Plans
# ---------------------------------------------------------
# Explain gives logical and physical execution plans
df_join.explain(True)  # Detailed execution plan
