# Spark Partitioning and Optimization

This notebook covers essential concepts for optimizing Apache Spark applications through effective partitioning strategies, minimizing shuffles, and leveraging caching mechanisms.

## Table of Contents
1. [Data Partitioning in Spark](#data-partitioning)
2. [Shuffle Operations](#shuffle-operations)
3. [Broadcast Joins](#broadcast-joins)
4. [Partition Pruning](#partition-pruning)
5. [Caching and Persistence](#caching-persistence)
6. [Key Takeaways](#takeaways)

---

## 1. Data Partitioning in Spark <a id='data-partitioning'></a>

**Partitioning** is the process of dividing data into smaller, manageable chunks (partitions) that can be processed in parallel across a cluster.

### Why Partitioning Matters

| Aspect | Impact |
|--------|--------|
| **Parallelism** | More partitions = more parallel tasks |
| **Memory** | Each partition must fit in executor memory |
| **Shuffles** | Poor partitioning leads to expensive data movement |
| **Skew** | Uneven partitions cause stragglers |

### Types of Partitioning

1. **Hash Partitioning**: Distributes data based on hash of partition key
2. **Range Partitioning**: Distributes data based on ordered ranges
3. **Round-Robin Partitioning**: Distributes data evenly regardless of content

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, broadcast, spark_partition_id
import pyspark.sql.functions as F

# Initialize Spark Session
spark = SparkSession.builder \
    .appName("PartitioningOptimization") \
    .config("spark.sql.shuffle.partitions", "200") \
    .config("spark.sql.adaptive.enabled", "true") \
    .getOrCreate()

print(f"Spark version: {spark.version}")
print(f"Default parallelism: {spark.sparkContext.defaultParallelism}")

In [None]:
# Create sample data
data = [(i, f"user_{i % 100}", i * 10) for i in range(10000)]
df = spark.createDataFrame(data, ["id", "user", "amount"])

# Check current number of partitions
print(f"Number of partitions: {df.rdd.getNumPartitions()}")

# View partition distribution
df.withColumn("partition_id", spark_partition_id()) \
    .groupBy("partition_id") \
    .count() \
    .orderBy("partition_id") \
    .show(5)

### Controlling Partitions

- **`repartition(n)`**: Increases or decreases partitions (causes full shuffle)
- **`coalesce(n)`**: Decreases partitions (avoids full shuffle, more efficient)
- **`repartition(col)`**: Partitions by specific column(s)

In [None]:
# Repartition by number
df_8_parts = df.repartition(8)
print(f"After repartition(8): {df_8_parts.rdd.getNumPartitions()} partitions")

# Coalesce to reduce partitions (no shuffle)
df_4_parts = df_8_parts.coalesce(4)
print(f"After coalesce(4): {df_4_parts.rdd.getNumPartitions()} partitions")

# Repartition by column (useful for joins)
df_by_user = df.repartition(10, "user")
print(f"After repartition by 'user': {df_by_user.rdd.getNumPartitions()} partitions")

# Check partition distribution after repartitioning by column
df_by_user.withColumn("partition_id", spark_partition_id()) \
    .groupBy("partition_id") \
    .count() \
    .orderBy("partition_id") \
    .show()

---

## 2. Shuffle Operations <a id='shuffle-operations'></a>

A **shuffle** is Spark's mechanism for redistributing data across partitions. It's one of the most expensive operations.

### Operations That Cause Shuffles

| Operation | Description |
|-----------|-------------|
| `groupBy()` | Aggregates require data with same key on same partition |
| `join()` | Default sort-merge join shuffles both datasets |
| `distinct()` | Requires comparing all records |
| `repartition()` | Explicitly redistributes data |
| `orderBy()` / `sort()` | Global ordering requires data movement |

### Minimizing Shuffles

1. **Pre-partition data** on join keys
2. **Use broadcast joins** for small tables
3. **Filter early** to reduce data volume
4. **Use `reduceByKey`** instead of `groupByKey` (for RDDs)

In [None]:
# Example: Shuffle caused by groupBy
aggregated_df = df.groupBy("user").agg(
    F.sum("amount").alias("total_amount"),
    F.count("*").alias("transaction_count")
)

# Explain the execution plan to see shuffle
print("=== GroupBy Execution Plan ===")
aggregated_df.explain()

In [None]:
# Create two DataFrames for join example
orders = spark.createDataFrame(
    [(1, "user_1", 100), (2, "user_2", 200), (3, "user_1", 150)],
    ["order_id", "user_id", "amount"]
)

users = spark.createDataFrame(
    [("user_1", "Alice"), ("user_2", "Bob"), ("user_3", "Charlie")],
    ["user_id", "name"]
)

# Default join causes shuffle on both sides
joined_df = orders.join(users, "user_id")
print("=== Default Join (Sort-Merge) Execution Plan ===")
joined_df.explain()

In [None]:
# Minimize shuffles by pre-partitioning on join key
orders_partitioned = orders.repartition(4, "user_id")
users_partitioned = users.repartition(4, "user_id")

# Now join - data already co-located by key
joined_optimized = orders_partitioned.join(users_partitioned, "user_id")
print("=== Pre-partitioned Join Execution Plan ===")
joined_optimized.explain()

---

## 3. Broadcast Joins <a id='broadcast-joins'></a>

**Broadcast joins** send the smaller dataset to all executors, avoiding shuffle of the larger dataset.

### When to Use Broadcast Joins

- One table is **small enough to fit in memory** (< 10MB default, configurable)
- Joining a **fact table with dimension tables**
- **Lookup operations** with reference data

### Configuration

```python
# Auto-broadcast threshold (default 10MB)
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", 10485760)  # bytes

# Disable auto-broadcast
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)
```

In [None]:
# Create a large orders table and small lookup table
large_orders = spark.createDataFrame(
    [(i, f"product_{i % 10}", i * 5) for i in range(10000)],
    ["order_id", "product_id", "quantity"]
)

# Small dimension/lookup table
products = spark.createDataFrame(
    [(f"product_{i}", f"Product Name {i}", i * 10.0) for i in range(10)],
    ["product_id", "product_name", "price"]
)

print(f"Orders count: {large_orders.count()}")
print(f"Products count: {products.count()}")

In [None]:
# Explicit broadcast join
from pyspark.sql.functions import broadcast

broadcast_joined = large_orders.join(
    broadcast(products),  # Broadcast the small table
    "product_id"
)

print("=== Broadcast Join Execution Plan ===")
broadcast_joined.explain()

# Notice: No Exchange (shuffle) on the large table side

In [None]:
# Compare performance: Regular join vs Broadcast join
import time

# Disable auto-broadcast for comparison
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)

# Regular Sort-Merge Join
start = time.time()
regular_join = large_orders.join(products, "product_id")
regular_join.count()  # Force execution
regular_time = time.time() - start

# Broadcast Join
start = time.time()
broadcast_join = large_orders.join(broadcast(products), "product_id")
broadcast_join.count()  # Force execution
broadcast_time = time.time() - start

print(f"Regular Join Time: {regular_time:.3f}s")
print(f"Broadcast Join Time: {broadcast_time:.3f}s")
print(f"Speedup: {regular_time / broadcast_time:.2f}x")

# Re-enable auto-broadcast
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", 10485760)

---

## 4. Partition Pruning <a id='partition-pruning'></a>

**Partition pruning** is an optimization that skips reading data partitions that don't match filter conditions.

### Types of Partition Pruning

1. **Static Partition Pruning**: Filter on partition column is known at compile time
2. **Dynamic Partition Pruning (DPP)**: Filter values determined at runtime from another table

### Best Practices

- Partition data by commonly filtered columns (date, region, etc.)
- Use **Parquet** or **Delta Lake** for efficient partition pruning
- Filter on partition columns **before** other operations

In [None]:
# Create partitioned data and write to disk
import tempfile
import os

# Sample sales data
sales_data = spark.createDataFrame(
    [(i, f"2024-0{(i % 3) + 1}-01", f"region_{i % 5}", i * 100) 
     for i in range(1000)],
    ["sale_id", "sale_date", "region", "amount"]
)

# Create temp directory for partitioned data
temp_dir = tempfile.mkdtemp()
partitioned_path = os.path.join(temp_dir, "sales_partitioned")

# Write partitioned by date and region
sales_data.write \
    .partitionBy("sale_date", "region") \
    .parquet(partitioned_path)

print(f"Data written to: {partitioned_path}")

In [None]:
# Read partitioned data
partitioned_sales = spark.read.parquet(partitioned_path)

# Static Partition Pruning - filter on partition column
filtered_sales = partitioned_sales.filter(
    (col("sale_date") == "2024-01-01") & 
    (col("region") == "region_0")
)

print("=== Static Partition Pruning ===")
filtered_sales.explain()
# Look for "PartitionFilters" in the plan - shows pruned partitions

In [None]:
# Dynamic Partition Pruning Example
# Enable DPP (enabled by default in Spark 3.0+)
spark.conf.set("spark.sql.optimizer.dynamicPartitionPruning.enabled", "true")

# Dimension table with filter
regions_to_include = spark.createDataFrame(
    [("region_0",), ("region_1",)],
    ["region"]
)

# Join with partitioned fact table - DPP kicks in
dpp_result = partitioned_sales.join(
    regions_to_include,
    "region"
)

print("=== Dynamic Partition Pruning ===")
dpp_result.explain(extended=True)

---

## 5. Caching and Persistence <a id='caching-persistence'></a>

**Caching** stores intermediate DataFrames in memory (or disk) to avoid recomputation.

### Storage Levels

| Level | Description |
|-------|-------------|
| `MEMORY_ONLY` | Store as deserialized Java objects in JVM (default for `cache()`) |
| `MEMORY_AND_DISK` | Spill to disk if memory is insufficient |
| `MEMORY_ONLY_SER` | Store as serialized objects (more space-efficient) |
| `DISK_ONLY` | Store only on disk |
| `OFF_HEAP` | Store in off-heap memory |

### When to Cache

- DataFrame is **used multiple times**
- Result of **expensive computation** (joins, aggregations)
- **Iterative algorithms** (ML training)

### When NOT to Cache

- DataFrame is used only once
- Data is too large to fit in memory
- Computation is simple (reads from Parquet with pushdown)

In [None]:
from pyspark import StorageLevel

# Create an expensive DataFrame (simulating complex computation)
expensive_df = large_orders.join(broadcast(products), "product_id") \
    .withColumn("total_value", col("quantity") * col("price")) \
    .groupBy("product_name") \
    .agg(
        F.sum("total_value").alias("revenue"),
        F.count("*").alias("order_count")
    )

# Cache in memory and disk
expensive_df.persist(StorageLevel.MEMORY_AND_DISK)

# Force caching by triggering an action
expensive_df.count()

print("DataFrame cached!")
print(f"Storage Level: {expensive_df.storageLevel}")

In [None]:
# Multiple operations on cached DataFrame (no recomputation)
import time

# First action - may still compute if not fully cached
start = time.time()
print("Top products by revenue:")
expensive_df.orderBy(col("revenue").desc()).show(5)
print(f"Time: {time.time() - start:.3f}s")

# Second action - reads from cache
start = time.time()
print("\nTotal revenue:")
print(expensive_df.agg(F.sum("revenue")).collect()[0][0])
print(f"Time: {time.time() - start:.3f}s")

# Third action - reads from cache
start = time.time()
print(f"\nAverage order count: {expensive_df.agg(F.avg('order_count')).collect()[0][0]:.2f}")
print(f"Time: {time.time() - start:.3f}s")

In [None]:
# Check what's cached
print("=== Cached DataFrames ===")
for (id, rdd) in spark.sparkContext._jsc.getPersistentRDDs().items():
    print(f"RDD ID: {id}, Name: {rdd.name()}, Storage Level: {rdd.getStorageLevel()}")

In [None]:
# Unpersist when done to free memory
expensive_df.unpersist()
print("DataFrame unpersisted!")

# Alternative: cache() is shorthand for persist(MEMORY_AND_DISK)
# df.cache()  # Equivalent to df.persist(StorageLevel.MEMORY_AND_DISK)

### Checkpointing vs Caching

| Feature | Caching | Checkpointing |
|---------|---------|---------------|
| **Storage** | Memory/Disk (temporary) | HDFS/S3 (durable) |
| **Lineage** | Preserved | Truncated |
| **Use Case** | Iterative algorithms | Breaking long lineages, fault tolerance |
| **Speed** | Faster | Slower (writes to distributed storage) |

In [None]:
# Checkpointing example (requires checkpoint directory)
checkpoint_dir = os.path.join(temp_dir, "checkpoints")
spark.sparkContext.setCheckpointDir(checkpoint_dir)

# Checkpoint breaks lineage - useful for iterative algorithms
# Uncomment to use:
# long_lineage_df = df
# for i in range(10):
#     long_lineage_df = long_lineage_df.withColumn(f"col_{i}", col("amount") + i)
# 
# # Checkpoint to truncate lineage
# checkpointed_df = long_lineage_df.checkpoint()

print(f"Checkpoint directory set to: {checkpoint_dir}")

---

## 6. Additional Optimization Tips

In [None]:
# Adaptive Query Execution (AQE) - Spark 3.0+
# Dynamically optimizes query plans at runtime

spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")

print("AQE Configuration:")
print(f"  AQE Enabled: {spark.conf.get('spark.sql.adaptive.enabled')}")
print(f"  Coalesce Partitions: {spark.conf.get('spark.sql.adaptive.coalescePartitions.enabled')}")
print(f"  Skew Join: {spark.conf.get('spark.sql.adaptive.skewJoin.enabled')}")

In [None]:
# Column Pruning - Only select needed columns
# Bad: Select all then filter
# all_cols = partitioned_sales.select("*").filter(col("amount") > 500)

# Good: Select only needed columns
pruned_cols = partitioned_sales.select("sale_id", "amount") \
    .filter(col("amount") > 500)

print("=== Column Pruning Plan ===")
pruned_cols.explain()

In [None]:
# Predicate Pushdown - Filter early
# Filters are pushed down to data source level

filtered_early = spark.read.parquet(partitioned_path) \
    .filter(col("amount") > 50000)  # Filter pushed to Parquet reader

print("=== Predicate Pushdown Plan ===")
filtered_early.explain()

---

## 7. Key Takeaways <a id='takeaways'></a>

### Partitioning Best Practices

| Strategy | When to Use |
|----------|-------------|
| **Use `coalesce()` over `repartition()`** | When reducing partitions (avoids shuffle) |
| **Partition by join keys** | When joining same datasets repeatedly |
| **Match partitions to cluster cores** | 2-4 partitions per CPU core is optimal |

### Shuffle Optimization

| Technique | Benefit |
|-----------|--------|
| **Broadcast joins** | Eliminates shuffle for small tables |
| **Pre-partitioning** | Reduces shuffle during joins |
| **Filter early** | Less data to shuffle |

### Caching Strategy

| Guideline | Recommendation |
|-----------|----------------|
| **Cache reused DataFrames** | Especially after expensive operations |
| **Use appropriate storage level** | `MEMORY_AND_DISK` for large datasets |
| **Unpersist when done** | Free memory for other operations |
| **Monitor cache usage** | Check Spark UI Storage tab |

### Quick Reference Configurations

```python
# Key optimization configs
spark.conf.set("spark.sql.shuffle.partitions", "200")  # Adjust based on data size
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "10485760")  # 10MB
spark.conf.set("spark.sql.adaptive.enabled", "true")  # Enable AQE
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")  # Handle skew
```

In [None]:
# Cleanup
import shutil

# Clean up temp directory
try:
    shutil.rmtree(temp_dir)
    print(f"Cleaned up: {temp_dir}")
except:
    pass

# Stop Spark session
spark.stop()
print("Spark session stopped.")