# PySpark Best Practices

This notebook covers key best practices and optimization techniques for working with PySpark.

## 1. SparkSession Configuration

Properly configuring your SparkSession is the first step toward optimized performance.

In [1]:
from pyspark.sql import SparkSession

# Create a well-configured SparkSession
spark = (SparkSession.builder
    .appName("PySpark Best Practices")
    .config("spark.sql.shuffle.partitions", 200)
    .config("spark.executor.memory", "2g")
    .config("spark.driver.memory", "2g")
    .config("spark.default.parallelism", 8)
    .config("spark.sql.adaptive.enabled", "true")
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true")
    .getOrCreate()
)

# Get the current configuration
print(spark.sparkContext.getConf().getAll())

/opt/spark/bin/load-spark-env.sh: line 68: ps: command not found
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/04/18 06:48:45 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


[('spark.executor.memory', '2g'), ('spark.driver.extraJavaOptions', '-Djava.net.preferIPv6Addresses=false -XX:+IgnoreUnrecognizedVMOptions --add-opens=java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.invoke=ALL-UNNAMED --add-opens=java.base/java.lang.reflect=ALL-UNNAMED --add-opens=java.base/java.io=ALL-UNNAMED --add-opens=java.base/java.net=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED --add-opens=java.base/java.util.concurrent=ALL-UNNAMED --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED --add-opens=java.base/jdk.internal.ref=ALL-UNNAMED --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/sun.nio.cs=ALL-UNNAMED --add-opens=java.base/sun.security.action=ALL-UNNAMED --add-opens=java.base/sun.util.calendar=ALL-UNNAMED --add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED -Djdk.reflect.useDirectMethodHandle=false'), ('spark.app.startTime', '1744958925326'), ('spark.app.id', 'local-174495

### Key Configuration Parameters

- **spark.sql.shuffle.partitions**: Controls the number of partitions during shuffles (default: 200)
- **spark.executor.memory**: Memory per executor
- **spark.driver.memory**: Memory for driver process
- **spark.default.parallelism**: Default number of partitions for RDDs
- **spark.sql.adaptive.enabled**: Enables adaptive query execution

💡 **Best Practice**: Tune these parameters based on your cluster size and workload characteristics.

## 2. Data Loading and Schema Management

In [2]:
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType

# Define schema explicitly
schema = StructType([
    StructField("id", IntegerType(), False),
    StructField("name", StringType(), True),
    StructField("age", IntegerType(), True),
    StructField("salary", DoubleType(), True)
])

# Create sample data
data = [
    (1, "Alice", 30, 50000.0),
    (2, "Bob", 32, 60000.0),
    (3, "Charlie", 28, 55000.0)
]

# Create DataFrame with explicit schema
df = spark.createDataFrame(data, schema=schema)
df.printSchema()
df.show()

root
 |-- id: integer (nullable = false)
 |-- name: string (nullable = true)
 |-- age: integer (nullable = true)
 |-- salary: double (nullable = true)

+---+-------+---+-------+
| id|   name|age| salary|
+---+-------+---+-------+
|  1|  Alice| 30|50000.0|
|  2|    Bob| 32|60000.0|
|  3|Charlie| 28|55000.0|
+---+-------+---+-------+



### Best Practices for Data Loading

1. **Always Define Schema Explicitly**
   - Avoids costly schema inference
   - Ensures correct data types
   - Improves performance on large datasets

2. **Use Appropriate File Formats**
   - Parquet is usually the best choice (columnar, compressed, schema-preserved)
   - ORC is good for Hive compatibility
   - Avoid text/CSV for large datasets when possible

3. **Partition Data Appropriately**
   - Choose partition columns that distribute data evenly
   - Avoid too many small partitions

In [3]:
# Example of reading/writing with best practices

# Write data using Parquet format with partitioning
df.write \
    .mode("overwrite") \
    .partitionBy("age") \
    .parquet("/tmp/example-data")

# Read data with explicit schema
df_read = spark.read \
    .schema(schema) \
    .parquet("/tmp/example-data")

df_read.explain()  # Show execution plan

== Physical Plan ==
*(1) ColumnarToRow
+- FileScan parquet [id#29,name#30,salary#31,age#32] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/tmp/example-data], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:int,name:string,salary:double>




                                                                                

## 3. DataFrame Operations - Transformations and Actions

In [4]:
from pyspark.sql.functions import col, when, expr, avg, sum, max, min, count

# Column operations - Good practice
df_transformed = df \
    .select(
        col("id"),
        col("name"),
        col("age"),
        col("salary"),  # Keep the original salary column
        (col("salary") * 1.1).alias("adjusted_salary")
    ) \
    .withColumn(
        "salary_category",
        when(col("salary") < 55000, "Entry")
        .when(col("salary") < 65000, "Mid")
        .otherwise("Senior")
    ) \
    .filter(col("age") > 25)

df_transformed.show()

+---+-------+---+-------+-----------------+---------------+
| id|   name|age| salary|  adjusted_salary|salary_category|
+---+-------+---+-------+-----------------+---------------+
|  1|  Alice| 30|50000.0|55000.00000000001|          Entry|
|  2|    Bob| 32|60000.0|          66000.0|            Mid|
|  3|Charlie| 28|55000.0|60500.00000000001|            Mid|
+---+-------+---+-------+-----------------+---------------+



### Best Practices for Transformations

1. **Chain Operations Efficiently**
   - Chain multiple transformations before actions
   - Use method chaining for readability

2. **Use Column Expressions**
   - Prefer `col()` and expressions over UDFs when possible
   - Use SQL functions from `pyspark.sql.functions`

3. **Limit Shuffling Operations**
   - Operations like `groupBy`, `join`, `repartition` cause shuffling
   - Try to minimize these operations

4. **Filter Early**
   - Apply filters as early as possible to reduce data volume

## 4. Optimization Techniques - Caching and Persistence

In [5]:
from pyspark.storagelevel import StorageLevel

# Cache data that will be reused
df_cached = df.cache()  # or df.persist()
df_cached.count()  # Materialize the cache

# More control with specific storage level
df_persisted = df.persist(StorageLevel.MEMORY_AND_DISK)
df_persisted.count()  # Materialize the persistence

# Clean up when done
df_cached.unpersist()
df_persisted.unpersist()

25/04/18 06:48:49 WARN CacheManager: Asked to cache already cached data.


DataFrame[id: int, name: string, age: int, salary: double]

### Caching Best Practices

1. **When to Cache**
   - Cache DataFrames used multiple times
   - Cache after expensive transformations
   - Cache after filtering down large datasets

2. **Storage Levels**
   - `MEMORY_ONLY`: Default, fastest but can cause OOM errors
   - `MEMORY_AND_DISK`: Safer option, spills to disk if needed
   - `DISK_ONLY`: When memory is limited

3. **Clean Up Cache**
   - Call `unpersist()` when done to free up resources

## 5. Joins and Aggregations

In [6]:
# Create a department DataFrame
dept_data = [(1, "Engineering"), (2, "HR"), (3, "Marketing")]
dept_schema = StructType([
    StructField("id", IntegerType(), False),
    StructField("department", StringType(), True)
])
dept_df = spark.createDataFrame(dept_data, schema=dept_schema)

# Broadcast join (efficient for small + large table joins)
from pyspark.sql.functions import broadcast

# Broadcast the smaller DataFrame
joined_df = df.join(broadcast(dept_df), df.id == dept_df.id)
joined_df.explain()
joined_df.show()

# Efficient aggregations
agg_df = df.groupBy("age").agg(
    count("id").alias("count"),
    avg("salary").alias("avg_salary"),
    max("salary").alias("max_salary")
)

agg_df.show()

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- BroadcastHashJoin [id#0], [id#313], Inner, BuildRight, false
   :- Scan ExistingRDD[id#0,name#1,age#2,salary#3]
   +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [plan_id=166]
      +- Scan ExistingRDD[id#313,department#314]


+---+-------+---+-------+---+-----------+
| id|   name|age| salary| id| department|
+---+-------+---+-------+---+-----------+
|  1|  Alice| 30|50000.0|  1|Engineering|
|  2|    Bob| 32|60000.0|  2|         HR|
|  3|Charlie| 28|55000.0|  3|  Marketing|
+---+-------+---+-------+---+-----------+

+---+-----+----------+----------+
|age|count|avg_salary|max_salary|
+---+-----+----------+----------+
| 30|    1|   50000.0|   50000.0|
| 32|    1|   60000.0|   60000.0|
| 28|    1|   55000.0|   55000.0|
+---+-----+----------+----------+



### Best Practices for Joins and Aggregations

1. **Join Strategies**
   - Use broadcast joins when one DataFrame is small (<10MB)
   - Join on columns with high cardinality
   - Prefer using `join()` with explicit conditions over SQL-style joins

2. **Join Types**
   - Inner joins are fastest
   - Left/right outer joins preserve one side
   - Full outer joins are most expensive

3. **Efficient Aggregations**
   - Combine multiple aggregations in a single call
   - Filter before grouping when possible
   - Consider `approx_count_distinct()` for approximate counts

## 6. Partitioning and Bucketing

In [7]:
# Repartitioning to control parallelism
df_repartitioned = df.repartition(8)
print(f"Number of partitions: {df_repartitioned.rdd.getNumPartitions()}")

# Repartition by specific column (good for joins)
df_repart_by_col = df.repartition("age")

# Coalesce to reduce partitions (no shuffle)
df_coalesced = df_repartitioned.coalesce(2)
print(f"Number of partitions after coalesce: {df_coalesced.rdd.getNumPartitions()}")

# Write with bucketing (good for repeated joins)
df.write \
    .bucketBy(4, "id") \
    .sortBy("id") \
    .saveAsTable("bucketed_table")

Number of partitions: 8
Number of partitions after coalesce: 2


### Partitioning Best Practices

1. **Choosing Number of Partitions**
   - Rule of thumb: 2-3 × number of CPU cores
   - Too few: underutilization, potential OOM
   - Too many: task scheduling overhead

2. **When to Repartition**
   - Before wide operations (joins, groupBy)
   - When partition sizes are skewed
   - When number of partitions is too low/high

3. **Coalesce vs. Repartition**
   - Use `coalesce()` to reduce partitions (no shuffle)
   - Use `repartition()` to increase partitions (full shuffle)

4. **Bucketing**
   - Good for repeated joins on same column
   - Pre-organizes data to avoid shuffles

## 7. UDFs and Performance

In [8]:
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType

# Avoid this approach for simple operations
@udf(returnType=IntegerType())
def slow_add_one(x):
    if x is not None:
        return x + 1
    return None

# Prefer built-in functions (much faster)
df_slow = df.withColumn("age_plus_one", slow_add_one(col("age")))
df_fast = df.withColumn("age_plus_one", col("age") + 1)

# Compare execution plans
print("With UDF:")
df_slow.explain()
print("\nWith built-in function:")
df_fast.explain()

With UDF:
== Physical Plan ==
*(2) Project [id#0, name#1, age#2, salary#3, pythonUDF0#413 AS age_plus_one#401]
+- BatchEvalPython [slow_add_one(age#2)#400], [pythonUDF0#413]
   +- *(1) Scan ExistingRDD[id#0,name#1,age#2,salary#3]



With built-in function:
== Physical Plan ==
*(1) Project [id#0, name#1, age#2, salary#3, (age#2 + 1) AS age_plus_one#407]
+- *(1) Scan ExistingRDD[id#0,name#1,age#2,salary#3]




### UDF Best Practices

1. **Avoid UDFs When Possible**
   - Use built-in functions from `pyspark.sql.functions`
   - Use SQL expressions with `expr()`
   - UDFs require serialization/deserialization overhead

2. **When to Use UDFs**
   - Complex logic not available in built-in functions
   - Operations requiring external libraries

3. **Pandas UDFs**
   - Use Pandas UDFs (vectorized UDFs) for better performance
   - Can be 10-100x faster than regular UDFs
   - Requires Arrow serialization

In [9]:
! pip install pandas pyarrow

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [10]:
# Pandas UDF example (much faster than regular UDF)
import pandas as pd
from pyspark.sql.functions import pandas_udf

@pandas_udf(IntegerType())
def pandas_add_one(s: pd.Series) -> pd.Series:
    return s + 1

df_pandas_udf = df.withColumn("age_plus_one", pandas_add_one(col("age")))
df_pandas_udf.show()

                                                                                

+---+-------+---+-------+------------+
| id|   name|age| salary|age_plus_one|
+---+-------+---+-------+------------+
|  1|  Alice| 30|50000.0|          31|
|  2|    Bob| 32|60000.0|          33|
|  3|Charlie| 28|55000.0|          29|
+---+-------+---+-------+------------+



## 8. Window Functions

In [11]:
from pyspark.sql.window import Window
from pyspark.sql.functions import row_number, rank, dense_rank, lead, lag, sum

# Create window spec
window_spec = Window.partitionBy("age").orderBy("salary")

# Apply window functions
df_window = df.withColumn("rank", rank().over(window_spec)) \
              .withColumn("dense_rank", dense_rank().over(window_spec)) \
              .withColumn("row_number", row_number().over(window_spec)) \
              .withColumn("next_salary", lead("salary", 1).over(window_spec)) \
              .withColumn("prev_salary", lag("salary", 1).over(window_spec))

# Window for running totals
sum_window = Window.partitionBy("age").orderBy("salary").rowsBetween(Window.unboundedPreceding, Window.currentRow)
df_window = df_window.withColumn("running_total", sum("salary").over(sum_window))

df_window.show()

+---+-------+---+-------+----+----------+----------+-----------+-----------+-------------+
| id|   name|age| salary|rank|dense_rank|row_number|next_salary|prev_salary|running_total|
+---+-------+---+-------+----+----------+----------+-----------+-----------+-------------+
|  3|Charlie| 28|55000.0|   1|         1|         1|       NULL|       NULL|      55000.0|
|  1|  Alice| 30|50000.0|   1|         1|         1|       NULL|       NULL|      50000.0|
|  2|    Bob| 32|60000.0|   1|         1|         1|       NULL|       NULL|      60000.0|
+---+-------+---+-------+----+----------+----------+-----------+-----------+-------------+



### Window Function Best Practices

1. **Reuse Window Specifications**
   - Define window specs once and reuse
   - Improves readability and performance

2. **Window Function Types**
   - Ranking functions: `rank()`, `dense_rank()`, `row_number()`
   - Analytic functions: `lead()`, `lag()`
   - Aggregate functions: `sum()`, `avg()`, `min()`, `max()`

3. **Bounded vs. Unbounded Windows**
   - Bounded windows (e.g., `rowsBetween(-2, 2)`) are faster
   - Unbounded windows need more resources

4. **Partitioning Considerations**
   - Partition by columns with reasonable cardinality
   - Too many partitions can degrade performance

## 9. Performance Monitoring

In [12]:
# Get execution plan
df_complex = df.join(dept_df, df.id == dept_df.id) \
               .groupBy("department") \
               .agg(avg("salary").alias("avg_salary")) \
               .filter(col("avg_salary") > 50000)

# Logical and physical plans
print("Logical Plan:")
df_complex.explain()

print("\nDetailed Physical Plan:")
df_complex.explain("formatted")

# Count with explanation
from time import time
start = time()
count = df_complex.count()
end = time()
print(f"Count: {count}, Time: {end - start:.2f}s")

Logical Plan:
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Filter (isnotnull(avg_salary#573) AND (avg_salary#573 > 50000.0))
   +- HashAggregate(keys=[department#314], functions=[avg(salary#3)])
      +- Exchange hashpartitioning(department#314, 200), ENSURE_REQUIREMENTS, [plan_id=464]
         +- HashAggregate(keys=[department#314], functions=[partial_avg(salary#3)])
            +- Project [salary#3, department#314]
               +- SortMergeJoin [id#0], [id#313], Inner
                  :- Sort [id#0 ASC NULLS FIRST], false, 0
                  :  +- Exchange hashpartitioning(id#0, 200), ENSURE_REQUIREMENTS, [plan_id=456]
                  :     +- Project [id#0, salary#3]
                  :        +- Scan ExistingRDD[id#0,name#1,age#2,salary#3]
                  +- Sort [id#313 ASC NULLS FIRST], false, 0
                     +- Exchange hashpartitioning(id#313, 200), ENSURE_REQUIREMENTS, [plan_id=457]
                        +- Scan ExistingRDD[id#313,department#314]

### Performance Monitoring Best Practices

1. **Use explain() to Understand Plans**
   - Check physical plan for expensive operations
   - Look for broadcast joins, exchange (shuffle) operations

2. **Monitor Spark UI**
   - Available at http://localhost:4040 by default
   - Check stage durations, executor utilization
   - Identify skew in task durations

3. **Performance Metrics to Watch**
   - Shuffle read/write size
   - Spill memory to disk
   - Task durations
   - Cache hit ratio

## 10. Common Optimizations

In [13]:
# 1. Broadcast variables for lookup tables
lookup_dict = {1: "A", 2: "B", 3: "C"}
lookup_broadcast = spark.sparkContext.broadcast(lookup_dict)

@udf(StringType())
def lookup_value(key):
    return lookup_broadcast.value.get(key, "Unknown")

df.withColumn("lookup_result", lookup_value(col("id"))).show()

# 2. Avoid collect() on large DataFrames
# Good: Take just what you need
sample_rows = df.limit(10).collect()

# 3. Use SQL when it's more intuitive
df.createOrReplaceTempView("employees")
dept_df.createOrReplaceTempView("departments")

sql_result = spark.sql("""
    SELECT d.department, AVG(e.salary) as avg_salary
    FROM employees e JOIN departments d ON e.id = d.id
    GROUP BY d.department
    HAVING AVG(e.salary) > 50000
""")

sql_result.show()

+---+-------+---+-------+-------------+
| id|   name|age| salary|lookup_result|
+---+-------+---+-------+-------------+
|  1|  Alice| 30|50000.0|            A|
|  2|    Bob| 32|60000.0|            B|
|  3|Charlie| 28|55000.0|            C|
+---+-------+---+-------+-------------+

+----------+----------+
|department|avg_salary|
+----------+----------+
|        HR|   60000.0|
| Marketing|   55000.0|
+----------+----------+



### Summary of PySpark Best Practices

1. **Data Management**
   - Use Parquet format for efficient storage
   - Define schemas explicitly
   - Partition data wisely

2. **Performance Optimization**
   - Cache/persist reused DataFrames
   - Use broadcast joins for small tables
   - Minimize shuffling operations
   - Filter early to reduce data volume

3. **Code Practices**
   - Prefer built-in functions over UDFs
   - Use Pandas UDFs for better performance
   - Chain transformations efficiently
   - Monitor the Spark UI and physical plans

4. **Resource Management**
   - Tune partitioning based on cluster size
   - Configure memory settings appropriately
   - Clean up cached data when no longer needed