# 03 - Join Optimization: Sort-Merge vs. Broadcast Hash Join

**Objective:** Compare performance of different join strategies in Spark.

This notebook demonstrates:
1. **Sort-Merge Join**: Default join for large tables (requires shuffle)
2. **Broadcast Hash Join**: Optimized join for small dimension tables
3. **Impact of Z-Ordering**: Effect on join performance
4. **Physical Plan Analysis**: Understanding Spark's execution plan

---

## Setup and Imports

In [None]:
# Add src directory to path
import sys
from pathlib import Path

notebook_dir = Path.cwd()
project_root = notebook_dir.parent
src_dir = project_root / "src"
sys.path.insert(0, str(src_dir))

print(f"Project root: {project_root}")
print(f"Src directory: {src_dir}")

In [None]:
# Import required libraries
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.functions import broadcast
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Import project modules
from config import (
    get_data_path,
    FACT_SALES_TABLE,
    DIM_CUSTOMERS_TABLE,
    SPARK_APP_NAME,
    PLOTS_DIR
)
from benchmark_utils import BenchmarkTimer, print_benchmark_summary

# Set plotting style
sns.set_theme(style="whitegrid")
plt.rcParams['figure.figsize'] = (12, 6)

print("✓ All imports successful")

## Initialize Spark Session

We'll configure Spark with specific settings to control join behavior.

In [None]:
# Create Spark session with Delta Lake support
spark = (
    SparkSession.builder
    .appName(f"{SPARK_APP_NAME} - Join Optimization")
    .master("local[*]")
    .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
    .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog")
    .config("spark.driver.memory", "4g")
    .config("spark.sql.shuffle.partitions", "8")
    # Configure broadcast join threshold (10MB default)
    .config("spark.sql.autoBroadcastJoinThreshold", "10485760")  # 10 MB
    .getOrCreate()
)

print(f"✓ Spark {spark.version} session initialized")
print(f"✓ App Name: {spark.sparkContext.appName}")
print(f"✓ Broadcast Join Threshold: {spark.conf.get('spark.sql.autoBroadcastJoinThreshold')} bytes")

## Load Data from Parquet Format

We'll use Parquet format for these benchmarks to ensure optimal read performance.

In [None]:
# Load fact_sales from Parquet
print("Loading fact_sales table...")
fact_sales_path = str(get_data_path("parquet", FACT_SALES_TABLE))
fact_sales = spark.read.parquet(fact_sales_path)
print(f"✓ Loaded {fact_sales.count():,} sales records")
print("\nSchema:")
fact_sales.printSchema()

In [None]:
# Load dim_customers from Parquet
print("Loading dim_customers table...")
dim_customers_path = str(get_data_path("parquet", DIM_CUSTOMERS_TABLE))
dim_customers = spark.read.parquet(dim_customers_path)
print(f"✓ Loaded {dim_customers.count():,} customer records")
print("\nSchema:")
dim_customers.printSchema()

## Understanding Join Strategies

### Sort-Merge Join
- Both tables are sorted by join key
- Requires shuffle (data movement across nodes)
- Good for large-large table joins
- More expensive due to shuffle overhead

### Broadcast Hash Join
- Small table is broadcasted to all nodes
- No shuffle required
- Excellent for large-small table joins (dimension tables)
- Limited by broadcast threshold

## Scenario A: Sort-Merge Join (Without Optimization)

First, we'll disable Adaptive Query Execution (AQE) to force a Sort-Merge Join.

In [None]:
# Disable Adaptive Query Execution for controlled testing
spark.conf.set("spark.sql.adaptive.enabled", "false")
print("✓ AQE disabled for controlled join testing")
print(f"✓ spark.sql.adaptive.enabled = {spark.conf.get('spark.sql.adaptive.enabled')}")

In [None]:
# Set broadcast threshold to -1 to disable automatic broadcast joins
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1")
print("✓ Broadcast joins disabled (threshold = -1)")
print("✓ This will force Sort-Merge Join")

In [None]:
# Perform Sort-Merge Join
with BenchmarkTimer(
    "Sort-Merge Join (Parquet)",
    description="Join fact_sales with dim_customers using Sort-Merge",
    spark=spark,
    clear_cache=True
):
    # Reload data to ensure clean state
    sales_df = spark.read.parquet(fact_sales_path)
    customers_df = spark.read.parquet(dim_customers_path)
    
    # Perform join
    joined_df = sales_df.join(
        customers_df,
        on="customer_id",
        how="inner"
    )
    
    # Trigger execution with an aggregation
    result = joined_df.groupBy("region").agg(
        F.sum("amount").alias("total_sales"),
        F.count("*").alias("num_transactions")
    ).collect()
    
    print(f"\nRegions found: {len(result)}")
    for row in result:
        print(f"  {row['region']}: ${row['total_sales']:,.2f} ({row['num_transactions']:,} transactions)")

In [None]:
# Examine the physical plan to verify Sort-Merge Join was used
print("\nPhysical Plan (Sort-Merge Join):")
print("="*80)
sales_df = spark.read.parquet(fact_sales_path)
customers_df = spark.read.parquet(dim_customers_path)
joined_df = sales_df.join(customers_df, on="customer_id", how="inner")
joined_df.explain(mode="formatted")

## Scenario B: Broadcast Hash Join (Optimized)

Now we'll use the `broadcast()` hint to force a Broadcast Hash Join.

In [None]:
# Re-enable automatic broadcast joins
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "10485760")  # 10 MB
print("✓ Broadcast joins re-enabled")
print(f"✓ Threshold: {spark.conf.get('spark.sql.autoBroadcastJoinThreshold')} bytes")

In [None]:
# Perform Broadcast Hash Join with explicit broadcast hint
with BenchmarkTimer(
    "Broadcast Hash Join (Parquet)",
    description="Join with broadcast hint on dim_customers",
    spark=spark,
    clear_cache=True
):
    # Reload data to ensure clean state
    sales_df = spark.read.parquet(fact_sales_path)
    customers_df = spark.read.parquet(dim_customers_path)
    
    # Perform join with broadcast hint
    joined_df = sales_df.join(
        broadcast(customers_df),  # Broadcast the smaller dimension table
        on="customer_id",
        how="inner"
    )
    
    # Trigger execution with the same aggregation
    result = joined_df.groupBy("region").agg(
        F.sum("amount").alias("total_sales"),
        F.count("*").alias("num_transactions")
    ).collect()
    
    print(f"\nRegions found: {len(result)}")
    for row in result:
        print(f"  {row['region']}: ${row['total_sales']:,.2f} ({row['num_transactions']:,} transactions)")

In [None]:
# Examine the physical plan to verify Broadcast Hash Join was used
print("\nPhysical Plan (Broadcast Hash Join):")
print("="*80)
sales_df = spark.read.parquet(fact_sales_path)
customers_df = spark.read.parquet(dim_customers_path)
joined_df = sales_df.join(broadcast(customers_df), on="customer_id", how="inner")
joined_df.explain(mode="formatted")

## Scenario C: Join with Delta Lake (Z-Ordered)

Test join performance with Delta Lake tables optimized with Z-Ordering.

In [None]:
# Re-enable AQE for Delta Lake testing
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "10485760")
print("✓ AQE re-enabled for Delta Lake testing")

In [None]:
# Perform join with Delta Lake (Z-Ordered on customer_id)
with BenchmarkTimer(
    "Broadcast Hash Join (Delta Z-Ordered)",
    description="Join with Z-Ordered Delta table",
    spark=spark,
    clear_cache=True
):
    # Load from Delta Lake
    delta_sales_path = str(get_data_path("delta", FACT_SALES_TABLE))
    delta_customers_path = str(get_data_path("delta", DIM_CUSTOMERS_TABLE))
    
    sales_df = spark.read.format("delta").load(delta_sales_path)
    customers_df = spark.read.format("delta").load(delta_customers_path)
    
    # Perform join with broadcast hint
    joined_df = sales_df.join(
        broadcast(customers_df),
        on="customer_id",
        how="inner"
    )
    
    # Trigger execution
    result = joined_df.groupBy("region").agg(
        F.sum("amount").alias("total_sales"),
        F.count("*").alias("num_transactions")
    ).collect()
    
    print(f"\nRegions found: {len(result)}")
    for row in result:
        print(f"  {row['region']}: ${row['total_sales']:,.2f} ({row['num_transactions']:,} transactions)")

## Scenario D: Sort-Merge Join with Delta (for comparison)

In [None]:
# Disable broadcast for Delta Sort-Merge comparison
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1")
spark.conf.set("spark.sql.adaptive.enabled", "false")

with BenchmarkTimer(
    "Sort-Merge Join (Delta Z-Ordered)",
    description="Sort-Merge with Z-Ordered Delta table",
    spark=spark,
    clear_cache=True
):
    # Load from Delta Lake
    delta_sales_path = str(get_data_path("delta", FACT_SALES_TABLE))
    delta_customers_path = str(get_data_path("delta", DIM_CUSTOMERS_TABLE))
    
    sales_df = spark.read.format("delta").load(delta_sales_path)
    customers_df = spark.read.format("delta").load(delta_customers_path)
    
    # Perform join without broadcast hint
    joined_df = sales_df.join(
        customers_df,
        on="customer_id",
        how="inner"
    )
    
    # Trigger execution
    result = joined_df.groupBy("region").agg(
        F.sum("amount").alias("total_sales"),
        F.count("*").alias("num_transactions")
    ).collect()
    
    print(f"\nRegions found: {len(result)}")

# Reset configurations
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "10485760")
spark.conf.set("spark.sql.adaptive.enabled", "true")

## Results Analysis and Visualization

In [None]:
# Load benchmark results for join tests
from config import BENCHMARK_LOG_FILE
import csv

join_results = []
with open(BENCHMARK_LOG_FILE, 'r', encoding='utf-8') as f:
    reader = csv.DictReader(f)
    for row in reader:
        if 'Join' in row['test_name'] and row['status'] == 'SUCCESS':
            join_results.append(row)

# Create DataFrame for analysis
results_df = pd.DataFrame(join_results)
results_df['duration_seconds'] = results_df['duration_seconds'].astype(float)

# Display recent join benchmarks
print("\nJoin Benchmark Results:")
print(results_df[['test_name', 'duration_seconds']].tail(10).to_string(index=False))

In [None]:
# Visualize join performance comparison
recent_joins = results_df.tail(4).copy()

if len(recent_joins) >= 4:
    fig, ax = plt.subplots(figsize=(12, 6))
    
    # Shorten test names for display
    labels = [
        'Sort-Merge\n(Parquet)',
        'Broadcast\n(Parquet)',
        'Broadcast\n(Delta Z-Order)',
        'Sort-Merge\n(Delta Z-Order)'
    ]
    
    durations = recent_joins['duration_seconds'].values
    colors = ['#e74c3c', '#2ecc71', '#3498db', '#f39c12']
    
    bars = ax.bar(labels, durations, color=colors, width=0.6)
    
    # Add value labels on bars
    for bar, duration in zip(bars, durations):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{duration:.2f}s',
                ha='center', va='bottom', fontweight='bold', fontsize=11)
    
    ax.set_title('Join Strategy Performance Comparison', fontsize=16, fontweight='bold')
    ax.set_ylabel('Duration (seconds)', fontsize=12)
    ax.set_xlabel('Join Strategy', fontsize=12)
    ax.grid(axis='y', alpha=0.3)
    
    # Calculate and display speedup
    if durations[0] > 0:
        speedup = durations[0] / durations[1]
        ax.text(0.98, 0.98, f'Broadcast Speedup: {speedup:.1f}x',
                transform=ax.transAxes, fontsize=12, verticalalignment='top',
                horizontalalignment='right', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    plt.tight_layout()
    plt.savefig(PLOTS_DIR / 'join_performance_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"✓ Plot saved to: {PLOTS_DIR / 'join_performance_comparison.png'}")
else:
    print("⚠ Not enough join benchmark data for visualization")

## Performance Analysis and Insights

In [None]:
# Calculate performance improvements
if len(recent_joins) >= 2:
    sort_merge_time = recent_joins.iloc[0]['duration_seconds']
    broadcast_time = recent_joins.iloc[1]['duration_seconds']
    
    improvement = ((sort_merge_time - broadcast_time) / sort_merge_time) * 100
    speedup = sort_merge_time / broadcast_time
    
    print("\n" + "="*80)
    print("PERFORMANCE ANALYSIS")
    print("="*80)
    print(f"\nSort-Merge Join Time:     {sort_merge_time:.3f} seconds")
    print(f"Broadcast Hash Join Time: {broadcast_time:.3f} seconds")
    print(f"\nImprovement: {improvement:.1f}%")
    print(f"Speedup:     {speedup:.2f}x faster")
    print("="*80)
else:
    print("⚠ Not enough data for performance analysis")

## Summary Report

In [None]:
# Print comprehensive summary
print_benchmark_summary()

print("\n" + "="*80)
print("KEY FINDINGS - JOIN OPTIMIZATION")
print("="*80)
print("""
1. BROADCAST HASH JOIN:
   - Significantly faster than Sort-Merge Join (typically 2-5x)
   - No shuffle required - dimension table sent to all nodes
   - Ideal for large fact table + small dimension table (star schema)
   - Limited by spark.sql.autoBroadcastJoinThreshold

2. SORT-MERGE JOIN:
   - Required when both tables are large
   - Involves expensive shuffle operations
   - Both sides must be sorted by join key
   - Can benefit from partitioning and bucketing

3. DELTA LAKE Z-ORDERING:
   - Improves data locality for join keys
   - Can reduce I/O by co-locating related data
   - Especially beneficial for Sort-Merge Joins
   - Combines well with data skipping

4. ADAPTIVE QUERY EXECUTION (AQE):
   - Dynamically optimizes joins at runtime
   - Can convert Sort-Merge to Broadcast automatically
   - Reduces shuffle overhead when possible

5. BEST PRACTICES:
   - Use broadcast() hint for dimension tables < 100MB
   - Consider partitioning large fact tables by join keys
   - Enable AQE in production (spark.sql.adaptive.enabled=true)
   - Monitor Spark UI to verify join strategy selection
   - Use Z-Ordering on frequently joined columns in Delta Lake
""")
print("="*80)
print("\n✓ Join optimization benchmark completed!")
print("✓ All benchmarks finished. Check results/ directory for logs and plots.")
print("="*80)