# Imports & Configuration

In [0]:
import os

In [0]:
from pyspark.sql.types import *
import pyspark.sql.functions as sf
from pyspark.sql import SparkSession

# Reading File

In [0]:
transactions_fp = os.path.abspath("../data/data_skew/transactions.parquet")
transactions_fp = "file:" + transactions_fp
df_transactions = spark.read.parquet(transactions_fp)

In [0]:
df_transactions.display(5, False)

In [0]:
customers_fp = os.path.abspath("../data/data_skew/customers.parquet")
customers_fp = "file:" + customers_fp
df_customers = spark.read.parquet(customers_fp)

In [0]:
df_customers.display(5, False)

# Spark Optimization Exercises

Complete the following exercises to improve the performance of Spark code. For each exercise, identify the performance issues and implement optimizations.

In [0]:
## Exercise 1: Inefficient Filtering

**Problem:** The following code filters a large dataset multiple times. Identify and fix the performance issues.

```python
# Inefficient code - DO NOT RUN AS IS
result = df_transactions.filter(sf.col("amount") > 100)
result = result.filter(sf.col("status") == "completed")
result = result.filter(sf.col("customer_id").isNotNull())
result = result.select("transaction_id", "customer_id", "amount", "date")
result = result.filter(sf.year("date") == 2024)
```

**Your Task:**
1. Identify the performance issues
2. Optimize the code below
3. Explain what optimizations you made

In [None]:
# Write your optimized code here

## Exercise 2: Join Optimization

**Problem:** The code below performs a join that could be optimized using broadcast join.

```python
# Inefficient code
large_df = df_transactions
small_df = df_customers.select("customer_id", "customer_segment").distinct()

result = large_df.join(small_df, "customer_id", "inner")
result = result.groupBy("customer_segment").agg(sf.sum("amount").alias("total_amount"))
```

**Your Task:**
1. Determine if a broadcast join is appropriate
2. Implement the optimization
3. Compare the query plans before and after optimization

In [None]:
# Write your optimized code here

## Exercise 3: Unnecessary Shuffles

**Problem:** The following code causes multiple unnecessary shuffles.

```python
# Inefficient code
df = df_transactions.groupBy("customer_id").agg(sf.sum("amount").alias("total_spent"))
df = df.filter(df.total_spent > 500)
df = df.join(df_customers, "customer_id")
df = df.groupBy("customer_segment").agg(sf.avg("total_spent").alias("avg_spent"))
```

**Your Task:**
1. Identify where unnecessary shuffles occur
2. Optimize the code to reduce shuffles
3. Use `.explain()` to verify the reduction in shuffle operations

In [None]:
# Write your optimized code here

## Exercise 4: Caching Strategy

**Problem:** The code below reuses a DataFrame multiple times but doesn't cache it effectively.

```python
# Inefficient code
filtered_transactions = df_transactions.filter(sf.col("date") >= "2024-01-01")
filtered_transactions = filtered_transactions.filter(sf.col("amount") > 0)

# This DataFrame is used multiple times
high_value = filtered_transactions.filter(sf.col("amount") > 1000).count()
by_status = filtered_transactions.groupBy("status").count()
avg_amount = filtered_transactions.agg(sf.avg("amount")).collect()
```

**Your Task:**
1. Identify where caching would be beneficial
2. Implement appropriate caching with the right storage level
3. Explain when to use cache() vs persist() and which storage level to use

In [None]:
# Write your optimized code here

## Exercise 5: Partition Optimization

**Problem:** The code reads data and performs operations without considering partitioning.

```python
# Inefficient code
df = df_transactions.repartition(200)  # Arbitrary partition count
result = df.filter(sf.col("customer_id") == "CUST001").groupBy("date").sum("amount")
```

**Your Task:**
1. Analyze the data size and operation to determine optimal partition count
2. Consider whether repartition or coalesce is more appropriate
3. Optimize partitioning based on the subsequent operations

In [None]:
# Write your optimized code here

## Exercise 6: UDF Optimization

**Problem:** The code uses a Python UDF which is inefficient.

```python
# Inefficient code
from pyspark.sql.functions import udf

@udf(returnType=StringType())
def categorize_amount(amount):
    if amount < 100:
        return "low"
    elif amount < 500:
        return "medium"
    else:
        return "high"

result = df_transactions.withColumn("category", categorize_amount(sf.col("amount")))
```

**Your Task:**
1. Identify why the UDF is inefficient
2. Rewrite using built-in Spark functions (when, otherwise)
3. Compare performance if possible

In [None]:
# Write your optimized code here

## Exercise 7: Data Skew Challenge

**Problem:** The transactions dataset has data skew where a few customers have many transactions.

```python
# This code will suffer from data skew
result = df_transactions.groupBy("customer_id").agg(
    sf.count("*").alias("transaction_count"),
    sf.sum("amount").alias("total_amount"),
    sf.avg("amount").alias("avg_amount")
)
```

**Your Task:**
1. Identify the skew in the data (check customer_id distribution)
2. Implement a salting strategy or other technique to handle the skew
3. Compare execution time before and after optimization

In [None]:
# Write your optimized code here

## Bonus Exercise: Complex Query Optimization

**Problem:** Optimize this complex query that combines multiple operations.

```python
# Complex inefficient code
df1 = df_transactions.filter(sf.col("amount") > 100)
df2 = df1.groupBy("customer_id").agg(sf.sum("amount").alias("total"))
df3 = df2.filter(sf.col("total") > 1000)
df4 = df3.join(df_customers, "customer_id")
df5 = df4.select("customer_id", "customer_name", "total", "customer_segment")
df6 = df5.orderBy("total", ascending=False)
result = df6.limit(10)
```

**Your Task:**
Apply all optimization techniques learned:
- Predicate pushdown
- Column pruning
- Join optimization
- Appropriate use of cache if needed
- Efficient shuffle operations

In [None]:
# Write your optimized code here