In [0]:
# Persistence and Caching

when to use -

-when performing iterative operations on same dataset
- when performing expensive operations like joins, groupBy, or aggregations
- when you know data can be fit in memory, caching is ideal for performance

In [0]:
from pyspark.sql import SparkSession
from pyspark import StorageLevel
spark = SparkSession.builder.appName("Caching and Persistance").getOrCreate()

empDf = spark.read.csv("dbfs:/FileStore/demo_folder/employee_dataset.csv",header=True,inferSchema=True)


In [0]:
empDf.show(10)

+----------+----------------+---+----------+-----------+------+------------+
|EmployeeID|            Name|Age|Department|JoiningDate|Salary|        City|
+----------+----------------+---+----------+-----------+------+------------+
|         1|      Jon Rivera| 56|     Sales| 2024-04-29|121250|     Houston|
|         2|  Nicole Daniels| 46|        HR| 2024-04-01|138633|    New York|
|         3|Monique Sullivan| 32|   Finance| 2020-05-02| 83619|Philadelphia|
|         4|    James Wright| 60| Marketing| 2023-02-21|129751|    New York|
|         5| Nicole Williams| 25|     Sales| 2018-04-29|123193|     Chicago|
|         6|     David Bates| 38|     Sales| 2019-03-12| 98719|      Dallas|
|         7|   Matthew Riggs| 56| Marketing| 2019-11-06| 71156| Los Angeles|
|         8|    Wendy Powers| 36|Operations| 2023-07-30| 73901|      Dallas|
|         9|  Thomas Collins| 40|   Support| 2018-10-12| 30418|Philadelphia|
|        10|     Joshua Wong| 28|        IT| 2022-05-17| 89252| San Antonio|

In [0]:
transformed_df = empDf.filter(empDf.Salary > 90000)

In [0]:
transformed_df.cache()

Out[5]: DataFrame[EmployeeID: int, Name: string, Age: int, Department: string, JoiningDate: date, Salary: int, City: string]

In [0]:
print("Count after cache:",transformed_df.count())

Count after cache: 4998626


In [0]:
transformed_df.show()

+----------+--------------------+---+----------+-----------+------+------------+
|EmployeeID|                Name|Age|Department|JoiningDate|Salary|        City|
+----------+--------------------+---+----------+-----------+------+------------+
|         1|          Jon Rivera| 56|     Sales| 2024-04-29|121250|     Houston|
|         2|      Nicole Daniels| 46|        HR| 2024-04-01|138633|    New York|
|         4|        James Wright| 60| Marketing| 2023-02-21|129751|    New York|
|         5|     Nicole Williams| 25|     Sales| 2018-04-29|123193|     Chicago|
|         6|         David Bates| 38|     Sales| 2019-03-12| 98719|      Dallas|
|        11|      Dakota Shields| 28|        HR| 2018-02-24|114528|      Dallas|
|        12|       Angelica Wade| 41|        HR| 2015-06-29|130914|     Phoenix|
|        13|      Allison Miller| 53| Marketing| 2019-09-19|108950|      Dallas|
|        14|        Ryan Morales| 57| Marketing| 2021-05-13|130007|      Dallas|
|        15|    Elizabeth La

In [0]:
result_df = transformed_df.groupBy("Department").count()

In [0]:
result_df.show()

+----------+------+
|Department| count|
+----------+------+
|     Sales|714347|
|        HR|715662|
|   Finance|713291|
| Marketing|713757|
|        IT|714311|
|   Support|713916|
|Operations|713342|
+----------+------+



In [0]:
transformed_df.unpersist() #Release the cache

Out[12]: DataFrame[EmployeeID: int, Name: string, Age: int, Department: string, JoiningDate: date, Salary: int, City: string]

In [0]:
spark.stop() #stop the spark session

In [0]:
#persist()

from pyspark.sql import SparkSession
from pyspark import StorageLevel
spark = SparkSession.builder.appName("Persistance example").getOrCreate()

data = [("James",35),("Aman",30),("Michel",29),("Sarah",25)]
columns = ["Name","Age"]

df = spark.createDataFrame(data,columns)

transformed_df = df.filter(df.Age > 28)



In [0]:
transformed_df.persist(StorageLevel.MEMORY_AND_DISK)

Out[2]: DataFrame[Name: string, Age: bigint]

In [0]:
print("Count after persist:",transformed_df.count())

Count after persist: 3


In [0]:
transformed_df.show()

+------+---+
|  Name|Age|
+------+---+
| James| 35|
|  Aman| 30|
|Michel| 29|
+------+---+



In [0]:
group_df = transformed_df.groupBy("Age").count()

In [0]:
group_df.show()

+---+-----+
|Age|count|
+---+-----+
| 35|    1|
| 30|    1|
| 29|    1|
+---+-----+



In [0]:
transformed_df.unpersist()

Out[7]: DataFrame[Name: string, Age: bigint]

In [0]:
spark.stop()

In [0]:
# Broadcast 

how it works

- Instend of sending both datasets to all nodes for shuffling and merging, spark sends the smaller dataset(broadcasted) to each executor
-This reduces netwrok traffic because only the smaller dataset is distributed and the large dataset remain in inplace, avoding a costly shuffle of the large data  

When to use Broadcast joins:
-when one dataset is much smaller than the other
-when the smaller dataset can fit into the memory executor.

In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import broadcast

spark = SparkSession.builder.appName("Broadcast Join Example").getOrCreate()
df_large = spark.createDataFrame([(1,'2024-01-01',100),
                                  (2,'2024-01-02',150),
                                  (3,'2024-01-03',200)],["product_id","transaction_date","amount"])

In [0]:
df_small = spark.createDataFrame([(1,"Product A","Category 1"),
                                  (2,"Product B","Category 2"),
                                  (3,"Product C","Category 3")],["product_id","product_name","category"])

In [0]:
spark.conf.set("spark.sql.autoBroadcastJoinThreshold","20MB")

In [0]:
df_joined = df_large.join(broadcast(df_small),on='product_id',how='inner')

In [0]:
df_joined.show()

+----------+----------------+------+------------+----------+
|product_id|transaction_date|amount|product_name|  category|
+----------+----------------+------+------------+----------+
|         1|      2024-01-01|   100|   Product A|Category 1|
|         2|      2024-01-02|   150|   Product B|Category 2|
|         3|      2024-01-03|   200|   Product C|Category 3|
+----------+----------------+------+------------+----------+



In [0]:
#Data skew - salting

How to Detect Data Skew:
Skewed Partitions: Some partitions may contain significantly more data than others.
Stragglers: Certain tasks take much longer to complete compared to others.
Skew in Keys: When joining or aggregating on a key, a few keys may have a disproportionately large number of records.
Strategies to Avoid Skew in Spark Jobs

-Salting the Keys
One common approach to avoid skew in joins or aggregations is salting. When one key has a large number of records, you artificially split 1f9c2_salt that key into multiple sub-keys, which helps distribute the load evenly across partitions.

In [0]:
# from pyspark.sql import SparkSession
from pyspark.sql import functions as F,SparkSession
import random
spark = SparkSession.builder.appName("Salting Example").getOrCreate()


In [0]:
df_large = spark.createDataFrame([(1,'2024-01-01',100),
                                  (2,'2024-01-02',150),
                                  (3,'2024-01-03',200)],["product_id","transaction_date","amount"])

df_small = spark.createDataFrame([(1,"Product A","Category 1"),
                                  (2,"Product B","Category 2"),
                                  (3,"Product C","Category 3")],["product_id","product_name","category"])

In [0]:
# Add salt to large dataframe (df_arge) by creating a random number as salt
df_large = df_large.withColumn("salt",F.lit((random.randint(0,2))))

#Add the same salt column in df_small by cross-joining with the salt values
df_small =df_small.crossJoin(spark.createDataFrame([(0,),(1,),(2,)],["salt"]))

In [0]:
df_joined = df_large.join(df_small,on=['product_id','salt'],how='inner')

In [0]:
df_joined.show()

# Explanation:
# We artificially introduced a salt column to distribute the product_id key more evenly.
# The cross join with the salt column ensures that for each product_id in df_small, we have multiple salted versions to match the larger dataset.
# This breaks up the skewed key into smaller pieces to spread the load across multiple partitions.

+----------+----+----------------+------+------------+----------+
|product_id|salt|transaction_date|amount|product_name|  category|
+----------+----+----------------+------+------------+----------+
|         1|   2|      2024-01-01|   100|   Product A|Category 1|
|         2|   2|      2024-01-02|   150|   Product B|Category 2|
|         3|   2|      2024-01-03|   200|   Product C|Category 3|
+----------+----+----------------+------+------------+----------+



In [0]:
spark.stop()