# Lesson 13 - Broadcast Variables and Accumulators

Okay, let's craft the technical notes for Lesson 13, focusing on Broadcast Variables and Accumulators in PySpark.

---

## PySpark Technical Notes: Lesson 13 - Shared Variables: Broadcast and Accumulators

**Objective:** This section covers Spark's mechanisms for sharing information efficiently across the cluster: Broadcast Variables for distributing large read-only data, and Accumulators for reliably aggregating results back from worker nodes to the driver program. Understanding these is crucial for optimizing certain types of Spark jobs and for effective debugging.

### 1. Efficiently Distributing Read-Only Data: Broadcast Variables

**Theory:**

In standard Spark execution, any external variable referenced within a transformation closure (like a function passed to `map`, `filter`, or a UDF) is serialized and sent along with the task's code to each executor responsible for processing a partition. If this external variable is large (e.g., a lookup table, a machine learning model, a large configuration map), sending a copy with potentially thousands of tasks becomes highly inefficient, consuming significant network bandwidth and increasing task serialization time.

**Broadcast Variables** provide a solution. They allow the programmer to ship a potentially large, *read-only* variable to each worker node only *once*. Tasks running on any executor on that node can then access this variable from a local cache, avoiding redundant network transfers for each task.

**Execution Flow:**
1.  **Creation:** The driver program creates a broadcast variable from a regular variable using `SparkContext.broadcast(variable)`.
2.  **Serialization & Distribution:** The driver serializes the variable's data. Spark breaks it into chunks and distributes these chunks efficiently (often using a BitTorrent-like protocol) to the Block Managers on each executor node.
3.  **Caching:** Each executor deserializes the chunks and caches the complete variable in memory (or potentially spilled to disk if memory is insufficient).
4.  **Task Execution:** When a task runs on an executor and needs the broadcast variable, it simply reads it from the local Block Manager cache using the `.value` attribute. No data is sent with the task closure itself.

**When to Use Broadcast Variables:**
*   When you need to share a **read-only** dataset/object with tasks across multiple stages or iterations.
*   When this dataset is **large enough** that sending it with every task closure would be inefficient (e.g., > few MBs, though the threshold depends on network speed and task frequency), but **small enough** to fit comfortably in the memory of each executor.
*   Common examples include lookup tables, configuration settings, machine learning models, stop word lists, etc.
*   Often used to optimize "map-side joins" where one dataset is small enough to be broadcast and joined with partitions of a larger dataset locally on each executor.

**Code Example:**

Let's use a broadcast variable to efficiently map US state abbreviations to full state names within a larger dataset.

```python
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, StringType, IntegerType

# Boilerplate Spark Session creation
spark = SparkSession.builder.appName("BroadcastVariableExample").getOrCreate()

# 1. Define the large-ish read-only lookup data (Python dictionary)
state_map = {
    "CA": "California", "NY": "New York", "TX": "Texas", "FL": "Florida",
    "IL": "Illinois", "PA": "Pennsylvania", "OH": "Ohio", "GA": "Georgia",
    "NC": "North Carolina", "MI": "Michigan"
    # ... potentially many more states
}

# 2. Create the broadcast variable on the driver
#    Accessed via spark.sparkContext
broadcast_state_map = spark.sparkContext.broadcast(state_map)

# Verify the type
print(f"Type of broadcast variable: {type(broadcast_state_map)}")

# Sample DataFrame with addresses needing state name lookup
data = [
    (1, "123 Main St", "Anytown", "CA"),
    (2, "456 Oak Ave", "Otherville", "NY"),
    (3, "789 Pine Ln", "Smalltown", "TX"),
    (4, "101 Maple Dr", "Bigcity", "CA"),
    (5, "202 Elm St", "Metropolis", "FL"),
    (6, "303 Birch Rd", "Somewhere", "ZZ"), # Unknown state code
]
schema = StructType([
    StructField("id", IntegerType(), True),
    StructField("address", StringType(), True),
    StructField("city", StringType(), True),
    StructField("state_abbr", StringType(), True)
])
address_df = spark.createDataFrame(data, schema)

print("--- Original Address Data ---")
address_df.show()

# 3. Define a UDF (User Defined Function) to perform the lookup using the broadcast variable
def get_state_name(abbr):
    # Access the actual dictionary using .value inside the UDF/closure
    state_dict = broadcast_state_map.value
    return state_dict.get(abbr, "Unknown") # Use .get for default value

# Register the UDF
get_state_name_udf = F.udf(get_state_name, StringType())

# 4. Apply the UDF to the DataFrame
address_df_with_state = address_df.withColumn(
    "state_name", get_state_name_udf(F.col("state_abbr"))
)

print("--- Address Data with State Name (using Broadcast) ---")
address_df_with_state.show()

# Accessing the value on the driver (possible but less common usage)
print(f"\nAccessing broadcast value on driver: {broadcast_state_map.value.get('CA')}")

spark.stop()
```

**Code Explanation:**

1.  **`state_map = {...}`**: Defines a standard Python dictionary on the driver. This represents our lookup data. In a real scenario, this could be loaded from a file or database.
2.  **`broadcast_state_map = spark.sparkContext.broadcast(state_map)`**: This is the core step. We take the `state_map` dictionary and create a `Broadcast` object using the `SparkContext`. Spark will handle serializing and distributing this object efficiently.
3.  **`def get_state_name(abbr): ...`**: Defines a Python function intended to run on executors as part of a UDF.
4.  **`state_dict = broadcast_state_map.value`**: **Crucially**, inside the code that runs on the *executor* (the UDF), we access the actual underlying data (the dictionary) using the `.value` attribute of the broadcast variable.
5.  **`return state_dict.get(abbr, "Unknown")`**: Performs the lookup using the retrieved dictionary. Using `.get()` provides a default value if the abbreviation isn't found.
6.  **`get_state_name_udf = F.udf(...)`**: Registers the Python function as a Spark UDF.
7.  **`address_df.withColumn(...)`**: Applies the UDF to the `state_abbr` column. When this transformation is executed, Spark ensures the `broadcast_state_map` is available locally on each executor running the `get_state_name_udf` tasks, and the UDF accesses it via `.value`.
8.  **`print(f"\nAccessing broadcast value on driver: ...")`**: Demonstrates that `.value` can also be used on the driver to retrieve the original object, though the primary benefit is executor-side access.
9.  **`spark.stop()`**: Releases Spark resources.

**Performance & Best Practices:**
*   Broadcast variables are **read-only** after creation. Updates to the original variable on the driver *after* broadcasting will **not** be reflected on the executors.
*   Ensure the broadcast data **fits comfortably in each executor's memory**. Broadcasting excessively large objects can lead to OutOfMemoryErrors or performance degradation due to spilling.
*   Consider the overhead: For very small variables, the overhead of broadcasting might exceed the cost of sending them with task closures. Profile if unsure.
*   Spark SQL's query optimizer often performs **Broadcast Joins** automatically (controlled by `spark.sql.autoBroadcastJoinThreshold`) when one side of a join is sufficiently small, effectively using broadcasting under the hood without explicit user code. However, manual broadcasting is useful for non-join scenarios or when needing finer control.

---

### 2. Aggregating Information Reliably: Accumulators

**Theory:**

Accumulators are shared variables designed for **associative and commutative "add" operations** performed by tasks running in parallel on executors. Their primary purpose is to reliably aggregate simple information (like counters or sums) back to the driver program, even in the presence of task failures and retries.

Standard variables defined on the driver and modified within task closures (e.g., `my_counter += 1` inside a `map` function) **do not work reliably**. This is because:
1.  Closures capture copies of variables; updates happen on the executor's copy, not the driver's original.
2.  Task failures and retries mean a task might execute multiple times, potentially updating a naive counter incorrectly.

Accumulators solve this by providing a mechanism where:
1.  The driver initializes an accumulator with an identity ("zero") value.
2.  Tasks receive a reference to the accumulator and can only use an `add` operation.
3.  Spark manages the aggregation of these partial updates from potentially multiple (retried) tasks per partition back to the driver *correctly and reliably*.
4.  Only the **driver program can read** the accumulator's final value using its `.value` attribute. Tasks cannot read the accumulator's value.

**Types of Accumulators:**
*   **Numeric:** `SparkContext.accumulator(zeroValue)` for integers/floats (deprecated but functional), or `SparkContext.sparkSession.sparkContext.longAccumulator(name=None)` and `doubleAccumulator(name=None)` for more type safety.
*   **Collection:** `SparkContext.sparkSession.sparkContext.collectionAccumulator(name=None)` to collect elements into a list (use with caution for large collections).
*   **Custom:** Possible to create custom accumulators by subclassing `AccumulatorV2` (more advanced).

**When to Use Accumulators:**
*   **Debugging:** Counting specific events, errors, or conditions occurring during distributed processing (e.g., number of malformed records, number of times a specific branch of code is hit).
*   **Simple Monitoring:** Gathering basic metrics like total rows processed, number of records filtered out by a certain criterion.
*   **Diagnostic Information:** Collecting a small set of unique IDs or error messages (using `CollectionAccumulator`).

**Important Caveat:** Spark guarantees that updates within **actions** (like `foreach`, `count`, `collect`, `save`) will be applied only *once* for each task's contribution to the final result. However, updates performed inside **transformations** (like `map`, `filter`) might be executed multiple times if stages are recomputed due to failures or speculative execution. Therefore, rely on accumulator values primarily *after* an action has completed, and be cautious interpreting values updated within transformations, especially for exact counts.

**Code Example:**

Let's use an accumulator to count the number of records with missing or invalid state abbreviations during processing.

```python
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, StringType, IntegerType

# Boilerplate Spark Session creation
spark = SparkSession.builder.appName("AccumulatorExample").getOrCreate()

# 1. Initialize an accumulator on the driver
#    Using the newer API via SparkContext obtained from SparkSession
invalid_state_counter = spark.sparkContext.accumulator(0)
# Alternative: Named accumulator (better for Spark UI visibility)
# invalid_state_counter = spark.sparkContext.longAccumulator("InvalidStateCount")

# Sample DataFrame with some potentially invalid state abbreviations
data = [
    (1, "CA"), (2, "NY"), (3, "TX"),
    (4, "CA"), (5, "FL"), (6, "ZZ"), # Invalid
    (7, "PA"), (8, ""),             # Invalid (empty string)
    (9, None)                       # Invalid (null)
]
schema = StructType([
    StructField("id", IntegerType(), True),
    StructField("state_abbr", StringType(), True)
])
address_df = spark.createDataFrame(data, schema=schema)

print("--- Original Address Data ---")
address_df.show()

# Valid state codes (could be from a broadcast variable too)
valid_states = {"CA", "NY", "TX", "FL", "PA", "OH", "GA", "NC", "MI"} # Example set

# 2. Use the accumulator within an action (e.g., foreach)
#    Updates inside actions are more reliable for final counts.
def validate_state(row):
    state = row["state_abbr"]
    if state is None or state == "" or state not in valid_states:
        # Use .add() to increment the accumulator
        invalid_state_counter.add(1)

# Apply the function to each row using foreach (an ACTION)
print("\nRunning validation using foreach...")
address_df.foreach(validate_state)
print("Validation complete.")

# 3. Read the accumulator's value on the driver AFTER the action completes
final_invalid_count = invalid_state_counter.value
print(f"\nTotal invalid state entries found: {final_invalid_count}") # Expected: 3

# Example: Using accumulator inside a transformation (less reliable for exact count)
# Note: This count might be higher if tasks are re-executed.
malformed_records_in_map = spark.sparkContext.accumulator(0)

def check_malformed_in_map(state):
    if state is None or state == "" or len(state) != 2:
         malformed_records_in_map.add(1)
    return state # map needs to return something

# Apply in a map transformation (LAZY)
processed_df = address_df.withColumn("state_processed",
                                     F.udf(check_malformed_in_map, StringType())(F.col("state_abbr")))

# Accumulator value is NOT guaranteed to be correct until an action is performed
print(f"\nAccumulator value immediately after map transformation defined: {malformed_records_in_map.value}") # Likely 0

# Trigger action
processed_df.count() # or .collect(), .show(), .write() etc.

# Value after action
print(f"Accumulator value after action triggering map: {malformed_records_in_map.value}") # Should reflect actual count now (potentially >3 if retries occurred)


spark.stop()
```

**Code Explanation:**

1.  **`invalid_state_counter = spark.sparkContext.accumulator(0)`**: Initializes a default (unnamed, integer-based) accumulator with a starting value of 0 on the driver.
2.  **`valid_states = {...}`**: A simple set of valid states for demonstration.
3.  **`def validate_state(row): ...`**: Defines a function that takes a DataFrame `Row` object.
4.  **`if state is None or ...:`**: Logic to check if the state abbreviation is invalid.
5.  **`invalid_state_counter.add(1)`**: If the state is invalid, **atomically adds 1** to the accumulator. This operation is safe to call from multiple tasks in parallel.
6.  **`address_df.foreach(validate_state)`**: Executes the `validate_state` function for *each row* in the DataFrame. `foreach` is an **action**, so it triggers computation, and Spark ensures accumulator updates within it are reliable for the final count.
7.  **`final_invalid_count = invalid_state_counter.value`**: **After** the `foreach` action completes, the driver retrieves the final aggregated value from the accumulator using `.value`.
8.  **`malformed_records_in_map = ...`**: Initializes a second accumulator.
9.  **`def check_malformed_in_map(state): ...`**: Defines a function intended for a `map`-like transformation (here used within a UDF called by `withColumn`). It also increments its accumulator.
10. **`processed_df = address_df.withColumn(...)`**: Defines a transformation. This is **lazy**; no computation happens yet.
11. **`print(f"Accumulator value immediately after map ...")`**: Shows the accumulator value *before* any action. It will typically be 0 because the transformation hasn't run.
12. **`processed_df.count()`**: An action that triggers the execution of the `withColumn` transformation containing the UDF.
13. **`print(f"Accumulator value after action ...")`**: Shows the accumulator value *after* the action. This value reflects updates from the map tasks but could be inflated if any tasks were retried. This highlights why updates in actions are preferred for definitive counts.

**Performance & Best Practices:**
*   Use accumulators primarily for **debugging or low-volume metrics**. Excessive use or accumulating large amounts of data (especially with `CollectionAccumulator`) can impact performance.
*   Prefer updating accumulators within **actions** (`foreach`, `foreachPartition`) for reliable, exactly-once counts related to the final output. Be cautious interpreting values updated within transformations due to potential re-computations.
*   Use named accumulators (`longAccumulator("name")`) as they appear in the Spark UI, aiding debugging.
*   Remember that only the driver can read the final `.value`. Tasks cannot access the current aggregated value.

**Conclusion:** Broadcast variables and accumulators are specialized tools in Spark's shared variable arsenal. Broadcast variables optimize the distribution of large, read-only data to executors, while accumulators provide a reliable way to aggregate simple counts or sums back to the driver from distributed tasks. Using them appropriately can lead to significant performance gains and provide valuable insights during debugging and monitoring.

---
**End of Lesson 13 Notes**