<a href="https://colab.research.google.com/github/suriarasai/BEAD2026/blob/main/colab/03a_Apache_Spark_User_Defined_Functions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction to  PySpark User Defined Functions
In this demo we will learn to use functional programming principles while creating effective user defined functions.

## Core Concept Introduction
While performing distributed computing with PySpark, UDF's thrive to be scalable but avoiding state share amoung worker nodes.
1. Recall our discussions, pure functions (no side effects, no state) can run safely on any worker node.
2. Stateful functions cause problems because state isn't shared across worker nodes.

## Functions with State (Bad practice)
### Example of Accumulator Anti Pattern
Usage of external states are discouraged.

In [3]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType, StringType

In [1]:
spark = SparkSession.builder.appName("FunctionalDemo").getOrCreate()

# BAD: Function with external state
counter = 0  # This won't work as expected in distributed setting

def count_processing_bad(value):
    global counter
    counter += 1  # Trying to maintain state
    return value * 2

# Create sample data
data = [(1,), (2,), (3,), (4,), (5,)]
df = spark.createDataFrame(data, ["value"])

# This will NOT work correctly in distributed mode
count_udf_bad = udf(count_processing_bad, IntegerType())
result_df = df.withColumn("doubled", count_udf_bad("value"))
result_df.show()

print(f"Counter value: {counter}")  # Will be 0 or unpredictable!

+-----+-------+
|value|doubled|
+-----+-------+
|    1|      2|
|    2|      4|
|    3|      6|
|    4|      8|
|    5|     10|
+-----+-------+

Counter value: 0


### Example of a Class with Mutable State
Causes in consistency

In [2]:
# BAD: Class with mutable state
class DataProcessor:
    def __init__(self):
        self.processed_count = 0
        self.sum_total = 0

    def process(self, value):
        self.processed_count += 1  # State mutation
        self.sum_total += value    # State mutation
        return value * 2

# This won't work in distributed setting
processor = DataProcessor()
process_udf = udf(lambda x: processor.process(x), IntegerType())

# The state will be lost/inconsistent
result_df = df.withColumn("processed", process_udf("value"))
result_df.show()
print(f"Processed: {processor.processed_count}")  # Wrong!

+-----+---------+
|value|processed|
+-----+---------+
|    1|        2|
|    2|        4|
|    3|        6|
|    4|        8|
|    5|       10|
+-----+---------+

Processed: 0


## Pure Functions (Scalable Approach)
To create a PySpark UDF using a pure function, you wrap the function with the udf() transformation and specify the return type, which allows the function to be applied to DataFrame columns in a distributed manner.

### Example of simple UDF

In [6]:
# 1. The Pure Function
def my_func(x): return x.upper()

# 2. Creating the UDF
upper_udf = udf(my_func, StringType())

# 3. Usage
df.withColumn("new_col", upper_udf(df["value"]))

DataFrame[value: bigint, new_col: string]

### Example of UDF with Decorators

In [7]:
# The decorator handles the conversion automatically
@udf(returnType=StringType())
def my_pure_function(text):
    if text is None:
        return None
    return text.upper()

# Usage (using 'value' as per your previous error)
df_result = df.withColumn("new_col", my_pure_function("value"))

### Example of Pure Transformation Functions

In [8]:
from pyspark.sql.functions import col, lit, sum as spark_sum

# GOOD: Pure function with no side effects
def multiply_pure(value, factor=2):
    """Pure function: output depends only on input"""
    return value * factor

# Register as UDF
multiply_udf = udf(lambda x: multiply_pure(x, 2), IntegerType())

# Use the pure function
result_df = df.withColumn("doubled", multiply_udf("value"))
result_df.show()

# If you need counting, use Spark's built-in aggregations
count_result = result_df.count()
print(f"Total processed: {count_result}")

+-----+-------+
|value|doubled|
+-----+-------+
|    1|      2|
|    2|      4|
|    3|      6|
|    4|      8|
|    5|     10|
+-----+-------+

Total processed: 5


### Example if Complex Pure Functios with Multiple Inputs
(TO BE DISCUSSED AFTER SQL LESSONS)

In [9]:
from pyspark.sql.types import StructType, StructField, FloatType
import json

# GOOD: Pure function for complex calculations
def calculate_metrics_pure(value, threshold, multiplier):
    """
    Pure function: No external dependencies or state
    All needed data passed as parameters
    """
    adjusted = value * multiplier
    is_above_threshold = adjusted > threshold
    return {
        'adjusted_value': adjusted,
        'above_threshold': is_above_threshold,
        'distance_from_threshold': adjusted - threshold
    }

# Create more complex dataset
complex_data = spark.createDataFrame(
    [(10, 25, 2.0), (20, 25, 1.5), (15, 25, 2.0)],
    ["value", "threshold", "multiplier"]
)

# Define return schema
return_schema = StructType([
    StructField("adjusted_value", FloatType()),
    StructField("above_threshold", FloatType()),
    StructField("distance_from_threshold", FloatType())
])

# Create UDF with pure function
metrics_udf = udf(
    lambda v, t, m: calculate_metrics_pure(v, t, m),
    return_schema
)

# Apply the pure function
result_df = complex_data.withColumn(
    "metrics",
    metrics_udf(col("value"), col("threshold"), col("multiplier"))
)

result_df.select("value", "metrics.*").show()

+-----+--------------+---------------+-----------------------+
|value|adjusted_value|above_threshold|distance_from_threshold|
+-----+--------------+---------------+-----------------------+
|   10|          20.0|           NULL|                   -5.0|
|   20|          30.0|           NULL|                    5.0|
|   15|          30.0|           NULL|                    5.0|
+-----+--------------+---------------+-----------------------+



### Example of Function Composition

In [10]:
# GOOD: Composing pure functions
def normalize(value, min_val, max_val):
    """Pure normalization function"""
    return (value - min_val) / (max_val - min_val) if max_val > min_val else 0

def apply_sigmoid(value):
    """Pure sigmoid transformation"""
    import math
    return 1 / (1 + math.exp(-value))

def transform_pipeline(value, min_val, max_val):
    """Compose pure functions"""
    normalized = normalize(value, min_val, max_val)
    return apply_sigmoid(normalized)

# Use with PySpark
transform_udf = udf(
    lambda x: transform_pipeline(x, 0, 100),
    FloatType()
)

data_range = spark.range(0, 100, 10).toDF("value")
transformed_df = data_range.withColumn("transformed", transform_udf("value"))
transformed_df.show()

+-----+-----------+
|value|transformed|
+-----+-----------+
|    0|        0.5|
|   10|  0.5249792|
|   20|   0.549834|
|   30|  0.5744425|
|   40| 0.59868765|
|   50| 0.62245935|
|   60|  0.6456563|
|   70|  0.6681878|
|   80|  0.6899745|
|   90|  0.7109495|
+-----+-----------+



## Handling State via PySPark Accumulators
When we genuinely need state-like behavior, we use Spark's built-in mechanisms.

### Example use of Accumulators

In [11]:
# GOOD: Using Spark's accumulator for distributed counting
from pyspark import AccumulatorParam

# Create an accumulator
processed_counter = spark.sparkContext.accumulator(0)

def process_with_accumulator(value):
    processed_counter.add(1)  # Thread-safe accumulation
    return value * 2

# Note: Accumulators in transformations can be tricky
# Better to use actions
rdd = spark.sparkContext.parallelize([1, 2, 3, 4, 5])
result = rdd.map(lambda x: process_with_accumulator(x)).collect()

print(f"Processed count: {processed_counter.value}")

Processed count: 5


### Example of Prefering Aggregations over State
(TO BE DISCUSSED AFTER SQL LESSONS)

In [12]:
from pyspark.sql import Window
from pyspark.sql.functions import row_number, sum, avg, max

# GOOD: Use window functions instead of stateful processing
data_with_groups = spark.createDataFrame(
    [("A", 10), ("A", 20), ("B", 15), ("B", 25), ("A", 30)],
    ["group", "value"]
)

# Instead of maintaining running totals in state, use window functions
window_spec = Window.partitionBy("group").orderBy("value")

result_df = data_with_groups.withColumn(
    "running_sum", spark_sum("value").over(window_spec)
).withColumn(
    "running_avg", avg("value").over(window_spec)
).withColumn(
    "rank", row_number().over(window_spec)
)

result_df.show()

+-----+-----+-----------+-----------+----+
|group|value|running_sum|running_avg|rank|
+-----+-----+-----------+-----------+----+
|    A|   10|         10|       10.0|   1|
|    A|   20|         30|       15.0|   2|
|    A|   30|         60|       20.0|   3|
|    B|   15|         15|       15.0|   1|
|    B|   25|         40|       20.0|   2|
+-----+-----+-----------+-----------+----+



Summary and Key Takeaways

1. Determinism: Pure functions always produce the same output for the same input
2. Parallelization: Pure functions can run on any worker without coordination
3. Testability: Pure functions are easy to unit test
4. Debugging: Pure functions are easier to debug (no hidden state)
5. Caching: Results of pure functions can be safely cached