# Spark Plan Analysis and Query Optimization

This notebook provides a comprehensive guide to analyzing Spark's logical and physical execution plans and using those insights to optimize your Spark SQL queries and DataFrame operations.

## Table of Contents

1. Introduction to Spark Execution Plans
2. Setting Up Our Environment
3. Key Optimization Areas in Spark Plans
   - Scan Operations and Table Statistics
   - Filter Pushdown
   - Join Strategies
   - Partition Pruning
   - Shuffle Operations
   - Data Skew
   - Whole-Stage Codegen
   - Caching
4. End-to-End Optimization Example
5. Best Practices

Let's get started by understanding what Spark execution plans are and why they matter.

## 1. Introduction to Spark Execution Plans

Spark uses a query optimizer called Catalyst to transform your high-level DataFrame operations or SQL queries into an efficient execution plan. This process involves several phases:

1. **Unresolved Logical Plan**: Initial representation of your query
2. **Resolved Logical Plan**: Column and table references are resolved
3. **Optimized Logical Plan**: Catalyst applies rule-based optimizations
4. **Physical Plan**: Converts logical plan to actual execution strategy
5. **Selected Physical Plan**: The most efficient execution plan is chosen
6. **Executed Plan**: The final plan after adaptations during runtime

Understanding these plans is key to optimizing Spark performance. The `explain()` method is our primary tool for examining these plans.

## 2. Setting Up Our Environment

Let's start by creating a Spark session and some sample data for our experiments:

In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, count, avg, sum, max, min, lit, concat, broadcast, expr
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, DateType
import time

# Initialize Spark Session
spark = SparkSession.builder \
    .appName("Spark Plan Analysis") \
    .config("spark.sql.shuffle.partitions", 10) \
    .getOrCreate()

print(f"Spark version: {spark.version}")
print("Session initialized successfully!")

/opt/spark/bin/load-spark-env.sh: line 68: ps: command not found
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/04/18 09:25:16 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


Spark version: 3.5.1
Session initialized successfully!


----------------------------------------
Exception occurred during processing of request from ('127.0.0.1', 34790)
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/socketserver.py", line 316, in _handle_request_noblock
    self.process_request(request, client_address)
  File "/usr/local/lib/python3.10/socketserver.py", line 347, in process_request
    self.finish_request(request, client_address)
  File "/usr/local/lib/python3.10/socketserver.py", line 360, in finish_request
    self.RequestHandlerClass(request, client_address, self)
  File "/usr/local/lib/python3.10/socketserver.py", line 747, in __init__
    self.handle()
  File "/usr/local/lib/python3.10/site-packages/pyspark/accumulators.py", line 295, in handle
    poll(accum_updates)
  File "/usr/local/lib/python3.10/site-packages/pyspark/accumulators.py", line 267, in poll
    if self.rfile in r and func():
  File "/usr/local/lib/python3.10/site-packages/pyspark/accumulators.py", line 271, in accum_updates
   

In [2]:
# Create sample employee data
employee_schema = StructType([
    StructField("id", IntegerType(), False),
    StructField("name", StringType(), False),
    StructField("dept_id", IntegerType(), False),
    StructField("salary", DoubleType(), False),
    StructField("hire_date", DateType(), False)
])

# Department schema
department_schema = StructType([
    StructField("id", IntegerType(), False),
    StructField("department_name", StringType(), False),
    StructField("location", StringType(), False),
    StructField("budget", DoubleType(), True)
])

# Generate sample data
import random
from datetime import datetime, timedelta

# Generate departments
dept_data = [
    (1, "Engineering", "San Francisco", 1000000.0),
    (2, "Sales", "New York", 800000.0),
    (3, "Marketing", "Chicago", 600000.0),
    (4, "HR", "Seattle", 400000.0),
    (5, "Finance", "Boston", 750000.0)
]

departments_df = spark.createDataFrame(dept_data, department_schema)
departments_df.createOrReplaceTempView("departments")

# Generate employees
employee_data = []
names = ["John", "Emma", "Michael", "Sophia", "James", "Olivia", "William", "Ava", "Alexander", "Mia"]
surnames = ["Smith", "Johnson", "Williams", "Brown", "Jones", "Miller", "Davis", "Garcia", "Rodriguez", "Wilson"]
start_date = datetime(2015, 1, 1)
end_date = datetime(2023, 12, 31)
delta = (end_date - start_date).days

for i in range(1, 1001):  # Generate 1000 employees
    name = f"{random.choice(names)} {random.choice(surnames)}"
    dept_id = random.randint(1, 5)
    salary = round(random.uniform(50000, 150000), 2)
    random_days = random.randint(0, delta)
    hire_date = start_date + timedelta(days=random_days)
    employee_data.append((i, name, dept_id, salary, hire_date))

employees_df = spark.createDataFrame(employee_data, employee_schema)
employees_df.createOrReplaceTempView("employees")

# Also create a partitioned version of our employees table
employees_df.write.mode("overwrite").partitionBy("dept_id").saveAsTable("partitioned_employees")

print(f"Created sample datasets with {employees_df.count()} employees and {departments_df.count()} departments")

25/04/18 09:25:19 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
25/04/18 09:25:19 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 84.44% for 9 writers
25/04/18 09:25:19 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 76.00% for 10 writers
25/04/18 09:25:19 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 69.09% for 11 writers
25/04/18 09:25:19 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 63.33% for 12 writers
25/04/18 09:25:20 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 69.09% for 11 writers
25/04/18 09:25:20 WARN MemoryManager: Total allocation exceeds 95.

Created sample datasets with 1000 employees and 5 departments


## 3. Key Optimization Areas in Spark Plans

Now let's examine the key areas we need to look at in Spark plans for optimization.

### 3.1 Scan Operations and Table Statistics

Scan operations determine how Spark reads data from sources. Inefficient scans can severely impact performance.

#### What to Look For in the Plan:
- Full table scans vs. filtered scans
- Reading too many columns
- Missing statistics
- File format efficiency

In [3]:
# Example 1: Analyzing a simple scan operation
query1 = employees_df.select("*")
print("Full table scan with all columns:")
query1.explain()

Full table scan with all columns:
== Physical Plan ==
*(1) Scan ExistingRDD[id#8,name#9,dept_id#10,salary#11,hire_date#12]




In [4]:
# Example 2: More efficient scan with column pruning
query2 = employees_df.select("id", "name", "salary")
print("Column pruning in action:")
query2.explain()

Column pruning in action:
== Physical Plan ==
*(1) Project [id#8, name#9, salary#11]
+- *(1) Scan ExistingRDD[id#8,name#9,dept_id#10,salary#11,hire_date#12]




#### Table Statistics

Statistics help the Spark optimizer make better decisions about join strategies, broadcast size limits, etc.

In [5]:
# # Create permanent tables instead of temp views
# # if you want to create permanent tables
# employees_df.write.saveAsTable("employees_table")
# departments_df.write.saveAsTable("departments_table")

# # Now you can analyze them
# spark.sql("ANALYZE TABLE employees_table COMPUTE STATISTICS")
# spark.sql("ANALYZE TABLE departments_table COMPUTE STATISTICS")

In [6]:
# You can analyze the partitioned table instead, which is permanent
spark.sql("ANALYZE TABLE partitioned_employees COMPUTE STATISTICS")
spark.sql("ANALYZE TABLE partitioned_employees COMPUTE STATISTICS FOR COLUMNS id, name, salary")

# Check statistics
print("Table Statistics for partitioned_employees:")
spark.sql("DESCRIBE TABLE EXTENDED partitioned_employees").filter(col("col_name").like("%Statistics%")).show(truncate=False)

Table Statistics for partitioned_employees:
+----------+----------------------+-------+
|col_name  |data_type             |comment|
+----------+----------------------+-------+
|Statistics|99624 bytes, 1000 rows|       |
+----------+----------------------+-------+



#### Best Practices for Scan Optimization:

1. **Select only necessary columns**: Reduces I/O and memory usage
2. **Use appropriate file formats**: Parquet/ORC > CSV/JSON for analytical workloads
3. **Compute and maintain statistics**: For better query planning
4. **Use partitioning**: For large tables to enable partition pruning

### 3.2 Filter Pushdown

Filter pushdown is the ability to push filter conditions down closer to the data source, reducing the amount of data that needs to be loaded.

#### What to Look For in the Plan:
- `PushedFilters` in the scan operation
- Filters applied before or after reading data

In [7]:
# Example 1: Checking if filters are pushed down
query3 = employees_df.filter(col("salary") > 100000)
print("Filter on in-memory DataFrame:")
query3.explain()

Filter on in-memory DataFrame:
== Physical Plan ==
*(1) Filter (salary#11 > 100000.0)
+- *(1) Scan ExistingRDD[id#8,name#9,dept_id#10,salary#11,hire_date#12]




In [8]:
# Example 2: Filter pushdown to Parquet files
parquet_employees = spark.read.table("partitioned_employees")
query4 = parquet_employees.filter(col("salary") > 100000)
print("Filter on Parquet-backed DataFrame:")
query4.explain()

Filter on Parquet-backed DataFrame:
== Physical Plan ==
*(1) Filter (isnotnull(salary#10176) AND (salary#10176 > 100000.0))
+- *(1) ColumnarToRow
   +- FileScan parquet spark_catalog.default.partitioned_employees[id#10174,name#10175,salary#10176,hire_date#10177,dept_id#10178] Batched: true, DataFilters: [isnotnull(salary#10176), (salary#10176 > 100000.0)], Format: Parquet, Location: CatalogFileIndex(1 paths)[file:/opt/spark/work-dir/spark-warehouse/partitioned_employees], PartitionFilters: [], PushedFilters: [IsNotNull(salary), GreaterThan(salary,100000.0)], ReadSchema: struct<id:int,name:string,salary:double,hire_date:date>




#### Best Practices for Filter Optimization:

1. **Apply filters as early as possible**: Put filters before joins or aggregations
2. **Use file formats that support predicate pushdown**: Parquet, ORC
3. **Filter on partitioned columns**: For partition pruning
4. **Use compatible filter expressions**: Some complex expressions can't be pushed down

Let's compare the performance impact of filter pushdown:

In [9]:
# Performance comparison: filter pushdown vs. no pushdown
def time_execution(df):
    start = time.time()
    count = df.count()  # Force execution
    end = time.time()
    return end - start, count

# With pushdown (filter then join)
start = time.time()
result1 = employees_df.filter(col("salary") > 80000) \
    .join(departments_df, employees_df["dept_id"] == departments_df["id"]) \
    .select("name", "department_name", "salary")
time1, count1 = time_execution(result1)

# Without pushdown (join then filter)
start = time.time()
result2 = employees_df \
    .join(departments_df, employees_df["dept_id"] == departments_df["id"]) \
    .filter(col("salary") > 80000) \
    .select("name", "department_name", "salary")
time2, count2 = time_execution(result2)

print(f"Filter then Join: {time1:.4f} seconds, {count1} records")
print(f"Join then Filter: {time2:.4f} seconds, {count2} records")
print(f"Performance difference: {(time2/time1):.2f}x")

Filter then Join: 0.3623 seconds, 699 records
Join then Filter: 0.3026 seconds, 699 records
Performance difference: 0.84x


### 3.3 Join Strategies

Spark supports several join strategies, and choosing the right one can significantly impact performance.

#### Common Join Types in Spark:

1. **Broadcast Hash Join**: Small table is broadcasted to all executors
2. **Shuffle Hash Join**: Both tables are shuffled by join key
3. **Sort Merge Join**: Both tables are sorted and then merged
4. **Broadcast Nested Loop Join**: Used for cross joins and some non-equi joins

#### What to Look For in the Plan:
- The join strategy being used
- Broadcast hints being applied
- Join order optimization

In [10]:
# Example 1: Letting Spark choose the join strategy
auto_join = employees_df.join(departments_df, employees_df["dept_id"] == departments_df["id"])
print("Automatic Join Strategy Selection:")
auto_join.explain()

Automatic Join Strategy Selection:
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- SortMergeJoin [dept_id#10], [id#0], Inner
   :- Sort [dept_id#10 ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(dept_id#10, 10), ENSURE_REQUIREMENTS, [plan_id=647]
   :     +- Scan ExistingRDD[id#8,name#9,dept_id#10,salary#11,hire_date#12]
   +- Sort [id#0 ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(id#0, 10), ENSURE_REQUIREMENTS, [plan_id=648]
         +- Scan ExistingRDD[id#0,department_name#1,location#2,budget#3]




In [11]:
# Example 2: Forcing a broadcast join
broadcast_join = employees_df.join(broadcast(departments_df), employees_df["dept_id"] == departments_df["id"])
print("Forced Broadcast Join:")
broadcast_join.explain()

Forced Broadcast Join:
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- BroadcastHashJoin [dept_id#10], [id#0], Inner, BuildRight, false
   :- Scan ExistingRDD[id#8,name#9,dept_id#10,salary#11,hire_date#12]
   +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [plan_id=664]
      +- Scan ExistingRDD[id#0,department_name#1,location#2,budget#3]




#### Performance Impact of Join Strategies

Let's measure the performance of different join strategies:

In [12]:
# Broadcast join performance
start = time.time()
broadcast_result = employees_df.join(broadcast(departments_df), employees_df["dept_id"] == departments_df["id"]) \
    .select("name", "department_name", "salary")
time_broadcast, count_broadcast = time_execution(broadcast_result)

# Default join strategy performance (likely sort-merge for this data size)
start = time.time()
default_result = employees_df.join(departments_df, employees_df["dept_id"] == departments_df["id"]) \
    .select("name", "department_name", "salary")
time_default, count_default = time_execution(default_result)

print(f"Broadcast Join: {time_broadcast:.4f} seconds")
print(f"Default Strategy: {time_default:.4f} seconds")
print(f"Performance difference: {(time_default/time_broadcast):.2f}x")

Broadcast Join: 0.2784 seconds
Default Strategy: 0.2820 seconds
Performance difference: 1.01x


#### Join Strategy Selection Guidelines:

1. **Broadcast Hash Join**: For small tables that can fit in memory (< 10MB by default)
   - Controlled by `spark.sql.autoBroadcastJoinThreshold`
   - Good for dimension tables joining with fact tables
  
2. **Sort Merge Join**: For large tables with well-distributed keys
   - Becomes default when tables are too large to broadcast
   - Good for large-to-large table joins
  
3. **Shuffle Hash Join**: For medium-sized tables with skewed data
   - Controlled by `spark.sql.join.preferSortMergeJoin`
  
4. **Broadcast Nested Loop Join**: Last resort, typically for non-equality joins

In [13]:
# Changing the broadcast threshold
original_threshold = spark.conf.get("spark.sql.autoBroadcastJoinThreshold")
print(f"Original broadcast threshold: {original_threshold}")

# Set a very low threshold to prevent broadcasting
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)
no_broadcast_join = employees_df.join(departments_df, employees_df["dept_id"] == departments_df["id"])
print("\nJoin strategy with broadcasting disabled:")
no_broadcast_join.explain()

# Reset to original value
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", original_threshold)

Original broadcast threshold: 10485760b

Join strategy with broadcasting disabled:
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- SortMergeJoin [dept_id#10], [id#0], Inner
   :- Sort [dept_id#10 ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(dept_id#10, 10), ENSURE_REQUIREMENTS, [plan_id=976]
   :     +- Scan ExistingRDD[id#8,name#9,dept_id#10,salary#11,hire_date#12]
   +- Sort [id#0 ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(id#0, 10), ENSURE_REQUIREMENTS, [plan_id=977]
         +- Scan ExistingRDD[id#0,department_name#1,location#2,budget#3]




### 3.4 Partition Pruning

Partition pruning is the ability of Spark to skip reading partitions that aren't relevant to the query, which can dramatically improve performance for large datasets.

#### What to Look For in the Plan:
- `PartitionFilters` in the scan operation
- Reduction in the number of files/partitions read

In [14]:
# Let's examine our partitioned table
print("Partitioned table structure:")
spark.sql("DESCRIBE TABLE EXTENDED partitioned_employees").show(truncate=False)

Partitioned table structure:
+----------------------------+--------------------------------------------------------------+-------+
|col_name                    |data_type                                                     |comment|
+----------------------------+--------------------------------------------------------------+-------+
|id                          |int                                                           |NULL   |
|name                        |string                                                        |NULL   |
|salary                      |double                                                        |NULL   |
|hire_date                   |date                                                          |NULL   |
|dept_id                     |int                                                           |NULL   |
|# Partition Information     |                                                              |       |
|# col_name                  |data_type              

In [15]:
# Example 1: Query without partition pruning
query_no_pruning = spark.table("partitioned_employees").filter(col("salary") > 100000)
print("Query without partition pruning:")
query_no_pruning.explain()

Query without partition pruning:
== Physical Plan ==
*(1) Filter (isnotnull(salary#10176) AND (salary#10176 > 100000.0))
+- *(1) ColumnarToRow
   +- FileScan parquet spark_catalog.default.partitioned_employees[id#10174,name#10175,salary#10176,hire_date#10177,dept_id#10178] Batched: true, DataFilters: [isnotnull(salary#10176), (salary#10176 > 100000.0)], Format: Parquet, Location: CatalogFileIndex(1 paths)[file:/opt/spark/work-dir/spark-warehouse/partitioned_employees], PartitionFilters: [], PushedFilters: [IsNotNull(salary), GreaterThan(salary,100000.0)], ReadSchema: struct<id:int,name:string,salary:double,hire_date:date>




In [16]:
# Example 2: Query with partition pruning
query_with_pruning = spark.table("partitioned_employees").filter(col("dept_id") == 2)
print("Query with partition pruning:")
query_with_pruning.explain()

Query with partition pruning:
== Physical Plan ==
*(1) ColumnarToRow
+- FileScan parquet spark_catalog.default.partitioned_employees[id#10174,name#10175,salary#10176,hire_date#10177,dept_id#10178] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/opt/spark/work-dir/spark-warehouse/partitioned_employees/dept_id=2], PartitionFilters: [isnotnull(dept_id#10178), (dept_id#10178 = 2)], PushedFilters: [], ReadSchema: struct<id:int,name:string,salary:double,hire_date:date>




In [17]:
# Example 3: Query with both partition pruning and additional filters
query_combined = spark.table("partitioned_employees") \
    .filter((col("dept_id") == 2) & (col("salary") > 100000))
print("Query with partition pruning and additional filters:")
query_combined.explain()

Query with partition pruning and additional filters:
== Physical Plan ==
*(1) Filter (isnotnull(salary#10176) AND (salary#10176 > 100000.0))
+- *(1) ColumnarToRow
   +- FileScan parquet spark_catalog.default.partitioned_employees[id#10174,name#10175,salary#10176,hire_date#10177,dept_id#10178] Batched: true, DataFilters: [isnotnull(salary#10176), (salary#10176 > 100000.0)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/opt/spark/work-dir/spark-warehouse/partitioned_employees/dept_id=2], PartitionFilters: [isnotnull(dept_id#10178), (dept_id#10178 = 2)], PushedFilters: [IsNotNull(salary), GreaterThan(salary,100000.0)], ReadSchema: struct<id:int,name:string,salary:double,hire_date:date>




#### Performance Impact of Partition Pruning

Let's measure the performance difference between queries with and without partition pruning:

In [18]:
# Without partition pruning
start = time.time()
result_no_pruning = spark.table("partitioned_employees") \
    .filter(col("salary") > 100000) \
    .select("id", "name", "salary", "dept_id")
time_no_pruning, count_no_pruning = time_execution(result_no_pruning)

# With partition pruning
start = time.time()
result_with_pruning = spark.table("partitioned_employees") \
    .filter((col("dept_id") == 2) & (col("salary") > 100000)) \
    .select("id", "name", "salary", "dept_id")
time_with_pruning, count_with_pruning = time_execution(result_with_pruning)

print(f"Without Partition Pruning: {time_no_pruning:.4f} seconds, {count_no_pruning} records")
print(f"With Partition Pruning: {time_with_pruning:.4f} seconds, {count_with_pruning} records")

if time_no_pruning > time_with_pruning:
    print(f"Performance improvement: {(time_no_pruning/time_with_pruning):.2f}x faster with pruning")

Without Partition Pruning: 0.1598 seconds, 489 records
With Partition Pruning: 0.0615 seconds, 95 records
Performance improvement: 2.60x faster with pruning


#### Best Practices for Partition Optimization:

1. **Choose appropriate partition columns**: 
   - High cardinality but not too high (e.g., date, region, category)
   - Commonly used in filters
  
2. **Avoid over-partitioning**: 
   - Too many small partitions create small files and overhead
   - Aim for partition sizes between 128MB and 1GB
  
3. **Include partition columns in queries**: 
   - Ensure queries filter on partition columns when possible
  
4. **Consider bucketing for join performance**: 
   - Complement partitioning with bucketing for frequently joined columns

### 3.5 Shuffle Operations

Shuffles redistribute data across partitions and are often the most expensive operations in Spark. They involve disk I/O, serialization, network transfer, and deserialization.

#### What to Look For in the Plan:
- `Exchange` operations in the physical plan
- The type of exchange (e.g., HashPartitioning, RangePartitioning)
- The number of shuffle partitions

In [19]:
# Example 1: Operations that trigger shuffles
shuffled_df = employees_df.groupBy("dept_id").agg(count("*").alias("emp_count"), avg("salary").alias("avg_salary"))
print("GroupBy operation causing shuffle:")
shuffled_df.explain()

GroupBy operation causing shuffle:
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- HashAggregate(keys=[dept_id#10], functions=[count(1), avg(salary#11)])
   +- Exchange hashpartitioning(dept_id#10, 10), ENSURE_REQUIREMENTS, [plan_id=1137]
      +- HashAggregate(keys=[dept_id#10], functions=[partial_count(1), partial_avg(salary#11)])
         +- Project [dept_id#10, salary#11]
            +- Scan ExistingRDD[id#8,name#9,dept_id#10,salary#11,hire_date#12]




In [20]:
# Example 2: Join operation causing shuffle
joined_df = employees_df.join(departments_df, employees_df["dept_id"] == departments_df["id"])
print("Join operation causing shuffle:")
joined_df.explain()

Join operation causing shuffle:
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- SortMergeJoin [dept_id#10], [id#0], Inner
   :- Sort [dept_id#10 ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(dept_id#10, 10), ENSURE_REQUIREMENTS, [plan_id=1152]
   :     +- Scan ExistingRDD[id#8,name#9,dept_id#10,salary#11,hire_date#12]
   +- Sort [id#0 ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(id#0, 10), ENSURE_REQUIREMENTS, [plan_id=1153]
         +- Scan ExistingRDD[id#0,department_name#1,location#2,budget#3]




#### Impact of the Number of Shuffle Partitions

Let's measure the impact of different shuffle partition settings:

In [21]:
# Test with different shuffle partition counts
partition_counts = [2, 10, 50, 200]
test_results = []

for partitions in partition_counts:
    spark.conf.set("spark.sql.shuffle.partitions", partitions)
    
    # Run a query that involves shuffling
    start = time.time()
    result = employees_df.groupBy("dept_id").agg(
        count("*").alias("emp_count"),
        avg("salary").alias("avg_salary")
    )
    execution_time, _ = time_execution(result)
    
    test_results.append((partitions, execution_time))
    print(f"Shuffle partitions: {partitions}, Execution time: {execution_time:.4f} seconds")

# Reset to initial value
spark.conf.set("spark.sql.shuffle.partitions", 10)

Shuffle partitions: 2, Execution time: 0.2265 seconds
Shuffle partitions: 10, Execution time: 0.1719 seconds
Shuffle partitions: 50, Execution time: 0.1725 seconds
Shuffle partitions: 200, Execution time: 0.1646 seconds


#### Best Practices for Shuffle Optimization:

1. **Tune shuffle partitions**: 
   - Default is 200, which is often too high for small datasets
   - Rule of thumb: 2-3 * number of cores for small-medium datasets
   - For larger clusters, start with cluster cores * 3-4
  
2. **Use appropriate partitioning**: 
   - Pre-partition data by join keys to reduce shuffling
   - Consider repartitioning before expensive operations
  
3. **Minimize the number of stages**: 
   - Chain transformations that don't require shuffles
  
4. **Use broadcast joins**: 
   - When possible, to avoid shuffling larger tables
  
5. **Consider enabling Adaptive Query Execution**:
   - Allows Spark to dynamically coalesce shuffle partitions

In [22]:
# Enable Adaptive Query Execution to automatically optimize shuffle partitions
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")

# Run a query with many initial partitions
spark.conf.set("spark.sql.shuffle.partitions", 50)
adaptive_query = employees_df.groupBy("dept_id").agg(
    count("*").alias("emp_count"),
    avg("salary").alias("avg_salary")
)

print("Query execution plan with Adaptive Execution enabled:")
adaptive_query.explain()

Query execution plan with Adaptive Execution enabled:
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- HashAggregate(keys=[dept_id#10], functions=[count(1), avg(salary#11)])
   +- Exchange hashpartitioning(dept_id#10, 50), ENSURE_REQUIREMENTS, [plan_id=1516]
      +- HashAggregate(keys=[dept_id#10], functions=[partial_count(1), partial_avg(salary#11)])
         +- Project [dept_id#10, salary#11]
            +- Scan ExistingRDD[id#8,name#9,dept_id#10,salary#11,hire_date#12]




### 3.6 Data Skew

Data skew occurs when data is unevenly distributed across partitions, causing some tasks to run much longer than others. This can significantly impact performance in both grouping and join operations.

#### What to Look For:
- In the Spark UI: tasks in a stage taking much longer than others
- Uneven partition sizes in the input or after a shuffle
- High standard deviation in execution times

Let's create a skewed dataset to demonstrate the issue:

In [23]:
# Create a skewed dataset
from pyspark.sql.functions import when

# Generate skewed data where 80% of records have dept_id = 1
skewed_employee_data = []
for i in range(1, 10001):  # 10,000 employees
    name = f"{random.choice(names)} {random.choice(surnames)}"
    # Create skew: 80% in dept_id 1
    dept_id = 1 if random.random() < 0.8 else random.randint(2, 5)
    salary = round(random.uniform(50000, 150000), 2)
    random_days = random.randint(0, delta)
    hire_date = start_date + timedelta(days=random_days)
    skewed_employee_data.append((i, name, dept_id, salary, hire_date))

skewed_employees_df = spark.createDataFrame(skewed_employee_data, employee_schema)
skewed_employees_df.createOrReplaceTempView("skewed_employees")

# Check the distribution
print("Distribution of employees across departments:")
skewed_employees_df.groupBy("dept_id").count().orderBy("dept_id").show()

Distribution of employees across departments:
+-------+-----+
|dept_id|count|
+-------+-----+
|      1| 8092|
|      2|  499|
|      3|  449|
|      4|  482|
|      5|  478|
+-------+-----+



#### Issues Caused by Data Skew

Let's observe the performance impact of skewed data in group by and join operations:

In [24]:
# GroupBy on skewed column
start = time.time()
skewed_agg = skewed_employees_df.groupBy("dept_id").agg(
    count("*").alias("emp_count"),
    avg("salary").alias("avg_salary")
)
time_skewed_group, _ = time_execution(skewed_agg)

# Join on skewed column
start = time.time()
skewed_join = skewed_employees_df.join(
    departments_df,
    skewed_employees_df["dept_id"] == departments_df["id"]
)
time_skewed_join, _ = time_execution(skewed_join)

print(f"GroupBy on skewed data execution time: {time_skewed_group:.4f} seconds")
print(f"Join on skewed data execution time: {time_skewed_join:.4f} seconds")

GroupBy on skewed data execution time: 0.1667 seconds
Join on skewed data execution time: 0.3821 seconds


#### Techniques to Handle Data Skew

1. **Salting**: Add a random number to the skewed key to distribute it
2. **Two-phase aggregation**: Local aggregation followed by global aggregation
3. **Separate processing**: Handle the skewed values separately

Let's implement the salting technique to handle skew in joins:

In [25]:
# Technique 1: Salting for skewed joins
from pyspark.sql.functions import monotonically_increasing_id, concat

# Step 1: Identify skewed values
skewed_keys = skewed_employees_df.groupBy("dept_id") \
    .count() \
    .filter(col("count") > 1000) \
    .select("dept_id") \
    .collect()

skewed_keys_list = [row[0] for row in skewed_keys]
print(f"Identified skewed keys: {skewed_keys_list}")

# Step 2: Add salt to skewed keys
salt_factor = 10  # Number of salt values to use

# Add a salt value to skewed records
salted_employees = skewed_employees_df.withColumn(
    "salt", 
    when(col("dept_id").isin(skewed_keys_list), 
         monotonically_increasing_id() % salt_factor)
    .otherwise(0)
)

# Create a salted key for joining
salted_employees = salted_employees.withColumn(
    "salted_dept_id", 
    when(col("salt") > 0, 
         concat(col("dept_id").cast("string"), lit("_"), col("salt").cast("string")))
    .otherwise(col("dept_id").cast("string"))
)

# Step 3: Explode the department table for skewed keys
from pyspark.sql.functions import explode, array, lit

# Generate salted keys for the departments table
exploded_depts = departments_df.withColumn(
    "is_skewed", col("id").isin(skewed_keys_list)
)

# For skewed keys, create multiple copies with salts
salted_depts = exploded_depts.withColumn(
    "salt_values",
    when(col("is_skewed"), 
         array([lit(i) for i in range(salt_factor)]))
    .otherwise(array(lit(0)))
)

# Explode to create multiple rows for skewed keys
salted_depts = salted_depts.withColumn("salt", explode("salt_values")).drop("salt_values")

# Create matching salted key
salted_depts = salted_depts.withColumn(
    "salted_id", 
    when(col("salt") > 0, 
         concat(col("id").cast("string"), lit("_"), col("salt").cast("string")))
    .otherwise(col("id").cast("string"))
)

# Step 4: Join on the salted keys
salted_join = salted_employees.join(
    salted_depts,
    salted_employees["salted_dept_id"] == salted_depts["salted_id"]
)

# Measure performance
time_salted_join, count_salted = time_execution(salted_join)

print(f"Regular join on skewed data: {time_skewed_join:.4f} seconds")
print(f"Salted join: {time_salted_join:.4f} seconds")
if time_skewed_join > time_salted_join:
    print(f"Performance improvement: {(time_skewed_join/time_salted_join):.2f}x faster with salting")

Identified skewed keys: [1]
Regular join on skewed data: 0.3821 seconds
Salted join: 0.3510 seconds
Performance improvement: 1.09x faster with salting


#### Alternative Approach: Handle Skewed Values Separately

Another strategy is to process skewed values separately from the rest:

In [26]:
# Technique 2: Process skewed values separately

# Split the dataset
skewed_records = skewed_employees_df.filter(col("dept_id").isin(skewed_keys_list))
normal_records = skewed_employees_df.filter(~col("dept_id").isin(skewed_keys_list))

print(f"Skewed records: {skewed_records.count()}, Normal records: {normal_records.count()}")

# Process the skewed records with broadcast join
start = time.time()
skewed_result = skewed_records.join(
    broadcast(departments_df),
    skewed_records["dept_id"] == departments_df["id"]
)

# Process normal records with regular join
normal_result = normal_records.join(
    departments_df,
    normal_records["dept_id"] == departments_df["id"]
)

# Combine results
combined_result = skewed_result.union(normal_result)
time_split_process, count_split = time_execution(combined_result)

print(f"Regular join on skewed data: {time_skewed_join:.4f} seconds")
print(f"Split processing: {time_split_process:.4f} seconds")
if time_skewed_join > time_split_process:
    print(f"Performance improvement: {(time_skewed_join/time_split_process):.2f}x faster with split processing")

Skewed records: 8092, Normal records: 1908
Regular join on skewed data: 0.3821 seconds
Split processing: 0.6445 seconds


## 4. End-to-End Optimization Example

Let's bring everything together with an end-to-end example. We'll start with a suboptimal query and improve it step by step using the techniques we've discussed.

In [27]:
# Reset configuration to defaults
spark.conf.set("spark.sql.shuffle.partitions", 200)  # Default value
spark.conf.set("spark.sql.adaptive.enabled", "false")
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", 10485760)  # 10MB

# Create larger datasets for this example
from pyspark.sql.functions import current_date, datediff, rand

# Create a large fact table (orders)
orders_schema = StructType([
    StructField("order_id", IntegerType(), False),
    StructField("customer_id", IntegerType(), False),
    StructField("product_id", IntegerType(), False),
    StructField("order_date", DateType(), False),
    StructField("quantity", IntegerType(), False),
    StructField("price", DoubleType(), False)
])

# Generate 100,000 orders
order_data = []
for i in range(1, 100001):
    customer_id = random.randint(1, 1000)
    product_id = random.randint(1, 100)
    random_days = random.randint(0, 365*3)  # Last 3 years
    order_date = datetime.now() - timedelta(days=random_days)
    quantity = random.randint(1, 10)
    price = round(random.uniform(10, 1000), 2)
    order_data.append((i, customer_id, product_id, order_date, quantity, price))

orders_df = spark.createDataFrame(order_data, orders_schema)
orders_df.createOrReplaceTempView("orders")

# Create dimension tables
# Customers
customer_schema = StructType([
    StructField("customer_id", IntegerType(), False),
    StructField("name", StringType(), False),
    StructField("email", StringType(), True),
    StructField("city", StringType(), False),
    StructField("state", StringType(), False),
    StructField("signup_date", DateType(), False)
])

# Generate 1,000 customers
customer_data = []
cities = ["New York", "Los Angeles", "Chicago", "Houston", "Phoenix", "Philadelphia", "San Antonio", "San Diego"]
states = ["NY", "CA", "IL", "TX", "AZ", "PA", "TX", "CA"]

for i in range(1, 1001):
    name = f"{random.choice(names)} {random.choice(surnames)}"
    email = f"{name.lower().replace(' ', '.')}@example.com"
    city_idx = random.randint(0, len(cities)-1)
    city = cities[city_idx]
    state = states[city_idx]
    random_days = random.randint(365, 365*5)  # 1-5 years ago
    signup_date = datetime.now() - timedelta(days=random_days)
    customer_data.append((i, name, email, city, state, signup_date))

customers_df = spark.createDataFrame(customer_data, customer_schema)
customers_df.createOrReplaceTempView("customers")

# Products
product_schema = StructType([
    StructField("product_id", IntegerType(), False),
    StructField("name", StringType(), False),
    StructField("category", StringType(), False),
    StructField("base_price", DoubleType(), False)
])

# Generate 100 products
product_data = []
categories = ["Electronics", "Clothing", "Home", "Books", "Sports"]
product_names = ["Laptop", "Phone", "Tablet", "TV", "Camera", "Shirt", "Pants", "Shoes", "Jacket", "Sofa", 
                 "Chair", "Table", "Bed", "Novel", "Textbook", "Cookbook", "Basketball", "Tennis Racket", "Bicycle"]

for i in range(1, 101):
    name = random.choice(product_names)
    category = random.choice(categories)
    base_price = round(random.uniform(10, 500), 2)
    product_data.append((i, name, category, base_price))

products_df = spark.createDataFrame(product_data, product_schema)
products_df.createOrReplaceTempView("products")

print(f"Created datasets with {orders_df.count()} orders, {customers_df.count()} customers, and {products_df.count()} products.")

Created datasets with 100000 orders, 1000 customers, and 100 products.


### Initial Suboptimal Query

Let's start with a suboptimal query that computes total sales by state and category for the last year:

In [28]:
# Original inefficient query
def run_original_query():
    query = """
    SELECT 
        c.state, 
        p.category, 
        COUNT(DISTINCT o.order_id) as num_orders,
        SUM(o.quantity * o.price) as total_sales
    FROM 
        orders o
    JOIN 
        customers c ON o.customer_id = c.customer_id
    JOIN 
        products p ON o.product_id = p.product_id
    WHERE 
        o.order_date >= date_sub(current_date(), 365)
    GROUP BY 
        c.state, p.category
    ORDER BY 
        total_sales DESC
    """
    return spark.sql(query)

# Get the execution plan
print("Original Query Plan:")
original_query = run_original_query()
original_query.explain()

# Measure performance
start = time.time()
original_query.collect()
original_time = time.time() - start
print(f"\nOriginal query execution time: {original_time:.4f} seconds")

Original Query Plan:
== Physical Plan ==
*(12) Sort [total_sales#10908 DESC NULLS LAST], true, 0
+- Exchange rangepartitioning(total_sales#10908 DESC NULLS LAST, 200), ENSURE_REQUIREMENTS, [plan_id=2770]
   +- *(11) HashAggregate(keys=[state#10860, category#10870], functions=[sum((cast(quantity#10848 as double) * price#10849)), count(distinct order_id#10844)])
      +- Exchange hashpartitioning(state#10860, category#10870, 200), ENSURE_REQUIREMENTS, [plan_id=2766]
         +- *(10) HashAggregate(keys=[state#10860, category#10870], functions=[merge_sum((cast(quantity#10848 as double) * price#10849)), partial_count(distinct order_id#10844)])
            +- *(10) HashAggregate(keys=[state#10860, category#10870, order_id#10844], functions=[merge_sum((cast(quantity#10848 as double) * price#10849))])
               +- Exchange hashpartitioning(state#10860, category#10870, order_id#10844, 200), ENSURE_REQUIREMENTS, [plan_id=2761]
                  +- *(9) HashAggregate(keys=[state#10860, cate

                                                                                


Original query execution time: 2.5330 seconds


### Optimization Step 1: Filter Pushdown and Early Projections

Our first optimization will be to apply filters early and use only necessary columns:

In [29]:
# Step 1: Filter pushdown and early projections
def run_optimized_query_step1():
    # First filter the orders
    filtered_orders = orders_df \
        .filter(col("order_date") >= expr("date_sub(current_date(), 365)")) \
        .select("order_id", "customer_id", "product_id", "quantity", "price")
    
    # Use only necessary columns from dimension tables
    customers_slim = customers_df.select("customer_id", "state")
    products_slim = products_df.select("product_id", "category")
    
    # Join and aggregate
    result = filtered_orders \
        .join(customers_slim, "customer_id") \
        .join(products_slim, "product_id") \
        .groupBy("state", "category") \
        .agg( \
            count("order_id").alias("num_orders"), \
            sum(col("quantity") * col("price")).alias("total_sales") \
        ) \
        .orderBy(col("total_sales").desc())
    
    return result

# Get the execution plan
print("Step 1 Query Plan (Filter Pushdown and Column Pruning):")
step1_query = run_optimized_query_step1()
step1_query.explain()

# Measure performance
start = time.time()
step1_query.collect()
step1_time = time.time() - start
print(f"\nStep 1 query execution time: {step1_time:.4f} seconds")
print(f"Improvement from original: {(original_time/step1_time):.2f}x faster")

Step 1 Query Plan (Filter Pushdown and Column Pruning):
== Physical Plan ==
*(11) Sort [total_sales#10963 DESC NULLS LAST], true, 0
+- Exchange rangepartitioning(total_sales#10963 DESC NULLS LAST, 200), ENSURE_REQUIREMENTS, [plan_id=2954]
   +- *(10) HashAggregate(keys=[state#10860, category#10870], functions=[count(1), sum((cast(quantity#10848 as double) * price#10849))])
      +- Exchange hashpartitioning(state#10860, category#10870, 200), ENSURE_REQUIREMENTS, [plan_id=2950]
         +- *(9) HashAggregate(keys=[state#10860, category#10870], functions=[partial_count(1), partial_sum((cast(quantity#10848 as double) * price#10849))])
            +- *(9) Project [quantity#10848, price#10849, state#10860, category#10870]
               +- *(9) SortMergeJoin [product_id#10846], [product_id#10868], Inner
                  :- *(6) Sort [product_id#10846 ASC NULLS FIRST], false, 0
                  :  +- Exchange hashpartitioning(product_id#10846, 200), ENSURE_REQUIREMENTS, [plan_id=2935]
    

### Optimization Step 2: Broadcast Small Tables

Next, let's use broadcast joins for the dimension tables:

In [30]:
# Step 2: Add broadcast joins
def run_optimized_query_step2():
    # First filter the orders
    filtered_orders = orders_df \
        .filter(col("order_date") >= expr("date_sub(current_date(), 365)")) \
        .select("order_id", "customer_id", "product_id", "quantity", "price")
    
    # Use only necessary columns from dimension tables
    customers_slim = customers_df.select("customer_id", "state")
    products_slim = products_df.select("product_id", "category")
    
    # Broadcast the dimension tables for more efficient joins
    result = filtered_orders \
        .join(broadcast(customers_slim), "customer_id") \
        .join(broadcast(products_slim), "product_id") \
        .groupBy("state", "category") \
        .agg( \
            count("order_id").alias("num_orders"), \
            sum(col("quantity") * col("price")).alias("total_sales") \
        ) \
        .orderBy(col("total_sales").desc())
    
    return result

# Get the execution plan
print("Step 2 Query Plan (With Broadcast Joins):")
step2_query = run_optimized_query_step2()
step2_query.explain()

# Measure performance
start = time.time()
step2_query.collect()
step2_time = time.time() - start
print(f"\nStep 2 query execution time: {step2_time:.4f} seconds")
print(f"Improvement from original: {(original_time/step2_time):.2f}x faster")
print(f"Improvement from step 1: {(step1_time/step2_time):.2f}x faster")

Step 2 Query Plan (With Broadcast Joins):
== Physical Plan ==
*(5) Sort [total_sales#11007 DESC NULLS LAST], true, 0
+- Exchange rangepartitioning(total_sales#11007 DESC NULLS LAST, 200), ENSURE_REQUIREMENTS, [plan_id=3104]
   +- *(4) HashAggregate(keys=[state#10860, category#10870], functions=[count(1), sum((cast(quantity#10848 as double) * price#10849))])
      +- Exchange hashpartitioning(state#10860, category#10870, 200), ENSURE_REQUIREMENTS, [plan_id=3100]
         +- *(3) HashAggregate(keys=[state#10860, category#10870], functions=[partial_count(1), partial_sum((cast(quantity#10848 as double) * price#10849))])
            +- *(3) Project [quantity#10848, price#10849, state#10860, category#10870]
               +- *(3) BroadcastHashJoin [product_id#10846], [product_id#10868], Inner, BuildRight, false
                  :- *(3) Project [product_id#10846, quantity#10848, price#10849, state#10860]
                  :  +- *(3) BroadcastHashJoin [customer_id#10845], [customer_id#10856],

### Optimization Step 3: Tune Shuffle Partitions

Now, let's adjust the number of shuffle partitions to better match our dataset size:

In [31]:
# Step 3: Tune shuffle partitions
# Default is 200, which is likely too high for our relatively small dataset
spark.conf.set("spark.sql.shuffle.partitions", 20)  # Adjusted value

def run_optimized_query_step3():
    # Same query as step 2, but with adjusted shuffle partitions
    filtered_orders = orders_df \
        .filter(col("order_date") >= expr("date_sub(current_date(), 365)")) \
        .select("order_id", "customer_id", "product_id", "quantity", "price")
    
    customers_slim = customers_df.select("customer_id", "state")
    products_slim = products_df.select("product_id", "category")
    
    result = filtered_orders \
        .join(broadcast(customers_slim), "customer_id") \
        .join(broadcast(products_slim), "product_id") \
        .groupBy("state", "category") \
        .agg( \
            count("order_id").alias("num_orders"), \
            sum(col("quantity") * col("price")).alias("total_sales") \
        ) \
        .orderBy(col("total_sales").desc())
    
    return result

# Get the execution plan
print("Step 3 Query Plan (With Adjusted Shuffle Partitions):")
step3_query = run_optimized_query_step3()
step3_query.explain()

# Measure performance
start = time.time()
step3_query.collect()
step3_time = time.time() - start
print(f"\nStep 3 query execution time: {step3_time:.4f} seconds")
print(f"Improvement from original: {(original_time/step3_time):.2f}x faster")
print(f"Improvement from step 2: {(step2_time/step3_time):.2f}x faster")

Step 3 Query Plan (With Adjusted Shuffle Partitions):
== Physical Plan ==
*(5) Sort [total_sales#11051 DESC NULLS LAST], true, 0
+- Exchange rangepartitioning(total_sales#11051 DESC NULLS LAST, 20), ENSURE_REQUIREMENTS, [plan_id=3228]
   +- *(4) HashAggregate(keys=[state#10860, category#10870], functions=[count(1), sum((cast(quantity#10848 as double) * price#10849))])
      +- Exchange hashpartitioning(state#10860, category#10870, 20), ENSURE_REQUIREMENTS, [plan_id=3224]
         +- *(3) HashAggregate(keys=[state#10860, category#10870], functions=[partial_count(1), partial_sum((cast(quantity#10848 as double) * price#10849))])
            +- *(3) Project [quantity#10848, price#10849, state#10860, category#10870]
               +- *(3) BroadcastHashJoin [product_id#10846], [product_id#10868], Inner, BuildRight, false
                  :- *(3) Project [product_id#10846, quantity#10848, price#10849, state#10860]
                  :  +- *(3) BroadcastHashJoin [customer_id#10845], [customer_

### Optimization Step 4: Enable Adaptive Query Execution

Finally, let's enable Adaptive Query Execution:

In [32]:
# Step 4: Enable Adaptive Query Execution
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")

def run_optimized_query_step4():
    # Same query as step 3, but with AQE enabled
    filtered_orders = orders_df \
        .filter(col("order_date") >= expr("date_sub(current_date(), 365)")) \
        .select("order_id", "customer_id", "product_id", "quantity", "price")
    
    customers_slim = customers_df.select("customer_id", "state")
    products_slim = products_df.select("product_id", "category")
    
    result = filtered_orders \
        .join(customers_slim, "customer_id") \
        .join(products_slim, "product_id") \
        .groupBy("state", "category") \
        .agg( \
            count("order_id").alias("num_orders"), \
            sum(col("quantity") * col("price")).alias("total_sales") \
        ) \
        .orderBy(col("total_sales").desc())
    
    return result

# Get the execution plan
print("Step 4 Query Plan (With Adaptive Query Execution):")
step4_query = run_optimized_query_step4()
step4_query.explain()

# Measure performance
start = time.time()
step4_query.collect()
step4_time = time.time() - start
print(f"\nStep 4 query execution time: {step4_time:.4f} seconds")
print(f"Improvement from original: {(original_time/step4_time):.2f}x faster")
print(f"Improvement from step 3: {(step3_time/step4_time):.2f}x faster")

Step 4 Query Plan (With Adaptive Query Execution):
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Sort [total_sales#11095 DESC NULLS LAST], true, 0
   +- Exchange rangepartitioning(total_sales#11095 DESC NULLS LAST, 20), ENSURE_REQUIREMENTS, [plan_id=3343]
      +- HashAggregate(keys=[state#10860, category#10870], functions=[count(1), sum((cast(quantity#10848 as double) * price#10849))])
         +- Exchange hashpartitioning(state#10860, category#10870, 20), ENSURE_REQUIREMENTS, [plan_id=3340]
            +- HashAggregate(keys=[state#10860, category#10870], functions=[partial_count(1), partial_sum((cast(quantity#10848 as double) * price#10849))])
               +- Project [quantity#10848, price#10849, state#10860, category#10870]
                  +- SortMergeJoin [product_id#10846], [product_id#10868], Inner
                     :- Sort [product_id#10846 ASC NULLS FIRST], false, 0
                     :  +- Exchange hashpartitioning(product_id#10846, 20), ENSURE_REQUIREMEN

### Comparing All Optimizations

Let's summarize the improvements from each optimization step:

In [33]:
# Summary of performance improvements
print("Performance Summary:")
print(f"Original query: {original_time:.4f} seconds")
print(f"Step 1 (Filter Pushdown & Column Pruning): {step1_time:.4f} seconds, {(original_time/step1_time):.2f}x faster")
print(f"Step 2 (Broadcast Joins): {step2_time:.4f} seconds, {(original_time/step2_time):.2f}x faster")
print(f"Step 3 (Tuned Shuffle Partitions): {step3_time:.4f} seconds, {(original_time/step3_time):.2f}x faster")
print(f"Step 4 (Adaptive Query Execution): {step4_time:.4f} seconds, {(original_time/step4_time):.2f}x faster")

# Reset configuration
spark.conf.set("spark.sql.shuffle.partitions", 200)  # Default value
spark.conf.set("spark.sql.adaptive.enabled", "false")

Performance Summary:
Original query: 2.5330 seconds
Step 1 (Filter Pushdown & Column Pruning): 1.3588 seconds, 1.86x faster
Step 2 (Broadcast Joins): 0.6238 seconds, 4.06x faster
Step 3 (Tuned Shuffle Partitions): 0.3980 seconds, 6.36x faster
Step 4 (Adaptive Query Execution): 0.4916 seconds, 5.15x faster


## 5. Best Practices and Guidelines

Let's summarize the key best practices for optimizing Spark performance:

### Optimization by Area

#### 1. Data Reading and Filtering
- **Select only necessary columns**: Reduces I/O and memory usage
- **Apply filters as early as possible**: Reduces data volume for downstream operations
- **Use appropriate file formats**: Parquet/ORC > CSV/JSON for analytical workloads
- **Leverage partition pruning**: Choose good partition columns for large tables
- **Compute and maintain statistics**: For better query planning

#### 2. Joins
- **Optimize join order**: Join the most filtered/smallest tables first
- **Use broadcast joins for small tables**: Avoid shuffling large tables
- **Handle skew in joins**: Use salting or separate processing for skewed keys
- **Pre-partition data on join keys**: To reduce shuffling
- **Filter before joining**: Reduces the size of the join operation

#### 3. Aggregations
- **Tune shuffle partitions**: Based on data size and cluster capacity
- **Use window functions efficiently**: For complex analytical queries
- **Pre-aggregate data when possible**: Reduces data volume for global aggregations
- **Handle skew in groupBy**: Similar to join skew handling

#### 4. General Optimizations
- **Use built-in functions over UDFs**: Better performance with codegen
- **Enable Adaptive Query Execution**: For dynamic optimizations
- **Cache judiciously**: Only for frequently accessed datasets
- **Monitor and analyze query plans**: Identify bottlenecks
- **Set appropriate configurations**: Memory, cores, shuffle partitions, etc.

### Key Configuration Parameters

Here are some important configuration parameters to consider:

In [34]:
# Show current configuration values
print("Current configuration:")
config_params = [
    "spark.sql.shuffle.partitions",
    "spark.sql.autoBroadcastJoinThreshold",
    "spark.sql.adaptive.enabled",
    "spark.sql.adaptive.coalescePartitions.enabled",
    "spark.driver.memory",
    "spark.executor.memory",
    "spark.sql.files.maxPartitionBytes",
    "spark.default.parallelism"
]

for param in config_params:
    try:
        value = spark.conf.get(param)
        print(f"{param}: {value}")
    except:
        print(f"{param}: Not set")

Current configuration:
spark.sql.shuffle.partitions: 200
spark.sql.autoBroadcastJoinThreshold: 10485760
spark.sql.adaptive.enabled: false
spark.sql.adaptive.coalescePartitions.enabled: true
spark.driver.memory: Not set
spark.executor.memory: Not set
spark.sql.files.maxPartitionBytes: 134217728b
spark.default.parallelism: Not set


### Recommendations for Configuration Tuning

1. **spark.sql.shuffle.partitions**:
   - Default: 200
   - Start with 2-3 times the number of cores for small-medium datasets
   - For large datasets (TB+), increase based on data size

2. **spark.sql.autoBroadcastJoinThreshold**:
   - Default: 10MB
   - Increase for larger driver memory (e.g., 50-100MB)
   - Set to -1 to disable automatic broadcasting

3. **spark.sql.adaptive.enabled**:
   - Default: true (in Spark 3.x)
   - Recommended: true for most workloads

4. **spark.driver.memory** and **spark.executor.memory**:
   - Depends on your cluster resources
   - Avoid OOM errors with appropriate sizing

5. **spark.sql.files.maxPartitionBytes**:
   - Default: 128MB
   - Adjust based on file size and memory available

### When to Use Each Optimization Technique

| Technique               | When to Use                                 | Example Scenario                             |
|-------------------------|---------------------------------------------|----------------------------------------------|
| Filter Pushdown         | Large datasets with filtering conditions    | Querying subset of time-series data          |
| Broadcast Joins         | Joining with small dimension tables         | Fact table with dimension lookups            |
| Repartitioning          | Before joins on non-uniform data            | Pre-shuffle data on join key                 |
| Caching                 | Repeatedly accessed intermediate results    | Iterative algorithms, multiple queries       |
| Salting                 | Highly skewed join or groupBy keys          | User activity data with popular items        |
| Shuffle Partition Tuning| Performance tuning for specific data size   | Adjust based on cluster and dataset size     |
| Adaptive Execution      | Complex queries with unpredictable stats    | Ad-hoc analytical queries                    |

### Closing Tips

1. **Understand your data**: Distribution, size, access patterns
2. **Analyze query plans**: Identify bottlenecks early
3. **Test incrementally**: Apply one optimization at a time
4. **Monitor Spark UI**: For execution details and skew detection
5. **Prefer simple optimizations first**: Often the biggest gains come from basic techniques

By carefully analyzing execution plans and applying these optimization techniques, you can significantly improve the performance of your Spark queries.

In [35]:
# Clean up
spark.catalog.clearCache()
print("Resources cleaned up. Notebook complete!")

Resources cleaned up. Notebook complete!


25/04/18 09:25:52 WARN JavaUtils: Attempt to delete using native Unix OS command failed for path = /tmp/blockmgr-2842cb6d-cd48-48a1-95cb-946c414be354. Falling back to Java IO way
java.io.IOException: Failed to delete: /tmp/blockmgr-2842cb6d-cd48-48a1-95cb-946c414be354
	at org.apache.spark.network.util.JavaUtils.deleteRecursivelyUsingUnixNative(JavaUtils.java:173)
	at org.apache.spark.network.util.JavaUtils.deleteRecursively(JavaUtils.java:109)
	at org.apache.spark.network.util.JavaUtils.deleteRecursively(JavaUtils.java:90)
	at org.apache.spark.util.SparkFileUtils.deleteRecursively(SparkFileUtils.scala:121)
	at org.apache.spark.util.SparkFileUtils.deleteRecursively$(SparkFileUtils.scala:120)
	at org.apache.spark.util.Utils$.deleteRecursively(Utils.scala:1126)
	at org.apache.spark.storage.DiskBlockManager.$anonfun$doStop$1(DiskBlockManager.scala:368)
	at org.apache.spark.storage.DiskBlockManager.$anonfun$doStop$1$adapted(DiskBlockManager.scala:364)
	at scala.collection.IndexedSeqOptimize