# Advanced Spark Query Plan Analysis and Optimization

This notebook provides an in-depth exploration of Spark's query plans and advanced optimization techniques, building on top of basic optimizations. We'll analyze logical and physical plans in detail to identify performance bottlenecks and apply targeted optimizations.

## Table of Contents

1. Introduction to Advanced Query Plan Analysis
2. Setting Up the Environment
3. Deep Dive into Logical and Physical Plans
   - Analyzing Scan Operations
   - Understanding Filter and Projection Pushdown
   - Detailed Join Strategy Analysis
   - Optimizing Shuffle Operations
   - Partition Tuning for Performance
   - Broadcast Operations and Memory Management
   - Aggregation Optimization Techniques
   - Spill to Disk Detection and Prevention
   - Skew Detection and Handling
   - Whole-Stage Codegen Analysis
4. Adaptive Query Execution Deep Dive
5. Cost-Based Optimization in Spark
6. Advanced Performance Tuning Configurations
7. Plan Metrics Analysis and Optimization
8. Real-world Examples and Solutions

Let's start by understanding what makes query plan analysis essential for advanced Spark optimization.

## 1. Introduction to Advanced Query Plan Analysis

While basic optimization strategies can significantly improve Spark performance, advanced optimization requires deep understanding of how Spark transforms queries into executable code. By analyzing both logical and physical plans, we can identify inefficiencies that aren't obvious from the DataFrame API or SQL interface.

Key components of advanced plan analysis include:

- **Logical Plan Transformation Rules**: Understanding how Catalyst applies rules to optimize logical plans
- **Physical Plan Selection Criteria**: How Spark decides which physical strategies to use
- **Cost Model Analysis**: How statistics influence plan decisions
- **Runtime Adaptations**: How plans are modified during execution
- **Performance Bottleneck Identification**: Finding specific operations that limit performance
- **Resource Usage Patterns**: Understanding memory, CPU, and I/O patterns

The `explain()` method with extended modes and additional utilities in the `ExplainCommand` provide the tools needed for this deep analysis.

## 2. Setting Up the Environment

Let's set up our environment and create test datasets for our advanced analysis:

In [1]:
!pip install matplotlib pandas numpy

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [2]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    col, count, avg, sum, max, min, lit, concat, broadcast, expr, 
    when, coalesce, array, explode, struct, to_json, from_json, 
    window, row_number, rank, dense_rank, ntile, lead, lag,
    udf, pandas_udf, PandasUDFType
)
from pyspark.sql.types import (
    StructType, StructField, StringType, IntegerType, DoubleType, 
    DateType, TimestampType, ArrayType, MapType, BooleanType
)
import time
import random
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
import matplotlib.pyplot as plt

# Initialize Spark Session with detailed configurations
spark = SparkSession.builder \
    .appName("Advanced Spark Optimization") \
    .config("spark.sql.shuffle.partitions", 10) \
    .config("spark.executor.memory", "2g") \
    .config("spark.driver.memory", "4g") \
    .config("spark.sql.autoBroadcastJoinThreshold", "10m") \
    .config("spark.sql.adaptive.enabled", "false") \
    .getOrCreate()

# Enable us to view full query plans
spark.conf.set("spark.sql.adaptive.enabled", "false")  # Initially disabled for clearer analysis
spark.conf.set("spark.sql.codegen.wholeStage", "true")  # Enable whole-stage codegen

print(f"Spark version: {spark.version}")
print("Session initialized with detailed configuration!")

/opt/spark/bin/load-spark-env.sh: line 68: ps: command not found
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/04/18 10:11:21 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


Spark version: 3.5.1
Session initialized with detailed configuration!


In [3]:
# Helper function to time execution and get metrics
def time_execution_with_metrics(df, action="count", name="query"):
    """Time and collect metrics for a DataFrame action"""
    start = time.time()
    
    if action == "count":
        result = df.count()
    elif action == "collect":
        result = df.collect()
    elif action == "show":
        result = df.show(n=10, truncate=False)
    else:
        raise ValueError(f"Unsupported action: {action}")
        
    execution_time = time.time() - start
    
    print(f"{name} execution time: {execution_time:.4f} seconds")
    return execution_time, result

# Create a function to show both logical and physical plans with different detail levels
def analyze_plans(df, name="Query"):
    """Show logical and physical plans at different detail levels"""
    print(f"\n{'='*20} {name} Plan Analysis {'='*20}\n")
    
    print("--- Logical Plan (Parsed) ---")
    df._jdf.queryExecution().logical().toJSON()
    df._jdf.queryExecution().logical().toString()
    
    print("\n--- Logical Plan (Analyzed) ---")
    df._jdf.queryExecution().analyzed().toString()
    
    print("\n--- Logical Plan (Optimized) ---")
    df._jdf.queryExecution().optimizedPlan().toString()
    
    print("\n--- Physical Plan ---")
    df.explain()
    
    print("\n--- Detailed Physical Plan ---")
    df.explain(mode="extended")
    
    print("\n--- Cost Analysis ---")
    df.explain(mode="cost")
    
    print("\n--- Codegen Details ---")
    df.explain(mode="codegen")
    
    print(f"\n{'='*50}\n")
    
    return None

Now let's create a complex dataset with various data characteristics that will allow us to explore advanced optimization techniques:

In [4]:
# Create a larger, more complex dataset with multiple tables
# 1. Fact table (transactions) with millions of rows
# 2. Multiple dimension tables with different sizes
# 3. Complex data types (arrays, maps, structs)
# 4. Skewed data distributions
# 5. Partitioned and bucketed tables

# First, let's create dimension tables

# Products dimension
product_schema = StructType([
    StructField("product_id", IntegerType(), False),
    StructField("name", StringType(), False),
    StructField("category_id", IntegerType(), False),
    StructField("subcategory_id", IntegerType(), True),
    StructField("price", DoubleType(), False),
    StructField("attributes", MapType(StringType(), StringType()), True),
    StructField("tags", ArrayType(StringType()), True),
    StructField("created_at", TimestampType(), False)
])

# Generate 5,000 products
product_data = []
product_names = ["Laptop", "Phone", "Tablet", "TV", "Camera", "Headphones", "Speaker", "Watch", "Keyboard", "Mouse"]
product_prefixes = ["Pro", "Elite", "Ultra", "Mini", "Max", "Lite", "Premium", "Basic", "Standard", "Advanced"]
product_suffixes = ["2023", "Plus", "XL", "SE", "X", "S", "Air", "Book", "Pad", "Vision"]

categories = 10
subcategories = 5  # per category

# Tag options
all_tags = ["new", "bestseller", "sale", "clearance", "limited", "exclusive", "featured", "discontinued", "popular", "trending"]

# Attribute options
attribute_keys = ["color", "size", "weight", "material", "connectivity", "power", "warranty", "origin"]
attribute_values = {
    "color": ["black", "white", "silver", "gold", "blue", "red", "green"],
    "size": ["small", "medium", "large", "XL", "compact"],
    "weight": ["light", "medium", "heavy", "ultra-light"],
    "material": ["plastic", "metal", "glass", "aluminum", "carbon fiber"],
    "connectivity": ["bluetooth", "wifi", "wired", "usb-c", "lightning"],
    "power": ["battery", "plug-in", "solar", "hybrid"],
    "warranty": ["1-year", "2-year", "lifetime", "extended"],
    "origin": ["USA", "China", "Japan", "Korea", "Germany"]
}

start_date = datetime(2020, 1, 1)
end_date = datetime.now()
date_range = (end_date - start_date).days

for i in range(1, 5001):
    # Create skewed distribution for categories (more products in lower categories)
    category_id = int(random.paretovariate(1.5))
    if category_id > categories:
        category_id = categories
    subcategory_id = random.randint(1, subcategories) if random.random() > 0.1 else None
    
    # Generate product name
    name = f"{random.choice(product_prefixes)} {random.choice(product_names)} {random.choice(product_suffixes)}"
    
    # Generate price with some skew
    if category_id <= 3:  # Higher priced categories
        price = round(random.uniform(500, 2000), 2)
    else:
        price = round(random.uniform(10, 500), 2)
    
    # Generate random tags (0-5 tags)
    num_tags = random.randint(0, 5)
    tags = random.sample(all_tags, num_tags) if num_tags > 0 else None
    
    # Generate random attributes (0-5 attributes)
    num_attrs = random.randint(0, 5)
    if num_attrs > 0:
        selected_keys = random.sample(attribute_keys, num_attrs)
        attributes = {k: random.choice(attribute_values[k]) for k in selected_keys}
    else:
        attributes = None
    
    # Generate creation date
    random_days = random.randint(0, date_range)
    created_at = start_date + timedelta(days=random_days)
    
    product_data.append((i, name, category_id, subcategory_id, price, attributes, tags, created_at))

products_df = spark.createDataFrame(product_data, product_schema)
products_df.createOrReplaceTempView("products")

# Save as a permanent table for later use
products_df.write.mode("overwrite").saveAsTable("products_table")

print(f"Created products dimension table with {products_df.count()} rows")

                                                                                

Created products dimension table with 5000 rows


In [5]:
# Create customers dimension table
customer_schema = StructType([
    StructField("customer_id", IntegerType(), False),
    StructField("name", StringType(), False),
    StructField("email", StringType(), True),
    StructField("address", StructType([
        StructField("street", StringType(), True),
        StructField("city", StringType(), False),
        StructField("state", StringType(), False),
        StructField("zip", StringType(), True),
        StructField("country", StringType(), False)
    ]), True),
    StructField("phone", StringType(), True),
    StructField("signup_date", DateType(), False),
    StructField("last_activity", TimestampType(), True),
    StructField("tier", StringType(), False),
    StructField("preferences", MapType(StringType(), StringType()), True)
])

# Generate 50,000 customers
customer_data = []
first_names = ["James", "Mary", "John", "Patricia", "Robert", "Jennifer", "Michael", "Linda", "William", "Elizabeth"]
last_names = ["Smith", "Johnson", "Williams", "Brown", "Jones", "Miller", "Davis", "Garcia", "Rodriguez", "Wilson"]
cities = ["New York", "Los Angeles", "Chicago", "Houston", "Phoenix", "Philadelphia", "San Antonio", "San Diego", "Dallas", "San Jose"]
states = ["NY", "CA", "IL", "TX", "AZ", "PA", "TX", "CA", "TX", "CA"]
countries = ["USA" for _ in range(10)]
tiers = ["Bronze", "Silver", "Gold", "Platinum"]
tier_weights = [0.5, 0.3, 0.15, 0.05]  # Distribution weights

preference_keys = ["communication", "payment", "shipping", "recommendations", "notifications"]
preference_values = {
    "communication": ["email", "sms", "phone", "mail"],
    "payment": ["credit", "debit", "paypal", "crypto", "bank_transfer"],
    "shipping": ["standard", "express", "overnight", "pickup"],
    "recommendations": ["enabled", "disabled"],
    "notifications": ["high", "medium", "low", "none"]
}

for i in range(1, 50001):
    # Name and email
    first = random.choice(first_names)
    last = random.choice(last_names)
    name = f"{first} {last}"
    email = f"{first.lower()}.{last.lower()}@example.com" if random.random() > 0.1 else None
    
    # Address
    if random.random() > 0.05:  # 5% have no address
        street = f"{random.randint(100, 9999)} Main St" if random.random() > 0.1 else None
        city_idx = random.randint(0, 9)
        city = cities[city_idx]
        state = states[city_idx]
        zip_code = f"{random.randint(10000, 99999)}" if random.random() > 0.1 else None
        country = countries[city_idx]
        address = (street, city, state, zip_code, country)
    else:
        address = None
    
    # Phone
    phone = f"{random.randint(100, 999)}-{random.randint(100, 999)}-{random.randint(1000, 9999)}" if random.random() > 0.2 else None
    
    # Dates
    signup_days = random.randint(0, date_range)
    signup_date = start_date.date() + timedelta(days=signup_days)
    
    if random.random() > 0.1:  # 10% have no activity
        activity_days = random.randint(0, date_range - signup_days)  # Activity after signup
        last_activity = start_date + timedelta(days=signup_days + activity_days)
    else:
        last_activity = None
    
    # Tier (weighted selection)
    tier = random.choices(tiers, weights=tier_weights)[0]
    
    # Preferences
    if random.random() > 0.3:  # 30% have no preferences
        num_prefs = random.randint(1, len(preference_keys))
        selected_prefs = random.sample(preference_keys, num_prefs)
        preferences = {k: random.choice(preference_values[k]) for k in selected_prefs}
    else:
        preferences = None
        
    customer_data.append((i, name, email, address, phone, signup_date, last_activity, tier, preferences))

customers_df = spark.createDataFrame(customer_data, customer_schema)
customers_df.createOrReplaceTempView("customers")

# Save as a permanent table for later use
customers_df.write.mode("overwrite").saveAsTable("customers_table")

print(f"Created customers dimension table with {customers_df.count()} rows")

Created customers dimension table with 50000 rows


In [6]:
# Create a fact table (transactions) with highly skewed data
# Import Python's built-in sum function
from builtins import sum as py_sum

transaction_schema = StructType([
    StructField("transaction_id", IntegerType(), False),
    StructField("customer_id", IntegerType(), False),
    StructField("transaction_date", TimestampType(), False),
    StructField("items", ArrayType(StructType([
        StructField("product_id", IntegerType(), False),
        StructField("quantity", IntegerType(), False),
        StructField("price", DoubleType(), False),
        StructField("discount", DoubleType(), True)
    ])), False),
    StructField("payment_method", StringType(), False),
    StructField("status", StringType(), False),
    StructField("total_amount", DoubleType(), False),
    StructField("store_id", IntegerType(), True),
    StructField("metadata", MapType(StringType(), StringType()), True),
    StructField("year", IntegerType(), False),
    StructField("month", IntegerType(), False)
])

# Generate a smaller sample of 100,000 transactions for notebook performance
transaction_data = []
payment_methods = ["credit_card", "debit_card", "paypal", "apple_pay", "bank_transfer", "gift_card", "crypto"]
payment_weights = [0.4, 0.3, 0.1, 0.1, 0.05, 0.04, 0.01]  # Distribution weights
statuses = ["completed", "shipped", "delivered", "cancelled", "refunded", "processing"]
status_weights = [0.6, 0.2, 0.1, 0.05, 0.03, 0.02]  # Distribution weights
store_count = 50

metadata_keys = ["source", "device", "coupon_code", "promotion", "referral"]
metadata_values = {
    "source": ["web", "mobile_app", "store", "phone", "partner"],
    "device": ["desktop", "mobile", "tablet", "kiosk", "pos"],
    "coupon_code": ["SAVE10", "WELCOME20", "FLASH30", "SEASON25", None],
    "promotion": ["holiday_sale", "clearance", "new_customer", "loyalty", None],
    "referral": ["friend", "search", "social", "email", None]
}

# Create some skewed distributions
# 1. Few customers make many purchases (80/20 rule)
# 2. Some products are much more popular than others
# 3. Some days have much higher transaction volumes

# Calculate weighted customer IDs (power law distribution)
customer_weights = [1/(i**0.8) for i in range(1, 50001)]
# Then use it instead of the PySpark sum
total_weight = py_sum(customer_weights)
customer_weights = [w/total_weight for w in customer_weights]

# Calculate weighted product IDs (power law distribution)
product_weights = [1/(i**0.9) for i in range(1, 5001)]
# Then use it instead of the PySpark sum
total_weight = py_sum(product_weights)
product_weights = [w/total_weight for w in product_weights]

# Generate transactions
for i in range(1, 100001):
    # Select customer with power law distribution
    customer_id = int(np.random.choice(range(1, 50001), p=customer_weights))
    
    # Generate transaction date with seasonal patterns
    # More transactions during holidays and weekends
    random_days = random.randint(0, date_range)
    base_date = start_date + timedelta(days=random_days)
    
    # Adjust for seasonal patterns
    month = base_date.month
    day_of_week = base_date.weekday()
    
    # Boost December (holiday season)
    if month == 12 and random.random() < 0.6:
        # Shift to December
        holiday_year = random.choice([2020, 2021, 2022])
        base_date = datetime(holiday_year, 12, random.randint(1, 31))
    
    # Boost weekends
    if day_of_week < 5 and random.random() < 0.4:  # Weekday
        # Shift to a weekend
        weekend_offset = random.choice([5, 6]) - day_of_week  # Shift to Saturday or Sunday
        base_date = base_date + timedelta(days=weekend_offset)
    
    transaction_date = base_date
    
    # Generate 1-5 items per transaction
    num_items = random.choices([1, 2, 3, 4, 5], weights=[0.5, 0.25, 0.15, 0.07, 0.03])[0]
    items = []
    
    total_amount = 0.0
    for _ in range(num_items):
        # Select product with power law distribution
        product_id = int(np.random.choice(range(1, 5001), p=product_weights))
        quantity = random.randint(1, 5)
        
        # Price depends on product category
        if product_id <= 1000:  # High-priced items (ids 1-1000)
            price = round(random.uniform(500, 2000), 2)
        else:
            price = round(random.uniform(10, 500), 2)
        
        # Apply discount sometimes
        discount = round(price * random.uniform(0.05, 0.3), 2) if random.random() < 0.3 else None
        item_total = price * quantity
        if discount:
            item_total -= discount * quantity
            
        total_amount += item_total
        items.append((product_id, quantity, price, discount))
    
    # Select payment method and status with weighted distribution
    payment_method = random.choices(payment_methods, weights=payment_weights)[0]
    status = random.choices(statuses, weights=status_weights)[0]
    
    # Store ID (NULL for online orders)
    store_id = random.randint(1, store_count) if random.random() < 0.3 else None
    
    # Transaction metadata
    if random.random() > 0.2:  # 20% have no metadata
        num_meta = random.randint(1, 3)
        selected_meta = random.sample(metadata_keys, num_meta)
        metadata = {k: random.choice(metadata_values[k]) for k in selected_meta if metadata_values[k][-1] is not None or random.random() > 0.2}
    else:
        metadata = None
    
    # Add transaction to dataset
    transaction_data.append((
        i, customer_id, transaction_date, items, payment_method, status, 
        total_amount, store_id, metadata, transaction_date.year, transaction_date.month
    ))

transactions_df = spark.createDataFrame(transaction_data, transaction_schema)
transactions_df.createOrReplaceTempView("transactions")

# Save as partitioned table by year and month for partition pruning demos
transactions_df.write.partitionBy("year", "month").mode("overwrite").saveAsTable("transactions_table")

print(f"Created transactions fact table with {transactions_df.count()} rows")

                                                                                

Created transactions fact table with 100000 rows


## 3. Deep Dive into Logical and Physical Plans

Now that we have our test datasets ready, let's explore the various aspects of Spark's query planning and optimization in detail.

### 3.1 Analyzing Scan Operations

Scan operations determine how Spark reads data from sources and are often the first point of optimization. Let's explore different types of scans and their performance implications.

In [7]:
# Let's start with basic scan operations and examine their plans
# 1. Full table scan
full_scan = spark.table("products_table")
analyze_plans(full_scan, "Full Table Scan")

# Check execution time
time_execution_with_metrics(full_scan, name="Full table scan")



--- Logical Plan (Parsed) ---

--- Logical Plan (Analyzed) ---

--- Logical Plan (Optimized) ---

--- Physical Plan ---
== Physical Plan ==
*(1) ColumnarToRow
+- FileScan parquet spark_catalog.default.products_table[product_id#149,name#150,category_id#151,subcategory_id#152,price#153,attributes#154,tags#155,created_at#156] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/opt/spark/work-dir/spark-warehouse/products_table], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<product_id:int,name:string,category_id:int,subcategory_id:int,price:double,attributes:map<...



--- Detailed Physical Plan ---
== Parsed Logical Plan ==
'UnresolvedRelation [products_table], [], false

== Analyzed Logical Plan ==
product_id: int, name: string, category_id: int, subcategory_id: int, price: double, attributes: map<string,string>, tags: array<string>, created_at: timestamp
SubqueryAlias spark_catalog.default.products_table
+- Relation spark_catalog.d

(0.2021493911743164, 5000)

In [8]:
# 2. Column Pruning - Selecting only specific columns
column_pruning = spark.table("products_table").select("product_id", "name", "price")
analyze_plans(column_pruning, "Column Pruning")

# Performance comparison
pruning_time, _ = time_execution_with_metrics(column_pruning, name="Column pruning scan")

# Examine file format statistics
print("\nParquet file metadata:")
spark.sql("""
    DESCRIBE FORMATTED products_table
""").show(truncate=False)

# # Alternatively, you can use the DataFrame API directly:
# # Get table metadata using DataFrame API
# spark.table("products_table").describe().show(truncate=False)

# # For storage information
# print("\nParquet file metadata:")
# metadata = spark.sql("SHOW TABLE EXTENDED LIKE 'products_table'").collect()
# for row in metadata:
#     print(row)



--- Logical Plan (Parsed) ---

--- Logical Plan (Analyzed) ---

--- Logical Plan (Optimized) ---

--- Physical Plan ---
== Physical Plan ==
*(1) ColumnarToRow
+- FileScan parquet spark_catalog.default.products_table[product_id#149,name#150,price#153] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/opt/spark/work-dir/spark-warehouse/products_table], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<product_id:int,name:string,price:double>



--- Detailed Physical Plan ---
== Parsed Logical Plan ==
'Project ['product_id, 'name, 'price]
+- SubqueryAlias spark_catalog.default.products_table
   +- Relation spark_catalog.default.products_table[product_id#149,name#150,category_id#151,subcategory_id#152,price#153,attributes#154,tags#155,created_at#156] parquet

== Analyzed Logical Plan ==
product_id: int, name: string, price: double
Project [product_id#149, name#150, price#153]
+- SubqueryAlias spark_catalog.default.products_table
   +- R

In [9]:
# 3. Complex types handling in scans
# Reading and projecting into complex types can impact performance
complex_scan = spark.table("products_table").select(
    "product_id", 
    "name", 
    col("attributes").getItem("color").alias("color"),
    col("tags")[0].alias("first_tag")
)
analyze_plans(complex_scan, "Complex Type Field Access")

# Performance measurement
complex_time, _ = time_execution_with_metrics(complex_scan, name="Complex field access")



--- Logical Plan (Parsed) ---

--- Logical Plan (Analyzed) ---

--- Logical Plan (Optimized) ---

--- Physical Plan ---
== Physical Plan ==
*(1) Project [product_id#149, name#150, attributes#154[color] AS color#233, tags#155[0] AS first_tag#234]
+- *(1) ColumnarToRow
   +- FileScan parquet spark_catalog.default.products_table[product_id#149,name#150,attributes#154,tags#155] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/opt/spark/work-dir/spark-warehouse/products_table], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<product_id:int,name:string,attributes:map<string,string>,tags:array<string>>



--- Detailed Physical Plan ---
== Parsed Logical Plan ==
'Project ['product_id, 'name, 'attributes[color] AS color#233, 'tags[0] AS first_tag#234]
+- SubqueryAlias spark_catalog.default.products_table
   +- Relation spark_catalog.default.products_table[product_id#149,name#150,category_id#151,subcategory_id#152,price#153,attributes#154,

In [10]:
# 4. File format impact on scan
# Compare Parquet read vs JSON read

# First, save a sample in both JSON and Parquet format
sample_df = spark.table("products_table").limit(1000)
sample_df.write.json("/tmp/products_json", mode="overwrite")
sample_df.write.parquet("/tmp/products_parquet", mode="overwrite")  # Save to different path

# Now read them back and compare scan plans
json_scan = spark.read.json("/tmp/products_json")
parquet_scan = spark.read.parquet("/tmp/products_parquet")  # Read from the correct path

print("JSON Scan Plan:")
json_scan.explain()

print("\nParquet Scan Plan:")
parquet_scan.explain()

# Compare performance
json_time, _ = time_execution_with_metrics(json_scan, name="JSON scan")
parquet_time, _ = time_execution_with_metrics(parquet_scan, name="Parquet scan")

print(f"Performance difference: Parquet is {json_time/parquet_time:.2f}x faster than JSON")

JSON Scan Plan:
== Physical Plan ==
FileScan json [attributes#280,category_id#281L,created_at#282,name#283,price#284,product_id#285L,subcategory_id#286L,tags#287] Batched: false, DataFilters: [], Format: JSON, Location: InMemoryFileIndex(1 paths)[file:/tmp/products_json], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<attributes:struct<color:string,connectivity:string,material:string,origin:string,power:str...



Parquet Scan Plan:
== Physical Plan ==
*(1) ColumnarToRow
+- FileScan parquet [product_id#296,name#297,category_id#298,subcategory_id#299,price#300,attributes#301,tags#302,created_at#303] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/tmp/products_parquet], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<product_id:int,name:string,category_id:int,subcategory_id:int,price:double,attributes:map<...


JSON scan execution time: 0.0636 seconds
Parquet scan execution time: 0.0362 seconds
Performance difference: Pa

#### Advanced Scan Optimization Techniques

1. **File Format Selection**: The file format significantly impacts scan performance:
   - Parquet and ORC use columnar storage, allowing efficient column pruning and predicate pushdown
   - JSON and CSV require parsing the entire row, even if you only need a few columns
   - Avro provides good compression and schema evolution

2. **File Size and Splitting**: 
   - Too many small files create overhead ("small file problem")
   - Files that are too large might not utilize parallelism efficiently
   - Ideal file size typically ranges from 64MB to 1GB

3. **Statistics Utilization**:
   - File formats like Parquet store statistics (min/max values) for columns
   - These statistics enable file pruning before reading data
   - Computing and maintaining statistics for tables improves planner decisions

4. **Vectorization**:
   - Modern file formats support vectorized reads (batch processing)
   - Column vectors are processed more efficiently than row-by-row
   - Parquet and ORC have built-in support for vectorized reading

Let's examine file pruning using statistics:

In [11]:
# Examine statistics-based file pruning
# First, let's analyze a partitioned table to see partition pruning
partition_info = spark.sql("""
    SHOW PARTITIONS transactions_table
""").collect()

print(f"Number of partitions in transactions_table: {len(partition_info)}")
for p in partition_info[:10]:
    print(p[0])
print("...")

# Statistics-based file skipping with Parquet
high_price_scan = spark.table("products_table").filter(col("price") > 1500)
analyze_plans(high_price_scan, "Statistics-based Pruning")

# Force enable/disable stats collection for comparison
current_stats = spark.conf.get("spark.sql.statistics.histogram.enabled")
print(f"Current histogram stats setting: {current_stats}")

# Enable detailed statistics
spark.conf.set("spark.sql.statistics.histogram.enabled", "true")
spark.sql("ANALYZE TABLE products_table COMPUTE STATISTICS FOR COLUMNS price")

# Check collected statistics
spark.sql("""
    DESCRIBE EXTENDED products_table price
""").show(truncate=False)

Number of partitions in transactions_table: 64
year=2020/month=1
year=2020/month=10
year=2020/month=11
year=2020/month=12
year=2020/month=2
year=2020/month=3
year=2020/month=4
year=2020/month=5
year=2020/month=6
year=2020/month=7
...


--- Logical Plan (Parsed) ---

--- Logical Plan (Analyzed) ---

--- Logical Plan (Optimized) ---

--- Physical Plan ---
== Physical Plan ==
*(1) Filter (isnotnull(price#153) AND (price#153 > 1500.0))
+- *(1) ColumnarToRow
   +- FileScan parquet spark_catalog.default.products_table[product_id#149,name#150,category_id#151,subcategory_id#152,price#153,attributes#154,tags#155,created_at#156] Batched: true, DataFilters: [isnotnull(price#153), (price#153 > 1500.0)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/opt/spark/work-dir/spark-warehouse/products_table], PartitionFilters: [], PushedFilters: [IsNotNull(price), GreaterThan(price,1500.0)], ReadSchema: struct<product_id:int,name:string,category_id:int,subcategory_id:int,price:double,attribute

### 3.2 Understanding Filter and Projection Pushdown

Filter and projection pushdown are optimization techniques that reduce the amount of data read from disk and transferred between nodes. Let's examine these in detail:

In [12]:
# Investigate filter pushdown capabilities
simple_filter = spark.table("transactions_table").filter(col("total_amount") > 1000)
print("Simple Filter Query Plan:")
simple_filter.explain()

# Examine PushedFilters in the plan
print("\nDetailed Simple Filter Query Plan:")
simple_filter.explain(mode="extended")

Simple Filter Query Plan:
== Physical Plan ==
*(1) Filter (isnotnull(total_amount#6375) AND (total_amount#6375 > 1000.0))
+- *(1) ColumnarToRow
   +- FileScan parquet spark_catalog.default.transactions_table[transaction_id#6369,customer_id#6370,transaction_date#6371,items#6372,payment_method#6373,status#6374,total_amount#6375,store_id#6376,metadata#6377,year#6378,month#6379] Batched: true, DataFilters: [isnotnull(total_amount#6375), (total_amount#6375 > 1000.0)], Format: Parquet, Location: CatalogFileIndex(1 paths)[file:/opt/spark/work-dir/spark-warehouse/transactions_table], PartitionFilters: [], PushedFilters: [IsNotNull(total_amount), GreaterThan(total_amount,1000.0)], ReadSchema: struct<transaction_id:int,customer_id:int,transaction_date:timestamp,items:array<struct<product_i...



Detailed Simple Filter Query Plan:
== Parsed Logical Plan ==
'Filter ('total_amount > 1000)
+- SubqueryAlias spark_catalog.default.transactions_table
   +- Relation spark_catalog.default.transactions_tab

In [13]:
# Complex expressions that may or may not be pushed down
from pyspark.sql.functions import col, year, udf
# 1. Simple comparison (usually pushes down)
filter1 = spark.table("transactions_table").filter(col("total_amount") > 1000)

# 2. Compound condition (usually pushes down)
filter2 = spark.table("transactions_table").filter(
    (col("total_amount") > 1000) & 
    (col("payment_method") == "credit_card")
)

# 3. Functions that may not push down
filter3 = spark.table("transactions_table").filter(
    year(col("transaction_date")) == 2022
)

# 4. UDFs (typically don't push down)
@udf("boolean")
def is_expensive(amount):
    return amount > 1000

filter4 = spark.table("transactions_table").filter(
    is_expensive(col("total_amount"))
)

# Display plans to compare
print("Filter 1 - Simple Comparison:")
filter1.explain()

print("\nFilter 2 - Compound Condition:")
filter2.explain()

print("\nFilter 3 - Built-in Function:")
filter3.explain()

print("\nFilter 4 - UDF:")
filter4.explain()

Filter 1 - Simple Comparison:
== Physical Plan ==
*(1) Filter (isnotnull(total_amount#6375) AND (total_amount#6375 > 1000.0))
+- *(1) ColumnarToRow
   +- FileScan parquet spark_catalog.default.transactions_table[transaction_id#6369,customer_id#6370,transaction_date#6371,items#6372,payment_method#6373,status#6374,total_amount#6375,store_id#6376,metadata#6377,year#6378,month#6379] Batched: true, DataFilters: [isnotnull(total_amount#6375), (total_amount#6375 > 1000.0)], Format: Parquet, Location: CatalogFileIndex(1 paths)[file:/opt/spark/work-dir/spark-warehouse/transactions_table], PartitionFilters: [], PushedFilters: [IsNotNull(total_amount), GreaterThan(total_amount,1000.0)], ReadSchema: struct<transaction_id:int,customer_id:int,transaction_date:timestamp,items:array<struct<product_i...



Filter 2 - Compound Condition:
== Physical Plan ==
*(1) Filter (((isnotnull(total_amount#6375) AND isnotnull(payment_method#6373)) AND (total_amount#6375 > 1000.0)) AND (payment_method#6373 = credit_

In [14]:
# Performance impact of filter pushdown vs. no pushdown
# Using built-in functions (can be pushed down) vs UDFs (cannot be pushed down)

# Filter using built-in functions
time_start = time.time()
builtin_result = spark.table("transactions_table").filter(
    (col("total_amount") > 1000) & 
    (col("status") == "completed")
).select("transaction_id", "customer_id", "total_amount").count()
builtin_time = time.time() - time_start

# Filter using UDFs (prevents pushdown)
@udf("boolean")
def matches_criteria(amount, status):
    return amount > 1000 and status == "completed"

time_start = time.time()
udf_result = spark.table("transactions_table").filter(
    matches_criteria(col("total_amount"), col("status"))
).select("transaction_id", "customer_id", "total_amount").count()
udf_time = time.time() - time_start

print(f"Built-in function filter: {builtin_time:.4f} seconds")
print(f"UDF filter (no pushdown): {udf_time:.4f} seconds")
print(f"Performance difference: {udf_time/builtin_time:.2f}x slower without pushdown")



Built-in function filter: 1.7885 seconds
UDF filter (no pushdown): 1.1917 seconds
Performance difference: 0.67x slower without pushdown


                                                                                

#### Projection Pushdown and Column Pruning

Projection pushdown (also known as column pruning) works by reading only the columns needed for a query. This reduces I/O, network transfer, and memory usage.

In [15]:
# Compare column pruning effectiveness with different queries

# 1. Select few columns
few_columns = spark.table("products_table").select("product_id", "name", "price")

# 2. Select many columns
many_columns = spark.table("products_table").select("*")

# 3. Select with complex derived columns
complex_columns = spark.table("products_table").select(
    "product_id", 
    "name", 
    "price",
    col("attributes").getItem("color").alias("color"),
    col("attributes").getItem("size").alias("size"),
    col("tags")[0].alias("first_tag"),
    (col("price") * 1.1).alias("price_with_tax")
)

# Compare plans and execution times
print("Few Columns Query Plan:")
few_columns.explain()

print("\nMany Columns Query Plan:")
many_columns.explain()

print("\nComplex Columns Query Plan:")
complex_columns.explain()

# Measure performance
few_time, _ = time_execution_with_metrics(few_columns, name="Few columns")
many_time, _ = time_execution_with_metrics(many_columns, name="Many columns")
complex_time, _ = time_execution_with_metrics(complex_columns, name="Complex columns")

print(f"\nPerformance comparison:")
print(f"Reading all columns is {many_time/few_time:.2f}x slower than selecting few columns")
print(f"Complex column selection is {complex_time/few_time:.2f}x slower than simple column selection")

Few Columns Query Plan:
== Physical Plan ==
*(1) ColumnarToRow
+- FileScan parquet spark_catalog.default.products_table[product_id#6344,name#6345,price#6348] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/opt/spark/work-dir/spark-warehouse/products_table], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<product_id:int,name:string,price:double>



Many Columns Query Plan:
== Physical Plan ==
*(1) ColumnarToRow
+- FileScan parquet spark_catalog.default.products_table[product_id#6344,name#6345,category_id#6346,subcategory_id#6347,price#6348,attributes#6349,tags#6350,created_at#6351] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/opt/spark/work-dir/spark-warehouse/products_table], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<product_id:int,name:string,category_id:int,subcategory_id:int,price:double,attributes:map<...



Complex Columns Query Plan:
== Physical Plan ==
*(1) Pr

#### Advanced Filter Optimization Techniques

1. **Filter Order**: Apply the most selective filters first to reduce data volume early
2. **Partition Pruning**: Filter on partition columns for massive performance gains
3. **File Skipping**: Use filters that can leverage statistics for file-level skipping
4. **Predicate Reordering**: The optimizer can reorder predicates for better performance
5. **Avoid UDFs in Filters**: UDFs prevent pushdown, use built-in functions when possible

In [16]:
# Let's examine partition pruning with different filter orderings

# 1. Filter on partition columns first
partition_first = spark.table("transactions_table") \
    .filter(col("year") == 2022) \
    .filter(col("month") == 12) \
    .filter(col("total_amount") > 1000)

# 2. Filter on non-partition columns first
non_partition_first = spark.table("transactions_table") \
    .filter(col("total_amount") > 1000) \
    .filter(col("year") == 2022) \
    .filter(col("month") == 12)

# Compare plans
print("Partition First Query Plan:")
partition_first.explain()

print("\nNon-Partition First Query Plan:")
non_partition_first.explain()

# Measure performance
partition_time, _ = time_execution_with_metrics(partition_first, name="Partition first")
non_partition_time, _ = time_execution_with_metrics(non_partition_first, name="Non-partition first")

print(f"\nNote: Both plans should have similar performance due to the query optimizer's ability to reorder filters.")
print(f"Actual performance ratio: {non_partition_time/partition_time:.2f}x")

Partition First Query Plan:
== Physical Plan ==
*(1) Filter (isnotnull(total_amount#6375) AND (total_amount#6375 > 1000.0))
+- *(1) ColumnarToRow
   +- FileScan parquet spark_catalog.default.transactions_table[transaction_id#6369,customer_id#6370,transaction_date#6371,items#6372,payment_method#6373,status#6374,total_amount#6375,store_id#6376,metadata#6377,year#6378,month#6379] Batched: true, DataFilters: [isnotnull(total_amount#6375), (total_amount#6375 > 1000.0)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/opt/spark/work-dir/spark-warehouse/transactions_table/year=2022/..., PartitionFilters: [isnotnull(year#6378), isnotnull(month#6379), (year#6378 = 2022), (month#6379 = 12)], PushedFilters: [IsNotNull(total_amount), GreaterThan(total_amount,1000.0)], ReadSchema: struct<transaction_id:int,customer_id:int,transaction_date:timestamp,items:array<struct<product_i...



Non-Partition First Query Plan:
== Physical Plan ==
*(1) Filter (isnotnull(total_amount#6375) AND (total_

#### Dynamic Partition Pruning

Dynamic partition pruning is an optimization technique that determines which partitions to read based on filters that are only known at runtime. This often happens with join conditions.

In [17]:
# Create a sample partitioned table for this example
spark.sql("""
CREATE OR REPLACE TEMPORARY VIEW sales_by_month AS
SELECT
    year,
    month,
    COUNT(*) as transaction_count,
    SUM(total_amount) as total_sales
FROM transactions_table
GROUP BY year, month
""")

# Now let's try dynamic partition pruning with a join
query = spark.sql("""
SELECT t.transaction_id, t.customer_id, t.total_amount
FROM transactions_table t
JOIN sales_by_month s ON t.year = s.year AND t.month = s.month
WHERE s.total_sales > 100000
""")

print("Join with Dynamic Partition Pruning:")
query.explain()

# Check execution metrics
time_execution_with_metrics(query, name="Dynamic partition pruning")

Join with Dynamic Partition Pruning:
== Physical Plan ==
*(3) Project [transaction_id#6369, customer_id#6370, total_amount#6375]
+- *(3) BroadcastHashJoin [year#6378, month#6379], [year#6639, month#6640], Inner, BuildRight, false
   :- *(3) ColumnarToRow
   :  +- FileScan parquet spark_catalog.default.transactions_table[transaction_id#6369,customer_id#6370,total_amount#6375,year#6378,month#6379] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(64 paths)[file:/opt/spark/work-dir/spark-warehouse/transactions_table/year=2024..., PartitionFilters: [isnotnull(year#6378), isnotnull(month#6379), dynamicpruningexpression(year#6378 IN dynamicprunin..., PushedFilters: [], ReadSchema: struct<transaction_id:int,customer_id:int,total_amount:double>
   :        :- SubqueryBroadcast dynamicpruning#6645, 0, [year#6639, month#6640], [id=#1022]
   :        :  +- BroadcastExchange HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, true] as bigint), 32) | (cast(inpu

                                                                                

Dynamic partition pruning execution time: 1.2028 seconds


(1.202821969985962, 100000)

#### Optimizing Filter and Projection Pushdown

To maximize the benefits of filter and projection pushdown:

1. **Use compatible file formats**: Parquet, ORC, and Delta support pushdown effectively
2. **Select only required columns**: Avoid `select("*")` when possible
3. **Use built-in functions**: Avoid UDFs in filter conditions
4. **Filter on partition columns**: Design your partitioning scheme for common query patterns
5. **Compute and maintain statistics**: Enable better file skipping
6. **Apply filters early**: Push filters as close to the data source as possible

### 3.3 Detailed Join Strategy Analysis

Joins are often the most expensive operations in Spark queries. Understanding different join strategies and their performance characteristics is crucial for optimization.

In [18]:
# Let's first check the current broadcast join threshold configuration
broadcast_threshold = spark.conf.get("spark.sql.autoBroadcastJoinThreshold")
print(f"Current broadcast join threshold: {broadcast_threshold}")

# Check table sizes to understand broadcasting decisions
print("\nTable sizes:")
spark.sql("SELECT COUNT(*) AS transactions_count FROM transactions_table").show()
spark.sql("SELECT COUNT(*) AS customers_count FROM customers_table").show()
spark.sql("SELECT COUNT(*) AS products_count FROM products_table").show()

Current broadcast join threshold: 10m

Table sizes:
+------------------+
|transactions_count|
+------------------+
|            100000|
+------------------+

+---------------+
|customers_count|
+---------------+
|          50000|
+---------------+

+--------------+
|products_count|
+--------------+
|          5000|
+--------------+



In [19]:
# Let's examine different join types and how Spark chooses them
# First, create a common join query
customer_join = spark.sql("""
SELECT 
    t.transaction_id, 
    c.name as customer_name, 
    t.total_amount
FROM 
    transactions_table t
JOIN 
    customers_table c
ON 
    t.customer_id = c.customer_id
WHERE 
    t.year = 2022 AND t.total_amount > 500
""")

# Analyze the join plan
print("Default Join Strategy:")
customer_join.explain()

Default Join Strategy:
== Physical Plan ==
*(2) Project [transaction_id#6369, name#6698 AS customer_name#6724, total_amount#6375]
+- *(2) BroadcastHashJoin [customer_id#6370], [customer_id#6697], Inner, BuildLeft, false
   :- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[1, int, true] as bigint)),false), [plan_id=1594]
   :  +- *(1) Project [transaction_id#6369, customer_id#6370, total_amount#6375]
   :     +- *(1) Filter ((isnotnull(total_amount#6375) AND (total_amount#6375 > 500.0)) AND isnotnull(customer_id#6370))
   :        +- *(1) ColumnarToRow
   :           +- FileScan parquet spark_catalog.default.transactions_table[transaction_id#6369,customer_id#6370,total_amount#6375,year#6378,month#6379] Batched: true, DataFilters: [isnotnull(total_amount#6375), (total_amount#6375 > 500.0), isnotnull(customer_id#6370)], Format: Parquet, Location: InMemoryFileIndex(12 paths)[file:/opt/spark/work-dir/spark-warehouse/transactions_table/year=2022..., PartitionFilters: [isnotnul

In [20]:
# Now let's disable auto-broadcasting and force a shuffle hash join
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1")

# Run the same query with broadcast disabled
no_broadcast_join = spark.sql("""
SELECT 
    t.transaction_id, 
    c.name as customer_name, 
    t.total_amount
FROM 
    transactions_table t
JOIN 
    customers_table c
ON 
    t.customer_id = c.customer_id
WHERE 
    t.year = 2022 AND t.total_amount > 500
""")

print("Join without broadcasting:")
no_broadcast_join.explain()

Join without broadcasting:
== Physical Plan ==
*(5) Project [transaction_id#6369, name#6698 AS customer_name#6729, total_amount#6375]
+- *(5) SortMergeJoin [customer_id#6370], [customer_id#6697], Inner
   :- *(2) Sort [customer_id#6370 ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(customer_id#6370, 10), ENSURE_REQUIREMENTS, [plan_id=1659]
   :     +- *(1) Project [transaction_id#6369, customer_id#6370, total_amount#6375]
   :        +- *(1) Filter ((isnotnull(total_amount#6375) AND (total_amount#6375 > 500.0)) AND isnotnull(customer_id#6370))
   :           +- *(1) ColumnarToRow
   :              +- FileScan parquet spark_catalog.default.transactions_table[transaction_id#6369,customer_id#6370,total_amount#6375,year#6378,month#6379] Batched: true, DataFilters: [isnotnull(total_amount#6375), (total_amount#6375 > 500.0), isnotnull(customer_id#6370)], Format: Parquet, Location: InMemoryFileIndex(12 paths)[file:/opt/spark/work-dir/spark-warehouse/transactions_table/year=2022

In [21]:
# Explicitly force a broadcast join using the broadcast hint
# Reset broadcast threshold first
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", broadcast_threshold)

# Use broadcast hint
forced_broadcast = spark.sql("""
SELECT /*+ BROADCAST(c) */
    t.transaction_id, 
    c.name as customer_name, 
    t.total_amount
FROM 
    transactions_table t
JOIN 
    customers_table c
ON 
    t.customer_id = c.customer_id
WHERE 
    t.year = 2022 AND t.total_amount > 500
""")

print("Join with broadcast hint:")
forced_broadcast.explain()

Join with broadcast hint:
== Physical Plan ==
*(2) Project [transaction_id#6369, name#6698 AS customer_name#6733, total_amount#6375]
+- *(2) BroadcastHashJoin [customer_id#6370], [customer_id#6697], Inner, BuildRight, false
   :- *(2) Project [transaction_id#6369, customer_id#6370, total_amount#6375]
   :  +- *(2) Filter ((isnotnull(total_amount#6375) AND (total_amount#6375 > 500.0)) AND isnotnull(customer_id#6370))
   :     +- *(2) ColumnarToRow
   :        +- FileScan parquet spark_catalog.default.transactions_table[transaction_id#6369,customer_id#6370,total_amount#6375,year#6378,month#6379] Batched: true, DataFilters: [isnotnull(total_amount#6375), (total_amount#6375 > 500.0), isnotnull(customer_id#6370)], Format: Parquet, Location: InMemoryFileIndex(12 paths)[file:/opt/spark/work-dir/spark-warehouse/transactions_table/year=2022..., PartitionFilters: [isnotnull(year#6378), (year#6378 = 2022)], PushedFilters: [IsNotNull(total_amount), GreaterThan(total_amount,500.0), IsNotNull(custom

#### Comparison of Join Strategies

Let's compare the performance of different join strategies on the same query:

In [22]:
# Let's systematically compare different join strategies
# We'll use DataFrame API for easier control

# Create the base tables for joins
transactions = spark.table("transactions_table").filter((col("year") == 2022) & (col("total_amount") > 500))
customers = spark.table("customers_table")

# 1. Broadcast Hash Join
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "50m")  # Ensure broadcasting
broadcast_join = transactions.join(broadcast(customers), transactions["customer_id"] == customers["customer_id"])
print("Broadcast Hash Join Plan:")
broadcast_join.explain()
broadcast_time, _ = time_execution_with_metrics(broadcast_join, name="Broadcast Hash Join")

# 2. Sort Merge Join
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1")  # Disable broadcasting
spark.conf.set("spark.sql.join.preferSortMergeJoin", "true")  # Prefer sort merge
sort_merge_join = transactions.join(customers, transactions["customer_id"] == customers["customer_id"])
print("\nSort Merge Join Plan:")
sort_merge_join.explain()
sort_merge_time, _ = time_execution_with_metrics(sort_merge_join, name="Sort Merge Join")

# 3. Shuffle Hash Join
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1")  # Disable broadcasting
spark.conf.set("spark.sql.join.preferSortMergeJoin", "false")  # Disable preference for sort merge
shuffle_hash_join = transactions.join(customers, transactions["customer_id"] == customers["customer_id"])
print("\nShuffle Hash Join Plan:")
shuffle_hash_join.explain()
shuffle_hash_time, _ = time_execution_with_metrics(shuffle_hash_join, name="Shuffle Hash Join")

# Reset configs to default
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", broadcast_threshold)
spark.conf.set("spark.sql.join.preferSortMergeJoin", "true")

# Summarize results
print("\nJoin Strategy Performance Comparison:")
print(f"Broadcast Hash Join: {broadcast_time:.4f} seconds (baseline)")
print(f"Sort Merge Join: {sort_merge_time:.4f} seconds ({sort_merge_time/broadcast_time:.2f}x vs broadcast)")
print(f"Shuffle Hash Join: {shuffle_hash_time:.4f} seconds ({shuffle_hash_time/broadcast_time:.2f}x vs broadcast)")

Broadcast Hash Join Plan:
== Physical Plan ==
*(2) BroadcastHashJoin [customer_id#6370], [customer_id#6697], Inner, BuildRight, false
:- *(2) Filter ((isnotnull(total_amount#6375) AND (total_amount#6375 > 500.0)) AND isnotnull(customer_id#6370))
:  +- *(2) ColumnarToRow
:     +- FileScan parquet spark_catalog.default.transactions_table[transaction_id#6369,customer_id#6370,transaction_date#6371,items#6372,payment_method#6373,status#6374,total_amount#6375,store_id#6376,metadata#6377,year#6378,month#6379] Batched: true, DataFilters: [isnotnull(total_amount#6375), (total_amount#6375 > 500.0), isnotnull(customer_id#6370)], Format: Parquet, Location: InMemoryFileIndex(12 paths)[file:/opt/spark/work-dir/spark-warehouse/transactions_table/year=2022..., PartitionFilters: [isnotnull(year#6378), (year#6378 = 2022)], PushedFilters: [IsNotNull(total_amount), GreaterThan(total_amount,500.0), IsNotNull(customer_id)], ReadSchema: struct<transaction_id:int,customer_id:int,transaction_date:timestamp,ite

#### Join Order and Join Reordering

Join order can significantly impact performance. The Catalyst optimizer attempts to reorder joins optimally, but sometimes it needs help:

In [23]:
# Let's analyze join order optimization with a 3-way join
from pyspark.sql.functions import explode, col

# Check the schema to understand the structure
print("Transaction table schema:")
spark.table("transactions_table").printSchema()

# Use DataFrame API approach which is more reliable for nested structures
# Version without hints - fixing the ambiguous column references
complex_join_df = (
    spark.table("transactions_table").filter((col("year") == 2022) & (col("month") == 12))
    .select("transaction_id", "customer_id", "items")
    .join(spark.table("customers_table"), "customer_id")
    .select("transaction_id", col("name").alias("customer_name"), "items")
    .withColumn("item", explode("items"))
    .select("transaction_id", "customer_name", 
            col("item.product_id").alias("product_id"), 
            col("item.quantity"), col("item.price").alias("item_price"))
    .join(spark.table("products_table"), "product_id")
    .select("transaction_id", "customer_name", col("name").alias("product_name"), 
           "quantity", col("item_price").alias("price"))
)

# Analyze the plan
print("Complex Join Plan:")
complex_join_df.explain()

# Now try with hints using DataFrame API 
hinted_join_df = (
    spark.table("transactions_table").filter((col("year") == 2022) & (col("month") == 12))
    .select("transaction_id", "customer_id", "items")
    .join(spark.table("customers_table").hint("broadcast"), "customer_id")
    .select("transaction_id", col("name").alias("customer_name"), "items")
    .withColumn("item", explode("items"))
    .select("transaction_id", "customer_name", 
            col("item.product_id").alias("product_id"), 
            col("item.quantity"), col("item.price").alias("item_price"))
    .join(spark.table("products_table").hint("broadcast"), "product_id")
    .select("transaction_id", "customer_name", col("name").alias("product_name"), 
           "quantity", col("item_price").alias("price"))
)

print("\nComplex Join with Hints Plan:")
hinted_join_df.explain()

# Compare performance
complex_time, _ = time_execution_with_metrics(complex_join_df, name="Standard Join Order")
hinted_time, _ = time_execution_with_metrics(hinted_join_df, name="Hinted Join Order")

print(f"\nPerformance impact of join order: {complex_time/hinted_time:.2f}x improvement with hints")

Transaction table schema:
root
 |-- transaction_id: integer (nullable = true)
 |-- customer_id: integer (nullable = true)
 |-- transaction_date: timestamp (nullable = true)
 |-- items: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- product_id: integer (nullable = true)
 |    |    |-- quantity: integer (nullable = true)
 |    |    |-- price: double (nullable = true)
 |    |    |-- discount: double (nullable = true)
 |-- payment_method: string (nullable = true)
 |-- status: string (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- store_id: integer (nullable = true)
 |-- metadata: map (nullable = true)
 |    |-- key: string
 |    |-- value: string (valueContainsNull = true)
 |-- year: integer (nullable = true)
 |-- month: integer (nullable = true)

Complex Join Plan:
== Physical Plan ==
*(3) Project [transaction_id#6369, customer_name#6997, name#6345 AS product_name#7037, quantity#7010, item_price#7008 AS price#7038]
+- *(3) Broadca

#### Handling Data Skew in Joins

Data skew can cause severe performance problems in joins, especially with sort-merge and shuffle hash joins. Let's examine techniques to identify and handle skew:

In [24]:
# Let's create a query that will exhibit skew due to the distribution of our data
# Remember that we created transactions with power-law distribution of customer_ids
skewed_join = spark.sql("""
SELECT 
    t.transaction_id, 
    c.name as customer_name, 
    t.total_amount,
    t.transaction_date
FROM 
    transactions_table t
JOIN 
    customers_table c
ON 
    t.customer_id = c.customer_id
""")

# Force a sort merge join to see the impact of skew
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1")
print("Join with potential skew:")
skewed_join.explain()

# Check distribution of customer_ids in our transactions
customer_distribution = spark.sql("""
SELECT 
    customer_id, 
    COUNT(*) as transaction_count
FROM 
    transactions_table
GROUP BY 
    customer_id
ORDER BY 
    transaction_count DESC
LIMIT 10
""")

print("\nTop 10 customers by transaction count:")
customer_distribution.show()

Join with potential skew:
== Physical Plan ==
*(5) Project [transaction_id#6369, name#6698 AS customer_name#7173, total_amount#6375, transaction_date#6371]
+- *(5) SortMergeJoin [customer_id#6370], [customer_id#6697], Inner
   :- *(2) Sort [customer_id#6370 ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(customer_id#6370, 10), ENSURE_REQUIREMENTS, [plan_id=2904]
   :     +- *(1) Project [transaction_id#6369, customer_id#6370, transaction_date#6371, total_amount#6375]
   :        +- *(1) Filter isnotnull(customer_id#6370)
   :           +- *(1) ColumnarToRow
   :              +- FileScan parquet spark_catalog.default.transactions_table[transaction_id#6369,customer_id#6370,transaction_date#6371,total_amount#6375,year#6378,month#6379] Batched: true, DataFilters: [isnotnull(customer_id#6370)], Format: Parquet, Location: CatalogFileIndex(1 paths)[file:/opt/spark/work-dir/spark-warehouse/transactions_table], PartitionFilters: [], PushedFilters: [IsNotNull(customer_id)], ReadSch

In [26]:
# Techniques to handle skew in joins
from pyspark.sql.functions import when, lit, rand, col, explode, array

# 1. Using salting to distribute skewed keys
# First, identify skewed keys
skewed_keys = spark.sql("""
SELECT 
    customer_id 
FROM (
    SELECT 
        customer_id, 
        COUNT(*) as cnt
    FROM 
        transactions_table
    GROUP BY 
        customer_id
) t
WHERE 
    cnt > 10
""").collect()

skewed_key_list = [row[0] for row in skewed_keys]
print(f"Identified {len(skewed_key_list)} skewed customer keys")

# Apply salting technique for skewed keys
transactions_df = spark.table("transactions_table")
customers_df = spark.table("customers_table")

# Number of salt values
num_salts = 10

# Add salt to skewed transactions
salted_transactions = transactions_df.withColumn(
    "salt", 
    when(col("customer_id").isin(skewed_key_list), 
         (rand() * num_salts).cast("int"))
    .otherwise(lit(0))
)

# Duplicate customers for skewed keys
from pyspark.sql.functions import explode, array

# Create salt values for skewed customers
expanded_customers = customers_df.withColumn(
    "salt_values",
    when(col("customer_id").isin(skewed_key_list),
         array([lit(i) for i in range(num_salts)]))
    .otherwise(array(lit(0)))
)

# Explode to create multiple rows for skewed customers
salted_customers = expanded_customers \
    .withColumn("salt", explode("salt_values")) \
    .drop("salt_values")

# Join with salted keys
salted_join = salted_transactions.join(
    salted_customers,
    (salted_transactions["customer_id"] == salted_customers["customer_id"]) &
    (salted_transactions["salt"] == salted_customers["salt"])
)

print("\nPlan with salting technique:")
salted_join.explain()

# Compare performance
standard_time, _ = time_execution_with_metrics(skewed_join, name="Standard join with skew")
salted_time, _ = time_execution_with_metrics(salted_join, name="Join with salting")

print(f"\nPerformance impact of salting: {standard_time/salted_time:.2f}x improvement")

Identified 1046 skewed customer keys

Plan with salting technique:
== Physical Plan ==
*(5) SortMergeJoin [customer_id#6370, salt#7243], [customer_id#6697, salt#7268], Inner
:- *(2) Sort [customer_id#6370 ASC NULLS FIRST, salt#7243 ASC NULLS FIRST], false, 0
:  +- Exchange hashpartitioning(customer_id#6370, salt#7243, 10), ENSURE_REQUIREMENTS, [plan_id=3135]
:     +- *(1) Filter (isnotnull(customer_id#6370) AND isnotnull(salt#7243))
:        +- *(1) Project [transaction_id#6369, customer_id#6370, transaction_date#6371, items#6372, payment_method#6373, status#6374, total_amount#6375, store_id#6376, metadata#6377, year#6378, month#6379, CASE WHEN customer_id#6370 INSET 1, 10, 100, 1005, 1007, 1008, 1009, 101, 1012, 1013, 1016, 1017, 1018, 1019, 102, 1020, 1021, 1022, 1025, 1028, 103, 1031, 1033, 1039, 104, 1041, 1042, 1043, 1044, 1045, 1047, 105, 1053, 1059, 106, 1061, 1062, 1063, 107, 1070, 1071, 1072, 1074, 108, 1080, 1081, 1084, 109, 1092, 1098, 1099, 11, 110, 1103, 1107, 1108, 111, 1

#### Advanced Join Optimization Techniques

1. **Choose the Right Join Strategy**:
   - Broadcast Hash Join: Use for small-to-medium tables that can fit in memory
   - Sort Merge Join: Best for large tables with evenly distributed keys
   - Shuffle Hash Join: Can be faster than Sort Merge for medium-sized tables

2. **Join Order Optimization**:
   - Join the most filtered/smallest tables first
   - Use the LEADING hint to control join order
   - Filter tables before joining when possible

3. **Data Skew Handling**:
   - Use salting for skewed keys
   - Split processing of skewed and non-skewed data
   - Consider pre-aggregating data before joins

4. **Configuration Tuning**:
   - Adjust broadcast threshold based on your data size
   - Set appropriate shuffle partition count
   - Consider enabling adaptive query execution for dynamic optimization

5. **Schema Optimization**:
   - Ensure join keys have the same data type to avoid implicit conversions
   - Consider clustering or bucketing tables on join keys

6. **Join Rewriting**:
   - Convert complex multi-way joins to simpler joins when possible
   - Use subqueries or CTEs to break down complex joins

### 3.4 Optimizing Shuffle Operations

Shuffles are among the most expensive operations in Spark, involving disk I/O, serialization, network transfer, and deserialization. Understanding and optimizing shuffles is critical for performance.

In [27]:
# Let's examine the available shuffle configuration
print("Current shuffle configuration:")

# Use default values for all configs to avoid errors
print(f"spark.sql.shuffle.partitions: {spark.conf.get('spark.sql.shuffle.partitions', 'not set')}")

# Get parallelism through SparkContext
parallelism = spark.sparkContext.defaultParallelism
print(f"Default parallelism: {parallelism}")

# Check adaptive execution settings
print(f"spark.sql.adaptive.enabled: {spark.conf.get('spark.sql.adaptive.enabled', 'not set')}")

# Check all available configurations
print("\nListing all available SQL configurations:")
all_configs = spark.sql("SET -v").collect()
print(f"Total configurations: {len(all_configs)}")

# Print a few examples
for i, row in enumerate(all_configs[:5]):  # Just show first 5
    print(f"{row[0]}: {row[1]}")

Current shuffle configuration:
spark.sql.shuffle.partitions: 10
Default parallelism: 12
spark.sql.adaptive.enabled: false

Listing all available SQL configurations:
Total configurations: 206
spark.sql.adaptive.advisoryPartitionSizeInBytes: <value of spark.sql.adaptive.shuffle.targetPostShuffleInputSize>
spark.sql.adaptive.autoBroadcastJoinThreshold: <undefined>
spark.sql.adaptive.coalescePartitions.enabled: true
spark.sql.adaptive.coalescePartitions.initialPartitionNum: <undefined>
spark.sql.adaptive.coalescePartitions.minPartitionSize: 1MB


In [28]:
# Let's create queries that trigger different types of shuffles

# 1. Shuffle due to repartition
repartition_shuffle = spark.table("transactions_table").repartition(10)
print("Repartition Shuffle Plan:")
repartition_shuffle.explain()

# 2. Shuffle due to groupBy
group_shuffle = spark.table("transactions_table").groupBy("payment_method").count()
print("\nGroupBy Shuffle Plan:")
group_shuffle.explain()

# 3. Shuffle due to join (without broadcast)
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1")
join_shuffle = spark.table("transactions_table").join(
    spark.table("customers_table"),
    "customer_id"
)
print("\nJoin Shuffle Plan:")
join_shuffle.explain()

# 4. Shuffle due to window function
from pyspark.sql.window import Window
window_spec = Window.partitionBy("payment_method").orderBy("total_amount")
window_shuffle = spark.table("transactions_table").withColumn(
    "rank", rank().over(window_spec)
)
print("\nWindow Function Shuffle Plan:")
window_shuffle.explain()

Repartition Shuffle Plan:
== Physical Plan ==
Exchange RoundRobinPartitioning(10), REPARTITION_BY_NUM, [plan_id=3445]
+- *(1) ColumnarToRow
   +- FileScan parquet spark_catalog.default.transactions_table[transaction_id#6369,customer_id#6370,transaction_date#6371,items#6372,payment_method#6373,status#6374,total_amount#6375,store_id#6376,metadata#6377,year#6378,month#6379] Batched: true, DataFilters: [], Format: Parquet, Location: CatalogFileIndex(1 paths)[file:/opt/spark/work-dir/spark-warehouse/transactions_table], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<transaction_id:int,customer_id:int,transaction_date:timestamp,items:array<struct<product_i...



GroupBy Shuffle Plan:
== Physical Plan ==
*(2) HashAggregate(keys=[payment_method#6373], functions=[count(1)])
+- Exchange hashpartitioning(payment_method#6373, 10), ENSURE_REQUIREMENTS, [plan_id=3475]
   +- *(1) HashAggregate(keys=[payment_method#6373], functions=[partial_count(1)])
      +- *(1) Project [payment_method

#### Impact of Shuffle Partition Count

The number of shuffle partitions significantly impacts performance. Too few can lead to memory pressure, while too many create small tasks with overhead.

In [29]:
# First fix the time_execution_with_metrics function
import time

def time_execution_with_metrics(df, action="count", name="query"):
    """Time and collect metrics for a DataFrame action"""
    start = time.time()
    
    if action == "count":
        result = df.count()
    elif action == "collect":
        result = df.collect()
    else:
        result = df.count()  # Default to count
        
    end = time.time()
    duration = end - start
    
    print(f"{name} executed in {duration:.4f} seconds")
    return duration, result

# Now test different shuffle partition counts
from builtins import min as py_min  # Import Python's built-in min

partition_counts = [5, 20, 100, 200]
test_query = """
SELECT 
    t.customer_id,
    c.name,
    COUNT(*) as transaction_count,
    SUM(t.total_amount) as total_spent,
    AVG(t.total_amount) as avg_transaction
FROM 
    transactions_table t
JOIN 
    customers_table c
ON 
    t.customer_id = c.customer_id
GROUP BY 
    t.customer_id, c.name
ORDER BY 
    total_spent DESC
"""

results = []
for partitions in partition_counts:
    # Set the shuffle partitions
    spark.conf.set("spark.sql.shuffle.partitions", partitions)
    
    # Run the test query
    test_df = spark.sql(test_query)
    
    # Measure performance
    print(f"\nTesting with {partitions} shuffle partitions:")
    test_df.explain()
    execution_time, _ = time_execution_with_metrics(test_df, name=f"Shuffle with {partitions} partitions")
    
    results.append((partitions, execution_time))

# Summarize results
print("\nShuffle Partition Performance Summary:")
for partitions, time_value in results:
    print(f"Partitions: {partitions}, Time: {time_value:.4f} seconds")

# Find optimal partition count using Python's built-in min function
optimal_result = py_min(results, key=lambda x: x[1])
optimal_partitions, min_time = optimal_result

print(f"\nOptimal partition count: {optimal_partitions} with {min_time:.4f} seconds")

# Reset to a reasonable value
spark.conf.set("spark.sql.shuffle.partitions", optimal_partitions)


Testing with 5 shuffle partitions:
== Physical Plan ==
*(6) Sort [total_spent#7501 DESC NULLS LAST], true, 0
+- Exchange rangepartitioning(total_spent#7501 DESC NULLS LAST, 5), ENSURE_REQUIREMENTS, [plan_id=3686]
   +- *(5) HashAggregate(keys=[customer_id#6370, name#6698], functions=[count(1), sum(total_amount#6375), avg(total_amount#6375)])
      +- *(5) HashAggregate(keys=[customer_id#6370, name#6698], functions=[partial_count(1), partial_sum(total_amount#6375), partial_avg(total_amount#6375)])
         +- *(5) Project [customer_id#6370, total_amount#6375, name#6698]
            +- *(5) SortMergeJoin [customer_id#6370], [customer_id#6697], Inner
               :- *(2) Sort [customer_id#6370 ASC NULLS FIRST], false, 0
               :  +- Exchange hashpartitioning(customer_id#6370, 5), ENSURE_REQUIREMENTS, [plan_id=3667]
               :     +- *(1) Project [customer_id#6370, total_amount#6375]
               :        +- *(1) Filter isnotnull(customer_id#6370)
               :       

                                                                                

Shuffle with 100 partitions executed in 0.8455 seconds

Testing with 200 shuffle partitions:
== Physical Plan ==
*(6) Sort [total_spent#7615 DESC NULLS LAST], true, 0
+- Exchange rangepartitioning(total_spent#7615 DESC NULLS LAST, 200), ENSURE_REQUIREMENTS, [plan_id=4508]
   +- *(5) HashAggregate(keys=[customer_id#6370, name#6698], functions=[count(1), sum(total_amount#6375), avg(total_amount#6375)])
      +- *(5) HashAggregate(keys=[customer_id#6370, name#6698], functions=[partial_count(1), partial_sum(total_amount#6375), partial_avg(total_amount#6375)])
         +- *(5) Project [customer_id#6370, total_amount#6375, name#6698]
            +- *(5) SortMergeJoin [customer_id#6370], [customer_id#6697], Inner
               :- *(2) Sort [customer_id#6370 ASC NULLS FIRST], false, 0
               :  +- Exchange hashpartitioning(customer_id#6370, 200), ENSURE_REQUIREMENTS, [plan_id=4489]
               :     +- *(1) Project [customer_id#6370, total_amount#6375]
               :        +- *(



Shuffle with 200 partitions executed in 1.0131 seconds

Shuffle Partition Performance Summary:
Partitions: 5, Time: 0.6657 seconds
Partitions: 20, Time: 0.5944 seconds
Partitions: 100, Time: 0.8455 seconds
Partitions: 200, Time: 1.0131 seconds

Optimal partition count: 20 with 0.5944 seconds


                                                                                

#### Avoiding Unnecessary Shuffles

One of the best optimizations is to avoid unnecessary shuffles entirely. Let's explore techniques for this:

In [30]:
# Technique 1: Use broadcast joins to avoid shuffles
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "10m")

# Join with shuffle
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1")
shuffle_join = spark.table("transactions_table").filter(col("year") == 2022).join(
    spark.table("customers_table"),
    "customer_id"
)

# Join with broadcast to avoid shuffle
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "50m")
broadcast_join = spark.table("transactions_table").filter(col("year") == 2022).join(
    broadcast(spark.table("customers_table")),
    "customer_id"
)

print("Join with shuffle:")
shuffle_join.explain()

print("\nJoin with broadcast (no shuffle):")
broadcast_join.explain()

# Measure performance
shuffle_time, _ = time_execution_with_metrics(shuffle_join, name="Join with shuffle")
broadcast_time, _ = time_execution_with_metrics(broadcast_join, name="Join with broadcast")

print(f"\nPerformance improvement: {shuffle_time/broadcast_time:.2f}x faster with broadcast")

Join with shuffle:
== Physical Plan ==
*(2) Project [customer_id#6370, transaction_id#6369, transaction_date#6371, items#6372, payment_method#6373, status#6374, total_amount#6375, store_id#6376, metadata#6377, year#6378, month#6379, name#6698, email#6699, address#6700, phone#6701, signup_date#6702, last_activity#6703, tier#6704, preferences#6705]
+- *(2) BroadcastHashJoin [customer_id#6370], [customer_id#6697], Inner, BuildLeft, false
   :- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[1, int, false] as bigint)),false), [plan_id=4731]
   :  +- *(1) Filter isnotnull(customer_id#6370)
   :     +- *(1) ColumnarToRow
   :        +- FileScan parquet spark_catalog.default.transactions_table[transaction_id#6369,customer_id#6370,transaction_date#6371,items#6372,payment_method#6373,status#6374,total_amount#6375,store_id#6376,metadata#6377,year#6378,month#6379] Batched: true, DataFilters: [isnotnull(customer_id#6370)], Format: Parquet, Location: InMemoryFileIndex(12 paths)[file:/

In [31]:
# Technique 2: Use map-side aggregation to reduce shuffle data volume

# First check the configuration
print(f"spark.sql.shuffle.partitions: {spark.conf.get('spark.sql.shuffle.partitions')}")
print(f"spark.sql.adaptive.enabled: {spark.conf.get('spark.sql.adaptive.enabled')}")

# Standard aggregation (full shuffle)
standard_agg = spark.table("transactions_table").groupBy("customer_id").sum("total_amount")

# Two-phase aggregation (map-side + reduce)
# We manually repartition on the groupBy key first to colocate data
two_phase_agg = spark.table("transactions_table") \
    .repartition(20, "customer_id") \
    .groupBy("customer_id") \
    .sum("total_amount")

print("Standard Aggregation Plan:")
standard_agg.explain()

print("\nTwo-Phase Aggregation Plan:")
two_phase_agg.explain()

# Measure performance
standard_time, _ = time_execution_with_metrics(standard_agg, name="Standard aggregation")
two_phase_time, _ = time_execution_with_metrics(two_phase_agg, name="Two-phase aggregation")

print(f"\nPerformance comparison: Two-phase is {standard_time/two_phase_time:.2f}x vs standard")

spark.sql.shuffle.partitions: 20
spark.sql.adaptive.enabled: false
Standard Aggregation Plan:
== Physical Plan ==
*(2) HashAggregate(keys=[customer_id#6370], functions=[sum(total_amount#6375)])
+- Exchange hashpartitioning(customer_id#6370, 20), ENSURE_REQUIREMENTS, [plan_id=5008]
   +- *(1) HashAggregate(keys=[customer_id#6370], functions=[partial_sum(total_amount#6375)])
      +- *(1) Project [customer_id#6370, total_amount#6375]
         +- *(1) ColumnarToRow
            +- FileScan parquet spark_catalog.default.transactions_table[customer_id#6370,total_amount#6375,year#6378,month#6379] Batched: true, DataFilters: [], Format: Parquet, Location: CatalogFileIndex(1 paths)[file:/opt/spark/work-dir/spark-warehouse/transactions_table], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<customer_id:int,total_amount:double>



Two-Phase Aggregation Plan:
== Physical Plan ==
*(2) HashAggregate(keys=[customer_id#6370], functions=[sum(total_amount#6375)])
+- *(2) HashAggregate(keys=[

In [32]:
# Technique 3: Coalesce instead of repartition when reducing partitions

# Repartition creates a full shuffle
repartition_df = spark.table("transactions_table").repartition(10)

# Coalesce avoids a full shuffle when reducing partitions
coalesce_df = spark.table("transactions_table").coalesce(10)

print("Repartition Plan (full shuffle):")
repartition_df.explain()

print("\nCoalesce Plan (partial shuffle):")
coalesce_df.explain()

# Measure performance
repartition_time, _ = time_execution_with_metrics(repartition_df, name="Repartition")
coalesce_time, _ = time_execution_with_metrics(coalesce_df, name="Coalesce")

print(f"\nPerformance comparison: Coalesce is {repartition_time/coalesce_time:.2f}x faster than repartition")

Repartition Plan (full shuffle):
== Physical Plan ==
Exchange RoundRobinPartitioning(10), REPARTITION_BY_NUM, [plan_id=5200]
+- *(1) ColumnarToRow
   +- FileScan parquet spark_catalog.default.transactions_table[transaction_id#6369,customer_id#6370,transaction_date#6371,items#6372,payment_method#6373,status#6374,total_amount#6375,store_id#6376,metadata#6377,year#6378,month#6379] Batched: true, DataFilters: [], Format: Parquet, Location: CatalogFileIndex(1 paths)[file:/opt/spark/work-dir/spark-warehouse/transactions_table], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<transaction_id:int,customer_id:int,transaction_date:timestamp,items:array<struct<product_i...



Coalesce Plan (partial shuffle):
== Physical Plan ==
Coalesce 10
+- *(1) ColumnarToRow
   +- FileScan parquet spark_catalog.default.transactions_table[transaction_id#6369,customer_id#6370,transaction_date#6371,items#6372,payment_method#6373,status#6374,total_amount#6375,store_id#6376,metadata#6377,year#6378,month#

#### Adaptive Query Execution for Shuffle Optimization

Adaptive Query Execution (AQE) can dynamically optimize shuffles at runtime. Let's explore its impact:

In [33]:
# Let's test with and without AQE
test_query = """
SELECT 
    c.tier,
    t.payment_method,
    COUNT(*) as transaction_count,
    SUM(t.total_amount) as total_sales
FROM 
    transactions_table t
JOIN 
    customers_table c
ON 
    t.customer_id = c.customer_id
GROUP BY 
    c.tier, t.payment_method
ORDER BY 
    total_sales DESC
"""

# Set a high shuffle partition count to see AQE's impact
spark.conf.set("spark.sql.shuffle.partitions", 200)

# Test without AQE
spark.conf.set("spark.sql.adaptive.enabled", "false")
non_adaptive_df = spark.sql(test_query)

print("Plan without Adaptive Execution:")
non_adaptive_df.explain()

# Test with AQE
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")
adaptive_df = spark.sql(test_query)

print("\nPlan with Adaptive Execution:")
adaptive_df.explain("formatted")

# Measure performance
non_adaptive_time, _ = time_execution_with_metrics(non_adaptive_df, name="Non-adaptive execution")
adaptive_time, _ = time_execution_with_metrics(adaptive_df, name="Adaptive execution")

print(f"\nPerformance comparison: AQE is {non_adaptive_time/adaptive_time:.2f}x faster")

# Reset configurations
spark.conf.set("spark.sql.shuffle.partitions", optimal_partitions)

Plan without Adaptive Execution:
== Physical Plan ==
*(4) Sort [total_sales#7904 DESC NULLS LAST], true, 0
+- Exchange rangepartitioning(total_sales#7904 DESC NULLS LAST, 200), ENSURE_REQUIREMENTS, [plan_id=5393]
   +- *(3) HashAggregate(keys=[tier#6704, payment_method#6373], functions=[count(1), sum(total_amount#6375)])
      +- Exchange hashpartitioning(tier#6704, payment_method#6373, 200), ENSURE_REQUIREMENTS, [plan_id=5389]
         +- *(2) HashAggregate(keys=[tier#6704, payment_method#6373], functions=[partial_count(1), partial_sum(total_amount#6375)])
            +- *(2) Project [payment_method#6373, total_amount#6375, tier#6704]
               +- *(2) BroadcastHashJoin [customer_id#6370], [customer_id#6697], Inner, BuildRight, false
                  :- *(2) Project [customer_id#6370, payment_method#6373, total_amount#6375]
                  :  +- *(2) Filter isnotnull(customer_id#6370)
                  :     +- *(2) ColumnarToRow
                  :        +- FileScan parquet 

#### Advanced Shuffle Optimization Techniques

1. **Tune Shuffle Partitions**: 
   - For small-medium datasets: Use 2-3× the number of CPU cores
   - For large datasets: Consider 1-2× the number of cores per GB of shuffle data
   - Monitor task duration - target 50-200ms per task

2. **Avoid Unnecessary Shuffles**: 
   - Use broadcast joins when possible
   - Use coalesce instead of repartition when reducing partitions
   - Reuse existing partitioning when possible

3. **Optimize Data Size**:
   - Pre-aggregate or filter data before shuffling
   - Select only required columns before shuffling
   - Use efficient serialization formats

4. **Balance Data Distribution**:
   - Address data skew (as covered in join optimization)
   - Consider custom partitioners for better balance

5. **Use Adaptive Execution**:
   - Enable adaptive query execution for dynamic optimization
   - Let Spark automatically adjust partition counts
   - Enable skew join optimization

6. **Monitor Shuffle Service**:
   - Check for external shuffle service health
   - Ensure sufficient disk space for shuffle files
   - Consider configuring `spark.local.dir` for faster disks